radames's picture
first
bb0f5a9
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Helper wrapper for a Tensorflow optimizer."""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import List, Union
from . import autosummary
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
try:
# TensorFlow 1.13
from tensorflow.python.ops import nccl_ops
except:
# Older TensorFlow versions
import tensorflow.contrib.nccl as nccl_ops
class Optimizer:
"""A Wrapper for tf.train.Optimizer.
Automatically takes care of:
- Gradient averaging for multi-GPU training.
- Gradient accumulation for arbitrarily large minibatches.
- Dynamic loss scaling and typecasts for FP16 training.
- Ignoring corrupted gradients that contain NaNs/Infs.
- Reporting statistics.
- Well-chosen default settings.
"""
def __init__(self,
# Name string that will appear in TensorFlow graph.
name: str = "Train",
# Underlying optimizer class.
tf_optimizer: str = "tf.train.AdamOptimizer",
# Learning rate. Can vary over time.
learning_rate: TfExpressionEx = 0.001,
# Treat N consecutive minibatches as one by accumulating gradients.
minibatch_multiplier: TfExpressionEx = None,
# Share internal state with a previously created optimizer?
share: "Optimizer" = None,
# Enable dynamic loss scaling for robust mixed-precision training?
use_loss_scaling: bool = False,
# Log2 of initial loss scaling factor.
loss_scaling_init: float = 64.0,
# Log2 of per-minibatch loss scaling increment when there is no overflow.
loss_scaling_inc: float = 0.0005,
# Log2 of per-minibatch loss scaling decrement when there is an overflow.
loss_scaling_dec: float = 1.0,
# Report fine-grained memory usage statistics in TensorBoard?
report_mem_usage: bool = False,
**kwargs):
# Public fields.
self.name = name
self.learning_rate = learning_rate
self.minibatch_multiplier = minibatch_multiplier
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
self.optimizer_kwargs = dict(kwargs)
self.use_loss_scaling = use_loss_scaling
self.loss_scaling_init = loss_scaling_init
self.loss_scaling_inc = loss_scaling_inc
self.loss_scaling_dec = loss_scaling_dec
# Private fields.
self._updates_applied = False
self._devices = OrderedDict() # device_name => EasyDict()
self._shared_optimizers = OrderedDict() # device_name => optimizer_class
self._gradient_shapes = None # [shape, ...]
self._report_mem_usage = report_mem_usage
# Validate arguments.
assert callable(self.optimizer_class)
# Share internal state if requested.
if share is not None:
assert isinstance(share, Optimizer)
assert self.optimizer_class is share.optimizer_class
assert self.learning_rate is share.learning_rate
assert self.optimizer_kwargs == share.optimizer_kwargs
self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
def _get_device(self, device_name: str):
"""Get internal state for the given TensorFlow device."""
tfutil.assert_tf_initialized()
if device_name in self._devices:
return self._devices[device_name]
# Initialize fields.
device = util.EasyDict()
device.name = device_name
device.optimizer = None # Underlying optimizer: optimizer_class
device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
# Raw gradients: var => [grad, ...]
device.grad_raw = OrderedDict()
device.grad_clean = OrderedDict() # Clean gradients: var => grad
# Accumulation sums: var => tf.Variable
device.grad_acc_vars = OrderedDict()
device.grad_acc_count = None # Accumulation counter: tf.Variable
device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
# Setup TensorFlow objects.
with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
if device_name not in self._shared_optimizers:
optimizer_name = self.scope.replace(
"/", "_") + "_opt%d" % len(self._shared_optimizers)
self._shared_optimizers[device_name] = self.optimizer_class(
name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
device.optimizer = self._shared_optimizers[device_name]
if self.use_loss_scaling:
device.loss_scaling_var = tf.Variable(np.float32(
self.loss_scaling_init), trainable=False, name="loss_scaling_var")
# Register device.
self._devices[device_name] = device
return device
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
"""Register the gradients of the given loss function with respect to the given variables.
Intended to be called once per GPU."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
device = self._get_device(loss.device)
# Validate trainables.
if isinstance(trainable_vars, dict):
# allow passing in Network.trainables as vars
trainable_vars = list(trainable_vars.values())
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
assert all(tfutil.is_tf_expression(expr)
for expr in trainable_vars + [loss])
assert all(var.device == device.name for var in trainable_vars)
# Validate shapes.
if self._gradient_shapes is None:
self._gradient_shapes = [var.shape.as_list()
for var in trainable_vars]
assert len(trainable_vars) == len(self._gradient_shapes)
assert all(var.shape.as_list() == var_shape for var,
var_shape in zip(trainable_vars, self._gradient_shapes))
# Report memory usage if requested.
deps = []
if self._report_mem_usage:
self._report_mem_usage = False
try:
with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
deps.append(autosummary.autosummary(
self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
except tf.errors.NotFoundError:
pass
# Compute gradients.
with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
grad_list = device.optimizer.compute_gradients(
loss=loss, var_list=trainable_vars, gate_gradients=gate)
# Register gradients.
for grad, var in grad_list:
if var not in device.grad_raw:
device.grad_raw[var] = []
device.grad_raw[var].append(grad)
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
"""Construct training op to update the registered variables based on their gradients."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
self._updates_applied = True
all_ops = []
# Check for no-op.
if allow_no_op and len(self._devices) == 0:
with tfutil.absolute_name_scope(self.scope):
return tf.no_op(name='TrainingOp')
# Clean up gradients.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
for var, grad in device.grad_raw.items():
# Filter out disconnected gradients and convert to float32.
grad = [g for g in grad if g is not None]
grad = [tf.cast(g, tf.float32) for g in grad]
# Sum within the device.
if len(grad) == 0:
grad = tf.zeros(var.shape) # No gradients => zero.
elif len(grad) == 1:
# Single gradient => use as is.
grad = grad[0]
else:
# Multiple gradients => sum.
grad = tf.add_n(grad)
# Scale as needed.
scale = 1.0 / \
len(device.grad_raw[var]) / len(self._devices)
scale = tf.constant(scale, dtype=tf.float32, name="scale")
if self.minibatch_multiplier is not None:
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
scale = self.undo_loss_scaling(scale)
device.grad_clean[var] = grad * scale
# Sum gradients across devices.
if len(self._devices) > 1:
with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
# NCCL does not support zero-sized tensors.
if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()):
all_grads = [device.grad_clean[var] for device, var in zip(
self._devices.values(), all_vars)]
all_grads = nccl_ops.all_sum(all_grads)
for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
device.grad_clean[var] = grad
# Apply updates separately on each device.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
# pylint: disable=cell-var-from-loop
# Accumulate gradients over time.
if self.minibatch_multiplier is None:
acc_ok = tf.constant(True, name='acc_ok')
device.grad_acc = OrderedDict(device.grad_clean)
else:
# Create variables.
with tf.control_dependencies(None):
for var in device.grad_clean.keys():
device.grad_acc_vars[var] = tf.Variable(
tf.zeros(var.shape), trainable=False, name="grad_acc_var")
device.grad_acc_count = tf.Variable(
tf.zeros([]), trainable=False, name="grad_acc_count")
# Track counter.
count_cur = device.grad_acc_count + 1.0
def count_inc_op(): return tf.assign(device.grad_acc_count, count_cur)
def count_reset_op(): return tf.assign(device.grad_acc_count, tf.zeros([]))
acc_ok = (count_cur >= tf.cast(
self.minibatch_multiplier, tf.float32))
all_ops.append(
tf.cond(acc_ok, count_reset_op, count_inc_op))
# Track gradients.
for var, grad in device.grad_clean.items():
acc_var = device.grad_acc_vars[var]
acc_cur = acc_var + grad
device.grad_acc[var] = acc_cur
with tf.control_dependencies([acc_cur]):
def acc_inc_op(): return tf.assign(acc_var, acc_cur)
def acc_reset_op(): return tf.assign(acc_var, tf.zeros(var.shape))
all_ops.append(
tf.cond(acc_ok, acc_reset_op, acc_inc_op))
# No overflow => apply gradients.
all_ok = tf.reduce_all(tf.stack(
[acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
def apply_op(): return device.optimizer.apply_gradients(
[(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
# Adjust loss scaling.
if self.use_loss_scaling:
def ls_inc_op(): return tf.assign_add(
device.loss_scaling_var, self.loss_scaling_inc)
def ls_dec_op(): return tf.assign_sub(
device.loss_scaling_var, self.loss_scaling_dec)
def ls_update_op(): return tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
# Last device => report statistics.
if device_idx == len(self._devices) - 1:
all_ops.append(autosummary.autosummary(
self.id + "/learning_rate", self.learning_rate))
all_ops.append(autosummary.autosummary(
self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
if self.use_loss_scaling:
all_ops.append(autosummary.autosummary(
self.id + "/loss_scaling_log2", device.loss_scaling_var))
# Initialize variables.
self.reset_optimizer_state()
if self.use_loss_scaling:
tfutil.init_uninitialized_vars(
[device.loss_scaling_var for device in self._devices.values()])
if self.minibatch_multiplier is not None:
tfutil.run([var.initializer for device in self._devices.values() for var in list(
device.grad_acc_vars.values()) + [device.grad_acc_count]])
# Group everything into a single op.
with tfutil.absolute_name_scope(self.scope):
return tf.group(*all_ops, name="TrainingOp")
def reset_optimizer_state(self) -> None:
"""Reset internal state of the underlying optimizer."""
tfutil.assert_tf_initialized()
tfutil.run([var.initializer for device in self._devices.values()
for var in device.optimizer.variables()])
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
return self._get_device(device).loss_scaling_var
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Apply dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Undo the effect of dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
class SimpleAdam:
"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
self.name = name
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.all_state_vars = []
def variables(self):
return self.all_state_vars
def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
assert gate_gradients == tf.train.Optimizer.GATE_NONE
return list(zip(tf.gradients(loss, var_list), var_list))
def apply_gradients(self, grads_and_vars):
with tf.name_scope(self.name):
state_vars = []
update_ops = []
# Adjust learning rate to deal with startup bias.
with tf.control_dependencies(None):
b1pow_var = tf.Variable(
dtype=tf.float32, initial_value=1, trainable=False)
b2pow_var = tf.Variable(
dtype=tf.float32, initial_value=1, trainable=False)
state_vars += [b1pow_var, b2pow_var]
b1pow_new = b1pow_var * self.beta1
b2pow_new = b2pow_var * self.beta2
update_ops += [tf.assign(b1pow_var, b1pow_new),
tf.assign(b2pow_var, b2pow_new)]
lr_new = self.learning_rate * \
tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
# Construct ops to update each variable.
for grad, var in grads_and_vars:
with tf.control_dependencies(None):
m_var = tf.Variable(
dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
v_var = tf.Variable(
dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
state_vars += [m_var, v_var]
m_new = self.beta1 * m_var + (1 - self.beta1) * grad
v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
update_ops += [tf.assign(m_var, m_new), tf.assign(v_var,
v_new), tf.assign_sub(var, var_delta)]
# Group everything together.
self.all_state_vars += state_vars
return tf.group(*update_ops)