JUGGHM commited on
Commit
8a32844
·
1 Parent(s): 4b42627

Upload 62 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mono/configs/HourglassDecoder/convlarge.0.3_150.py +25 -0
  2. mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py +25 -0
  3. mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py +25 -0
  4. mono/configs/HourglassDecoder/vit.raft5.large.py +33 -0
  5. mono/configs/HourglassDecoder/vit.raft5.small.py +33 -0
  6. mono/configs/__init__.py +1 -0
  7. mono/configs/_base_/_data_base_.py +13 -0
  8. mono/configs/_base_/datasets/_data_base_.py +12 -0
  9. mono/configs/_base_/default_runtime.py +4 -0
  10. mono/configs/_base_/models/backbones/convnext_large.py +16 -0
  11. mono/configs/_base_/models/backbones/dino_vit_large.py +7 -0
  12. mono/configs/_base_/models/backbones/dino_vit_large_reg.py +7 -0
  13. mono/configs/_base_/models/backbones/dino_vit_small_reg.py +7 -0
  14. mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py +10 -0
  15. mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py +20 -0
  16. mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py +19 -0
  17. mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py +19 -0
  18. mono/model/__init__.py +5 -0
  19. mono/model/__pycache__/__init__.cpython-39.pyc +0 -0
  20. mono/model/__pycache__/monodepth_model.cpython-39.pyc +0 -0
  21. mono/model/backbones/ConvNeXt.py +271 -0
  22. mono/model/backbones/ViT_DINO.py +1504 -0
  23. mono/model/backbones/ViT_DINO_reg.py +1293 -0
  24. mono/model/backbones/__init__.py +11 -0
  25. mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc +0 -0
  26. mono/model/backbones/__pycache__/__init__.cpython-39.pyc +0 -0
  27. mono/model/decode_heads/HourGlassDecoder.py +274 -0
  28. mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py +1033 -0
  29. mono/model/decode_heads/__init__.py +4 -0
  30. mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc +0 -0
  31. mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc +0 -0
  32. mono/model/model_pipelines/__base_model__.py +20 -0
  33. mono/model/model_pipelines/__init__.py +6 -0
  34. mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc +0 -0
  35. mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc +0 -0
  36. mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc +0 -0
  37. mono/model/model_pipelines/dense_pipeline.py +16 -0
  38. mono/model/monodepth_model.py +37 -0
  39. mono/tools/test_scale_cano.py +158 -0
  40. mono/utils/__init__.py +1 -0
  41. mono/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  42. mono/utils/__pycache__/avg_meter.cpython-39.pyc +0 -0
  43. mono/utils/__pycache__/comm.cpython-39.pyc +0 -0
  44. mono/utils/__pycache__/custom_data.cpython-39.pyc +0 -0
  45. mono/utils/__pycache__/do_test.cpython-39.pyc +0 -0
  46. mono/utils/__pycache__/logger.cpython-39.pyc +0 -0
  47. mono/utils/__pycache__/mldb.cpython-39.pyc +0 -0
  48. mono/utils/__pycache__/running.cpython-39.pyc +0 -0
  49. mono/utils/__pycache__/transform.cpython-39.pyc +0 -0
  50. mono/utils/__pycache__/unproj_pcd.cpython-39.pyc +0 -0
mono/configs/HourglassDecoder/convlarge.0.3_150.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_=[
2
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
+ '../_base_/datasets/_data_base_.py',
4
+ '../_base_/default_runtime.py',
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ pretrained=False,
10
+ )
11
+ )
12
+
13
+ # configs of the canonical space
14
+ data_basic=dict(
15
+ canonical_space = dict(
16
+ img_size=(512, 960),
17
+ focal_length=1000.0,
18
+ ),
19
+ depth_range=(0, 1),
20
+ depth_normalize=(0.3, 150),
21
+ crop_size = (544, 1216),
22
+ )
23
+
24
+ batchsize_per_gpu = 2
25
+ thread_per_gpu = 4
mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_=[
2
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
+ '../_base_/datasets/_data_base_.py',
4
+ '../_base_/default_runtime.py',
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ pretrained=False,
10
+ )
11
+ )
12
+
13
+ # configs of the canonical space
14
+ data_basic=dict(
15
+ canonical_space = dict(
16
+ img_size=(512, 960),
17
+ focal_length=1000.0,
18
+ ),
19
+ depth_range=(0, 1),
20
+ depth_normalize=(0.3, 150),
21
+ crop_size = (512, 1088),
22
+ )
23
+
24
+ batchsize_per_gpu = 2
25
+ thread_per_gpu = 4
mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_=[
2
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
+ '../_base_/datasets/_data_base_.py',
4
+ '../_base_/default_runtime.py',
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ pretrained=False,
10
+ )
11
+ )
12
+
13
+ # configs of the canonical space
14
+ data_basic=dict(
15
+ canonical_space = dict(
16
+ img_size=(512, 960),
17
+ focal_length=1000.0,
18
+ ),
19
+ depth_range=(0, 1),
20
+ depth_normalize=(0.3, 150),
21
+ crop_size = (480, 1216),
22
+ )
23
+
24
+ batchsize_per_gpu = 2
25
+ thread_per_gpu = 4
mono/configs/HourglassDecoder/vit.raft5.large.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_=[
2
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
3
+ '../_base_/datasets/_data_base_.py',
4
+ '../_base_/default_runtime.py',
5
+ ]
6
+
7
+ import numpy as np
8
+ model=dict(
9
+ decode_head=dict(
10
+ type='RAFTDepthNormalDPT5',
11
+ iters=8,
12
+ n_downsample=2,
13
+ detach=False,
14
+ )
15
+ )
16
+
17
+
18
+ max_value = 200
19
+ # configs of the canonical space
20
+ data_basic=dict(
21
+ canonical_space = dict(
22
+ # img_size=(540, 960),
23
+ focal_length=1000.0,
24
+ ),
25
+ depth_range=(0, 1),
26
+ depth_normalize=(0.1, max_value),
27
+ crop_size = (616, 1064), # %28 = 0
28
+ clip_depth_range=(0.1, 200),
29
+ vit_size=(616,1064)
30
+ )
31
+
32
+ batchsize_per_gpu = 1
33
+ thread_per_gpu = 1
mono/configs/HourglassDecoder/vit.raft5.small.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_=[
2
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
3
+ '../_base_/datasets/_data_base_.py',
4
+ '../_base_/default_runtime.py',
5
+ ]
6
+
7
+ import numpy as np
8
+ model=dict(
9
+ decode_head=dict(
10
+ type='RAFTDepthNormalDPT5',
11
+ iters=4,
12
+ n_downsample=2,
13
+ detach=False,
14
+ )
15
+ )
16
+
17
+
18
+ max_value = 200
19
+ # configs of the canonical space
20
+ data_basic=dict(
21
+ canonical_space = dict(
22
+ # img_size=(540, 960),
23
+ focal_length=1000.0,
24
+ ),
25
+ depth_range=(0, 1),
26
+ depth_normalize=(0.1, max_value),
27
+ crop_size = (616, 1064), # %28 = 0
28
+ clip_depth_range=(0.1, 200),
29
+ vit_size=(616,1064)
30
+ )
31
+
32
+ batchsize_per_gpu = 1
33
+ thread_per_gpu = 1
mono/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mono/configs/_base_/_data_base_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # canonical camera setting and basic data setting
2
+ # we set it same as the E300 camera (crop version)
3
+ #
4
+ data_basic=dict(
5
+ canonical_space = dict(
6
+ img_size=(540, 960),
7
+ focal_length=1196.0,
8
+ ),
9
+ depth_range=(0.9, 150),
10
+ depth_normalize=(0.006, 1.001),
11
+ crop_size = (512, 960),
12
+ clip_depth_range=(0.9, 150),
13
+ )
mono/configs/_base_/datasets/_data_base_.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # canonical camera setting and basic data setting
2
+ #
3
+ data_basic=dict(
4
+ canonical_space = dict(
5
+ img_size=(540, 960),
6
+ focal_length=1196.0,
7
+ ),
8
+ depth_range=(0.9, 150),
9
+ depth_normalize=(0.006, 1.001),
10
+ crop_size = (512, 960),
11
+ clip_depth_range=(0.9, 150),
12
+ )
mono/configs/_base_/default_runtime.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ load_from = None
3
+ cudnn_benchmark = True
4
+ test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel']
mono/configs/_base_/models/backbones/convnext_large.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #_base_ = ['./_model_base_.py',]
2
+
3
+ #'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth'
4
+ model = dict(
5
+ #type='EncoderDecoderAuxi',
6
+ backbone=dict(
7
+ type='convnext_large',
8
+ pretrained=True,
9
+ in_22k=True,
10
+ out_indices=[0, 1, 2, 3],
11
+ drop_path_rate=0.4,
12
+ layer_scale_init_value=1.0,
13
+ checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth',
14
+ prefix='backbones.',
15
+ out_channels=[192, 384, 768, 1536]),
16
+ )
mono/configs/_base_/models/backbones/dino_vit_large.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model = dict(
2
+ backbone=dict(
3
+ type='vit_large',
4
+ prefix='backbones.',
5
+ out_channels=[1024, 1024, 1024, 1024],
6
+ drop_path_rate = 0.0),
7
+ )
mono/configs/_base_/models/backbones/dino_vit_large_reg.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model = dict(
2
+ backbone=dict(
3
+ type='vit_large_reg',
4
+ prefix='backbones.',
5
+ out_channels=[1024, 1024, 1024, 1024],
6
+ drop_path_rate = 0.0),
7
+ )
mono/configs/_base_/models/backbones/dino_vit_small_reg.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model = dict(
2
+ backbone=dict(
3
+ type='vit_small_reg',
4
+ prefix='backbones.',
5
+ out_channels=[384, 384, 384, 384],
6
+ drop_path_rate = 0.0),
7
+ )
mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ _base_ = ['../backbones/convnext_large.py',]
3
+ model = dict(
4
+ type='DensePredModel',
5
+ decode_head=dict(
6
+ type='HourglassDecoder',
7
+ in_channels=[192, 384, 768, 1536],
8
+ decoder_channel=[128, 128, 256, 512],
9
+ prefix='decode_heads.'),
10
+ )
mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ _base_ = ['../backbones/dino_vit_large.py']
3
+ model = dict(
4
+ type='DensePredModel',
5
+ decode_head=dict(
6
+ type='RAFTDepthDPT',
7
+ in_channels=[1024, 1024, 1024, 1024],
8
+ use_cls_token=True,
9
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
10
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
11
+ up_scale = 7,
12
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
13
+ n_gru_layers=3,
14
+ n_downsample=2,
15
+ iters=12,
16
+ slow_fast_gru=True,
17
+ corr_radius=4,
18
+ corr_levels=4,
19
+ prefix='decode_heads.'),
20
+ )
mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ _base_ = ['../backbones/dino_vit_large_reg.py']
3
+ model = dict(
4
+ type='DensePredModel',
5
+ decode_head=dict(
6
+ type='RAFTDepthDPT',
7
+ in_channels=[1024, 1024, 1024, 1024],
8
+ use_cls_token=True,
9
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
10
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
11
+ up_scale = 7,
12
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
13
+ n_gru_layers=3,
14
+ n_downsample=2,
15
+ iters=3,
16
+ slow_fast_gru=True,
17
+ num_register_tokens=4,
18
+ prefix='decode_heads.'),
19
+ )
mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ _base_ = ['../backbones/dino_vit_small_reg.py']
3
+ model = dict(
4
+ type='DensePredModel',
5
+ decode_head=dict(
6
+ type='RAFTDepthDPT',
7
+ in_channels=[384, 384, 384, 384],
8
+ use_cls_token=True,
9
+ feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14]
10
+ decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14]
11
+ up_scale = 7,
12
+ hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -]
13
+ n_gru_layers=3,
14
+ n_downsample=2,
15
+ iters=3,
16
+ slow_fast_gru=True,
17
+ num_register_tokens=4,
18
+ prefix='decode_heads.'),
19
+ )
mono/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .monodepth_model import DepthModel
2
+ # from .__base_model__ import BaseDepthModel
3
+
4
+
5
+ __all__ = ['DepthModel', 'BaseDepthModel']
mono/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (250 Bytes). View file
 
mono/model/__pycache__/monodepth_model.cpython-39.pyc ADDED
Binary file (1.62 kB). View file
 
mono/model/backbones/ConvNeXt.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import trunc_normal_, DropPath
5
+ from timm.models.registry import register_model
6
+
7
+ class Block(nn.Module):
8
+ r""" ConvNeXt Block. There are two equivalent implementations:
9
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
10
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
11
+ We use (2) as we find it slightly faster in PyTorch
12
+
13
+ Args:
14
+ dim (int): Number of input channels.
15
+ drop_path (float): Stochastic depth rate. Default: 0.0
16
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
17
+ """
18
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
19
+ super().__init__()
20
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
21
+ self.norm = LayerNorm(dim, eps=1e-6)
22
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
23
+ self.act = nn.GELU()
24
+ self.pwconv2 = nn.Linear(4 * dim, dim)
25
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
26
+ requires_grad=True) if layer_scale_init_value > 0 else None
27
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
28
+
29
+ def forward(self, x):
30
+ input = x
31
+ x = self.dwconv(x)
32
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
33
+ x = self.norm(x)
34
+ x = self.pwconv1(x)
35
+ x = self.act(x)
36
+ x = self.pwconv2(x)
37
+ if self.gamma is not None:
38
+ x = self.gamma * x
39
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
40
+
41
+ x = input + self.drop_path(x)
42
+ return x
43
+
44
+ class ConvNeXt(nn.Module):
45
+ r""" ConvNeXt
46
+ A PyTorch impl of : `A ConvNet for the 2020s` -
47
+ https://arxiv.org/pdf/2201.03545.pdf
48
+ Args:
49
+ in_chans (int): Number of input image channels. Default: 3
50
+ num_classes (int): Number of classes for classification head. Default: 1000
51
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
52
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
53
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
54
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
55
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
56
+ """
57
+ def __init__(self, in_chans=3, num_classes=1000,
58
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
59
+ layer_scale_init_value=1e-6, head_init_scale=1.,
60
+ **kwargs,):
61
+ super().__init__()
62
+
63
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
64
+ stem = nn.Sequential(
65
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
66
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
67
+ )
68
+ self.downsample_layers.append(stem)
69
+ for i in range(3):
70
+ downsample_layer = nn.Sequential(
71
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
72
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
73
+ )
74
+ self.downsample_layers.append(downsample_layer)
75
+
76
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
77
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
78
+ cur = 0
79
+ for i in range(4):
80
+ stage = nn.Sequential(
81
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
82
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
83
+ )
84
+ self.stages.append(stage)
85
+ cur += depths[i]
86
+
87
+ #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
88
+ #self.head = nn.Linear(dims[-1], num_classes)
89
+
90
+ self.apply(self._init_weights)
91
+ #self.head.weight.data.mul_(head_init_scale)
92
+ #self.head.bias.data.mul_(head_init_scale)
93
+
94
+ def _init_weights(self, m):
95
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
96
+ trunc_normal_(m.weight, std=.02)
97
+ nn.init.constant_(m.bias, 0)
98
+
99
+ def forward_features(self, x):
100
+ features = []
101
+ for i in range(4):
102
+ x = self.downsample_layers[i](x)
103
+ x = self.stages[i](x)
104
+ features.append(x)
105
+ return features # global average pooling, (N, C, H, W) -> (N, C)
106
+
107
+ def forward(self, x):
108
+ #x = self.forward_features(x)
109
+ #x = self.head(x)
110
+ features = self.forward_features(x)
111
+ return features
112
+
113
+ class LayerNorm(nn.Module):
114
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
115
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
116
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
117
+ with shape (batch_size, channels, height, width).
118
+ """
119
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
120
+ super().__init__()
121
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
122
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
123
+ self.eps = eps
124
+ self.data_format = data_format
125
+ if self.data_format not in ["channels_last", "channels_first"]:
126
+ raise NotImplementedError
127
+ self.normalized_shape = (normalized_shape, )
128
+
129
+ def forward(self, x):
130
+ if self.data_format == "channels_last":
131
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
132
+ elif self.data_format == "channels_first":
133
+ u = x.mean(1, keepdim=True)
134
+ s = (x - u).pow(2).mean(1, keepdim=True)
135
+ x = (x - u) / torch.sqrt(s + self.eps)
136
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
137
+ return x
138
+
139
+
140
+ model_urls = {
141
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
142
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
143
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
144
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
145
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
146
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
147
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
148
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
149
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
150
+ }
151
+
152
+ def convnext_tiny(pretrained=True,in_22k=False, **kwargs):
153
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
154
+ if pretrained:
155
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
156
+ #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
157
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
158
+ model_dict = model.state_dict()
159
+ pretrained_dict = {}
160
+ unmatched_pretrained_dict = {}
161
+ for k, v in checkpoint['model'].items():
162
+ if k in model_dict:
163
+ pretrained_dict[k] = v
164
+ else:
165
+ unmatched_pretrained_dict[k] = v
166
+ model_dict.update(pretrained_dict)
167
+ model.load_state_dict(model_dict)
168
+ print(
169
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
170
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
171
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
172
+ return model
173
+
174
+ def convnext_small(pretrained=True,in_22k=False, **kwargs):
175
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
176
+ if pretrained:
177
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
178
+ #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
179
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
180
+ model_dict = model.state_dict()
181
+ pretrained_dict = {}
182
+ unmatched_pretrained_dict = {}
183
+ for k, v in checkpoint['model'].items():
184
+ if k in model_dict:
185
+ pretrained_dict[k] = v
186
+ else:
187
+ unmatched_pretrained_dict[k] = v
188
+ model_dict.update(pretrained_dict)
189
+ model.load_state_dict(model_dict)
190
+ print(
191
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
192
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
193
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
194
+ return model
195
+
196
+ def convnext_base(pretrained=True, in_22k=False, **kwargs):
197
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
198
+ if pretrained:
199
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
200
+ #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
201
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
202
+ model_dict = model.state_dict()
203
+ pretrained_dict = {}
204
+ unmatched_pretrained_dict = {}
205
+ for k, v in checkpoint['model'].items():
206
+ if k in model_dict:
207
+ pretrained_dict[k] = v
208
+ else:
209
+ unmatched_pretrained_dict[k] = v
210
+ model_dict.update(pretrained_dict)
211
+ model.load_state_dict(model_dict)
212
+ print(
213
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
214
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
215
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
216
+ return model
217
+
218
+ def convnext_large(pretrained=True, in_22k=False, **kwargs):
219
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
220
+ if pretrained:
221
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
222
+ #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
223
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
224
+ model_dict = model.state_dict()
225
+ pretrained_dict = {}
226
+ unmatched_pretrained_dict = {}
227
+ for k, v in checkpoint['model'].items():
228
+ if k in model_dict:
229
+ pretrained_dict[k] = v
230
+ else:
231
+ unmatched_pretrained_dict[k] = v
232
+ model_dict.update(pretrained_dict)
233
+ model.load_state_dict(model_dict)
234
+ print(
235
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
236
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
237
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
238
+ return model
239
+
240
+ def convnext_xlarge(pretrained=True, in_22k=False, **kwargs):
241
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
242
+ if pretrained:
243
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
244
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
245
+ #url = model_urls['convnext_xlarge_22k']
246
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
247
+ model_dict = model.state_dict()
248
+ pretrained_dict = {}
249
+ unmatched_pretrained_dict = {}
250
+ for k, v in checkpoint['model'].items():
251
+ if k in model_dict:
252
+ pretrained_dict[k] = v
253
+ else:
254
+ unmatched_pretrained_dict[k] = v
255
+ model_dict.update(pretrained_dict)
256
+ model.load_state_dict(model_dict)
257
+ print(
258
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
259
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
260
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
261
+ return model
262
+
263
+ if __name__ == '__main__':
264
+ import torch
265
+ model = convnext_base(True, in_22k=False).cuda()
266
+
267
+ rgb = torch.rand((2, 3, 256, 256)).cuda()
268
+ out = model(rgb)
269
+ print(len(out))
270
+ for i, ft in enumerate(out):
271
+ print(i, ft.shape)
mono/model/backbones/ViT_DINO.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ from functools import partial
12
+ import math
13
+ import logging
14
+ from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torch.utils.checkpoint
20
+ from torch.nn.init import trunc_normal_
21
+
22
+ #from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+ class ConvBlock(nn.Module):
27
+ def __init__(self, channels):
28
+ super(ConvBlock, self).__init__()
29
+
30
+ self.act = nn.ReLU(inplace=True)
31
+ self.conv1 = nn.Conv2d(
32
+ channels,
33
+ channels,
34
+ kernel_size=3,
35
+ stride=1,
36
+ padding=1
37
+ )
38
+ self.norm1 = nn.BatchNorm2d(channels)
39
+ self.conv2 = nn.Conv2d(
40
+ channels,
41
+ channels,
42
+ kernel_size=3,
43
+ stride=1,
44
+ padding=1
45
+ )
46
+ self.norm2 = nn.BatchNorm2d(channels)
47
+
48
+ def forward(self, x):
49
+
50
+ out = self.norm1(x)
51
+ out = self.act(out)
52
+ out = self.conv1(out)
53
+ out = self.norm2(out)
54
+ out = self.act(out)
55
+ out = self.conv2(out)
56
+ return x + out
57
+
58
+ def make_2tuple(x):
59
+ if isinstance(x, tuple):
60
+ assert len(x) == 2
61
+ return x
62
+
63
+ assert isinstance(x, int)
64
+ return (x, x)
65
+
66
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
67
+ if drop_prob == 0.0 or not training:
68
+ return x
69
+ keep_prob = 1 - drop_prob
70
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
71
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
72
+ if keep_prob > 0.0:
73
+ random_tensor.div_(keep_prob)
74
+ output = x * random_tensor
75
+ return output
76
+
77
+ class DropPath(nn.Module):
78
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
79
+
80
+ def __init__(self, drop_prob=None):
81
+ super(DropPath, self).__init__()
82
+ self.drop_prob = drop_prob
83
+
84
+ def forward(self, x):
85
+ return drop_path(x, self.drop_prob, self.training)
86
+
87
+ class LayerScale(nn.Module):
88
+ def __init__(
89
+ self,
90
+ dim: int,
91
+ init_values: Union[float, Tensor] = 1e-5,
92
+ inplace: bool = False,
93
+ ) -> None:
94
+ super().__init__()
95
+ self.inplace = inplace
96
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
97
+
98
+ def forward(self, x: Tensor) -> Tensor:
99
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
100
+
101
+
102
+ class PatchEmbed(nn.Module):
103
+ """
104
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
105
+
106
+ Args:
107
+ img_size: Image size.
108
+ patch_size: Patch token size.
109
+ in_chans: Number of input image channels.
110
+ embed_dim: Number of linear projection output channels.
111
+ norm_layer: Normalization layer.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ img_size: Union[int, Tuple[int, int]] = 224,
117
+ patch_size: Union[int, Tuple[int, int]] = 16,
118
+ in_chans: int = 3,
119
+ embed_dim: int = 768,
120
+ norm_layer: Optional[Callable] = None,
121
+ flatten_embedding: bool = True,
122
+ ) -> None:
123
+ super().__init__()
124
+
125
+ image_HW = make_2tuple(img_size)
126
+ patch_HW = make_2tuple(patch_size)
127
+ patch_grid_size = (
128
+ image_HW[0] // patch_HW[0],
129
+ image_HW[1] // patch_HW[1],
130
+ )
131
+
132
+ self.img_size = image_HW
133
+ self.patch_size = patch_HW
134
+ self.patches_resolution = patch_grid_size
135
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
136
+
137
+ self.in_chans = in_chans
138
+ self.embed_dim = embed_dim
139
+
140
+ self.flatten_embedding = flatten_embedding
141
+
142
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
143
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
144
+
145
+ def forward(self, x: Tensor) -> Tensor:
146
+ _, _, H, W = x.shape
147
+ patch_H, patch_W = self.patch_size
148
+
149
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
150
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
151
+
152
+ x = self.proj(x) # B C H W
153
+ H, W = x.size(2), x.size(3)
154
+ x = x.flatten(2).transpose(1, 2) # B HW C
155
+ x = self.norm(x)
156
+ if not self.flatten_embedding:
157
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
158
+ return x
159
+
160
+ def flops(self) -> float:
161
+ Ho, Wo = self.patches_resolution
162
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
163
+ if self.norm is not None:
164
+ flops += Ho * Wo * self.embed_dim
165
+ return flops
166
+
167
+ class Mlp(nn.Module):
168
+ def __init__(
169
+ self,
170
+ in_features: int,
171
+ hidden_features: Optional[int] = None,
172
+ out_features: Optional[int] = None,
173
+ act_layer: Callable[..., nn.Module] = nn.GELU,
174
+ drop: float = 0.0,
175
+ bias: bool = True,
176
+ ) -> None:
177
+ super().__init__()
178
+ out_features = out_features or in_features
179
+ hidden_features = hidden_features or in_features
180
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
181
+ self.act = act_layer()
182
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
183
+ self.drop = nn.Dropout(drop)
184
+
185
+ def forward(self, x: Tensor) -> Tensor:
186
+ x = self.fc1(x)
187
+ x = self.act(x)
188
+ x = self.drop(x)
189
+ x = self.fc2(x)
190
+ x = self.drop(x)
191
+ return x
192
+
193
+
194
+ class SwiGLUFFN(nn.Module):
195
+ def __init__(
196
+ self,
197
+ in_features: int,
198
+ hidden_features: Optional[int] = None,
199
+ out_features: Optional[int] = None,
200
+ act_layer: Callable[..., nn.Module] = None,
201
+ drop: float = 0.0,
202
+ bias: bool = True,
203
+ ) -> None:
204
+ super().__init__()
205
+ out_features = out_features or in_features
206
+ hidden_features = hidden_features or in_features
207
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
208
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
209
+
210
+ def forward(self, x: Tensor) -> Tensor:
211
+ x12 = self.w12(x)
212
+ x1, x2 = x12.chunk(2, dim=-1)
213
+ hidden = F.silu(x1) * x2
214
+ return self.w3(hidden)
215
+
216
+
217
+ try:
218
+ from xformers.ops import SwiGLU
219
+ #import numpy.bool
220
+ XFORMERS_AVAILABLE = True
221
+ except ImportError:
222
+ SwiGLU = SwiGLUFFN
223
+ XFORMERS_AVAILABLE = False
224
+
225
+ class SwiGLUFFNFused(SwiGLU):
226
+ def __init__(
227
+ self,
228
+ in_features: int,
229
+ hidden_features: Optional[int] = None,
230
+ out_features: Optional[int] = None,
231
+ act_layer: Callable[..., nn.Module] = None,
232
+ drop: float = 0.0,
233
+ bias: bool = True,
234
+ ) -> None:
235
+ out_features = out_features or in_features
236
+ hidden_features = hidden_features or in_features
237
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
238
+ super().__init__(
239
+ in_features=in_features,
240
+ hidden_features=hidden_features,
241
+ out_features=out_features,
242
+ bias=bias,
243
+ )
244
+
245
+
246
+ try:
247
+ from xformers.ops import memory_efficient_attention, unbind, fmha
248
+ from xformers.components.attention import ScaledDotProduct
249
+ from xformers.components import MultiHeadDispatch
250
+ #import numpy.bool
251
+ XFORMERS_AVAILABLE = True
252
+ except ImportError:
253
+ logger.warning("xFormers not available")
254
+ XFORMERS_AVAILABLE = False
255
+
256
+
257
+ class Attention(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim: int,
261
+ num_heads: int = 8,
262
+ qkv_bias: bool = False,
263
+ proj_bias: bool = True,
264
+ attn_drop: float = 0.0,
265
+ proj_drop: float = 0.0,
266
+ window_size: int = 0,
267
+ ) -> None:
268
+ super().__init__()
269
+ self.num_heads = num_heads
270
+ head_dim = dim // num_heads
271
+ self.scale = head_dim**-0.5
272
+
273
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
274
+ self.attn_drop = nn.Dropout(attn_drop)
275
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
276
+ self.proj_drop = nn.Dropout(proj_drop)
277
+
278
+ #if not self.training:
279
+ #
280
+ # self.attn = ScaledDotProduct()
281
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
282
+
283
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
284
+ B, N, C = x.shape
285
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
286
+
287
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
288
+ attn = q @ k.transpose(-2, -1)
289
+
290
+ if attn_bias is not None:
291
+ attn = attn + attn_bias[:, :, :N]
292
+
293
+ attn = attn.softmax(dim=-1)
294
+ attn = self.attn_drop(attn)
295
+
296
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
297
+ x = self.proj(x)
298
+ x = self.proj_drop(x)
299
+ return x
300
+
301
+
302
+ class MemEffAttention(Attention):
303
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
304
+ if not XFORMERS_AVAILABLE:
305
+ #if True:
306
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
307
+ return super().forward(x, attn_bias)
308
+
309
+ B, N, C = x.shape
310
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
311
+
312
+ q, k, v = unbind(qkv, 2)
313
+ if attn_bias is not None:
314
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
315
+ else:
316
+ x = memory_efficient_attention(q, k, v)
317
+ x = x.reshape([B, N, C])
318
+
319
+ x = self.proj(x)
320
+ x = self.proj_drop(x)
321
+ return x
322
+
323
+ try:
324
+ from xformers.ops import fmha
325
+ from xformers.ops import scaled_index_add, index_select_cat
326
+ #import numpy.bool
327
+ XFORMERS_AVAILABLE = True
328
+ except ImportError:
329
+ logger.warning("xFormers not available")
330
+ XFORMERS_AVAILABLE = False
331
+
332
+ class Block(nn.Module):
333
+ def __init__(
334
+ self,
335
+ dim: int,
336
+ num_heads: int,
337
+ mlp_ratio: float = 4.0,
338
+ qkv_bias: bool = False,
339
+ proj_bias: bool = True,
340
+ ffn_bias: bool = True,
341
+ drop: float = 0.0,
342
+ attn_drop: float = 0.0,
343
+ init_values = None,
344
+ drop_path: float = 0.0,
345
+ act_layer: Callable[..., nn.Module] = nn.GELU,
346
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
347
+ attn_class: Callable[..., nn.Module] = Attention,
348
+ ffn_layer: Callable[..., nn.Module] = Mlp,
349
+ ) -> None:
350
+ super().__init__()
351
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
352
+ self.norm1 = norm_layer(dim)
353
+ self.attn = attn_class(
354
+ dim,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ proj_bias=proj_bias,
358
+ attn_drop=attn_drop,
359
+ proj_drop=drop,
360
+ )
361
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
362
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
363
+
364
+ self.norm2 = norm_layer(dim)
365
+ mlp_hidden_dim = int(dim * mlp_ratio)
366
+ self.mlp = ffn_layer(
367
+ in_features=dim,
368
+ hidden_features=mlp_hidden_dim,
369
+ act_layer=act_layer,
370
+ drop=drop,
371
+ bias=ffn_bias,
372
+ )
373
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
374
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
375
+
376
+ self.sample_drop_ratio = drop_path
377
+
378
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
379
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
380
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
381
+
382
+ def ffn_residual_func(x: Tensor) -> Tensor:
383
+ return self.ls2(self.mlp(self.norm2(x)))
384
+
385
+ if self.training and self.sample_drop_ratio > 0.1:
386
+ # the overhead is compensated only for a drop path rate larger than 0.1
387
+ x = drop_add_residual_stochastic_depth(
388
+ x,
389
+ residual_func=attn_residual_func,
390
+ sample_drop_ratio=self.sample_drop_ratio,
391
+ attn_bias=attn_bias
392
+ )
393
+ x = drop_add_residual_stochastic_depth(
394
+ x,
395
+ residual_func=ffn_residual_func,
396
+ sample_drop_ratio=self.sample_drop_ratio,
397
+ )
398
+ elif self.training and self.sample_drop_ratio > 0.0:
399
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
400
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
401
+ else:
402
+ x = x + attn_residual_func(x, attn_bias)
403
+ x = x + ffn_residual_func(x)
404
+ return x
405
+
406
+
407
+ def drop_add_residual_stochastic_depth(
408
+ x: Tensor,
409
+ residual_func: Callable[[Tensor], Tensor],
410
+ sample_drop_ratio: float = 0.0, attn_bias=None
411
+ ) -> Tensor:
412
+ # 1) extract subset using permutation
413
+ b, n, d = x.shape
414
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
415
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
416
+ x_subset = x[brange]
417
+
418
+ # 2) apply residual_func to get residual
419
+ residual = residual_func(x_subset, attn_bias)
420
+
421
+ x_flat = x.flatten(1)
422
+ residual = residual.flatten(1)
423
+
424
+ residual_scale_factor = b / sample_subset_size
425
+
426
+ # 3) add the residual
427
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
428
+ return x_plus_residual.view_as(x)
429
+
430
+
431
+ def get_branges_scales(x, sample_drop_ratio=0.0):
432
+ b, n, d = x.shape
433
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
434
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
435
+ residual_scale_factor = b / sample_subset_size
436
+ return brange, residual_scale_factor
437
+
438
+
439
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
440
+ if scaling_vector is None:
441
+ x_flat = x.flatten(1)
442
+ residual = residual.flatten(1)
443
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
444
+ else:
445
+ x_plus_residual = scaled_index_add(
446
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
447
+ )
448
+ return x_plus_residual
449
+
450
+
451
+ attn_bias_cache: Dict[Tuple, Any] = {}
452
+
453
+
454
+ def get_attn_bias_and_cat(x_list, branges=None):
455
+ """
456
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
457
+ """
458
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
459
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
460
+ if all_shapes not in attn_bias_cache.keys():
461
+ seqlens = []
462
+ for b, x in zip(batch_sizes, x_list):
463
+ for _ in range(b):
464
+ seqlens.append(x.shape[1])
465
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
466
+ attn_bias._batch_sizes = batch_sizes
467
+ attn_bias_cache[all_shapes] = attn_bias
468
+
469
+ if branges is not None:
470
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
471
+ else:
472
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
473
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
474
+
475
+ return attn_bias_cache[all_shapes], cat_tensors
476
+
477
+
478
+ def drop_add_residual_stochastic_depth_list(
479
+ x_list: List[Tensor],
480
+ residual_func: Callable[[Tensor, Any], Tensor],
481
+ sample_drop_ratio: float = 0.0,
482
+ scaling_vector=None,
483
+ ) -> Tensor:
484
+ # 1) generate random set of indices for dropping samples in the batch
485
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
486
+ branges = [s[0] for s in branges_scales]
487
+ residual_scale_factors = [s[1] for s in branges_scales]
488
+
489
+ # 2) get attention bias and index+concat the tensors
490
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
491
+
492
+ # 3) apply residual_func to get residual, and split the result
493
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
494
+
495
+ outputs = []
496
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
497
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
498
+ return outputs
499
+
500
+
501
+ class NestedTensorBlock(Block):
502
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
503
+ """
504
+ x_list contains a list of tensors to nest together and run
505
+ """
506
+ assert isinstance(self.attn, MemEffAttention)
507
+
508
+ if self.training and self.sample_drop_ratio > 0.0:
509
+
510
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
511
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
512
+
513
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
514
+ return self.mlp(self.norm2(x))
515
+
516
+ x_list = drop_add_residual_stochastic_depth_list(
517
+ x_list,
518
+ residual_func=attn_residual_func,
519
+ sample_drop_ratio=self.sample_drop_ratio,
520
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
521
+ )
522
+ x_list = drop_add_residual_stochastic_depth_list(
523
+ x_list,
524
+ residual_func=ffn_residual_func,
525
+ sample_drop_ratio=self.sample_drop_ratio,
526
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
527
+ )
528
+ return x_list
529
+ else:
530
+
531
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
532
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
533
+
534
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
535
+ return self.ls2(self.mlp(self.norm2(x)))
536
+
537
+ attn_bias, x = get_attn_bias_and_cat(x_list)
538
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
539
+ x = x + ffn_residual_func(x)
540
+ return attn_bias.split(x)
541
+
542
+ def forward(self, x_or_x_list, attn_bias=None):
543
+ if isinstance(x_or_x_list, Tensor):
544
+ return super().forward(x_or_x_list, attn_bias)
545
+ elif isinstance(x_or_x_list, list):
546
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
547
+ return self.forward_nested(x_or_x_list)
548
+ else:
549
+ raise AssertionError
550
+
551
+
552
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
553
+ if not depth_first and include_root:
554
+ fn(module=module, name=name)
555
+ for child_name, child_module in module.named_children():
556
+ child_name = ".".join((name, child_name)) if name else child_name
557
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
558
+ if depth_first and include_root:
559
+ fn(module=module, name=name)
560
+ return module
561
+
562
+
563
+ class BlockChunk(nn.ModuleList):
564
+ def forward(self, x, others=None):
565
+ for b in self:
566
+ if others == None:
567
+ x = b(x)
568
+ else:
569
+ x = b(x, others)
570
+ return x
571
+
572
+
573
+ class DinoVisionTransformer(nn.Module):
574
+ def __init__(
575
+ self,
576
+ img_size=224,
577
+ patch_size=16,
578
+ in_chans=3,
579
+ embed_dim=768,
580
+ depth=12,
581
+ num_heads=12,
582
+ mlp_ratio=4.0,
583
+ qkv_bias=True,
584
+ ffn_bias=True,
585
+ proj_bias=True,
586
+ drop_path_rate=0.0,
587
+ drop_path_uniform=False,
588
+ #init_values=None, # for layerscale: None or 0 => no layerscale
589
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
590
+ embed_layer=PatchEmbed,
591
+ act_layer=nn.GELU,
592
+ block_fn=NestedTensorBlock,
593
+ ffn_layer="mlp",
594
+ block_chunks=1,
595
+ window_size=37,
596
+ **kwargs
597
+ ):
598
+ """
599
+ Args:
600
+ img_size (int, tuple): input image size
601
+ patch_size (int, tuple): patch size
602
+ in_chans (int): number of input channels
603
+ embed_dim (int): embedding dimension
604
+ depth (int): depth of transformer
605
+ num_heads (int): number of attention heads
606
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
607
+ qkv_bias (bool): enable bias for qkv if True
608
+ proj_bias (bool): enable bias for proj in attn if True
609
+ ffn_bias (bool): enable bias for ffn if True
610
+ drop_path_rate (float): stochastic depth rate
611
+ drop_path_uniform (bool): apply uniform drop rate across blocks
612
+ weight_init (str): weight init scheme
613
+ init_values (float): layer-scale init values
614
+ embed_layer (nn.Module): patch embedding layer
615
+ act_layer (nn.Module): MLP activation layer
616
+ block_fn (nn.Module): transformer block class
617
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
618
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
619
+ """
620
+ super().__init__()
621
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
622
+
623
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
624
+ self.num_tokens = 1
625
+ self.n_blocks = depth
626
+ self.num_heads = num_heads
627
+ self.patch_size = patch_size
628
+ self.window_size = window_size
629
+
630
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
631
+ num_patches = self.patch_embed.num_patches
632
+
633
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
634
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
635
+
636
+ if drop_path_uniform is True:
637
+ dpr = [drop_path_rate] * depth
638
+ else:
639
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
640
+
641
+ if ffn_layer == "mlp":
642
+ logger.info("using MLP layer as FFN")
643
+ ffn_layer = Mlp
644
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
645
+ logger.info("using SwiGLU layer as FFN")
646
+ ffn_layer = SwiGLUFFNFused
647
+ elif ffn_layer == "identity":
648
+ logger.info("using Identity layer as FFN")
649
+
650
+ def f(*args, **kwargs):
651
+ return nn.Identity()
652
+
653
+ ffn_layer = f
654
+ else:
655
+ raise NotImplementedError
656
+
657
+ blocks_list = [
658
+ block_fn(
659
+ dim=embed_dim,
660
+ num_heads=num_heads,
661
+ mlp_ratio=mlp_ratio,
662
+ qkv_bias=qkv_bias,
663
+ proj_bias=proj_bias,
664
+ ffn_bias=ffn_bias,
665
+ drop_path=dpr[i],
666
+ norm_layer=norm_layer,
667
+ act_layer=act_layer,
668
+ ffn_layer=ffn_layer,
669
+ init_values=init_values,
670
+ )
671
+ for i in range(depth)
672
+ ]
673
+ if block_chunks > 0:
674
+ self.chunked_blocks = True
675
+ chunked_blocks = []
676
+ chunksize = depth // block_chunks
677
+ for i in range(0, depth, chunksize):
678
+ # this is to keep the block index consistent if we chunk the block list
679
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
680
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
681
+ else:
682
+ self.chunked_blocks = False
683
+ self.blocks = nn.ModuleList(blocks_list)
684
+
685
+ self.norm = norm_layer(embed_dim)
686
+ self.head = nn.Identity()
687
+
688
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
689
+
690
+ self.init_weights()
691
+
692
+ def init_weights(self):
693
+ trunc_normal_(self.pos_embed, std=0.02)
694
+ nn.init.normal_(self.cls_token, std=1e-6)
695
+ named_apply(init_weights_vit_timm, self)
696
+
697
+ def interpolate_pos_encoding(self, x, w, h):
698
+ previous_dtype = x.dtype
699
+ npatch = x.shape[1] - 1
700
+ N = self.pos_embed.shape[1] - 1
701
+ if npatch == N and w == h:
702
+ return self.pos_embed
703
+ pos_embed = self.pos_embed.float()
704
+ class_pos_embed = pos_embed[:, 0]
705
+ patch_pos_embed = pos_embed[:, 1:]
706
+ dim = x.shape[-1]
707
+ w0 = w // self.patch_size
708
+ h0 = h // self.patch_size
709
+ # we add a small number to avoid floating point error in the interpolation
710
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
711
+ w0, h0 = w0 + 0.1, h0 + 0.1
712
+
713
+ patch_pos_embed = nn.functional.interpolate(
714
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
715
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
716
+ mode="bicubic",
717
+ )
718
+
719
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
720
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
721
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
722
+
723
+ def prepare_tokens_with_masks(self, x, masks=None):
724
+ B, nc, w, h = x.shape
725
+ x = self.patch_embed(x)
726
+ if masks is not None:
727
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
728
+
729
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
730
+ x = x + self.interpolate_pos_encoding(x, w, h)
731
+
732
+ return x
733
+
734
+ def forward_features_list(self, x_list, masks_list):
735
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
736
+ for blk in self.blocks:
737
+ x = blk(x)
738
+
739
+ all_x = x
740
+ output = []
741
+ for x, masks in zip(all_x, masks_list):
742
+ x_norm = self.norm(x)
743
+ output.append(
744
+ {
745
+ "x_norm_clstoken": x_norm[:, 0],
746
+ "x_norm_patchtokens": x_norm[:, 1:],
747
+ "x_prenorm": x,
748
+ "masks": masks,
749
+ }
750
+ )
751
+ return output
752
+
753
+ def forward_features(self, x, masks=None):
754
+ if isinstance(x, list):
755
+ return self.forward_features_list(x, masks)
756
+
757
+ B, C, H, W = x.size()
758
+ pad_h = (self.patch_size - H % self.patch_size)
759
+ pad_w = (self.patch_size - W % self.patch_size)
760
+ if pad_h == self.patch_size:
761
+ pad_h = 0
762
+ if pad_w == self.patch_size:
763
+ pad_w = 0
764
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
765
+ if pad_h + pad_w > 0:
766
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
767
+
768
+ x = self.prepare_tokens_with_masks(x, masks)
769
+
770
+ features = []
771
+ for blk in self.blocks:
772
+ x = blk(x)
773
+ # for idx in range(len(self.blocks[0])):
774
+ # x = self.blocks[0][idx](x)
775
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
776
+ # features.append(x)
777
+
778
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
779
+
780
+ x_norm = self.norm(x)
781
+ # return {
782
+ # "x_norm_clstoken": x_norm[:, 0],
783
+ # "x_norm_patchtokens": x_norm[:, 1:],
784
+ # "x_prenorm": x,
785
+ # "masks": masks,
786
+ # }
787
+ features = []
788
+ features.append(x_norm)
789
+ features.append(x_norm)
790
+ features.append(x_norm)
791
+ features.append(x_norm)
792
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
793
+
794
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
795
+ x = self.prepare_tokens_with_masks(x)
796
+ # If n is an int, take the n last blocks. If it's a list, take them
797
+ output, total_block_len = [], len(self.blocks)
798
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
799
+ for i, blk in enumerate(self.blocks):
800
+ x = blk(x)
801
+ if i in blocks_to_take:
802
+ output.append(x)
803
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
804
+ return output
805
+
806
+ def _get_intermediate_layers_chunked(self, x, n=1):
807
+ x = self.prepare_tokens_with_masks(x)
808
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
809
+ # If n is an int, take the n last blocks. If it's a list, take them
810
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
811
+ for block_chunk in self.blocks:
812
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
813
+ x = blk(x)
814
+ if i in blocks_to_take:
815
+ output.append(x)
816
+ i += 1
817
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
818
+ return output
819
+
820
+ def get_intermediate_layers(
821
+ self,
822
+ x: torch.Tensor,
823
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
824
+ reshape: bool = False,
825
+ return_class_token: bool = False,
826
+ norm=True,
827
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
828
+ if self.chunked_blocks:
829
+ outputs = self._get_intermediate_layers_chunked(x, n)
830
+ else:
831
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
832
+ if norm:
833
+ outputs = [self.norm(out) for out in outputs]
834
+ class_tokens = [out[:, 0] for out in outputs]
835
+ outputs = [out[:, 1:] for out in outputs]
836
+ if reshape:
837
+ B, _, w, h = x.shape
838
+ outputs = [
839
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
840
+ for out in outputs
841
+ ]
842
+ if return_class_token:
843
+ return tuple(zip(outputs, class_tokens))
844
+ return tuple(outputs)
845
+
846
+ def forward(self, *args, is_training=False, **kwargs):
847
+ ret = self.forward_features(*args, **kwargs)
848
+ return ret
849
+ # if is_training:
850
+ # return ret
851
+ # else:
852
+ # return self.head(ret["x_norm_clstoken"])
853
+
854
+
855
+ class PosConv(nn.Module):
856
+ # PEG from https://arxiv.org/abs/2102.10882
857
+ def __init__(self, in_chans, embed_dim=768, stride=1):
858
+ super(PosConv, self).__init__()
859
+ self.proj = nn.Sequential(
860
+ nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim),
861
+ )
862
+ self.stride = stride
863
+
864
+ def forward(self, x, size):
865
+ B, N, C = x.shape
866
+ cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
867
+ x = self.proj(cnn_feat_token)
868
+ if self.stride == 1:
869
+ x += cnn_feat_token
870
+ x = x.flatten(2).transpose(1, 2)
871
+ return x
872
+
873
+ #def no_weight_decay(self):
874
+ #return ['proj.%d.weight' % i for i in range(4)]
875
+
876
+ class DinoWindowVisionTransformer(nn.Module):
877
+ def __init__(
878
+ self,
879
+ img_size=224,
880
+ patch_size=16,
881
+ in_chans=3,
882
+ embed_dim=768,
883
+ depth=12,
884
+ num_heads=12,
885
+ mlp_ratio=4.0,
886
+ qkv_bias=True,
887
+ ffn_bias=True,
888
+ proj_bias=True,
889
+ drop_path_rate=0.0,
890
+ drop_path_uniform=False,
891
+ #init_values=None, # for layerscale: None or 0 => no layerscale
892
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
893
+ embed_layer=PatchEmbed,
894
+ act_layer=nn.GELU,
895
+ block_fn=NestedTensorBlock,
896
+ ffn_layer="mlp",
897
+ block_chunks=1,
898
+ window_size=7,
899
+ **kwargs
900
+ ):
901
+ """
902
+ Args:
903
+ img_size (int, tuple): input image size
904
+ patch_size (int, tuple): patch size
905
+ in_chans (int): number of input channels
906
+ embed_dim (int): embedding dimension
907
+ depth (int): depth of transformer
908
+ num_heads (int): number of attention heads
909
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
910
+ qkv_bias (bool): enable bias for qkv if True
911
+ proj_bias (bool): enable bias for proj in attn if True
912
+ ffn_bias (bool): enable bias for ffn if True
913
+ drop_path_rate (float): stochastic depth rate
914
+ drop_path_uniform (bool): apply uniform drop rate across blocks
915
+ weight_init (str): weight init scheme
916
+ init_values (float): layer-scale init values
917
+ embed_layer (nn.Module): patch embedding layer
918
+ act_layer (nn.Module): MLP activation layer
919
+ block_fn (nn.Module): transformer block class
920
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
921
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
922
+ """
923
+ super().__init__()
924
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
925
+
926
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
927
+ self.num_tokens = 1
928
+ self.n_blocks = depth
929
+ self.num_heads = num_heads
930
+ self.patch_size = patch_size
931
+
932
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
933
+ num_patches = self.patch_embed.num_patches
934
+
935
+ #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
936
+ #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
937
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
938
+
939
+ self.pos_conv = PosConv(self.embed_dim, self.embed_dim)
940
+
941
+ self.window_size = window_size
942
+ #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)])
943
+ #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)])
944
+
945
+ if drop_path_uniform is True:
946
+ dpr = [drop_path_rate] * depth
947
+ else:
948
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
949
+
950
+ if ffn_layer == "mlp":
951
+ logger.info("using MLP layer as FFN")
952
+ ffn_layer = Mlp
953
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
954
+ logger.info("using SwiGLU layer as FFN")
955
+ ffn_layer = SwiGLUFFNFused
956
+ elif ffn_layer == "identity":
957
+ logger.info("using Identity layer as FFN")
958
+
959
+ def f(*args, **kwargs):
960
+ return nn.Identity()
961
+
962
+ ffn_layer = f
963
+ else:
964
+ raise NotImplementedError
965
+
966
+ blocks_list = [
967
+ block_fn(
968
+ dim=embed_dim,
969
+ num_heads=num_heads,
970
+ mlp_ratio=mlp_ratio,
971
+ qkv_bias=qkv_bias,
972
+ proj_bias=proj_bias,
973
+ ffn_bias=ffn_bias,
974
+ drop_path=dpr[i],
975
+ norm_layer=norm_layer,
976
+ act_layer=act_layer,
977
+ ffn_layer=ffn_layer,
978
+ init_values=init_values,
979
+ )
980
+ for i in range(depth)
981
+ ]
982
+ if block_chunks > 0:
983
+ self.chunked_blocks = True
984
+ chunked_blocks = []
985
+ chunksize = depth // block_chunks
986
+ for i in range(0, depth, chunksize):
987
+ # this is to keep the block index consistent if we chunk the block list
988
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
989
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
990
+ else:
991
+ self.chunked_blocks = False
992
+ self.blocks = nn.ModuleList(blocks_list)
993
+
994
+ self.norm = norm_layer(embed_dim)
995
+ self.head = nn.Identity()
996
+
997
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
998
+
999
+ self.nh = -1
1000
+ self.nw = -1
1001
+ try:
1002
+ H = cfg.data_basic['crop_size'][0]
1003
+ W = cfg.data_basic['crop_size'][1]
1004
+ pad_h = (self.patch_size - H % self.patch_size)
1005
+ pad_w = (self.patch_size - W % self.patch_size)
1006
+ if pad_h == self.patch_size:
1007
+ pad_h = 0
1008
+ if pad_w == self.patch_size:
1009
+ pad_w = 0
1010
+ self.nh = (H + pad_h) // self.patch_size
1011
+ self.nw = (W + pad_w) // self.patch_size
1012
+ self.prepare_attn_bias((self.nh, self.nw))
1013
+ except:
1014
+ pass
1015
+ self.init_weights()
1016
+
1017
+ self.total_step = 10000 # For PE -> GPE transfer
1018
+ self.start_step = 2000
1019
+ self.current_step = 20000
1020
+
1021
+ def init_weights(self):
1022
+ #trunc_normal_(self.pos_embed, std=0.02)
1023
+ #nn.init.normal_(self.cls_token, std=1e-6)
1024
+ named_apply(init_weights_vit_timm, self)
1025
+ for i in range(4):
1026
+ try:
1027
+ nn.init.constant_(self.conv_block[i].conv2.weight, 0.0)
1028
+ except:
1029
+ pass
1030
+
1031
+ def interpolate_pos_encoding(self, x, w, h):
1032
+ previous_dtype = x.dtype
1033
+ #npatch = x.shape[1] - 1
1034
+ #N = self.pos_embed.shape[1] - 1
1035
+ npatch = x.shape[1]
1036
+ N = self.pos_embed.shape[1]
1037
+ if npatch == N and w == h:
1038
+ return self.pos_embed
1039
+ pos_embed = self.pos_embed.float()
1040
+ #class_pos_embed = pos_embed[:, 0]
1041
+ #patch_pos_embed = pos_embed[:, 1:]
1042
+ patch_pos_embed = pos_embed
1043
+ dim = x.shape[-1]
1044
+ w0 = w // self.patch_size
1045
+ h0 = h // self.patch_size
1046
+ # we add a small number to avoid floating point error in the interpolation
1047
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1048
+ w0, h0 = w0 + 0.1, h0 + 0.1
1049
+
1050
+ patch_pos_embed = nn.functional.interpolate(
1051
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
1052
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
1053
+ mode="bicubic",
1054
+ )
1055
+
1056
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
1057
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1058
+ return patch_pos_embed.to(previous_dtype)
1059
+ #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
1060
+
1061
+ def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]:
1062
+ """
1063
+ Partition into non-overlapping windows with padding if needed.
1064
+ Args:
1065
+ x (tensor): input tokens with [B, H, W, C].
1066
+ window_size (int): window size.
1067
+
1068
+ Returns:
1069
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
1070
+ (Hp, Wp): padded height and width before partition
1071
+ """
1072
+ if conv_feature == False:
1073
+ B, N, C = x.shape
1074
+ H, W = hw[0], hw[1]
1075
+
1076
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1077
+
1078
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
1079
+ else:
1080
+ B, C, H, W = x.shape
1081
+
1082
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
1083
+
1084
+ windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C)
1085
+
1086
+ #y = torch.cat((x_cls, windows), dim=1)
1087
+ return windows #, (Hp, Wp)
1088
+
1089
+
1090
+ def window_unpartition(self,
1091
+ windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False
1092
+ ) -> torch.Tensor:
1093
+ """
1094
+ Window unpartition into original sequences and removing padding.
1095
+ Args:
1096
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
1097
+ window_size (int): window size.
1098
+ pad_hw (Tuple): padded height and width (Hp, Wp).
1099
+ hw (Tuple): original height and width (H, W) before padding.
1100
+
1101
+ Returns:
1102
+ x: unpartitioned sequences with [B, H, W, C].
1103
+ """
1104
+ H, W = hw
1105
+
1106
+ B = windows.shape[0] // (H * W // window_size // window_size)
1107
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1108
+
1109
+ if conv_feature == False:
1110
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1)
1111
+ else:
1112
+ C = windows.shape[-1]
1113
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
1114
+
1115
+ # if Hp > H or Wp > W:
1116
+ # x = x[:, :H, :W, :].contiguous()
1117
+ return x
1118
+
1119
+ def prepare_tokens_with_masks(self, x, masks=None, step=-1):
1120
+ B, nc, w, h = x.shape
1121
+ x = self.patch_embed(x)
1122
+ if masks is not None:
1123
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
1124
+
1125
+ #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
1126
+ if step == -1:
1127
+ step = self.current_step
1128
+ else:
1129
+ self.current_step = step
1130
+
1131
+ if step < self.start_step:
1132
+ coef = 0.0
1133
+ elif step < self.total_step:
1134
+ coef = (step - self.start_step) / (self.total_step - self.start_step)
1135
+ else:
1136
+ coef = 1.0
1137
+
1138
+ x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw))
1139
+
1140
+ return x
1141
+
1142
+ def prepare_attn_bias(self, shape):
1143
+ window_size = self.window_size
1144
+ if window_size <= 0:
1145
+ return
1146
+
1147
+ import xformers.components.attention.attention_patterns as AP
1148
+
1149
+ nh, nw = shape
1150
+ radius = (window_size-1)//2
1151
+ mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1152
+
1153
+ pad = (8 - (nh * nw) % 8)
1154
+ if pad == 8:
1155
+ pad = 0
1156
+ mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous()
1157
+ if pad > 0:
1158
+ mask = mask_pad[:, :-pad].view(nh, nw, nh, nw)
1159
+ else:
1160
+ mask = mask_pad[:, :].view(nh, nw, nh, nw)
1161
+
1162
+ # angle
1163
+ mask[:radius+1, :radius+1, :window_size, :window_size] = True
1164
+ mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
1165
+ mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
1166
+ mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
1167
+
1168
+ # edge
1169
+ mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
1170
+ mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
1171
+ mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
1172
+ mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
1173
+
1174
+ mask = mask.view(nh*nw, nh*nw)
1175
+ bias_pad = torch.log(mask_pad)
1176
+ #bias = bias_pad[:, :-pad]
1177
+ self.register_buffer('attn_bias', bias_pad)
1178
+
1179
+ return bias_pad
1180
+
1181
+ def forward_features_list(self, x_list, masks_list):
1182
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
1183
+ for blk in self.blocks:
1184
+ x = blk(x)
1185
+
1186
+ all_x = x
1187
+ output = []
1188
+ for x, masks in zip(all_x, masks_list):
1189
+ x_norm = self.norm(x)
1190
+ output.append(
1191
+ {
1192
+ "x_norm_clstoken": x_norm[:, 0],
1193
+ "x_norm_patchtokens": x_norm[:, 1:],
1194
+ "x_prenorm": x,
1195
+ "masks": masks,
1196
+ }
1197
+ )
1198
+ return output
1199
+
1200
+ def forward_features(self, x, masks=None, **kwargs):
1201
+ if isinstance(x, list):
1202
+ return self.forward_features_list(x, masks)
1203
+
1204
+ B, C, H, W = x.size()
1205
+ pad_h = (self.patch_size - H % self.patch_size)
1206
+ pad_w = (self.patch_size - W % self.patch_size)
1207
+ if pad_h == self.patch_size:
1208
+ pad_h = 0
1209
+ if pad_w == self.patch_size:
1210
+ pad_w = 0
1211
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
1212
+ if pad_h + pad_w > 0:
1213
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
1214
+
1215
+ nh = (H+pad_h)//self.patch_size
1216
+ nw = (W+pad_w)//self.patch_size
1217
+
1218
+ if self.window_size > 0:
1219
+ if nh == self.nh and nw == self.nw:
1220
+ attn_bias = self.attn_bias
1221
+ else:
1222
+ attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size))
1223
+ self.nh = nh
1224
+ self.nw = nw
1225
+ attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1)
1226
+ else:
1227
+ attn_bias = None
1228
+
1229
+ x = self.prepare_tokens_with_masks(x, masks)
1230
+ #x = self.patch_embed(x)
1231
+
1232
+ features = []
1233
+ #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
1234
+ for blk in self.blocks:
1235
+ x = blk(x, attn_bias)
1236
+ #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
1237
+
1238
+ # for idx in range(len(self.blocks[0])):
1239
+ # x = self.blocks[0][idx](x, attn_bias)
1240
+
1241
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
1242
+ # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
1243
+ # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x)
1244
+ # if idx + 1 != len(self.blocks[0]):
1245
+ # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
1246
+ # else:
1247
+ # b, c, h, w = x.size()
1248
+ # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c)
1249
+ #features.append(x)
1250
+
1251
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
1252
+
1253
+ x_norm = self.norm(x)
1254
+ # return {
1255
+ # "x_norm_clstoken": x_norm[:, 0],
1256
+ # "x_norm_patchtokens": x_norm[:, 1:],
1257
+ # "x_prenorm": x,
1258
+ # "masks": masks,
1259
+ # }
1260
+ features = []
1261
+ features.append(x_norm)
1262
+ features.append(x_norm)
1263
+ features.append(x_norm)
1264
+ features.append(x_norm)
1265
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
1266
+
1267
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
1268
+ x = self.prepare_tokens_with_masks(x)
1269
+ # If n is an int, take the n last blocks. If it's a list, take them
1270
+ output, total_block_len = [], len(self.blocks)
1271
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1272
+ for i, blk in enumerate(self.blocks):
1273
+ x = blk(x)
1274
+ if i in blocks_to_take:
1275
+ output.append(x)
1276
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1277
+ return output
1278
+
1279
+ def _get_intermediate_layers_chunked(self, x, n=1):
1280
+ x = self.prepare_tokens_with_masks(x)
1281
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
1282
+ # If n is an int, take the n last blocks. If it's a list, take them
1283
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1284
+ for block_chunk in self.blocks:
1285
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
1286
+ x = blk(x)
1287
+ if i in blocks_to_take:
1288
+ output.append(x)
1289
+ i += 1
1290
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1291
+ return output
1292
+
1293
+ def get_intermediate_layers(
1294
+ self,
1295
+ x: torch.Tensor,
1296
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
1297
+ reshape: bool = False,
1298
+ return_class_token: bool = False,
1299
+ norm=True,
1300
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
1301
+ if self.chunked_blocks:
1302
+ outputs = self._get_intermediate_layers_chunked(x, n)
1303
+ else:
1304
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
1305
+ if norm:
1306
+ outputs = [self.norm(out) for out in outputs]
1307
+ class_tokens = [out[:, 0] for out in outputs]
1308
+ outputs = [out[:, 1:] for out in outputs]
1309
+ if reshape:
1310
+ B, _, w, h = x.shape
1311
+ outputs = [
1312
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
1313
+ for out in outputs
1314
+ ]
1315
+ if return_class_token:
1316
+ return tuple(zip(outputs, class_tokens))
1317
+ return tuple(outputs)
1318
+
1319
+ def forward(self, *args, is_training=False, **kwargs):
1320
+ ret = self.forward_features(*args, **kwargs)
1321
+ return ret
1322
+ # if is_training:
1323
+ # return ret
1324
+ # else:
1325
+ # return self.head(ret["x_norm_clstoken"])
1326
+
1327
+
1328
+
1329
+
1330
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
1331
+ """ViT weight initialization, original timm impl (for reproducibility)"""
1332
+ if isinstance(module, nn.Linear):
1333
+ trunc_normal_(module.weight, std=0.02)
1334
+ if module.bias is not None:
1335
+ nn.init.zeros_(module.bias)
1336
+
1337
+
1338
+ def vit_small(patch_size=14, **kwargs):
1339
+ model = DinoVisionTransformer(
1340
+ patch_size=patch_size,
1341
+ embed_dim=384,
1342
+ depth=12,
1343
+ num_heads=6,
1344
+ mlp_ratio=4,
1345
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1346
+ **kwargs,
1347
+ )
1348
+ return model
1349
+
1350
+
1351
+ def vit_base(patch_size=14, **kwargs):
1352
+ model = DinoWindowVisionTransformer(
1353
+ patch_size=patch_size,
1354
+ embed_dim=768,
1355
+ depth=12,
1356
+ num_heads=12,
1357
+ mlp_ratio=4,
1358
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1359
+ **kwargs,
1360
+ )
1361
+ return model
1362
+
1363
+
1364
+ def vit_large(patch_size=14, checkpoint=None, **kwargs):
1365
+ model = DinoVisionTransformer(
1366
+ img_size = 518,
1367
+ patch_size=patch_size,
1368
+ embed_dim=1024,
1369
+ depth=24,
1370
+ num_heads=16,
1371
+ mlp_ratio=4,
1372
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1373
+ **kwargs,
1374
+ )
1375
+
1376
+ if checkpoint is not None:
1377
+ with open(checkpoint, "rb") as f:
1378
+ state_dict = torch.load(f)
1379
+ try:
1380
+ model.load_state_dict(state_dict, strict=True)
1381
+ except:
1382
+ new_state_dict = {}
1383
+ for key, value in state_dict.items():
1384
+ if 'blocks' in key:
1385
+ key_new = 'blocks.0' + key[len('blocks'):]
1386
+ else:
1387
+ key_new = key
1388
+ new_state_dict[key_new] = value
1389
+
1390
+ model.load_state_dict(new_state_dict, strict=True)
1391
+ #del model.norm
1392
+ del model.mask_token
1393
+ return model
1394
+
1395
+ # model = DinoWindowVisionTransformer(
1396
+ # img_size = 518,
1397
+ # patch_size=patch_size,
1398
+ # embed_dim=1024,
1399
+ # depth=24,
1400
+ # num_heads=16,
1401
+ # mlp_ratio=4,
1402
+ # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1403
+ # window_size=37,
1404
+ # **kwargs,
1405
+ # )
1406
+
1407
+ # if checkpoint is not None:
1408
+ # with open(checkpoint, "rb") as f:
1409
+ # state_dict = torch.load(f)
1410
+ # try:
1411
+ # model.load_state_dict(state_dict, strict=True)
1412
+ # except:
1413
+ # new_state_dict = {}
1414
+ # for key, value in state_dict.items():
1415
+ # if 'blocks' in key:
1416
+ # key_new = 'blocks.0' + key[len('blocks'):]
1417
+ # else:
1418
+ # key_new = key
1419
+ # if 'pos_embed' in key:
1420
+ # value = value[:, 1:, :]
1421
+ # new_state_dict[key_new] = value
1422
+
1423
+ # model.load_state_dict(new_state_dict, strict=False)
1424
+ # #del model.norm
1425
+ # del model.mask_token
1426
+ return model
1427
+
1428
+
1429
+ def vit_giant2(patch_size=16, **kwargs):
1430
+ """
1431
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1432
+ """
1433
+ model = DinoVisionTransformer(
1434
+ patch_size=patch_size,
1435
+ embed_dim=1536,
1436
+ depth=40,
1437
+ num_heads=24,
1438
+ mlp_ratio=4,
1439
+ block_fn=partial(Block, attn_class=MemEffAttention),
1440
+ **kwargs,
1441
+ )
1442
+ return model
1443
+
1444
+ if __name__ == '__main__':
1445
+ try:
1446
+ from mmcv.utils import Config
1447
+ except:
1448
+ from mmengine import Config
1449
+
1450
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
1451
+
1452
+ #cfg.data_basic['crop_size']['0']
1453
+ #cfg.data_basic['crop_size']['1']
1454
+ cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py')
1455
+
1456
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
1457
+ rgb = torch.zeros(1, 3, 1400, 1680).cuda()
1458
+ model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda()
1459
+
1460
+ #import timm
1461
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
1462
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
1463
+
1464
+ out1 = model(rgb)
1465
+ #out2 = model2(rgb)
1466
+ temp = 0
1467
+
1468
+
1469
+
1470
+ # import time
1471
+ # window_size = 37
1472
+ # def prepare_window_masks(shape):
1473
+ # if window_size <= 0:
1474
+ # return None
1475
+ # import xformers.components.attention.attention_patterns as AP
1476
+
1477
+ # B, nh, nw, _, _ = shape
1478
+ # radius = (window_size-1)//2
1479
+ # #time0 = time.time()
1480
+ # d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1481
+ # #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1482
+ # # mask = mask.view(nh, nw, nh, nw)
1483
+ # # #time1 = time.time() - time0
1484
+
1485
+ # # # angle
1486
+ # # mask[:radius+1, :radius+1, :window_size, :window_size] = True
1487
+ # # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
1488
+ # # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
1489
+ # # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
1490
+ # # time2 = time.time() - time0 - time1
1491
+
1492
+ # # # edge
1493
+ # # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
1494
+ # # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
1495
+ # # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
1496
+ # # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
1497
+ # # time3 = time.time() - time0 - time2
1498
+ # # print(time1, time2, time3)
1499
+
1500
+ # # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1)
1501
+
1502
+ # shape = (1, 55, 55, None, None)
1503
+ # mask = prepare_window_masks(shape)
1504
+ # # temp = 1
mono/model/backbones/ViT_DINO_reg.py ADDED
@@ -0,0 +1,1293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ from functools import partial
12
+ import math
13
+ import logging
14
+ from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torch.utils.checkpoint
20
+ from torch.nn.init import trunc_normal_
21
+ import torch.nn.init
22
+ import torch.nn.functional as F
23
+
24
+ #from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
25
+
26
+ logger = logging.getLogger("dinov2")
27
+
28
+ # SSF finetuning originally by dongzelian
29
+ def init_ssf_scale_shift(dim):
30
+ scale = nn.Parameter(torch.ones(dim))
31
+ shift = nn.Parameter(torch.zeros(dim))
32
+
33
+ nn.init.normal_(scale, mean=1, std=.02)
34
+ nn.init.normal_(shift, std=.02)
35
+
36
+ return scale, shift
37
+
38
+ def ssf_ada(x, scale, shift):
39
+ assert scale.shape == shift.shape
40
+ if x.shape[-1] == scale.shape[0]:
41
+ return x * scale + shift
42
+ elif x.shape[1] == scale.shape[0]:
43
+ return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
44
+ else:
45
+ raise ValueError('the input tensor shape does not match the shape of the scale factor.')
46
+
47
+ # LoRA finetuning originally by edwardjhu
48
+ class LoRALayer():
49
+ def __init__(
50
+ self,
51
+ r: int,
52
+ lora_alpha: int,
53
+ lora_dropout: float,
54
+ merge_weights: bool,
55
+ ):
56
+ self.r = r
57
+ self.lora_alpha = lora_alpha
58
+ # Optional dropout
59
+ if lora_dropout > 0.:
60
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
61
+ else:
62
+ self.lora_dropout = lambda x: x
63
+ # Mark the weight as unmerged
64
+ self.merged = False
65
+ self.merge_weights = merge_weights
66
+
67
+ class LoRALinear(nn.Linear, LoRALayer):
68
+ # LoRA implemented in a dense layer
69
+ def __init__(
70
+ self,
71
+ in_features: int,
72
+ out_features: int,
73
+ r: int = 0,
74
+ lora_alpha: int = 1,
75
+ lora_dropout: float = 0.,
76
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
77
+ merge_weights: bool = True,
78
+ **kwargs
79
+ ):
80
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
81
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
82
+ merge_weights=merge_weights)
83
+
84
+ self.fan_in_fan_out = fan_in_fan_out
85
+ # Actual trainable parameters
86
+ if r > 0:
87
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
88
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
89
+ self.scaling = self.lora_alpha / self.r
90
+ # Freezing the pre-trained weight matrix
91
+ self.weight.requires_grad = False
92
+ self.reset_parameters()
93
+ if fan_in_fan_out:
94
+ self.weight.data = self.weight.data.transpose(0, 1)
95
+
96
+ def reset_parameters(self):
97
+ #nn.Linear.reset_parameters(self)
98
+ if hasattr(self, 'lora_A'):
99
+ # initialize B the same way as the default for nn.Linear and A to zero
100
+ # this is different than what is described in the paper but should not affect performance
101
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
102
+ nn.init.zeros_(self.lora_B)
103
+
104
+ # def train(self, mode: bool = True):
105
+ # def T(w):
106
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
107
+ # nn.Linear.train(self, mode)
108
+ # if mode:
109
+ # if self.merge_weights and self.merged:
110
+ # # Make sure that the weights are not merged
111
+ # if self.r > 0:
112
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
113
+ # self.merged = False
114
+ # else:
115
+ # if self.merge_weights and not self.merged:
116
+ # # Merge the weights and mark it
117
+ # if self.r > 0:
118
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
119
+ # self.merged = True
120
+
121
+ def forward(self, x: torch.Tensor):
122
+ def T(w):
123
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
124
+ if self.r > 0 and not self.merged:
125
+ result = F.linear(x, T(self.weight), bias=self.bias)
126
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
127
+ return result
128
+ else:
129
+ return F.linear(x, T(self.weight), bias=self.bias)
130
+
131
+
132
+
133
+ def make_2tuple(x):
134
+ if isinstance(x, tuple):
135
+ assert len(x) == 2
136
+ return x
137
+
138
+ assert isinstance(x, int)
139
+ return (x, x)
140
+
141
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
142
+ if drop_prob == 0.0 or not training:
143
+ return x
144
+ keep_prob = 1 - drop_prob
145
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
146
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
147
+ if keep_prob > 0.0:
148
+ random_tensor.div_(keep_prob)
149
+ output = x * random_tensor
150
+ return output
151
+
152
+ class DropPath(nn.Module):
153
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
154
+
155
+ def __init__(self, drop_prob=None):
156
+ super(DropPath, self).__init__()
157
+ self.drop_prob = drop_prob
158
+
159
+ def forward(self, x):
160
+ return drop_path(x, self.drop_prob, self.training)
161
+
162
+ class LayerScale(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim: int,
166
+ init_values: Union[float, Tensor] = 1e-5,
167
+ inplace: bool = False,
168
+ ) -> None:
169
+ super().__init__()
170
+ self.inplace = inplace
171
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
172
+
173
+ def forward(self, x: Tensor) -> Tensor:
174
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
175
+
176
+
177
+ class PatchEmbed(nn.Module):
178
+ """
179
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
180
+
181
+ Args:
182
+ img_size: Image size.
183
+ patch_size: Patch token size.
184
+ in_chans: Number of input image channels.
185
+ embed_dim: Number of linear projection output channels.
186
+ norm_layer: Normalization layer.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ img_size: Union[int, Tuple[int, int]] = 224,
192
+ patch_size: Union[int, Tuple[int, int]] = 16,
193
+ in_chans: int = 3,
194
+ embed_dim: int = 768,
195
+ norm_layer: Optional[Callable] = None,
196
+ flatten_embedding: bool = True,
197
+ tuning_mode: Optional[str] = None
198
+ ) -> None:
199
+ super().__init__()
200
+
201
+ image_HW = make_2tuple(img_size)
202
+ patch_HW = make_2tuple(patch_size)
203
+ patch_grid_size = (
204
+ image_HW[0] // patch_HW[0],
205
+ image_HW[1] // patch_HW[1],
206
+ )
207
+
208
+ self.img_size = image_HW
209
+ self.patch_size = patch_HW
210
+ self.patches_resolution = patch_grid_size
211
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
212
+
213
+ self.in_chans = in_chans
214
+ self.embed_dim = embed_dim
215
+
216
+ self.flatten_embedding = flatten_embedding
217
+
218
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
219
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
220
+
221
+ if tuning_mode != None:
222
+ self.tuning_mode = tuning_mode
223
+ if tuning_mode == 'ssf':
224
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
225
+ else:
226
+ pass
227
+ #raise NotImplementedError()
228
+ else:
229
+ self.tuning_mode = None
230
+
231
+ def forward(self, x: Tensor) -> Tensor:
232
+ _, _, H, W = x.shape
233
+ patch_H, patch_W = self.patch_size
234
+
235
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
236
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
237
+
238
+ x = self.proj(x) # B C H W
239
+ H, W = x.size(2), x.size(3)
240
+ x = x.flatten(2).transpose(1, 2) # B HW C
241
+ x = self.norm(x)
242
+ if self.tuning_mode == 'ssf':
243
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
244
+ if not self.flatten_embedding:
245
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
246
+ return x
247
+
248
+ def flops(self) -> float:
249
+ Ho, Wo = self.patches_resolution
250
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
251
+ if self.norm is not None:
252
+ flops += Ho * Wo * self.embed_dim
253
+ return flops
254
+
255
+ class Mlp(nn.Module):
256
+ def __init__(
257
+ self,
258
+ in_features: int,
259
+ hidden_features: Optional[int] = None,
260
+ out_features: Optional[int] = None,
261
+ act_layer: Callable[..., nn.Module] = nn.GELU,
262
+ drop: float = 0.0,
263
+ bias: bool = True,
264
+ tuning_mode: Optional[int] = None
265
+ ) -> None:
266
+ super().__init__()
267
+ out_features = out_features or in_features
268
+ hidden_features = hidden_features or in_features
269
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
270
+ self.act = act_layer()
271
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
272
+ self.drop = nn.Dropout(drop)
273
+
274
+ if tuning_mode != None:
275
+ self.tuning_mode = tuning_mode
276
+ if tuning_mode == 'ssf':
277
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
278
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
279
+ else:
280
+ pass
281
+ #raise NotImplementedError()
282
+ else:
283
+ self.tuning_mode = None
284
+
285
+ def forward(self, x: Tensor) -> Tensor:
286
+ x = self.fc1(x)
287
+ if self.tuning_mode == 'ssf':
288
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
289
+
290
+ x = self.act(x)
291
+ x = self.drop(x)
292
+ x = self.fc2(x)
293
+ if self.tuning_mode == 'ssf':
294
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
295
+
296
+ x = self.drop(x)
297
+ return x
298
+
299
+
300
+ class SwiGLUFFN(nn.Module):
301
+ def __init__(
302
+ self,
303
+ in_features: int,
304
+ hidden_features: Optional[int] = None,
305
+ out_features: Optional[int] = None,
306
+ act_layer: Callable[..., nn.Module] = None,
307
+ drop: float = 0.0,
308
+ bias: bool = True,
309
+ tuning_mode: Optional[int] = None
310
+ ) -> None:
311
+ super().__init__()
312
+ out_features = out_features or in_features
313
+ hidden_features = hidden_features or in_features
314
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
315
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
316
+
317
+ if tuning_mode != None:
318
+ self.tuning_mode = tuning_mode
319
+ if tuning_mode == 'ssf':
320
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features)
321
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
322
+ else:
323
+ pass
324
+ #raise NotImplementedError()
325
+ else:
326
+ self.tuning_mode = None
327
+
328
+
329
+ def forward(self, x: Tensor) -> Tensor:
330
+ x12 = self.w12(x)
331
+ if self.tuning_mode == 'ssf':
332
+ x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1)
333
+
334
+ x1, x2 = x12.chunk(2, dim=-1)
335
+ hidden = F.silu(x1) * x2
336
+ out = self.w3(hidden)
337
+
338
+ if self.tuning_mode == 'ssf':
339
+ out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2)
340
+
341
+ return out
342
+
343
+
344
+ try:
345
+ from xformers.ops import SwiGLU
346
+ #import numpy.bool
347
+ XFORMERS_AVAILABLE = True
348
+ except ImportError:
349
+ SwiGLU = SwiGLUFFN
350
+ XFORMERS_AVAILABLE = False
351
+
352
+ class SwiGLUFFNFused(SwiGLU):
353
+ def __init__(
354
+ self,
355
+ in_features: int,
356
+ hidden_features: Optional[int] = None,
357
+ out_features: Optional[int] = None,
358
+ act_layer: Callable[..., nn.Module] = None,
359
+ drop: float = 0.0,
360
+ bias: bool = True,
361
+ ) -> None:
362
+ out_features = out_features or in_features
363
+ hidden_features = hidden_features or in_features
364
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
365
+ super().__init__(
366
+ in_features=in_features,
367
+ hidden_features=hidden_features,
368
+ out_features=out_features,
369
+ bias=bias,
370
+ )
371
+
372
+
373
+ try:
374
+ from xformers.ops import memory_efficient_attention, unbind, fmha
375
+ from xformers.components.attention import ScaledDotProduct
376
+ from xformers.components import MultiHeadDispatch
377
+ #import numpy.bool
378
+ XFORMERS_AVAILABLE = True
379
+ except ImportError:
380
+ logger.warning("xFormers not available")
381
+ XFORMERS_AVAILABLE = False
382
+
383
+
384
+ class Attention(nn.Module):
385
+ def __init__(
386
+ self,
387
+ dim: int,
388
+ num_heads: int = 8,
389
+ qkv_bias: bool = False,
390
+ proj_bias: bool = True,
391
+ attn_drop: float = 0.0,
392
+ proj_drop: float = 0.0,
393
+ window_size: int = 0,
394
+ tuning_mode: Optional[int] = None
395
+ ) -> None:
396
+ super().__init__()
397
+ self.num_heads = num_heads
398
+ head_dim = dim // num_heads
399
+ self.scale = head_dim**-0.5
400
+
401
+ if tuning_mode == 'lora':
402
+ self.tuning_mode = tuning_mode
403
+ self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8)
404
+ else:
405
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
+
407
+ self.attn_drop = nn.Dropout(attn_drop)
408
+
409
+ if tuning_mode == 'lora':
410
+ self.tuning_mode = tuning_mode
411
+ self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8)
412
+ else:
413
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
414
+ self.proj_drop = nn.Dropout(proj_drop)
415
+
416
+ if tuning_mode != None:
417
+ self.tuning_mode = tuning_mode
418
+ if tuning_mode == 'ssf':
419
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
420
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
421
+ else:
422
+ pass
423
+ #raise NotImplementedError()
424
+ else:
425
+ self.tuning_mode = None
426
+
427
+ #if not self.training:
428
+ #
429
+ # self.attn = ScaledDotProduct()
430
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
431
+
432
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
433
+ B, N, C = x.shape
434
+ if self.tuning_mode == 'ssf':
435
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
436
+ else:
437
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
438
+
439
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
440
+ attn = q @ k.transpose(-2, -1)
441
+
442
+ if attn_bias is not None:
443
+ attn = attn + attn_bias[:, :, :N]
444
+
445
+ attn = attn.softmax(dim=-1)
446
+ attn = self.attn_drop(attn)
447
+
448
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
449
+ x = self.proj(x)
450
+
451
+ if self.tuning_mode == 'ssf':
452
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
453
+
454
+ x = self.proj_drop(x)
455
+ return x
456
+
457
+
458
+ class MemEffAttention(Attention):
459
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
460
+ if not XFORMERS_AVAILABLE:
461
+ #if True:
462
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
463
+ return super().forward(x, attn_bias)
464
+
465
+ B, N, C = x.shape
466
+ if self.tuning_mode == 'ssf':
467
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads)
468
+ else:
469
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
470
+
471
+ q, k, v = unbind(qkv, 2)
472
+ if attn_bias is not None:
473
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
474
+ else:
475
+ x = memory_efficient_attention(q, k, v)
476
+ x = x.reshape([B, N, C])
477
+
478
+ x = self.proj(x)
479
+ if self.tuning_mode == 'ssf':
480
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
481
+
482
+ x = self.proj_drop(x)
483
+ return x
484
+
485
+ try:
486
+ from xformers.ops import fmha
487
+ from xformers.ops import scaled_index_add, index_select_cat
488
+ #import numpy.bool
489
+ XFORMERS_AVAILABLE = True
490
+ except ImportError:
491
+ logger.warning("xFormers not available")
492
+ XFORMERS_AVAILABLE = False
493
+
494
+ class Block(nn.Module):
495
+ def __init__(
496
+ self,
497
+ dim: int,
498
+ num_heads: int,
499
+ mlp_ratio: float = 4.0,
500
+ qkv_bias: bool = False,
501
+ proj_bias: bool = True,
502
+ ffn_bias: bool = True,
503
+ drop: float = 0.0,
504
+ attn_drop: float = 0.0,
505
+ init_values = None,
506
+ drop_path: float = 0.0,
507
+ act_layer: Callable[..., nn.Module] = nn.GELU,
508
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
509
+ attn_class: Callable[..., nn.Module] = Attention,
510
+ ffn_layer: Callable[..., nn.Module] = Mlp,
511
+ tuning_mode: Optional[int] = None
512
+ ) -> None:
513
+ super().__init__()
514
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
515
+ self.norm1 = norm_layer(dim)
516
+ self.attn = attn_class(
517
+ dim,
518
+ num_heads=num_heads,
519
+ qkv_bias=qkv_bias,
520
+ proj_bias=proj_bias,
521
+ attn_drop=attn_drop,
522
+ proj_drop=drop,
523
+ tuning_mode=tuning_mode
524
+ )
525
+
526
+ if tuning_mode != None:
527
+ self.tuning_mode = tuning_mode
528
+ if tuning_mode == 'ssf':
529
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
530
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
531
+ else:
532
+ pass
533
+ #raise NotImplementedError()
534
+ else:
535
+ self.tuning_mode = None
536
+
537
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
538
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
539
+
540
+ self.norm2 = norm_layer(dim)
541
+ mlp_hidden_dim = int(dim * mlp_ratio)
542
+ self.mlp = ffn_layer(
543
+ in_features=dim,
544
+ hidden_features=mlp_hidden_dim,
545
+ act_layer=act_layer,
546
+ drop=drop,
547
+ bias=ffn_bias,
548
+ )
549
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
550
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
551
+
552
+ self.sample_drop_ratio = drop_path
553
+
554
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
555
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
556
+ if self.tuning_mode == 'ssf':
557
+ return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias))
558
+ else:
559
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
560
+
561
+ def ffn_residual_func(x: Tensor) -> Tensor:
562
+ if self.tuning_mode == 'ssf':
563
+ return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
564
+ else:
565
+ return self.ls2(self.mlp(self.norm2(x)))
566
+
567
+ if self.training and self.sample_drop_ratio > 0.1:
568
+ # the overhead is compensated only for a drop path rate larger than 0.1
569
+ x = drop_add_residual_stochastic_depth(
570
+ x,
571
+ residual_func=attn_residual_func,
572
+ sample_drop_ratio=self.sample_drop_ratio,
573
+ attn_bias=attn_bias
574
+ )
575
+ x = drop_add_residual_stochastic_depth(
576
+ x,
577
+ residual_func=ffn_residual_func,
578
+ sample_drop_ratio=self.sample_drop_ratio,
579
+ )
580
+ elif self.training and self.sample_drop_ratio > 0.0:
581
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
582
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
583
+ else:
584
+ x = x + attn_residual_func(x, attn_bias)
585
+ x = x + ffn_residual_func(x)
586
+ return x
587
+
588
+
589
+ def drop_add_residual_stochastic_depth(
590
+ x: Tensor,
591
+ residual_func: Callable[[Tensor], Tensor],
592
+ sample_drop_ratio: float = 0.0, attn_bias=None
593
+ ) -> Tensor:
594
+ # 1) extract subset using permutation
595
+ b, n, d = x.shape
596
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
597
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
598
+ x_subset = x[brange]
599
+
600
+ # 2) apply residual_func to get residual
601
+ residual = residual_func(x_subset, attn_bias)
602
+
603
+ x_flat = x.flatten(1)
604
+ residual = residual.flatten(1)
605
+
606
+ residual_scale_factor = b / sample_subset_size
607
+
608
+ # 3) add the residual
609
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
610
+ return x_plus_residual.view_as(x)
611
+
612
+
613
+ def get_branges_scales(x, sample_drop_ratio=0.0):
614
+ b, n, d = x.shape
615
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
616
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
617
+ residual_scale_factor = b / sample_subset_size
618
+ return brange, residual_scale_factor
619
+
620
+
621
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
622
+ if scaling_vector is None:
623
+ x_flat = x.flatten(1)
624
+ residual = residual.flatten(1)
625
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
626
+ else:
627
+ x_plus_residual = scaled_index_add(
628
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
629
+ )
630
+ return x_plus_residual
631
+
632
+
633
+ attn_bias_cache: Dict[Tuple, Any] = {}
634
+
635
+
636
+ def get_attn_bias_and_cat(x_list, branges=None):
637
+ """
638
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
639
+ """
640
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
641
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
642
+ if all_shapes not in attn_bias_cache.keys():
643
+ seqlens = []
644
+ for b, x in zip(batch_sizes, x_list):
645
+ for _ in range(b):
646
+ seqlens.append(x.shape[1])
647
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
648
+ attn_bias._batch_sizes = batch_sizes
649
+ attn_bias_cache[all_shapes] = attn_bias
650
+
651
+ if branges is not None:
652
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
653
+ else:
654
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
655
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
656
+
657
+ return attn_bias_cache[all_shapes], cat_tensors
658
+
659
+
660
+ def drop_add_residual_stochastic_depth_list(
661
+ x_list: List[Tensor],
662
+ residual_func: Callable[[Tensor, Any], Tensor],
663
+ sample_drop_ratio: float = 0.0,
664
+ scaling_vector=None,
665
+ ) -> Tensor:
666
+ # 1) generate random set of indices for dropping samples in the batch
667
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
668
+ branges = [s[0] for s in branges_scales]
669
+ residual_scale_factors = [s[1] for s in branges_scales]
670
+
671
+ # 2) get attention bias and index+concat the tensors
672
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
673
+
674
+ # 3) apply residual_func to get residual, and split the result
675
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
676
+
677
+ outputs = []
678
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
679
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
680
+ return outputs
681
+
682
+
683
+ class NestedTensorBlock(Block):
684
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
685
+ """
686
+ x_list contains a list of tensors to nest together and run
687
+ """
688
+ assert isinstance(self.attn, MemEffAttention)
689
+
690
+ if self.training and self.sample_drop_ratio > 0.0:
691
+
692
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
693
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
694
+
695
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
696
+ return self.mlp(self.norm2(x))
697
+
698
+ x_list = drop_add_residual_stochastic_depth_list(
699
+ x_list,
700
+ residual_func=attn_residual_func,
701
+ sample_drop_ratio=self.sample_drop_ratio,
702
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
703
+ )
704
+ x_list = drop_add_residual_stochastic_depth_list(
705
+ x_list,
706
+ residual_func=ffn_residual_func,
707
+ sample_drop_ratio=self.sample_drop_ratio,
708
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
709
+ )
710
+ return x_list
711
+ else:
712
+
713
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
714
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
715
+
716
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
717
+ return self.ls2(self.mlp(self.norm2(x)))
718
+
719
+ attn_bias, x = get_attn_bias_and_cat(x_list)
720
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
721
+ x = x + ffn_residual_func(x)
722
+ return attn_bias.split(x)
723
+
724
+ def forward(self, x_or_x_list, attn_bias=None):
725
+ if isinstance(x_or_x_list, Tensor):
726
+ return super().forward(x_or_x_list, attn_bias)
727
+ elif isinstance(x_or_x_list, list):
728
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
729
+ return self.forward_nested(x_or_x_list)
730
+ else:
731
+ raise AssertionError
732
+
733
+
734
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
735
+ if not depth_first and include_root:
736
+ fn(module=module, name=name)
737
+ for child_name, child_module in module.named_children():
738
+ child_name = ".".join((name, child_name)) if name else child_name
739
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
740
+ if depth_first and include_root:
741
+ fn(module=module, name=name)
742
+ return module
743
+
744
+
745
+ class BlockChunk(nn.ModuleList):
746
+ def forward(self, x, others=None):
747
+ for b in self:
748
+ if others == None:
749
+ x = b(x)
750
+ else:
751
+ x = b(x, others)
752
+ return x
753
+
754
+
755
+ class DinoVisionTransformer(nn.Module):
756
+ def __init__(
757
+ self,
758
+ img_size=518,
759
+ patch_size=16,
760
+ in_chans=3,
761
+ embed_dim=768,
762
+ depth=12,
763
+ num_heads=12,
764
+ mlp_ratio=4.0,
765
+ qkv_bias=True,
766
+ ffn_bias=True,
767
+ proj_bias=True,
768
+ drop_path_rate=0.0,
769
+ drop_path_uniform=False,
770
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
771
+ embed_layer=PatchEmbed,
772
+ act_layer=nn.GELU,
773
+ block_fn=Block,
774
+ ffn_layer="mlp",
775
+ block_chunks=1,
776
+ num_register_tokens=0,
777
+ interpolate_antialias=False,
778
+ interpolate_offset=0.1,
779
+ tuning_mode=None,
780
+ **kwargs
781
+ ):
782
+ """
783
+ Args:
784
+ img_size (int, tuple): input image size
785
+ patch_size (int, tuple): patch size
786
+ in_chans (int): number of input channels
787
+ embed_dim (int): embedding dimension
788
+ depth (int): depth of transformer
789
+ num_heads (int): number of attention heads
790
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
791
+ qkv_bias (bool): enable bias for qkv if True
792
+ proj_bias (bool): enable bias for proj in attn if True
793
+ ffn_bias (bool): enable bias for ffn if True
794
+ drop_path_rate (float): stochastic depth rate
795
+ drop_path_uniform (bool): apply uniform drop rate across blocks
796
+ weight_init (str): weight init scheme
797
+ init_values (float): layer-scale init values
798
+ embed_layer (nn.Module): patch embedding layer
799
+ act_layer (nn.Module): MLP activation layer
800
+ block_fn (nn.Module): transformer block class
801
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
802
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
803
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
804
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
805
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
806
+ """
807
+ super().__init__()
808
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
809
+
810
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
811
+ self.num_tokens = 1
812
+ self.n_blocks = depth
813
+ self.num_heads = num_heads
814
+ self.patch_size = patch_size
815
+ self.num_register_tokens = num_register_tokens
816
+ self.interpolate_antialias = interpolate_antialias
817
+ self.interpolate_offset = interpolate_offset
818
+
819
+ if tuning_mode != None:
820
+ self.tuning_mode = tuning_mode
821
+ if tuning_mode == 'ssf':
822
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
823
+ else:
824
+ pass
825
+ #raise NotImplementedError()
826
+ else:
827
+ self.tuning_mode = None
828
+ tuning_mode_list = [tuning_mode] * depth
829
+
830
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)
831
+ num_patches = self.patch_embed.num_patches
832
+
833
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
834
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
835
+ assert num_register_tokens >= 0
836
+ self.register_tokens = (
837
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
838
+ )
839
+
840
+ if drop_path_uniform is True:
841
+ dpr = [drop_path_rate] * depth
842
+ else:
843
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
844
+
845
+ if ffn_layer == "mlp":
846
+ logger.info("using MLP layer as FFN")
847
+ ffn_layer = Mlp
848
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
849
+ logger.info("using SwiGLU layer as FFN")
850
+ ffn_layer = SwiGLUFFNFused
851
+ elif ffn_layer == "identity":
852
+ logger.info("using Identity layer as FFN")
853
+
854
+ def f(*args, **kwargs):
855
+ return nn.Identity()
856
+
857
+ ffn_layer = f
858
+ else:
859
+ raise NotImplementedError
860
+
861
+ blocks_list = [
862
+ block_fn(
863
+ dim=embed_dim,
864
+ num_heads=num_heads,
865
+ mlp_ratio=mlp_ratio,
866
+ qkv_bias=qkv_bias,
867
+ proj_bias=proj_bias,
868
+ ffn_bias=ffn_bias,
869
+ drop_path=dpr[i],
870
+ norm_layer=norm_layer,
871
+ act_layer=act_layer,
872
+ ffn_layer=ffn_layer,
873
+ init_values=init_values,
874
+ tuning_mode=tuning_mode_list[i]
875
+ )
876
+ for i in range(depth)
877
+ ]
878
+ if block_chunks > 0:
879
+ self.chunked_blocks = True
880
+ chunked_blocks = []
881
+ chunksize = depth // block_chunks
882
+ for i in range(0, depth, chunksize):
883
+ # this is to keep the block index consistent if we chunk the block list
884
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
885
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
886
+ else:
887
+ self.chunked_blocks = False
888
+ self.blocks = nn.ModuleList(blocks_list)
889
+
890
+ self.norm = norm_layer(embed_dim)
891
+ self.head = nn.Identity()
892
+
893
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
894
+
895
+ self.init_weights()
896
+
897
+ def init_weights(self):
898
+ trunc_normal_(self.pos_embed, std=0.02)
899
+ nn.init.normal_(self.cls_token, std=1e-6)
900
+ if self.register_tokens is not None:
901
+ nn.init.normal_(self.register_tokens, std=1e-6)
902
+ named_apply(init_weights_vit_timm, self)
903
+
904
+ def interpolate_pos_encoding(self, x, w, h):
905
+ previous_dtype = x.dtype
906
+ npatch = x.shape[1] - 1
907
+ N = self.pos_embed.shape[1] - 1
908
+ if npatch == N and w == h:
909
+ return self.pos_embed
910
+ pos_embed = self.pos_embed.float()
911
+ class_pos_embed = pos_embed[:, 0]
912
+ patch_pos_embed = pos_embed[:, 1:]
913
+ dim = x.shape[-1]
914
+ w0 = w // self.patch_size
915
+ h0 = h // self.patch_size
916
+ # we add a small number to avoid floating point error in the interpolation
917
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
918
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
919
+
920
+ sqrt_N = math.sqrt(N)
921
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
922
+ patch_pos_embed = nn.functional.interpolate(
923
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
924
+ scale_factor=(sx, sy),
925
+ mode="bicubic",
926
+ antialias=self.interpolate_antialias,
927
+ )
928
+
929
+ assert int(w0) == patch_pos_embed.shape[-2]
930
+ assert int(h0) == patch_pos_embed.shape[-1]
931
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
932
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
933
+
934
+ def prepare_tokens_with_masks(self, x, masks=None):
935
+ B, nc, w, h = x.shape
936
+ x = self.patch_embed(x)
937
+ if masks is not None:
938
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
939
+
940
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
941
+ x = x + self.interpolate_pos_encoding(x, w, h)
942
+
943
+ if self.register_tokens is not None:
944
+ x = torch.cat(
945
+ (
946
+ x[:, :1],
947
+ self.register_tokens.expand(x.shape[0], -1, -1),
948
+ x[:, 1:],
949
+ ),
950
+ dim=1,
951
+ )
952
+
953
+ return x
954
+
955
+ def forward_features_list(self, x_list, masks_list):
956
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
957
+ for blk in self.blocks:
958
+ x = blk(x)
959
+
960
+ all_x = x
961
+ output = []
962
+ for x, masks in zip(all_x, masks_list):
963
+ x_norm = self.norm(x)
964
+ output.append(
965
+ {
966
+ "x_norm_clstoken": x_norm[:, 0],
967
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
968
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
969
+ "x_prenorm": x,
970
+ "masks": masks,
971
+ }
972
+ )
973
+ return output
974
+
975
+ def forward_features(self, x, masks=None):
976
+ if isinstance(x, list):
977
+ return self.forward_features_list(x, masks)
978
+
979
+ B, C, H, W = x.size()
980
+ pad_h = (self.patch_size - H % self.patch_size)
981
+ pad_w = (self.patch_size - W % self.patch_size)
982
+ if pad_h == self.patch_size:
983
+ pad_h = 0
984
+ if pad_w == self.patch_size:
985
+ pad_w = 0
986
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
987
+ if pad_h + pad_w > 0:
988
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
989
+
990
+ x = self.prepare_tokens_with_masks(x, masks)
991
+
992
+ for blk in self.blocks:
993
+ x = blk(x)
994
+
995
+ x_norm = self.norm(x)
996
+ if self.tuning_mode == 'ssf':
997
+ x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1)
998
+
999
+ # return {
1000
+ # "x_norm_clstoken": x_norm[:, 0],
1001
+ # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
1002
+ # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
1003
+ # "x_prenorm": x,
1004
+ # "masks": masks,
1005
+ # }
1006
+ features = []
1007
+ features.append(x_norm)
1008
+ features.append(x_norm)
1009
+ features.append(x_norm)
1010
+ features.append(x_norm)
1011
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
1012
+
1013
+
1014
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
1015
+ x = self.prepare_tokens_with_masks(x)
1016
+ # If n is an int, take the n last blocks. If it's a list, take them
1017
+ output, total_block_len = [], len(self.blocks)
1018
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1019
+ for i, blk in enumerate(self.blocks):
1020
+ x = blk(x)
1021
+ if i in blocks_to_take:
1022
+ output.append(x)
1023
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1024
+ return output
1025
+
1026
+ def _get_intermediate_layers_chunked(self, x, n=1):
1027
+ x = self.prepare_tokens_with_masks(x)
1028
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
1029
+ # If n is an int, take the n last blocks. If it's a list, take them
1030
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1031
+ for block_chunk in self.blocks:
1032
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
1033
+ x = blk(x)
1034
+ if i in blocks_to_take:
1035
+ output.append(x)
1036
+ i += 1
1037
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1038
+ return output
1039
+
1040
+ def get_intermediate_layers(
1041
+ self,
1042
+ x: torch.Tensor,
1043
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
1044
+ reshape: bool = False,
1045
+ return_class_token: bool = False,
1046
+ norm=True,
1047
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
1048
+ if self.chunked_blocks:
1049
+ outputs = self._get_intermediate_layers_chunked(x, n)
1050
+ else:
1051
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
1052
+ if norm:
1053
+ outputs = [self.norm(out) for out in outputs]
1054
+ class_tokens = [out[:, 0] for out in outputs]
1055
+ outputs = [out[:, 1:] for out in outputs]
1056
+ if reshape:
1057
+ B, _, w, h = x.shape
1058
+ outputs = [
1059
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
1060
+ for out in outputs
1061
+ ]
1062
+ if return_class_token:
1063
+ return tuple(zip(outputs, class_tokens))
1064
+ return tuple(outputs)
1065
+
1066
+ def forward(self, *args, is_training=False, **kwargs):
1067
+ ret = self.forward_features(*args, **kwargs)
1068
+ return ret
1069
+ # if is_training:
1070
+ # return ret
1071
+ # else:
1072
+ # return self.head(ret["x_norm_clstoken"])
1073
+
1074
+
1075
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
1076
+ """ViT weight initialization, original timm impl (for reproducibility)"""
1077
+ if isinstance(module, nn.Linear):
1078
+ trunc_normal_(module.weight, std=0.02)
1079
+ if module.bias is not None:
1080
+ nn.init.zeros_(module.bias)
1081
+
1082
+
1083
+ def load_ckpt_dino(checkpoint, model):
1084
+ if checkpoint is not None:
1085
+ try:
1086
+ with open(checkpoint, "rb") as f:
1087
+ state_dict = torch.load(f)
1088
+ except:
1089
+ print('NO pretrained imagenet ckpt available! Check your path!')
1090
+ del model.mask_token
1091
+ return
1092
+
1093
+ try:
1094
+ model.load_state_dict(state_dict, strict=True)
1095
+ except:
1096
+ new_state_dict = {}
1097
+ for key, value in state_dict.items():
1098
+ if 'blocks' in key:
1099
+ key_new = 'blocks.0' + key[len('blocks'):]
1100
+ else:
1101
+ key_new = key
1102
+ new_state_dict[key_new] = value
1103
+
1104
+ model.load_state_dict(new_state_dict, strict=True)
1105
+ del model.mask_token
1106
+ return
1107
+ else:
1108
+ return
1109
+
1110
+
1111
+ def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1112
+ model = DinoVisionTransformer(
1113
+ patch_size=patch_size,
1114
+ embed_dim=384,
1115
+ depth=12,
1116
+ num_heads=6,
1117
+ mlp_ratio=4,
1118
+ block_fn=partial(Block, attn_class=MemEffAttention),
1119
+ num_register_tokens=num_register_tokens,
1120
+ **kwargs,
1121
+ )
1122
+
1123
+ load_ckpt_dino(checkpoint, model)
1124
+
1125
+ return model
1126
+
1127
+
1128
+ def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1129
+ model = DinoVisionTransformer(
1130
+ patch_size=patch_size,
1131
+ embed_dim=768,
1132
+ depth=12,
1133
+ num_heads=12,
1134
+ mlp_ratio=4,
1135
+ block_fn=partial(Block, attn_class=MemEffAttention),
1136
+ num_register_tokens=num_register_tokens,
1137
+ **kwargs,
1138
+ )
1139
+ return model
1140
+
1141
+
1142
+ def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1143
+ model = DinoVisionTransformer(
1144
+ patch_size=patch_size,
1145
+ embed_dim=1024,
1146
+ depth=24,
1147
+ num_heads=16,
1148
+ mlp_ratio=4,
1149
+ block_fn=partial(Block, attn_class=MemEffAttention),
1150
+ num_register_tokens=num_register_tokens,
1151
+ **kwargs,
1152
+ )
1153
+
1154
+ if checkpoint is not None:
1155
+ with open(checkpoint, "rb") as f:
1156
+ state_dict = torch.load(f)
1157
+ try:
1158
+ model.load_state_dict(state_dict, strict=True)
1159
+ except:
1160
+ new_state_dict = {}
1161
+ for key, value in state_dict.items():
1162
+ if 'blocks' in key:
1163
+ key_new = 'blocks.0' + key[len('blocks'):]
1164
+ else:
1165
+ key_new = key
1166
+ new_state_dict[key_new] = value
1167
+
1168
+ model.load_state_dict(new_state_dict, strict=True)
1169
+ del model.mask_token
1170
+ return model
1171
+
1172
+
1173
+ def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1174
+ """
1175
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1176
+ """
1177
+ model = DinoVisionTransformer(
1178
+ patch_size=patch_size,
1179
+ embed_dim=1536,
1180
+ depth=40,
1181
+ num_heads=24,
1182
+ mlp_ratio=4,
1183
+ block_fn=partial(Block, attn_class=MemEffAttention),
1184
+ num_register_tokens=num_register_tokens,
1185
+ ffn_layer='swiglu',
1186
+ **kwargs,
1187
+ )
1188
+ return model
1189
+
1190
+
1191
+
1192
+ def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1193
+ model = DinoVisionTransformer(
1194
+ patch_size=patch_size,
1195
+ embed_dim=384,
1196
+ depth=12,
1197
+ num_heads=6,
1198
+ mlp_ratio=4,
1199
+ block_fn=partial(Block, attn_class=MemEffAttention),
1200
+ num_register_tokens=num_register_tokens,
1201
+ tuning_mode=tuning_mode,
1202
+ **kwargs,
1203
+ )
1204
+
1205
+ load_ckpt_dino(checkpoint, model)
1206
+
1207
+ return model
1208
+
1209
+
1210
+ def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
1211
+ model = DinoVisionTransformer(
1212
+ patch_size=patch_size,
1213
+ embed_dim=768,
1214
+ depth=12,
1215
+ num_heads=12,
1216
+ mlp_ratio=4,
1217
+ block_fn=partial(Block, attn_class=MemEffAttention),
1218
+ num_register_tokens=num_register_tokens,
1219
+ **kwargs,
1220
+ )
1221
+
1222
+ load_ckpt_dino(checkpoint, model)
1223
+
1224
+ return model
1225
+
1226
+
1227
+ def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1228
+ model = DinoVisionTransformer(
1229
+ img_size = 518,
1230
+ patch_size=patch_size,
1231
+ embed_dim=1024,
1232
+ depth=24,
1233
+ num_heads=16,
1234
+ mlp_ratio=4,
1235
+ block_fn=partial(Block, attn_class=MemEffAttention),
1236
+ num_register_tokens=num_register_tokens,
1237
+ tuning_mode=tuning_mode,
1238
+ **kwargs,
1239
+ )
1240
+
1241
+ load_ckpt_dino(checkpoint, model)
1242
+
1243
+ return model
1244
+
1245
+
1246
+ def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1247
+ """
1248
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1249
+ """
1250
+ model = DinoVisionTransformer(
1251
+ patch_size=patch_size,
1252
+ embed_dim=1536,
1253
+ depth=40,
1254
+ num_heads=24,
1255
+ mlp_ratio=4,
1256
+ block_fn=partial(Block, attn_class=MemEffAttention),
1257
+ num_register_tokens=num_register_tokens,
1258
+ ffn_layer='swiglu',
1259
+ tuning_mode=tuning_mode,
1260
+ **kwargs,
1261
+ )
1262
+
1263
+ load_ckpt_dino(checkpoint, model)
1264
+
1265
+ return model
1266
+
1267
+ if __name__ == '__main__':
1268
+ try:
1269
+ from mmcv.utils import Config
1270
+ except:
1271
+ from mmengine import Config
1272
+
1273
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
1274
+
1275
+ #cfg.data_basic['crop_size']['0']
1276
+ #cfg.data_basic['crop_size']['1']
1277
+ cfg = Config.fromfile('/opt/ml/project/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py')
1278
+
1279
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
1280
+ rgb = torch.zeros(1, 3, 616, 1064).cuda()
1281
+ cfg['tuning_mode'] = 'ssf'
1282
+ #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda()
1283
+ model = vit_large_reg(tuning_mode='ssf').cuda()
1284
+
1285
+ #import timm
1286
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
1287
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
1288
+
1289
+ out1 = model(rgb)
1290
+ #out2 = model2(rgb)
1291
+ temp = 0
1292
+
1293
+
mono/model/backbones/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ConvNeXt import convnext_xlarge
2
+ from .ConvNeXt import convnext_small
3
+ from .ConvNeXt import convnext_base
4
+ from .ConvNeXt import convnext_large
5
+ from .ConvNeXt import convnext_tiny
6
+ from .ViT_DINO import vit_large
7
+ from .ViT_DINO_reg import vit_small_reg, vit_large_reg
8
+
9
+ __all__ = [
10
+ 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg'
11
+ ]
mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc ADDED
Binary file (9.37 kB). View file
 
mono/model/backbones/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (410 Bytes). View file
 
mono/model/decode_heads/HourGlassDecoder.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+ def compute_depth_expectation(prob, depth_values):
8
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
9
+ depth = torch.sum(prob * depth_values, 1)
10
+ return depth
11
+
12
+ class ConvBlock(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size=3):
14
+ super(ConvBlock, self).__init__()
15
+
16
+ if kernel_size == 3:
17
+ self.conv = nn.Sequential(
18
+ nn.ReflectionPad2d(1),
19
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
20
+ )
21
+ elif kernel_size == 1:
22
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
23
+
24
+ self.nonlin = nn.ELU(inplace=True)
25
+
26
+ def forward(self, x):
27
+ out = self.conv(x)
28
+ out = self.nonlin(out)
29
+ return out
30
+
31
+
32
+ class ConvBlock_double(nn.Module):
33
+ def __init__(self, in_channels, out_channels, kernel_size=3):
34
+ super(ConvBlock_double, self).__init__()
35
+
36
+ if kernel_size == 3:
37
+ self.conv = nn.Sequential(
38
+ nn.ReflectionPad2d(1),
39
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
40
+ )
41
+ elif kernel_size == 1:
42
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
43
+
44
+ self.nonlin = nn.ELU(inplace=True)
45
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1)
46
+ self.nonlin_2 =nn.ELU(inplace=True)
47
+
48
+ def forward(self, x):
49
+ out = self.conv(x)
50
+ out = self.nonlin(out)
51
+ out = self.conv_2(out)
52
+ out = self.nonlin_2(out)
53
+ return out
54
+
55
+ class DecoderFeature(nn.Module):
56
+ def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]):
57
+ super(DecoderFeature, self).__init__()
58
+ self.num_ch_dec = num_ch_dec
59
+ self.feat_channels = feat_channels
60
+
61
+ self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1)
62
+ self.upconv_3_1 = ConvBlock_double(
63
+ self.feat_channels[2] + self.num_ch_dec[3],
64
+ self.num_ch_dec[3],
65
+ kernel_size=1)
66
+
67
+ self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3)
68
+ self.upconv_2_1 = ConvBlock_double(
69
+ self.feat_channels[1] + self.num_ch_dec[2],
70
+ self.num_ch_dec[2],
71
+ kernel_size=3)
72
+
73
+ self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3)
74
+ self.upconv_1_1 = ConvBlock_double(
75
+ self.feat_channels[0] + self.num_ch_dec[1],
76
+ self.num_ch_dec[1],
77
+ kernel_size=3)
78
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
79
+
80
+ def forward(self, ref_feature):
81
+ x = ref_feature[3]
82
+
83
+ x = self.upconv_3_0(x)
84
+ x = torch.cat((self.upsample(x), ref_feature[2]), 1)
85
+ x = self.upconv_3_1(x)
86
+
87
+ x = self.upconv_2_0(x)
88
+ x = torch.cat((self.upsample(x), ref_feature[1]), 1)
89
+ x = self.upconv_2_1(x)
90
+
91
+ x = self.upconv_1_0(x)
92
+ x = torch.cat((self.upsample(x), ref_feature[0]), 1)
93
+ x = self.upconv_1_1(x)
94
+ return x
95
+
96
+
97
+ class UNet(nn.Module):
98
+ def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'):
99
+ super(UNet, self).__init__()
100
+ basic_block = ConvBnReLU
101
+ num_depth = 128
102
+
103
+ self.conv0 = basic_block(inp_ch, num_depth)
104
+ if channel_mode == 'v0':
105
+ channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8]
106
+ elif channel_mode == 'v1':
107
+ channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth]
108
+ self.down_sample_times = down_sample_times
109
+ for i in range(down_sample_times):
110
+ setattr(
111
+ self, 'conv_%d' % i,
112
+ nn.Sequential(
113
+ basic_block(channels[i], channels[i+1], stride=2),
114
+ basic_block(channels[i+1], channels[i+1])
115
+ )
116
+ )
117
+ for i in range(down_sample_times-1,-1,-1):
118
+ setattr(self, 'deconv_%d' % i,
119
+ nn.Sequential(
120
+ nn.ConvTranspose2d(
121
+ channels[i+1],
122
+ channels[i],
123
+ kernel_size=3,
124
+ padding=1,
125
+ output_padding=1,
126
+ stride=2,
127
+ bias=False),
128
+ nn.BatchNorm2d(channels[i]),
129
+ nn.ReLU(inplace=True)
130
+ )
131
+ )
132
+ self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0)
133
+
134
+ def forward(self, x):
135
+ features = {}
136
+ conv0 = self.conv0(x)
137
+ x = conv0
138
+ features[0] = conv0
139
+ for i in range(self.down_sample_times):
140
+ x = getattr(self, 'conv_%d' % i)(x)
141
+ features[i+1] = x
142
+ for i in range(self.down_sample_times-1,-1,-1):
143
+ x = features[i] + getattr(self, 'deconv_%d' % i)(x)
144
+ x = self.prob(x)
145
+ return x
146
+
147
+ class ConvBnReLU(nn.Module):
148
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
149
+ super(ConvBnReLU, self).__init__()
150
+ self.conv = nn.Conv2d(
151
+ in_channels,
152
+ out_channels,
153
+ kernel_size,
154
+ stride=stride,
155
+ padding=pad,
156
+ bias=False
157
+ )
158
+ self.bn = nn.BatchNorm2d(out_channels)
159
+
160
+ def forward(self, x):
161
+ return F.relu(self.bn(self.conv(x)), inplace=True)
162
+
163
+
164
+ class HourglassDecoder(nn.Module):
165
+ def __init__(self, cfg):
166
+ super(HourglassDecoder, self).__init__()
167
+ self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048]
168
+ self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256]
169
+ self.min_val = cfg.data_basic.depth_normalize[0]
170
+ self.max_val = cfg.data_basic.depth_normalize[1]
171
+
172
+ self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256]
173
+ self.num_depth_regressor_anchor = 512
174
+ self.feat_channels = self.inchannels
175
+ unet_in_channel = self.num_ch_dec[1]
176
+ unet_out_channel = 256
177
+
178
+ self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec)
179
+ self.conv_out_2 = UNet(inp_ch=unet_in_channel,
180
+ output_chal=unet_out_channel + 1,
181
+ down_sample_times=3,
182
+ channel_mode='v0',
183
+ )
184
+
185
+ self.depth_regressor_2 = nn.Sequential(
186
+ nn.Conv2d(unet_out_channel,
187
+ self.num_depth_regressor_anchor,
188
+ kernel_size=3,
189
+ padding=1,
190
+ ),
191
+ nn.BatchNorm2d(self.num_depth_regressor_anchor),
192
+ nn.ReLU(inplace=True),
193
+ nn.Conv2d(
194
+ self.num_depth_regressor_anchor,
195
+ self.num_depth_regressor_anchor,
196
+ kernel_size=1,
197
+ )
198
+ )
199
+ self.residual_channel = 16
200
+ self.conv_up_2 = nn.Sequential(
201
+ nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1),
202
+ nn.BatchNorm2d(self.residual_channel),
203
+ nn.ReLU(),
204
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
205
+ nn.Upsample(scale_factor=4),
206
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
207
+ nn.ReLU(),
208
+ nn.Conv2d(self.residual_channel, 1, 1, padding=0),
209
+ )
210
+
211
+ def get_bins(self, bins_num):
212
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda')
213
+ depth_bins_vec = torch.exp(depth_bins_vec)
214
+ return depth_bins_vec
215
+
216
+ def register_depth_expectation_anchor(self, bins_num, B):
217
+ depth_bins_vec = self.get_bins(bins_num)
218
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
219
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
220
+
221
+ def upsample(self, x, scale_factor=2):
222
+ return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
223
+
224
+ def regress_depth_2(self, feature_map_d):
225
+ prob = self.depth_regressor_2(feature_map_d).softmax(dim=1)
226
+ B = prob.shape[0]
227
+ if "depth_expectation_anchor" not in self._buffers:
228
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
229
+ d = compute_depth_expectation(
230
+ prob,
231
+ self.depth_expectation_anchor[:B, ...]
232
+ ).unsqueeze(1)
233
+ return d
234
+
235
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
236
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
237
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
238
+ meshgrid = torch.stack((x, y))
239
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
240
+ return meshgrid
241
+
242
+ def forward(self, features_mono, **kwargs):
243
+ '''
244
+ trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4]
245
+ inv_intrinsic_pool: list of inverse intrinsic matrix.
246
+ features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...].
247
+ '''
248
+ outputs = {}
249
+ # get encoder feature of the reference view
250
+ ref_feat = features_mono
251
+
252
+ feature_map_mono = self.decoder_mono(ref_feat)
253
+ feature_map_mono_pred = self.conv_out_2(feature_map_mono)
254
+ confidence_map_2 = feature_map_mono_pred[:, -1:, :, :]
255
+ feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :]
256
+
257
+ depth_pred_2 = self.regress_depth_2(feature_map_d_2)
258
+
259
+ B, _, H, W = depth_pred_2.shape
260
+
261
+ meshgrid = self.create_mesh_grid(H, W, B)
262
+
263
+ depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \
264
+ self.conv_up_2(
265
+ torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1)
266
+ )
267
+ confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4)
268
+
269
+ outputs=dict(
270
+ prediction=depth_pred_mono,
271
+ confidence=confidence_map_mono,
272
+ pred_logit=None,
273
+ )
274
+ return outputs
mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py ADDED
@@ -0,0 +1,1033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+ # LORA finetuning originally by edwardjhu
8
+ class LoRALayer():
9
+ def __init__(
10
+ self,
11
+ r: int,
12
+ lora_alpha: int,
13
+ lora_dropout: float,
14
+ merge_weights: bool,
15
+ ):
16
+ self.r = r
17
+ self.lora_alpha = lora_alpha
18
+ # Optional dropout
19
+ if lora_dropout > 0.:
20
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
21
+ else:
22
+ self.lora_dropout = lambda x: x
23
+ # Mark the weight as unmerged
24
+ self.merged = False
25
+ self.merge_weights = merge_weights
26
+
27
+ class LoRALinear(nn.Linear, LoRALayer):
28
+ # LoRA implemented in a dense layer
29
+ def __init__(
30
+ self,
31
+ in_features: int,
32
+ out_features: int,
33
+ r: int = 0,
34
+ lora_alpha: int = 1,
35
+ lora_dropout: float = 0.,
36
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
37
+ merge_weights: bool = True,
38
+ **kwargs
39
+ ):
40
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
41
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
42
+ merge_weights=merge_weights)
43
+
44
+ self.fan_in_fan_out = fan_in_fan_out
45
+ # Actual trainable parameters
46
+ if r > 0:
47
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
48
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
49
+ self.scaling = self.lora_alpha / self.r
50
+ # Freezing the pre-trained weight matrix
51
+ self.weight.requires_grad = False
52
+ self.reset_parameters()
53
+ if fan_in_fan_out:
54
+ self.weight.data = self.weight.data.transpose(0, 1)
55
+
56
+ def reset_parameters(self):
57
+ #nn.Linear.reset_parameters(self)
58
+ if hasattr(self, 'lora_A'):
59
+ # initialize B the same way as the default for nn.Linear and A to zero
60
+ # this is different than what is described in the paper but should not affect performance
61
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
62
+ nn.init.zeros_(self.lora_B)
63
+
64
+ # def train(self, mode: bool = True):
65
+ # def T(w):
66
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
67
+ # nn.Linear.train(self, mode)
68
+ # if mode:
69
+ # if self.merge_weights and self.merged:
70
+ # # Make sure that the weights are not merged
71
+ # if self.r > 0:
72
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
73
+ # self.merged = False
74
+ # else:
75
+ # if self.merge_weights and not self.merged:
76
+ # # Merge the weights and mark it
77
+ # if self.r > 0:
78
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
79
+ # self.merged = True
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ def T(w):
83
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
84
+ if self.r > 0 and not self.merged:
85
+ result = F.linear(x, T(self.weight), bias=self.bias)
86
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
87
+ return result
88
+ else:
89
+ return F.linear(x, T(self.weight), bias=self.bias)
90
+
91
+ class ConvLoRA(nn.Conv2d, LoRALayer):
92
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
93
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
94
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
95
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
96
+ assert isinstance(kernel_size, int)
97
+
98
+ # Actual trainable parameters
99
+ if r > 0:
100
+ self.lora_A = nn.Parameter(
101
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
102
+ )
103
+ self.lora_B = nn.Parameter(
104
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
105
+ )
106
+ self.scaling = self.lora_alpha / self.r
107
+ # Freezing the pre-trained weight matrix
108
+ self.weight.requires_grad = False
109
+ self.reset_parameters()
110
+ self.merged = False
111
+
112
+ def reset_parameters(self):
113
+ #self.conv.reset_parameters()
114
+ if hasattr(self, 'lora_A'):
115
+ # initialize A the same way as the default for nn.Linear and B to zero
116
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
117
+ nn.init.zeros_(self.lora_B)
118
+
119
+ # def train(self, mode=True):
120
+ # super(ConvLoRA, self).train(mode)
121
+ # if mode:
122
+ # if self.merge_weights and self.merged:
123
+ # if self.r > 0:
124
+ # # Make sure that the weights are not merged
125
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
126
+ # self.merged = False
127
+ # else:
128
+ # if self.merge_weights and not self.merged:
129
+ # if self.r > 0:
130
+ # # Merge the weights and mark it
131
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
132
+ # self.merged = True
133
+
134
+ def forward(self, x):
135
+ if self.r > 0 and not self.merged:
136
+ # return self.conv._conv_forward(
137
+ # x,
138
+ # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
139
+ # self.conv.bias
140
+ # )
141
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
142
+ bias = self.bias
143
+
144
+ return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
145
+ else:
146
+ return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
147
+
148
+ class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer):
149
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
150
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
151
+ nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
152
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
153
+ assert isinstance(kernel_size, int)
154
+
155
+ # Actual trainable parameters
156
+ if r > 0:
157
+ self.lora_A = nn.Parameter(
158
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
159
+ )
160
+ self.lora_B = nn.Parameter(
161
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
162
+ )
163
+ self.scaling = self.lora_alpha / self.r
164
+ # Freezing the pre-trained weight matrix
165
+ self.weight.requires_grad = False
166
+ self.reset_parameters()
167
+ self.merged = False
168
+
169
+ def reset_parameters(self):
170
+ #self.conv.reset_parameters()
171
+ if hasattr(self, 'lora_A'):
172
+ # initialize A the same way as the default for nn.Linear and B to zero
173
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
174
+ nn.init.zeros_(self.lora_B)
175
+
176
+ # def train(self, mode=True):
177
+ # super(ConvTransposeLoRA, self).train(mode)
178
+ # if mode:
179
+ # if self.merge_weights and self.merged:
180
+ # if self.r > 0:
181
+ # # Make sure that the weights are not merged
182
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
183
+ # self.merged = False
184
+ # else:
185
+ # if self.merge_weights and not self.merged:
186
+ # if self.r > 0:
187
+ # # Merge the weights and mark it
188
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
189
+ # self.merged = True
190
+
191
+ def forward(self, x):
192
+ if self.r > 0 and not self.merged:
193
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
194
+ bias = self.bias
195
+ return F.conv_transpose2d(x, weight,
196
+ bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
197
+ groups=self.groups, dilation=self.dilation)
198
+ else:
199
+ return F.conv_transpose2d(x, self.weight,
200
+ bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
201
+ groups=self.groups, dilation=self.dilation)
202
+ #return self.conv(x)
203
+
204
+ class Conv2dLoRA(ConvLoRA):
205
+ def __init__(self, *args, **kwargs):
206
+ super(Conv2dLoRA, self).__init__(*args, **kwargs)
207
+
208
+ class ConvTranspose2dLoRA(ConvTransposeLoRA):
209
+ def __init__(self, *args, **kwargs):
210
+ super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs)
211
+
212
+
213
+ def compute_depth_expectation(prob, depth_values):
214
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
215
+ depth = torch.sum(prob * depth_values, 1)
216
+ return depth
217
+
218
+ def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
219
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
220
+ return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
221
+
222
+ # def upflow8(flow, mode='bilinear'):
223
+ # new_size = (8 * flow.shape[2], 8 * flow.shape[3])
224
+ # return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
225
+
226
+ def upflow4(flow, mode='bilinear'):
227
+ new_size = (4 * flow.shape[2], 4 * flow.shape[3])
228
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
229
+ return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
230
+
231
+ def coords_grid(batch, ht, wd):
232
+ # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
233
+ coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
234
+ coords = torch.stack(coords[::-1], dim=0).float()
235
+ return coords[None].repeat(batch, 1, 1, 1)
236
+
237
+ def norm_normalize(norm_out):
238
+ min_kappa = 0.01
239
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
240
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
241
+ kappa = F.elu(kappa) + 1.0 + min_kappa
242
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
243
+ return final_out
244
+
245
+ # uncertainty-guided sampling (only used during training)
246
+ @torch.no_grad()
247
+ def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
248
+ device = init_normal.device
249
+ B, _, H, W = init_normal.shape
250
+ N = int(sampling_ratio * H * W)
251
+ beta = beta
252
+
253
+ # uncertainty map
254
+ uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W
255
+
256
+ # gt_invalid_mask (B, H, W)
257
+ if gt_norm_mask is not None:
258
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
259
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
260
+ uncertainty_map[gt_invalid_mask] = -1e4
261
+
262
+ # (B, H*W)
263
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
264
+
265
+ # importance sampling
266
+ if int(beta * N) > 0:
267
+ importance = idx[:, :int(beta * N)] # B, beta*N
268
+
269
+ # remaining
270
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
271
+
272
+ # coverage
273
+ num_coverage = N - int(beta * N)
274
+
275
+ if num_coverage <= 0:
276
+ samples = importance
277
+ else:
278
+ coverage_list = []
279
+ for i in range(B):
280
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
281
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
282
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
283
+ samples = torch.cat((importance, coverage), dim=1) # B, N
284
+
285
+ else:
286
+ # remaining
287
+ remaining = idx[:, :] # B, H*W
288
+
289
+ # coverage
290
+ num_coverage = N
291
+
292
+ coverage_list = []
293
+ for i in range(B):
294
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
295
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
296
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
297
+ samples = coverage
298
+
299
+ # point coordinates
300
+ rows_int = samples // W # 0 for first row, H-1 for last row
301
+ rows_float = rows_int / float(H-1) # 0 to 1.0
302
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
303
+
304
+ cols_int = samples % W # 0 for first column, W-1 for last column
305
+ cols_float = cols_int / float(W-1) # 0 to 1.0
306
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
307
+
308
+ point_coords = torch.zeros(B, 1, N, 2)
309
+ point_coords[:, 0, :, 0] = cols_float # x coord
310
+ point_coords[:, 0, :, 1] = rows_float # y coord
311
+ point_coords = point_coords.to(device)
312
+ return point_coords, rows_int, cols_int
313
+
314
+ class FlowHead(nn.Module):
315
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None):
316
+ super(FlowHead, self).__init__()
317
+ self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
318
+ self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
319
+
320
+ self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
321
+ self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
322
+ self.relu = nn.ReLU(inplace=True)
323
+
324
+ def forward(self, x):
325
+ depth = self.conv2d(self.relu(self.conv1d(x)))
326
+ normal = self.conv2n(self.relu(self.conv1n(x)))
327
+ return torch.cat((depth, normal), dim=1)
328
+
329
+
330
+ class ConvGRU(nn.Module):
331
+ def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None):
332
+ super(ConvGRU, self).__init__()
333
+ self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
334
+ self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
335
+ self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
336
+
337
+ def forward(self, h, cz, cr, cq, *x_list):
338
+ x = torch.cat(x_list, dim=1)
339
+ hx = torch.cat([h, x], dim=1)
340
+
341
+ z = torch.sigmoid((self.convz(hx) + cz))
342
+ r = torch.sigmoid((self.convr(hx) + cr))
343
+ q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))
344
+
345
+ # z = torch.sigmoid((self.convz(hx) + cz).float())
346
+ # r = torch.sigmoid((self.convr(hx) + cr).float())
347
+ # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())
348
+
349
+ h = (1-z) * h + z * q
350
+ return h
351
+
352
+ def pool2x(x):
353
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
354
+
355
+ def pool4x(x):
356
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
357
+
358
+ def interp(x, dest):
359
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
360
+ return interpolate_float32(x, dest.shape[2:], **interp_args)
361
+
362
+ class BasicMultiUpdateBlock(nn.Module):
363
+ def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None):
364
+ super().__init__()
365
+ self.args = args
366
+ self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
367
+ self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
368
+
369
+ # self.encoder = BasicMotionEncoder(args)
370
+ # encoder_output_dim = 128 # if there is corr volume
371
+ encoder_output_dim = 6 # no corr volume
372
+
373
+ self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode)
374
+ self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode)
375
+ self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode)
376
+ self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode)
377
+ factor = 2**self.n_downsample
378
+
379
+ self.mask = nn.Sequential(
380
+ Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0),
381
+ nn.ReLU(inplace=True),
382
+ Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0))
383
+
384
+ def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
385
+
386
+ if iter32:
387
+ net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
388
+ if iter16:
389
+ if self.n_gru_layers > 2:
390
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
391
+ else:
392
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
393
+ if iter08:
394
+ if corr is not None:
395
+ motion_features = self.encoder(flow, corr)
396
+ else:
397
+ motion_features = flow
398
+ if self.n_gru_layers > 1:
399
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
400
+ else:
401
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features)
402
+
403
+ if not update:
404
+ return net
405
+
406
+ delta_flow = self.flow_head(net[0])
407
+
408
+ # scale mask to balence gradients
409
+ mask = .25 * self.mask(net[0])
410
+ return net, mask, delta_flow
411
+
412
+ class LayerNorm2d(nn.LayerNorm):
413
+ def __init__(self, dim):
414
+ super(LayerNorm2d, self).__init__(dim)
415
+
416
+ def forward(self, x):
417
+ x = x.permute(0, 2, 3, 1).contiguous()
418
+ x = super(LayerNorm2d, self).forward(x)
419
+ x = x.permute(0, 3, 1, 2).contiguous()
420
+ return x
421
+
422
+ class ResidualBlock(nn.Module):
423
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None):
424
+ super(ResidualBlock, self).__init__()
425
+
426
+ self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0)
427
+ self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
428
+ self.relu = nn.ReLU(inplace=True)
429
+
430
+ num_groups = planes // 8
431
+
432
+ if norm_fn == 'group':
433
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
434
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
435
+ if not (stride == 1 and in_planes == planes):
436
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
437
+
438
+ elif norm_fn == 'batch':
439
+ self.norm1 = nn.BatchNorm2d(planes)
440
+ self.norm2 = nn.BatchNorm2d(planes)
441
+ if not (stride == 1 and in_planes == planes):
442
+ self.norm3 = nn.BatchNorm2d(planes)
443
+
444
+ elif norm_fn == 'instance':
445
+ self.norm1 = nn.InstanceNorm2d(planes)
446
+ self.norm2 = nn.InstanceNorm2d(planes)
447
+ if not (stride == 1 and in_planes == planes):
448
+ self.norm3 = nn.InstanceNorm2d(planes)
449
+
450
+ elif norm_fn == 'layer':
451
+ self.norm1 = LayerNorm2d(planes)
452
+ self.norm2 = LayerNorm2d(planes)
453
+ if not (stride == 1 and in_planes == planes):
454
+ self.norm3 = LayerNorm2d(planes)
455
+
456
+ elif norm_fn == 'none':
457
+ self.norm1 = nn.Sequential()
458
+ self.norm2 = nn.Sequential()
459
+ if not (stride == 1 and in_planes == planes):
460
+ self.norm3 = nn.Sequential()
461
+
462
+ if stride == 1 and in_planes == planes:
463
+ self.downsample = None
464
+
465
+ else:
466
+ self.downsample = nn.Sequential(
467
+ Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3)
468
+
469
+ def forward(self, x):
470
+ y = x
471
+ y = self.conv1(y)
472
+ y = self.norm1(y)
473
+ y = self.relu(y)
474
+ y = self.conv2(y)
475
+ y = self.norm2(y)
476
+ y = self.relu(y)
477
+
478
+ if self.downsample is not None:
479
+ x = self.downsample(x)
480
+
481
+ return self.relu(x+y)
482
+
483
+
484
+ class ContextFeatureEncoder(nn.Module):
485
+ '''
486
+ Encoder features are used to:
487
+ 1. initialize the hidden state of the update operator
488
+ 2. and also injected into the GRU during each iteration of the update operator
489
+ '''
490
+ def __init__(self, in_dim, output_dim, tuning_mode=None):
491
+ '''
492
+ in_dim = [x4, x8, x16, x32]
493
+ output_dim = [hindden_dims, context_dims]
494
+ [[x4,x8,x16,x32],[x4,x8,x16,x32]]
495
+ '''
496
+ super().__init__()
497
+
498
+ output_list = []
499
+ for dim in output_dim:
500
+ conv_out = nn.Sequential(
501
+ ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode),
502
+ Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
503
+ output_list.append(conv_out)
504
+
505
+ self.outputs04 = nn.ModuleList(output_list)
506
+
507
+ output_list = []
508
+ for dim in output_dim:
509
+ conv_out = nn.Sequential(
510
+ ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode),
511
+ Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
512
+ output_list.append(conv_out)
513
+
514
+ self.outputs08 = nn.ModuleList(output_list)
515
+
516
+ output_list = []
517
+ for dim in output_dim:
518
+ conv_out = nn.Sequential(
519
+ ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode),
520
+ Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
521
+ output_list.append(conv_out)
522
+
523
+ self.outputs16 = nn.ModuleList(output_list)
524
+
525
+ # output_list = []
526
+ # for dim in output_dim:
527
+ # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1)
528
+ # output_list.append(conv_out)
529
+
530
+ # self.outputs32 = nn.ModuleList(output_list)
531
+
532
+ def forward(self, encoder_features):
533
+ x_4, x_8, x_16, x_32 = encoder_features
534
+
535
+ outputs04 = [f(x_4) for f in self.outputs04]
536
+ outputs08 = [f(x_8) for f in self.outputs08]
537
+ outputs16 = [f(x_16)for f in self.outputs16]
538
+ # outputs32 = [f(x_32) for f in self.outputs32]
539
+
540
+ return (outputs04, outputs08, outputs16)
541
+
542
+ class ConvBlock(nn.Module):
543
+ # reimplementation of DPT
544
+ def __init__(self, channels, tuning_mode=None):
545
+ super(ConvBlock, self).__init__()
546
+
547
+ self.act = nn.ReLU(inplace=True)
548
+ self.conv1 = Conv2dLoRA(
549
+ channels,
550
+ channels,
551
+ kernel_size=3,
552
+ stride=1,
553
+ padding=1,
554
+ r = 8 if tuning_mode == 'lora' else 0
555
+ )
556
+ self.conv2 = Conv2dLoRA(
557
+ channels,
558
+ channels,
559
+ kernel_size=3,
560
+ stride=1,
561
+ padding=1,
562
+ r = 8 if tuning_mode == 'lora' else 0
563
+ )
564
+
565
+ def forward(self, x):
566
+ out = self.act(x)
567
+ out = self.conv1(out)
568
+ out = self.act(out)
569
+ out = self.conv2(out)
570
+ return x + out
571
+
572
+ class FuseBlock(nn.Module):
573
+ # reimplementation of DPT
574
+ def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None):
575
+ super(FuseBlock, self).__init__()
576
+
577
+ self.fuse = fuse
578
+ self.scale_factor = scale_factor
579
+ self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode)
580
+ if self.fuse:
581
+ self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode)
582
+
583
+ self.out_conv = Conv2dLoRA(
584
+ in_channels,
585
+ out_channels,
586
+ kernel_size=1,
587
+ stride=1,
588
+ padding=0,
589
+ r = 8 if tuning_mode == 'lora' else 0
590
+ )
591
+ self.upsample = upsample
592
+
593
+ def forward(self, x1, x2=None):
594
+ if x2 is not None:
595
+ x2 = self.way_branch(x2)
596
+ x1 = x1 + x2
597
+
598
+ out = self.way_trunk(x1)
599
+
600
+ if self.upsample:
601
+ out = interpolate_float32(
602
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
603
+ )
604
+ out = self.out_conv(out)
605
+ return out
606
+
607
+ class Readout(nn.Module):
608
+ # From DPT
609
+ def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
610
+ super(Readout, self).__init__()
611
+ self.use_cls_token = use_cls_token
612
+ if self.use_cls_token == True:
613
+ self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0)
614
+ self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0)
615
+ self.act = nn.GELU()
616
+ else:
617
+ self.project = nn.Identity()
618
+
619
+ def forward(self, x):
620
+
621
+ if self.use_cls_token == True:
622
+ x_patch = self.project_patch(x[0])
623
+ x_learn = self.project_learn(x[1])
624
+ x_learn = x_learn.expand_as(x_patch).contiguous()
625
+ features = x_patch + x_learn
626
+ return self.act(features)
627
+ else:
628
+ return self.project(x)
629
+
630
+ class Token2Feature(nn.Module):
631
+ # From DPT
632
+ def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
633
+ super(Token2Feature, self).__init__()
634
+ self.scale_factor = scale_factor
635
+ self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
636
+ if scale_factor > 1 and isinstance(scale_factor, int):
637
+ self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
638
+ in_channels=vit_channel,
639
+ out_channels=feature_channel,
640
+ kernel_size=scale_factor,
641
+ stride=scale_factor,
642
+ padding=0,
643
+ )
644
+
645
+ elif scale_factor > 1:
646
+ self.sample = nn.Sequential(
647
+ # Upsample2(upscale=scale_factor),
648
+ # nn.Upsample(scale_factor=scale_factor),
649
+ Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
650
+ in_channels=vit_channel,
651
+ out_channels=feature_channel,
652
+ kernel_size=1,
653
+ stride=1,
654
+ padding=0,
655
+ ),
656
+ )
657
+
658
+
659
+ elif scale_factor < 1:
660
+ scale_factor = int(1.0 / scale_factor)
661
+ self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
662
+ in_channels=vit_channel,
663
+ out_channels=feature_channel,
664
+ kernel_size=scale_factor+1,
665
+ stride=scale_factor,
666
+ padding=1,
667
+ )
668
+
669
+ else:
670
+ self.sample = nn.Identity()
671
+
672
+ def forward(self, x):
673
+ x = self.readoper(x)
674
+ #if use_cls_token == True:
675
+ x = x.permute(0, 3, 1, 2).contiguous()
676
+ if isinstance(self.scale_factor, float):
677
+ x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
678
+ x = self.sample(x)
679
+ return x
680
+
681
+ class EncoderFeature(nn.Module):
682
+ def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None):
683
+ super(EncoderFeature, self).__init__()
684
+ self.vit_channel = vit_channel
685
+ self.num_ch_dec = num_ch_dec
686
+
687
+ self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
688
+ self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
689
+ self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
690
+ self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
691
+
692
+ def forward(self, ref_feature):
693
+ x = self.read_3(ref_feature[3]) # 1/14
694
+ x2 = self.read_2(ref_feature[2]) # 1/14
695
+ x1 = self.read_1(ref_feature[1]) # 1/7
696
+ x0 = self.read_0(ref_feature[0]) # 1/4
697
+
698
+ return x, x2, x1, x0
699
+
700
+ class DecoderFeature(nn.Module):
701
+ def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None):
702
+ super(DecoderFeature, self).__init__()
703
+ self.vit_channel = vit_channel
704
+ self.num_ch_dec = num_ch_dec
705
+
706
+ self.upconv_3 = FuseBlock(
707
+ self.num_ch_dec[4],
708
+ self.num_ch_dec[3],
709
+ fuse=False, upsample=False, tuning_mode=tuning_mode)
710
+
711
+ self.upconv_2 = FuseBlock(
712
+ self.num_ch_dec[3],
713
+ self.num_ch_dec[2],
714
+ tuning_mode=tuning_mode)
715
+
716
+ self.upconv_1 = FuseBlock(
717
+ self.num_ch_dec[2],
718
+ self.num_ch_dec[1] + 2,
719
+ scale_factor=7/4,
720
+ tuning_mode=tuning_mode)
721
+
722
+ # self.upconv_0 = FuseBlock(
723
+ # self.num_ch_dec[1],
724
+ # self.num_ch_dec[0] + 1,
725
+ # )
726
+
727
+ def forward(self, ref_feature):
728
+ x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4
729
+
730
+ x = self.upconv_3(x) # 1/14
731
+ x = self.upconv_2(x, x2) # 1/7
732
+ x = self.upconv_1(x, x1) # 1/4
733
+ # x = self.upconv_0(x, x0) # 4/7
734
+ return x
735
+
736
+ class RAFTDepthNormalDPT5(nn.Module):
737
+ def __init__(self, cfg):
738
+ super().__init__()
739
+ self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
740
+ self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
741
+ self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
742
+ self.use_cls_token = cfg.model.decode_head.use_cls_token
743
+ self.up_scale = cfg.model.decode_head.up_scale
744
+ self.num_register_tokens = cfg.model.decode_head.num_register_tokens
745
+ self.min_val = cfg.data_basic.depth_normalize[0]
746
+ self.max_val = cfg.data_basic.depth_normalize[1]
747
+ self.regress_scale = 100.0\
748
+
749
+ try:
750
+ tuning_mode = cfg.model.decode_head.tuning_mode
751
+ except:
752
+ tuning_mode = None
753
+ self.tuning_mode = tuning_mode
754
+
755
+ self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
756
+ self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
757
+ self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
758
+ self.iters = cfg.model.decode_head.iters # 22
759
+ self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True
760
+
761
+ self.num_depth_regressor_anchor = 256 # 512
762
+ self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
763
+ self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode)
764
+ self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode)
765
+ self.depth_regressor = nn.Sequential(
766
+ Conv2dLoRA(self.used_res_channel,
767
+ self.num_depth_regressor_anchor,
768
+ kernel_size=3,
769
+ padding=1, r = 8 if tuning_mode == 'lora' else 0),
770
+ # nn.BatchNorm2d(self.num_depth_regressor_anchor),
771
+ nn.ReLU(inplace=True),
772
+ Conv2dLoRA(self.num_depth_regressor_anchor,
773
+ self.num_depth_regressor_anchor,
774
+ kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
775
+ )
776
+ self.normal_predictor = nn.Sequential(
777
+ Conv2dLoRA(self.used_res_channel,
778
+ 128,
779
+ kernel_size=3,
780
+ padding=1, r = 8 if tuning_mode == 'lora' else 0,),
781
+ # nn.BatchNorm2d(128),
782
+ nn.ReLU(inplace=True),
783
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
784
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
785
+ Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
786
+ )
787
+
788
+ self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode)
789
+ self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)])
790
+ self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode)
791
+
792
+ self.relu = nn.ReLU(inplace=True)
793
+
794
+ def get_bins(self, bins_num):
795
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
796
+ depth_bins_vec = torch.exp(depth_bins_vec)
797
+ return depth_bins_vec
798
+
799
+ def register_depth_expectation_anchor(self, bins_num, B):
800
+ depth_bins_vec = self.get_bins(bins_num)
801
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
802
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
803
+
804
+ def clamp(self, x):
805
+ y = self.relu(x - self.min_val) + self.min_val
806
+ y = self.max_val - self.relu(self.max_val - y)
807
+ return y
808
+
809
+ def regress_depth(self, feature_map_d):
810
+ prob_feature = self.depth_regressor(feature_map_d)
811
+ prob = prob_feature.softmax(dim=1)
812
+ #prob = prob_feature.float().softmax(dim=1)
813
+
814
+ ## Error logging
815
+ if torch.isnan(prob).any():
816
+ print('prob_feat_nan!!!')
817
+ if torch.isinf(prob).any():
818
+ print('prob_feat_inf!!!')
819
+
820
+ # h = prob[0,:,0,0].cpu().numpy().reshape(-1)
821
+ # import matplotlib.pyplot as plt
822
+ # plt.bar(range(len(h)), h)
823
+ B = prob.shape[0]
824
+ if "depth_expectation_anchor" not in self._buffers:
825
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
826
+ d = compute_depth_expectation(
827
+ prob,
828
+ self.depth_expectation_anchor[:B, ...]).unsqueeze(1)
829
+
830
+ ## Error logging
831
+ if torch.isnan(d ).any():
832
+ print('d_nan!!!')
833
+ if torch.isinf(d ).any():
834
+ print('d_inf!!!')
835
+
836
+ return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature
837
+
838
+ def pred_normal(self, feature_map, confidence):
839
+ normal_out = self.normal_predictor(feature_map)
840
+
841
+ ## Error logging
842
+ if torch.isnan(normal_out).any():
843
+ print('norm_nan!!!')
844
+ if torch.isinf(normal_out).any():
845
+ print('norm_feat_inf!!!')
846
+
847
+ return norm_normalize(torch.cat([normal_out, confidence], dim=1))
848
+ #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
849
+
850
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
851
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
852
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
853
+ meshgrid = torch.stack((x, y))
854
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
855
+ #self.register_buffer('meshgrid', meshgrid, persistent=False)
856
+ return meshgrid
857
+
858
+ def upsample_flow(self, flow, mask):
859
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
860
+ N, D, H, W = flow.shape
861
+ factor = 2 ** self.n_downsample
862
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
863
+ mask = torch.softmax(mask, dim=2)
864
+ #mask = torch.softmax(mask.float(), dim=2)
865
+
866
+ #up_flow = F.unfold(factor * flow, [3,3], padding=1)
867
+ up_flow = F.unfold(flow, [3,3], padding=1)
868
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
869
+
870
+ up_flow = torch.sum(mask * up_flow, dim=2)
871
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
872
+ return up_flow.reshape(N, D, factor*H, factor*W)
873
+
874
+ def initialize_flow(self, img):
875
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
876
+ N, _, H, W = img.shape
877
+
878
+ coords0 = coords_grid(N, H, W).to(img.device)
879
+ coords1 = coords_grid(N, H, W).to(img.device)
880
+
881
+ return coords0, coords1
882
+
883
+ def upsample(self, x, scale_factor=2):
884
+ """Upsample input tensor by a factor of 2
885
+ """
886
+ return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")
887
+
888
+ def forward(self, vit_features, **kwargs):
889
+ ## read vit token to multi-scale features
890
+ B, H, W, _, _, num_register_tokens = vit_features[1]
891
+ vit_features = vit_features[0]
892
+
893
+ ## Error logging
894
+ if torch.isnan(vit_features[0]).any():
895
+ print('vit_feature_nan!!!')
896
+ if torch.isinf(vit_features[0]).any():
897
+ print('vit_feature_inf!!!')
898
+
899
+ if self.use_cls_token == True:
900
+ vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
901
+ ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
902
+ else:
903
+ vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
904
+ encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4
905
+
906
+ ## Error logging
907
+ for en_ft in encoder_features:
908
+ if torch.isnan(en_ft).any():
909
+ print('decoder_feature_nan!!!')
910
+ print(en_ft.shape)
911
+ if torch.isinf(en_ft).any():
912
+ print('decoder_feature_inf!!!')
913
+ print(en_ft.shape)
914
+
915
+ ## decode features to init-depth (and confidence)
916
+ ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth
917
+
918
+ ## Error logging
919
+ if torch.isnan(ref_feat).any():
920
+ print('ref_feat_nan!!!')
921
+ if torch.isinf(ref_feat).any():
922
+ print('ref_feat_inf!!!')
923
+
924
+ feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
925
+ depth_confidence_map = ref_feat[:, -2:-1, :, :]
926
+ normal_confidence_map = ref_feat[:, -1:, :, :]
927
+ depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
928
+ normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal
929
+
930
+ depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)
931
+
932
+ ## encoder features to context-feature for init-hidden-state and contex-features
933
+ cnet_list = self.context_feature_encoder(encoder_features[::-1])
934
+ net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
935
+ inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features
936
+
937
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
938
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
939
+
940
+ coords0, coords1 = self.initialize_flow(net_list[0])
941
+ if depth_init is not None:
942
+ coords1 = coords1 + depth_init
943
+
944
+ if self.training:
945
+ low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
946
+ init_depth = upflow4(depth_init)
947
+ flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
948
+ conf_predictions = [init_depth[:,1:2]]
949
+ normal_outs = [norm_normalize(init_depth[:,2:].clone())]
950
+
951
+ else:
952
+ flow_predictions = []
953
+ conf_predictions = []
954
+ samples_pred_list = []
955
+ coord_list = []
956
+ normal_outs = []
957
+ low_resolution_init = []
958
+
959
+ for itr in range(self.iters):
960
+ # coords1 = coords1.detach()
961
+ flow = coords1 - coords0
962
+ if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
963
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
964
+ if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
965
+ net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
966
+ net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)
967
+
968
+ # F(t+1) = F(t) + \Delta(t)
969
+ coords1 = coords1 + delta_flow
970
+
971
+ # We do not need to upsample or output intermediate results in test_mode
972
+ #if (not self.training) and itr < self.iters-1:
973
+ #continue
974
+
975
+ # upsample predictions
976
+ if up_mask is None:
977
+ flow_up = self.upsample(coords1-coords0, 4)
978
+ else:
979
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
980
+ # flow_up = self.upsample(coords1-coords0, 4)
981
+
982
+ flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
983
+ conf_predictions.append(flow_up[:,1:2])
984
+ normal_outs.append(norm_normalize(flow_up[:,2:].clone()))
985
+
986
+ outputs=dict(
987
+ prediction=flow_predictions[-1],
988
+ predictions_list=flow_predictions,
989
+ confidence=conf_predictions[-1],
990
+ confidence_list=conf_predictions,
991
+ pred_logit=None,
992
+ # samples_pred_list=samples_pred_list,
993
+ # coord_list=coord_list,
994
+ prediction_normal=normal_outs[-1],
995
+ normal_out_list=normal_outs,
996
+ low_resolution_init=low_resolution_init,
997
+ )
998
+
999
+ return outputs
1000
+
1001
+
1002
+ if __name__ == "__main__":
1003
+ try:
1004
+ from mmcv.utils import Config
1005
+ except:
1006
+ from mmengine import Config
1007
+ cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
1008
+ cfg.model.decode_head.in_channels = [384, 384, 384, 384]
1009
+ cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
1010
+ cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
1011
+ cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
1012
+ cfg.model.decode_head.up_scale = 7
1013
+
1014
+ # cfg.model.decode_head.use_cls_token = True
1015
+ # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1016
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1017
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1018
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]]
1019
+
1020
+ cfg.model.decode_head.use_cls_token = True
1021
+ cfg.model.decode_head.num_register_tokens = 4
1022
+ vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
1023
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
1024
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
1025
+ torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]
1026
+
1027
+ decoder = RAFTDepthNormalDPT5(cfg).cuda()
1028
+ output = decoder(vit_feature)
1029
+ temp = 1
1030
+
1031
+
1032
+
1033
+
mono/model/decode_heads/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .HourGlassDecoder import HourglassDecoder
2
+ from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
3
+
4
+ __all__=['HourglassDecoder', 'RAFTDepthNormalDPT5']
mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc ADDED
Binary file (8.65 kB). View file
 
mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (252 Bytes). View file
 
mono/model/model_pipelines/__base_model__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mono.utils.comm import get_func
4
+
5
+
6
+ class BaseDepthModel(nn.Module):
7
+ def __init__(self, cfg, **kwargs) -> None:
8
+ super(BaseDepthModel, self).__init__()
9
+ model_type = cfg.model.type
10
+ self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
11
+
12
+ def forward(self, data):
13
+ output = self.depth_model(**data)
14
+
15
+ return output['prediction'], output['confidence'], output
16
+
17
+ def inference(self, data):
18
+ with torch.no_grad():
19
+ pred_depth, confidence, _ = self.forward(data)
20
+ return pred_depth, confidence
mono/model/model_pipelines/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .dense_pipeline import DensePredModel
3
+ from .__base_model__ import BaseDepthModel
4
+ __all__ = [
5
+ 'DensePredModel', 'BaseDepthModel',
6
+ ]
mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc ADDED
Binary file (1.19 kB). View file
 
mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (313 Bytes). View file
 
mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc ADDED
Binary file (1.01 kB). View file
 
mono/model/model_pipelines/dense_pipeline.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mono.utils.comm import get_func
4
+
5
+ class DensePredModel(nn.Module):
6
+ def __init__(self, cfg) -> None:
7
+ super(DensePredModel, self).__init__()
8
+
9
+ self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
10
+ self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
11
+
12
+ def forward(self, input, **kwargs):
13
+ # [f_32, f_16, f_8, f_4]
14
+ features = self.encoder(input)
15
+ out = self.decoder(features, **kwargs)
16
+ return out
mono/model/monodepth_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .model_pipelines.__base_model__ import BaseDepthModel
4
+
5
+ class DepthModel(BaseDepthModel):
6
+ def __init__(self, cfg, **kwards):
7
+ super(DepthModel, self).__init__(cfg)
8
+ model_type = cfg.model.type
9
+
10
+ def inference(self, data):
11
+ with torch.no_grad():
12
+ pred_depth, confidence, output_dict = self.forward(data)
13
+ return pred_depth, confidence, output_dict
14
+
15
+ def get_monodepth_model(
16
+ cfg : dict,
17
+ **kwargs
18
+ ) -> nn.Module:
19
+ # config depth model
20
+ model = DepthModel(cfg, **kwargs)
21
+ #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
22
+ assert isinstance(model, nn.Module)
23
+ return model
24
+
25
+ def get_configured_monodepth_model(
26
+ cfg: dict,
27
+ ) -> nn.Module:
28
+ """
29
+ Args:
30
+ @ configs: configures for the network.
31
+ @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
32
+ @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
33
+ Returns:
34
+ # model: depth model.
35
+ """
36
+ model = get_monodepth_model(cfg)
37
+ return model
mono/tools/test_scale_cano.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import cv2
4
+ import time
5
+ import sys
6
+ CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ sys.path.append(CODE_SPACE)
8
+ import argparse
9
+ import mmcv
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.multiprocessing as mp
13
+
14
+ try:
15
+ from mmcv.utils import Config, DictAction
16
+ except:
17
+ from mmengine import Config, DictAction
18
+ from datetime import timedelta
19
+ import random
20
+ import numpy as np
21
+ from mono.utils.logger import setup_logger
22
+ import glob
23
+ from mono.utils.comm import init_env
24
+ from mono.model.monodepth_model import get_configured_monodepth_model
25
+ from mono.utils.running import load_ckpt
26
+ from mono.utils.do_test import do_scalecano_test_with_custom_data
27
+ from mono.utils.mldb import load_data_info, reset_ckpt_path
28
+ from mono.utils.custom_data import load_from_annos, load_data
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser(description='Train a segmentor')
32
+ parser.add_argument('config', help='train config file path')
33
+ parser.add_argument('--show-dir', help='the dir to save logs and visualization results')
34
+ parser.add_argument('--load-from', help='the checkpoint file to load weights from')
35
+ parser.add_argument('--node_rank', type=int, default=0)
36
+ parser.add_argument('--nnodes', type=int, default=1, help='number of nodes')
37
+ parser.add_argument('--options', nargs='+', action=DictAction, help='custom options')
38
+ parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher')
39
+ parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data')
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ def main(args):
44
+ os.chdir(CODE_SPACE)
45
+ cfg = Config.fromfile(args.config)
46
+
47
+ if args.options is not None:
48
+ cfg.merge_from_dict(args.options)
49
+
50
+ # show_dir is determined in this priority: CLI > segment in file > filename
51
+ if args.show_dir is not None:
52
+ # update configs according to CLI args if args.show_dir is not None
53
+ cfg.show_dir = args.show_dir
54
+ else:
55
+ # use condig filename + timestamp as default show_dir if args.show_dir is None
56
+ cfg.show_dir = osp.join('./show_dirs',
57
+ osp.splitext(osp.basename(args.config))[0],
58
+ args.timestamp)
59
+
60
+ # ckpt path
61
+ if args.load_from is None:
62
+ raise RuntimeError('Please set model path!')
63
+ cfg.load_from = args.load_from
64
+
65
+ # load data info
66
+ data_info = {}
67
+ load_data_info('data_info', data_info=data_info)
68
+ cfg.mldb_info = data_info
69
+ # update check point info
70
+ reset_ckpt_path(cfg.model, data_info)
71
+
72
+ # create show dir
73
+ os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True)
74
+
75
+ # init the logger before other steps
76
+ cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log')
77
+ logger = setup_logger(cfg.log_file)
78
+
79
+ # log some basic info
80
+ logger.info(f'Config:\n{cfg.pretty_text}')
81
+
82
+ # init distributed env dirst, since logger depends on the dist info
83
+ if args.launcher == 'None':
84
+ cfg.distributed = False
85
+ else:
86
+ cfg.distributed = True
87
+ init_env(args.launcher, cfg)
88
+ logger.info(f'Distributed training: {cfg.distributed}')
89
+
90
+ # dump config
91
+ cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config)))
92
+ test_data_path = args.test_data_path
93
+ if not os.path.isabs(test_data_path):
94
+ test_data_path = osp.join(CODE_SPACE, test_data_path)
95
+
96
+ if 'json' in test_data_path:
97
+ test_data = load_from_annos(test_data_path)
98
+ else:
99
+ test_data = load_data(args.test_data_path)
100
+
101
+ if not cfg.distributed:
102
+ main_worker(0, cfg, args.launcher, test_data)
103
+ else:
104
+ # distributed training
105
+ if args.launcher == 'ror':
106
+ local_rank = cfg.dist_params.local_rank
107
+ main_worker(local_rank, cfg, args.launcher, test_data)
108
+ else:
109
+ mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data))
110
+
111
+ def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list):
112
+ if cfg.distributed:
113
+ cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
114
+ cfg.dist_params.local_rank = local_rank
115
+
116
+ if launcher == 'ror':
117
+ init_torch_process_group(use_hvd=False)
118
+ else:
119
+ torch.cuda.set_device(local_rank)
120
+ default_timeout = timedelta(minutes=30)
121
+ dist.init_process_group(
122
+ backend=cfg.dist_params.backend,
123
+ init_method=cfg.dist_params.dist_url,
124
+ world_size=cfg.dist_params.world_size,
125
+ rank=cfg.dist_params.global_rank,
126
+ timeout=default_timeout)
127
+
128
+ logger = setup_logger(cfg.log_file)
129
+ # build model
130
+ model = get_configured_monodepth_model(cfg, )
131
+
132
+ # config distributed training
133
+ if cfg.distributed:
134
+ model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
135
+ device_ids=[local_rank],
136
+ output_device=local_rank,
137
+ find_unused_parameters=True)
138
+ else:
139
+ model = torch.nn.DataParallel(model).cuda()
140
+
141
+ # load ckpt
142
+ model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
143
+ model.eval()
144
+
145
+ do_scalecano_test_with_custom_data(
146
+ model,
147
+ cfg,
148
+ test_data,
149
+ logger,
150
+ cfg.distributed,
151
+ local_rank
152
+ )
153
+
154
+ if __name__ == '__main__':
155
+ args = parse_args()
156
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
157
+ args.timestamp = timestamp
158
+ main(args)
mono/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mono/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (160 Bytes). View file
 
mono/utils/__pycache__/avg_meter.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
mono/utils/__pycache__/comm.cpython-39.pyc ADDED
Binary file (9.72 kB). View file
 
mono/utils/__pycache__/custom_data.cpython-39.pyc ADDED
Binary file (1.21 kB). View file
 
mono/utils/__pycache__/do_test.cpython-39.pyc ADDED
Binary file (8.71 kB). View file
 
mono/utils/__pycache__/logger.cpython-39.pyc ADDED
Binary file (3.17 kB). View file
 
mono/utils/__pycache__/mldb.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
mono/utils/__pycache__/running.cpython-39.pyc ADDED
Binary file (2.09 kB). View file
 
mono/utils/__pycache__/transform.cpython-39.pyc ADDED
Binary file (11.5 kB). View file
 
mono/utils/__pycache__/unproj_pcd.cpython-39.pyc ADDED
Binary file (2.61 kB). View file