Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import torch | |
from PIL.Image import Image | |
from collections import OrderedDict | |
from scepter.modules.utils.distribute import we | |
from scepter.modules.utils.config import Config | |
from scepter.modules.utils.logger import get_logger | |
from scepter.studio.utils.env import get_available_memory | |
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS | |
from scepter.modules.utils.registry import Registry, build_from_config | |
def get_model(model_tuple): | |
assert 'model' in model_tuple | |
return model_tuple['model'] | |
class BaseInference(): | |
''' | |
support to load the components dynamicly. | |
create and load model when run this model at the first time. | |
''' | |
def __init__(self, cfg, logger=None): | |
if logger is None: | |
logger = get_logger(name='scepter') | |
self.logger = logger | |
self.name = cfg.NAME | |
def init_from_modules(self, modules): | |
for k, v in modules.items(): | |
self.__setattr__(k, v) | |
def infer_model(self, cfg, module_paras=None): | |
module = { | |
'model': None, | |
'cfg': cfg, | |
'device': 'offline', | |
'name': cfg.NAME, | |
'function_info': {}, | |
'paras': {} | |
} | |
if module_paras is None: | |
return module | |
function_info = {} | |
paras = { | |
k.lower(): v | |
for k, v in module_paras.get('PARAS', {}).items() | |
} | |
for function in module_paras.get('FUNCTION', []): | |
input_dict = {} | |
for inp in function.get('INPUT', []): | |
if inp.lower() in self.input: | |
input_dict[inp.lower()] = self.input[inp.lower()] | |
function_info[function.NAME] = { | |
'dtype': function.get('DTYPE', 'float32'), | |
'input': input_dict | |
} | |
module['paras'] = paras | |
module['function_info'] = function_info | |
return module | |
def init_from_ckpt(self, path, model, ignore_keys=list()): | |
if path.endswith('safetensors'): | |
from safetensors.torch import load_file as load_safetensors | |
sd = load_safetensors(path) | |
else: | |
sd = torch.load(path, map_location='cpu', weights_only=True) | |
new_sd = OrderedDict() | |
for k, v in sd.items(): | |
ignored = False | |
for ik in ignore_keys: | |
if ik in k: | |
if we.rank == 0: | |
self.logger.info( | |
'Ignore key {} from state_dict.'.format(k)) | |
ignored = True | |
break | |
if not ignored: | |
new_sd[k] = v | |
missing, unexpected = model.load_state_dict(new_sd, strict=False) | |
if we.rank == 0: | |
self.logger.info( | |
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' | |
) | |
if len(missing) > 0: | |
self.logger.info(f'Missing Keys:\n {missing}') | |
if len(unexpected) > 0: | |
self.logger.info(f'\nUnexpected Keys:\n {unexpected}') | |
def load(self, module): | |
if module['device'] == 'offline': | |
from scepter.modules.utils.import_utils import LazyImportModule | |
if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or | |
module['cfg'].NAME in MODELS.class_map): | |
model = MODELS.build(module['cfg'], logger=self.logger).eval() | |
elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or | |
module['cfg'].NAME in BACKBONES.class_map): | |
model = BACKBONES.build(module['cfg'], | |
logger=self.logger).eval() | |
elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or | |
module['cfg'].NAME in EMBEDDERS.class_map): | |
model = EMBEDDERS.build(module['cfg'], | |
logger=self.logger).eval() | |
else: | |
raise NotImplementedError | |
if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None: | |
model = model.to(getattr(torch, module['cfg'].DTYPE)) | |
if module['cfg'].get('RELOAD_MODEL', None): | |
self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model) | |
module['model'] = model | |
module['device'] = 'cpu' | |
if module['device'] == 'cpu': | |
module['device'] = we.device_id | |
module['model'] = module['model'].to(we.device_id) | |
return module | |
def unload(self, module): | |
if module is None: | |
return module | |
mem = get_available_memory() | |
free_mem = int(mem['available'] / (1024**2)) | |
total_mem = int(mem['total'] / (1024**2)) | |
if free_mem < 0.5 * total_mem: | |
if module['model'] is not None: | |
module['model'] = module['model'].to('cpu') | |
del module['model'] | |
module['model'] = None | |
module['device'] = 'offline' | |
print('delete module') | |
else: | |
if module['model'] is not None: | |
module['model'] = module['model'].to('cpu') | |
module['device'] = 'cpu' | |
else: | |
module['device'] = 'offline' | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
return module | |
def dynamic_load(self, module=None, name=''): | |
self.logger.info('Loading {} model'.format(name)) | |
if name == 'all': | |
for subname in self.loaded_model_name: | |
self.loaded_model[subname] = self.dynamic_load( | |
getattr(self, subname), subname) | |
elif name in self.loaded_model_name: | |
if name in self.loaded_model: | |
if module['cfg'] != self.loaded_model[name]['cfg']: | |
self.unload(self.loaded_model[name]) | |
module = self.load(module) | |
self.loaded_model[name] = module | |
return module | |
elif module['device'] == 'cpu' or module['device'] == 'offline': | |
module = self.load(module) | |
return module | |
else: | |
return module | |
else: | |
module = self.load(module) | |
self.loaded_model[name] = module | |
return module | |
else: | |
return self.load(module) | |
def dynamic_unload(self, module=None, name='', skip_loaded=False): | |
self.logger.info('Unloading {} model'.format(name)) | |
if name == 'all': | |
for name, module in self.loaded_model.items(): | |
module = self.unload(self.loaded_model[name]) | |
self.loaded_model[name] = module | |
elif name in self.loaded_model_name: | |
if name in self.loaded_model: | |
if not skip_loaded: | |
module = self.unload(self.loaded_model[name]) | |
self.loaded_model[name] = module | |
else: | |
self.unload(module) | |
else: | |
self.unload(module) | |
def load_default(self, cfg): | |
module_paras = {} | |
if cfg is not None: | |
self.paras = cfg.PARAS | |
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()} | |
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()} | |
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()} | |
module_paras = cfg.MODULES_PARAS | |
return module_paras | |
def load_image(self, image, num_samples=1): | |
if isinstance(image, torch.Tensor): | |
pass | |
elif isinstance(image, Image): | |
pass | |
elif isinstance(image, Image): | |
pass | |
def get_function_info(self, module, function_name=None): | |
all_function = module['function_info'] | |
if function_name in all_function: | |
return function_name, all_function[function_name]['dtype'] | |
if function_name is None and len(all_function) == 1: | |
for k, v in all_function.items(): | |
return k, v['dtype'] | |
def __call__(self, | |
input, | |
**kwargs): | |
return | |
def build_inference(cfg, registry, logger=None, *args, **kwargs): | |
""" After build model, load pretrained model if exists key `pretrain`. | |
pretrain (str, dict): Describes how to load pretrained model. | |
str, treat pretrain as model path; | |
dict: should contains key `path`, and other parameters token by function load_pretrained(); | |
""" | |
if not isinstance(cfg, Config): | |
raise TypeError(f'Config must be type dict, got {type(cfg)}') | |
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs) | |
return model | |
# reigister cls for diffusion. | |
INFERENCES = Registry('INFERENCE', build_func=build_inference) | |