|
from .imagebase import imagebase_ae, imagebase_ae_stride, imagebase_ae_channel |
|
from .videobase import videobase_ae, videobase_ae_stride, videobase_ae_channel |
|
from .videobase import ( |
|
VQVAEConfiguration, |
|
VQVAEModel, |
|
VQVAETrainer, |
|
CausalVQVAEModel, |
|
CausalVQVAEConfiguration, |
|
CausalVQVAETrainer |
|
) |
|
|
|
ae_stride_config = {} |
|
ae_stride_config.update(imagebase_ae_stride) |
|
ae_stride_config.update(videobase_ae_stride) |
|
|
|
ae_channel_config = {} |
|
ae_channel_config.update(imagebase_ae_channel) |
|
ae_channel_config.update(videobase_ae_channel) |
|
|
|
def getae(args): |
|
"""deprecation""" |
|
ae = imagebase_ae.get(args.ae, None) or videobase_ae.get(args.ae, None) |
|
assert ae is not None |
|
return ae(args.ae) |
|
|
|
def getae_wrapper(ae): |
|
"""deprecation""" |
|
ae = imagebase_ae.get(ae, None) or videobase_ae.get(ae, None) |
|
assert ae is not None |
|
return ae |