import gc import os import psutil import torch def print_memory_usage(): process = psutil.Process(os.getpid()) print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB") print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB") def clear_cuda_and_gc(): print_memory_usage() print("Clearing cuda and gc") clear_gc() clear_cuda() print_memory_usage() def clear_cuda(): with torch.no_grad(): torch.cuda.empty_cache() def clear_gc(): gc.collect() def auto_clear_cuda_and_gc(controlnet): def auto_clear_cuda_and_gc_wrapper(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: controlnet.cleanup() clear_cuda_and_gc() raise e return wrapper return auto_clear_cuda_and_gc_wrapper