File size: 655 Bytes
e303d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""Benchmarking and measurement utilities"""

import pynvml
import torch


def gpu_memory_usage(device):
    if isinstance(device, torch.device):
        device = device.index
    if isinstance(device, str) and device.startswith("cuda:"):
        device = int(device[5:])

    # NB torch.cuda.memory_usage returns zero so we use lower level api
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
    info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return info.used / 1024.0**3


def log_gpu_memory_usage(log, msg, device):
    log.info(
        f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
    )