|
"""Benchmarking and measurement utilities""" |
|
|
|
import pynvml |
|
import torch |
|
|
|
|
|
def gpu_memory_usage(device): |
|
if isinstance(device, torch.device): |
|
device = device.index |
|
if isinstance(device, str) and device.startswith("cuda:"): |
|
device = int(device[5:]) |
|
|
|
|
|
pynvml.nvmlInit() |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device) |
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle) |
|
return info.used / 1024.0**3 |
|
|
|
|
|
def log_gpu_memory_usage(log, msg, device): |
|
log.info( |
|
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2 |
|
) |
|
|