File size: 689 Bytes
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from enum import Enum

import yaml
from easydict import EasyDict as edict
import torch.nn as nn
import torch


def load_yaml(path):
    with open(path, 'r') as f:
        return edict(yaml.safe_load(f))


def move_to_device(obj, device):
    if isinstance(obj, nn.Module):
        return obj.to(device)
    if torch.is_tensor(obj):
        return obj.to(device)
    if isinstance(obj, (tuple, list)):
        return [move_to_device(el, device) for el in obj]
    if isinstance(obj, dict):
        return {name: move_to_device(val, device) for name, val in obj.items()}
    raise ValueError(f'Unexpected type {type(obj)}')


class SmallMode(Enum):
    DROP = "drop"
    UPSCALE = "upscale"