import torch import torch.nn as nn import numpy as np import math import torch.nn.functional as F # LORA finetuning originally by edwardjhu class LoRALayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights class LoRALinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): #nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize B the same way as the default for nn.Linear and A to zero # this is different than what is described in the paper but should not affect performance nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) # def train(self, mode: bool = True): # def T(w): # return w.transpose(0, 1) if self.fan_in_fan_out else w # nn.Linear.train(self, mode) # if mode: # if self.merge_weights and self.merged: # # Make sure that the weights are not merged # if self.r > 0: # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling # self.merged = False # else: # if self.merge_weights and not self.merged: # # Merge the weights and mark it # if self.r > 0: # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling # self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return F.linear(x, T(self.weight), bias=self.bias) class ConvLoRA(nn.Conv2d, LoRALayer): def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert isinstance(kernel_size, int) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) ) self.lora_B = nn.Parameter( self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() self.merged = False def reset_parameters(self): #self.conv.reset_parameters() if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) # def train(self, mode=True): # super(ConvLoRA, self).train(mode) # if mode: # if self.merge_weights and self.merged: # if self.r > 0: # # Make sure that the weights are not merged # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling # self.merged = False # else: # if self.merge_weights and not self.merged: # if self.r > 0: # # Merge the weights and mark it # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling # self.merged = True def forward(self, x): if self.r > 0 and not self.merged: # return self.conv._conv_forward( # x, # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, # self.conv.bias # ) weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling bias = self.bias return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) else: return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer): def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert isinstance(kernel_size, int) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) ) self.lora_B = nn.Parameter( self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() self.merged = False def reset_parameters(self): #self.conv.reset_parameters() if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) # def train(self, mode=True): # super(ConvTransposeLoRA, self).train(mode) # if mode: # if self.merge_weights and self.merged: # if self.r > 0: # # Make sure that the weights are not merged # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling # self.merged = False # else: # if self.merge_weights and not self.merged: # if self.r > 0: # # Merge the weights and mark it # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling # self.merged = True def forward(self, x): if self.r > 0 and not self.merged: weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling bias = self.bias return F.conv_transpose2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, groups=self.groups, dilation=self.dilation) else: return F.conv_transpose2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, groups=self.groups, dilation=self.dilation) #return self.conv(x) class Conv2dLoRA(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv2dLoRA, self).__init__(*args, **kwargs) class ConvTranspose2dLoRA(ConvTransposeLoRA): def __init__(self, *args, **kwargs): super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs) def compute_depth_expectation(prob, depth_values): depth_values = depth_values.view(*depth_values.shape, 1, 1) depth = torch.sum(prob * depth_values, 1) return depth def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None): #with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) # def upflow8(flow, mode='bilinear'): # new_size = (8 * flow.shape[2], 8 * flow.shape[3]) # return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) def upflow4(flow, mode='bilinear'): new_size = (4 * flow.shape[2], 4 * flow.shape[3]) #with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) def coords_grid(batch, ht, wd): # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd))) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def norm_normalize(norm_out): min_kappa = 0.01 norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 kappa = F.elu(kappa) + 1.0 + min_kappa final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) return final_out # uncertainty-guided sampling (only used during training) @torch.no_grad() def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): device = init_normal.device B, _, H, W = init_normal.shape N = int(sampling_ratio * H * W) beta = beta # uncertainty map uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W # gt_invalid_mask (B, H, W) if gt_norm_mask is not None: gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 uncertainty_map[gt_invalid_mask] = -1e4 # (B, H*W) _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) # importance sampling if int(beta * N) > 0: importance = idx[:, :int(beta * N)] # B, beta*N # remaining remaining = idx[:, int(beta * N):] # B, H*W - beta*N # coverage num_coverage = N - int(beta * N) if num_coverage <= 0: samples = importance else: coverage_list = [] for i in range(B): idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N samples = torch.cat((importance, coverage), dim=1) # B, N else: # remaining remaining = idx[:, :] # B, H*W # coverage num_coverage = N coverage_list = [] for i in range(B): idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N samples = coverage # point coordinates rows_int = samples // W # 0 for first row, H-1 for last row rows_float = rows_int / float(H-1) # 0 to 1.0 rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 cols_int = samples % W # 0 for first column, W-1 for last column cols_float = cols_int / float(W-1) # 0 to 1.0 cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 point_coords = torch.zeros(B, 1, N, 2) point_coords[:, 0, :, 0] = cols_float # x coord point_coords[:, 0, :, 1] = rows_float # y coord point_coords = point_coords.to(device) return point_coords, rows_int, cols_int class FlowHead(nn.Module): def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None): super(FlowHead, self).__init__() self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) self.relu = nn.ReLU(inplace=True) def forward(self, x): depth = self.conv2d(self.relu(self.conv1d(x))) normal = self.conv2n(self.relu(self.conv1n(x))) return torch.cat((depth, normal), dim=1) class ConvGRU(nn.Module): def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None): super(ConvGRU, self).__init__() self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) def forward(self, h, cz, cr, cq, *x_list): x = torch.cat(x_list, dim=1) hx = torch.cat([h, x], dim=1) z = torch.sigmoid((self.convz(hx) + cz)) r = torch.sigmoid((self.convr(hx) + cr)) q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq)) # z = torch.sigmoid((self.convz(hx) + cz).float()) # r = torch.sigmoid((self.convr(hx) + cr).float()) # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float()) h = (1-z) * h + z * q return h def pool2x(x): return F.avg_pool2d(x, 3, stride=2, padding=1) def pool4x(x): return F.avg_pool2d(x, 5, stride=4, padding=1) def interp(x, dest): interp_args = {'mode': 'bilinear', 'align_corners': True} return interpolate_float32(x, dest.shape[2:], **interp_args) class BasicMultiUpdateBlock(nn.Module): def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None): super().__init__() self.args = args self.n_gru_layers = args.model.decode_head.n_gru_layers # 3 self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) # self.encoder = BasicMotionEncoder(args) # encoder_output_dim = 128 # if there is corr volume encoder_output_dim = 6 # no corr volume self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode) self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode) self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode) self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode) factor = 2**self.n_downsample self.mask = nn.Sequential( Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0)) def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True): if iter32: net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1])) if iter16: if self.n_gru_layers > 2: net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1])) else: net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1])) if iter08: if corr is not None: motion_features = self.encoder(flow, corr) else: motion_features = flow if self.n_gru_layers > 1: net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) else: net[0] = self.gru08(net[0], *(inp[0]), motion_features) if not update: return net delta_flow = self.flow_head(net[0]) # scale mask to balence gradients mask = .25 * self.mask(net[0]) return net, mask, delta_flow class LayerNorm2d(nn.LayerNorm): def __init__(self, dim): super(LayerNorm2d, self).__init__(dim) def forward(self, x): x = x.permute(0, 2, 3, 1).contiguous() x = super(LayerNorm2d, self).forward(x) x = x.permute(0, 3, 1, 2).contiguous() return x class ResidualBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None): super(ResidualBlock, self).__init__() self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0) self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == 'group': self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not (stride == 1 and in_planes == planes): self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == 'batch': self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not (stride == 1 and in_planes == planes): self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == 'instance': self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not (stride == 1 and in_planes == planes): self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == 'layer': self.norm1 = LayerNorm2d(planes) self.norm2 = LayerNorm2d(planes) if not (stride == 1 and in_planes == planes): self.norm3 = LayerNorm2d(planes) elif norm_fn == 'none': self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not (stride == 1 and in_planes == planes): self.norm3 = nn.Sequential() if stride == 1 and in_planes == planes: self.downsample = None else: self.downsample = nn.Sequential( Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3) def forward(self, x): y = x y = self.conv1(y) y = self.norm1(y) y = self.relu(y) y = self.conv2(y) y = self.norm2(y) y = self.relu(y) if self.downsample is not None: x = self.downsample(x) return self.relu(x+y) class ContextFeatureEncoder(nn.Module): ''' Encoder features are used to: 1. initialize the hidden state of the update operator 2. and also injected into the GRU during each iteration of the update operator ''' def __init__(self, in_dim, output_dim, tuning_mode=None): ''' in_dim = [x4, x8, x16, x32] output_dim = [hindden_dims, context_dims] [[x4,x8,x16,x32],[x4,x8,x16,x32]] ''' super().__init__() output_list = [] for dim in output_dim: conv_out = nn.Sequential( ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode), Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) output_list.append(conv_out) self.outputs04 = nn.ModuleList(output_list) output_list = [] for dim in output_dim: conv_out = nn.Sequential( ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode), Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) output_list.append(conv_out) self.outputs08 = nn.ModuleList(output_list) output_list = [] for dim in output_dim: conv_out = nn.Sequential( ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode), Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) output_list.append(conv_out) self.outputs16 = nn.ModuleList(output_list) # output_list = [] # for dim in output_dim: # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1) # output_list.append(conv_out) # self.outputs32 = nn.ModuleList(output_list) def forward(self, encoder_features): x_4, x_8, x_16, x_32 = encoder_features outputs04 = [f(x_4) for f in self.outputs04] outputs08 = [f(x_8) for f in self.outputs08] outputs16 = [f(x_16)for f in self.outputs16] # outputs32 = [f(x_32) for f in self.outputs32] return (outputs04, outputs08, outputs16) class ConvBlock(nn.Module): # reimplementation of DPT def __init__(self, channels, tuning_mode=None): super(ConvBlock, self).__init__() self.act = nn.ReLU(inplace=True) self.conv1 = Conv2dLoRA( channels, channels, kernel_size=3, stride=1, padding=1, r = 8 if tuning_mode == 'lora' else 0 ) self.conv2 = Conv2dLoRA( channels, channels, kernel_size=3, stride=1, padding=1, r = 8 if tuning_mode == 'lora' else 0 ) def forward(self, x): out = self.act(x) out = self.conv1(out) out = self.act(out) out = self.conv2(out) return x + out class FuseBlock(nn.Module): # reimplementation of DPT def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None): super(FuseBlock, self).__init__() self.fuse = fuse self.scale_factor = scale_factor self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode) if self.fuse: self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode) self.out_conv = Conv2dLoRA( in_channels, out_channels, kernel_size=1, stride=1, padding=0, r = 8 if tuning_mode == 'lora' else 0 ) self.upsample = upsample def forward(self, x1, x2=None): if x2 is not None: x2 = self.way_branch(x2) x1 = x1 + x2 out = self.way_trunk(x1) if self.upsample: out = interpolate_float32( out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True ) out = self.out_conv(out) return out class Readout(nn.Module): # From DPT def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None): super(Readout, self).__init__() self.use_cls_token = use_cls_token if self.use_cls_token == True: self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0) self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0) self.act = nn.GELU() else: self.project = nn.Identity() def forward(self, x): if self.use_cls_token == True: x_patch = self.project_patch(x[0]) x_learn = self.project_learn(x[1]) x_learn = x_learn.expand_as(x_patch).contiguous() features = x_patch + x_learn return self.act(features) else: return self.project(x) class Token2Feature(nn.Module): # From DPT def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None): super(Token2Feature, self).__init__() self.scale_factor = scale_factor self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) if scale_factor > 1 and isinstance(scale_factor, int): self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0, in_channels=vit_channel, out_channels=feature_channel, kernel_size=scale_factor, stride=scale_factor, padding=0, ) elif scale_factor > 1: self.sample = nn.Sequential( # Upsample2(upscale=scale_factor), # nn.Upsample(scale_factor=scale_factor), Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, in_channels=vit_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0, ), ) elif scale_factor < 1: scale_factor = int(1.0 / scale_factor) self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, in_channels=vit_channel, out_channels=feature_channel, kernel_size=scale_factor+1, stride=scale_factor, padding=1, ) else: self.sample = nn.Identity() def forward(self, x): x = self.readoper(x) #if use_cls_token == True: x = x.permute(0, 3, 1, 2).contiguous() if isinstance(self.scale_factor, float): x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest') x = self.sample(x) return x class EncoderFeature(nn.Module): def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None): super(EncoderFeature, self).__init__() self.vit_channel = vit_channel self.num_ch_dec = num_ch_dec self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) def forward(self, ref_feature): x = self.read_3(ref_feature[3]) # 1/14 x2 = self.read_2(ref_feature[2]) # 1/14 x1 = self.read_1(ref_feature[1]) # 1/7 x0 = self.read_0(ref_feature[0]) # 1/4 return x, x2, x1, x0 class DecoderFeature(nn.Module): def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None): super(DecoderFeature, self).__init__() self.vit_channel = vit_channel self.num_ch_dec = num_ch_dec self.upconv_3 = FuseBlock( self.num_ch_dec[4], self.num_ch_dec[3], fuse=False, upsample=False, tuning_mode=tuning_mode) self.upconv_2 = FuseBlock( self.num_ch_dec[3], self.num_ch_dec[2], tuning_mode=tuning_mode) self.upconv_1 = FuseBlock( self.num_ch_dec[2], self.num_ch_dec[1] + 2, scale_factor=7/4, tuning_mode=tuning_mode) # self.upconv_0 = FuseBlock( # self.num_ch_dec[1], # self.num_ch_dec[0] + 1, # ) def forward(self, ref_feature): x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4 x = self.upconv_3(x) # 1/14 x = self.upconv_2(x, x2) # 1/7 x = self.upconv_1(x, x1) # 1/4 # x = self.upconv_0(x, x0) # 4/7 return x class RAFTDepthNormalDPT5(nn.Module): def __init__(self, cfg): super().__init__() self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024] self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14] self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14] self.use_cls_token = cfg.model.decode_head.use_cls_token self.up_scale = cfg.model.decode_head.up_scale self.num_register_tokens = cfg.model.decode_head.num_register_tokens self.min_val = cfg.data_basic.depth_normalize[0] self.max_val = cfg.data_basic.depth_normalize[1] self.regress_scale = 100.0\ try: tuning_mode = cfg.model.decode_head.tuning_mode except: tuning_mode = None self.tuning_mode = tuning_mode self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128] self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3 self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) self.iters = cfg.model.decode_head.iters # 22 self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True self.num_depth_regressor_anchor = 256 # 512 self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode) self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode) self.depth_regressor = nn.Sequential( Conv2dLoRA(self.used_res_channel, self.num_depth_regressor_anchor, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0), # nn.BatchNorm2d(self.num_depth_regressor_anchor), nn.ReLU(inplace=True), Conv2dLoRA(self.num_depth_regressor_anchor, self.num_depth_regressor_anchor, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), ) self.normal_predictor = nn.Sequential( Conv2dLoRA(self.used_res_channel, 128, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0,), # nn.BatchNorm2d(128), nn.ReLU(inplace=True), Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), ) self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode) self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)]) self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode) self.relu = nn.ReLU(inplace=True) def get_bins(self, bins_num): depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda") #depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu") depth_bins_vec = torch.exp(depth_bins_vec) return depth_bins_vec def register_depth_expectation_anchor(self, bins_num, B): depth_bins_vec = self.get_bins(bins_num) depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) def clamp(self, x): y = self.relu(x - self.min_val) + self.min_val y = self.max_val - self.relu(self.max_val - y) return y def regress_depth(self, feature_map_d): prob_feature = self.depth_regressor(feature_map_d) prob = prob_feature.softmax(dim=1) #prob = prob_feature.float().softmax(dim=1) ## Error logging if torch.isnan(prob).any(): print('prob_feat_nan!!!') if torch.isinf(prob).any(): print('prob_feat_inf!!!') # h = prob[0,:,0,0].cpu().numpy().reshape(-1) # import matplotlib.pyplot as plt # plt.bar(range(len(h)), h) B = prob.shape[0] if "depth_expectation_anchor" not in self._buffers: self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) d = compute_depth_expectation( prob, self.depth_expectation_anchor[:B, ...]).unsqueeze(1) ## Error logging if torch.isnan(d ).any(): print('d_nan!!!') if torch.isinf(d ).any(): print('d_inf!!!') return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature def pred_normal(self, feature_map, confidence): normal_out = self.normal_predictor(feature_map) ## Error logging if torch.isnan(normal_out).any(): print('norm_nan!!!') if torch.isinf(normal_out).any(): print('norm_feat_inf!!!') return norm_normalize(torch.cat([normal_out, confidence], dim=1)) #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float()) #def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True): def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') meshgrid = torch.stack((x, y)) meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) #self.register_buffer('meshgrid', meshgrid, persistent=False) return meshgrid def upsample_flow(self, flow, mask): """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ N, D, H, W = flow.shape factor = 2 ** self.n_downsample mask = mask.view(N, 1, 9, factor, factor, H, W) mask = torch.softmax(mask, dim=2) #mask = torch.softmax(mask.float(), dim=2) #up_flow = F.unfold(factor * flow, [3,3], padding=1) up_flow = F.unfold(flow, [3,3], padding=1) up_flow = up_flow.view(N, D, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, D, factor*H, factor*W) def initialize_flow(self, img): """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" N, _, H, W = img.shape coords0 = coords_grid(N, H, W).to(img.device) coords1 = coords_grid(N, H, W).to(img.device) return coords0, coords1 def upsample(self, x, scale_factor=2): """Upsample input tensor by a factor of 2 """ return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest") def forward(self, vit_features, **kwargs): ## read vit token to multi-scale features B, H, W, _, _, num_register_tokens = vit_features[1] vit_features = vit_features[0] ## Error logging if torch.isnan(vit_features[0]).any(): print('vit_feature_nan!!!') if torch.isinf(vit_features[0]).any(): print('vit_feature_inf!!!') if self.use_cls_token == True: vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \ ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features] else: vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features] encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4 ## Error logging for en_ft in encoder_features: if torch.isnan(en_ft).any(): print('decoder_feature_nan!!!') print(en_ft.shape) if torch.isinf(en_ft).any(): print('decoder_feature_inf!!!') print(en_ft.shape) ## decode features to init-depth (and confidence) ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth ## Error logging if torch.isnan(ref_feat).any(): print('ref_feat_nan!!!') if torch.isinf(ref_feat).any(): print('ref_feat_inf!!!') feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction depth_confidence_map = ref_feat[:, -2:-1, :, :] normal_confidence_map = ref_feat[:, -1:, :, :] depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W) ## encoder features to context-feature for init-hidden-state and contex-features cnet_list = self.context_feature_encoder(encoder_features[::-1]) net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)] coords0, coords1 = self.initialize_flow(net_list[0]) if depth_init is not None: coords1 = coords1 + depth_init if self.training: low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())] init_depth = upflow4(depth_init) flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)] conf_predictions = [init_depth[:,1:2]] normal_outs = [norm_normalize(init_depth[:,2:].clone())] else: flow_predictions = [] conf_predictions = [] samples_pred_list = [] coord_list = [] normal_outs = [] low_resolution_init = [] for itr in range(self.iters): # coords1 = coords1.detach() flow = coords1 - coords0 if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False) net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2) # F(t+1) = F(t) + \Delta(t) coords1 = coords1 + delta_flow # We do not need to upsample or output intermediate results in test_mode #if (not self.training) and itr < self.iters-1: #continue # upsample predictions if up_mask is None: flow_up = self.upsample(coords1-coords0, 4) else: flow_up = self.upsample_flow(coords1 - coords0, up_mask) # flow_up = self.upsample(coords1-coords0, 4) flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val)) conf_predictions.append(flow_up[:,1:2]) normal_outs.append(norm_normalize(flow_up[:,2:].clone())) outputs=dict( prediction=flow_predictions[-1], predictions_list=flow_predictions, confidence=conf_predictions[-1], confidence_list=conf_predictions, pred_logit=None, # samples_pred_list=samples_pred_list, # coord_list=coord_list, prediction_normal=normal_outs[-1], normal_out_list=normal_outs, low_resolution_init=low_resolution_init, ) return outputs if __name__ == "__main__": try: from mmcv.utils import Config except: from mmengine import Config cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py') cfg.model.decode_head.in_channels = [384, 384, 384, 384] cfg.model.decode_head.feature_channels = [96, 192, 384, 768] cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384] cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48] cfg.model.decode_head.up_scale = 7 # cfg.model.decode_head.use_cls_token = True # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]] cfg.model.decode_head.use_cls_token = True cfg.model.decode_head.num_register_tokens = 4 vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)] decoder = RAFTDepthNormalDPT5(cfg).cuda() output = decoder(vit_feature) temp = 1