Afrinetwork7 commited on
Commit
aa48ffa
1 Parent(s): 03f9951

Update whisper_jax/layers.py

Browse files
Files changed (1) hide show
  1. whisper_jax/layers.py +65 -0
whisper_jax/layers.py CHANGED
@@ -56,6 +56,71 @@ NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxi
56
  default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def dot_product_attention(
60
  query: Array,
61
  key: Array,
 
56
  default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
57
 
58
 
59
+ # ------------------------------------------------------------------------------
60
+ # Temporary inlined JAX N-d initializer code
61
+ # TODO(levskaya): remove once new JAX release is out.
62
+ # ------------------------------------------------------------------------------
63
+ def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
64
+ """Inlined JAX `nn.initializer._compute_fans`."""
65
+ if isinstance(in_axis, int):
66
+ in_size = shape[in_axis]
67
+ else:
68
+ in_size = int(np.prod([shape[i] for i in in_axis]))
69
+ if isinstance(out_axis, int):
70
+ out_size = shape[out_axis]
71
+ else:
72
+ out_size = int(np.prod([shape[i] for i in out_axis]))
73
+ receptive_field_size = shape.total / in_size / out_size
74
+ fan_in = in_size * receptive_field_size
75
+ fan_out = out_size * receptive_field_size
76
+ return fan_in, fan_out
77
+
78
+
79
+ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
80
+ """Inlined JAX `nn.initializer.variance_scaling`."""
81
+
82
+ def init(key, shape, dtype=dtype):
83
+ return jnp.zeros(shape, dtype=dtype)
84
+ dtype = jax.dtypes.canonicalize_dtype(dtype)
85
+ shape = jax.core.as_named_shape(shape)
86
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
87
+ if mode == "fan_in":
88
+ denominator = fan_in
89
+ elif mode == "fan_out":
90
+ denominator = fan_out
91
+ elif mode == "fan_avg":
92
+ denominator = (fan_in + fan_out) / 2
93
+ else:
94
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
95
+ variance = jnp.array(scale / denominator, dtype=dtype)
96
+
97
+ if distribution == "truncated_normal":
98
+ # constant is stddev of standard normal truncated to (-2, 2)
99
+ stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
100
+ return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
101
+ elif distribution == "normal":
102
+ return random.normal(key, shape, dtype) * jnp.sqrt(variance)
103
+ elif distribution == "uniform":
104
+ return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
105
+ else:
106
+ raise ValueError("invalid distribution for variance scaling " "initializer: {}".format(distribution))
107
+
108
+ return init
109
+
110
+
111
+ # ------------------------------------------------------------------------------
112
+
113
+
114
+ def nd_dense_init(scale, mode, distribution):
115
+ """Initializer with in_axis, out_axis set at call time."""
116
+
117
+ def init_fn(key, shape, dtype, in_axis, out_axis):
118
+ fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
119
+ return fn(key, shape, dtype)
120
+
121
+ return init_fn
122
+
123
+
124
  def dot_product_attention(
125
  query: Array,
126
  key: Array,