|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import datetime |
|
import os |
|
import random |
|
import subprocess |
|
import time |
|
from multiprocessing import JoinableQueue, Process |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import yaml |
|
from mmcv import Config |
|
from mmcv.runner import get_dist_info |
|
|
|
from diffusion.utils.dist_utils import get_rank |
|
from diffusion.utils.logger import get_root_logger |
|
|
|
os.environ["MOX_SILENT_MODE"] = "1" |
|
|
|
|
|
class SafeLoaderWithTuple(yaml.SafeLoader): |
|
"""A yaml safe loader with python tuple loading capabilities.""" |
|
|
|
def construct_python_tuple(self, node): |
|
return tuple(self.construct_sequence(node)) |
|
|
|
|
|
SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple) |
|
|
|
|
|
def read_yaml_config(file, base_dir=None): |
|
|
|
with open(file, encoding="utf-8") as f: |
|
yaml_config = yaml.load(f, Loader=SafeLoaderWithTuple) |
|
|
|
if "_base_" in yaml_config and yaml_config["_base_"]: |
|
if base_dir is None: |
|
base_dir = os.path.dirname(file) |
|
|
|
base_files = yaml_config["_base_"] |
|
if isinstance(base_files, str): |
|
base_files = [base_files] |
|
|
|
base_config = None |
|
for base_file in base_files: |
|
base_file = os.path.join(base_dir, base_file) |
|
if base_config is None: |
|
base_config = read_config(base_file, base_dir) |
|
else: |
|
curr_config = read_config(base_file, base_dir) |
|
base_config.merge_from_dict(curr_config) |
|
|
|
base_config.merge_from_dict(yaml_config) |
|
return base_config |
|
else: |
|
return Config(yaml_config) |
|
|
|
|
|
def read_config(file, base_dir=None): |
|
while True: |
|
if file.endswith(".yaml") or file.endswith(".yml"): |
|
config = read_yaml_config(file, base_dir) |
|
else: |
|
config = Config.fromfile(file) |
|
|
|
if len(config) == 0: |
|
time.sleep(0.1) |
|
continue |
|
break |
|
return config |
|
|
|
|
|
def init_random_seed(seed=None, device="cuda"): |
|
"""Initialize random seed. |
|
|
|
If the seed is not set, the seed will be automatically randomized, |
|
and then broadcast to all processes to prevent some potential bugs. |
|
|
|
Args: |
|
seed (int, Optional): The seed. Default to None. |
|
device (str): The device where the seed will be put on. |
|
Default to 'cuda'. |
|
|
|
Returns: |
|
int: Seed to be used. |
|
""" |
|
if seed is not None: |
|
return seed |
|
|
|
|
|
|
|
|
|
rank, world_size = get_dist_info() |
|
seed = np.random.randint(2**31) |
|
if world_size == 1: |
|
return seed |
|
|
|
if rank == 0: |
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
|
else: |
|
random_num = torch.tensor(0, dtype=torch.int32, device=device) |
|
dist.broadcast(random_num, src=0) |
|
return random_num.item() |
|
|
|
|
|
def set_random_seed(seed, deterministic=False): |
|
"""Set random seed. |
|
|
|
Args: |
|
seed (int): Seed to be used. |
|
deterministic (bool): Whether to set the deterministic option for |
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` |
|
to True and `torch.backends.cudnn.benchmark` to False. |
|
Default: False. |
|
""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
if deterministic: |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
class SimpleTimer: |
|
def __init__(self, num_tasks, log_interval=1, desc="Process"): |
|
self.num_tasks = num_tasks |
|
self.desc = desc |
|
self.count = 0 |
|
self.log_interval = log_interval |
|
self.start_time = time.time() |
|
self.logger = get_root_logger() |
|
|
|
def log(self, n=1): |
|
self.count += n |
|
if (self.count % self.log_interval) == 0 or self.count == self.num_tasks: |
|
time_elapsed = time.time() - self.start_time |
|
avg_time = time_elapsed / self.count |
|
eta_sec = avg_time * (self.num_tasks - self.count) |
|
eta_str = str(datetime.timedelta(seconds=int(eta_sec))) |
|
elapsed_str = str(datetime.timedelta(seconds=int(time_elapsed))) |
|
log_info = ( |
|
f"{self.desc} [{self.count}/{self.num_tasks}], elapsed_time:{elapsed_str}," |
|
f" avg_time: {avg_time}, eta: {eta_str}." |
|
) |
|
self.logger.info(log_info) |
|
|
|
|
|
class DebugUnderflowOverflow: |
|
""" |
|
This debug class helps detect and understand where the model starts getting very large or very small, and more |
|
importantly `nan` or `inf` weight and activation elements. |
|
There are 2 working modes: |
|
1. Underflow/overflow detection (default) |
|
2. Specific batch absolute min/max tracing without detection |
|
Mode 1: Underflow/overflow detection |
|
To activate the underflow/overflow detection, initialize the object with the model : |
|
```python |
|
debug_overflow = DebugUnderflowOverflow(model) |
|
``` |
|
then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or |
|
output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this |
|
event, each frame reporting |
|
1. the fully qualified module name plus the class name whose `forward` was run |
|
2. the absolute min and max value of all elements for each module weights, and the inputs and output |
|
For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision : |
|
``` |
|
Detected inf/nan during batch_number=0 |
|
Last 21 forward frames: |
|
abs min abs max metadata |
|
[...] |
|
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear |
|
2.17e-07 4.50e+00 weight |
|
1.79e-06 4.65e+00 input[0] |
|
2.68e-06 3.70e+01 output |
|
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear |
|
8.08e-07 2.66e+01 weight |
|
1.79e-06 4.65e+00 input[0] |
|
1.27e-04 2.37e+02 output |
|
encoder.block.2.layer.1.DenseReluDense.wo Linear |
|
1.01e-06 6.44e+00 weight |
|
0.00e+00 9.74e+03 input[0] |
|
3.18e-04 6.27e+04 output |
|
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense |
|
1.79e-06 4.65e+00 input[0] |
|
3.18e-04 6.27e+04 output |
|
encoder.block.2.layer.1.dropout Dropout |
|
3.18e-04 6.27e+04 input[0] |
|
0.00e+00 inf output |
|
``` |
|
You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value |
|
was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which |
|
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than |
|
64K, and we get an overlow. |
|
As you can see it's the previous frames that we need to look into when the numbers start going into very large for |
|
fp16 numbers. |
|
The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed. |
|
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example : |
|
```python |
|
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) |
|
``` |
|
To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may |
|
take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next |
|
section. |
|
Mode 2. Specific batch absolute min/max tracing without detection |
|
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. |
|
Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a |
|
given batch, and only do that for batches 1 and 3. Then you instantiate this class as : |
|
```python |
|
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) |
|
``` |
|
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. |
|
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can |
|
fast-forward right to that area. |
|
Early stopping: |
|
You can also specify the batch number after which to stop the training, with : |
|
```python |
|
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3) |
|
``` |
|
This feature is mainly useful in the tracing mode, but you can use it for any mode. |
|
**Performance**: |
|
As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the |
|
training down. Therefore remember to turn it off once the debugging needs have been met. |
|
Args: |
|
model (`nn.Module`): |
|
The model to debug. |
|
max_frames_to_save (`int`, *optional*, defaults to 21): |
|
How many frames back to record |
|
trace_batch_nums(`List[int]`, *optional*, defaults to `[]`): |
|
Which batch numbers to trace (turns detection off) |
|
abort_after_batch_num (`int``, *optional*): |
|
Whether to abort after a certain batch number has finished |
|
""" |
|
|
|
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): |
|
self.model = model |
|
self.trace_batch_nums = trace_batch_nums |
|
self.abort_after_batch_num = abort_after_batch_num |
|
|
|
|
|
self.frames = collections.deque([], max_frames_to_save) |
|
self.frame = [] |
|
self.batch_number = 0 |
|
self.total_calls = 0 |
|
self.detected_overflow = False |
|
self.prefix = " " |
|
|
|
self.analyse_model() |
|
|
|
self.register_forward_hook() |
|
|
|
def save_frame(self, frame=None): |
|
if frame is not None: |
|
self.expand_frame(frame) |
|
self.frames.append("\n".join(self.frame)) |
|
self.frame = [] |
|
|
|
def expand_frame(self, line): |
|
self.frame.append(line) |
|
|
|
def trace_frames(self): |
|
print("\n".join(self.frames)) |
|
self.frames = [] |
|
|
|
def reset_saved_frames(self): |
|
self.frames = [] |
|
|
|
def dump_saved_frames(self): |
|
print(f"\nDetected inf/nan during batch_number={self.batch_number}") |
|
print(f"Last {len(self.frames)} forward frames:") |
|
print(f"{'abs min':8} {'abs max':8} metadata") |
|
print("\n".join(self.frames)) |
|
print("\n\n") |
|
self.frames = [] |
|
|
|
def analyse_model(self): |
|
|
|
|
|
|
|
|
|
self.module_names = {m: name for name, m in self.model.named_modules()} |
|
|
|
|
|
def analyse_variable(self, var, ctx): |
|
if torch.is_tensor(var): |
|
self.expand_frame(self.get_abs_min_max(var, ctx)) |
|
if self.detect_overflow(var, ctx): |
|
self.detected_overflow = True |
|
elif var is None: |
|
self.expand_frame(f"{'None':>17} {ctx}") |
|
else: |
|
self.expand_frame(f"{'not a tensor':>17} {ctx}") |
|
|
|
def batch_start_frame(self): |
|
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") |
|
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") |
|
|
|
def batch_end_frame(self): |
|
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n") |
|
|
|
def create_frame(self, module, input, output): |
|
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") |
|
|
|
|
|
for name, p in module.named_parameters(recurse=False): |
|
self.analyse_variable(p, name) |
|
|
|
|
|
if isinstance(input, tuple): |
|
for i, x in enumerate(input): |
|
self.analyse_variable(x, f"input[{i}]") |
|
else: |
|
self.analyse_variable(input, "input") |
|
|
|
|
|
if isinstance(output, tuple): |
|
for i, x in enumerate(output): |
|
|
|
if isinstance(x, tuple): |
|
for j, y in enumerate(x): |
|
self.analyse_variable(y, f"output[{i}][{j}]") |
|
else: |
|
self.analyse_variable(x, f"output[{i}]") |
|
else: |
|
self.analyse_variable(output, "output") |
|
|
|
self.save_frame() |
|
|
|
def register_forward_hook(self): |
|
self.model.apply(self._register_forward_hook) |
|
|
|
def _register_forward_hook(self, module): |
|
module.register_forward_hook(self.forward_hook) |
|
|
|
def forward_hook(self, module, input, output): |
|
|
|
|
|
|
|
last_frame_of_batch = False |
|
|
|
trace_mode = True if self.batch_number in self.trace_batch_nums else False |
|
if trace_mode: |
|
self.reset_saved_frames() |
|
|
|
if self.total_calls == 0: |
|
self.batch_start_frame() |
|
self.total_calls += 1 |
|
|
|
|
|
|
|
if module == self.model: |
|
self.batch_number += 1 |
|
last_frame_of_batch = True |
|
|
|
self.create_frame(module, input, output) |
|
|
|
|
|
|
|
|
|
if trace_mode: |
|
self.trace_frames() |
|
|
|
if last_frame_of_batch: |
|
self.batch_start_frame() |
|
|
|
if self.detected_overflow and not trace_mode: |
|
self.dump_saved_frames() |
|
|
|
|
|
raise ValueError( |
|
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " |
|
"Please scroll up above this traceback to see the activation values prior to this event." |
|
) |
|
|
|
|
|
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: |
|
raise ValueError( |
|
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to" |
|
f" `abort_after_batch_num={self.abort_after_batch_num}` arg" |
|
) |
|
|
|
@staticmethod |
|
def get_abs_min_max(var, ctx): |
|
abs_var = var.abs() |
|
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" |
|
|
|
@staticmethod |
|
def detect_overflow(var, ctx): |
|
""" |
|
Report whether the tensor contains any `nan` or `inf` entries. |
|
This is useful for detecting overflows/underflows and best to call right after the function that did some math that |
|
modified the tensor in question. |
|
This function contains a few other helper features that you can enable and tweak directly if you want to track |
|
various other things. |
|
Args: |
|
var: the tensor variable to check |
|
ctx: the message to print as a context |
|
Return: |
|
`True` if `inf` or `nan` was detected, `False` otherwise |
|
""" |
|
detected = False |
|
if torch.isnan(var).any().item(): |
|
detected = True |
|
print(f"{ctx} has nans") |
|
if torch.isinf(var).any().item(): |
|
detected = True |
|
print(f"{ctx} has infs") |
|
if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item(): |
|
detected = True |
|
print(f"{ctx} has overflow values {var.abs().max().item()}.") |
|
|
|
if 0: |
|
n100 = var[torch.ge(var.abs(), 100)] |
|
if n100.numel() > 0: |
|
print(f"{ctx}: n100={n100.numel()}") |
|
n1000 = var[torch.ge(var.abs(), 1000)] |
|
if n1000.numel() > 0: |
|
print(f"{ctx}: n1000={n1000.numel()}") |
|
n10000 = var[torch.ge(var.abs(), 10000)] |
|
if n10000.numel() > 0: |
|
print(f"{ctx}: n10000={n10000.numel()}") |
|
|
|
if 0: |
|
print(f"min={var.min():9.2e} max={var.max():9.2e}") |
|
|
|
if 0: |
|
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") |
|
|
|
return detected |
|
|