Create optimization _tf.py
Browse files- optimization _tf.py +371 -0
optimization _tf.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Functions and classes related to optimization (weight updates)."""
|
16 |
+
|
17 |
+
|
18 |
+
import re
|
19 |
+
from typing import Callable, List, Optional, Union
|
20 |
+
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
from tensorflow.keras.optimizers.legacy import Adam
|
26 |
+
except ImportError:
|
27 |
+
from tensorflow.keras.optimizers import Adam
|
28 |
+
|
29 |
+
|
30 |
+
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
31 |
+
"""
|
32 |
+
Applies a warmup schedule on a given learning rate decay schedule.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
initial_learning_rate (`float`):
|
36 |
+
The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
|
37 |
+
of the warmup).
|
38 |
+
decay_schedule_fn (`Callable`):
|
39 |
+
The schedule function to apply after the warmup for the rest of training.
|
40 |
+
warmup_steps (`int`):
|
41 |
+
The number of steps for the warmup part of training.
|
42 |
+
power (`float`, *optional*, defaults to 1):
|
43 |
+
The power to use for the polynomial warmup (defaults is a linear warmup).
|
44 |
+
name (`str`, *optional*):
|
45 |
+
Optional name prefix for the returned tensors during the schedule.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
initial_learning_rate: float,
|
51 |
+
decay_schedule_fn: Callable,
|
52 |
+
warmup_steps: int,
|
53 |
+
power: float = 1.0,
|
54 |
+
name: str = None,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.initial_learning_rate = initial_learning_rate
|
58 |
+
self.warmup_steps = warmup_steps
|
59 |
+
self.power = power
|
60 |
+
self.decay_schedule_fn = decay_schedule_fn
|
61 |
+
self.name = name
|
62 |
+
|
63 |
+
def __call__(self, step):
|
64 |
+
with tf.name_scope(self.name or "WarmUp") as name:
|
65 |
+
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
|
66 |
+
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
67 |
+
global_step_float = tf.cast(step, tf.float32)
|
68 |
+
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
69 |
+
warmup_percent_done = global_step_float / warmup_steps_float
|
70 |
+
warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
|
71 |
+
return tf.cond(
|
72 |
+
global_step_float < warmup_steps_float,
|
73 |
+
lambda: warmup_learning_rate,
|
74 |
+
lambda: self.decay_schedule_fn(step - self.warmup_steps),
|
75 |
+
name=name,
|
76 |
+
)
|
77 |
+
|
78 |
+
def get_config(self):
|
79 |
+
return {
|
80 |
+
"initial_learning_rate": self.initial_learning_rate,
|
81 |
+
"decay_schedule_fn": self.decay_schedule_fn,
|
82 |
+
"warmup_steps": self.warmup_steps,
|
83 |
+
"power": self.power,
|
84 |
+
"name": self.name,
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
def create_optimizer(
|
89 |
+
init_lr: float,
|
90 |
+
num_train_steps: int,
|
91 |
+
num_warmup_steps: int,
|
92 |
+
min_lr_ratio: float = 0.0,
|
93 |
+
adam_beta1: float = 0.9,
|
94 |
+
adam_beta2: float = 0.999,
|
95 |
+
adam_epsilon: float = 1e-8,
|
96 |
+
adam_clipnorm: Optional[float] = None,
|
97 |
+
adam_global_clipnorm: Optional[float] = None,
|
98 |
+
weight_decay_rate: float = 0.0,
|
99 |
+
power: float = 1.0,
|
100 |
+
include_in_weight_decay: Optional[List[str]] = None,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
init_lr (`float`):
|
107 |
+
The desired learning rate at the end of the warmup phase.
|
108 |
+
num_train_steps (`int`):
|
109 |
+
The total number of training steps.
|
110 |
+
num_warmup_steps (`int`):
|
111 |
+
The number of warmup steps.
|
112 |
+
min_lr_ratio (`float`, *optional*, defaults to 0):
|
113 |
+
The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.
|
114 |
+
adam_beta1 (`float`, *optional*, defaults to 0.9):
|
115 |
+
The beta1 to use in Adam.
|
116 |
+
adam_beta2 (`float`, *optional*, defaults to 0.999):
|
117 |
+
The beta2 to use in Adam.
|
118 |
+
adam_epsilon (`float`, *optional*, defaults to 1e-8):
|
119 |
+
The epsilon to use in Adam.
|
120 |
+
adam_clipnorm: (`float`, *optional*, defaults to `None`):
|
121 |
+
If not `None`, clip the gradient norm for each weight tensor to this value.
|
122 |
+
adam_global_clipnorm: (`float`, *optional*, defaults to `None`)
|
123 |
+
If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all
|
124 |
+
weight tensors, as if they were concatenated into a single vector.
|
125 |
+
weight_decay_rate (`float`, *optional*, defaults to 0):
|
126 |
+
The weight decay to use.
|
127 |
+
power (`float`, *optional*, defaults to 1.0):
|
128 |
+
The power to use for PolynomialDecay.
|
129 |
+
include_in_weight_decay (`List[str]`, *optional*):
|
130 |
+
List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
|
131 |
+
applied to all parameters except bias and layer norm parameters.
|
132 |
+
"""
|
133 |
+
# Implements linear decay of the learning rate.
|
134 |
+
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
135 |
+
initial_learning_rate=init_lr,
|
136 |
+
decay_steps=num_train_steps - num_warmup_steps,
|
137 |
+
end_learning_rate=init_lr * min_lr_ratio,
|
138 |
+
power=power,
|
139 |
+
)
|
140 |
+
if num_warmup_steps:
|
141 |
+
lr_schedule = WarmUp(
|
142 |
+
initial_learning_rate=init_lr,
|
143 |
+
decay_schedule_fn=lr_schedule,
|
144 |
+
warmup_steps=num_warmup_steps,
|
145 |
+
)
|
146 |
+
if weight_decay_rate > 0.0:
|
147 |
+
optimizer = AdamWeightDecay(
|
148 |
+
learning_rate=lr_schedule,
|
149 |
+
weight_decay_rate=weight_decay_rate,
|
150 |
+
beta_1=adam_beta1,
|
151 |
+
beta_2=adam_beta2,
|
152 |
+
epsilon=adam_epsilon,
|
153 |
+
clipnorm=adam_clipnorm,
|
154 |
+
global_clipnorm=adam_global_clipnorm,
|
155 |
+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
156 |
+
include_in_weight_decay=include_in_weight_decay,
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
optimizer = tf.keras.optimizers.Adam(
|
160 |
+
learning_rate=lr_schedule,
|
161 |
+
beta_1=adam_beta1,
|
162 |
+
beta_2=adam_beta2,
|
163 |
+
epsilon=adam_epsilon,
|
164 |
+
clipnorm=adam_clipnorm,
|
165 |
+
global_clipnorm=adam_global_clipnorm,
|
166 |
+
)
|
167 |
+
# We return the optimizer and the LR scheduler in order to better track the
|
168 |
+
# evolution of the LR independently of the optimizer.
|
169 |
+
return optimizer, lr_schedule
|
170 |
+
|
171 |
+
|
172 |
+
class AdamWeightDecay(Adam):
|
173 |
+
"""
|
174 |
+
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
|
175 |
+
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
|
176 |
+
with the m and v parameters in strange ways as shown in [Decoupled Weight Decay
|
177 |
+
Regularization](https://arxiv.org/abs/1711.05101).
|
178 |
+
|
179 |
+
Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
|
180 |
+
to adding the square of the weights to the loss with plain (non-momentum) SGD.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
learning_rate (`Union[float, tf.keras.optimizers.schedules.LearningRateSchedule]`, *optional*, defaults to 1e-3):
|
184 |
+
The learning rate to use or a schedule.
|
185 |
+
beta_1 (`float`, *optional*, defaults to 0.9):
|
186 |
+
The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
|
187 |
+
beta_2 (`float`, *optional*, defaults to 0.999):
|
188 |
+
The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
|
189 |
+
epsilon (`float`, *optional*, defaults to 1e-7):
|
190 |
+
The epsilon parameter in Adam, which is a small constant for numerical stability.
|
191 |
+
amsgrad (`bool`, *optional*, default to `False`):
|
192 |
+
Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and
|
193 |
+
Beyond](https://arxiv.org/abs/1904.09237).
|
194 |
+
weight_decay_rate (`float`, *optional*, defaults to 0):
|
195 |
+
The weight decay to apply.
|
196 |
+
include_in_weight_decay (`List[str]`, *optional*):
|
197 |
+
List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
|
198 |
+
applied to all parameters by default (unless they are in `exclude_from_weight_decay`).
|
199 |
+
exclude_from_weight_decay (`List[str]`, *optional*):
|
200 |
+
List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
|
201 |
+
`include_in_weight_decay` is passed, the names in it will supersede this list.
|
202 |
+
name (`str`, *optional*, defaults to 'AdamWeightDecay'):
|
203 |
+
Optional name for the operations created when applying gradients.
|
204 |
+
kwargs:
|
205 |
+
Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
206 |
+
norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time
|
207 |
+
inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use
|
208 |
+
`learning_rate` instead.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
learning_rate: Union[float, tf.keras.optimizers.schedules.LearningRateSchedule] = 0.001,
|
214 |
+
beta_1: float = 0.9,
|
215 |
+
beta_2: float = 0.999,
|
216 |
+
epsilon: float = 1e-7,
|
217 |
+
amsgrad: bool = False,
|
218 |
+
weight_decay_rate: float = 0.0,
|
219 |
+
include_in_weight_decay: Optional[List[str]] = None,
|
220 |
+
exclude_from_weight_decay: Optional[List[str]] = None,
|
221 |
+
name: str = "AdamWeightDecay",
|
222 |
+
**kwargs,
|
223 |
+
):
|
224 |
+
super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
225 |
+
self.weight_decay_rate = weight_decay_rate
|
226 |
+
self._include_in_weight_decay = include_in_weight_decay
|
227 |
+
self._exclude_from_weight_decay = exclude_from_weight_decay
|
228 |
+
|
229 |
+
@classmethod
|
230 |
+
def from_config(cls, config):
|
231 |
+
"""Creates an optimizer from its config with WarmUp custom object."""
|
232 |
+
custom_objects = {"WarmUp": WarmUp}
|
233 |
+
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
|
234 |
+
|
235 |
+
def _prepare_local(self, var_device, var_dtype, apply_state):
|
236 |
+
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
|
237 |
+
apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
|
238 |
+
self.weight_decay_rate, name="adam_weight_decay_rate"
|
239 |
+
)
|
240 |
+
|
241 |
+
def _decay_weights_op(self, var, learning_rate, apply_state):
|
242 |
+
do_decay = self._do_use_weight_decay(var.name)
|
243 |
+
if do_decay:
|
244 |
+
return var.assign_sub(
|
245 |
+
learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
|
246 |
+
use_locking=self._use_locking,
|
247 |
+
)
|
248 |
+
return tf.no_op()
|
249 |
+
|
250 |
+
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
|
251 |
+
grads, tvars = list(zip(*grads_and_vars))
|
252 |
+
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
|
253 |
+
|
254 |
+
def _get_lr(self, var_device, var_dtype, apply_state):
|
255 |
+
"""Retrieves the learning rate with the given state."""
|
256 |
+
if apply_state is None:
|
257 |
+
return self._decayed_lr_t[var_dtype], {}
|
258 |
+
|
259 |
+
apply_state = apply_state or {}
|
260 |
+
coefficients = apply_state.get((var_device, var_dtype))
|
261 |
+
if coefficients is None:
|
262 |
+
coefficients = self._fallback_apply_state(var_device, var_dtype)
|
263 |
+
apply_state[(var_device, var_dtype)] = coefficients
|
264 |
+
|
265 |
+
return coefficients["lr_t"], {"apply_state": apply_state}
|
266 |
+
|
267 |
+
def _resource_apply_dense(self, grad, var, apply_state=None):
|
268 |
+
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
269 |
+
decay = self._decay_weights_op(var, lr_t, apply_state)
|
270 |
+
with tf.control_dependencies([decay]):
|
271 |
+
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
|
272 |
+
|
273 |
+
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
274 |
+
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
275 |
+
decay = self._decay_weights_op(var, lr_t, apply_state)
|
276 |
+
with tf.control_dependencies([decay]):
|
277 |
+
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
|
278 |
+
|
279 |
+
def get_config(self):
|
280 |
+
config = super().get_config()
|
281 |
+
config.update({"weight_decay_rate": self.weight_decay_rate})
|
282 |
+
return config
|
283 |
+
|
284 |
+
def _do_use_weight_decay(self, param_name):
|
285 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
286 |
+
if self.weight_decay_rate == 0:
|
287 |
+
return False
|
288 |
+
|
289 |
+
if self._include_in_weight_decay:
|
290 |
+
for r in self._include_in_weight_decay:
|
291 |
+
if re.search(r, param_name) is not None:
|
292 |
+
return True
|
293 |
+
|
294 |
+
if self._exclude_from_weight_decay:
|
295 |
+
for r in self._exclude_from_weight_decay:
|
296 |
+
if re.search(r, param_name) is not None:
|
297 |
+
return False
|
298 |
+
return True
|
299 |
+
|
300 |
+
|
301 |
+
# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
302 |
+
class GradientAccumulator(object):
|
303 |
+
"""
|
304 |
+
Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
|
305 |
+
replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
|
306 |
+
then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`.
|
307 |
+
"""
|
308 |
+
|
309 |
+
# We use the ON_READ synchronization policy so that no synchronization is
|
310 |
+
# performed on assignment. To get the value, we call .value() which returns the
|
311 |
+
# value on the current replica without synchronization.
|
312 |
+
|
313 |
+
def __init__(self):
|
314 |
+
"""Initializes the accumulator."""
|
315 |
+
self._gradients = []
|
316 |
+
self._accum_steps = None
|
317 |
+
|
318 |
+
@property
|
319 |
+
def step(self):
|
320 |
+
"""Number of accumulated steps."""
|
321 |
+
if self._accum_steps is None:
|
322 |
+
self._accum_steps = tf.Variable(
|
323 |
+
tf.constant(0, dtype=tf.int64),
|
324 |
+
trainable=False,
|
325 |
+
synchronization=tf.VariableSynchronization.ON_READ,
|
326 |
+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
327 |
+
)
|
328 |
+
|
329 |
+
return self._accum_steps.value()
|
330 |
+
|
331 |
+
@property
|
332 |
+
def gradients(self):
|
333 |
+
"""The accumulated gradients on the current replica."""
|
334 |
+
if not self._gradients:
|
335 |
+
raise ValueError("The accumulator should be called first to initialize the gradients")
|
336 |
+
return [gradient.value() if gradient is not None else gradient for gradient in self._gradients]
|
337 |
+
|
338 |
+
def __call__(self, gradients):
|
339 |
+
"""Accumulates `gradients` on the current replica."""
|
340 |
+
if not self._gradients:
|
341 |
+
_ = self.step # Create the step variable.
|
342 |
+
self._gradients.extend(
|
343 |
+
[
|
344 |
+
tf.Variable(
|
345 |
+
tf.zeros_like(gradient),
|
346 |
+
trainable=False,
|
347 |
+
synchronization=tf.VariableSynchronization.ON_READ,
|
348 |
+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
349 |
+
)
|
350 |
+
if gradient is not None
|
351 |
+
else gradient
|
352 |
+
for gradient in gradients
|
353 |
+
]
|
354 |
+
)
|
355 |
+
if len(gradients) != len(self._gradients):
|
356 |
+
raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}")
|
357 |
+
|
358 |
+
for accum_gradient, gradient in zip(self._gradients, gradients):
|
359 |
+
if accum_gradient is not None and gradient is not None:
|
360 |
+
accum_gradient.assign_add(gradient)
|
361 |
+
|
362 |
+
self._accum_steps.assign_add(1)
|
363 |
+
|
364 |
+
def reset(self):
|
365 |
+
"""Resets the accumulated gradients on the current replica."""
|
366 |
+
if not self._gradients:
|
367 |
+
return
|
368 |
+
self._accum_steps.assign(0)
|
369 |
+
for gradient in self._gradients:
|
370 |
+
if gradient is not None:
|
371 |
+
gradient.assign(tf.zeros_like(gradient))
|