fc-simple / inference /registry.py
ekhatskevich
initial commit
9235b7f
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from PIL.Image import Image
from collections import OrderedDict
from scepter.modules.utils.distribute import we
from scepter.modules.utils.config import Config
from scepter.modules.utils.logger import get_logger
from scepter.studio.utils.env import get_available_memory
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
from scepter.modules.utils.registry import Registry, build_from_config
def get_model(model_tuple):
assert 'model' in model_tuple
return model_tuple['model']
class BaseInference():
'''
support to load the components dynamicly.
create and load model when run this model at the first time.
'''
def __init__(self, cfg, logger=None):
if logger is None:
logger = get_logger(name='scepter')
self.logger = logger
self.name = cfg.NAME
def init_from_modules(self, modules):
for k, v in modules.items():
self.__setattr__(k, v)
def infer_model(self, cfg, module_paras=None):
module = {
'model': None,
'cfg': cfg,
'device': 'offline',
'name': cfg.NAME,
'function_info': {},
'paras': {}
}
if module_paras is None:
return module
function_info = {}
paras = {
k.lower(): v
for k, v in module_paras.get('PARAS', {}).items()
}
for function in module_paras.get('FUNCTION', []):
input_dict = {}
for inp in function.get('INPUT', []):
if inp.lower() in self.input:
input_dict[inp.lower()] = self.input[inp.lower()]
function_info[function.NAME] = {
'dtype': function.get('DTYPE', 'float32'),
'input': input_dict
}
module['paras'] = paras
module['function_info'] = function_info
return module
def init_from_ckpt(self, path, model, ignore_keys=list()):
if path.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors
sd = load_safetensors(path)
else:
sd = torch.load(path, map_location='cpu', weights_only=True)
new_sd = OrderedDict()
for k, v in sd.items():
ignored = False
for ik in ignore_keys:
if ik in k:
if we.rank == 0:
self.logger.info(
'Ignore key {} from state_dict.'.format(k))
ignored = True
break
if not ignored:
new_sd[k] = v
missing, unexpected = model.load_state_dict(new_sd, strict=False)
if we.rank == 0:
self.logger.info(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
self.logger.info(f'Missing Keys:\n {missing}')
if len(unexpected) > 0:
self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
def load(self, module):
if module['device'] == 'offline':
from scepter.modules.utils.import_utils import LazyImportModule
if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
module['cfg'].NAME in MODELS.class_map):
model = MODELS.build(module['cfg'], logger=self.logger).eval()
elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
module['cfg'].NAME in BACKBONES.class_map):
model = BACKBONES.build(module['cfg'],
logger=self.logger).eval()
elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
module['cfg'].NAME in EMBEDDERS.class_map):
model = EMBEDDERS.build(module['cfg'],
logger=self.logger).eval()
else:
raise NotImplementedError
if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
model = model.to(getattr(torch, module['cfg'].DTYPE))
if module['cfg'].get('RELOAD_MODEL', None):
self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
module['model'] = model
module['device'] = 'cpu'
if module['device'] == 'cpu':
module['device'] = we.device_id
module['model'] = module['model'].to(we.device_id)
return module
def unload(self, module):
if module is None:
return module
mem = get_available_memory()
free_mem = int(mem['available'] / (1024**2))
total_mem = int(mem['total'] / (1024**2))
if free_mem < 0.5 * total_mem:
if module['model'] is not None:
module['model'] = module['model'].to('cpu')
del module['model']
module['model'] = None
module['device'] = 'offline'
print('delete module')
else:
if module['model'] is not None:
module['model'] = module['model'].to('cpu')
module['device'] = 'cpu'
else:
module['device'] = 'offline'
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return module
def dynamic_load(self, module=None, name=''):
self.logger.info('Loading {} model'.format(name))
if name == 'all':
for subname in self.loaded_model_name:
self.loaded_model[subname] = self.dynamic_load(
getattr(self, subname), subname)
elif name in self.loaded_model_name:
if name in self.loaded_model:
if module['cfg'] != self.loaded_model[name]['cfg']:
self.unload(self.loaded_model[name])
module = self.load(module)
self.loaded_model[name] = module
return module
elif module['device'] == 'cpu' or module['device'] == 'offline':
module = self.load(module)
return module
else:
return module
else:
module = self.load(module)
self.loaded_model[name] = module
return module
else:
return self.load(module)
def dynamic_unload(self, module=None, name='', skip_loaded=False):
self.logger.info('Unloading {} model'.format(name))
if name == 'all':
for name, module in self.loaded_model.items():
module = self.unload(self.loaded_model[name])
self.loaded_model[name] = module
elif name in self.loaded_model_name:
if name in self.loaded_model:
if not skip_loaded:
module = self.unload(self.loaded_model[name])
self.loaded_model[name] = module
else:
self.unload(module)
else:
self.unload(module)
def load_default(self, cfg):
module_paras = {}
if cfg is not None:
self.paras = cfg.PARAS
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
module_paras = cfg.MODULES_PARAS
return module_paras
def load_image(self, image, num_samples=1):
if isinstance(image, torch.Tensor):
pass
elif isinstance(image, Image):
pass
elif isinstance(image, Image):
pass
def get_function_info(self, module, function_name=None):
all_function = module['function_info']
if function_name in all_function:
return function_name, all_function[function_name]['dtype']
if function_name is None and len(all_function) == 1:
for k, v in all_function.items():
return k, v['dtype']
@torch.no_grad()
def __call__(self,
input,
**kwargs):
return
def build_inference(cfg, registry, logger=None, *args, **kwargs):
""" After build model, load pretrained model if exists key `pretrain`.
pretrain (str, dict): Describes how to load pretrained model.
str, treat pretrain as model path;
dict: should contains key `path`, and other parameters token by function load_pretrained();
"""
if not isinstance(cfg, Config):
raise TypeError(f'Config must be type dict, got {type(cfg)}')
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
return model
# reigister cls for diffusion.
INFERENCES = Registry('INFERENCE', build_func=build_inference)