Spaces:
Runtime error
Runtime error
File size: 3,071 Bytes
d6d3a5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import itertools
import numpy as np
import torch
def sort_dict(disordered):
sorted_dict = {k: disordered[k] for k in sorted(disordered)}
return sorted_dict
def prefix_dict(mydict, prefix):
out = {prefix + k: v for k, v in mydict.items()}
return out
def postfix_dict(mydict, postfix):
out = {k + postfix: v for k, v in mydict.items()}
return out
def unsort(L, sort_idx):
assert isinstance(sort_idx, list)
assert isinstance(L, list)
LL = zip(sort_idx, L)
LL = sorted(LL, key=lambda x: x[0])
_, L = zip(*LL)
return list(L)
def cat_dl(out_list, dim, verbose=True, squeeze=True):
out = {}
for key, val in out_list.items():
if isinstance(val[0], torch.Tensor):
out[key] = torch.cat(val, dim=dim)
if squeeze:
out[key] = out[key].squeeze()
elif isinstance(val[0], np.ndarray):
out[key] = np.concatenate(val, axis=dim)
if squeeze:
out[key] = np.squeeze(out[key])
elif isinstance(val[0], list):
out[key] = sum(val, [])
else:
if verbose:
print(f"Ignoring {key} undefined type {type(val[0])}")
return out
def stack_dl(out_list, dim, verbose=True, squeeze=True):
out = {}
for key, val in out_list.items():
if isinstance(val[0], torch.Tensor):
out[key] = torch.stack(val, dim=dim)
if squeeze:
out[key] = out[key].squeeze()
elif isinstance(val[0], np.ndarray):
out[key] = np.stack(val, axis=dim)
if squeeze:
out[key] = np.squeeze(out[key])
elif isinstance(val[0], list):
out[key] = sum(val, [])
else:
out[key] = val
if verbose:
print(f"Processing {key} undefined type {type(val[0])}")
return out
def add_prefix_postfix(mydict, prefix="", postfix=""):
assert isinstance(mydict, dict)
return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
def ld2dl(LD):
assert isinstance(LD, list)
assert isinstance(LD[0], dict)
"""
A list of dict (same keys) to a dict of lists
"""
dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
return dict_list
class NameSpace(object):
def __init__(self, adict):
self.__dict__.update(adict)
def dict2ns(mydict):
"""
Convert dict objec to namespace
"""
return NameSpace(mydict)
def ld2dev(ld, dev):
"""
Convert tensors in a list or dict to a device recursively
"""
if isinstance(ld, torch.Tensor):
return ld.to(dev)
if isinstance(ld, dict):
for k, v in ld.items():
ld[k] = ld2dev(v, dev)
return ld
if isinstance(ld, list):
return [ld2dev(x, dev) for x in ld]
return ld
def all_comb_dict(hyper_dict):
assert isinstance(hyper_dict, dict)
keys, values = zip(*hyper_dict.items())
permute_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
return permute_dicts
|