File size: 972 Bytes
19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import logging
from saicinpainting.training.modules.ffc import FFCResNetGenerator
from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
NLayerDiscriminator, MultidilatedNLayerDiscriminator
def make_generator(config, kind, **kwargs):
logging.info(f'Make generator {kind}')
if kind == 'pix2pixhd_multidilated':
return MultiDilatedGlobalGenerator(**kwargs)
if kind == 'pix2pixhd_global':
return GlobalGenerator(**kwargs)
if kind == 'ffc_resnet':
return FFCResNetGenerator(**kwargs)
raise ValueError(f'Unknown generator kind {kind}')
def make_discriminator(kind, **kwargs):
logging.info(f'Make discriminator {kind}')
if kind == 'pix2pixhd_nlayer_multidilated':
return MultidilatedNLayerDiscriminator(**kwargs)
if kind == 'pix2pixhd_nlayer':
return NLayerDiscriminator(**kwargs)
raise ValueError(f'Unknown discriminator kind {kind}')
|