File size: 758 Bytes
1e3b872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from r_chainner.archs.face.gfpganv1_clean_arch import GFPGANv1Clean
from r_chainner.types import PyTorchModel
class UnsupportedModel(Exception):
pass
def load_state_dict(state_dict) -> PyTorchModel:
state_dict_keys = list(state_dict.keys())
if "params_ema" in state_dict_keys:
state_dict = state_dict["params_ema"]
elif "params-ema" in state_dict_keys:
state_dict = state_dict["params-ema"]
elif "params" in state_dict_keys:
state_dict = state_dict["params"]
state_dict_keys = list(state_dict.keys())
# GFPGAN
if (
"toRGB.0.weight" in state_dict_keys
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
):
model = GFPGANv1Clean(state_dict)
return model
|