File size: 5,942 Bytes
d6d3a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from common.ld_utils import unsort as unsort_list


# pytorch implementation for np.nanmean
# https://github.com/pytorch/pytorch/issues/21987#issuecomment-539402619
def nanmean(v, *args, inplace=False, **kwargs):
    if not inplace:
        v = v.clone()
    is_nan = torch.isnan(v)
    v[is_nan] = 0
    return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)


def grad_norm(model):
    # compute norm of gradient for a model
    total_norm = None
    for p in model.parameters():
        if p.grad is not None:
            if total_norm is None:
                total_norm = 0
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2

    if total_norm is not None:
        total_norm = total_norm ** (1.0 / 2)
    else:
        total_norm = 0.0
    return total_norm


def pad_tensor_list(v_list: list):
    dev = v_list[0].device
    num_meshes = len(v_list)
    num_dim = 1 if len(v_list[0].shape) == 1 else v_list[0].shape[1]
    v_len_list = []
    for verts in v_list:
        v_len_list.append(verts.shape[0])

    pad_len = max(v_len_list)
    dtype = v_list[0].dtype
    if num_dim == 1:
        padded_tensor = torch.zeros(num_meshes, pad_len, dtype=dtype)
    else:
        padded_tensor = torch.zeros(num_meshes, pad_len, num_dim, dtype=dtype)
    for idx, (verts, v_len) in enumerate(zip(v_list, v_len_list)):
        padded_tensor[idx, :v_len] = verts
    padded_tensor = padded_tensor.to(dev)
    v_len_list = torch.LongTensor(v_len_list).to(dev)
    return padded_tensor, v_len_list


def unpad_vtensor(
    vtensor: (torch.Tensor), lens: (torch.LongTensor, torch.cuda.LongTensor)
):
    tensors_list = []
    for verts, vlen in zip(vtensor, lens):
        tensors_list.append(verts[:vlen])
    return tensors_list


def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.
    Args:
      labels: (LongTensor) class labels, sized [N, D1, D2, ..].
      num_classes: (int) number of classes.
    Returns:
      (tensor) encoded labels, sized [N, D1, D2, .., Dk, #classes].
    """
    y = torch.eye(num_classes).float()
    return y[labels]


def unsort(ten, sort_idx):
    """
    Unsort a tensor of shape (N, *) using the sort_idx list(N).
    Return a tensor of the pre-sorting order in shape (N, *)
    """
    assert isinstance(ten, torch.Tensor)
    assert isinstance(sort_idx, list)
    assert ten.shape[0] == len(sort_idx)

    out_list = list(torch.chunk(ten, ten.size(0), dim=0))
    out_list = unsort_list(out_list, sort_idx)
    out_list = torch.cat(out_list, dim=0)
    return out_list


def all_comb(X, Y):
    """
    Returns all possible combinations of elements in X and Y.
    X: (n_x, d_x)
    Y: (n_y, d_y)
    Output: Z: (n_x*x_y, d_x+d_y)
    Example:
    X = tensor([[8, 8, 8],
                [7, 5, 9]])
    Y = tensor([[3, 8, 7, 7],
                [3, 7, 9, 9],
                [6, 4, 3, 7]])
    Z = tensor([[8, 8, 8, 3, 8, 7, 7],
                [8, 8, 8, 3, 7, 9, 9],
                [8, 8, 8, 6, 4, 3, 7],
                [7, 5, 9, 3, 8, 7, 7],
                [7, 5, 9, 3, 7, 9, 9],
                [7, 5, 9, 6, 4, 3, 7]])
    """
    assert len(X.size()) == 2
    assert len(Y.size()) == 2
    X1 = X.unsqueeze(1)
    Y1 = Y.unsqueeze(0)
    X2 = X1.repeat(1, Y.shape[0], 1)
    Y2 = Y1.repeat(X.shape[0], 1, 1)
    Z = torch.cat([X2, Y2], -1)
    Z = Z.view(-1, Z.shape[-1])
    return Z


def toggle_parameters(model, requires_grad):
    """
    Set all weights to requires_grad or not.
    """
    for param in model.parameters():
        param.requires_grad = requires_grad


def detach_tensor(ten):
    """This function move tensor to cpu and convert to numpy"""
    if isinstance(ten, torch.Tensor):
        return ten.cpu().detach().numpy()
    return ten


def count_model_parameters(model):
    """
    Return the amount of parameters that requries gradients.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def reset_all_seeds(seed):
    """Reset all seeds for reproduciability."""
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)


def get_activation(name):
    """This function return an activation constructor by name."""
    if name == "tanh":
        return nn.Tanh()
    elif name == "sigmoid":
        return nn.Sigmoid()
    elif name == "relu":
        return nn.ReLU()
    elif name == "selu":
        return nn.SELU()
    elif name == "relu6":
        return nn.ReLU6()
    elif name == "softplus":
        return nn.Softplus()
    elif name == "softshrink":
        return nn.Softshrink()
    else:
        print("Undefined activation: %s" % (name))
        assert False


def stack_ll_tensors(tensor_list_list):
    """
    Recursively stack a list of lists of lists .. whose elements are tensors with the same shape
    """
    if isinstance(tensor_list_list, torch.Tensor):
        return tensor_list_list
    assert isinstance(tensor_list_list, list)
    if isinstance(tensor_list_list[0], torch.Tensor):
        return torch.stack(tensor_list_list)

    stacked_tensor = []
    for tensor_list in tensor_list_list:
        stacked_tensor.append(stack_ll_tensors(tensor_list))
    stacked_tensor = torch.stack(stacked_tensor)
    return stacked_tensor


def get_optim(name):
    """This function return an optimizer constructor by name."""
    if name == "adam":
        return optim.Adam
    elif name == "rmsprop":
        return optim.RMSprop
    elif name == "sgd":
        return optim.SGD
    else:
        print("Undefined optim: %s" % (name))
        assert False


def decay_lr(optimizer, gamma):
    """
    Decay the learning rate by gamma
    """
    assert isinstance(gamma, float)
    assert 0 <= gamma and gamma <= 1.0
    for param_group in optimizer.param_groups:
        param_group["lr"] *= gamma