aapot commited on
Commit
2e1433f
1 Parent(s): 8981e8d

Saving weights and logs of step 10000

Browse files
__pycache__/distributed_shampoo.cpython-38.pyc ADDED
Binary file (51.7 kB). View file
 
distributed_shampoo.py ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
2
+
3
+ # coding=utf-8
4
+ # Copyright 2021 The Google Research Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # An implementation of distributed Shampoo optimizer from:
19
+ #
20
+ # Scalable Second Order Optimization for Deep Learning
21
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
22
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
23
+ #
24
+ # This implementation moves computation of inverse pth root back to the
25
+ # accelerator (if higher precision is available).
26
+ #
27
+ # Authors: Rohan Anil (rohananil at google dot com)
28
+ # & Vineet Gupta (vineet at google dot com)
29
+ #
30
+
31
+ """Distributed Shampoo Implementation."""
32
+
33
+ import enum
34
+ import functools
35
+ import itertools
36
+ from typing import Any, List, NamedTuple
37
+
38
+ import chex
39
+ from flax import struct
40
+ import jax
41
+ from jax import lax
42
+ import jax.experimental.pjit as pjit
43
+ import jax.numpy as jnp
44
+ import numpy as np
45
+ import optax
46
+
47
+
48
+ # pylint:disable=no-value-for-parameter
49
+ @struct.dataclass
50
+ class QuantizedValue:
51
+ """State associated with quantized value."""
52
+ quantized: chex.Array
53
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
54
+ bucket_size: chex.Array
55
+ quantized_dtype: jnp.dtype = struct.field(
56
+ pytree_node=False) # Dtype for the quantized value.
57
+ extract_diagonal: bool = struct.field(
58
+ pytree_node=False) # In case its centered.
59
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
60
+
61
+ @classmethod
62
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
63
+ if isinstance(fvalue, list) and not fvalue:
64
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
65
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
66
+ fvalue, quantized_dtype, extract_diagonal)
67
+ return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
68
+ quantized_dtype, extract_diagonal,
69
+ list(quantized.shape))
70
+
71
+ # Quantization is from Lingvo JAX optimizers.
72
+ # We extend it for int16 quantization of PSD matrices.
73
+ @classmethod
74
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
75
+ """Returns quantized value and the bucket."""
76
+ if quantized_dtype == jnp.float32:
77
+ return fvalue, [], []
78
+ elif quantized_dtype == jnp.bfloat16:
79
+ return fvalue.astype(jnp.bfloat16), [], []
80
+
81
+ float_dtype = fvalue.dtype
82
+ if quantized_dtype == jnp.int8:
83
+ # value -128 is not used.
84
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
85
+ elif quantized_dtype == jnp.int16:
86
+ # value -32768 is not used.
87
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
88
+ else:
89
+ raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
90
+ # max value is mapped to num_buckets
91
+
92
+ if extract_diagonal and fvalue.ndim != 2:
93
+ raise ValueError(
94
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.')
95
+
96
+ diagonal_fvalue = []
97
+ if extract_diagonal:
98
+ diagonal_fvalue = jnp.diag(fvalue)
99
+ # Remove the diagonal entries.
100
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
101
+
102
+ # TODO(rohananil): Extend this by making use of information about the blocks
103
+ # SM3 style which will be useful for diagonal statistics
104
+ # We first decide the scale.
105
+ if fvalue.ndim < 1:
106
+ raise ValueError(
107
+ f'Input array {fvalue} must have a strictly positive number of '
108
+ 'dimensions.')
109
+
110
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
111
+ bucket_size = max_abs / num_buckets
112
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
113
+ # To avoid divide by 0.0
114
+ bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
115
+ jnp.ones_like(bs_expanded))
116
+ ratio = fvalue / bs_nonzero
117
+ # We use rounding to remove bias.
118
+ quantized = jnp.round(ratio)
119
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
120
+
121
+ def to_float(self):
122
+ """Returns the float value."""
123
+ if isinstance(self.quantized, list) and not self.quantized:
124
+ return self.quantized
125
+
126
+ if self.quantized_dtype == jnp.float32:
127
+ return self.quantized
128
+
129
+ if self.quantized_dtype == jnp.bfloat16:
130
+ return self.quantized.astype(jnp.float32)
131
+
132
+ float_dtype = self.bucket_size.dtype
133
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
134
+ val = self.quantized.astype(float_dtype) * bucket_size
135
+ if self.extract_diagonal:
136
+ val += jnp.diag(self.diagonal)
137
+ return val
138
+
139
+
140
+ # Per parameter optimizer state used in data-parallel training.
141
+ class ParameterStats(NamedTuple):
142
+ """State associated to each parameter of the model being trained."""
143
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
144
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
145
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
146
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
147
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
148
+
149
+
150
+ # For training extremely large model; We keep a global state with a concatenated
151
+ # statistics and preconditioner states for all vars. This is so that we can
152
+ # annotate the leading axis to be sharded to save memory at the cost of
153
+ # communication.
154
+ @struct.dataclass
155
+ class GlobalShardedParameterStats:
156
+ statistics: chex.Array # Statistics
157
+ preconditioners: chex.Array # Preconditioners
158
+
159
+
160
+ # These are per-parameter local states; All statistics here mirror the parameter
161
+ # Thus the sharding is copied over from the param specification.
162
+ @struct.dataclass
163
+ class LocalShardedParameterStats:
164
+ """State associated to each parameter of the model being trained."""
165
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
166
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
167
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
168
+ index_start: np.int32 = struct.field(
169
+ pytree_node=False) # Index into global statistics array
170
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
171
+
172
+
173
+ class ShardedShampooStats(NamedTuple):
174
+ """Shampoo state in sharded mode."""
175
+ global_stats: Any
176
+ local_stats: Any
177
+
178
+
179
+ class ShampooState(NamedTuple):
180
+ count: chex.Array
181
+ stats: Any
182
+
183
+
184
+ class GraftingType(enum.IntEnum):
185
+ SGD = 1
186
+ ADAGRAD = 2
187
+ RMSPROP = 3
188
+ RMSPROP_NORMALIZED = 4
189
+
190
+
191
+ def power_iteration(
192
+ matrix,
193
+ num_iters=100,
194
+ error_tolerance=1e-6,
195
+ precision=lax.Precision.HIGHEST):
196
+ r"""Power iteration algorithm.
197
+
198
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
199
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
200
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
201
+
202
+ References:
203
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
204
+
205
+ Args:
206
+ matrix: the symmetric PSD matrix.
207
+ num_iters: Number of iterations.
208
+ error_tolerance: Iterative exit condition.
209
+ precision: precision XLA related flag, the available options are:
210
+ a) lax.Precision.DEFAULT (better step time, but not precise)
211
+ b) lax.Precision.HIGH (increased precision, slower)
212
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
213
+
214
+ Returns:
215
+ eigen vector, eigen value
216
+ """
217
+ matrix_size = matrix.shape[-1]
218
+ def _iter_condition(state):
219
+ i, unused_v, unused_s, unused_s_v, run_step = state
220
+ return jnp.logical_and(i < num_iters, run_step)
221
+
222
+ def _iter_body(state):
223
+ """One step of power iteration."""
224
+ i, new_v, s, s_v, unused_run_step = state
225
+ new_v = new_v / jnp.linalg.norm(new_v)
226
+
227
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
228
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
229
+ return (i + 1, s_v, s_new, s_v,
230
+ jnp.greater(jnp.abs(s_new - s), error_tolerance))
231
+
232
+ # Figure out how to use step as seed for random.
233
+ v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
234
+ matrix_size).astype(matrix.dtype)
235
+
236
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
237
+ _, v_out, s_out, _, _ = lax.while_loop(
238
+ _iter_condition, _iter_body, init_state)
239
+ v_out = v_out / jnp.linalg.norm(v_out)
240
+ return v_out, s_out
241
+
242
+
243
+ def matrix_inverse_pth_root(
244
+ matrix,
245
+ p,
246
+ num_iters=100,
247
+ ridge_epsilon=1e-6,
248
+ error_tolerance=1e-6,
249
+ precision=lax.Precision.HIGHEST):
250
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
251
+
252
+ This function uses the Coupled newton iterations algorithm for
253
+ the computation of a matrix's inverse pth root.
254
+
255
+
256
+ References:
257
+ [Functions of Matrices, Theory and Computation,
258
+ Nicholas J Higham, Pg 184, Eq 7.18](
259
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
260
+
261
+ Args:
262
+ matrix: the symmetric PSD matrix whose power it to be computed
263
+ p: exponent, for p a positive integer.
264
+ num_iters: Maximum number of iterations.
265
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
266
+ error_tolerance: Error indicator, useful for early termination.
267
+ precision: precision XLA related flag, the available options are:
268
+ a) lax.Precision.DEFAULT (better step time, but not precise)
269
+ b) lax.Precision.HIGH (increased precision, slower)
270
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
271
+
272
+ Returns:
273
+ matrix^(-1/p)
274
+ """
275
+
276
+ # We use float32 for the matrix inverse pth root.
277
+ # Switch to f64 if you have hardware that supports it.
278
+ matrix_size = matrix.shape[0]
279
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
280
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
281
+ _, max_ev = power_iteration(
282
+ matrix=matrix, num_iters=100,
283
+ error_tolerance=1e-6, precision=precision)
284
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
285
+
286
+ def _unrolled_mat_pow_1(mat_m):
287
+ """Computes mat_m^1."""
288
+ return mat_m
289
+
290
+ def _unrolled_mat_pow_2(mat_m):
291
+ """Computes mat_m^2."""
292
+ return jnp.matmul(mat_m, mat_m, precision=precision)
293
+
294
+ def _unrolled_mat_pow_4(mat_m):
295
+ """Computes mat_m^4."""
296
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
297
+ return jnp.matmul(
298
+ mat_pow_2, mat_pow_2, precision=precision)
299
+
300
+ def _unrolled_mat_pow_8(mat_m):
301
+ """Computes mat_m^4."""
302
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
303
+ return jnp.matmul(
304
+ mat_pow_4, mat_pow_4, precision=precision)
305
+
306
+ def mat_power(mat_m, p):
307
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
308
+
309
+ Args:
310
+ mat_m: a square matrix
311
+ p: a positive integer
312
+
313
+ Returns:
314
+ mat_m^p
315
+ """
316
+ # We unrolled the loop for performance reasons.
317
+ exponent = jnp.round(jnp.log2(p))
318
+ return lax.switch(
319
+ jnp.asarray(exponent, jnp.int32), [
320
+ _unrolled_mat_pow_1,
321
+ _unrolled_mat_pow_2,
322
+ _unrolled_mat_pow_4,
323
+ _unrolled_mat_pow_8,
324
+ ], (mat_m))
325
+
326
+ def _iter_condition(state):
327
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
328
+ run_step) = state
329
+ error_above_threshold = jnp.logical_and(
330
+ error > error_tolerance, run_step)
331
+ return jnp.logical_and(i < num_iters, error_above_threshold)
332
+
333
+ def _iter_body(state):
334
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
335
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
336
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
337
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
338
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
339
+ # sometimes error increases after an iteration before decreasing and
340
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
341
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
342
+ new_error < error * 1.2)
343
+
344
+ if matrix_size == 1:
345
+ resultant_mat_h = (matrix + ridge_epsilon)**alpha
346
+ error = 0
347
+ else:
348
+ damped_matrix = matrix + ridge_epsilon * identity
349
+
350
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
351
+ new_mat_m_0 = damped_matrix * z
352
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
353
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
354
+ init_state = tuple(
355
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
356
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
357
+ _iter_condition, _iter_body, init_state)
358
+ error = jnp.max(jnp.abs(mat_m - identity))
359
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
360
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
361
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
362
+ return resultant_mat_h, error
363
+
364
+
365
+ def merge_small_dims(shape_to_merge, max_dim):
366
+ """Merge small dimensions.
367
+
368
+ If there are some small dimensions, we collapse them:
369
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
370
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
371
+
372
+ Args:
373
+ shape_to_merge: Shape to merge small dimensions.
374
+ max_dim: Maximal dimension of output shape used in merging.
375
+
376
+ Returns:
377
+ Merged shape.
378
+ """
379
+ resulting_shape = []
380
+ product = 1
381
+ for d in shape_to_merge:
382
+ if product * d <= max_dim:
383
+ product *= d
384
+ else:
385
+ if product > 1:
386
+ resulting_shape.append(product)
387
+ product = d
388
+ if product > 1:
389
+ resulting_shape.append(product)
390
+ return resulting_shape
391
+
392
+
393
+ def pad_matrix(mat, max_size):
394
+ """Pad a matrix to a max_size.
395
+
396
+ Args:
397
+ mat: a matrix to pad.
398
+ max_size: matrix size requested.
399
+
400
+ Returns:
401
+ Given M returns [[M, 0], [0, I]]
402
+ """
403
+ size = mat.shape[0]
404
+ assert size <= max_size
405
+ if size == max_size:
406
+ return mat
407
+ pad_size = max_size - size
408
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
409
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
410
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
411
+ mat = jnp.concatenate([mat, zs1], 1)
412
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
413
+ return mat
414
+
415
+
416
+ def pad_vector(vec, max_size):
417
+ """Pad a vector to a max_size.
418
+
419
+ Args:
420
+ vec: a vector to pad.
421
+ max_size: matrix size requested.
422
+
423
+ Returns:
424
+ Given V returns [V, 0]
425
+ """
426
+ size = vec.shape[0]
427
+ assert size <= max_size
428
+ if size == max_size:
429
+ return vec
430
+ pad_size = max_size - size
431
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
432
+ return jnp.concatenate([vec, zs1], 0)
433
+
434
+
435
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
436
+ """Avoids wasteful buffer allocation with XLA."""
437
+
438
+ def _iter_body(unused_state):
439
+ results = compute_fn(*args, **kwargs)
440
+ return tuple([False] + list(results))
441
+
442
+ def _iter_condition(state):
443
+ return state[0]
444
+
445
+ results = jax.lax.while_loop(_iter_condition, _iter_body,
446
+ tuple([predicate] + init_state))
447
+ return tuple(results[1:])
448
+
449
+
450
+ class BlockPartitioner:
451
+ """Partitions a tensor into smaller tensors."""
452
+
453
+ def __init__(self, param, block_size):
454
+ self._shape = param.shape
455
+ self._splits = []
456
+ split_sizes = []
457
+ # We split params into smaller blocks. Here we store the metadata to make
458
+ # that split.
459
+ for i, d in enumerate(param.shape):
460
+ if 0 < block_size < d:
461
+ # d-1, otherwise split appends a 0-size array.
462
+ nsplit = (d - 1) // block_size
463
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
464
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
465
+ sizes[-1] = d - indices[-1]
466
+ self._splits.append((i, indices))
467
+ split_sizes.append(sizes)
468
+ else:
469
+ split_sizes.append(np.array([d], dtype=np.int32))
470
+ self._num_splits = len(split_sizes)
471
+ self._preconditioner_shapes = []
472
+ for t in itertools.product(*split_sizes):
473
+ self._preconditioner_shapes.extend([[d, d] for d in t])
474
+
475
+ def shapes_for_preconditioners(self):
476
+ return self._preconditioner_shapes
477
+
478
+ def num_splits(self):
479
+ return self._num_splits
480
+
481
+ def partition(self, tensor):
482
+ """Partition tensor into blocks."""
483
+
484
+ assert tensor.shape == self._shape
485
+ tensors = [tensor]
486
+ for (i, indices) in self._splits:
487
+ tensors_local = []
488
+ for t in tensors:
489
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
490
+ tensors = tensors_local
491
+ return tensors
492
+
493
+ def merge_partitions(self, partitions):
494
+ """Merge partitions back to original shape."""
495
+
496
+ for (i, indices) in reversed(self._splits):
497
+ n = len(indices) + 1
498
+ partial_merged_tensors = []
499
+ ind = 0
500
+ while ind < len(partitions):
501
+ partial_merged_tensors.append(
502
+ jnp.concatenate(partitions[ind:ind + n], axis=i))
503
+ ind += n
504
+ partitions = partial_merged_tensors
505
+ assert len(partitions) == 1
506
+ return partitions[0]
507
+
508
+
509
+ class Preconditioner:
510
+ """Compute statistics/shape from gradients for preconditioning."""
511
+
512
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
513
+ self._original_shape = param.shape
514
+ self._transformed_shape = param.shape
515
+ if best_effort_shape_interpretation:
516
+ self._transformed_shape = merge_small_dims(self._original_shape,
517
+ block_size)
518
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
519
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
520
+
521
+ def statistics_from_grad(self, grad):
522
+ """Compute statistics from gradients.
523
+
524
+ Args:
525
+ grad: Gradient to compute statistics from.
526
+
527
+ Returns:
528
+ A list of gradient statistics for each partition.
529
+ """
530
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
531
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
532
+ stats = []
533
+ for g in partitioned_grads:
534
+ g_stats = []
535
+ rank = len(g.shape)
536
+ for i in range(rank):
537
+ axes = list(range(i)) + list(range(i + 1, rank))
538
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
539
+ g_stats.append(stat)
540
+ stats.extend(g_stats)
541
+ return stats
542
+
543
+ def shapes_for_preconditioners(self):
544
+ """Returns shape from statistics."""
545
+ return self._partitioner.shapes_for_preconditioners()
546
+
547
+ def exponent_for_preconditioner(self):
548
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
549
+ return 2 * len(self._transformed_shape)
550
+
551
+ def preconditioned_grad(self, grad, preconditioners):
552
+ """Precondition the gradient.
553
+
554
+ Args:
555
+ grad: A gradient tensor to precondition.
556
+ preconditioners: A list of preconditioners to apply.
557
+
558
+ Returns:
559
+ A preconditioned gradient.
560
+ """
561
+
562
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
563
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
564
+ preconditioned_partitioned_grads = []
565
+ num_splits = self._partitioner.num_splits()
566
+ for i, g in enumerate(partitioned_grads):
567
+ preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
568
+ num_splits]
569
+ rank = len(g.shape)
570
+ precond_g = g
571
+ for j in range(rank):
572
+ precond_g = jnp.tensordot(
573
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
574
+ preconditioned_partitioned_grads.append(precond_g)
575
+ merged_grad = self._partitioner.merge_partitions(
576
+ preconditioned_partitioned_grads)
577
+ return jnp.reshape(merged_grad, self._original_shape)
578
+
579
+
580
+ def _convert_to_parameter_stats(global_stats, local_stat):
581
+ """Creates parameter stats from sharded stats."""
582
+ index_start = int(local_stat.index_start)
583
+ index_end = int(len(local_stat.sizes)) + index_start
584
+ statistics = global_stats.statistics[index_start:index_end, :, :]
585
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
586
+ new_statistics = []
587
+ new_preconditioners = []
588
+ for i, size in enumerate(local_stat.sizes):
589
+ new_statistics.append(statistics[i][:size, :size])
590
+ new_preconditioners.append(preconditioners[i][:size, :size])
591
+ return ParameterStats(local_stat.diagonal_statistics, new_statistics,
592
+ new_preconditioners, local_stat.diagonal_momentum,
593
+ local_stat.momentum)
594
+
595
+
596
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
597
+ """Creates sharded stats from paramter stats."""
598
+ return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
599
+ parameter_stats.diagonal_momentum,
600
+ parameter_stats.momentum,
601
+ local_stats.index_start, local_stats.sizes)
602
+
603
+
604
+ def batch(x, num_devices):
605
+ """Batch `x` so that so that leading axis is num_devices."""
606
+ n = len(x)
607
+ b = int(n / num_devices)
608
+ return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
609
+
610
+
611
+ def unbatch(batched_values):
612
+ """Unbatch values across leading axis and return a list of elements."""
613
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
614
+ results = []
615
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
616
+ v_array = jnp.squeeze(v_array)
617
+ # b2 = batches (number of preconditioner computation) per core.
618
+ if b2 > 1:
619
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
620
+ results.append(jnp.squeeze(v))
621
+ else:
622
+ results.append(v_array)
623
+ return results
624
+
625
+
626
+ def distributed_shampoo(
627
+ learning_rate,
628
+ block_size,
629
+ beta1=0.9,
630
+ beta2=0.999,
631
+ diagonal_epsilon=1e-10,
632
+ matrix_epsilon=1e-6,
633
+ weight_decay=0.0,
634
+ start_preconditioning_step=5,
635
+ preconditioning_compute_steps=1,
636
+ statistics_compute_steps=1,
637
+ best_effort_shape_interpretation=True,
638
+ graft_type=GraftingType.SGD,
639
+ nesterov=True,
640
+ exponent_override=0,
641
+ # Pass pmap 'batch axis name' in pmap mode.
642
+ batch_axis_name=None,
643
+ ### Only set following 3 params in pjit/spmd mode.
644
+ ### WARNING: Experimental
645
+ mesh_axis_names=None,
646
+ num_devices_for_pjit=None,
647
+ shard_optimizer_states=False,
648
+ ###
649
+ ### Experimental memory reduction mode
650
+ best_effort_memory_usage_reduction=False,
651
+ ###
652
+ inverse_failure_threshold=0.1,
653
+ moving_average_for_momentum=False,
654
+ skip_preconditioning_dim_size_gt=4096,
655
+ clip_by_scaled_gradient_norm=None,
656
+ precision=lax.Precision.HIGHEST):
657
+ """Distributed Shampoo optimizer.
658
+
659
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
660
+ variant of full-matrix Adagrad), that provides significant convergence and
661
+ wall-clock time improvements compared to conventional first-order methods,
662
+ and that has been shown to scale to large state-of-the-art deep learning
663
+ models.
664
+
665
+ References:
666
+ Scalable Second Order Optimization for Deep Learning,
667
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
668
+
669
+ Preprint: https://arxiv.org/abs/2002.09018
670
+
671
+ Args:
672
+ learning_rate: the step size used to update the parameters.
673
+ block_size: Block size for large layers (if > 0). Preconditioning compute
674
+ operation is cubic in the dimension of the tensor. Block size allows us to
675
+ chunk the layers into sub-layers of maximal dimension dictated by this
676
+ value. Use 128 as default (increase if you have compute budget).
677
+ beta1: momentum parameter.
678
+ beta2: second moment averaging parameter.
679
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
680
+ to AdaGrad is enabled).
681
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
682
+ root. If you are running in f32 precision for inverse pth root
683
+ (recommended today) this can go upto 1e-6. If you have latest hardware
684
+ with native f64 precision, set this upto 1e-12.
685
+ weight_decay: Weight decay for regularization.
686
+ start_preconditioning_step: When to start Shampoo update before which
687
+ diagonal update is used. This is because we dont have enough information
688
+ to do stable inverse.
689
+ preconditioning_compute_steps: How often to compute preconditioner.
690
+ Performance tuning params for controlling memory and compute requirements.
691
+ Ideally set this and statistics_compute_steps params to 1.
692
+ statistics_compute_steps: How often to compute statistics.
693
+ best_effort_shape_interpretation: If there are some small dimensions,
694
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
695
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
696
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
697
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
698
+ where SGD/AdaGrad is already well tuned. Available options are:
699
+ GraftingType.SGD and GraftingType.ADAGRAD.
700
+ nesterov: Nesterov momentum.
701
+ exponent_override: Override the exponent used in matrix inverse.
702
+ batch_axis_name: labeled axis over pmap for data-parallel training the
703
+ optimizer used for.
704
+ mesh_axis_names: Axis names for the mesh (used in pjit).
705
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
706
+ shard_optimizer_states: Shard optimizer states to save memory in model
707
+ parallel training.
708
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
709
+ diagonal_statistics -> jnp.bfloat16
710
+ momentum buffers (2x) -> jnp.int8
711
+ statistics, preconditioners -> jnp.int16 + diagonals
712
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
713
+ determine that using this threshold.
714
+ moving_average_for_momentum: Whether to use moving average for momentum
715
+ instead of exponential moving average.
716
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
717
+ greater than this value.
718
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
719
+ when using RMSProp Grafting).
720
+ precision: precision XLA related flag, the available options are: a)
721
+ lax.Precision.DEFAULT (better step time, but not precise) b)
722
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
723
+ (best possible precision, slowest)
724
+
725
+ Returns:
726
+ a GradientTransformation.
727
+ """
728
+
729
+ def quantized_dtype_for_momentum_buffers():
730
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
731
+
732
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
733
+ def quantized_dtype_for_diagonal_statistics_buffers():
734
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
735
+
736
+ # Preconditioner and statistics are both stores as int16 in this mode.
737
+ # We take out the diagonal to make quantization easier.
738
+ def quantized_dtype_for_second_moment_statistics_buffers():
739
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
740
+
741
+ # Preconditioner and statistics are both stores as int16 in this mode.
742
+ # We take out the diagonal to make quantization easier.
743
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
744
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
745
+
746
+ def _to_float(maybe_quantized):
747
+ if isinstance(maybe_quantized, QuantizedValue):
748
+ return maybe_quantized.to_float()
749
+ else:
750
+ return maybe_quantized
751
+
752
+ def _maybe_quantize_statistics(statistics_list):
753
+ return _maybe_quantize_matrices_with_dtype(
754
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
755
+
756
+ def _maybe_quantize_preconditioners(statistics_list):
757
+ return _maybe_quantize_matrices_with_dtype(
758
+ statistics_list,
759
+ quantized_dtype_for_second_moment_preconditioner_buffers())
760
+
761
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
762
+ if quantized_dtype != jnp.float32:
763
+ return ([
764
+ QuantizedValue.from_float_value(
765
+ s, quantized_dtype, extract_diagonal=True)
766
+ for s in statistics_list
767
+ ])
768
+ else:
769
+ return statistics_list
770
+
771
+ def _maybe_dequantize_preconditioners(preconditioner_list):
772
+ return _maybe_dequantize_matrices_with_dtype(
773
+ preconditioner_list,
774
+ quantized_dtype_for_second_moment_preconditioner_buffers())
775
+
776
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
777
+ if quantized_dtype != jnp.float32:
778
+ return [s.to_float() for s in statistics_list]
779
+ else:
780
+ return statistics_list
781
+
782
+ def _quantize_diagonal_statistics(diagonal_statistics):
783
+ return QuantizedValue.from_float_value(
784
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
785
+
786
+ def _quantize_momentum(momentum_statistics):
787
+ return QuantizedValue.from_float_value(
788
+ momentum_statistics, quantized_dtype_for_momentum_buffers())
789
+
790
+ def sharded_init_fn(params):
791
+ params_flat, treedef = jax.tree_flatten(params)
792
+ # Find max size to pad to.
793
+ max_size = 0
794
+ for param in params_flat:
795
+ preconditioner = Preconditioner(param, block_size,
796
+ best_effort_shape_interpretation)
797
+ if not _skip_preconditioning(param):
798
+ shapes = preconditioner.shapes_for_preconditioners()
799
+ sizes = [s[0] for s in shapes]
800
+ max_size = max(max(sizes), max_size)
801
+
802
+ padded_statistics = []
803
+ padded_preconditioners = []
804
+ local_stats_flat = []
805
+ for param in params_flat:
806
+ preconditioner = Preconditioner(param, block_size,
807
+ best_effort_shape_interpretation)
808
+ shapes = preconditioner.shapes_for_preconditioners()
809
+ sizes = []
810
+
811
+ statistics = []
812
+ preconditioners = []
813
+ index_start = len(padded_statistics)
814
+ if not _skip_preconditioning(param):
815
+ sizes = [s[0] for s in shapes]
816
+ shapes = preconditioner.shapes_for_preconditioners()
817
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
818
+ preconditioners = [jnp.eye(max_size) for s in shapes]
819
+ padded_statistics.extend(statistics)
820
+ padded_preconditioners.extend(preconditioners)
821
+
822
+ diagonal_statistics = []
823
+ if graft_type != GraftingType.SGD:
824
+ diagonal_statistics = jnp.zeros_like(param)
825
+ local_stats_flat.append(
826
+ LocalShardedParameterStats(
827
+ _quantize_diagonal_statistics(diagonal_statistics),
828
+ _quantize_momentum(jnp.zeros_like(param)),
829
+ _quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
830
+
831
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
832
+ # Pad the statistics and preconditioner matrices to be a multiple of
833
+ # num devices.
834
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
835
+ # is split on.
836
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
837
+ padded_statistics.extend([
838
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
839
+ for _ in range(to_pad)
840
+ ])
841
+ padded_preconditioners.extend([
842
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
843
+ for _ in range(to_pad)
844
+ ])
845
+ global_stats = GlobalShardedParameterStats(
846
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners))
847
+ return ShampooState(
848
+ count=jnp.zeros([], jnp.int32),
849
+ stats=ShardedShampooStats(global_stats, local_stats))
850
+
851
+ def sharded_update_fn(grads, state, params):
852
+ """Transform the input gradient and update all statistics in sharded mode.
853
+
854
+ Args:
855
+ grads: the gradient tensors for the parameters.
856
+ state: a named tuple containing the state of the optimizer
857
+ params: the parameters that should be updated.
858
+
859
+ Returns:
860
+ A tuple containing the new parameters and the new optimizer state.
861
+ """
862
+ params_flat, treedef = jax.tree_flatten(params)
863
+ grads_flat = treedef.flatten_up_to(grads)
864
+
865
+ global_stats = state.stats.global_stats
866
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
867
+ stats_flat = [
868
+ _convert_to_parameter_stats(global_stats, local_stat)
869
+ for local_stat in local_stats_flat
870
+ ]
871
+ new_stats_flat = jax.tree_multimap(
872
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
873
+ stats_flat, params_flat)
874
+
875
+ exponents = []
876
+ for stat, param in zip(new_stats_flat, params_flat):
877
+ num_statistics = len(stat.statistics)
878
+ if num_statistics > 0:
879
+ preconditioner = Preconditioner(param, block_size,
880
+ best_effort_shape_interpretation)
881
+ exponent = (
882
+ preconditioner.exponent_for_preconditioner()
883
+ if exponent_override == 0 else exponent_override)
884
+ exponents.extend([exponent] * num_statistics)
885
+
886
+ outputs = jax.tree_multimap(
887
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
888
+ new_stats_flat, params_flat)
889
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
890
+
891
+ updates = jax.tree_unflatten(treedef, updates_flat)
892
+ # Create new local_stats
893
+ new_local_stats_flat = [
894
+ _convert_from_parameter_stats(new_stat, local_stat)
895
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
896
+ ]
897
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
898
+
899
+ max_size = global_stats.statistics.shape[1]
900
+ new_padded_statistics = []
901
+ for stat in new_stats_flat:
902
+ new_padded_statistics.extend(
903
+ [pad_matrix(stat, max_size) for stat in stat.statistics])
904
+
905
+ # Create global stats
906
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
907
+ # stack/pad can be obviated away.
908
+ # Pad the statistics and preconditioner matrices to be a multiple of
909
+ # num devices.
910
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
911
+ # is split on.
912
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
913
+ new_padded_statistics.extend([
914
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
915
+ for _ in range(to_pad)
916
+ ])
917
+ exponents.extend([1 for _ in range(to_pad)])
918
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
919
+ new_stacked_exponents = jnp.stack(exponents)
920
+ def _matrix_inverse_pth_root_vmap(xs, ps):
921
+ mi_pth_root = functools.partial(
922
+ matrix_inverse_pth_root,
923
+ ridge_epsilon=matrix_epsilon,
924
+ precision=precision)
925
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
926
+ return preconditioners, errors
927
+
928
+ def _internal_inverse_pth_root_all():
929
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
930
+ new_stacked_padded_statistics, new_stacked_exponents)
931
+ return preconditioners, errors
932
+
933
+ if preconditioning_compute_steps == 1:
934
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
935
+ else:
936
+ # Passing statistics instead of preconditioners as they are similarly
937
+ # shaped tensors. Note statistics will be ignored as we are passing in
938
+ # a large init value for error.
939
+ preconditioners_init = new_stacked_padded_statistics
940
+ errors_init = np.stack([inverse_failure_threshold] * len(exponents))
941
+ init_state = [preconditioners_init, errors_init]
942
+ perform_step = state.count % preconditioning_compute_steps == 0
943
+ new_preconditioners, errors = efficient_cond(
944
+ perform_step, _internal_inverse_pth_root_all, init_state)
945
+
946
+ errors = errors.reshape((-1, 1, 1))
947
+ predicate = jnp.logical_or(
948
+ jnp.isnan(errors),
949
+ errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
950
+ # TODO(rohananil): Check for numerical instabilities.
951
+ new_conditional_preconditioners = (
952
+ predicate * global_stats.preconditioners +
953
+ (1.0 - predicate) * new_preconditioners)
954
+ new_global_stats = GlobalShardedParameterStats(
955
+ new_stacked_padded_statistics, new_conditional_preconditioners)
956
+ new_shampoo_state = ShampooState(
957
+ count=state.count + 1,
958
+ stats=ShardedShampooStats(new_global_stats, new_local_stats))
959
+ return updates, new_shampoo_state
960
+
961
+ def init_fn(params):
962
+ """Initialise the optimiser's state."""
963
+
964
+ def _init(param):
965
+ preconditioner = Preconditioner(param, block_size,
966
+ best_effort_shape_interpretation)
967
+ statistics = []
968
+ preconditioners = []
969
+ if not _skip_preconditioning(param):
970
+ shapes = preconditioner.shapes_for_preconditioners()
971
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
972
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
973
+
974
+ diagonal_statistics = []
975
+ if graft_type != GraftingType.SGD:
976
+ diagonal_statistics = jnp.zeros_like(param)
977
+ return ParameterStats(
978
+ _quantize_diagonal_statistics(diagonal_statistics),
979
+ _maybe_quantize_statistics(statistics),
980
+ _maybe_quantize_preconditioners(preconditioners),
981
+ _quantize_momentum(jnp.zeros_like(param)),
982
+ _quantize_momentum(jnp.zeros_like(param)))
983
+ return ShampooState(
984
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
985
+
986
+ def _skip_preconditioning(param):
987
+ return len(param.shape) < 1 or any(
988
+ [s > skip_preconditioning_dim_size_gt for s in param.shape])
989
+
990
+ def _compute_stats(grad, state, param, step):
991
+ """Compute per-parameter statistics."""
992
+ preconditioner = Preconditioner(param, block_size,
993
+ best_effort_shape_interpretation)
994
+ new_statistics = [[]] * len(state.statistics)
995
+ w1 = beta2
996
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
997
+ if not _skip_preconditioning(param):
998
+
999
+ def compute_updated_statistics():
1000
+ new_stats = preconditioner.statistics_from_grad(grad)
1001
+ new_stats_accumulators = []
1002
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1003
+ new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
1004
+ w2 * stat)
1005
+ return _maybe_quantize_statistics(new_stats_accumulators)
1006
+
1007
+ if statistics_compute_steps > 1:
1008
+ perform_step = step % statistics_compute_steps == 0
1009
+ init_state = state.statistics
1010
+ new_statistics = list(
1011
+ efficient_cond(perform_step, compute_updated_statistics,
1012
+ init_state))
1013
+ else:
1014
+ new_statistics = compute_updated_statistics()
1015
+ return ParameterStats(state.diagonal_statistics, new_statistics,
1016
+ state.preconditioners, state.diagonal_momentum,
1017
+ state.momentum)
1018
+
1019
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1020
+ mi_pth_root = functools.partial(
1021
+ matrix_inverse_pth_root,
1022
+ ridge_epsilon=matrix_epsilon,
1023
+ precision=precision)
1024
+ return jax.vmap(mi_pth_root)(xs, ps)
1025
+
1026
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1027
+
1028
+ def _quantized_to_float(qx, qd, qb):
1029
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1030
+ return qv.to_float()
1031
+
1032
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1033
+ v = _quantized_to_float(qx, qd, qb)
1034
+ preconditioner, error = matrix_inverse_pth_root(
1035
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision)
1036
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1037
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1038
+
1039
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1040
+
1041
+ def _matrix_inverse_pth_root_pjit(xs, ps):
1042
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
1043
+ # Partition the concatenated statistics matrix across all cores.
1044
+ partitioned_xs, partitioned_ps = pjit.pjit(
1045
+ lambda x, y: (x, y),
1046
+ in_axis_resources=None,
1047
+ out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
1048
+ # Run matrix inverse pth root on each shard.
1049
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1050
+ partitioned_xs, partitioned_ps)
1051
+ # Recombine the outputs at each core.
1052
+ preconditioners, errors = pjit.pjit(
1053
+ lambda x, y: (x, y),
1054
+ in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
1055
+ pjit.PartitionSpec(mesh_axis_names_tuple,)),
1056
+ out_axis_resources=(None, None))(partitioned_preconditioners,
1057
+ partitioned_errors)
1058
+ return preconditioners, errors
1059
+
1060
+ def _pmap_compute_preconditioners(states, step, statistics,
1061
+ num_statistics_per_state, original_shapes,
1062
+ exponents, max_size, prev_preconditioners):
1063
+ """Computes preconditioners for given statistics in states in PMAP mode.
1064
+
1065
+ Args:
1066
+ states: A list of optimizer states.
1067
+ step: Current step number
1068
+ statistics: A list of statistics for all variables (for every dim)
1069
+ num_statistics_per_state: Number of statistis per state to reconstruct
1070
+ output states.
1071
+ original_shapes: A list of shapes of the statistics.
1072
+ exponents: Exponent power to use for inverse-pth roots.
1073
+ max_size: Maximum dim of the statistics to pad.
1074
+ prev_preconditioners: Previously available preconditioner.
1075
+
1076
+ Returns:
1077
+ New optimizer states after computing the preconditioner.
1078
+ """
1079
+ num_devices = lax.psum(1, batch_axis_name)
1080
+ num_statistics = len(statistics)
1081
+ # Pad statistics and exponents to next multiple of num_devices.
1082
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1083
+ to_pad = -num_statistics % num_devices
1084
+ packed_statistics.extend([
1085
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
1086
+ for _ in range(to_pad)
1087
+ ])
1088
+ exponents.extend([1 for _ in range(to_pad)])
1089
+
1090
+ if not packed_statistics:
1091
+ return states
1092
+
1093
+ all_statistics = batch(packed_statistics, num_devices)
1094
+ all_exponents = batch(exponents, num_devices)
1095
+
1096
+ def _internal_inverse_pth_root_all():
1097
+ current_replica = lax.axis_index(batch_axis_name)
1098
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1099
+ all_statistics[current_replica], all_exponents[current_replica])
1100
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1101
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1102
+ preconditioners_flat = unbatch(preconditioners)
1103
+ errors_flat = unbatch(errors)
1104
+ return preconditioners_flat, errors_flat
1105
+
1106
+ if preconditioning_compute_steps == 1:
1107
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1108
+ else:
1109
+ # Passing statistics instead of preconditioners as they are similarly
1110
+ # shaped tensors. Note statistics will be ignored as we are passing in
1111
+ # a large init value for error.
1112
+ preconditioners_init = packed_statistics
1113
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
1114
+ init_state = [preconditioners_init, errors_init]
1115
+ perform_step = step % preconditioning_compute_steps == 0
1116
+ preconditioners_flat, errors_flat = efficient_cond(
1117
+ perform_step, _internal_inverse_pth_root_all, init_state)
1118
+
1119
+ def _skip(error):
1120
+ condition = jnp.logical_or(
1121
+ jnp.isnan(error), error >= inverse_failure_threshold)
1122
+ return condition.astype(error.dtype)
1123
+
1124
+ def _select_preconditioner(error, new_p, old_p):
1125
+ return lax.cond(
1126
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1127
+
1128
+ new_preconditioners_flat = []
1129
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1130
+ prev_preconditioners, errors_flat):
1131
+ new_preconditioners_flat.append(
1132
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1133
+
1134
+ assert len(states) == len(num_statistics_per_state)
1135
+ assert len(new_preconditioners_flat) == num_statistics
1136
+
1137
+ # Add back empty preconditioners so we that we can set the optimizer state.
1138
+ preconditioners_for_states = []
1139
+ idx = 0
1140
+ for num_statistics, state in zip(num_statistics_per_state, states):
1141
+ if num_statistics == 0:
1142
+ preconditioners_for_states.append([])
1143
+ else:
1144
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1145
+ num_statistics]
1146
+ assert len(state.statistics) == len(preconditioners_for_state)
1147
+ preconditioners_for_states.append(preconditioners_for_state)
1148
+ idx += num_statistics
1149
+ new_states = []
1150
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1151
+ new_states.append(
1152
+ ParameterStats(state.diagonal_statistics, state.statistics,
1153
+ new_preconditioners, state.diagonal_momentum,
1154
+ state.momentum))
1155
+
1156
+ return new_states
1157
+
1158
+ def _pmap_quantized_compute_preconditioners(states, step, statistics,
1159
+ num_statistics_per_state,
1160
+ original_shapes, exponents,
1161
+ max_size, prev_preconditioners):
1162
+ """Computes preconditioners for given statistics in states in PMAP mode.
1163
+
1164
+ For quantization, each statistic is represented by three values:
1165
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1166
+ without ever recreating the original matrix in f32.
1167
+
1168
+ Args:
1169
+ states: A list of optimizer states.
1170
+ step: Current step number
1171
+ statistics: A list of statistics for all variables (for every dim)
1172
+ num_statistics_per_state: Number of statistis per state to reconstruct
1173
+ output states.
1174
+ original_shapes: A list of shapes of the statistics.
1175
+ exponents: Exponent power to use for inverse-pth roots.
1176
+ max_size: Maximum dim of the statistics to pad.
1177
+ prev_preconditioners: Previously available preconditioner.
1178
+
1179
+ Returns:
1180
+ New optimizer states after computing the preconditioner.
1181
+ """
1182
+ num_devices = lax.psum(1, batch_axis_name)
1183
+ num_statistics = len(statistics)
1184
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1185
+ # Complexity here is around: shapes needing be statically shaped,
1186
+ # our custom quantization type requires a different type of packing.
1187
+
1188
+ # Parallel tensors:
1189
+ # quantized [dxd]
1190
+ # diagonals [d] f32
1191
+ # bucket_sizes [d] f32
1192
+ packed_quantized_statistics = [
1193
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1194
+ ]
1195
+ packed_quantized_diagonals = [
1196
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1197
+ ]
1198
+ packed_quantized_bucket_sizes = [
1199
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1200
+ ]
1201
+
1202
+ to_pad = -num_statistics % num_devices
1203
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1204
+ quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
1205
+ True)
1206
+ packed_quantized_statistics.extend(
1207
+ [quantized_eye.quantized for _ in range(to_pad)])
1208
+ packed_quantized_diagonals.extend(
1209
+ [quantized_eye.diagonal for _ in range(to_pad)])
1210
+ packed_quantized_bucket_sizes.extend(
1211
+ [quantized_eye.bucket_size for _ in range(to_pad)])
1212
+ exponents.extend([1 for _ in range(to_pad)])
1213
+
1214
+ if not packed_quantized_statistics:
1215
+ return states
1216
+
1217
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1218
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1219
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
1220
+ num_devices)
1221
+ all_exponents = batch(exponents, num_devices)
1222
+
1223
+ def _internal_inverse_pth_root_all():
1224
+ current_replica = lax.axis_index(batch_axis_name)
1225
+ quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = (
1226
+ _quantized_matrix_inverse_pth_root_vmap(
1227
+ all_quantized_statistics[current_replica],
1228
+ all_quantized_diagonals[current_replica],
1229
+ all_quantized_bucket_sizes[current_replica],
1230
+ all_exponents[current_replica]))
1231
+ quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
1232
+ batch_axis_name)
1233
+ quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
1234
+ batch_axis_name)
1235
+ quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
1236
+ batch_axis_name)
1237
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1238
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1239
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1240
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1241
+ errors_flat = unbatch(errors)
1242
+ return (quantized_preconditioners_flat, quantized_diagonals_flat,
1243
+ quantized_bucket_sizes_flat, errors_flat)
1244
+
1245
+ if preconditioning_compute_steps == 1:
1246
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1247
+ quantized_bucket_sizes_flat, errors_flat) = (
1248
+ _internal_inverse_pth_root_all())
1249
+ else:
1250
+ # Passing statistics instead of preconditioners as they are similarly
1251
+ # shaped tensors. Note statistics will be ignored as we are passing in
1252
+ # a large init value for error.
1253
+ quantized_preconditioners_init = packed_quantized_statistics
1254
+ quantized_diagonals_init = packed_quantized_diagonals
1255
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1256
+ errors_init = ([inverse_failure_threshold] *
1257
+ len(quantized_preconditioners_init))
1258
+ init_state = [
1259
+ quantized_preconditioners_init, quantized_diagonals_init,
1260
+ quantized_bucket_sizes_init, errors_init
1261
+ ]
1262
+ perform_step = step % preconditioning_compute_steps == 0
1263
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1264
+ quantized_bucket_sizes_flat, errors_flat) = (
1265
+ efficient_cond(perform_step, _internal_inverse_pth_root_all,
1266
+ init_state))
1267
+
1268
+ def _skip(error):
1269
+ condition = jnp.logical_or(
1270
+ jnp.isnan(error), error >= inverse_failure_threshold)
1271
+ return condition.astype(error.dtype)
1272
+
1273
+ def _select_preconditioner(error, new_p, old_p):
1274
+ return lax.cond(
1275
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1276
+
1277
+ new_quantized_preconditioners_flat = []
1278
+ new_quantized_diagonals_flat = []
1279
+ new_quantized_bucket_sizes_flat = []
1280
+ for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
1281
+ quantized_diagonals_flat,
1282
+ quantized_bucket_sizes_flat,
1283
+ original_shapes,
1284
+ prev_preconditioners, errors_flat):
1285
+ new_quantized_preconditioners_flat.append(
1286
+ _select_preconditioner(error, p[:shape[0], :shape[1]],
1287
+ prev_p.quantized))
1288
+ new_quantized_diagonals_flat.append(
1289
+ _select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
1290
+ new_quantized_bucket_sizes_flat.append(
1291
+ _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
1292
+
1293
+ assert len(states) == len(num_statistics_per_state)
1294
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1295
+ assert len(new_quantized_diagonals_flat) == num_statistics
1296
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1297
+
1298
+ # Add back empty preconditioners so we that we can set the optimizer state.
1299
+ preconditioners_for_states = []
1300
+ idx = 0
1301
+ for num_statistics, state in zip(num_statistics_per_state, states):
1302
+ if num_statistics == 0:
1303
+ preconditioners_for_states.append([])
1304
+ else:
1305
+ quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
1306
+ idx:idx + num_statistics]
1307
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1308
+ idx:idx + num_statistics]
1309
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1310
+ idx:idx + num_statistics]
1311
+
1312
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1313
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1314
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1315
+
1316
+ quantized_preconditioners = []
1317
+ for qv, qd, qb in zip(quantized_preconditioners_for_state,
1318
+ quantized_diagonals_for_state,
1319
+ quantized_bucket_sizes_for_state):
1320
+ quantized_preconditioners.append(
1321
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
1322
+ preconditioners_for_states.append(quantized_preconditioners)
1323
+ idx += num_statistics
1324
+ new_states = []
1325
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1326
+ new_states.append(
1327
+ ParameterStats(state.diagonal_statistics, state.statistics,
1328
+ new_preconditioners, state.diagonal_momentum,
1329
+ state.momentum))
1330
+
1331
+ return new_states
1332
+
1333
+ def _pjit_compute_preconditioners(states, step, statistics,
1334
+ num_statistics_per_state, original_shapes,
1335
+ exponents, max_size, prev_preconditioners):
1336
+ """Computes preconditioners for given statistics in states in PJIT mode.
1337
+
1338
+ Args:
1339
+ states: A list of optimizer states.
1340
+ step: Current step number
1341
+ statistics: A list of statistics for all variables (for every dim)
1342
+ num_statistics_per_state: Number of statistis per state to reconstruct
1343
+ output states.
1344
+ original_shapes: A list of shapes of the statistics.
1345
+ exponents: Exponent power to use for inverse-pth roots.
1346
+ max_size: Maximum dim of the statistics to pad.
1347
+ prev_preconditioners: Previously available preconditioner.
1348
+
1349
+ Returns:
1350
+ New optimizer states after computing the preconditioner.
1351
+ """
1352
+ num_statistics = len(statistics)
1353
+ to_pad = -num_statistics % num_devices_for_pjit
1354
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1355
+ padded_statistics.extend([
1356
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
1357
+ for _ in range(to_pad)
1358
+ ])
1359
+ exponents.extend([1 for _ in range(to_pad)])
1360
+ all_statistics = jnp.stack(padded_statistics)
1361
+ all_exponents = jnp.stack(exponents)
1362
+
1363
+ def _internal_inverse_pth_root_all():
1364
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1365
+ all_statistics, all_exponents)
1366
+ b1 = preconditioners.shape[0]
1367
+
1368
+ def split(batched_values):
1369
+ return [
1370
+ jnp.squeeze(v)
1371
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1372
+ ]
1373
+
1374
+ return split(preconditioners), split(errors)
1375
+
1376
+ if preconditioning_compute_steps == 1:
1377
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1378
+ else:
1379
+ # Passing statistics instead of preconditioners as they are similarly
1380
+ # shaped tensors. Note statistics will be ignored as we are passing in
1381
+ # a large init value for error.
1382
+ preconditioners_init = padded_statistics
1383
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1384
+ init_state = [preconditioners_init, errors_init]
1385
+ perform_step = step % preconditioning_compute_steps == 0
1386
+ preconditioners_flat, errors_flat = efficient_cond(
1387
+ perform_step, _internal_inverse_pth_root_all, init_state)
1388
+
1389
+ def _skip(error):
1390
+ condition = jnp.logical_or(
1391
+ jnp.isnan(error), error >= inverse_failure_threshold)
1392
+ return condition.astype(error.dtype)
1393
+
1394
+ def _select_preconditioner(error, new_p, old_p):
1395
+ return lax.cond(
1396
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1397
+
1398
+ new_preconditioners_flat = []
1399
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1400
+ prev_preconditioners, errors_flat):
1401
+ new_preconditioners_flat.append(
1402
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1403
+
1404
+ assert len(states) == len(num_statistics_per_state)
1405
+ assert len(new_preconditioners_flat) == num_statistics
1406
+
1407
+ # Add back empty preconditioners so we that we can set the optimizer state.
1408
+ preconditioners_for_states = []
1409
+ idx = 0
1410
+ for num_statistics, state in zip(num_statistics_per_state, states):
1411
+ if num_statistics == 0:
1412
+ preconditioners_for_states.append([])
1413
+ else:
1414
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1415
+ num_statistics]
1416
+ assert len(state.statistics) == len(preconditioners_for_state)
1417
+ preconditioners_for_states.append(preconditioners_for_state)
1418
+ idx += num_statistics
1419
+ new_states = []
1420
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1421
+ new_states.append(
1422
+ ParameterStats(state.diagonal_statistics, state.statistics,
1423
+ new_preconditioners, state.diagonal_momentum,
1424
+ state.momentum))
1425
+
1426
+ return new_states
1427
+
1428
+ def _compute_preconditioners(states, params, step):
1429
+ """Computes preconditioners for given statistics in states.
1430
+
1431
+ Args:
1432
+ states: A list of optimizer states.
1433
+ params: A list of params.
1434
+ step: Current step number
1435
+
1436
+ Returns:
1437
+ New optimizer states after computing the preconditioner.
1438
+ """
1439
+ statistics = []
1440
+ num_statistics_per_state = []
1441
+ original_shapes = []
1442
+ exponents = []
1443
+ max_size = 0
1444
+ prev_preconditioners = []
1445
+
1446
+ for state, param in zip(states, params):
1447
+ num_statistics = len(state.statistics)
1448
+ num_statistics_per_state.append(num_statistics)
1449
+ original_shapes_for_state = []
1450
+ if num_statistics > 0:
1451
+ preconditioner = Preconditioner(param, block_size,
1452
+ best_effort_shape_interpretation)
1453
+ for statistic in state.statistics:
1454
+ exponents.append(preconditioner.exponent_for_preconditioner(
1455
+ ) if exponent_override == 0 else exponent_override)
1456
+ original_shapes_for_state.append(statistic.shape)
1457
+ max_size = max(max_size, statistic.shape[0])
1458
+
1459
+ statistics.extend(state.statistics)
1460
+ prev_preconditioners.extend(state.preconditioners)
1461
+ original_shapes.extend(original_shapes_for_state)
1462
+
1463
+ if batch_axis_name:
1464
+ # Quantization is only enabled if batch_axis_name is not set.
1465
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1466
+
1467
+ if quantized_dtype == jnp.float32:
1468
+ return _pmap_compute_preconditioners(states, step, statistics,
1469
+ num_statistics_per_state,
1470
+ original_shapes, exponents,
1471
+ max_size, prev_preconditioners)
1472
+ else:
1473
+ return _pmap_quantized_compute_preconditioners(
1474
+ states, step, statistics, num_statistics_per_state, original_shapes,
1475
+ exponents, max_size, prev_preconditioners)
1476
+
1477
+ else:
1478
+ return _pjit_compute_preconditioners(states, step, statistics,
1479
+ num_statistics_per_state,
1480
+ original_shapes, exponents, max_size,
1481
+ prev_preconditioners)
1482
+
1483
+ def _transform_grad(grad, state, param, step):
1484
+ """Transform per-parameter gradients."""
1485
+ preconditioner = Preconditioner(param, block_size,
1486
+ best_effort_shape_interpretation)
1487
+ sgd_update = grad
1488
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1489
+ if graft_type == GraftingType.ADAGRAD:
1490
+ new_diagonal_statistics = state.diagonal_statistics.to_float(
1491
+ ) + jnp.square(grad)
1492
+ adagrad_update = grad / (
1493
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1494
+ grafting_update = adagrad_update
1495
+ elif (graft_type == GraftingType.RMSPROP or
1496
+ graft_type == GraftingType.RMSPROP_NORMALIZED):
1497
+
1498
+ scaled_grad = grad
1499
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1500
+ scaled_grad = grad / jnp.linalg.norm(grad)
1501
+
1502
+ w1 = beta2
1503
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1504
+
1505
+ new_diagonal_statistics = (
1506
+ w1 * state.diagonal_statistics.to_float() +
1507
+ w2 * jnp.square(scaled_grad))
1508
+ rmsprop_update = scaled_grad / (
1509
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1510
+
1511
+ if clip_by_scaled_gradient_norm:
1512
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1513
+ jnp.sqrt(float(rmsprop_update.size)))
1514
+ clipping_denom = jnp.maximum(
1515
+ 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
1516
+ rmsprop_update /= clipping_denom
1517
+
1518
+ grafting_update = rmsprop_update
1519
+ else:
1520
+ grafting_update = sgd_update
1521
+
1522
+ precond_grad = grad
1523
+ if not _skip_preconditioning(param):
1524
+ precond_grad = preconditioner.preconditioned_grad(
1525
+ precond_grad,
1526
+ _maybe_dequantize_preconditioners(state.preconditioners))
1527
+ else:
1528
+ precond_grad = grafting_update
1529
+
1530
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1531
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1532
+
1533
+ multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
1534
+ shampoo_update = precond_grad * multiplier
1535
+
1536
+ shampoo_update_with_wd = shampoo_update
1537
+ grafting_update_with_wd = grafting_update
1538
+ if weight_decay != 0:
1539
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1540
+ grafting_update_with_wd = grafting_update + weight_decay * param
1541
+
1542
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1543
+ shampoo_update_with_wd_momentum = (
1544
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
1545
+ grafting_update_with_wd_momentum = (
1546
+ state.diagonal_momentum.to_float() * beta1 +
1547
+ w * grafting_update_with_wd)
1548
+
1549
+ run_shampoo = (step >= start_preconditioning_step).astype(
1550
+ grafting_update_with_wd_momentum.dtype)
1551
+
1552
+ momentum_update = (
1553
+ run_shampoo * shampoo_update_with_wd_momentum +
1554
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
1555
+
1556
+ wd_update = (
1557
+ run_shampoo * shampoo_update_with_wd +
1558
+ (1.0 - run_shampoo) * grafting_update_with_wd)
1559
+
1560
+ if nesterov:
1561
+ momentum_update = w * wd_update + beta1 * momentum_update
1562
+
1563
+ lr = learning_rate
1564
+ if callable(learning_rate):
1565
+ lr = learning_rate(step)
1566
+ transformed_update = -1.0 * lr * momentum_update
1567
+
1568
+ param_stats = ParameterStats(
1569
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1570
+ state.statistics, state.preconditioners,
1571
+ _quantize_momentum(grafting_update_with_wd_momentum),
1572
+ _quantize_momentum(shampoo_update_with_wd_momentum))
1573
+ return transformed_update, param_stats
1574
+
1575
+ def update_fn(grads, state, params):
1576
+ """Transform the input gradient and update all statistics.
1577
+
1578
+ Args:
1579
+ grads: the gradient tensors for the parameters.
1580
+ state: a named tuple containing the state of the optimizer
1581
+ params: the parameters that should be updated.
1582
+
1583
+ Returns:
1584
+ A tuple containing the new parameters and the new optimizer state.
1585
+ """
1586
+ params_flat, treedef = jax.tree_flatten(params)
1587
+ stats_flat = treedef.flatten_up_to(state.stats)
1588
+ grads_flat = treedef.flatten_up_to(grads)
1589
+
1590
+ new_stats_flat = jax.tree_multimap(
1591
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1592
+ stats_flat, params_flat)
1593
+ new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
1594
+ state.count)
1595
+
1596
+ outputs = jax.tree_multimap(
1597
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1598
+ new_stats_flat, params_flat)
1599
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1600
+
1601
+ updates = jax.tree_unflatten(treedef, updates_flat)
1602
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1603
+
1604
+ new_state = ShampooState(
1605
+ count=state.count+1, stats=new_stats)
1606
+ return updates, new_state
1607
+
1608
+ if shard_optimizer_states:
1609
+ return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
1610
+ else:
1611
+ return optax.GradientTransformation(init_fn, update_fn)
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1cc8ebedc59040ce75a0f2580660ea2f4740461de531d9c84bdef97630b6aa1e
3
  size 497764120
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8785c0613f57cbd0fdb77c6d7cd033dcf7f09564f92d30a51b9adcf591da8ef6
3
  size 497764120
run_clm_flax.py CHANGED
@@ -61,6 +61,8 @@ from transformers import (
61
  from transformers.file_utils import get_full_repo_name
62
  from transformers.testing_utils import CaptureLogger
63
 
 
 
64
 
65
  logger = logging.getLogger(__name__)
66
 
@@ -96,6 +98,9 @@ class TrainingArguments:
96
  adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
97
  adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
98
  adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
 
 
 
99
  num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
100
  warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
101
  warmup_ratio: float = field(default=0.0, metadata={"help": "Linear warmup ratio of total train steps."})
@@ -652,6 +657,33 @@ def main():
652
  optimizer = optax.adafactor(
653
  learning_rate=lr_schedule_fn,
654
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  else:
656
  optimizer = optax.adamw(
657
  learning_rate=lr_schedule_fn,
 
61
  from transformers.file_utils import get_full_repo_name
62
  from transformers.testing_utils import CaptureLogger
63
 
64
+ from distributed_shampoo import distributed_shampoo, GraftingType
65
+
66
 
67
  logger = logging.getLogger(__name__)
68
 
 
98
  adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
99
  adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
100
  adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
101
+ distributed_shampoo: bool = field(
102
+ default=False, metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
103
+ )
104
  num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
105
  warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
106
  warmup_ratio: float = field(default=0.0, metadata={"help": "Linear warmup ratio of total train steps."})
 
657
  optimizer = optax.adafactor(
658
  learning_rate=lr_schedule_fn,
659
  )
660
+
661
+ elif training_args.distributed_shampoo:
662
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
663
+ # Notes:
664
+ # - mask for weight decay is not implemented but we don't use it anyway
665
+ optimizer = distributed_shampoo(
666
+ lr_schedule_fn,
667
+ block_size=1024, # recommended default for large LM is 1536
668
+ beta1=0.9,
669
+ beta2=0.999,
670
+ diagonal_epsilon=1e-10,
671
+ matrix_epsilon=1e-8,
672
+ weight_decay=0.0,
673
+ start_preconditioning_step=1001,
674
+ preconditioning_compute_steps=10,
675
+ statistics_compute_steps=1,
676
+ best_effort_shape_interpretation=True,
677
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
678
+ nesterov=False,
679
+ exponent_override=0,
680
+ batch_axis_name="batch",
681
+ inverse_failure_threshold=0.1,
682
+ moving_average_for_momentum=True,
683
+ skip_preconditioning_dim_size_gt=4096,
684
+ clip_by_scaled_gradient_norm=None,
685
+ precision=jax.lax.Precision.HIGHEST,
686
+ )
687
  else:
688
  optimizer = optax.adamw(
689
  learning_rate=lr_schedule_fn,
runs/events.out.tfevents.1642099734.t1v-n-42145f73-w-0.2317757.0.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbbf5988858cf76da919fb1758c11b0ecc821048cb24a60adc8a53b7f332f86b
3
- size 15255487
 
 
 
 
runs/{events.out.tfevents.1642208918.t1v-n-42145f73-w-0.2567321.0.v2 → events.out.tfevents.1642236904.t1v-n-42145f73-w-0.2775834.0.v2} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26da4254e7068396f8f600c76f7aac5deda5b8f741de6eb8aa560fea2b2ed1e8
3
- size 2950215
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3e58f1ebeab5e80fee00375c5559d4f2276213fa5a9e715e12a9a59a491f1e0
3
+ size 1471449
start_train.sh CHANGED
@@ -13,13 +13,10 @@ python3 run_clm_flax.py \
13
  --per_device_train_batch_size="32" \
14
  --per_device_eval_batch_size="32" \
15
  --preprocessing_num_workers="1" \
16
- --learning_rate="5e-3" \
17
- --warmup_ratio="0.01" \
 
18
  --cosine_decay \
19
- --adam_beta1="0.9" \
20
- --adam_beta2="0.98" \
21
- --adam_epsilon="1e-8" \
22
- --weight_decay="0.01" \
23
  --overwrite_output_dir \
24
  --logging_steps="500" \
25
  --eval_steps="10000" \
 
13
  --per_device_train_batch_size="32" \
14
  --per_device_eval_batch_size="32" \
15
  --preprocessing_num_workers="1" \
16
+ --distributed_shampoo \
17
+ --learning_rate="1e-4" \
18
+ --warmup_steps="4000" \
19
  --cosine_decay \
 
 
 
 
20
  --overwrite_output_dir \
21
  --logging_steps="500" \
22
  --eval_steps="10000" \