|
import gc |
|
import os |
|
|
|
import psutil |
|
import torch |
|
|
|
|
|
def print_memory_usage(): |
|
process = psutil.Process(os.getpid()) |
|
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB") |
|
print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB") |
|
|
|
|
|
def clear_cuda_and_gc(): |
|
print_memory_usage() |
|
print("Clearing cuda and gc") |
|
clear_gc() |
|
clear_cuda() |
|
print_memory_usage() |
|
|
|
|
|
def clear_cuda(): |
|
with torch.no_grad(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def clear_gc(): |
|
gc.collect() |
|
|
|
|
|
def auto_clear_cuda_and_gc(controlnet): |
|
def auto_clear_cuda_and_gc_wrapper(func): |
|
def wrapper(*args, **kwargs): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
controlnet.cleanup() |
|
clear_cuda_and_gc() |
|
raise e |
|
|
|
return wrapper |
|
|
|
return auto_clear_cuda_and_gc_wrapper |
|
|