Qihang Yu commited on
Commit
a06fad0
1 Parent(s): f6d10ab

Add kMaX-DeepLab

Browse files
Files changed (40) hide show
  1. app.py +71 -4
  2. configs/coco/panoptic-segmentation/kmax_convnext_base.yaml +13 -0
  3. configs/coco/panoptic-segmentation/kmax_convnext_large.yaml +13 -0
  4. configs/coco/panoptic-segmentation/kmax_convnext_small.yaml +13 -0
  5. configs/coco/panoptic-segmentation/kmax_convnext_tiny.yaml +13 -0
  6. configs/coco/panoptic-segmentation/kmax_r50.yaml +91 -0
  7. convert-pretrained-model-to-d2.py +36 -0
  8. convert-tf-weights-to-d2.py +400 -0
  9. demo/demo.ipynb +213 -0
  10. demo/demo.py +156 -0
  11. demo/predictor.py +166 -0
  12. docs/clustering_view_of_mask_transformer.png +0 -0
  13. docs/kmax_decoder.png +0 -0
  14. kmax_deeplab/__init__.py +15 -0
  15. kmax_deeplab/config.py +96 -0
  16. kmax_deeplab/data/__init__.py +1 -0
  17. kmax_deeplab/data/dataset_mappers/__init__.py +0 -0
  18. kmax_deeplab/data/dataset_mappers/coco_panoptic_kmaxdeeplab_dataset_mapper.py +326 -0
  19. kmax_deeplab/data/datasets/__init__.py +3 -0
  20. kmax_deeplab/data/datasets/register_coco_panoptic_annos_semseg.py +182 -0
  21. kmax_deeplab/evaluation/__init__.py +0 -0
  22. kmax_deeplab/evaluation/instance_evaluation.py +107 -0
  23. kmax_deeplab/evaluation/panoptic_evaluation.py +269 -0
  24. kmax_deeplab/kmax_model.py +446 -0
  25. kmax_deeplab/modeling/__init__.py +4 -0
  26. kmax_deeplab/modeling/backbone/__init__.py +0 -0
  27. kmax_deeplab/modeling/backbone/convnext.py +210 -0
  28. kmax_deeplab/modeling/backbone/resnet.py +697 -0
  29. kmax_deeplab/modeling/criterion.py +432 -0
  30. kmax_deeplab/modeling/matcher.py +128 -0
  31. kmax_deeplab/modeling/meta_arch/__init__.py +0 -0
  32. kmax_deeplab/modeling/meta_arch/kmax_deeplab_head.py +88 -0
  33. kmax_deeplab/modeling/pixel_decoder/__init__.py +0 -0
  34. kmax_deeplab/modeling/pixel_decoder/kmax_pixel_decoder.py +370 -0
  35. kmax_deeplab/modeling/transformer_decoder/__init__.py +1 -0
  36. kmax_deeplab/modeling/transformer_decoder/kmax_transformer_decoder.py +453 -0
  37. pakages.txt +4 -0
  38. requirements.txt +34 -0
  39. train_net.py +266 -0
  40. train_net_utils.py +225 -0
app.py CHANGED
@@ -1,7 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ os.system("pip install gdown")
5
+
6
+ os.system("pip install imutils")
7
+
8
+ os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
9
+
10
+ os.system("pip install git+https://github.com/cocodataset/panopticapi.git")
11
+
12
  import gradio as gr
13
+ # check pytorch installation:
14
+ import detectron2
15
+ from detectron2.utils.logger import setup_logger
16
+
17
+ # import some common libraries
18
+ import numpy as np
19
+ import cv2
20
+ import torch
21
+
22
+ # import some common detectron2 utilities
23
+ from detectron2 import model_zoo
24
+ from detectron2.engine import DefaultPredictor
25
+ from detectron2.config import get_cfg
26
+ from detectron2.utils.visualizer import Visualizer, ColorMode
27
+ from detectron2.data import MetadataCatalog
28
+ from detectron2.projects.deeplab import add_deeplab_config
29
+ coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
30
+
31
+ # import kMaXDeepLab project
32
+ from kmax_deeplab import add_kmax_deeplab_config
33
+
34
+ from PIL import Image
35
+ import imutils
36
+
37
+ cfg = get_cfg()
38
+ cfg.MODEL.DEVICE='cpu'
39
+ add_deeplab_config(cfg)
40
+ add_kmax_deeplab_config(cfg)
41
+ cfg.merge_from_file("configs/coco/panoptic-segmentation/kmax_convnext_large.yaml")
42
+ os.system("gdown 1b6rEnKw4PNTdqSdWpmb0P9dsvN0pkOiN")
43
+ cfg.MODEL.WEIGHTS = './kmax_convnext_large.pth'
44
+ cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = True
45
+ cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = True
46
+ cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True
47
+ predictor = DefaultPredictor(cfg)
48
+
49
+ os.system("wget https://i.imgur.com/Vj17K5z.jpg")
50
+
51
+ def inference(img):
52
+ im = cv2.imread(img)
53
+ im = imutils.resize(im, width=512)
54
+ outputs = predictor(im)
55
+ v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
56
+ panoptic_result = v.draw_panoptic_seg(outputs["panoptic_seg"][0].to("cpu"), outputs["panoptic_seg"][1]).get_image()
57
+ v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
58
+ instance_result = v.draw_instance_predictions(outputs["instances"].to("cpu")).get_image()
59
+ v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
60
+ semantic_result = v.draw_sem_seg(outputs["sem_seg"].argmax(0).to("cpu")).get_image()
61
+ return Image.fromarray(np.uint8(panoptic_result)).convert('RGB'),Image.fromarray(np.uint8(instance_result)).convert('RGB'),Image.fromarray(np.uint8(semantic_result)).convert('RGB')
62
+
63
+
64
+ title = "kMaX-DeepLab"
65
+ description = "Gradio demo for kMaX-DeepLab. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
66
+
67
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.01527' target='_blank'>kMaX-DeepLab</a> | <a href='https://github.com/google-research/deeplab2' target='_blank'>Github Repo</a></p>"
68
 
69
+ examples = [['Vj17K5z.jpg']]
 
70
 
71
+ gr.Interface(inference, inputs=gr.inputs.Image(type="filepath"), outputs=[gr.outputs.Image(label="Panoptic segmentation",type="pil"),gr.outputs.Image(label="instance segmentation",type="pil"),gr.outputs.Image(label="semantic segmentation",type="pil")], title=title,
72
+ description=description,
73
+ article=article,
74
+ examples=examples).launch(enable_queue=True,cache_examples=True)
configs/coco/panoptic-segmentation/kmax_convnext_base.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: kmax_r50.yaml
2
+ MODEL:
3
+ # backbone part.
4
+ BACKBONE:
5
+ NAME: "D2ConvNeXt"
6
+ WEIGHTS: "./convnext_base_22k_1k_384.pkl"
7
+ CONVNEXT:
8
+ IN_CHANNELS: 3
9
+ DEPTHS: [3, 3, 27, 3]
10
+ DIMS: [128, 256, 512, 1024]
11
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_base_os32.textproto#L28
12
+ DROP_PATH_RATE: 0.5
13
+ OUT_INDICES: [0, 1, 2, 3]
configs/coco/panoptic-segmentation/kmax_convnext_large.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: kmax_r50.yaml
2
+ MODEL:
3
+ # backbone part.
4
+ BACKBONE:
5
+ NAME: "D2ConvNeXt"
6
+ WEIGHTS: "./convnext_large_22k_1k_384.pkl"
7
+ CONVNEXT:
8
+ IN_CHANNELS: 3
9
+ DEPTHS: [3, 3, 27, 3]
10
+ DIMS: [192, 384, 768, 1536]
11
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_large_os32.textproto#L28
12
+ DROP_PATH_RATE: 0.6
13
+ OUT_INDICES: [0, 1, 2, 3]
configs/coco/panoptic-segmentation/kmax_convnext_small.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: kmax_r50.yaml
2
+ MODEL:
3
+ # backbone part.
4
+ BACKBONE:
5
+ NAME: "D2ConvNeXt"
6
+ WEIGHTS: "./convnext_small_22k_1k_384.pkl"
7
+ CONVNEXT:
8
+ IN_CHANNELS: 3
9
+ DEPTHS: [3, 3, 27, 3]
10
+ DIMS: [96, 192, 384, 768]
11
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_small_os32.textproto#L28
12
+ DROP_PATH_RATE: 0.4
13
+ OUT_INDICES: [0, 1, 2, 3]
configs/coco/panoptic-segmentation/kmax_convnext_tiny.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: kmax_r50.yaml
2
+ MODEL:
3
+ # backbone part.
4
+ BACKBONE:
5
+ NAME: "D2ConvNeXt"
6
+ WEIGHTS: "./convnext_tiny_22k_1k_384.pkl"
7
+ CONVNEXT:
8
+ IN_CHANNELS: 3
9
+ DEPTHS: [3, 3, 9, 3]
10
+ DIMS: [96, 192, 384, 768]
11
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_convnext_tiny_os32.textproto#L28
12
+ DROP_PATH_RATE: 0.3
13
+ OUT_INDICES: [0, 1, 2, 3]
configs/coco/panoptic-segmentation/kmax_r50.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ # backbone part.
3
+ BACKBONE:
4
+ FREEZE_AT: 0
5
+ NAME: "custom_bn_build_resnet_backbone" # we customize the momentum and eps in syncbn, to align with tf implementation.
6
+ WEIGHTS: "../R-50.pkl"
7
+ PIXEL_MEAN: [127.5, 127.5, 127.5]
8
+ PIXEL_STD: [127.5, 127.5, 127.5]
9
+ RESNETS:
10
+ DEPTH: 50
11
+ STEM_TYPE: "basic" # not used
12
+ STEM_OUT_CHANNELS: 64
13
+ STRIDE_IN_1X1: False
14
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ NORM: "SyncBN"
16
+ RES5_MULTI_GRID: [1, 1, 1] # not used
17
+
18
+ # kmax part.
19
+ META_ARCHITECTURE: "kMaXDeepLab"
20
+ SEM_SEG_HEAD:
21
+ NAME: "kMaXDeepLabHead"
22
+ IGNORE_VALUE: 255
23
+ NUM_CLASSES: 133
24
+ LOSS_WEIGHT: 1.0
25
+
26
+ KMAX_DEEPLAB:
27
+ SAVE_VIS_NUM: 0
28
+ SHARE_FINAL_MATCHING: True
29
+ DEEP_SUPERVISION: True
30
+ NO_OBJECT_WEIGHT: 1e-5
31
+ CLASS_WEIGHT: 3.0
32
+ DICE_WEIGHT: 3.0
33
+ MASK_WEIGHT: 0.3
34
+ INSDIS_WEIGHT: 1.0
35
+ AUX_SEMANTIC_WEIGHT: 1.0
36
+
37
+ PIXEL_DEC:
38
+ NAME: "kMaXPixelDecoder"
39
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
40
+ DEC_LAYERS: [1, 5, 1, 1]
41
+ LAYER_TYPES: ["axial", "axial", "bottleneck", "bottleneck"]
42
+ DEC_CHANNELS: [512, 256, 128, 64]
43
+
44
+ TRANS_DEC:
45
+ NAME: "kMaXTransformerDecoder"
46
+ DEC_LAYERS: [2, 2, 2]
47
+ NUM_OBJECT_QUERIES: 128
48
+ IN_CHANNELS: [2048, 1024, 512] # [512 * 4, 256 * 4, 128 * 4]
49
+ DROP_PATH_PROB: 0.2
50
+
51
+ TEST:
52
+ SEMANTIC_ON: False
53
+ INSTANCE_ON: False # Save some time :)
54
+ PANOPTIC_ON: True
55
+ OBJECT_MASK_THRESHOLD: 0.4
56
+ CLASS_THRESHOLD_THING: 0.7
57
+ CLASS_THRESHOLD_STUFF: 0.5
58
+ REORDER_CLASS_WEIGHT: 1.0
59
+ REORDER_MASK_WEIGHT: 1.0
60
+ OVERLAP_THRESHOLD: 0.8
61
+
62
+ DATASETS:
63
+ TRAIN: ("coco_2017_train_panoptic",)
64
+ TEST: ("coco_2017_val_panoptic",)
65
+ SOLVER:
66
+ IMS_PER_BATCH: 64
67
+ BASE_LR: 0.0005
68
+ LR_SCHEDULER_NAME: "TF2WarmupPolyLR"
69
+ MAX_ITER: 150000
70
+ WARMUP_ITERS: 5000
71
+ WEIGHT_DECAY: 0.05
72
+ OPTIMIZER: "ADAMW"
73
+ BACKBONE_MULTIPLIER: 0.1
74
+ CLIP_GRADIENTS:
75
+ ENABLED: False
76
+ AMP:
77
+ ENABLED: True
78
+ INPUT:
79
+ IMAGE_SIZE: [1281, 1281]
80
+ MIN_SCALE: 0.2
81
+ MAX_SCALE: 2.0
82
+ FORMAT: "RGB"
83
+ DATASET_MAPPER_NAME: "coco_panoptic_lsj"
84
+ MIN_SIZE_TEST: 1281
85
+ MAX_SIZE_TEST: 1281
86
+ TEST:
87
+ EVAL_PERIOD: 5000
88
+ DATALOADER:
89
+ FILTER_EMPTY_ANNOTATIONS: True
90
+ NUM_WORKERS: 4
91
+ VERSION: 2
convert-pretrained-model-to-d2.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+
4
+ import pickle as pkl
5
+ import sys
6
+
7
+ import torch
8
+
9
+ """
10
+ Usage:
11
+ # download pretrained swin model:
12
+ wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
13
+ # run the conversion
14
+ ./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl
15
+ # Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config:
16
+ MODEL:
17
+ WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl"
18
+ INPUT:
19
+ FORMAT: "RGB"
20
+ """
21
+
22
+ if __name__ == "__main__":
23
+ input = sys.argv[1]
24
+
25
+ obj = torch.load(input, map_location="cpu")["model"]
26
+
27
+ # Clean unused convnext weight
28
+ if "norm.weight" in obj:
29
+ del obj["norm.weight"]
30
+ if "norm.bias" in obj:
31
+ del obj["norm.bias"]
32
+
33
+ res = {"model": obj, "__author__": "third_party", "matching_heuristics": True}
34
+
35
+ with open(sys.argv[2], "wb") as f:
36
+ pkl.dump(res, f)
convert-tf-weights-to-d2.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import pickle as pkl
3
+ import sys
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ def load_tf_weights(ckpt_path):
9
+ # https://stackoverflow.com/questions/40118062/how-to-read-weights-saved-in-tensorflow-checkpoint-file
10
+ from tensorflow.python.training import py_checkpoint_reader
11
+ reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
12
+ state_dict = {}
13
+ for k in reader.get_variable_to_shape_map():
14
+ if '.OPTIMIZER_SLOT' in k or 'optimizer' in k or '_CHECKPOINTABLE_OBJECT_GRAPH' in k or 'save_counter' in k or 'global_step' in k:
15
+ continue
16
+ v = reader.get_tensor(k)
17
+ state_dict[k.replace('/.ATTRIBUTES/VARIABLE_VALUE', '')] = v
18
+ for k in sorted(state_dict.keys()):
19
+ print(k, state_dict[k].shape)
20
+ return state_dict
21
+
22
+ def map_bn(name1, name2):
23
+ res = {}
24
+ res[name1 + '/gamma'] = name2 + ".weight"
25
+ res[name1 + '/beta'] = name2 + ".bias"
26
+ res[name1 + '/moving_mean'] = name2 + ".running_mean"
27
+ res[name1 + '/moving_variance'] = name2 + ".running_var"
28
+ return res
29
+
30
+
31
+ def map_conv(name1, name2, dw=False, bias=False):
32
+ res = {}
33
+ if dw:
34
+ res[name1 + '/depthwise_kernel'] = name2 + ".weight"
35
+ else:
36
+ res[name1 + '/kernel'] = name2 + ".weight"
37
+ if bias:
38
+ res[name1 + '/bias'] = name2 + ".bias"
39
+ return res
40
+
41
+
42
+ def tf_2_torch_mapping_r50():
43
+ res = {}
44
+ res.update(map_conv('encoder/_stem/_conv', 'backbone.stem.conv1'))
45
+ res.update(map_bn('encoder/_stem/_batch_norm', 'backbone.stem.conv1.norm'))
46
+ block_num = {2: 3, 3: 4, 4: 6, 5: 3}
47
+ for stage_idx in range(2, 6):
48
+ for block_idx in range(1, block_num[stage_idx] + 1):
49
+ res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
50
+ f'backbone.res{stage_idx}.{block_idx-1}.conv1'))
51
+ res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
52
+ f'backbone.res{stage_idx}.{block_idx-1}.conv1.norm'))
53
+ res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
54
+ f'backbone.res{stage_idx}.{block_idx-1}.conv2'))
55
+ res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
56
+ f'backbone.res{stage_idx}.{block_idx-1}.conv2.norm'))
57
+ res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
58
+ f'backbone.res{stage_idx}.{block_idx-1}.conv3'))
59
+ res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
60
+ f'backbone.res{stage_idx}.{block_idx-1}.conv3.norm'))
61
+ res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_conv',
62
+ f'backbone.res{stage_idx}.{block_idx-1}.shortcut'))
63
+ res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
64
+ f'backbone.res{stage_idx}.{block_idx-1}.shortcut.norm'))
65
+ return res
66
+
67
+ def tf_2_torch_mapping_convnext():
68
+ res = {}
69
+ for i in range(4):
70
+ if i == 0:
71
+ res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-0',
72
+ f'backbone.downsample_layers.{i}.0', bias=True))
73
+ res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-1',
74
+ f'backbone.downsample_layers.{i}.1'))
75
+ else:
76
+ res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-1',
77
+ f'backbone.downsample_layers.{i}.1', bias=True))
78
+ res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-0',
79
+ f'backbone.downsample_layers.{i}.0'))
80
+
81
+ block_num = {0: 3, 1: 3, 2: 27, 3: 3}
82
+ for stage_idx in range(4):
83
+ for block_idx in range(block_num[stage_idx]):
84
+ res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/depthwise_conv',
85
+ f'backbone.stages.{stage_idx}.{block_idx}.dwconv', bias=True))
86
+ res.update(map_bn(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/norm',
87
+ f'backbone.stages.{stage_idx}.{block_idx}.norm'))
88
+ res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv1',
89
+ f'backbone.stages.{stage_idx}.{block_idx}.pwconv1', bias=True))
90
+ res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv2',
91
+ f'backbone.stages.{stage_idx}.{block_idx}.pwconv2', bias=True))
92
+ res[f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/layer_scale'] = f'backbone.stages.{stage_idx}.{block_idx}.gamma'
93
+
94
+ return res
95
+
96
+ def tf_2_torch_mapping_pixel_dec():
97
+ res = {}
98
+ for i in range(4):
99
+ res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
100
+ res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
101
+ res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
102
+ res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
103
+
104
+ for i in range(3):
105
+ res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_conv',
106
+ f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.conv'))
107
+ res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_batch_norm',
108
+ f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.norm'))
109
+ res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_conv',
110
+ f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.conv'))
111
+ res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_batch_norm',
112
+ f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.norm'))
113
+
114
+ num_blocks = {0: 1, 1:5, 2:1, 3:1}
115
+ for stage_idx in range(4):
116
+ for block_idx in range(1, 1+num_blocks[stage_idx]):
117
+ res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_conv',
118
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.conv'))
119
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
120
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.norm'))
121
+ res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
122
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.conv'))
123
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
124
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.norm'))
125
+ res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
126
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.conv'))
127
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
128
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.norm'))
129
+ if stage_idx <= 1:
130
+ for attn in ['height', 'width']:
131
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_qkv',
132
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_qkv'))
133
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_retrieved_output',
134
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_retrieved_output'))
135
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_similarity',
136
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_similarity'))
137
+ res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_key_rpe/embeddings'] = (
138
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._key_rpe._embeddings.weight')
139
+ res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_query_rpe/embeddings'] = (
140
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._query_rpe._embeddings.weight')
141
+ res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_value_rpe/embeddings'] = (
142
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._value_rpe._embeddings.weight')
143
+ res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/qkv_kernel'] = (
144
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis.qkv_transform.conv.weight')
145
+ else:
146
+ res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
147
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.conv'))
148
+ res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
149
+ f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.norm'))
150
+ return res
151
+
152
+
153
+ def tf_2_torch_mapping_predcitor(prefix_tf, prefix_torch):
154
+ res = {}
155
+ res.update(map_bn(prefix_tf + 'pixel_space_feature_batch_norm',
156
+ prefix_torch + '_pixel_space_head_last_convbn.norm'))
157
+ res[prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel'] = (
158
+ prefix_torch + '_pixel_space_head_conv0bnact.conv.weight'
159
+ )
160
+ res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_batch_norm',
161
+ prefix_torch + '_pixel_space_head_conv0bnact.norm'))
162
+ res.update(map_conv(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_conv',
163
+ prefix_torch + '_pixel_space_head_conv1bnact.conv'))
164
+ res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_batch_norm',
165
+ prefix_torch + '_pixel_space_head_conv1bnact.norm'))
166
+ res.update(map_conv(prefix_tf + 'pixel_space_head/final_conv',
167
+ prefix_torch + '_pixel_space_head_last_convbn.conv', bias=True))
168
+ res.update(map_bn(prefix_tf + 'pixel_space_mask_batch_norm',
169
+ prefix_torch + '_pixel_space_mask_batch_norm'))
170
+ res.update(map_conv(prefix_tf + 'transformer_class_head/_conv',
171
+ prefix_torch + '_transformer_class_head.conv', bias=True))
172
+ res.update(map_conv(prefix_tf + 'transformer_mask_head/_conv',
173
+ prefix_torch + '_transformer_mask_head.conv'))
174
+ res.update(map_bn(prefix_tf + 'transformer_mask_head/_batch_norm',
175
+ prefix_torch + '_transformer_mask_head.norm'))
176
+
177
+ return res
178
+
179
+
180
+ def tf_2_torch_mapping_trans_dec():
181
+ res = {}
182
+
183
+ res.update(map_bn('transformer_decoder/_class_embedding_projection/_batch_norm',
184
+ 'sem_seg_head.predictor._class_embedding_projection.norm'))
185
+ res.update(map_conv('transformer_decoder/_class_embedding_projection/_conv',
186
+ 'sem_seg_head.predictor._class_embedding_projection.conv'))
187
+ res.update(map_bn('transformer_decoder/_mask_embedding_projection/_batch_norm',
188
+ 'sem_seg_head.predictor._mask_embedding_projection.norm'))
189
+ res.update(map_conv('transformer_decoder/_mask_embedding_projection/_conv',
190
+ 'sem_seg_head.predictor._mask_embedding_projection.conv'))
191
+
192
+ res['transformer_decoder/cluster_centers'] = 'sem_seg_head.predictor._cluster_centers.weight'
193
+
194
+ res.update(tf_2_torch_mapping_predcitor(
195
+ prefix_tf = '',
196
+ prefix_torch = 'sem_seg_head.predictor._predcitor.'
197
+ ))
198
+ for kmax_idx in range(6):
199
+ res.update(tf_2_torch_mapping_predcitor(
200
+ prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/_auxiliary_clustering_predictor/_',
201
+ prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}._predcitor.'
202
+ ))
203
+ common_prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/'
204
+ common_prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}.'
205
+ res.update(map_bn(common_prefix_tf + '_kmeans_memory_batch_norm_retrieved_value',
206
+ common_prefix_torch + '_kmeans_query_batch_norm_retrieved_value'))
207
+ res.update(map_bn(common_prefix_tf + '_kmeans_memory_conv3_bn/_batch_norm',
208
+ common_prefix_torch + '_kmeans_query_conv3_bn.norm'))
209
+ res.update(map_conv(common_prefix_tf + '_kmeans_memory_conv3_bn/_conv',
210
+ common_prefix_torch + '_kmeans_query_conv3_bn.conv'))
211
+ res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_retrieved_value',
212
+ common_prefix_torch + '_query_self_attention._batch_norm_retrieved_value'))
213
+ res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_similarity',
214
+ common_prefix_torch + '_query_self_attention._batch_norm_similarity'))
215
+
216
+ res.update(map_bn(common_prefix_tf + '_memory_conv1_bn_act/_batch_norm',
217
+ common_prefix_torch + '_query_conv1_bn_act.norm'))
218
+ res.update(map_conv(common_prefix_tf + '_memory_conv1_bn_act/_conv',
219
+ common_prefix_torch + '_query_conv1_bn_act.conv'))
220
+
221
+ res.update(map_bn(common_prefix_tf + '_memory_conv3_bn/_batch_norm',
222
+ common_prefix_torch + '_query_conv3_bn.norm'))
223
+ res.update(map_conv(common_prefix_tf + '_memory_conv3_bn/_conv',
224
+ common_prefix_torch + '_query_conv3_bn.conv'))
225
+
226
+ res.update(map_bn(common_prefix_tf + '_memory_ffn_conv1_bn_act/_batch_norm',
227
+ common_prefix_torch + '_query_ffn_conv1_bn_act.norm'))
228
+ res.update(map_conv(common_prefix_tf + '_memory_ffn_conv1_bn_act/_conv',
229
+ common_prefix_torch + '_query_ffn_conv1_bn_act.conv'))
230
+
231
+ res.update(map_bn(common_prefix_tf + '_memory_ffn_conv2_bn/_batch_norm',
232
+ common_prefix_torch + '_query_ffn_conv2_bn.norm'))
233
+ res.update(map_conv(common_prefix_tf + '_memory_ffn_conv2_bn/_conv',
234
+ common_prefix_torch + '_query_ffn_conv2_bn.conv'))
235
+
236
+ res.update(map_bn(common_prefix_tf + '_memory_qkv_conv_bn/_batch_norm',
237
+ common_prefix_torch + '_query_qkv_conv_bn.norm'))
238
+ res.update(map_conv(common_prefix_tf + '_memory_qkv_conv_bn/_conv',
239
+ common_prefix_torch + '_query_qkv_conv_bn.conv'))
240
+
241
+ res.update(map_bn(common_prefix_tf + '_pixel_conv1_bn_act/_batch_norm',
242
+ common_prefix_torch + '_pixel_conv1_bn_act.norm'))
243
+ res.update(map_conv(common_prefix_tf + '_pixel_conv1_bn_act/_conv',
244
+ common_prefix_torch + '_pixel_conv1_bn_act.conv'))
245
+
246
+ res.update(map_bn(common_prefix_tf + '_pixel_v_conv_bn/_batch_norm',
247
+ common_prefix_torch + '_pixel_v_conv_bn.norm'))
248
+ res.update(map_conv(common_prefix_tf + '_pixel_v_conv_bn/_conv',
249
+ common_prefix_torch + '_pixel_v_conv_bn.conv'))
250
+
251
+ return res
252
+
253
+
254
+ def tf_2_torch_mapping_aux_semanic_dec():
255
+ res = {}
256
+ res.update(map_conv('semantic_decoder/_aspp/_conv_bn_act/_conv',
257
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.conv'))
258
+ res.update(map_bn('semantic_decoder/_aspp/_conv_bn_act/_batch_norm',
259
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.norm'))
260
+
261
+ res.update(map_conv('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_conv',
262
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.conv'))
263
+ res.update(map_bn('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_batch_norm',
264
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.norm'))
265
+
266
+ res.update(map_conv('semantic_decoder/_aspp/_proj_conv_bn_act/_conv',
267
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.conv'))
268
+ res.update(map_bn('semantic_decoder/_aspp/_proj_conv_bn_act/_batch_norm',
269
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.norm'))
270
+ for i in range(1, 4):
271
+ res.update(map_conv(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_conv',
272
+ f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.conv'))
273
+ res.update(map_bn(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_batch_norm',
274
+ f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.norm'))
275
+
276
+ res.update({
277
+ 'semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
278
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.conv.weight'})
279
+ res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_batch_norm',
280
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.norm'))
281
+ res.update({
282
+ 'semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_conv/kernel':
283
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.conv.weight'})
284
+ res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_batch_norm',
285
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.norm'))
286
+
287
+ res.update({
288
+ 'semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
289
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.conv.weight'})
290
+ res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_batch_norm',
291
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.norm'))
292
+ res.update({
293
+ 'semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_conv/kernel':
294
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.conv.weight'})
295
+ res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_batch_norm',
296
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.norm'))
297
+
298
+ res.update({
299
+ 'semantic_decoder/_low_level_conv1/_conv/kernel':
300
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.conv.weight'})
301
+ res.update(map_bn('semantic_decoder/_low_level_conv1/_batch_norm',
302
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.norm'))
303
+ res.update({
304
+ 'semantic_decoder/_low_level_conv2/_conv/kernel':
305
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.conv.weight'})
306
+ res.update(map_bn('semantic_decoder/_low_level_conv2/_batch_norm',
307
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.norm'))
308
+
309
+
310
+ res.update({
311
+ 'semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
312
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight'})
313
+ res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_batch_norm',
314
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.norm'))
315
+ res.update({
316
+ 'semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_conv/kernel':
317
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.conv.weight'})
318
+ res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_batch_norm',
319
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.norm'))
320
+
321
+ res.update({
322
+ 'semantic_last_layer/kernel':
323
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.weight'})
324
+ res.update({
325
+ 'semantic_last_layer/bias':
326
+ 'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.bias'})
327
+ return res
328
+
329
+
330
+ # python3 convert-tf-weights-to-d2.py kmax_resnet50_coco_train/ckpt-150000 tf_kmax_r50.pkl
331
+
332
+ if __name__ == "__main__":
333
+ input = sys.argv[1]
334
+
335
+ state_dict = load_tf_weights(input)
336
+ #exit()
337
+
338
+ state_dict_torch = {}
339
+
340
+ mapping_key = {}
341
+ if 'resnet50' in input:
342
+ mapping_key.update(tf_2_torch_mapping_r50())
343
+ elif 'convnext' in input:
344
+ mapping_key.update(tf_2_torch_mapping_convnext())
345
+ mapping_key.update(tf_2_torch_mapping_pixel_dec())
346
+ mapping_key.update(tf_2_torch_mapping_trans_dec())
347
+
348
+ mapping_key.update(tf_2_torch_mapping_aux_semanic_dec())
349
+
350
+ for k in state_dict.keys():
351
+ value = state_dict[k]
352
+ k2 = mapping_key[k]
353
+ rank = len(value.shape)
354
+
355
+ if '_batch_norm_retrieved_output' in k2 or '_batch_norm_similarity' in k2 or '_batch_norm_retrieved_value' in k2:
356
+ value = np.reshape(value, [-1])
357
+ elif 'qkv_transform.conv.weight' in k2:
358
+ # (512, 1024) -> (1024, 512, 1)
359
+ value = np.transpose(value, (1, 0))[:, :, None]
360
+ elif '_cluster_centers.weight' in k2:
361
+ # (1, 128, 256) -> (256, 128)
362
+ value = np.transpose(value[0], (1, 0))
363
+ elif '_pixel_conv1_bn_act.conv.weight' in k2:
364
+ # (1, 512, 256) -> (256, 512, 1, 1)
365
+ value = np.transpose(value, (2, 1, 0))[:, :, :, None]
366
+ elif '_pixel_v_conv_bn.conv.weight' in k2:
367
+ # (1, 256, 256) -> (256, 256, 1, 1)
368
+ value = np.transpose(value, (2, 1, 0))[:, :, :, None]
369
+ elif '_pixel_space_head_conv0bnact.conv.weight' in k2:
370
+ # (5, 5, 256, 1) -> (256, 1, 5, 5)
371
+ value = np.transpose(value, (2, 3, 0, 1))
372
+ elif '/layer_scale' in k:
373
+ value = np.reshape(value, [-1])
374
+ elif 'pwconv1.weight' in k2 or 'pwconv2.weight' in k2:
375
+ # (128, 512) -> (512, 128)
376
+ value = np.transpose(value, (1, 0))
377
+ elif ('_low_level_fusion_os4_conv0_bn_act.conv.weight' in k2
378
+ or '_low_level_fusion_os8_conv0_bn_act.conv.weight' in k2
379
+ or 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight' in k2):
380
+ value = np.transpose(value, (2, 3, 0, 1))
381
+ else:
382
+ if rank == 1: # bias, norm etc
383
+ pass
384
+ elif rank == 2: # _query_rpe
385
+ pass
386
+ elif rank == 3: # conv 1d kernel, etc
387
+ value = np.transpose(value, (2, 1, 0))
388
+ elif rank == 4: # conv 2d kernel, etc
389
+ value = np.transpose(value, (3, 2, 0, 1))
390
+
391
+ state_dict_torch[k2] = value
392
+
393
+ res = {"model": state_dict_torch, "__author__": "third_party", "matching_heuristics": True}
394
+
395
+ with open(sys.argv[2], "wb") as f:
396
+ pkl.dump(res, f)
397
+
398
+
399
+ # r50: 52.85 -> 52.71 w/ eps 1e-3
400
+ # convnext-base: 56.85 -> 56.97 w/ eps 1e-3
demo/demo.ipynb ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# kMaX-DeepLab Demo\n",
9
+ "This notebook is modified by Qihang Yu, with reference from [Mask2Former's script](https://colab.research.google.com/drive/1uIWE5KbGFSjrxey2aRd5pWkKNY1_SaNq)"
10
+ ]
11
+ },
12
+ {
13
+ "attachments": {},
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "# Install detectron2"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# Install detectron2\n",
27
+ "import torch\n",
28
+ "TORCH_VERSION = \".\".join(torch.__version__.split(\".\")[:2])\n",
29
+ "CUDA_VERSION = torch.__version__.split(\"+\")[-1]\n",
30
+ "print(\"torch: \", TORCH_VERSION, \"; cuda: \", CUDA_VERSION)\n",
31
+ "# Install detectron2 that matches the above pytorch version\n",
32
+ "# See https://detectron2.readthedocs.io/tutorials/install.html for instructions\n",
33
+ "!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html"
34
+ ]
35
+ },
36
+ {
37
+ "attachments": {},
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "# Install kMaX-DeepLab"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "# clone and install kMaX-DeepLab\n",
51
+ "!git clone https://github.com/yucornetto/kmaxdeeplab_detectron2.git\n",
52
+ "%cd kmaxdeeplab_detectron2\n",
53
+ "!pip install -U opencv-python\n",
54
+ "!pip install git+https://github.com/cocodataset/panopticapi.git\n",
55
+ "!pip install -r requirements.txt"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# You may need to restart your runtime prior to this, to let your installation take effect\n",
65
+ "%cd /content/kmaxdeeplab_detectron2\n",
66
+ "# Some basic setup:\n",
67
+ "# Setup detectron2 logger\n",
68
+ "import detectron2\n",
69
+ "from detectron2.utils.logger import setup_logger\n",
70
+ "setup_logger()\n",
71
+ "setup_logger(name=\"kmax_deeplab\")\n",
72
+ "\n",
73
+ "# import some common libraries\n",
74
+ "import numpy as np\n",
75
+ "import cv2\n",
76
+ "import torch\n",
77
+ "from google.colab.patches import cv2_imshow\n",
78
+ "\n",
79
+ "# import some common detectron2 utilities\n",
80
+ "from detectron2 import model_zoo\n",
81
+ "from detectron2.engine import DefaultPredictor\n",
82
+ "from detectron2.config import get_cfg\n",
83
+ "from detectron2.utils.visualizer import Visualizer, ColorMode\n",
84
+ "from detectron2.data import MetadataCatalog\n",
85
+ "from detectron2.projects.deeplab import add_deeplab_config\n",
86
+ "coco_metadata = MetadataCatalog.get(\"coco_2017_val_panoptic\")\n",
87
+ "\n",
88
+ "# import Mask2Former project\n",
89
+ "from kmax_deeplab import add_kmax_deeplab_config"
90
+ ]
91
+ },
92
+ {
93
+ "attachments": {},
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "# Run a pre-trained Mask2Former model\n",
98
+ "We first download an image from the COCO dataset:"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "!wget http://images.cocodataset.org/val2017/000000005477.jpg -q -O input.jpg\n",
108
+ "im = cv2.imread(\"./input.jpg\")\n",
109
+ "cv2_imshow(im)"
110
+ ]
111
+ },
112
+ {
113
+ "attachments": {},
114
+ "cell_type": "markdown",
115
+ "metadata": {},
116
+ "source": [
117
+ "Then, we create a detectron2 config and a detectron2 `DefaultPredictor` to run inference on this image."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "cfg = get_cfg()\n",
127
+ "add_deeplab_config(cfg)\n",
128
+ "add_kmax_deeplab_config(cfg)\n",
129
+ "cfg.merge_from_file(\"configs/coco/panoptic-segmentation/kmax_convnext_large.yaml\")\n",
130
+ "cfg.MODEL.WEIGHTS = 'https://drive.google.com/uc?id=1b6rEnKw4PNTdqSdWpmb0P9dsvN0pkOiN&export=download'\n",
131
+ "cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = True\n",
132
+ "cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = True\n",
133
+ "cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True\n",
134
+ "predictor = DefaultPredictor(cfg)\n",
135
+ "outputs = predictor(im)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "# Show panoptic/instance/semantic predictions: \n",
145
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
146
+ "panoptic_result = v.draw_panoptic_seg(outputs[\"panoptic_seg\"][0].to(\"cpu\"), outputs[\"panoptic_seg\"][1]).get_image()\n",
147
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
148
+ "instance_result = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")).get_image()\n",
149
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
150
+ "semantic_result = v.draw_sem_seg(outputs[\"sem_seg\"].argmax(0).to(\"cpu\")).get_image()\n",
151
+ "print(\"Panoptic segmentation (top), instance segmentation (middle), semantic segmentation (bottom)\")\n",
152
+ "cv2_imshow(np.concatenate((panoptic_result, instance_result, semantic_result), axis=0)[:, :, ::-1])"
153
+ ]
154
+ },
155
+ {
156
+ "attachments": {},
157
+ "cell_type": "markdown",
158
+ "metadata": {},
159
+ "source": [
160
+ "Let's try an image not from COCO as well:"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "# Download a sample image and display. Replace path here to try your own images!\n",
170
+ "!wget https://web.eecs.umich.edu/~fouhey/fun/desk/desk.jpg\n",
171
+ "im = cv2.imread(\"./desk.jpg\")\n",
172
+ "cv2_imshow(im)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "outputs = predictor(im)\n",
182
+ "# Show panoptic/instance/semantic predictions: \n",
183
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
184
+ "panoptic_result = v.draw_panoptic_seg(outputs[\"panoptic_seg\"][0].to(\"cpu\"), outputs[\"panoptic_seg\"][1]).get_image()\n",
185
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
186
+ "instance_result = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")).get_image()\n",
187
+ "v = Visualizer(im[:, :, ::-1], coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)\n",
188
+ "semantic_result = v.draw_sem_seg(outputs[\"sem_seg\"].argmax(0).to(\"cpu\")).get_image()\n",
189
+ "print(\"Panoptic segmentation (top), instance segmentation (middle), semantic segmentation (bottom)\")\n",
190
+ "cv2_imshow(np.concatenate((panoptic_result, instance_result, semantic_result), axis=0)[:, :, ::-1])"
191
+ ]
192
+ }
193
+ ],
194
+ "metadata": {
195
+ "kernelspec": {
196
+ "display_name": "Python 3",
197
+ "language": "python",
198
+ "name": "python3"
199
+ },
200
+ "language_info": {
201
+ "name": "python",
202
+ "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
203
+ },
204
+ "orig_nbformat": 4,
205
+ "vscode": {
206
+ "interpreter": {
207
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
208
+ }
209
+ }
210
+ },
211
+ "nbformat": 4,
212
+ "nbformat_minor": 2
213
+ }
demo/demo.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
3
+ import argparse
4
+ import glob
5
+ import multiprocessing as mp
6
+ import os
7
+
8
+ # fmt: off
9
+ import sys
10
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
11
+ # fmt: on
12
+
13
+ import tempfile
14
+ import time
15
+ import warnings
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import tqdm
20
+
21
+ from detectron2.config import get_cfg
22
+ from detectron2.data.detection_utils import read_image
23
+ from detectron2.projects.deeplab import add_deeplab_config
24
+ from detectron2.utils.logger import setup_logger
25
+
26
+ from kmax_deeplab import add_kmax_deeplab_config
27
+ from predictor import VisualizationDemo
28
+
29
+
30
+ # constants
31
+ WINDOW_NAME = "kmaxdeeplab demo"
32
+
33
+
34
+ def setup_cfg(args):
35
+ # load config from file and command-line arguments
36
+ cfg = get_cfg()
37
+ add_deeplab_config(cfg)
38
+ add_kmax_deeplab_config(cfg)
39
+ cfg.merge_from_file(args.config_file)
40
+ cfg.merge_from_list(args.opts)
41
+ cfg.freeze()
42
+ return cfg
43
+
44
+
45
+ def get_parser():
46
+ parser = argparse.ArgumentParser(description="kmaxdeeplab demo for builtin configs")
47
+ parser.add_argument(
48
+ "--config-file",
49
+ default="configs/coco/panoptic-segmentation/kmax_convnext_large.yaml",
50
+ metavar="FILE",
51
+ help="path to config file",
52
+ )
53
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
54
+ parser.add_argument("--video-input", help="Path to video file.")
55
+ parser.add_argument(
56
+ "--input",
57
+ nargs="+",
58
+ help="A list of space separated input images; "
59
+ "or a single glob pattern such as 'directory/*.jpg'",
60
+ )
61
+ parser.add_argument(
62
+ "--output",
63
+ help="A file or directory to save output visualizations. "
64
+ "If not given, will show output in an OpenCV window.",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--confidence-threshold",
69
+ type=float,
70
+ default=0.5,
71
+ help="Minimum score for instance predictions to be shown",
72
+ )
73
+ parser.add_argument(
74
+ "--opts",
75
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
76
+ default=[],
77
+ nargs=argparse.REMAINDER,
78
+ )
79
+ return parser
80
+
81
+
82
+ def test_opencv_video_format(codec, file_ext):
83
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
84
+ filename = os.path.join(dir, "test_file" + file_ext)
85
+ writer = cv2.VideoWriter(
86
+ filename=filename,
87
+ fourcc=cv2.VideoWriter_fourcc(*codec),
88
+ fps=float(30),
89
+ frameSize=(10, 10),
90
+ isColor=True,
91
+ )
92
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
93
+ writer.release()
94
+ if os.path.isfile(filename):
95
+ return True
96
+ return False
97
+
98
+
99
+ if __name__ == "__main__":
100
+ mp.set_start_method("spawn", force=True)
101
+ args = get_parser().parse_args()
102
+ setup_logger(name="fvcore")
103
+ logger = setup_logger()
104
+ logger.info("Arguments: " + str(args))
105
+
106
+ cfg = setup_cfg(args)
107
+
108
+ demo = VisualizationDemo(cfg)
109
+
110
+ if args.input:
111
+ if len(args.input) == 1:
112
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
113
+ assert args.input, "The input path(s) was not found"
114
+ for path in tqdm.tqdm(args.input, disable=not args.output):
115
+ # use PIL, to be consistent with evaluation
116
+ img = read_image(path, format="BGR")
117
+ start_time = time.time()
118
+ predictions, visualized_output = demo.run_on_image(img)
119
+ logger.info(
120
+ "{}: {} in {:.2f}s".format(
121
+ path,
122
+ "detected {} instances".format(len(predictions["instances"]))
123
+ if "instances" in predictions
124
+ else "finished",
125
+ time.time() - start_time,
126
+ )
127
+ )
128
+
129
+ ## Below are raw outputs.
130
+ # panoptic_seg, segments_info = predictions["panoptic_seg"]
131
+ # print(panoptic_seg.shape, segments_info)
132
+
133
+ if args.output:
134
+ if os.path.isdir(args.output):
135
+ assert os.path.isdir(args.output), args.output
136
+ out_filename = os.path.join(args.output, os.path.basename(path))
137
+ else:
138
+ assert len(args.input) == 1, "Please specify a directory with args.output"
139
+ out_filename = args.output
140
+ visualized_output.save(out_filename)
141
+ else:
142
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
143
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
144
+ if cv2.waitKey(0) == 27:
145
+ break # esc to quit
146
+ elif args.webcam:
147
+ assert args.input is None, "Cannot have both --input and --webcam!"
148
+ assert args.output is None, "output not yet supported with --webcam!"
149
+ cam = cv2.VideoCapture(0)
150
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
151
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
152
+ cv2.imshow(WINDOW_NAME, vis)
153
+ if cv2.waitKey(1) == 27:
154
+ break # esc to quit
155
+ cam.release()
156
+ cv2.destroyAllWindows()
demo/predictor.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
3
+ import atexit
4
+ import bisect
5
+ import multiprocessing as mp
6
+ from collections import deque
7
+
8
+ import cv2
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.engine.defaults import DefaultPredictor
13
+ from detectron2.utils.video_visualizer import VideoVisualizer
14
+ from detectron2.utils.visualizer import ColorMode, Visualizer
15
+
16
+
17
+ class VisualizationDemo(object):
18
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
19
+ """
20
+ Args:
21
+ cfg (CfgNode):
22
+ instance_mode (ColorMode):
23
+ parallel (bool): whether to run the model in different processes from visualization.
24
+ Useful since the visualization logic can be slow.
25
+ """
26
+ self.metadata = MetadataCatalog.get(
27
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
28
+ )
29
+ self.cpu_device = torch.device("cpu")
30
+ self.instance_mode = instance_mode
31
+
32
+ self.parallel = parallel
33
+ if parallel:
34
+ num_gpu = torch.cuda.device_count()
35
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
36
+ else:
37
+ self.predictor = DefaultPredictor(cfg)
38
+
39
+ def run_on_image(self, image):
40
+ """
41
+ Args:
42
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
43
+ This is the format used by OpenCV.
44
+ Returns:
45
+ predictions (dict): the output of the model.
46
+ vis_output (VisImage): the visualized image output.
47
+ """
48
+ vis_output = None
49
+ predictions = self.predictor(image)
50
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
51
+ image = image[:, :, ::-1]
52
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
53
+ if "panoptic_seg" in predictions:
54
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
55
+ vis_output = visualizer.draw_panoptic_seg_predictions(
56
+ panoptic_seg.to(self.cpu_device), segments_info
57
+ )
58
+ else:
59
+ if "sem_seg" in predictions:
60
+ vis_output = visualizer.draw_sem_seg(
61
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
62
+ )
63
+ if "instances" in predictions:
64
+ instances = predictions["instances"].to(self.cpu_device)
65
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
66
+
67
+ return predictions, vis_output
68
+
69
+ def _frame_from_video(self, video):
70
+ while video.isOpened():
71
+ success, frame = video.read()
72
+ if success:
73
+ yield frame
74
+ else:
75
+ break
76
+
77
+
78
+ class AsyncPredictor:
79
+ """
80
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
81
+ Because rendering the visualization takes considerably amount of time,
82
+ this helps improve throughput a little bit when rendering videos.
83
+ """
84
+
85
+ class _StopToken:
86
+ pass
87
+
88
+ class _PredictWorker(mp.Process):
89
+ def __init__(self, cfg, task_queue, result_queue):
90
+ self.cfg = cfg
91
+ self.task_queue = task_queue
92
+ self.result_queue = result_queue
93
+ super().__init__()
94
+
95
+ def run(self):
96
+ predictor = DefaultPredictor(self.cfg)
97
+
98
+ while True:
99
+ task = self.task_queue.get()
100
+ if isinstance(task, AsyncPredictor._StopToken):
101
+ break
102
+ idx, data = task
103
+ result = predictor(data)
104
+ self.result_queue.put((idx, result))
105
+
106
+ def __init__(self, cfg, num_gpus: int = 1):
107
+ """
108
+ Args:
109
+ cfg (CfgNode):
110
+ num_gpus (int): if 0, will run on CPU
111
+ """
112
+ num_workers = max(num_gpus, 1)
113
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
114
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
115
+ self.procs = []
116
+ for gpuid in range(max(num_gpus, 1)):
117
+ cfg = cfg.clone()
118
+ cfg.defrost()
119
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
120
+ self.procs.append(
121
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
122
+ )
123
+
124
+ self.put_idx = 0
125
+ self.get_idx = 0
126
+ self.result_rank = []
127
+ self.result_data = []
128
+
129
+ for p in self.procs:
130
+ p.start()
131
+ atexit.register(self.shutdown)
132
+
133
+ def put(self, image):
134
+ self.put_idx += 1
135
+ self.task_queue.put((self.put_idx, image))
136
+
137
+ def get(self):
138
+ self.get_idx += 1 # the index needed for this request
139
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
140
+ res = self.result_data[0]
141
+ del self.result_data[0], self.result_rank[0]
142
+ return res
143
+
144
+ while True:
145
+ # make sure the results are returned in the correct order
146
+ idx, res = self.result_queue.get()
147
+ if idx == self.get_idx:
148
+ return res
149
+ insert = bisect.bisect(self.result_rank, idx)
150
+ self.result_rank.insert(insert, idx)
151
+ self.result_data.insert(insert, res)
152
+
153
+ def __len__(self):
154
+ return self.put_idx - self.get_idx
155
+
156
+ def __call__(self, image):
157
+ self.put(image)
158
+ return self.get()
159
+
160
+ def shutdown(self):
161
+ for _ in self.procs:
162
+ self.task_queue.put(AsyncPredictor._StopToken())
163
+
164
+ @property
165
+ def default_buffer_size(self):
166
+ return len(self.procs) * 5
docs/clustering_view_of_mask_transformer.png ADDED
docs/kmax_decoder.png ADDED
kmax_deeplab/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import data # register all new datasets
2
+ from . import modeling
3
+
4
+ # config
5
+ from .config import add_kmax_deeplab_config
6
+
7
+ # dataset loading
8
+ from .data.dataset_mappers.coco_panoptic_kmaxdeeplab_dataset_mapper import COCOPanoptickMaXDeepLabDatasetMapper
9
+
10
+
11
+ # models
12
+ from .kmax_model import kMaXDeepLab
13
+
14
+ # evaluation
15
+ from .evaluation.instance_evaluation import InstanceSegEvaluator
kmax_deeplab/config.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from detectron2.config import CfgNode as CN
3
+
4
+
5
+ def add_kmax_deeplab_config(cfg):
6
+ """
7
+ Add config for KMAX_DEEPLAB.
8
+ """
9
+ # NOTE: configs from original maskformer
10
+ # data config
11
+ # select the dataset mapper
12
+ cfg.INPUT.DATASET_MAPPER_NAME = "coco_panoptic_kmaxdeeplab"
13
+ # Color augmentation
14
+ # Pad image and segmentation GT in dataset mapper.
15
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
16
+
17
+ # solver config
18
+ # weight decay on embedding
19
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.05
20
+ # optimizer
21
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
22
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
23
+
24
+ # kMaX-DeepLab model config
25
+ cfg.MODEL.KMAX_DEEPLAB = CN()
26
+
27
+ # whether to share matching results
28
+ cfg.MODEL.KMAX_DEEPLAB.SHARE_FINAL_MATCHING = True
29
+
30
+ # vis
31
+ cfg.MODEL.KMAX_DEEPLAB.SAVE_VIS_NUM = 0
32
+
33
+ # loss
34
+ cfg.MODEL.KMAX_DEEPLAB.DEEP_SUPERVISION = True
35
+ cfg.MODEL.KMAX_DEEPLAB.SKIP_CONN_INIT_VALUE = 0.0
36
+ cfg.MODEL.KMAX_DEEPLAB.NO_OBJECT_WEIGHT = 1e-5
37
+ cfg.MODEL.KMAX_DEEPLAB.CLASS_WEIGHT = 3.0
38
+ cfg.MODEL.KMAX_DEEPLAB.DICE_WEIGHT = 3.0
39
+ cfg.MODEL.KMAX_DEEPLAB.MASK_WEIGHT = 0.3
40
+ cfg.MODEL.KMAX_DEEPLAB.INSDIS_WEIGHT = 1.0
41
+ cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT = 1.0
42
+
43
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_TEMPERATURE = 1.5
44
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_SAMPLE_K = 4096
45
+ cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_TEMPERATURE = 2.0
46
+ cfg.MODEL.KMAX_DEEPLAB.UX_SEMANTIC_SAMPLE_K = 4096
47
+
48
+
49
+ # pixel decoder config
50
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC = CN()
51
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME = "kMaXPixelDecoder"
52
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES = ['res2', 'res3', 'res4', 'res5']
53
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS = [1, 5, 1, 1]
54
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES = ["axial", "axial", "bottleneck", "bottleneck"]
55
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS = [512, 256, 128, 64]
56
+ cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB = 0.0
57
+
58
+ # transformer decoder config
59
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC = CN()
60
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME = "kMaXTransformerDecoder"
61
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS = [2, 2, 2]
62
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES = 128
63
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.IN_CHANNELS = [2048, 1024, 512]
64
+ cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DROP_PATH_PROB = 0.0
65
+
66
+ # kMaX-DeepLab inference config
67
+ cfg.MODEL.KMAX_DEEPLAB.TEST = CN()
68
+ cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON = False
69
+ cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON = False
70
+ cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON = True
71
+ cfg.MODEL.KMAX_DEEPLAB.TEST.OBJECT_MASK_THRESHOLD = 0.4
72
+ cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_THING = 0.7
73
+ cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_STUFF = 0.5
74
+ cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_CLASS_WEIGHT = 1.0
75
+ cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_MASK_WEIGHT = 1.0
76
+ cfg.MODEL.KMAX_DEEPLAB.TEST.OVERLAP_THRESHOLD = 0.8
77
+ cfg.MODEL.KMAX_DEEPLAB.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
78
+
79
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
80
+ # you can use this config to override
81
+ cfg.MODEL.KMAX_DEEPLAB.SIZE_DIVISIBILITY = -1
82
+
83
+ # https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/config.py#L197
84
+ cfg.MODEL.CONVNEXT = CN()
85
+ cfg.MODEL.CONVNEXT.IN_CHANNELS = 3
86
+ cfg.MODEL.CONVNEXT.DEPTHS = [3, 3, 27, 3]
87
+ cfg.MODEL.CONVNEXT.DIMS = [192, 384, 768, 1536]
88
+ cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.6
89
+ cfg.MODEL.CONVNEXT.LSIT = 1e-6
90
+ cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3]
91
+ cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
92
+
93
+ cfg.INPUT.IMAGE_SIZE = [1281, 1281]
94
+ cfg.INPUT.MIN_SCALE = 0.2
95
+ cfg.INPUT.MAX_SCALE = 2.0
96
+
kmax_deeplab/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import datasets
kmax_deeplab/data/dataset_mappers/__init__.py ADDED
File without changes
kmax_deeplab/data/dataset_mappers/coco_panoptic_kmaxdeeplab_dataset_mapper.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
2
+ # modified by Qihang Yu
3
+ import copy
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+ import random
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.data import detection_utils as utils
12
+ from detectron2.data import transforms as T
13
+ from detectron2.projects.point_rend import ColorAugSSDTransform
14
+ from detectron2.structures import BitMasks, Boxes, Instances
15
+
16
+ import os
17
+
18
+ __all__ = ["COCOPanoptickMaXDeepLabDatasetMapper"]
19
+
20
+
21
+ def build_transform_gen(cfg, is_train, scale_ratio=1.0):
22
+ """
23
+ Create a list of default :class:`Augmentation` from config.
24
+ Now it includes resizing and flipping.
25
+ Returns:
26
+ list[Augmentation]
27
+ """
28
+ image_size = cfg.INPUT.IMAGE_SIZE
29
+ assert is_train
30
+
31
+ min_scale = cfg.INPUT.MIN_SCALE * scale_ratio
32
+ max_scale = cfg.INPUT.MAX_SCALE * scale_ratio
33
+
34
+
35
+ augmentation = [
36
+ T.ResizeScale(
37
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size[0], target_width=image_size[1]
38
+ ),
39
+ ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT),
40
+ T.RandomCrop(crop_type="absolute", crop_size=(image_size[0], image_size[1])),
41
+ T.RandomFlip(),
42
+ ]
43
+
44
+ return augmentation
45
+
46
+
47
+ class COCOPanoptickMaXDeepLabDatasetMapper:
48
+ """
49
+ A callable which takes a dataset dict in Detectron2 Dataset format,
50
+ and map it into a format used by kMaX-DeepLab.
51
+
52
+ The callable currently does the following:
53
+
54
+ 1. Read the image from "file_name"
55
+ 2. Applies geometric transforms to the image and annotation
56
+ 3. Find and applies suitable cropping to the image and annotation
57
+ 4. Prepare image and annotation to Tensors
58
+ """
59
+
60
+ @configurable
61
+ def __init__(
62
+ self,
63
+ is_train=True,
64
+ *,
65
+ tfm_gens,
66
+ tfm_gens_copy_paste,
67
+ image_format,
68
+ image_size,
69
+ ):
70
+ """
71
+ NOTE: this interface is experimental.
72
+ Args:
73
+ is_train: for training or inference
74
+ augmentations: a list of augmentations or deterministic transforms to apply
75
+ tfm_gens: data augmentation
76
+ tfm_gens_copy_paste: data augmentation
77
+ image_format: an image format supported by :func:`detection_utils.read_image`
78
+ image_size: expected image size
79
+ """
80
+ self.tfm_gens = tfm_gens
81
+ self.tfm_gens_copy_paste = tfm_gens_copy_paste
82
+ if is_train:
83
+ logging.getLogger(__name__).info(
84
+ "[COCOPanopticDeepLab2DatasetMapper] Full TransformGens used in training: {}, {}".format(
85
+ str(self.tfm_gens), str(self.tfm_gens_copy_paste)
86
+ )
87
+ )
88
+ else:
89
+ logging.getLogger(__name__).info(
90
+ "[COCOPanopticDeepLab2DatasetMapper] Full TransformGens used in testing: {}".format(
91
+ str(self.tfm_gens)
92
+ )
93
+ )
94
+ self.img_format = image_format
95
+ self.is_train = is_train
96
+ self.image_size = image_size
97
+
98
+ dataset_root = os.getenv("DETECTRON2_DATASETS", "datasets")
99
+ image_dir = os.path.join(dataset_root, "coco/train2017")
100
+ gt_dir = os.path.join(dataset_root, "coco/panoptic_train2017")
101
+ semseg_dir = os.path.join(dataset_root, "coco/panoptic_semseg_train2017")
102
+ json_file = os.path.join(dataset_root, "coco/annotations/panoptic_train2017.json")
103
+ from ..datasets import register_coco_panoptic_annos_semseg
104
+ meta_data = register_coco_panoptic_annos_semseg.get_metadata()
105
+ self.dataset_dict_all = register_coco_panoptic_annos_semseg.load_coco_panoptic_json(
106
+ json_file, image_dir, gt_dir, semseg_dir, meta_data
107
+ )
108
+ self.filename2idx = {}
109
+ for idx, dataset_dict in enumerate(self.dataset_dict_all):
110
+ self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')] = idx
111
+
112
+
113
+ @classmethod
114
+ def from_config(cls, cfg, is_train=True):
115
+ # Build augmentation
116
+ tfm_gens = build_transform_gen(cfg, is_train)
117
+ tfm_gens_copy_paste = build_transform_gen(cfg, is_train, scale_ratio=0.5)
118
+ ret = {
119
+ "is_train": is_train,
120
+ "tfm_gens": tfm_gens,
121
+ "tfm_gens_copy_paste": tfm_gens_copy_paste,
122
+ "image_format": cfg.INPUT.FORMAT,
123
+ "image_size": cfg.INPUT.IMAGE_SIZE
124
+ }
125
+ return ret
126
+
127
+ def read_dataset_dict(self, dataset_dict, is_copy_paste=False):
128
+ """
129
+ Args:
130
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
131
+
132
+ Returns:
133
+ dict: a format that builtin models in detectron2 accept
134
+ """
135
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
136
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
137
+ utils.check_image_size(dataset_dict, image)
138
+
139
+ if not is_copy_paste:
140
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
141
+ else:
142
+ image, transforms = T.apply_transform_gens(self.tfm_gens_copy_paste, image)
143
+
144
+ dataset_dict["image"] = np.ascontiguousarray(image.transpose(2, 0, 1))
145
+
146
+ if not self.is_train:
147
+ dataset_dict.pop("annotations", None)
148
+ return dataset_dict, None
149
+
150
+ # We pad the image manually, for copy-paste purpose.
151
+ padded_image = np.zeros((3, self.image_size[0], self.image_size[1]), dtype=dataset_dict["image"].dtype)
152
+ new_h, new_w = dataset_dict["image"].shape[1:]
153
+ offset_h, offset_w = 0, 0 # following the d2 panoptic deeplab implementaiton to only perform bottom/right padding.
154
+ padded_image[:, offset_h:offset_h+new_h, offset_w:offset_w+new_w] = dataset_dict["image"]
155
+ dataset_dict["image"] = padded_image
156
+ if "pan_seg_file_name" in dataset_dict:
157
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
158
+
159
+ # apply the same transformation to panoptic segmentation
160
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
161
+
162
+ from panopticapi.utils import rgb2id
163
+
164
+ pan_seg_gt = rgb2id(pan_seg_gt) # int32 # H x W
165
+ # similarily, we manually pad the label, and we use label -1 to indicate those padded pixels.
166
+ # In this way, we can masking out the padded pixels values to 0 after normalization, which aligns the
167
+ # behavior between training and testing.
168
+ padded_pan_seg_gt = np.zeros((self.image_size[0], self.image_size[1]), dtype=pan_seg_gt.dtype)
169
+ is_real_pixels = np.zeros((self.image_size[0], self.image_size[1]), dtype=np.bool)
170
+ padded_pan_seg_gt[offset_h:offset_h+new_h, offset_w:offset_w+new_w] = pan_seg_gt
171
+ is_real_pixels[offset_h:offset_h+new_h, offset_w:offset_w+new_w] = True
172
+ dataset_dict["is_real_pixels"] = is_real_pixels
173
+ pan_seg_gt = padded_pan_seg_gt
174
+ return dataset_dict, pan_seg_gt
175
+
176
+ # This should never happen.
177
+ raise NotImplementedError
178
+
179
+ def call_copypaste(self, dataset_dict):
180
+ """
181
+ Args:
182
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
183
+
184
+ Returns:
185
+ dict: a format that builtin models in detectron2 accept
186
+ """
187
+ # Read main image.
188
+ dataset_dict, pan_seg_gt = self.read_dataset_dict(dataset_dict, is_copy_paste=False)
189
+ # Read copy-paste image.
190
+ # We use the last number as a bias to random number, in case same random numbers are generated across devices.
191
+ main_image_idx = self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')]
192
+ random_image_idx = main_image_idx + random.randint(0, len(self.dataset_dict_all) - 1)
193
+ random_image_idx = random_image_idx % len(self.dataset_dict_all)
194
+ dataset_dict_copy_paste = copy.deepcopy(self.dataset_dict_all[random_image_idx])
195
+ dataset_dict_copy_paste, pan_seg_gt_copy_paste = self.read_dataset_dict(dataset_dict_copy_paste, is_copy_paste=True)
196
+
197
+ # Copy data_dict_copy_paste onto data_dict. 0 means keep original pixel, 1 means use copy-paste pixel.
198
+ copy_paste_masks = np.zeros((pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
199
+
200
+ segments_info_copy_paste = dataset_dict_copy_paste["segments_info"]
201
+ all_ids = []
202
+ thing_ids = []
203
+ for segment_info_copy_paste in segments_info_copy_paste:
204
+ class_id = segment_info_copy_paste["category_id"]
205
+ if not segment_info_copy_paste["iscrowd"]:
206
+ # -1 is reserved for padded pixels.
207
+ if segment_info_copy_paste["id"] in [-1, 0]:
208
+ print(segment_info_copy_paste)
209
+ raise ValueError("id should not be -1, 0")
210
+ all_ids.append(segment_info_copy_paste["id"])
211
+ if segment_info_copy_paste["isthing"]: # All thing classes are copy-pasted.
212
+ thing_ids.append(segment_info_copy_paste["id"])
213
+
214
+ # Shuffle and randomly select kept label ids.
215
+ random.shuffle(all_ids)
216
+ keep_number = random.randint(0, len(all_ids))
217
+
218
+ for index, label_id in enumerate(all_ids):
219
+ # randomly copy labels, but keep all thing classes.
220
+ if index < keep_number or label_id in thing_ids:
221
+ copy_paste_masks[pan_seg_gt_copy_paste == label_id] = 1
222
+
223
+ # We merge the image and copy-paste image based on the copy-paste mask.
224
+ dataset_dict["image"] = (dataset_dict["image"] * (1.0 - copy_paste_masks).astype(dataset_dict["image"].dtype) +
225
+ dataset_dict_copy_paste["image"] * copy_paste_masks.astype(dataset_dict["image"].dtype))
226
+ dataset_dict["image"] = torch.as_tensor(dataset_dict["image"])
227
+
228
+ dataset_dict["is_real_pixels"] = (dataset_dict["is_real_pixels"] * (1.0 - copy_paste_masks).astype(dataset_dict["is_real_pixels"].dtype) +
229
+ dataset_dict_copy_paste["is_real_pixels"] * copy_paste_masks.astype(dataset_dict["is_real_pixels"].dtype))
230
+ dataset_dict["is_real_pixels"] = torch.as_tensor(dataset_dict["is_real_pixels"])
231
+ # We set all ids in copy-paste image to be negative, so that there will be no overlap between original id and copy-paste id.
232
+ pan_seg_gt_copy_paste = -pan_seg_gt_copy_paste
233
+ pan_seg_gt = (pan_seg_gt * (1.0 - copy_paste_masks).astype(pan_seg_gt.dtype) +
234
+ pan_seg_gt_copy_paste * copy_paste_masks.astype(pan_seg_gt.dtype))
235
+
236
+ # We use 4x downsampled gt for final supervision.
237
+ pan_seg_gt = pan_seg_gt[::4, ::4]
238
+ sem_seg_gt = -np.ones_like(pan_seg_gt) # H x W, init with -1
239
+
240
+ # We then process the obtained pan_seg_gt to training format.
241
+ image_shape = dataset_dict["image"].shape[1:] # h, w
242
+ segments_info = dataset_dict["segments_info"]
243
+ instances = Instances(image_shape)
244
+ classes = []
245
+ masks = []
246
+ valid_pixel_num = 0
247
+ # As the two images may share same stuff classes, we use a dict to track existing stuff and merge them.
248
+ stuff_class_to_idx = {}
249
+ for segment_info in segments_info:
250
+ class_id = segment_info["category_id"]
251
+ if not segment_info["iscrowd"]:
252
+ # -1 is reserved to indicate padded pixels.
253
+ if segment_info["id"] in [-1, 0]:
254
+ print(segment_info)
255
+ raise ValueError("id should not be -1, 0")
256
+ binary_mask = (pan_seg_gt == segment_info["id"])
257
+ # As it is possible that some masks are removed during the copy-paste process, we need
258
+ # to double check if the maks exists.
259
+ valid_pixel_num_ = binary_mask.sum()
260
+ valid_pixel_num += valid_pixel_num_
261
+ if valid_pixel_num_ > 0:
262
+ sem_seg_gt[binary_mask] = class_id
263
+ if not segment_info["isthing"]:
264
+ # For original image, stuff should only appear once.
265
+ if class_id in stuff_class_to_idx:
266
+ raise ValueError('class_id should not already be in stuff_class_to_idx!')
267
+ else:
268
+ stuff_class_to_idx[class_id] = len(masks)
269
+ classes.append(class_id)
270
+ masks.append(binary_mask)
271
+
272
+ for segment_info in segments_info_copy_paste:
273
+ class_id = segment_info["category_id"]
274
+ if not segment_info["iscrowd"]:
275
+ # -1 is reserved to indicate padded pixels.
276
+ if segment_info["id"] in [-1, 0]:
277
+ print(segment_info)
278
+ raise ValueError("id should not be -1, 0")
279
+ # Note that copy-paste id is negative.
280
+ binary_mask = (pan_seg_gt == -segment_info["id"])
281
+ valid_pixel_num_ = binary_mask.sum()
282
+ valid_pixel_num += valid_pixel_num_
283
+ if valid_pixel_num_ > 0:
284
+ sem_seg_gt[binary_mask] = class_id
285
+ if not segment_info["isthing"]:
286
+ # The stuff in copy-paste image already appeared in original image.
287
+ if class_id in stuff_class_to_idx:
288
+ # Merge into original stuff masks.
289
+ masks[stuff_class_to_idx[class_id]] = np.logical_or(masks[stuff_class_to_idx[class_id]], binary_mask)
290
+ continue
291
+ else:
292
+ stuff_class_to_idx[class_id] = len(masks)
293
+ classes.append(class_id)
294
+ masks.append(binary_mask)
295
+
296
+ classes = np.array(classes)
297
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
298
+ sem_seg_gt = torch.tensor(sem_seg_gt, dtype=torch.int64)
299
+
300
+ if len(masks) == 0:
301
+ # Some image does not have annotation (all ignored)
302
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
303
+ instances.gt_boxes = Boxes(torch.zeros((0, 4)))
304
+ else:
305
+ masks = BitMasks(
306
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
307
+ )
308
+ instances.gt_masks = masks.tensor
309
+ instances.gt_boxes = masks.get_bounding_boxes()
310
+
311
+ dataset_dict["instances"] = instances
312
+ dataset_dict["sem_seg_gt"] = sem_seg_gt
313
+ dataset_dict["valid_pixel_num"] = valid_pixel_num
314
+ return dataset_dict
315
+
316
+ def __call__(self, dataset_dict):
317
+ res = self.call_copypaste(dataset_dict)
318
+ while ("instances" in res and res["instances"].gt_masks.shape[0] == 0) or ("valid_pixel_num" in res and res["valid_pixel_num"] <= 4096):
319
+ # this gt is empty or contains too many void pixels, let's re-generate one.
320
+ main_image_idx = self.filename2idx[dataset_dict["file_name"].split('/')[-1].replace('.jpg', '')]
321
+ random_image_idx = main_image_idx + random.randint(0, len(self.dataset_dict_all) - 1)
322
+ random_image_idx = random_image_idx % len(self.dataset_dict_all)
323
+ dataset_dict = self.dataset_dict_all[random_image_idx]
324
+ res = self.call_copypaste(dataset_dict)
325
+
326
+ return res
kmax_deeplab/data/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import (
2
+ register_coco_panoptic_annos_semseg,
3
+ )
kmax_deeplab/data/datasets/register_coco_panoptic_annos_semseg.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py
2
+
3
+ import json
4
+ import os
5
+
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from detectron2.data.datasets import load_sem_seg
8
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
9
+ from detectron2.utils.file_io import PathManager
10
+
11
+
12
+ _PREDEFINED_SPLITS_COCO_PANOPTIC = {
13
+ "coco_2017_train_panoptic": (
14
+ # This is the original panoptic annotation directory
15
+ "coco/panoptic_train2017",
16
+ "coco/annotations/panoptic_train2017.json",
17
+ # This directory contains semantic annotations that are
18
+ # converted from panoptic annotations.
19
+ # It is used by PanopticFPN.
20
+ # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
21
+ # to create these directories.
22
+ "coco/panoptic_semseg_train2017",
23
+ ),
24
+ "coco_2017_val_panoptic": (
25
+ "coco/panoptic_val2017",
26
+ "coco/annotations/panoptic_val2017.json",
27
+ "coco/panoptic_semseg_val2017",
28
+ ),
29
+ }
30
+
31
+
32
+ def get_metadata():
33
+ meta = {}
34
+ # The following metadata maps contiguous id from [0, #thing categories +
35
+ # #stuff categories) to their names and colors. We have to replica of the
36
+ # same name and color under "thing_*" and "stuff_*" because the current
37
+ # visualization function in D2 handles thing and class classes differently
38
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
39
+ # enable reusing existing visualization functions.
40
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
41
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
42
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
43
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
44
+
45
+ meta["thing_classes"] = thing_classes
46
+ meta["thing_colors"] = thing_colors
47
+ meta["stuff_classes"] = stuff_classes
48
+ meta["stuff_colors"] = stuff_colors
49
+
50
+ # Convert category id for training:
51
+ # category id: like semantic segmentation, it is the class id for each
52
+ # pixel. Since there are some classes not used in evaluation, the category
53
+ # id is not always contiguous and thus we have two set of category ids:
54
+ # - original category id: category id in the original dataset, mainly
55
+ # used for evaluation.
56
+ # - contiguous category id: [0, #classes), in order to train the linear
57
+ # softmax classifier.
58
+ thing_dataset_id_to_contiguous_id = {}
59
+ stuff_dataset_id_to_contiguous_id = {}
60
+
61
+ for i, cat in enumerate(COCO_CATEGORIES):
62
+ if cat["isthing"]:
63
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
64
+ # else:
65
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
66
+
67
+ # in order to use sem_seg evaluator
68
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
69
+
70
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
71
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
72
+
73
+ return meta
74
+
75
+
76
+ def load_coco_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
77
+ """
78
+ Args:
79
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
80
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
81
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
82
+ Returns:
83
+ list[dict]: a list of dicts in Detectron2 standard format. (See
84
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
85
+ """
86
+
87
+ def _convert_category_id(segment_info, meta):
88
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
89
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
90
+ segment_info["category_id"]
91
+ ]
92
+ segment_info["isthing"] = True
93
+ else:
94
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
95
+ segment_info["category_id"]
96
+ ]
97
+ segment_info["isthing"] = False
98
+ return segment_info
99
+
100
+ with PathManager.open(json_file) as f:
101
+ json_info = json.load(f)
102
+
103
+ ret = []
104
+ for ann in json_info["annotations"]:
105
+ image_id = int(ann["image_id"])
106
+ # TODO: currently we assume image and label has the same filename but
107
+ # different extension, and images have extension ".jpg" for COCO. Need
108
+ # to make image extension a user-provided argument if we extend this
109
+ # function to support other COCO-like datasets.
110
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
111
+ label_file = os.path.join(gt_dir, ann["file_name"])
112
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
113
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
114
+ ret.append(
115
+ {
116
+ "file_name": image_file,
117
+ "image_id": image_id,
118
+ "pan_seg_file_name": label_file,
119
+ "sem_seg_file_name": sem_label_file,
120
+ "segments_info": segments_info,
121
+ }
122
+ )
123
+ assert len(ret), f"No images found in {image_dir}!"
124
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
125
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
126
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
127
+ return ret
128
+
129
+
130
+ def register_coco_panoptic_annos_sem_seg(
131
+ name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
132
+ ):
133
+ panoptic_name = name
134
+ delattr(MetadataCatalog.get(panoptic_name), "thing_classes")
135
+ delattr(MetadataCatalog.get(panoptic_name), "thing_colors")
136
+ MetadataCatalog.get(panoptic_name).set(
137
+ thing_classes=metadata["thing_classes"],
138
+ thing_colors=metadata["thing_colors"],
139
+ # thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"],
140
+ )
141
+
142
+ # the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg"
143
+ semantic_name = name + "_with_sem_seg"
144
+ DatasetCatalog.register(
145
+ semantic_name,
146
+ lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, sem_seg_root, metadata),
147
+ )
148
+ MetadataCatalog.get(semantic_name).set(
149
+ sem_seg_root=sem_seg_root,
150
+ panoptic_root=panoptic_root,
151
+ image_root=image_root,
152
+ panoptic_json=panoptic_json,
153
+ json_file=instances_json,
154
+ evaluator_type="coco_panoptic_seg",
155
+ ignore_label=255,
156
+ label_divisor=1000,
157
+ **metadata,
158
+ )
159
+
160
+
161
+ def register_all_coco_panoptic_annos_sem_seg(root):
162
+ for (
163
+ prefix,
164
+ (panoptic_root, panoptic_json, semantic_root),
165
+ ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
166
+ prefix_instances = prefix[: -len("_panoptic")]
167
+ instances_meta = MetadataCatalog.get(prefix_instances)
168
+ image_root, instances_json = instances_meta.image_root, instances_meta.json_file
169
+
170
+ register_coco_panoptic_annos_sem_seg(
171
+ prefix,
172
+ get_metadata(),
173
+ image_root,
174
+ os.path.join(root, panoptic_root),
175
+ os.path.join(root, panoptic_json),
176
+ os.path.join(root, semantic_root),
177
+ instances_json,
178
+ )
179
+
180
+
181
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
182
+ register_all_coco_panoptic_annos_sem_seg(_root)
kmax_deeplab/evaluation/__init__.py ADDED
File without changes
kmax_deeplab/evaluation/instance_evaluation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/evaluation/instance_evaluation.py
2
+ import contextlib
3
+ import copy
4
+ import io
5
+ import itertools
6
+ import json
7
+ import logging
8
+ import numpy as np
9
+ import os
10
+ import pickle
11
+ from collections import OrderedDict
12
+ import pycocotools.mask as mask_util
13
+ import torch
14
+ from pycocotools.coco import COCO
15
+ from pycocotools.cocoeval import COCOeval
16
+ from tabulate import tabulate
17
+
18
+ import detectron2.utils.comm as comm
19
+ from detectron2.config import CfgNode
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.data.datasets.coco import convert_to_coco_json
22
+ from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
23
+ from detectron2.evaluation.fast_eval_api import COCOeval_opt
24
+ from detectron2.structures import Boxes, BoxMode, pairwise_iou
25
+ from detectron2.utils.file_io import PathManager
26
+ from detectron2.utils.logger import create_small_table
27
+
28
+
29
+ # modified from COCOEvaluator for instance segmetnat
30
+ class InstanceSegEvaluator(COCOEvaluator):
31
+ """
32
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
33
+ for keypoint detection outputs using COCO's metrics.
34
+ See http://cocodataset.org/#detection-eval and
35
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
36
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
37
+ the metric cannot be computed (e.g. due to no predictions made).
38
+
39
+ In addition to COCO, this evaluator is able to support any bounding box detection,
40
+ instance segmentation, or keypoint detection dataset.
41
+ """
42
+
43
+ def _eval_predictions(self, predictions, img_ids=None):
44
+ """
45
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
46
+ """
47
+ self._logger.info("Preparing results for COCO format ...")
48
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
49
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
50
+
51
+ # unmap the category ids for COCO
52
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
53
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
54
+ # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
55
+ # num_classes = len(all_contiguous_ids)
56
+ # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
57
+
58
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
59
+ for result in coco_results:
60
+ category_id = result["category_id"]
61
+ # assert category_id < num_classes, (
62
+ # f"A prediction has class={category_id}, "
63
+ # f"but the dataset only has {num_classes} classes and "
64
+ # f"predicted class id should be in [0, {num_classes - 1}]."
65
+ # )
66
+ assert category_id in reverse_id_mapping, (
67
+ f"A prediction has class={category_id}, "
68
+ f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
69
+ )
70
+ result["category_id"] = reverse_id_mapping[category_id]
71
+
72
+ if self._output_dir:
73
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
74
+ self._logger.info("Saving results to {}".format(file_path))
75
+ with PathManager.open(file_path, "w") as f:
76
+ f.write(json.dumps(coco_results))
77
+ f.flush()
78
+
79
+ if not self._do_evaluation:
80
+ self._logger.info("Annotations are not available for evaluation.")
81
+ return
82
+
83
+ self._logger.info(
84
+ "Evaluating predictions with {} COCO API...".format(
85
+ "unofficial" if self._use_fast_impl else "official"
86
+ )
87
+ )
88
+ for task in sorted(tasks):
89
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
90
+ coco_eval = (
91
+ _evaluate_predictions_on_coco(
92
+ self._coco_api,
93
+ coco_results,
94
+ task,
95
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
96
+ use_fast_impl=self._use_fast_impl,
97
+ img_ids=img_ids,
98
+ max_dets_per_image=self._max_dets_per_image,
99
+ )
100
+ if len(coco_results) > 0
101
+ else None # cocoapi does not handle empty results very well
102
+ )
103
+
104
+ res = self._derive_coco_results(
105
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
106
+ )
107
+ self._results[task] = res
kmax_deeplab/evaluation/panoptic_evaluation.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py
2
+ # Reference: https://github.com/open-mmlab/mmdetection/pull/7538
3
+
4
+ #!/usr/bin/env python
5
+ from __future__ import absolute_import
6
+ from __future__ import division
7
+ from __future__ import print_function
8
+ from __future__ import unicode_literals
9
+ import os, sys
10
+ import numpy as np
11
+ import json
12
+ import time
13
+ from datetime import timedelta
14
+ from collections import defaultdict
15
+ import argparse
16
+ import multiprocessing
17
+
18
+ import PIL.Image as Image
19
+
20
+ from panopticapi.utils import get_traceback, rgb2id
21
+
22
+ OFFSET = 256 * 256 * 256
23
+ VOID = 0
24
+
25
+ class PQStatCat():
26
+ def __init__(self):
27
+ self.iou = 0.0
28
+ self.tp = 0
29
+ self.fp = 0
30
+ self.fn = 0
31
+
32
+ def __iadd__(self, pq_stat_cat):
33
+ self.iou += pq_stat_cat.iou
34
+ self.tp += pq_stat_cat.tp
35
+ self.fp += pq_stat_cat.fp
36
+ self.fn += pq_stat_cat.fn
37
+ return self
38
+
39
+
40
+ class PQStat():
41
+ def __init__(self):
42
+ self.pq_per_cat = defaultdict(PQStatCat)
43
+
44
+ def __getitem__(self, i):
45
+ return self.pq_per_cat[i]
46
+
47
+ def __iadd__(self, pq_stat):
48
+ for label, pq_stat_cat in pq_stat.pq_per_cat.items():
49
+ self.pq_per_cat[label] += pq_stat_cat
50
+ return self
51
+
52
+ def pq_average(self, categories, isthing):
53
+ pq, sq, rq, n = 0, 0, 0, 0
54
+ per_class_results = {}
55
+ for label, label_info in categories.items():
56
+ if isthing is not None:
57
+ cat_isthing = label_info['isthing'] == 1
58
+ if isthing != cat_isthing:
59
+ continue
60
+ iou = self.pq_per_cat[label].iou
61
+ tp = self.pq_per_cat[label].tp
62
+ fp = self.pq_per_cat[label].fp
63
+ fn = self.pq_per_cat[label].fn
64
+ if tp + fp + fn == 0:
65
+ per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0}
66
+ continue
67
+ n += 1
68
+ pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
69
+ sq_class = iou / tp if tp != 0 else 0
70
+ rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)
71
+ per_class_results[label] = {'pq': pq_class, 'sq': sq_class, 'rq': rq_class}
72
+ pq += pq_class
73
+ sq += sq_class
74
+ rq += rq_class
75
+
76
+ return {'pq': pq / n, 'sq': sq / n, 'rq': rq / n, 'n': n}, per_class_results
77
+
78
+
79
+ @get_traceback
80
+ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories):
81
+ pq_stat = PQStat()
82
+
83
+ idx = 0
84
+ for gt_ann, pred_ann in annotation_set:
85
+ if idx % 100 == 0:
86
+ print('Core: {}, {} from {} images processed'.format(proc_id, idx, len(annotation_set)))
87
+ idx += 1
88
+
89
+ pan_gt = np.array(Image.open(os.path.join(gt_folder, gt_ann['file_name'])), dtype=np.uint32)
90
+ pan_gt = rgb2id(pan_gt)
91
+ pan_pred = np.array(Image.open(os.path.join(pred_folder, pred_ann['file_name'])), dtype=np.uint32)
92
+ pan_pred = rgb2id(pan_pred)
93
+
94
+ gt_segms = {el['id']: el for el in gt_ann['segments_info']}
95
+ pred_segms = {el['id']: el for el in pred_ann['segments_info']}
96
+
97
+ # predicted segments area calculation + prediction sanity checks
98
+ pred_labels_set = set(el['id'] for el in pred_ann['segments_info'])
99
+ labels, labels_cnt = np.unique(pan_pred, return_counts=True)
100
+ for label, label_cnt in zip(labels, labels_cnt):
101
+ if label not in pred_segms:
102
+ if label == VOID:
103
+ continue
104
+ raise KeyError('In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.'.format(gt_ann['image_id'], label))
105
+ pred_segms[label]['area'] = label_cnt
106
+ pred_labels_set.remove(label)
107
+ if pred_segms[label]['category_id'] not in categories:
108
+ raise KeyError('In the image with ID {} segment with ID {} has unknown category_id {}.'.format(gt_ann['image_id'], label, pred_segms[label]['category_id']))
109
+ if len(pred_labels_set) != 0:
110
+ raise KeyError('In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.'.format(gt_ann['image_id'], list(pred_labels_set)))
111
+
112
+ # confusion matrix calculation
113
+ pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64)
114
+ gt_pred_map = {}
115
+ labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
116
+ for label, intersection in zip(labels, labels_cnt):
117
+ gt_id = label // OFFSET
118
+ pred_id = label % OFFSET
119
+ gt_pred_map[(gt_id, pred_id)] = intersection
120
+
121
+ # count all matched pairs
122
+ gt_matched = set()
123
+ pred_matched = set()
124
+ for label_tuple, intersection in gt_pred_map.items():
125
+ gt_label, pred_label = label_tuple
126
+ if gt_label not in gt_segms:
127
+ continue
128
+ if pred_label not in pred_segms:
129
+ continue
130
+ if gt_segms[gt_label]['iscrowd'] == 1:
131
+ continue
132
+ if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']:
133
+ continue
134
+
135
+ union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
136
+ iou = intersection / union
137
+ if iou > 0.5:
138
+ pq_stat[gt_segms[gt_label]['category_id']].tp += 1
139
+ pq_stat[gt_segms[gt_label]['category_id']].iou += iou
140
+ gt_matched.add(gt_label)
141
+ pred_matched.add(pred_label)
142
+
143
+ # count false positives
144
+ crowd_labels_dict = {}
145
+ for gt_label, gt_info in gt_segms.items():
146
+ if gt_label in gt_matched:
147
+ continue
148
+ # crowd segments are ignored
149
+ if gt_info['iscrowd'] == 1:
150
+ crowd_labels_dict[gt_info['category_id']] = gt_label
151
+ continue
152
+ pq_stat[gt_info['category_id']].fn += 1
153
+
154
+ # count false positives
155
+ for pred_label, pred_info in pred_segms.items():
156
+ if pred_label in pred_matched:
157
+ continue
158
+ # intersection of the segment with VOID
159
+ intersection = gt_pred_map.get((VOID, pred_label), 0)
160
+ # plus intersection with corresponding CROWD region if it exists
161
+ if pred_info['category_id'] in crowd_labels_dict:
162
+ intersection += gt_pred_map.get((crowd_labels_dict[pred_info['category_id']], pred_label), 0)
163
+ # predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions
164
+ if intersection / pred_info['area'] > 0.5:
165
+ continue
166
+ pq_stat[pred_info['category_id']].fp += 1
167
+ print('Core: {}, all {} images processed'.format(proc_id, len(annotation_set)))
168
+ return pq_stat
169
+
170
+
171
+ def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories):
172
+ cpu_num = multiprocessing.cpu_count()
173
+ annotations_split = np.array_split(matched_annotations_list, cpu_num)
174
+ print("Number of cores: {}, images per core: {}".format(cpu_num, len(annotations_split[0])))
175
+ workers = multiprocessing.Pool(processes=cpu_num)
176
+ processes = []
177
+ for proc_id, annotation_set in enumerate(annotations_split):
178
+ p = workers.apply_async(pq_compute_single_core,
179
+ (proc_id, annotation_set, gt_folder, pred_folder, categories))
180
+ processes.append(p)
181
+
182
+ # https://github.com/open-mmlab/mmdetection/pull/7538
183
+ # Close the process pool, otherwise it will lead to memory
184
+ # leaking problems.
185
+ workers.close()
186
+ workers.join()
187
+
188
+
189
+ pq_stat = PQStat()
190
+ for p in processes:
191
+ pq_stat += p.get()
192
+ return pq_stat
193
+
194
+
195
+ def pq_compute(gt_json_file, pred_json_file, gt_folder=None, pred_folder=None):
196
+
197
+ start_time = time.time()
198
+ with open(gt_json_file, 'r') as f:
199
+ gt_json = json.load(f)
200
+ with open(pred_json_file, 'r') as f:
201
+ pred_json = json.load(f)
202
+
203
+ if gt_folder is None:
204
+ gt_folder = gt_json_file.replace('.json', '')
205
+ if pred_folder is None:
206
+ pred_folder = pred_json_file.replace('.json', '')
207
+ categories = {el['id']: el for el in gt_json['categories']}
208
+
209
+ print("Evaluation panoptic segmentation metrics:")
210
+ print("Ground truth:")
211
+ print("\tSegmentation folder: {}".format(gt_folder))
212
+ print("\tJSON file: {}".format(gt_json_file))
213
+ print("Prediction:")
214
+ print("\tSegmentation folder: {}".format(pred_folder))
215
+ print("\tJSON file: {}".format(pred_json_file))
216
+
217
+ if not os.path.isdir(gt_folder):
218
+ raise Exception("Folder {} with ground truth segmentations doesn't exist".format(gt_folder))
219
+ if not os.path.isdir(pred_folder):
220
+ raise Exception("Folder {} with predicted segmentations doesn't exist".format(pred_folder))
221
+
222
+ pred_annotations = {el['image_id']: el for el in pred_json['annotations']}
223
+ matched_annotations_list = []
224
+ for gt_ann in gt_json['annotations']:
225
+ image_id = gt_ann['image_id']
226
+ if image_id not in pred_annotations:
227
+ raise Exception('no prediction for the image with id: {}'.format(image_id))
228
+ matched_annotations_list.append((gt_ann, pred_annotations[image_id]))
229
+
230
+ pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories)
231
+
232
+ metrics = [("All", None), ("Things", True), ("Stuff", False)]
233
+ results = {}
234
+ for name, isthing in metrics:
235
+ results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing)
236
+ if name == 'All':
237
+ results['per_class'] = per_class_results
238
+ print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N"))
239
+ print("-" * (10 + 7 * 4))
240
+
241
+ for name, _isthing in metrics:
242
+ print("{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format(
243
+ name,
244
+ 100 * results[name]['pq'],
245
+ 100 * results[name]['sq'],
246
+ 100 * results[name]['rq'],
247
+ results[name]['n'])
248
+ )
249
+
250
+ t_delta = time.time() - start_time
251
+ print("Time elapsed: {:0.2f} seconds".format(t_delta))
252
+
253
+ return results
254
+
255
+
256
+ if __name__ == "__main__":
257
+ parser = argparse.ArgumentParser()
258
+ parser.add_argument('--gt_json_file', type=str,
259
+ help="JSON file with ground truth data")
260
+ parser.add_argument('--pred_json_file', type=str,
261
+ help="JSON file with predictions data")
262
+ parser.add_argument('--gt_folder', type=str, default=None,
263
+ help="Folder with ground turth COCO format segmentations. \
264
+ Default: X if the corresponding json file is X.json")
265
+ parser.add_argument('--pred_folder', type=str, default=None,
266
+ help="Folder with prediction COCO format segmentations. \
267
+ Default: X if the corresponding json file is X.json")
268
+ args = parser.parse_args()
269
+ pq_compute(args.gt_json_file, args.pred_json_file, args.gt_folder, args.pred_folder)
kmax_deeplab/kmax_model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py
2
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/kmax_deeplab.py
3
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py
4
+ # Modified by Qihang Yu
5
+
6
+ from typing import Tuple, List
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from detectron2.config import configurable
13
+ from detectron2.data import MetadataCatalog
14
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
15
+ from detectron2.modeling.backbone import Backbone
16
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
17
+ from detectron2.structures import Boxes, ImageList, Instances
18
+ from detectron2.utils.memory import retry_if_cuda_oom
19
+
20
+ from .modeling.criterion import SetCriterion
21
+ from .modeling.matcher import HungarianMatcher
22
+ from torch.cuda.amp import autocast
23
+
24
+
25
+ @META_ARCH_REGISTRY.register()
26
+ class kMaXDeepLab(nn.Module):
27
+ """
28
+ Main class for mask classification semantic segmentation architectures.
29
+ """
30
+
31
+ @configurable
32
+ def __init__(
33
+ self,
34
+ *,
35
+ backbone: Backbone,
36
+ sem_seg_head: nn.Module,
37
+ criterion: nn.Module,
38
+ num_queries: int,
39
+ object_mask_threshold: float,
40
+ class_threshold_thing: float,
41
+ class_threshold_stuff: float,
42
+ overlap_threshold: float,
43
+ reorder_class_weight: float,
44
+ reorder_mask_weight: float,
45
+ metadata,
46
+ size_divisibility: int,
47
+ sem_seg_postprocess_before_inference: bool,
48
+ pixel_mean: Tuple[float],
49
+ pixel_std: Tuple[float],
50
+ # inference
51
+ semantic_on: bool,
52
+ panoptic_on: bool,
53
+ instance_on: bool,
54
+ test_topk_per_image: int,
55
+ input_shape: List[int]
56
+ ):
57
+ """
58
+ Args:
59
+ backbone: a backbone module, must follow detectron2's backbone interface
60
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
61
+ criterion: a module that defines the loss
62
+ num_queries: int, number of queries
63
+ object_mask_threshold: float, threshold to filter query based on classification score
64
+ for panoptic segmentation inference
65
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
66
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
67
+ segmentation inference
68
+ size_divisibility: Some backbones require the input height and width to be divisible by a
69
+ specific integer. We can use this to override such requirement.
70
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
71
+ to original input size before semantic segmentation inference or after.
72
+ For high-resolution dataset like Mapillary, resizing predictions before
73
+ inference will cause OOM error.
74
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
75
+ the per-channel mean and std to be used to normalize the input image
76
+ semantic_on: bool, whether to output semantic segmentation prediction
77
+ instance_on: bool, whether to output instance segmentation prediction
78
+ panoptic_on: bool, whether to output panoptic segmentation prediction
79
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
80
+ """
81
+ super().__init__()
82
+ self.backbone = backbone
83
+ self.sem_seg_head = sem_seg_head
84
+ self.criterion = criterion
85
+ self.num_queries = num_queries
86
+ self.overlap_threshold = overlap_threshold
87
+ self.object_mask_threshold = object_mask_threshold
88
+ self.class_threshold_thing = class_threshold_thing
89
+ self.class_threshold_stuff = class_threshold_stuff
90
+ self.reorder_class_weight = reorder_class_weight
91
+ self.reorder_mask_weight = reorder_mask_weight
92
+ self.metadata = metadata
93
+ if size_divisibility < 0:
94
+ # use backbone size_divisibility if not set
95
+ size_divisibility = self.backbone.size_divisibility
96
+ self.size_divisibility = size_divisibility
97
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
98
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
99
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
100
+
101
+ # additional args
102
+ self.semantic_on = semantic_on
103
+ self.instance_on = instance_on
104
+ self.panoptic_on = panoptic_on
105
+ self.test_topk_per_image = test_topk_per_image
106
+
107
+ if not self.semantic_on:
108
+ assert self.sem_seg_postprocess_before_inference
109
+
110
+ self.input_shape = input_shape
111
+
112
+ @classmethod
113
+ def from_config(cls, cfg):
114
+ backbone = build_backbone(cfg)
115
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
116
+
117
+ # Loss parameters:
118
+ deep_supervision = cfg.MODEL.KMAX_DEEPLAB.DEEP_SUPERVISION
119
+ no_object_weight = cfg.MODEL.KMAX_DEEPLAB.NO_OBJECT_WEIGHT
120
+ share_final_matching = cfg.MODEL.KMAX_DEEPLAB.SHARE_FINAL_MATCHING
121
+
122
+ # loss weights
123
+ class_weight = cfg.MODEL.KMAX_DEEPLAB.CLASS_WEIGHT
124
+ dice_weight = cfg.MODEL.KMAX_DEEPLAB.DICE_WEIGHT
125
+ mask_weight = cfg.MODEL.KMAX_DEEPLAB.MASK_WEIGHT
126
+ insdis_weight = cfg.MODEL.KMAX_DEEPLAB.INSDIS_WEIGHT
127
+ aux_semantic_weight = cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT
128
+
129
+ # building criterion
130
+ matcher = HungarianMatcher()
131
+
132
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight,
133
+ "loss_pixel_insdis": insdis_weight, "loss_aux_semantic": aux_semantic_weight}
134
+
135
+ if deep_supervision:
136
+ dec_layers = sum(cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS)
137
+ aux_weight_dict = {}
138
+ for i in range(dec_layers):
139
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
140
+ weight_dict.update(aux_weight_dict)
141
+
142
+ losses = ["labels", "masks"]
143
+ if insdis_weight > 0:
144
+ losses += ["pixels"]
145
+ if aux_semantic_weight > 0:
146
+ losses += ["aux_semantic"]
147
+
148
+ criterion = SetCriterion(
149
+ sem_seg_head.num_classes,
150
+ matcher=matcher,
151
+ weight_dict=weight_dict,
152
+ eos_coef=no_object_weight,
153
+ losses=losses,
154
+ share_final_matching=share_final_matching,
155
+ pixel_insdis_temperature=cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_TEMPERATURE,
156
+ pixel_insdis_sample_k=cfg.MODEL.KMAX_DEEPLAB.PIXEL_INSDIS_SAMPLE_K,
157
+ aux_semantic_temperature=cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_TEMPERATURE,
158
+ aux_semantic_sample_k=cfg.MODEL.KMAX_DEEPLAB.UX_SEMANTIC_SAMPLE_K
159
+ )
160
+
161
+ return {
162
+ "backbone": backbone,
163
+ "sem_seg_head": sem_seg_head,
164
+ "criterion": criterion,
165
+ "num_queries": cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES,
166
+ "object_mask_threshold": cfg.MODEL.KMAX_DEEPLAB.TEST.OBJECT_MASK_THRESHOLD,
167
+ "class_threshold_thing": cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_THING,
168
+ "class_threshold_stuff": cfg.MODEL.KMAX_DEEPLAB.TEST.CLASS_THRESHOLD_STUFF,
169
+ "overlap_threshold": cfg.MODEL.KMAX_DEEPLAB.TEST.OVERLAP_THRESHOLD,
170
+ "reorder_class_weight": cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_CLASS_WEIGHT,
171
+ "reorder_mask_weight": cfg.MODEL.KMAX_DEEPLAB.TEST.REORDER_MASK_WEIGHT,
172
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
173
+ "size_divisibility": cfg.MODEL.KMAX_DEEPLAB.SIZE_DIVISIBILITY,
174
+ "sem_seg_postprocess_before_inference": (
175
+ cfg.MODEL.KMAX_DEEPLAB.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
176
+ or cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON
177
+ or cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON
178
+ ),
179
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
180
+ "pixel_std": cfg.MODEL.PIXEL_STD,
181
+ # inference
182
+ "semantic_on": cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON,
183
+ "instance_on": cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON,
184
+ "panoptic_on": cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON,
185
+ "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
186
+ "input_shape": cfg.INPUT.IMAGE_SIZE
187
+ }
188
+
189
+ @property
190
+ def device(self):
191
+ return self.pixel_mean.device
192
+
193
+ def forward(self, batched_inputs):
194
+ """
195
+ Args:
196
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
197
+ Each item in the list contains the inputs for one image.
198
+ For now, each item in the list is a dict that contains:
199
+ * "image": Tensor, image in (C, H, W) format.
200
+ * "instances": per-region ground truth
201
+ * Other information that's included in the original dicts, such as:
202
+ "height", "width" (int): the output resolution of the model (may be different
203
+ from input resolution), used in inference.
204
+ Returns:
205
+ list[dict]:
206
+ each dict has the results for one image. The dict contains the following keys:
207
+
208
+ * "sem_seg":
209
+ A Tensor that represents the
210
+ per-pixel segmentation prediced by the head.
211
+ The prediction has shape KxHxW that represents the logits of
212
+ each class for each pixel.
213
+ * "panoptic_seg":
214
+ A tuple that represent panoptic output
215
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
216
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
217
+ Each dict contains keys "id", "category_id", "isthing".
218
+ """
219
+ images = [x["image"].to(self.device) for x in batched_inputs]
220
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
221
+ if "is_real_pixels" in batched_inputs[0]:
222
+ is_real_pixels = [x["is_real_pixels"] for x in batched_inputs]
223
+ # Set all padded pixel values to 0.
224
+ images = [x * y.to(x) for x, y in zip(images, is_real_pixels)]
225
+
226
+ # We perform zero padding to ensure input shape equal to self.input_shape.
227
+ # The padding is done on the right and bottom sides.
228
+ for idx in range(len(images)):
229
+ cur_height, cur_width = images[idx].shape[-2:]
230
+ padding = (0, max(0, self.input_shape[1] - cur_width), 0, max(0, self.input_shape[0] - cur_height), 0, 0)
231
+ images[idx] = F.pad(images[idx], padding, value=0)
232
+ images = ImageList.from_tensors(images, -1)
233
+
234
+ if self.training:
235
+ # mask classification target
236
+ if "instances" in batched_inputs[0]:
237
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
238
+ gt_semantic = [x["sem_seg_gt"].to(self.device) for x in batched_inputs]
239
+ targets = self.prepare_targets(gt_instances, gt_semantic, images)
240
+ else:
241
+ targets = None
242
+
243
+ features = self.backbone(images.tensor)
244
+ outputs = self.sem_seg_head(features)
245
+
246
+ if self.training:
247
+
248
+ with autocast(enabled=False):
249
+ # bipartite matching-based loss
250
+ for output_key in ["pixel_feature", "pred_masks", "pred_logits", "aux_semantic_pred"]:
251
+ if output_key in outputs:
252
+ outputs[output_key] = outputs[output_key].float()
253
+ for i in range(len(outputs["aux_outputs"])):
254
+ for output_key in ["pixel_feature", "pred_masks", "pred_logits"]:
255
+ outputs["aux_outputs"][i][output_key] = outputs["aux_outputs"][i][output_key].float()
256
+
257
+ losses = self.criterion(outputs, targets)
258
+
259
+ for k in list(losses.keys()):
260
+ if k in self.criterion.weight_dict:
261
+ losses[k] *= self.criterion.weight_dict[k]
262
+ else:
263
+ # remove this loss if not specified in `weight_dict`
264
+ losses.pop(k)
265
+ return losses
266
+ else:
267
+ mask_cls_results = outputs["pred_logits"]
268
+ mask_pred_results = outputs["pred_masks"]
269
+
270
+ align_corners = (images.tensor.shape[-1] % 2 == 1)
271
+ # upsample masks
272
+ mask_pred_results = F.interpolate(
273
+ mask_pred_results,
274
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
275
+ mode="bilinear",
276
+ align_corners=align_corners,
277
+ )
278
+
279
+ del outputs
280
+
281
+ processed_results = []
282
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
283
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
284
+ ):
285
+ height = input_per_image.get("height", image_size[0])
286
+ width = input_per_image.get("width", image_size[1])
287
+ cur_image = input_per_image["image"].to(self.device)
288
+ processed_results.append({})
289
+ scale_factor = max(images.tensor.shape[-2:]) / max(height, width)
290
+ ori_height, ori_width = round(height * scale_factor), round(width * scale_factor)
291
+ mask_pred_result = mask_pred_result[:, :ori_height, :ori_width].expand(1, -1, -1, -1)
292
+ cur_image = cur_image[:, :ori_height, :ori_width].expand(1, -1, -1, -1)
293
+ mask_pred_result = F.interpolate(
294
+ mask_pred_result, size=(height, width), mode="bilinear", align_corners=align_corners
295
+ )[0]
296
+ cur_image = F.interpolate(
297
+ cur_image.float(), size=(height, width), mode="bilinear", align_corners=align_corners
298
+ )[0].to(torch.uint8)
299
+
300
+ if self.sem_seg_postprocess_before_inference:
301
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
302
+
303
+ # semantic segmentation inference
304
+ if self.semantic_on:
305
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
306
+ if not self.sem_seg_postprocess_before_inference:
307
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
308
+ processed_results[-1]["sem_seg"] = r
309
+
310
+ # panoptic segmentation inference
311
+ if self.panoptic_on:
312
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
313
+ processed_results[-1]["panoptic_seg"] = panoptic_r
314
+ processed_results[-1]["original_image"] = cur_image
315
+
316
+ # instance segmentation inference
317
+ if self.instance_on:
318
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
319
+ processed_results[-1]["instances"] = instance_r
320
+
321
+ return processed_results
322
+
323
+ def prepare_targets(self, targets, targets_semantic, images):
324
+ new_targets = []
325
+ for targets_per_image, semantic_gt_mask in zip(targets, targets_semantic):
326
+ gt_masks = targets_per_image.gt_masks
327
+ new_targets.append(
328
+ {
329
+ "labels": targets_per_image.gt_classes,
330
+ "masks": gt_masks,
331
+ "semantic_masks": semantic_gt_mask
332
+ }
333
+ )
334
+ return new_targets
335
+
336
+ def semantic_inference(self, mask_cls, mask_pred):
337
+ # For cls prob, we exluced the void class following
338
+ # https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L199
339
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
340
+ mask_pred = F.softmax(mask_pred, dim=0)
341
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
342
+ return semseg
343
+
344
+ def panoptic_inference(self, mask_cls, mask_pred):
345
+ # mask_cls: N x C
346
+ # mask_pred: N x H x W
347
+ # some hyper-params
348
+ num_mask_slots = mask_pred.shape[0]
349
+ cls_threshold_thing = self.class_threshold_thing
350
+ cls_threshold_stuff = self.class_threshold_stuff
351
+ object_mask_threshold = self.object_mask_threshold
352
+ overlap_threshold = self.overlap_threshold
353
+ reorder_class_weight = self.reorder_class_weight
354
+ reorder_mask_weight = self.reorder_mask_weight
355
+
356
+ # https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L675
357
+ # https://github.com/google-research/deeplab2/blob/main/model/post_processor/max_deeplab.py#L199
358
+ cls_scores, cls_labels = F.softmax(mask_cls, dim=-1)[..., :-1].max(-1) # N
359
+ mask_scores = F.softmax(mask_pred, dim=0)
360
+ binary_masks = mask_scores > object_mask_threshold # N x H x W
361
+ mask_scores_flat = mask_scores.flatten(1) # N x HW
362
+ binary_masks_flat = binary_masks.flatten(1).float() # N x HW
363
+ pixel_number_flat = binary_masks_flat.sum(1) # N
364
+ mask_scores_flat = (mask_scores_flat * binary_masks_flat).sum(1) / torch.clamp(pixel_number_flat, min=1.0) # N
365
+
366
+ reorder_score = (cls_scores ** reorder_class_weight) * (mask_scores_flat ** reorder_mask_weight) # N
367
+ reorder_indices = torch.argsort(reorder_score, dim=-1, descending=True)
368
+
369
+ panoptic_seg = torch.zeros((mask_pred.shape[1], mask_pred.shape[2]),
370
+ dtype=torch.int32, device=mask_pred.device)
371
+ segments_info = []
372
+
373
+ current_segment_id = 0
374
+ stuff_memory_list = {}
375
+ for i in range(num_mask_slots):
376
+ cur_idx = reorder_indices[i].item() # 1
377
+ cur_binary_mask = binary_masks[cur_idx] # H x W
378
+ cur_cls_score = cls_scores[cur_idx].item() # 1
379
+ cur_cls_label = cls_labels[cur_idx].item() # 1
380
+ is_thing = cur_cls_label in self.metadata.thing_dataset_id_to_contiguous_id.values()
381
+ is_confident = (is_thing and cur_cls_score > cls_threshold_thing) or (
382
+ (not is_thing) and cur_cls_score > cls_threshold_stuff)
383
+
384
+ original_pixel_number = cur_binary_mask.float().sum()
385
+ new_binary_mask = torch.logical_and(cur_binary_mask, (panoptic_seg == 0))
386
+ new_pixel_number = new_binary_mask.float().sum()
387
+ is_not_overlap_too_much = new_pixel_number > (original_pixel_number * overlap_threshold)
388
+
389
+ if is_confident and is_not_overlap_too_much:
390
+ # merge stuff regions
391
+ if not is_thing:
392
+ if int(cur_cls_label) in stuff_memory_list.keys():
393
+ panoptic_seg[new_binary_mask] = stuff_memory_list[int(cur_cls_label)]
394
+ continue
395
+ else:
396
+ stuff_memory_list[int(cur_cls_label)] = current_segment_id + 1
397
+
398
+ current_segment_id += 1
399
+ panoptic_seg[new_binary_mask] = current_segment_id
400
+
401
+ segments_info.append(
402
+ {
403
+ "id": current_segment_id,
404
+ "isthing": bool(is_thing),
405
+ "category_id": int(cur_cls_label),
406
+ }
407
+ )
408
+
409
+ return panoptic_seg, segments_info
410
+
411
+
412
+ def instance_inference(self, mask_cls, mask_pred):
413
+ # mask_pred is already processed to have the same shape as original input
414
+ image_size = mask_pred.shape[-2:]
415
+
416
+ mask_pred = mask_pred.softmax(dim=0)
417
+ # [Q, K]
418
+ scores = F.softmax(mask_cls[:, :-1], dim=-1)
419
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
420
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
421
+ labels_per_image = labels[topk_indices]
422
+
423
+ topk_indices = topk_indices // self.sem_seg_head.num_classes
424
+ mask_pred = mask_pred[topk_indices]
425
+
426
+ # if this is panoptic segmentation, we only keep the "thing" classes
427
+ if self.panoptic_on:
428
+ keep = torch.zeros_like(scores_per_image).bool()
429
+ for i, lab in enumerate(labels_per_image):
430
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
431
+
432
+ scores_per_image = scores_per_image[keep]
433
+ labels_per_image = labels_per_image[keep]
434
+ mask_pred = mask_pred[keep]
435
+
436
+ result = Instances(image_size)
437
+ result.pred_masks = (mask_pred > self.object_mask_threshold).float()
438
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
439
+ # Uncomment the following to get boxes from masks (this is slow)
440
+ # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
441
+
442
+ # calculate average mask prob
443
+ mask_scores_per_image = (mask_pred.flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
444
+ result.scores = scores_per_image * mask_scores_per_image
445
+ result.pred_classes = labels_per_image
446
+ return result
kmax_deeplab/modeling/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .backbone.convnext import D2ConvNeXt
2
+ from .backbone.resnet import custom_bn_build_resnet_backbone
3
+ from .pixel_decoder.kmax_pixel_decoder import kMaXPixelDecoder
4
+ from .meta_arch.kmax_deeplab_head import kMaXDeepLabHead
kmax_deeplab/modeling/backbone/__init__.py ADDED
File without changes
kmax_deeplab/modeling/backbone/convnext.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference: https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/modeling/backbone/convnext.py
2
+
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from timm.models.layers import DropPath
10
+
11
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
12
+ from torch.cuda.amp import autocast
13
+
14
+
15
+ class Block(nn.Module):
16
+ r""" ConvNeXt Block. There are two equivalent implementations:
17
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19
+ We use (2) as we find it slightly faster in PyTorch
20
+
21
+ Args:
22
+ dim (int): Number of input channels.
23
+ drop_path (float): Stochastic depth rate. Default: 0.0
24
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25
+ """
26
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.norm = LayerNorm(dim, eps=1e-6)
30
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31
+ self.act = nn.GELU()
32
+ self.pwconv2 = nn.Linear(4 * dim, dim)
33
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34
+ requires_grad=True) if layer_scale_init_value > 0 else None
35
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ input = x
39
+ x = self.dwconv(x)
40
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41
+ x = self.norm(x)
42
+ x = self.pwconv1(x)
43
+ x = self.act(x)
44
+ x = self.pwconv2(x)
45
+ if self.gamma is not None:
46
+ x = self.gamma * x
47
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48
+
49
+ x = input + self.drop_path(x)
50
+ return x
51
+
52
+ class LayerNorm(nn.Module):
53
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
54
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
55
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
56
+ with shape (batch_size, channels, height, width).
57
+ """
58
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
59
+ super().__init__()
60
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
61
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
62
+ self.eps = eps
63
+ self.data_format = data_format
64
+ if self.data_format not in ["channels_last", "channels_first"]:
65
+ raise NotImplementedError
66
+ self.normalized_shape = (normalized_shape, )
67
+
68
+ def forward(self, x):
69
+ with autocast(enabled=False):
70
+ x = x.float()
71
+ if self.data_format == "channels_last":
72
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
73
+ elif self.data_format == "channels_first":
74
+ u = x.mean(1, keepdim=True)
75
+ s = (x - u).pow(2).mean(1, keepdim=True)
76
+ x = (x - u) / torch.sqrt(s + self.eps)
77
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
78
+ return x
79
+
80
+
81
+ class ConvNeXt(nn.Module):
82
+ r""" ConvNeXt
83
+ A PyTorch impl of : `A ConvNet for the 2020s` -
84
+ https://arxiv.org/pdf/2201.03545.pdf
85
+
86
+ Args:
87
+ in_chans (int): Number of input image channels. Default: 3
88
+ num_classes (int): Number of classes for classification head. Default: 1000
89
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
90
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
91
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
92
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
93
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
94
+ """
95
+ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
96
+ drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
97
+ ):
98
+ super().__init__()
99
+
100
+ self.num_features = dims
101
+
102
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
103
+ stem = nn.Sequential(
104
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
105
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
106
+ )
107
+ self.downsample_layers.append(stem)
108
+ for i in range(3):
109
+ downsample_layer = nn.Sequential(
110
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
111
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
112
+ )
113
+ self.downsample_layers.append(downsample_layer)
114
+
115
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
116
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
117
+ cur = 0
118
+ for i in range(4):
119
+ stage = nn.Sequential(
120
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
121
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
122
+ )
123
+ self.stages.append(stage)
124
+ cur += depths[i]
125
+
126
+ self.out_indices = out_indices
127
+
128
+ def forward_features(self, x):
129
+ outs = {}
130
+ for i in range(4):
131
+ # We add zero padding here for downstream tasks.
132
+ # ref: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/convnext.py#L128
133
+ if i == 0:
134
+ x = F.pad(x, (1, 2, 1, 2, 0, 0, 0, 0), "constant", 0)
135
+ else:
136
+ x = F.pad(x, (0, 1, 0, 1, 0, 0, 0, 0), "constant", 0)
137
+ x = self.downsample_layers[i](x)
138
+ x = self.stages[i](x)
139
+ if i in self.out_indices:
140
+ outs["res{}".format(i + 2)] = x
141
+
142
+ return outs
143
+
144
+ def forward(self, x):
145
+ x = self.forward_features(x)
146
+ return x
147
+
148
+ @BACKBONE_REGISTRY.register()
149
+ class D2ConvNeXt(ConvNeXt, Backbone):
150
+ def __init__(self, cfg, input_shape):
151
+
152
+ in_chans = cfg.MODEL.CONVNEXT.IN_CHANNELS
153
+ depths = cfg.MODEL.CONVNEXT.DEPTHS
154
+ dims = cfg.MODEL.CONVNEXT.DIMS
155
+ drop_path_rate = cfg.MODEL.CONVNEXT.DROP_PATH_RATE
156
+ layer_scale_init_value = cfg.MODEL.CONVNEXT.LSIT
157
+ out_indices = cfg.MODEL.CONVNEXT.OUT_INDICES
158
+
159
+ super().__init__(
160
+ in_chans=in_chans,
161
+ depths=depths,
162
+ dims=dims,
163
+ drop_path_rate=drop_path_rate,
164
+ layer_scale_init_value=layer_scale_init_value,
165
+ out_indices=out_indices,
166
+ )
167
+
168
+ self._out_features = cfg.MODEL.CONVNEXT.OUT_FEATURES
169
+
170
+ self._out_feature_strides = {
171
+ "res2": 4,
172
+ "res3": 8,
173
+ "res4": 16,
174
+ "res5": 32,
175
+ }
176
+ self._out_feature_channels = {
177
+ "res2": self.num_features[0],
178
+ "res3": self.num_features[1],
179
+ "res4": self.num_features[2],
180
+ "res5": self.num_features[3],
181
+ }
182
+
183
+ def forward(self, x):
184
+ """
185
+ Args:
186
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
187
+ Returns:
188
+ dict[str->Tensor]: names and the corresponding features
189
+ """
190
+ assert (
191
+ x.dim() == 4
192
+ ), f"ConvNeXt takes an input of shape (N, C, H, W). Got {x.shape} instead!"
193
+ outputs = {}
194
+ y = super().forward(x)
195
+ for k in y.keys():
196
+ if k in self._out_features:
197
+ outputs[k] = y[k]
198
+ return outputs
199
+
200
+ def output_shape(self):
201
+ return {
202
+ name: ShapeSpec(
203
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
204
+ )
205
+ for name in self._out_features
206
+ }
207
+
208
+ @property
209
+ def size_divisibility(self):
210
+ return -1
kmax_deeplab/modeling/backbone/resnet.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
2
+ # Modified by Qihang Yu
3
+
4
+ import numpy as np
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from detectron2.layers import (
11
+ CNNBlockBase,
12
+ Conv2d,
13
+ DeformConv,
14
+ ModulatedDeformConv,
15
+ #ShapeSpec,
16
+ #get_norm,
17
+ )
18
+
19
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
20
+
21
+ from ..pixel_decoder.kmax_pixel_decoder import get_norm
22
+
23
+ __all__ = [
24
+ "ResNetBlockBase",
25
+ "BasicBlock",
26
+ "BottleneckBlock",
27
+ "DeformBottleneckBlock",
28
+ "BasicStem",
29
+ "ResNet",
30
+ "make_stage",
31
+ "custom_bn_build_resnet_backbone",
32
+ ]
33
+
34
+
35
+ class BasicBlock(CNNBlockBase):
36
+ """
37
+ The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
38
+ with two 3x3 conv layers and a projection shortcut if needed.
39
+ """
40
+
41
+ def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
42
+ """
43
+ Args:
44
+ in_channels (int): Number of input channels.
45
+ out_channels (int): Number of output channels.
46
+ stride (int): Stride for the first conv.
47
+ norm (str or callable): normalization for all conv layers.
48
+ See :func:`layers.get_norm` for supported format.
49
+ """
50
+ super().__init__(in_channels, out_channels, stride)
51
+
52
+ if in_channels != out_channels:
53
+ self.shortcut = Conv2d(
54
+ in_channels,
55
+ out_channels,
56
+ kernel_size=1,
57
+ stride=stride,
58
+ bias=False,
59
+ norm=get_norm(norm, out_channels),
60
+ )
61
+ else:
62
+ self.shortcut = None
63
+
64
+ self.conv1 = Conv2d(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=3,
68
+ stride=stride,
69
+ padding=1,
70
+ bias=False,
71
+ norm=get_norm(norm, out_channels),
72
+ )
73
+
74
+ self.conv2 = Conv2d(
75
+ out_channels,
76
+ out_channels,
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=1,
80
+ bias=False,
81
+ norm=get_norm(norm, out_channels),
82
+ )
83
+
84
+ for layer in [self.conv1, self.conv2, self.shortcut]:
85
+ if layer is not None: # shortcut can be None
86
+ weight_init.c2_msra_fill(layer)
87
+
88
+ def forward(self, x):
89
+ out = self.conv1(x)
90
+ out = F.relu_(out)
91
+ out = self.conv2(out)
92
+
93
+ if self.shortcut is not None:
94
+ shortcut = self.shortcut(x)
95
+ else:
96
+ shortcut = x
97
+
98
+ out += shortcut
99
+ out = F.relu_(out)
100
+ return out
101
+
102
+
103
+ class BottleneckBlock(CNNBlockBase):
104
+ """
105
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
106
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
107
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ in_channels,
113
+ out_channels,
114
+ *,
115
+ bottleneck_channels,
116
+ stride=1,
117
+ num_groups=1,
118
+ norm="BN",
119
+ stride_in_1x1=False,
120
+ dilation=1,
121
+ ):
122
+ """
123
+ Args:
124
+ bottleneck_channels (int): number of output channels for the 3x3
125
+ "bottleneck" conv layers.
126
+ num_groups (int): number of groups for the 3x3 conv layer.
127
+ norm (str or callable): normalization for all conv layers.
128
+ See :func:`layers.get_norm` for supported format.
129
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
130
+ first 1x1 convolution or the bottleneck 3x3 convolution.
131
+ dilation (int): the dilation rate of the 3x3 conv layer.
132
+ """
133
+ super().__init__(in_channels, out_channels, stride)
134
+
135
+ if in_channels != out_channels:
136
+ self.shortcut = Conv2d(
137
+ in_channels,
138
+ out_channels,
139
+ kernel_size=1,
140
+ stride=stride,
141
+ bias=False,
142
+ norm=get_norm(norm, out_channels),
143
+ )
144
+ else:
145
+ self.shortcut = None
146
+
147
+ # The original MSRA ResNet models have stride in the first 1x1 conv
148
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
149
+ # stride in the 3x3 conv
150
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
151
+
152
+ self.conv1 = Conv2d(
153
+ in_channels,
154
+ bottleneck_channels,
155
+ kernel_size=1,
156
+ stride=stride_1x1,
157
+ bias=False,
158
+ norm=get_norm(norm, bottleneck_channels),
159
+ )
160
+
161
+ self.conv2 = Conv2d(
162
+ bottleneck_channels,
163
+ bottleneck_channels,
164
+ kernel_size=3,
165
+ stride=stride_3x3,
166
+ padding=1 * dilation,
167
+ bias=False,
168
+ groups=num_groups,
169
+ dilation=dilation,
170
+ norm=get_norm(norm, bottleneck_channels),
171
+ )
172
+
173
+ self.conv3 = Conv2d(
174
+ bottleneck_channels,
175
+ out_channels,
176
+ kernel_size=1,
177
+ bias=False,
178
+ norm=get_norm(norm, out_channels),
179
+ )
180
+
181
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
182
+ if layer is not None: # shortcut can be None
183
+ weight_init.c2_msra_fill(layer)
184
+
185
+ # Zero-initialize the last normalization in each residual branch,
186
+ # so that at the beginning, the residual branch starts with zeros,
187
+ # and each residual block behaves like an identity.
188
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
189
+ # "For BN layers, the learnable scaling coefficient γ is initialized
190
+ # to be 1, except for each residual block's last BN
191
+ # where γ is initialized to be 0."
192
+
193
+ # nn.init.constant_(self.conv3.norm.weight, 0)
194
+ # TODO this somehow hurts performance when training GN models from scratch.
195
+ # Add it as an option when we need to use this code to train a backbone.
196
+
197
+ def forward(self, x):
198
+ out = self.conv1(x)
199
+ out = F.relu_(out)
200
+
201
+ out = self.conv2(out)
202
+ out = F.relu_(out)
203
+
204
+ out = self.conv3(out)
205
+
206
+ if self.shortcut is not None:
207
+ shortcut = self.shortcut(x)
208
+ else:
209
+ shortcut = x
210
+
211
+ out += shortcut
212
+ out = F.relu_(out)
213
+ return out
214
+
215
+
216
+ class DeformBottleneckBlock(CNNBlockBase):
217
+ """
218
+ Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
219
+ in the 3x3 convolution.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ in_channels,
225
+ out_channels,
226
+ *,
227
+ bottleneck_channels,
228
+ stride=1,
229
+ num_groups=1,
230
+ norm="BN",
231
+ stride_in_1x1=False,
232
+ dilation=1,
233
+ deform_modulated=False,
234
+ deform_num_groups=1,
235
+ ):
236
+ super().__init__(in_channels, out_channels, stride)
237
+ self.deform_modulated = deform_modulated
238
+
239
+ if in_channels != out_channels:
240
+ self.shortcut = Conv2d(
241
+ in_channels,
242
+ out_channels,
243
+ kernel_size=1,
244
+ stride=stride,
245
+ bias=False,
246
+ norm=get_norm(norm, out_channels),
247
+ )
248
+ else:
249
+ self.shortcut = None
250
+
251
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
252
+
253
+ self.conv1 = Conv2d(
254
+ in_channels,
255
+ bottleneck_channels,
256
+ kernel_size=1,
257
+ stride=stride_1x1,
258
+ bias=False,
259
+ norm=get_norm(norm, bottleneck_channels),
260
+ )
261
+
262
+ if deform_modulated:
263
+ deform_conv_op = ModulatedDeformConv
264
+ # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
265
+ offset_channels = 27
266
+ else:
267
+ deform_conv_op = DeformConv
268
+ offset_channels = 18
269
+
270
+ self.conv2_offset = Conv2d(
271
+ bottleneck_channels,
272
+ offset_channels * deform_num_groups,
273
+ kernel_size=3,
274
+ stride=stride_3x3,
275
+ padding=1 * dilation,
276
+ dilation=dilation,
277
+ )
278
+ self.conv2 = deform_conv_op(
279
+ bottleneck_channels,
280
+ bottleneck_channels,
281
+ kernel_size=3,
282
+ stride=stride_3x3,
283
+ padding=1 * dilation,
284
+ bias=False,
285
+ groups=num_groups,
286
+ dilation=dilation,
287
+ deformable_groups=deform_num_groups,
288
+ norm=get_norm(norm, bottleneck_channels),
289
+ )
290
+
291
+ self.conv3 = Conv2d(
292
+ bottleneck_channels,
293
+ out_channels,
294
+ kernel_size=1,
295
+ bias=False,
296
+ norm=get_norm(norm, out_channels),
297
+ )
298
+
299
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
300
+ if layer is not None: # shortcut can be None
301
+ weight_init.c2_msra_fill(layer)
302
+
303
+ nn.init.constant_(self.conv2_offset.weight, 0)
304
+ nn.init.constant_(self.conv2_offset.bias, 0)
305
+
306
+ def forward(self, x):
307
+ out = self.conv1(x)
308
+ out = F.relu_(out)
309
+
310
+ if self.deform_modulated:
311
+ offset_mask = self.conv2_offset(out)
312
+ offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
313
+ offset = torch.cat((offset_x, offset_y), dim=1)
314
+ mask = mask.sigmoid()
315
+ out = self.conv2(out, offset, mask)
316
+ else:
317
+ offset = self.conv2_offset(out)
318
+ out = self.conv2(out, offset)
319
+ out = F.relu_(out)
320
+
321
+ out = self.conv3(out)
322
+
323
+ if self.shortcut is not None:
324
+ shortcut = self.shortcut(x)
325
+ else:
326
+ shortcut = x
327
+
328
+ out += shortcut
329
+ out = F.relu_(out)
330
+ return out
331
+
332
+
333
+ class BasicStem(CNNBlockBase):
334
+ """
335
+ The standard ResNet stem (layers before the first residual block),
336
+ with a conv, relu and max_pool.
337
+ """
338
+
339
+ def __init__(self, in_channels=3, out_channels=64, norm="BN"):
340
+ """
341
+ Args:
342
+ norm (str or callable): norm after the first conv layer.
343
+ See :func:`layers.get_norm` for supported format.
344
+ """
345
+ super().__init__(in_channels, out_channels, 4)
346
+ self.in_channels = in_channels
347
+ self.conv1 = Conv2d(
348
+ in_channels,
349
+ out_channels,
350
+ kernel_size=7,
351
+ stride=2,
352
+ padding=3,
353
+ bias=False,
354
+ norm=get_norm(norm, out_channels),
355
+ )
356
+ weight_init.c2_msra_fill(self.conv1)
357
+
358
+ def forward(self, x):
359
+ x = self.conv1(x)
360
+ x = F.relu_(x)
361
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
362
+ return x
363
+
364
+
365
+ class ResNet(Backbone):
366
+ """
367
+ Implement :paper:`ResNet`.
368
+ """
369
+
370
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
371
+ """
372
+ Args:
373
+ stem (nn.Module): a stem module
374
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
375
+ each contains multiple :class:`CNNBlockBase`.
376
+ num_classes (None or int): if None, will not perform classification.
377
+ Otherwise, will create a linear layer.
378
+ out_features (list[str]): name of the layers whose outputs should
379
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
380
+ If None, will return the output of the last layer.
381
+ freeze_at (int): The number of stages at the beginning to freeze.
382
+ see :meth:`freeze` for detailed explanation.
383
+ """
384
+ super().__init__()
385
+ self.stem = stem
386
+ self.num_classes = num_classes
387
+
388
+ current_stride = self.stem.stride
389
+ self._out_feature_strides = {"stem": current_stride}
390
+ self._out_feature_channels = {"stem": self.stem.out_channels}
391
+
392
+ self.stage_names, self.stages = [], []
393
+
394
+ if out_features is not None:
395
+ # Avoid keeping unused layers in this module. They consume extra memory
396
+ # and may cause allreduce to fail
397
+ num_stages = max(
398
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
399
+ )
400
+ stages = stages[:num_stages]
401
+ for i, blocks in enumerate(stages):
402
+ assert len(blocks) > 0, len(blocks)
403
+ for block in blocks:
404
+ assert isinstance(block, CNNBlockBase), block
405
+
406
+ name = "res" + str(i + 2)
407
+ stage = nn.Sequential(*blocks)
408
+
409
+ self.add_module(name, stage)
410
+ self.stage_names.append(name)
411
+ self.stages.append(stage)
412
+
413
+ self._out_feature_strides[name] = current_stride = int(
414
+ current_stride * np.prod([k.stride for k in blocks])
415
+ )
416
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
417
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
418
+
419
+ if num_classes is not None:
420
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
421
+ self.linear = nn.Linear(curr_channels, num_classes)
422
+
423
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
424
+ # "The 1000-way fully-connected layer is initialized by
425
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
426
+ nn.init.normal_(self.linear.weight, std=0.01)
427
+ name = "linear"
428
+
429
+ if out_features is None:
430
+ out_features = [name]
431
+ self._out_features = out_features
432
+ assert len(self._out_features)
433
+ children = [x[0] for x in self.named_children()]
434
+ for out_feature in self._out_features:
435
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
436
+ self.freeze(freeze_at)
437
+
438
+ def forward(self, x):
439
+ """
440
+ Args:
441
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
442
+
443
+ Returns:
444
+ dict[str->Tensor]: names and the corresponding features
445
+ """
446
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
447
+ outputs = {}
448
+ x = self.stem(x)
449
+ if "stem" in self._out_features:
450
+ outputs["stem"] = x
451
+ for name, stage in zip(self.stage_names, self.stages):
452
+ x = stage(x)
453
+ if name in self._out_features:
454
+ outputs[name] = x
455
+ if self.num_classes is not None:
456
+ x = self.avgpool(x)
457
+ x = torch.flatten(x, 1)
458
+ x = self.linear(x)
459
+ if "linear" in self._out_features:
460
+ outputs["linear"] = x
461
+ return outputs
462
+
463
+ def output_shape(self):
464
+ return {
465
+ name: ShapeSpec(
466
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
467
+ )
468
+ for name in self._out_features
469
+ }
470
+
471
+ def freeze(self, freeze_at=0):
472
+ """
473
+ Freeze the first several stages of the ResNet. Commonly used in
474
+ fine-tuning.
475
+
476
+ Layers that produce the same feature map spatial size are defined as one
477
+ "stage" by :paper:`FPN`.
478
+
479
+ Args:
480
+ freeze_at (int): number of stages to freeze.
481
+ `1` means freezing the stem. `2` means freezing the stem and
482
+ one residual stage, etc.
483
+
484
+ Returns:
485
+ nn.Module: this ResNet itself
486
+ """
487
+ if freeze_at >= 1:
488
+ self.stem.freeze()
489
+ for idx, stage in enumerate(self.stages, start=2):
490
+ if freeze_at >= idx:
491
+ for block in stage.children():
492
+ block.freeze()
493
+ return self
494
+
495
+ @staticmethod
496
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
497
+ """
498
+ Create a list of blocks of the same type that forms one ResNet stage.
499
+
500
+ Args:
501
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
502
+ stage. A module of this type must not change spatial resolution of inputs unless its
503
+ stride != 1.
504
+ num_blocks (int): number of blocks in this stage
505
+ in_channels (int): input channels of the entire stage.
506
+ out_channels (int): output channels of **every block** in the stage.
507
+ kwargs: other arguments passed to the constructor of
508
+ `block_class`. If the argument name is "xx_per_block", the
509
+ argument is a list of values to be passed to each block in the
510
+ stage. Otherwise, the same argument is passed to every block
511
+ in the stage.
512
+
513
+ Returns:
514
+ list[CNNBlockBase]: a list of block module.
515
+
516
+ Examples:
517
+ ::
518
+ stage = ResNet.make_stage(
519
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
520
+ bottleneck_channels=16, num_groups=1,
521
+ stride_per_block=[2, 1, 1],
522
+ dilations_per_block=[1, 1, 2]
523
+ )
524
+
525
+ Usually, layers that produce the same feature map spatial size are defined as one
526
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
527
+ all be 1.
528
+ """
529
+ blocks = []
530
+ for i in range(num_blocks):
531
+ curr_kwargs = {}
532
+ for k, v in kwargs.items():
533
+ if k.endswith("_per_block"):
534
+ assert len(v) == num_blocks, (
535
+ f"Argument '{k}' of make_stage should have the "
536
+ f"same length as num_blocks={num_blocks}."
537
+ )
538
+ newk = k[: -len("_per_block")]
539
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
540
+ curr_kwargs[newk] = v[i]
541
+ else:
542
+ curr_kwargs[k] = v
543
+
544
+ blocks.append(
545
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
546
+ )
547
+ in_channels = out_channels
548
+ return blocks
549
+
550
+ @staticmethod
551
+ def make_default_stages(depth, block_class=None, **kwargs):
552
+ """
553
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
554
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
555
+ instead for fine-grained customization.
556
+
557
+ Args:
558
+ depth (int): depth of ResNet
559
+ block_class (type): the CNN block class. Has to accept
560
+ `bottleneck_channels` argument for depth > 50.
561
+ By default it is BasicBlock or BottleneckBlock, based on the
562
+ depth.
563
+ kwargs:
564
+ other arguments to pass to `make_stage`. Should not contain
565
+ stride and channels, as they are predefined for each depth.
566
+
567
+ Returns:
568
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
569
+ :class:`ResNet.__init__`.
570
+ """
571
+ num_blocks_per_stage = {
572
+ 18: [2, 2, 2, 2],
573
+ 34: [3, 4, 6, 3],
574
+ 50: [3, 4, 6, 3],
575
+ 101: [3, 4, 23, 3],
576
+ 152: [3, 8, 36, 3],
577
+ }[depth]
578
+ if block_class is None:
579
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
580
+ if depth < 50:
581
+ in_channels = [64, 64, 128, 256]
582
+ out_channels = [64, 128, 256, 512]
583
+ else:
584
+ in_channels = [64, 256, 512, 1024]
585
+ out_channels = [256, 512, 1024, 2048]
586
+ ret = []
587
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
588
+ if depth >= 50:
589
+ kwargs["bottleneck_channels"] = o // 4
590
+ ret.append(
591
+ ResNet.make_stage(
592
+ block_class=block_class,
593
+ num_blocks=n,
594
+ stride_per_block=[s] + [1] * (n - 1),
595
+ in_channels=i,
596
+ out_channels=o,
597
+ **kwargs,
598
+ )
599
+ )
600
+ return ret
601
+
602
+
603
+ ResNetBlockBase = CNNBlockBase
604
+ """
605
+ Alias for backward compatibiltiy.
606
+ """
607
+
608
+
609
+ def make_stage(*args, **kwargs):
610
+ """
611
+ Deprecated alias for backward compatibiltiy.
612
+ """
613
+ return ResNet.make_stage(*args, **kwargs)
614
+
615
+
616
+ @BACKBONE_REGISTRY.register()
617
+ def custom_bn_build_resnet_backbone(cfg, input_shape):
618
+ """
619
+ Create a ResNet instance from config.
620
+
621
+ Returns:
622
+ ResNet: a :class:`ResNet` instance.
623
+ """
624
+ # need registration of new blocks/stems?
625
+ norm = cfg.MODEL.RESNETS.NORM
626
+ stem = BasicStem(
627
+ in_channels=input_shape.channels,
628
+ out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
629
+ norm=norm,
630
+ )
631
+
632
+ # fmt: off
633
+ freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
634
+ out_features = cfg.MODEL.RESNETS.OUT_FEATURES
635
+ depth = cfg.MODEL.RESNETS.DEPTH
636
+ num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
637
+ width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
638
+ bottleneck_channels = num_groups * width_per_group
639
+ in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
640
+ out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
641
+ stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
642
+ res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
643
+ deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
644
+ deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
645
+ deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
646
+ # fmt: on
647
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
648
+
649
+ num_blocks_per_stage = {
650
+ 18: [2, 2, 2, 2],
651
+ 34: [3, 4, 6, 3],
652
+ 50: [3, 4, 6, 3],
653
+ 101: [3, 4, 23, 3],
654
+ 152: [3, 8, 36, 3],
655
+ }[depth]
656
+
657
+ if depth in [18, 34]:
658
+ assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
659
+ assert not any(
660
+ deform_on_per_stage
661
+ ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
662
+ assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
663
+ assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
664
+
665
+ stages = []
666
+
667
+ for idx, stage_idx in enumerate(range(2, 6)):
668
+ # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
669
+ dilation = res5_dilation if stage_idx == 5 else 1
670
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
671
+ stage_kargs = {
672
+ "num_blocks": num_blocks_per_stage[idx],
673
+ "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
674
+ "in_channels": in_channels,
675
+ "out_channels": out_channels,
676
+ "norm": norm,
677
+ }
678
+ # Use BasicBlock for R18 and R34.
679
+ if depth in [18, 34]:
680
+ stage_kargs["block_class"] = BasicBlock
681
+ else:
682
+ stage_kargs["bottleneck_channels"] = bottleneck_channels
683
+ stage_kargs["stride_in_1x1"] = stride_in_1x1
684
+ stage_kargs["dilation"] = dilation
685
+ stage_kargs["num_groups"] = num_groups
686
+ if deform_on_per_stage[idx]:
687
+ stage_kargs["block_class"] = DeformBottleneckBlock
688
+ stage_kargs["deform_modulated"] = deform_modulated
689
+ stage_kargs["deform_num_groups"] = deform_num_groups
690
+ else:
691
+ stage_kargs["block_class"] = BottleneckBlock
692
+ blocks = ResNet.make_stage(**stage_kargs)
693
+ in_channels = out_channels
694
+ out_channels *= 2
695
+ bottleneck_channels *= 2
696
+ stages.append(blocks)
697
+ return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
kmax_deeplab/modeling/criterion.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
2
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
3
+ # Modified by Qihang Yu
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ _SOFTMAX_MASKING_CONSTANT = -99999.0
10
+
11
+ # https://www.tensorflow.org/api_docs/python/tf/math/divide_no_nan
12
+ def divide_no_nan(x: torch.Tensor, y: torch.Tensor):
13
+ return torch.nan_to_num(x / y, nan=0.0, posinf=0.0, neginf=0.0)
14
+
15
+
16
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L393
17
+ def focal_cross_entropy_loss(
18
+ pred: torch.Tensor,
19
+ gt: torch.Tensor,
20
+ weight: torch.Tensor, # This is for PQ-loss weighting
21
+ focal_loss_alpha: float = 0.75,
22
+ focal_loss_gamma: float = 0.0,
23
+ background_channel_index: int = -1):
24
+ """
25
+ pred: B x N x C
26
+ gt: B x N
27
+ weight: B x N
28
+ """
29
+ pred = pred.transpose(1, 2) # B x C x N
30
+ gt = F.one_hot(gt, num_classes=pred.shape[1]).transpose(1, 2).to(pred) # B x C x N
31
+ loss = F.cross_entropy(pred, gt, reduction="none") # B x N
32
+ if focal_loss_gamma == 0.0:
33
+ focal_loss = loss
34
+ else:
35
+ pred = F.softmax(pred, dim=1) # B x C x N
36
+ pt = (pred * gt).sum(1) # B x N
37
+ focal_loss = torch.pow(1.0 - pt, focal_loss_gamma) * loss # B x N
38
+
39
+ if focal_loss_alpha >= 0:
40
+ alpha_weights = (
41
+ focal_loss_alpha * (1.0 - gt[:, background_channel_index])
42
+ + (1 - focal_loss_alpha) * gt[:, background_channel_index]) # B x N
43
+ focal_loss = alpha_weights * focal_loss # B x N
44
+
45
+ focal_loss = focal_loss * weight # B x N
46
+ focal_loss = focal_loss.flatten(1)
47
+ num_non_zero = (focal_loss != 0.0).to(focal_loss).sum(-1) # B
48
+ num_non_zero = torch.clamp(num_non_zero, min=1.0)
49
+ loss_sum_per_sample = focal_loss.sum(-1) # B
50
+ return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
51
+
52
+
53
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L50
54
+ def _gumbel_topk_sample(logits: torch.Tensor, k: int):
55
+ """Samples k points from the softmax distribution with Gumbel-Top-k trick."""
56
+ # Note that torch.rand is [0, 1), we need to make it (0, 1) to ensure the log is valid.
57
+ gumbel_noise = torch.rand(size=logits.shape, dtype=logits.dtype, device=logits.device)
58
+ gumbel_noise = -torch.log(-torch.log(gumbel_noise))
59
+ _, indices = torch.topk(logits + gumbel_noise, k)
60
+ return indices
61
+
62
+
63
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L576
64
+ def pixelwise_insdis_loss(
65
+ pixel_feature: torch.Tensor,
66
+ gt_mask: torch.Tensor,
67
+ sample_temperature: float,
68
+ sample_k: int,
69
+ instance_discrimination_temperature: float,
70
+ pixel_gt_void_mask: torch.Tensor,
71
+ inverse_gt_mask_area: torch.Tensor
72
+ ):
73
+
74
+ # pixel_feature: B x C x H x W
75
+ # gt_mask: B x N x H x W
76
+ pixel_feature = pixel_feature.flatten(2) # B x C x HW
77
+ gt_mask = gt_mask.flatten(2) # B x N x HW
78
+ pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
79
+ inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
80
+
81
+ sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
82
+ # sample_logits.masked_fill_(pixel_gt_void_mask, float('-inf'))
83
+ sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
84
+
85
+ sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
86
+ # Sample ground truth one-hot encodings and compute gt_similarity.
87
+ pixel_gt_sampled_feature = torch.gather(gt_mask, dim=2, index=sample_indices.unsqueeze(1).repeat(1, gt_mask.shape[1], 1)) # B x N x K
88
+ sampled_gt_similarity = torch.einsum('bnk,bnj->bkj', pixel_gt_sampled_feature, pixel_gt_sampled_feature) # B x K x K
89
+
90
+ # Normalize the ground truth similarity into a distribution (sum to 1).
91
+ pixel_normalizing_constant = sampled_gt_similarity.sum(dim=1, keepdim=True) # B x 1 x K
92
+ sampled_gt_similarity /= torch.clamp(pixel_normalizing_constant, min=1.0) # B x K x K
93
+
94
+ # Sample predicted features and compute pred_similarity.
95
+ pixel_pred_sampled_feature = torch.gather(pixel_feature, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pixel_feature.shape[1], 1)) # B x C x K
96
+ sampled_pred_similarity = torch.einsum('bck,bcj->bkj', pixel_pred_sampled_feature, pixel_pred_sampled_feature) # B x K x K
97
+ sampled_pred_similarity /= instance_discrimination_temperature # B x K x K
98
+ loss = F.cross_entropy(sampled_pred_similarity, sampled_gt_similarity, reduction="none") # B x K
99
+
100
+ num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
101
+ num_non_zero = torch.clamp(num_non_zero, min=1.0)
102
+ loss_sum_per_sample = loss.sum(-1) # B
103
+ return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
104
+
105
+
106
+ def aux_semantic_loss(
107
+ pred_semantic_logits: torch.Tensor,
108
+ ground_truth_semantic: torch.Tensor,
109
+ sample_temperature: float,
110
+ sample_k: int,
111
+ pixel_gt_void_mask: torch.Tensor,
112
+ inverse_gt_mask_area: torch.Tensor,
113
+ num_classes: int):
114
+
115
+ pred_semantic_logits = pred_semantic_logits.flatten(2) # B x C x HW
116
+ ground_truth_semantic = ground_truth_semantic.flatten(1) # B x HW
117
+ pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
118
+ inverse_gt_mask_area = inverse_gt_mask_area.flatten(1) # B x HW
119
+
120
+ sample_logits = torch.log(inverse_gt_mask_area) * sample_temperature # B x HW
121
+ sample_logits += pixel_gt_void_mask.to(sample_logits) * _SOFTMAX_MASKING_CONSTANT
122
+
123
+ sample_indices = _gumbel_topk_sample(sample_logits, sample_k) # B x K
124
+ sampled_ground_truth_semantic = torch.gather(ground_truth_semantic, dim=1, index=sample_indices) # B x K
125
+ sampled_pred_semantic_logits = torch.gather(pred_semantic_logits, dim=2, index=sample_indices.unsqueeze(1).repeat(1, pred_semantic_logits.shape[1], 1)) # B x C x K
126
+ # ignore the class index num_classes.
127
+ keep_mask = (sampled_ground_truth_semantic != num_classes) # B x K
128
+ loss = F.cross_entropy(sampled_pred_semantic_logits, sampled_ground_truth_semantic, ignore_index=num_classes, reduction='none') # B x K
129
+ loss = loss * keep_mask.to(loss)
130
+ num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
131
+ num_non_zero = torch.clamp(num_non_zero, min=1.0)
132
+ loss_sum_per_sample = loss.sum(-1) # B
133
+ return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
134
+
135
+
136
+ # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L56
137
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L510
138
+ def dice_loss(
139
+ inputs: torch.Tensor,
140
+ targets: torch.Tensor,
141
+ pixel_gt_void_mask: torch.Tensor,
142
+ matched_cls_prob: torch.Tensor
143
+ ):
144
+ """
145
+ Compute the DICE loss, similar to generalized IOU for masks
146
+ Args:
147
+ inputs: A float tensor of arbitrary shape.
148
+ The predictions for each example.
149
+ targets: A float tensor with the same shape as inputs. Stores the binary
150
+ classification label for each element in inputs
151
+ (0 for the negative class and 1 for the positive class).
152
+ """
153
+ inputs = inputs.softmax(1) # B N HW
154
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/base_loss.py#L111
155
+ inputs = inputs.masked_fill(pixel_gt_void_mask.unsqueeze(1), 0) # remove void pixels.
156
+ smooth = 1.0
157
+ intersection = 2 * (inputs * targets).sum(-1) + smooth # B x N
158
+ denominator = inputs.sum(-1) + targets.sum(-1) + smooth # B x N
159
+ loss = 1.0 - divide_no_nan(intersection, denominator)
160
+ loss *= matched_cls_prob
161
+ # Note: kMaX-DeepLab sum over num_masks and avg over batches. But here batch and num_mask are one
162
+ # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/base_loss.py#L559
163
+ # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L402
164
+ # As the existing of modifer, it equals to multiplier by 0.75
165
+ return (loss.sum(1) * 0.75/128).mean() # sum over masks and mean over batches.
166
+
167
+
168
+ def softmax_ce_loss(
169
+ inputs: torch.Tensor,
170
+ targets: torch.Tensor,
171
+ pixel_gt_void_mask: torch.Tensor,
172
+ ):
173
+ """
174
+ Args:
175
+ inputs: A float tensor of arbitrary shape.
176
+ The predictions for each example.
177
+ targets: A float tensor with the same shape as inputs. Stores the binary
178
+ classification label for each element in inputs
179
+ (0 for the negative class and 1 for the positive class).
180
+ Returns:
181
+ Loss tensor
182
+ """
183
+ loss = F.cross_entropy(inputs, targets, reduction="none") # B x HW
184
+ loss = loss.masked_fill(pixel_gt_void_mask, 0) # remove void pixels.
185
+
186
+ num_non_zero = (loss != 0.0).to(loss).sum(-1) # B
187
+ num_non_zero = torch.clamp(num_non_zero, min=1.0)
188
+ loss_sum_per_sample = loss.sum(-1) # B
189
+ return divide_no_nan(loss_sum_per_sample, num_non_zero).mean() # 1
190
+
191
+
192
+ class SetCriterion(nn.Module):
193
+ """This class computes the loss for DETR.
194
+ The process happens in two steps:
195
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
196
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
197
+ """
198
+
199
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, share_final_matching,
200
+ pixel_insdis_temperature=1.5, pixel_insdis_sample_k=4096,
201
+ aux_semantic_temperature=2.0, aux_semantic_sample_k=4096):
202
+ """Create the criterion.
203
+ Parameters:
204
+ num_classes: number of object categories, omitting the special no-object category
205
+ matcher: module able to compute a matching between targets and proposals
206
+ eos_coef: relative classification weight applied to the no-object category
207
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
208
+ """
209
+ super().__init__()
210
+ self.num_classes = num_classes
211
+ self.matcher = matcher
212
+ self.weight_dict = weight_dict
213
+ self.eos_coef = eos_coef
214
+ self.losses = losses
215
+ self.share_final_matching = share_final_matching
216
+ self.pixel_insdis_temperature = pixel_insdis_temperature
217
+ self.pixel_insdis_sample_k = pixel_insdis_sample_k
218
+ self.aux_semantic_temperature = aux_semantic_temperature
219
+ self.aux_semantic_sample_k = aux_semantic_sample_k
220
+
221
+ def loss_labels(self, outputs, targets):
222
+ """Classification loss (NLL)
223
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
224
+ """
225
+ assert "pred_logits" in outputs
226
+ src_logits = outputs["pred_logits"] # B x N x C
227
+ target_classes = targets["labels"] # B x N
228
+ pq_loss_class_weight = targets["pq_loss_class_weight"]
229
+ losses = {"loss_ce": focal_cross_entropy_loss(src_logits, target_classes, pq_loss_class_weight)}
230
+ return losses
231
+
232
+ def loss_masks(self, outputs, targets):
233
+ """Compute the losses related to the masks: the focal loss and the dice loss.
234
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
235
+ """
236
+ src_masks = outputs["pred_masks"] # B x N x H x W
237
+ target_masks = targets["masks"]
238
+ pq_loss_mask_weight = targets["pq_loss_mask_weight"]
239
+ pixel_gt_void_mask = targets["pixel_gt_void_mask"]
240
+
241
+ src_masks = src_masks.flatten(2) # B x N x HW
242
+ target_masks = target_masks.flatten(2) # B x N x HW
243
+ pixel_gt_void_mask = pixel_gt_void_mask.flatten(1) # B x HW
244
+
245
+ losses = {
246
+ "loss_mask": softmax_ce_loss(src_masks, target_masks, pixel_gt_void_mask),
247
+ "loss_dice": dice_loss(src_masks, target_masks, pixel_gt_void_mask, pq_loss_mask_weight),
248
+ }
249
+
250
+ return losses
251
+
252
+ def loss_pixels(self, outputs, targets):
253
+ pixel_feature = outputs["pixel_feature"]
254
+ target_masks = targets["masks"]
255
+ pixel_gt_void_mask = targets["pixel_gt_void_mask"]
256
+ inverse_gt_mask_area = targets["inverse_gt_mask_area"]
257
+
258
+ losses = {"loss_pixel_insdis": pixelwise_insdis_loss(
259
+ pixel_feature=pixel_feature,
260
+ gt_mask=target_masks,
261
+ sample_temperature=self.pixel_insdis_temperature,
262
+ sample_k=self.pixel_insdis_sample_k,
263
+ instance_discrimination_temperature=0.3,
264
+ pixel_gt_void_mask=pixel_gt_void_mask,
265
+ inverse_gt_mask_area=inverse_gt_mask_area
266
+ )}
267
+
268
+ del target_masks
269
+ return losses
270
+
271
+ def loss_semantic(self, outputs, targets):
272
+ pred_semantic_logits = outputs["aux_semantic_pred"]
273
+ ground_truth_semantic = targets["ground_truth_semantic"]
274
+ pixel_gt_void_mask = targets["pixel_gt_void_mask"].flatten(1)
275
+ inverse_gt_mask_area = targets["inverse_gt_mask_area"].flatten(1)
276
+
277
+ losses = {"loss_aux_semantic": aux_semantic_loss(
278
+ pred_semantic_logits=pred_semantic_logits,
279
+ ground_truth_semantic=ground_truth_semantic,
280
+ sample_temperature=self.aux_semantic_temperature,
281
+ sample_k=self.aux_semantic_sample_k,
282
+ pixel_gt_void_mask=pixel_gt_void_mask,
283
+ inverse_gt_mask_area=inverse_gt_mask_area,
284
+ num_classes=self.num_classes
285
+ )}
286
+ return losses
287
+
288
+ @torch.no_grad()
289
+ def _get_src_permutation_idx(self, indices):
290
+ # permute predictions following indices
291
+ # torch.full_like gives a tensor full of i in shape of src.shape
292
+ # at each iter, i is the index, src is the src ind in shape of (N)
293
+ # so batch_idx is concat of (0,0,...), (1,1,...), with shape (N0+N1+N2+...+Nb)
294
+ # so if we flatten gt/pred across bathces, this gives the batch_id of each sample
295
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
296
+ # src_idx is src_ind concated to shape (N0+N1+N2+...+Nb)
297
+ # it is a flattened concat of mask_id at each batch
298
+ src_idx = torch.cat([src for (src, _) in indices])
299
+ return batch_idx, src_idx
300
+
301
+
302
+ def get_loss(self, loss, outputs, targets):
303
+ loss_map = {
304
+ 'labels': self.loss_labels,
305
+ 'masks': self.loss_masks,
306
+ 'pixels': self.loss_pixels,
307
+ 'aux_semantic': self.loss_semantic,
308
+ }
309
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
310
+ return loss_map[loss](outputs, targets)
311
+
312
+ @torch.no_grad()
313
+ def process_gt(self, outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=False):
314
+ # Permute&Pad Pred&GT for loss compuation.
315
+ # By controling process_gt, we can share the matching results for all preds.
316
+ src_idx = self._get_src_permutation_idx(indices)
317
+
318
+ src_masks = outputs["pred_masks"].detach() # B x N x H x W
319
+
320
+ # Pad and permute the target_mask to B x N x H x W
321
+ target_masks = torch.zeros_like(src_masks)
322
+ target_masks_o = torch.cat([t["masks"][J] for t, (_, J) in zip(targets, indices)]).to(target_masks)
323
+ target_masks[src_idx] = target_masks_o
324
+
325
+ # Pad and permute the matched_cls_prob to B x N
326
+ matched_cls_prob_o = torch.cat([cls_prob for cls_prob in matched_cls_prob])
327
+ matched_cls_prob_o = torch.clamp(matched_cls_prob_o, min=self.eos_coef)
328
+ # https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py#L1034
329
+ # no penalty for unmatched masks.
330
+ matched_cls_prob = torch.full(
331
+ src_masks.shape[:2], 0, dtype=src_masks.dtype, device=src_masks.device
332
+ ) # B x N
333
+ matched_cls_prob[src_idx] = matched_cls_prob_o.to(matched_cls_prob)
334
+
335
+ # pixel_gt_void_mask is used to indicate those pixels without labels.
336
+ pixel_gt_void_mask = (target_masks.sum(1) < 1) # B x H x W
337
+
338
+ # inverse_gt_mask_area is used to sample pixels.
339
+ mask_gt_area = target_masks.sum(2).sum(2) # B x N
340
+ pixel_gt_area = torch.einsum('bnhw,bn->bhw', target_masks, mask_gt_area) # B x H x W
341
+ inverse_gt_mask_area = (pixel_gt_area.shape[1] * pixel_gt_area.shape[2]) / torch.clamp(pixel_gt_area, min=1.0) # B x H x W
342
+
343
+ src_logits = outputs["pred_logits"] # B x N x C
344
+ # Pad and permute the target_classes to B x N
345
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
346
+ # This serves as a padding.
347
+ target_classes = torch.full(
348
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
349
+ )
350
+ # We put real GT to those corresponds to src_idx, and put void into other places.
351
+ target_classes[src_idx] = target_classes_o
352
+
353
+ src_masks_prob = src_masks.softmax(1)
354
+ void_mask = pixel_gt_void_mask.to(src_masks_prob) # B x H x W
355
+ # compute iou instead of dice for void overlapping.
356
+ def computer_iou_score(x, y):
357
+ # x : B x N x H x W
358
+ # y : B x H x W
359
+ x = x.flatten(2) # B x N x L
360
+ y = y.flatten(1) # B x L
361
+ intersection = torch.einsum('bnl,bl->bn', x, y) # B x N
362
+ denominator = x.sum(-1) # B x N
363
+ return intersection / (denominator + 1e-5) # B x N
364
+
365
+ # Pad and permute the matched_dice to B x N
366
+ matched_dice_o = torch.cat([dice for dice in matched_dice])
367
+ matched_dice = computer_iou_score(src_masks_prob, void_mask) # unmatched masks use their dice with void
368
+ matched_dice[src_idx] = matched_dice_o.to(matched_dice)
369
+ matched_dice = torch.clamp(matched_dice, min=self.eos_coef)
370
+
371
+
372
+ processed_gt = {"masks": target_masks, "labels": target_classes,
373
+ "pq_loss_mask_weight": matched_cls_prob,
374
+ "pq_loss_class_weight": matched_dice,
375
+ "pixel_gt_void_mask": pixel_gt_void_mask,
376
+ "inverse_gt_mask_area": inverse_gt_mask_area,}
377
+
378
+ if process_semantic:
379
+ # To obtain semantic gt
380
+ ground_truth_semantic = [t["semantic_masks"] for t in targets]
381
+ ground_truth_semantic = torch.stack(ground_truth_semantic, dim=0) # B x H x W
382
+ # self.num_classes is set to ignore label
383
+ ground_truth_semantic[ground_truth_semantic==-1] = self.num_classes
384
+ processed_gt.update({"ground_truth_semantic": ground_truth_semantic})
385
+
386
+ return processed_gt
387
+
388
+
389
+ def forward(self, outputs, targets):
390
+ """This performs the loss computation.
391
+ Parameters:
392
+ outputs: dict of tensors, see the output specification of the model for the format
393
+ targets: list of dicts, such that len(targets) == batch_size.
394
+ The expected keys in each dict depends on the losses applied, see each loss' doc
395
+ """
396
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
397
+ indices, matched_dice, matched_cls_prob = self.matcher(outputs_without_aux, targets)
398
+ # Pad GT to the same number of prediction.
399
+ processed_targets = self.process_gt(outputs, targets, indices, matched_dice, matched_cls_prob, process_semantic=True)
400
+ # Compute all the requested losses
401
+ losses = {}
402
+ for loss in self.losses:
403
+ losses.update(self.get_loss(loss, outputs, processed_targets))
404
+
405
+ if "aux_outputs" in outputs:
406
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
407
+ # We share matching results across predictions.
408
+ if not self.share_final_matching:
409
+ indices, matched_dice, matched_cls_prob = self.matcher(aux_outputs, targets)
410
+ if not self.share_final_matching:
411
+ processed_targets = self.process_gt(aux_outputs, targets, indices, matched_dice, matched_cls_prob)
412
+ for loss in self.losses:
413
+ if loss in ['aux_semantic']:
414
+ # Only for final output.
415
+ continue
416
+ l_dict = self.get_loss(loss, aux_outputs, processed_targets)
417
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
418
+ losses.update(l_dict)
419
+ return losses
420
+
421
+ def __repr__(self):
422
+ head = "Criterion " + self.__class__.__name__
423
+ body = [
424
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
425
+ "losses: {}".format(self.losses),
426
+ "weight_dict: {}".format(self.weight_dict),
427
+ "num_classes: {}".format(self.num_classes),
428
+ "eos_coef: {}".format(self.eos_coef),
429
+ ]
430
+ _repr_indent = 4
431
+ lines = [head] + [" " * _repr_indent + line for line in body]
432
+ return "\n".join(lines)
kmax_deeplab/modeling/matcher.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py
2
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/loss/max_deeplab_loss.py
3
+ # Modified by Qihang Yu
4
+
5
+ """
6
+ Modules to compute the matching cost and solve the corresponding LSAP.
7
+ """
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+ from torch import nn
12
+ from torch.cuda.amp import autocast
13
+ import numpy as np
14
+
15
+
16
+ # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L158
17
+ @torch.no_grad()
18
+ def compute_mask_similarity(inputs: torch.Tensor, targets: torch.Tensor):
19
+ """
20
+ Compute the DICE loss, similar to generalized IOU for masks
21
+ Args:
22
+ inputs: A float tensor of arbitrary shape.
23
+ The predictions for each example.
24
+ targets: A float tensor with the same shape as inputs. Stores the binary
25
+ classification label for each element in inputs
26
+ (0 for the negative class and 1 for the positive class).
27
+ """
28
+ denominator_epsilon = 1e-5
29
+ inputs = F.softmax(inputs, dim=0)
30
+ inputs = inputs.flatten(1) # N x HW
31
+
32
+ pixel_gt_non_void_mask = (targets.sum(0, keepdim=True) > 0).to(inputs) # 1xHW
33
+ inputs = inputs * pixel_gt_non_void_mask
34
+
35
+ intersection = torch.einsum("nc,mc->nm", inputs, targets)
36
+ denominator = (inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]) / 2.0
37
+ return intersection / (denominator + denominator_epsilon)
38
+
39
+
40
+ # https://github.com/google-research/deeplab2/blob/c4a533c14fac1a1071a6d24c5379c31a69a3e5e6/model/loss/max_deeplab_loss.py#L941
41
+ @torch.no_grad()
42
+ def compute_class_similarity(inputs: torch.Tensor, targets: torch.Tensor):
43
+ pred_class_prob = inputs.softmax(-1)[..., :-1] # exclude the void class
44
+ return pred_class_prob[:, targets]
45
+
46
+
47
+ class HungarianMatcher(nn.Module):
48
+ """This class computes an assignment between the targets and the predictions of the network
49
+
50
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
51
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
52
+ while the others are un-matched (and thus treated as non-objects).
53
+ """
54
+
55
+ def __init__(self):
56
+ """Creates the matcher
57
+
58
+ Params:
59
+ cost_class: This is the relative weight of the classification error in the matching cost
60
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
61
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
62
+ """
63
+ super().__init__()
64
+
65
+ @torch.no_grad()
66
+ def memory_efficient_forward(self, outputs, targets):
67
+ """More memory-friendly matching"""
68
+ bs, num_queries = outputs["pred_logits"].shape[:2]
69
+
70
+ indices = []
71
+ matched_dice = []
72
+ matched_cls_prob = []
73
+ # Iterate through batch size
74
+ for b in range(bs):
75
+ with autocast(enabled=False):
76
+ class_similarity = compute_class_similarity(outputs["pred_logits"][b].float(), targets[b]["labels"])
77
+ out_mask = outputs["pred_masks"][b].flatten(1) # [num_queries, H_pred, W_pred]
78
+ # gt masks are already padded when preparing target
79
+ tgt_mask = targets[b]["masks"].to(out_mask).flatten(1)
80
+ with autocast(enabled=False):
81
+ mask_similarity = compute_mask_similarity(out_mask.float(), tgt_mask.float())
82
+
83
+ # Final cost matrix
84
+ C = - mask_similarity * class_similarity
85
+ C = C.reshape(num_queries, -1).cpu() # N x M , N = num_queries, M = num_gt
86
+
87
+ # the assignment will be truncated to a square matrix.
88
+ row_ind, col_ind = linear_sum_assignment(C)
89
+ matched_dice.append(mask_similarity[row_ind, col_ind].detach())
90
+ matched_cls_prob.append(class_similarity[row_ind, col_ind].detach())
91
+ indices.append((row_ind, col_ind)) # row_ind and col_ind, row_ind = 0,1,2,3,...,N-1, col_ind = a,b,c,d,...
92
+
93
+ indices = [
94
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
95
+ for i, j in indices
96
+ ]
97
+
98
+ return indices, matched_dice, matched_cls_prob
99
+
100
+
101
+ @torch.no_grad()
102
+ def forward(self, outputs, targets):
103
+ """Performs the matching
104
+
105
+ Params:
106
+ outputs: This is a dict that contains at least these entries:
107
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
108
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
109
+
110
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
111
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
112
+ objects in the target) containing the class labels
113
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
114
+
115
+ Returns:
116
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
117
+ - index_i is the indices of the selected predictions (in order)
118
+ - index_j is the indices of the corresponding selected targets (in order)
119
+ For each batch element, it holds:
120
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
121
+ """
122
+ return self.memory_efficient_forward(outputs, targets)
123
+
124
+ def __repr__(self, _repr_indent=4):
125
+ head = "Matcher " + self.__class__.__name__
126
+ body = []
127
+ lines = [head] + [" " * _repr_indent + line for line in body]
128
+ return "\n".join(lines)
kmax_deeplab/modeling/meta_arch/__init__.py ADDED
File without changes
kmax_deeplab/modeling/meta_arch/kmax_deeplab_head.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/meta_arch/mask_former_head.py
2
+ # Modified by Qihang Yu
3
+
4
+ from typing import Dict
5
+
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import ShapeSpec
11
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
12
+
13
+ from ..transformer_decoder.kmax_transformer_decoder import build_transformer_decoder
14
+
15
+
16
+ def build_pixel_decoder(cfg, input_shape):
17
+ """
18
+ Build a pixel decoder from `cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME`.
19
+ """
20
+ name = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.NAME
21
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
22
+ forward_features = getattr(model, "forward_features", None)
23
+ if not callable(forward_features):
24
+ raise ValueError(
25
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
26
+ f"Please implement forward_features for {name} to only return mask features."
27
+ )
28
+ return model
29
+
30
+
31
+ @SEM_SEG_HEADS_REGISTRY.register()
32
+ class kMaXDeepLabHead(nn.Module):
33
+
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ input_shape: Dict[str, ShapeSpec],
38
+ *,
39
+ num_classes: int,
40
+ pixel_decoder: nn.Module,
41
+ loss_weight: float = 1.0,
42
+ ignore_value: int = -1,
43
+ transformer_predictor: nn.Module,
44
+ ):
45
+ """
46
+ NOTE: this interface is experimental.
47
+ Args:
48
+ input_shape: shapes (channels and stride) of the input features
49
+ num_classes: number of classes to predict
50
+ pixel_decoder: the pixel decoder module
51
+ loss_weight: loss weight
52
+ ignore_value: category id to be ignored during training.
53
+ transformer_predictor: the transformer decoder that makes prediction
54
+ transformer_in_feature: input feature name to the transformer_predictor
55
+ """
56
+ super().__init__()
57
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
58
+ self.in_features = [k for k, v in input_shape]
59
+
60
+ self.ignore_value = ignore_value
61
+ self.common_stride = 4
62
+ self.loss_weight = loss_weight
63
+
64
+ self.pixel_decoder = pixel_decoder
65
+ self.predictor = transformer_predictor
66
+
67
+ self.num_classes = num_classes
68
+
69
+ @classmethod
70
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
71
+ return {
72
+ "input_shape": {
73
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES
74
+ },
75
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
76
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
77
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
78
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
79
+ "transformer_predictor": build_transformer_decoder(cfg, input_shape),
80
+ }
81
+
82
+ def forward(self, features):
83
+ return self.layers(features)
84
+
85
+ def layers(self, features):
86
+ panoptic_features, semantic_features, multi_scale_features = self.pixel_decoder.forward_features(features)
87
+ predictions = self.predictor(multi_scale_features, panoptic_features, semantic_features)
88
+ return predictions
kmax_deeplab/modeling/pixel_decoder/__init__.py ADDED
File without changes
kmax_deeplab/modeling/pixel_decoder/kmax_pixel_decoder.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py
2
+ # Modified by Qihang Yu
3
+
4
+ from typing import Dict, List
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from timm.models.layers import DropPath
11
+ from timm.models.layers import trunc_normal_tf_ as trunc_normal_
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.layers import ShapeSpec
15
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
16
+ from torch.cuda.amp import autocast
17
+
18
+ from ..backbone.convnext import LayerNorm
19
+
20
+ import math
21
+
22
+
23
+ def get_activation(name):
24
+ if name is None or name.lower() == 'none':
25
+ return nn.Identity()
26
+ if name == 'relu':
27
+ return nn.ReLU()
28
+ elif name == 'gelu':
29
+ return nn.GELU()
30
+
31
+
32
+ def get_norm(name, channels):
33
+ if name is None or name.lower() == 'none':
34
+ return nn.Identity()
35
+
36
+ if name.lower() == 'syncbn':
37
+ return nn.SyncBatchNorm(channels, eps=1e-3, momentum=0.01)
38
+
39
+
40
+ class ConvBN(nn.Module):
41
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=None, act=None,
42
+ conv_type='2d', conv_init='he_normal', norm_init=1.0):
43
+ super().__init__()
44
+
45
+ if conv_type == '2d':
46
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
47
+ elif conv_type == '1d':
48
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
49
+
50
+ self.norm = get_norm(norm, out_channels)
51
+ self.act = get_activation(act)
52
+
53
+ if conv_init == 'normal':
54
+ nn.init.normal_(self.conv.weight, std=.02)
55
+ elif conv_init == 'trunc_normal':
56
+ trunc_normal_(self.conv.weight, std=.02)
57
+ elif conv_init == 'he_normal':
58
+ # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal
59
+ trunc_normal_(self.conv.weight, std=math.sqrt(2.0 / in_channels))
60
+ elif conv_init == 'xavier_uniform':
61
+ nn.init.xavier_uniform_(self.conv.weight)
62
+ if bias:
63
+ nn.init.zeros_(self.conv.bias)
64
+
65
+ if norm is not None:
66
+ nn.init.constant_(self.norm.weight, norm_init)
67
+
68
+ def forward(self, x):
69
+ return self.act(self.norm(self.conv(x)))
70
+
71
+
72
+ MAX_SPAN = 255
73
+ def _compute_relative_distance_matrix(query_length, key_length):
74
+ if (key_length - query_length) % 2:
75
+ raise ValueError('Key_length should be query_length + 2 * memory_flange.')
76
+ key_index = torch.arange(key_length)
77
+ query_index = torch.arange(query_length) + (key_length - query_length) // 2
78
+ distance_matrix = key_index[None, :] - query_index[:, None]
79
+ # Shift the distance_matrix so that it is >= 0. Each entry of the
80
+ # distance_matrix distance will index a relative positional embedding.
81
+ distance_matrix = distance_matrix + MAX_SPAN - 1
82
+ return distance_matrix
83
+
84
+
85
+ class RelativePositionalEncoding(nn.Module):
86
+ def __init__(self, query_length, key_length, depth):
87
+ super().__init__()
88
+ self._embeddings = nn.Embedding(MAX_SPAN * 2 - 1, depth)
89
+ trunc_normal_(self._embeddings.weight, std=1.0)
90
+ self._relative_distance_matrix = _compute_relative_distance_matrix(query_length, key_length)
91
+ self.query_length = query_length
92
+ self.key_length = key_length
93
+ self.depth = depth
94
+
95
+ def forward(self):
96
+ return self._embeddings.weight[self._relative_distance_matrix.reshape(-1)].reshape(self.query_length, self.key_length, self.depth)
97
+
98
+
99
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L36
100
+ class AxialAttention(nn.Module):
101
+ def __init__(self, in_planes, query_shape=56, total_key_depth=512, total_value_depth=1024, num_heads=8):
102
+ assert (total_key_depth % num_heads == 0) and (total_value_depth % num_heads == 0)
103
+ super().__init__()
104
+ self._in_planes = in_planes
105
+ self._query_shape = query_shape
106
+ self._total_key_depth = total_key_depth
107
+ self._total_value_depth = total_value_depth
108
+ self._num_heads = num_heads
109
+ self._key_depth_per_head = total_key_depth // num_heads
110
+
111
+ self.qkv_transform = ConvBN(in_planes, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, stride=1,
112
+ padding=0, bias=False, norm=None, act=None, conv_type='1d')
113
+ trunc_normal_(self.qkv_transform.conv.weight, std=in_planes ** -0.5)
114
+
115
+ self._query_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
116
+ self._key_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head)
117
+ self._value_rpe = RelativePositionalEncoding(query_shape, query_shape, total_value_depth // num_heads)
118
+
119
+ self._batch_norm_qkv = get_norm('syncbn', self._total_key_depth * 2 + self._total_value_depth)
120
+ self._batch_norm_similarity = get_norm('syncbn', num_heads * 3)
121
+ self._batch_norm_retrieved_output = get_norm('syncbn', self._total_value_depth * 2)
122
+
123
+
124
+ def forward(self, x):
125
+ N, C, L = x.shape
126
+ qkv = self._batch_norm_qkv(self.qkv_transform(x))
127
+ q, k, v = torch.split(qkv, [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1)
128
+ q = q.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
129
+ k = k.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L)
130
+ v = v.reshape(N, self._num_heads, self._total_value_depth // self._num_heads, L)
131
+
132
+ similarity_logits = []
133
+ content_similarity = torch.einsum('bhdl,bhdm->bhlm', q, k)
134
+ query_rpe = self._query_rpe()
135
+ query_rpe_similarity = torch.einsum('bhdl,lmd->bhlm', q, query_rpe)
136
+ key_rpe = self._key_rpe()
137
+ key_rpe_similarity = torch.einsum('bhdm,lmd->bhlm', k, key_rpe)
138
+ similarity_logits = torch.cat([content_similarity, query_rpe_similarity, key_rpe_similarity], dim=1)
139
+ similarity_logits = self._batch_norm_similarity(similarity_logits).reshape(N, 3, self._num_heads, L, L).sum(dim=1)
140
+
141
+ with autocast(enabled=False):
142
+ weights = F.softmax(similarity_logits.float(), dim=-1)
143
+
144
+ retrieved_content = torch.einsum('bhlm,bhdm->bhdl', weights, v)
145
+ value_rpe = self._value_rpe()
146
+ retrieved_rpe = torch.einsum('bhlm,lmd->bhdl', weights, value_rpe)
147
+
148
+ retrieved_output = torch.cat([retrieved_content, retrieved_rpe], dim=1).reshape(N, 2*self._total_value_depth, L)
149
+ retrieved_output = self._batch_norm_retrieved_output(retrieved_output).reshape(N, 2, self._total_value_depth, L).sum(1)
150
+
151
+ return retrieved_output
152
+
153
+
154
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L316
155
+ class AxialAttention2D(nn.Module):
156
+ def __init__(self, in_planes, query_shape=[56, 56], filters=512, key_expansion=1, value_expansion=2, num_heads=8):
157
+ super().__init__()
158
+ total_key_depth = int(round(filters * key_expansion))
159
+ total_value_depth = int(round(filters * value_expansion))
160
+ self._total_key_depth = total_key_depth
161
+ self._total_value_depth = total_value_depth
162
+ self._height_axis = AxialAttention(
163
+ in_planes=in_planes,
164
+ query_shape=query_shape[0],
165
+ total_key_depth=total_key_depth,
166
+ total_value_depth=total_value_depth,
167
+ num_heads=num_heads)
168
+ self._width_axis = AxialAttention(
169
+ in_planes=total_value_depth,
170
+ query_shape=query_shape[1],
171
+ total_key_depth=total_key_depth,
172
+ total_value_depth=total_value_depth,
173
+ num_heads=num_heads)
174
+
175
+ def forward(self, x):
176
+ # N C H W -> N W C H
177
+ N, C, H, W = x.shape
178
+ x = x.permute(0, 3, 1, 2).contiguous()
179
+ x = x.reshape(N*W, C, H)
180
+ x = self._height_axis(x)
181
+ # N W C H -> N H C W
182
+ x = x.reshape(N, W, self._total_value_depth, H).permute(0, 3, 2, 1).contiguous()
183
+ x = x.reshape(N*H, self._total_value_depth, W)
184
+ x = self._width_axis(x)
185
+ x = x.reshape(N, H, self._total_value_depth, W).permute(0, 2, 1, 3).contiguous()
186
+ x = x.reshape(N, self._total_value_depth, H, W)
187
+ return x
188
+
189
+
190
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_blocks.py#L36
191
+ class SingleBlock(nn.Module):
192
+
193
+ def __init__(self, inplanes, filter_list, block_type, query_shape=[56, 56], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0):
194
+ super(SingleBlock, self).__init__()
195
+ self._block_type = block_type.lower()
196
+ self._filter_list = filter_list
197
+ self._conv1_bn_act = ConvBN(inplanes, self._filter_list[0], kernel_size=1, bias=False, norm='syncbn', act='gelu')
198
+ if self._block_type == 'axial':
199
+ self._attention = AxialAttention2D(in_planes=self._filter_list[0], query_shape=query_shape, filters=self._filter_list[1],
200
+ key_expansion=key_expansion, value_expansion=value_expansion, num_heads=num_heads)
201
+ output_channel = filter_list[1] * value_expansion
202
+ elif self._block_type == 'bottleneck':
203
+ self._conv2_bn_act = ConvBN(self._filter_list[0], self._filter_list[1], kernel_size=3, padding=1, bias=False, norm='syncbn', act='gelu')
204
+ output_channel = filter_list[1]
205
+ self._conv3_bn = ConvBN(output_channel, self._filter_list[2], kernel_size=1, bias=False, norm='syncbn', act=None, norm_init=0.0)
206
+
207
+ self._shortcut = None
208
+ if inplanes != self._filter_list[-1]:
209
+ self._shortcut = ConvBN(inplanes, self._filter_list[-1], kernel_size=1, bias=False, norm='syncbn', act=None)
210
+ self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
211
+
212
+ def forward(self, x):
213
+ x = F.gelu(x)
214
+
215
+ shortcut = x
216
+ if self._shortcut is not None:
217
+ shortcut = self._shortcut(shortcut)
218
+
219
+ x = self._conv1_bn_act(x)
220
+ if self._block_type == 'axial':
221
+ x = self._attention(x)
222
+ x = F.gelu(x)
223
+ elif self._block_type == 'bottleneck':
224
+ x = self._conv2_bn_act(x)
225
+ x = self._conv3_bn(x)
226
+
227
+ x = self.drop_path(x) + shortcut
228
+
229
+ return x
230
+
231
+
232
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L42
233
+ class BlockGroup(nn.Module):
234
+ def __init__(self, inplanes, base_filter, num_blocks, block_type, **kwargs):
235
+ super().__init__()
236
+ self._num_blocks = num_blocks
237
+ block_type = block_type.lower()
238
+ if block_type == 'axial':
239
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L247
240
+ filter_list = [base_filter * 2, base_filter, base_filter * 4]
241
+ elif block_type == 'bottleneck':
242
+ # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L250
243
+ filter_list = [base_filter, base_filter, base_filter * 4]
244
+
245
+ self._blocks = nn.ModuleList()
246
+ for i in range(num_blocks):
247
+ self._blocks.append(SingleBlock(inplanes=inplanes, filter_list=filter_list, block_type=block_type, **kwargs))
248
+ inplanes = filter_list[-1]
249
+
250
+ def forward(self, x):
251
+ for i in range(self._num_blocks):
252
+ x = self._blocks[i](x)
253
+ return x
254
+
255
+
256
+ # https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/resized_fuse.py#L31
257
+ class ResizedFuse(nn.Module):
258
+ def __init__(self, low_in_channels, high_in_channels, out_channels):
259
+ super().__init__()
260
+ self.low_in_channels = low_in_channels
261
+ self.high_in_channels = high_in_channels
262
+ self.out_channels = out_channels
263
+ if low_in_channels != out_channels:
264
+ self._conv_bn_low = ConvBN(low_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
265
+ if high_in_channels != out_channels:
266
+ self._conv_bn_high = ConvBN(high_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None)
267
+
268
+ def forward(self, lowres_x, highres_x):
269
+
270
+ align_corners = (lowres_x.shape[-1] % 2 == 1)
271
+ if self.low_in_channels != self.out_channels:
272
+ lowres_x = F.gelu(lowres_x)
273
+ lowres_x = self._conv_bn_low(lowres_x)
274
+ lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
275
+ else:
276
+ lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners)
277
+
278
+ if self.high_in_channels != self.out_channels:
279
+ highres_x = F.gelu(highres_x)
280
+ highres_x = self._conv_bn_high(highres_x)
281
+
282
+ return lowres_x + highres_x
283
+
284
+
285
+ @SEM_SEG_HEADS_REGISTRY.register()
286
+ class kMaXPixelDecoder(nn.Module):
287
+ @configurable
288
+ def __init__(
289
+ self,
290
+ input_shape: Dict[str, ShapeSpec],
291
+ *,
292
+ dec_layers: List[int],
293
+ dec_channels: List[int],
294
+ layer_types: List[str],
295
+ drop_path_prob: float,
296
+ spatial_shape: List[int],
297
+ ):
298
+ """
299
+ NOTE: this interface is experimental.
300
+ Args:
301
+ """
302
+ super().__init__()
303
+ self.num_stages = len(input_shape)
304
+ assert self.num_stages == len(dec_layers) and self.num_stages == len(dec_channels) and self.num_stages == len(layer_types)
305
+ # For now, we hard code all hyper-parameters.
306
+ block_types = ['axial', 'axial', 'bottleneck', 'bottleneck']
307
+ input_shape = sorted(input_shape.items(), key=lambda x: -x[1].stride)
308
+ self.in_features = [k for k, v in input_shape] # starting from "res5" to "res2"
309
+ in_channels = [v.channels for k, v in input_shape]
310
+
311
+ add_one = (spatial_shape[0] % 2, spatial_shape[1] % 2)
312
+ query_shape = [
313
+ (spatial_shape[0]//32+add_one[0], spatial_shape[1]//32+add_one[1]),
314
+ (spatial_shape[0]//16+add_one[0], spatial_shape[1]//16+add_one[1]),
315
+ (spatial_shape[0]//8+add_one[0], spatial_shape[1]//8+add_one[1]),
316
+ (spatial_shape[0]//4+add_one[0], spatial_shape[1]//4+add_one[1])]
317
+
318
+ self._in_norms = nn.ModuleList()
319
+ self._stages = nn.ModuleList()
320
+ self._resized_fuses = nn.ModuleList()
321
+
322
+ for i in range(self.num_stages):
323
+ self._in_norms.append(LayerNorm(in_channels[i], data_format="channels_first"))
324
+ inplanes = in_channels[i] if i == 0 else dec_channels[i]
325
+ self._stages.append(BlockGroup(inplanes=inplanes,
326
+ base_filter=dec_channels[i], num_blocks=dec_layers[i], block_type=block_types[i],
327
+ query_shape=query_shape[i], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0))
328
+
329
+ if i > 0:
330
+ self._resized_fuses.append(ResizedFuse(
331
+ low_in_channels=dec_channels[i-1] * 4,
332
+ high_in_channels=in_channels[i],
333
+ out_channels=dec_channels[i]))
334
+
335
+
336
+ @classmethod
337
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
338
+ ret = {}
339
+ ret["input_shape"] = {
340
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES
341
+ }
342
+ ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS
343
+ ret["dec_channels"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS
344
+ ret["layer_types"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES
345
+ ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB
346
+ ret["spatial_shape"] = cfg.INPUT.IMAGE_SIZE # We expect the height == width
347
+ return ret
348
+
349
+
350
+ def forward_features(self, features):
351
+ out = []
352
+ multi_scale_features = []
353
+
354
+ x = self._in_norms[0](features[self.in_features[0]])
355
+
356
+ for idx in range(self.num_stages - 1):
357
+ x = self._stages[idx](x)
358
+ out.append(x)
359
+ x = self._resized_fuses[idx](
360
+ lowres_x=x,
361
+ highres_x=self._in_norms[idx+1](features[self.in_features[idx+1]]))
362
+
363
+ x = self._stages[-1](x)
364
+ out.append(x)
365
+ multi_scale_features = out[:3] # OS32, 16, 8, they are used for kmax_transformer_decoder.
366
+ panoptic_features = out[-1] # OS4, it is used for final mask prediction.
367
+ # OS 32, 8, 4
368
+ semantic_features = [features[self.in_features[0]], features[self.in_features[2]], features[self.in_features[3]]]
369
+ return panoptic_features, semantic_features, multi_scale_features
370
+
kmax_deeplab/modeling/transformer_decoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .kmax_transformer_decoder import kMaXTransformerDecoder
kmax_deeplab/modeling/transformer_decoder/kmax_transformer_decoder.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/google-research/deeplab2/blob/main/model/transformer_decoder/kmax.py
2
+ # Modified by Qihang Yu
3
+
4
+ from typing import List
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.cuda.amp import autocast
9
+
10
+ from timm.models.layers import DropPath
11
+ from timm.models.layers import trunc_normal_tf_ as trunc_normal_
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.utils.registry import Registry
15
+
16
+ from ..pixel_decoder.kmax_pixel_decoder import get_norm, ConvBN
17
+
18
+ import math
19
+
20
+
21
+ TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
22
+ TRANSFORMER_DECODER_REGISTRY.__doc__ = """
23
+ Registry for transformer module.
24
+ """
25
+ def build_transformer_decoder(cfg, input_shape_from_backbone):
26
+ """
27
+ Build a instance embedding branch from `cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME`.
28
+ """
29
+ name = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NAME
30
+ return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, input_shape_from_backbone)
31
+
32
+
33
+ # https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/decoder/max_deeplab.py#L60
34
+ def add_bias_towards_void(query_class_logits, void_prior_prob=0.9):
35
+ class_logits_shape = query_class_logits.shape
36
+ init_bias = [0.0] * class_logits_shape[-1]
37
+ init_bias[-1] = math.log(
38
+ (class_logits_shape[-1] - 1) * void_prior_prob / (1 - void_prior_prob))
39
+ return query_class_logits + torch.tensor(init_bias, dtype=query_class_logits.dtype).to(query_class_logits)
40
+
41
+
42
+ # https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/dual_path_transformer.py#L41
43
+ class AttentionOperation(nn.Module):
44
+ def __init__(self, channels_v, num_heads):
45
+ super().__init__()
46
+ self._batch_norm_similarity = get_norm('syncbn', num_heads)
47
+ self._batch_norm_retrieved_value = get_norm('syncbn', channels_v)
48
+
49
+ def forward(self, query, key, value):
50
+ N, _, _, L = query.shape
51
+ _, num_heads, C, _ = value.shape
52
+ similarity_logits = torch.einsum('bhdl,bhdm->bhlm', query, key)
53
+ similarity_logits = self._batch_norm_similarity(similarity_logits)
54
+
55
+ with autocast(enabled=False):
56
+ attention_weights = F.softmax(similarity_logits.float(), dim=-1)
57
+ retrieved_value = torch.einsum(
58
+ 'bhlm,bhdm->bhdl', attention_weights, value)
59
+ retrieved_value = retrieved_value.reshape(N, num_heads * C, L)
60
+ retrieved_value = self._batch_norm_retrieved_value(
61
+ retrieved_value)
62
+ retrieved_value = F.gelu(retrieved_value)
63
+ return retrieved_value
64
+
65
+
66
+ # https://github.com/google-research/deeplab2/blob/main/model/kmax_deeplab.py#L32
67
+ class kMaXPredictor(nn.Module):
68
+ def __init__(self, in_channel_pixel, in_channel_query, num_classes=133+1):
69
+ super().__init__()
70
+ self._pixel_space_head_conv0bnact = ConvBN(in_channel_pixel, in_channel_pixel, kernel_size=5, groups=in_channel_pixel, padding=2, bias=False,
71
+ norm='syncbn', act='gelu', conv_init='xavier_uniform')
72
+ self._pixel_space_head_conv1bnact = ConvBN(in_channel_pixel, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu')
73
+ self._pixel_space_head_last_convbn = ConvBN(256, 128, kernel_size=1, bias=True, norm='syncbn', act=None)
74
+ trunc_normal_(self._pixel_space_head_last_convbn.conv.weight, std=0.01)
75
+
76
+ self._transformer_mask_head = ConvBN(256, 128, kernel_size=1, bias=False, norm='syncbn', act=None, conv_type='1d')
77
+ self._transformer_class_head = ConvBN(256, num_classes, kernel_size=1, norm=None, act=None, conv_type='1d')
78
+ trunc_normal_(self._transformer_class_head.conv.weight, std=0.01)
79
+
80
+ self._pixel_space_mask_batch_norm = get_norm('syncbn', channels=1)
81
+ nn.init.constant_(self._pixel_space_mask_batch_norm.weight, 0.1)
82
+
83
+
84
+ def forward(self, mask_embeddings, class_embeddings, pixel_feature):
85
+ # mask_embeddings/class_embeddings: B x C x N
86
+ # pixel feature: B x C x H x W
87
+ pixel_space_feature = self._pixel_space_head_conv0bnact(pixel_feature)
88
+ pixel_space_feature = self._pixel_space_head_conv1bnact(pixel_space_feature)
89
+ pixel_space_feature = self._pixel_space_head_last_convbn(pixel_space_feature)
90
+ pixel_space_normalized_feature = F.normalize(pixel_space_feature, p=2, dim=1)
91
+
92
+ cluster_class_logits = self._transformer_class_head(class_embeddings).permute(0, 2, 1).contiguous()
93
+ cluster_class_logits = add_bias_towards_void(cluster_class_logits)
94
+ cluster_mask_kernel = self._transformer_mask_head(mask_embeddings)
95
+ mask_logits = torch.einsum('bchw,bcn->bnhw',
96
+ pixel_space_normalized_feature, cluster_mask_kernel)
97
+
98
+ mask_logits = self._pixel_space_mask_batch_norm(mask_logits.unsqueeze(dim=1)).squeeze(dim=1)
99
+
100
+
101
+ return {
102
+ 'class_logits': cluster_class_logits,
103
+ 'mask_logits': mask_logits,
104
+ 'pixel_feature': pixel_space_normalized_feature}
105
+
106
+
107
+ # https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/dual_path_transformer.py#L107
108
+ class kMaXTransformerLayer(nn.Module):
109
+ def __init__(
110
+ self,
111
+ num_classes=133,
112
+ in_channel_pixel=2048,
113
+ in_channel_query=256,
114
+ base_filters=128,
115
+ num_heads=8,
116
+ bottleneck_expansion=2,
117
+ key_expansion=1,
118
+ value_expansion=2,
119
+ drop_path_prob=0.0,
120
+ ):
121
+ super().__init__()
122
+
123
+ self._num_classes = num_classes
124
+ self._num_heads = num_heads
125
+ self._bottleneck_channels = int(round(base_filters * bottleneck_expansion))
126
+ self._total_key_depth = int(round(base_filters * key_expansion))
127
+ self._total_value_depth = int(round(base_filters * value_expansion))
128
+
129
+ # Per tf2 implementation, the same drop path prob are applied to:
130
+ # 1. k-means update for object query
131
+ # 2. self/cross-attetion for object query
132
+ # 3. ffn for object query
133
+ self.drop_path_kmeans = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
134
+ self.drop_path_attn = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
135
+ self.drop_path_ffn = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity()
136
+
137
+ initialization_std = self._bottleneck_channels ** -0.5
138
+ self._query_conv1_bn_act = ConvBN(in_channel_query, self._bottleneck_channels, kernel_size=1, bias=False,
139
+ norm='syncbn', act='gelu', conv_type='1d')
140
+
141
+ self._pixel_conv1_bn_act = ConvBN(in_channel_pixel, self._bottleneck_channels, kernel_size=1, bias=False,
142
+ norm='syncbn', act='gelu')
143
+
144
+ self._query_qkv_conv_bn = ConvBN(self._bottleneck_channels, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, bias=False,
145
+ norm='syncbn', act=None, conv_type='1d')
146
+ trunc_normal_(self._query_qkv_conv_bn.conv.weight, std=initialization_std)
147
+
148
+ self._pixel_v_conv_bn = ConvBN(self._bottleneck_channels, self._total_value_depth, kernel_size=1, bias=False,
149
+ norm='syncbn', act=None)
150
+ trunc_normal_(self._pixel_v_conv_bn.conv.weight, std=initialization_std)
151
+
152
+ self._query_self_attention = AttentionOperation(channels_v=self._total_value_depth, num_heads=num_heads)
153
+
154
+ self._query_conv3_bn = ConvBN(self._total_value_depth, in_channel_query, kernel_size=1, bias=False,
155
+ norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
156
+
157
+ self._query_ffn_conv1_bn_act = ConvBN(in_channel_query, 2048, kernel_size=1, bias=False,
158
+ norm='syncbn', act='gelu', conv_type='1d')
159
+ self._query_ffn_conv2_bn = ConvBN(2048, in_channel_query, kernel_size=1, bias=False,
160
+ norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
161
+
162
+ self._predcitor = kMaXPredictor(in_channel_pixel=self._bottleneck_channels,
163
+ in_channel_query=self._bottleneck_channels, num_classes=num_classes)
164
+ self._kmeans_query_batch_norm_retrieved_value = get_norm('syncbn', self._total_value_depth)
165
+ self._kmeans_query_conv3_bn = ConvBN(self._total_value_depth, in_channel_query, kernel_size=1, bias=False,
166
+ norm='syncbn', act=None, conv_type='1d', norm_init=0.0)
167
+
168
+
169
+ def forward(self, pixel_feature, query_feature):
170
+ N, C, H, W = pixel_feature.shape
171
+ _, D, L = query_feature.shape
172
+ pixel_space = self._pixel_conv1_bn_act(F.gelu(pixel_feature)) # N C H W
173
+ query_space = self._query_conv1_bn_act(query_feature) # N x C x L
174
+
175
+ # k-means cross-attention.
176
+ pixel_value = self._pixel_v_conv_bn(pixel_space) # N C H W
177
+ pixel_value = pixel_value.reshape(N, self._total_value_depth, H*W)
178
+ # k-means assignment.
179
+ prediction_result = self._predcitor(
180
+ mask_embeddings=query_space, class_embeddings=query_space, pixel_feature=pixel_space)
181
+ clustering_result = prediction_result['mask_logits'].flatten(2).detach() # N L HW
182
+
183
+ with torch.no_grad():
184
+ clustering_result = prediction_result['mask_logits'].flatten(2).detach() # N L HW
185
+ index = clustering_result.max(1, keepdim=True)[1]
186
+ clustering_result = torch.zeros_like(clustering_result, memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0)
187
+
188
+ with autocast(enabled=False):
189
+ # k-means update.
190
+ kmeans_update = torch.einsum('blm,bdm->bdl', clustering_result.float(), pixel_value.float()) # N x C x L
191
+
192
+ kmeans_update = self._kmeans_query_batch_norm_retrieved_value(kmeans_update)
193
+ kmeans_update = self._kmeans_query_conv3_bn(kmeans_update)
194
+ query_feature = query_feature + self.drop_path_kmeans(kmeans_update)
195
+
196
+ # query self-attention.
197
+ query_qkv = self._query_qkv_conv_bn(query_space)
198
+ query_q, query_k, query_v = torch.split(query_qkv,
199
+ [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1)
200
+ query_q = query_q.reshape(N, self._num_heads, self._total_key_depth//self._num_heads, L)
201
+ query_k = query_k.reshape(N, self._num_heads, self._total_key_depth//self._num_heads, L)
202
+ query_v = query_v.reshape(N, self._num_heads, self._total_value_depth//self._num_heads, L)
203
+ self_attn_update = self._query_self_attention(query_q, query_k, query_v)
204
+ self_attn_update = self._query_conv3_bn(self_attn_update)
205
+ query_feature = query_feature + self.drop_path_attn(self_attn_update)
206
+ query_feature = F.gelu(query_feature)
207
+
208
+ # FFN.
209
+ ffn_update = self._query_ffn_conv1_bn_act(query_feature)
210
+ ffn_update = self._query_ffn_conv2_bn(ffn_update)
211
+ query_feature = query_feature + self.drop_path_ffn(ffn_update)
212
+ query_feature = F.gelu(query_feature)
213
+
214
+ return query_feature, prediction_result
215
+
216
+
217
+ class ASPP(nn.Module):
218
+ def __init__(self, in_channels, output_channels, atrous_rates):
219
+ super().__init__()
220
+
221
+ self._aspp_conv0 = ConvBN(in_channels, output_channels, kernel_size=1, bias=False,
222
+ norm='syncbn', act='gelu')
223
+
224
+ rate1, rate2, rate3 = atrous_rates
225
+ self._aspp_conv1 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate1, padding=rate1, bias=False,
226
+ norm='syncbn', act='gelu')
227
+
228
+ self._aspp_conv2 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate2, padding=rate2, bias=False,
229
+ norm='syncbn', act='gelu')
230
+
231
+ self._aspp_conv3 = ConvBN(in_channels, output_channels, kernel_size=3, dilation=rate3, padding=rate3, bias=False,
232
+ norm='syncbn', act='gelu')
233
+
234
+ self._avg_pool = nn.AdaptiveAvgPool2d(1)
235
+ self._aspp_pool = ConvBN(in_channels, output_channels, kernel_size=1, bias=False,
236
+ norm='syncbn', act='gelu')
237
+
238
+ self._proj_conv_bn_act = ConvBN(output_channels * 5, output_channels, kernel_size=1, bias=False,
239
+ norm='syncbn', act='gelu')
240
+ # https://github.com/google-research/deeplab2/blob/main/model/decoder/aspp.py#L249
241
+ self._proj_drop = nn.Dropout(p=0.1)
242
+
243
+ def forward(self, x):
244
+ results = []
245
+ results.append(self._aspp_conv0(x))
246
+ results.append(self._aspp_conv1(x))
247
+ results.append(self._aspp_conv2(x))
248
+ results.append(self._aspp_conv3(x))
249
+ align_corners = (x.shape[-1] % 2 == 1)
250
+ results.append(F.interpolate(self._aspp_pool(self._avg_pool(x)), size=x.shape[-2:], mode='bilinear', align_corners=align_corners))
251
+
252
+ x = torch.cat(results, dim=1)
253
+ x = self._proj_conv_bn_act(x)
254
+ x = self._proj_drop(x)
255
+
256
+ return x
257
+
258
+
259
+ class SemanticPredictor(nn.Module):
260
+ def __init__(self, in_channels, os8_channels, os4_channels, num_classes):
261
+ super().__init__()
262
+
263
+ # Below is PanopticDeepLabSingleDecoder
264
+ self._aspp = ASPP(
265
+ in_channels=in_channels,
266
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_r50_os32.textproto#L35
267
+ output_channels=256,
268
+ # https://github.com/google-research/deeplab2/blob/main/configs/coco/kmax_deeplab/kmax_meta_r50_os32.textproto#L36
269
+ atrous_rates=[6,12,18])
270
+
271
+ self._low_level_projection_os8 = ConvBN(os8_channels, 64, kernel_size=1, bias=False,
272
+ norm='syncbn', act='gelu')
273
+
274
+ self._low_level_fusion_os8_conv0_bn_act = ConvBN(256 + 64, 256 + 64, groups=256 + 64, kernel_size=5, padding=2, bias=False,
275
+ norm='syncbn', act='gelu', conv_init='xavier_uniform')
276
+ self._low_level_fusion_os8_conv1_bn_act = ConvBN(256 + 64, 256, kernel_size=1,bias=False,
277
+ norm='syncbn', act='gelu')
278
+
279
+ self._low_level_projection_os4 = ConvBN(os4_channels, 32, kernel_size=1, bias=False,
280
+ norm='syncbn', act='gelu')
281
+
282
+ self._low_level_fusion_os4_conv0_bn_act = ConvBN(256 + 32, 256 + 32, groups=256 + 32, kernel_size=5, padding=2, bias=False,
283
+ norm='syncbn', act='gelu', conv_init='xavier_uniform')
284
+ self._low_level_fusion_os4_conv1_bn_act = ConvBN(256 + 32, 256, kernel_size=1,bias=False,
285
+ norm='syncbn', act='gelu')
286
+
287
+ # Below is PanopticDeepLabSingleHead
288
+ self.conv_block_0 = ConvBN(256, 256, groups=256, kernel_size=5, padding=2, bias=False,
289
+ norm='syncbn', act='gelu', conv_init='xavier_uniform')
290
+ self.conv_block_1 = ConvBN(256, 256, kernel_size=1,bias=False,
291
+ norm='syncbn', act='gelu')
292
+ self.final_conv = ConvBN(256, num_classes, kernel_size=1, norm=None, act=None)
293
+ trunc_normal_(self.final_conv.conv.weight, std=0.01)
294
+
295
+ def forward(self, x, low_features_os8, low_features_os4):
296
+ x = self._aspp(x)
297
+ align_corners = (x.shape[-1] % 2 == 1)
298
+ low_features_os8 = self._low_level_projection_os8(low_features_os8)
299
+ x = F.interpolate(x, size=low_features_os8.shape[-2:], mode='bilinear', align_corners=align_corners)
300
+ x = torch.concat([x, low_features_os8], dim=1)
301
+ x = self._low_level_fusion_os8_conv0_bn_act(x)
302
+ x = self._low_level_fusion_os8_conv1_bn_act(x)
303
+
304
+ low_features_os4 = self._low_level_projection_os4(low_features_os4)
305
+ x = F.interpolate(x, size=low_features_os4.shape[-2:], mode='bilinear', align_corners=align_corners)
306
+ x = torch.concat([x, low_features_os4], dim=1)
307
+ x = self._low_level_fusion_os4_conv0_bn_act(x)
308
+ x = self._low_level_fusion_os4_conv1_bn_act(x)
309
+
310
+ x = self.conv_block_0(x)
311
+ x = self.conv_block_1(x)
312
+ x = self.final_conv(x)
313
+ return x
314
+
315
+
316
+ @TRANSFORMER_DECODER_REGISTRY.register()
317
+ class kMaXTransformerDecoder(nn.Module):
318
+
319
+ @configurable
320
+ def __init__(
321
+ self,
322
+ *,
323
+ dec_layers: List[int],
324
+ in_channels: List[int],
325
+ num_classes: int,
326
+ num_queries: int,
327
+ drop_path_prob: float,
328
+ add_aux_semantic_pred: bool,
329
+ input_shape_from_backbone,
330
+ ):
331
+ """
332
+ NOTE: this interface is experimental.
333
+ Args:
334
+ """
335
+ super().__init__()
336
+
337
+ # define Transformer decoder here
338
+ self._kmax_transformer_layers = nn.ModuleList()
339
+ self._num_blocks = dec_layers
340
+ os2channels = {32: in_channels[0], 16: in_channels[1], 8: in_channels[2]}
341
+
342
+ for index, output_stride in enumerate([32, 16, 8]):
343
+ for _ in range(self._num_blocks[index]):
344
+ self._kmax_transformer_layers.append(
345
+ kMaXTransformerLayer(num_classes=num_classes+1,
346
+ in_channel_pixel=os2channels[output_stride],
347
+ in_channel_query=256,
348
+ base_filters=128,
349
+ num_heads=8,
350
+ bottleneck_expansion=2,
351
+ key_expansion=1,
352
+ value_expansion=2,
353
+ drop_path_prob=drop_path_prob)
354
+ )
355
+
356
+
357
+ self._num_queries = num_queries
358
+ # learnable query features
359
+ self._cluster_centers = nn.Embedding(256, num_queries)
360
+ trunc_normal_(self._cluster_centers.weight, std=1.0)
361
+
362
+ self._class_embedding_projection = ConvBN(256, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu',
363
+ conv_type='1d')
364
+
365
+ self._mask_embedding_projection = ConvBN(256, 256, kernel_size=1, bias=False, norm='syncbn', act='gelu',
366
+ conv_type='1d')
367
+
368
+ self._predcitor = kMaXPredictor(in_channel_pixel=256,
369
+ in_channel_query=256, num_classes=num_classes+1)
370
+
371
+
372
+ self._add_aux_semantic_pred = add_aux_semantic_pred
373
+ if add_aux_semantic_pred:
374
+ self._auxiliary_semantic_predictor = SemanticPredictor(
375
+ in_channels=input_shape_from_backbone['res5'].channels,
376
+ os8_channels=input_shape_from_backbone['res3'].channels,
377
+ os4_channels=input_shape_from_backbone['res2'].channels,
378
+ # +1 for void.
379
+ num_classes=num_classes+1)
380
+
381
+
382
+ @classmethod
383
+ def from_config(cls, cfg, input_shape_from_backbone):
384
+ ret = {}
385
+ ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DEC_LAYERS
386
+ ret["in_channels"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.IN_CHANNELS
387
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
388
+ ret["num_queries"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.NUM_OBJECT_QUERIES
389
+ ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.TRANS_DEC.DROP_PATH_PROB
390
+ ret["add_aux_semantic_pred"] = (cfg.MODEL.KMAX_DEEPLAB.AUX_SEMANTIC_WEIGHT > 0)
391
+ ret["input_shape_from_backbone"] = input_shape_from_backbone
392
+ return ret
393
+
394
+
395
+ def forward(self, x, panoptic_features, semantic_features):
396
+ B = x[0].shape[0]
397
+ cluster_centers = self._cluster_centers.weight.unsqueeze(0).repeat(B, 1, 1) # B x C x L
398
+
399
+ current_transformer_idx = 0
400
+
401
+ predictions_class = []
402
+ predictions_mask = []
403
+ predictions_pixel_feature = []
404
+
405
+ for i, feat in enumerate(x):
406
+ for _ in range(self._num_blocks[i]):
407
+ cluster_centers, prediction_result = self._kmax_transformer_layers[current_transformer_idx](
408
+ pixel_feature=feat, query_feature=cluster_centers
409
+ )
410
+ predictions_class.append(prediction_result['class_logits'])
411
+ predictions_mask.append(prediction_result['mask_logits'])
412
+ predictions_pixel_feature.append(prediction_result['pixel_feature'])
413
+ current_transformer_idx += 1
414
+
415
+ class_embeddings = self._class_embedding_projection(cluster_centers)
416
+ mask_embeddings = self._mask_embedding_projection(cluster_centers)
417
+
418
+ # Final predictions.
419
+ prediction_result = self._predcitor(
420
+ class_embeddings=class_embeddings,
421
+ mask_embeddings=mask_embeddings,
422
+ pixel_feature=panoptic_features,
423
+ )
424
+ predictions_class.append(prediction_result['class_logits'])
425
+ predictions_mask.append(prediction_result['mask_logits'])
426
+ predictions_pixel_feature.append(prediction_result['pixel_feature'])
427
+
428
+ out = {
429
+ 'pred_logits': predictions_class[-1],
430
+ 'pred_masks': predictions_mask[-1],
431
+ 'pixel_feature': predictions_pixel_feature[-1],
432
+ 'aux_outputs': self._set_aux_loss(
433
+ predictions_class, predictions_mask, predictions_pixel_feature
434
+ ),
435
+ }
436
+
437
+ if self._add_aux_semantic_pred and self.training:
438
+ semantic_features, low_features_os8, low_features_os4 = semantic_features
439
+ aux_semantic_prediction = self._auxiliary_semantic_predictor(
440
+ x=semantic_features, low_features_os8=low_features_os8, low_features_os4=low_features_os4)
441
+ out.update({'aux_semantic_pred': aux_semantic_prediction,})
442
+ return out
443
+
444
+
445
+ @torch.jit.unused
446
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_pixel_feature):
447
+ target_size = outputs_seg_masks[-1].shape[-2:]
448
+ align_corners = (target_size[0] % 2 == 1)
449
+ return [
450
+ {"pred_logits": a, "pred_masks": F.interpolate(b, size=target_size, mode="bilinear", align_corners=align_corners),
451
+ "pixel_feature": F.interpolate(c, size=target_size, mode="bilinear", align_corners=align_corners),}
452
+ for a, b, c in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_pixel_feature[:-1])
453
+ ]
pakages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libtinfo5
2
+ libsm6
3
+ libxext6
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pyyaml==5.1
2
+ torch==1.9.0
3
+ torchvision==0.10.0
4
+
5
+ docutils==0.16
6
+ # https://github.com/sphinx-doc/sphinx/commit/7acd3ada3f38076af7b2b5c9f3b60bb9c2587a3d
7
+ sphinx==3.2.0
8
+ recommonmark==0.6.0
9
+ sphinx_rtd_theme
10
+ # Dependencies here are only those required by import
11
+ termcolor
12
+ numpy
13
+ tqdm
14
+ matplotlib
15
+ termcolor
16
+ yacs
17
+ tabulate
18
+ cloudpickle
19
+ Pillow
20
+ future
21
+ fvcore
22
+ omegaconf>=2.1.0.dev24
23
+ hydra-core>=1.1.0.dev5
24
+
25
+ opencv-python-headless
26
+
27
+
28
+ cython
29
+ scipy
30
+ shapely
31
+ timm
32
+ h5py
33
+ submitit
34
+ scikit-image
train_net.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py
2
+ # Modified by Qihang Yu
3
+
4
+ try:
5
+ # ignore ShapelyDeprecationWarning from fvcore
6
+ from shapely.errors import ShapelyDeprecationWarning
7
+ import warnings
8
+ warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
9
+ except:
10
+ pass
11
+
12
+ import copy
13
+ import itertools
14
+ import os
15
+
16
+ from typing import Any, Dict, List, Set
17
+
18
+ import torch
19
+
20
+ import detectron2.utils.comm as comm
21
+ from detectron2.checkpoint import DetectionCheckpointer
22
+ from detectron2.config import get_cfg
23
+ from detectron2.data import MetadataCatalog, build_detection_train_loader, build_detection_test_loader
24
+ from detectron2.engine import (
25
+ DefaultTrainer,
26
+ default_argument_parser,
27
+ default_setup,
28
+ launch,
29
+ )
30
+ from detectron2.evaluation import (
31
+ COCOEvaluator,
32
+ DatasetEvaluators,
33
+ SemSegEvaluator,
34
+ verify_results,
35
+ )
36
+ from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
37
+ from detectron2.solver.build import maybe_add_gradient_clipping
38
+ from detectron2.utils.logger import setup_logger
39
+
40
+ # MaskFormer
41
+ from kmax_deeplab import (
42
+ COCOPanoptickMaXDeepLabDatasetMapper,
43
+ add_kmax_deeplab_config,
44
+ )
45
+
46
+ from detectron2.data import MetadataCatalog
47
+
48
+ import train_net_utils
49
+
50
+
51
+ class Trainer(DefaultTrainer):
52
+ """
53
+ Extension of the Trainer class adapted to MaskFormer.
54
+ """
55
+
56
+ @classmethod
57
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
58
+ """
59
+ Create evaluator(s) for a given dataset.
60
+ This uses the special metadata "evaluator_type" associated with each
61
+ builtin dataset. For your own dataset, you can simply create an
62
+ evaluator manually in your script and do not have to worry about the
63
+ hacky if-else logic here.
64
+ """
65
+ if output_folder is None:
66
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
67
+ evaluator_list = []
68
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
69
+ # panoptic segmentation
70
+ if evaluator_type in [
71
+ "coco_panoptic_seg",
72
+ ]:
73
+ if cfg.MODEL.KMAX_DEEPLAB.TEST.PANOPTIC_ON:
74
+ evaluator_list.append(train_net_utils.COCOPanopticEvaluatorwithVis(dataset_name, output_folder, save_vis_num=cfg.MODEL.KMAX_DEEPLAB.SAVE_VIS_NUM))
75
+ # COCO
76
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.KMAX_DEEPLAB.TEST.INSTANCE_ON:
77
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
78
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.KMAX_DEEPLAB.TEST.SEMANTIC_ON:
79
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
80
+ elif len(evaluator_list) == 1:
81
+ return evaluator_list[0]
82
+ return DatasetEvaluators(evaluator_list)
83
+
84
+ @classmethod
85
+ def build_train_loader(cls, cfg):
86
+ # Semantic segmentation dataset mapper
87
+ if cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
88
+ mapper = COCOPanoptickMaXDeepLabDatasetMapper(cfg, True)
89
+ return build_detection_train_loader(cfg, mapper=mapper)
90
+ else:
91
+ mapper = None
92
+ return build_detection_train_loader(cfg, mapper=mapper)
93
+
94
+
95
+ @classmethod
96
+ def build_lr_scheduler(cls, cfg, optimizer):
97
+ """
98
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
99
+ Overwrite it if you'd like a different scheduler.
100
+ """
101
+ name = cfg.SOLVER.LR_SCHEDULER_NAME
102
+ if name == "TF2WarmupPolyLR":
103
+ return train_net_utils.TF2WarmupPolyLR(
104
+ optimizer,
105
+ cfg.SOLVER.MAX_ITER,
106
+ warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
107
+ warmup_iters=cfg.SOLVER.WARMUP_ITERS,
108
+ warmup_method=cfg.SOLVER.WARMUP_METHOD,
109
+ power=cfg.SOLVER.POLY_LR_POWER,
110
+ constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING,
111
+ )
112
+ else:
113
+ return build_lr_scheduler(cfg, optimizer)
114
+
115
+ @classmethod
116
+ def build_optimizer(cls, cfg, model):
117
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
118
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
119
+
120
+ defaults = {}
121
+ defaults["lr"] = cfg.SOLVER.BASE_LR
122
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
123
+
124
+ from kmax_deeplab.modeling.backbone.convnext import LayerNorm
125
+
126
+ norm_module_types = (
127
+ torch.nn.BatchNorm1d,
128
+ torch.nn.BatchNorm2d,
129
+ torch.nn.BatchNorm3d,
130
+ torch.nn.SyncBatchNorm,
131
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
132
+ torch.nn.GroupNorm,
133
+ torch.nn.InstanceNorm1d,
134
+ torch.nn.InstanceNorm2d,
135
+ torch.nn.InstanceNorm3d,
136
+ torch.nn.LayerNorm,
137
+ torch.nn.LocalResponseNorm,
138
+ LayerNorm
139
+ )
140
+
141
+ params: List[Dict[str, Any]] = []
142
+ memo: Set[torch.nn.parameter.Parameter] = set()
143
+ for module_name, module in model.named_modules():
144
+ for module_param_name, value in module.named_parameters(recurse=False):
145
+ if not value.requires_grad:
146
+ continue
147
+ # Avoid duplicating parameters
148
+ if value in memo:
149
+ continue
150
+ memo.add(value)
151
+
152
+ hyperparams = copy.copy(defaults)
153
+ hyperparams["name"] = (module_name, module_param_name)
154
+ if "backbone" in module_name:
155
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
156
+ if (
157
+ "relative_position_bias_table" in module_param_name
158
+ or "absolute_pos_embed" in module_param_name
159
+ ):
160
+ print(module_param_name)
161
+ hyperparams["weight_decay"] = 0.0
162
+ if isinstance(module, norm_module_types):
163
+ hyperparams["weight_decay"] = weight_decay_norm
164
+ if isinstance(module, torch.nn.Embedding):
165
+ hyperparams["weight_decay"] = weight_decay_embed
166
+ # Rule for kMaX.
167
+ if "_rpe" in module_name:
168
+ # relative positional embedding in axial attention.
169
+ hyperparams["weight_decay"] = 0.0
170
+ if "_cluster_centers" in module_name:
171
+ # cluster center embeddings.
172
+ hyperparams["weight_decay"] = 0.0
173
+ if "bias" in module_param_name:
174
+ # any bias terms.
175
+ hyperparams["weight_decay"] = 0.0
176
+ if "gamma" in module_param_name:
177
+ # gamma term in convnext
178
+ hyperparams["weight_decay"] = 0.0
179
+
180
+ params.append({"params": [value], **hyperparams})
181
+ for param_ in params:
182
+ print(param_["name"], param_["lr"], param_["weight_decay"])
183
+
184
+ def maybe_add_full_model_gradient_clipping(optim):
185
+ # detectron2 doesn't have full model gradient clipping now
186
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
187
+ enable = (
188
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
189
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
190
+ and clip_norm_val > 0.0
191
+ )
192
+
193
+ class FullModelGradientClippingOptimizer(optim):
194
+ def step(self, closure=None):
195
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
196
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
197
+ super().step(closure=closure)
198
+
199
+ return FullModelGradientClippingOptimizer if enable else optim
200
+
201
+ optimizer_type = cfg.SOLVER.OPTIMIZER
202
+ if optimizer_type == "SGD":
203
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
204
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
205
+ )
206
+ elif optimizer_type == "ADAMW":
207
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
208
+ params, cfg.SOLVER.BASE_LR
209
+ )
210
+ elif optimizer_type == "ADAM":
211
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.Adam)(
212
+ params, cfg.SOLVER.BASE_LR
213
+ )
214
+ else:
215
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
216
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
217
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
218
+ return optimizer
219
+
220
+
221
+ def setup(args):
222
+ """
223
+ Create configs and perform basic setups.
224
+ """
225
+ cfg = get_cfg()
226
+ # for poly lr schedule
227
+ add_deeplab_config(cfg)
228
+ add_kmax_deeplab_config(cfg)
229
+ cfg.merge_from_file(args.config_file)
230
+ cfg.merge_from_list(args.opts)
231
+ cfg.freeze()
232
+ default_setup(cfg, args)
233
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="kmax_deeplab")
234
+ return cfg
235
+
236
+
237
+ def main(args):
238
+ cfg = setup(args)
239
+
240
+ torch.backends.cudnn.enabled = True
241
+ if args.eval_only:
242
+ model = Trainer.build_model(cfg)
243
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
244
+ cfg.MODEL.WEIGHTS, resume=args.resume
245
+ )
246
+ res = Trainer.test(cfg, model)
247
+ if comm.is_main_process():
248
+ verify_results(cfg, res)
249
+ return res
250
+
251
+ trainer = Trainer(cfg)
252
+ trainer.resume_or_load(resume=args.resume)
253
+ return trainer.train()
254
+
255
+
256
+ if __name__ == "__main__":
257
+ args = default_argument_parser().parse_args()
258
+ print("Command Line Args:", args)
259
+ launch(
260
+ main,
261
+ args.num_gpus,
262
+ num_machines=args.num_machines,
263
+ machine_rank=args.machine_rank,
264
+ dist_url=args.dist_url,
265
+ args=(args,),
266
+ )
train_net_utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ import numpy as np
8
+ import tempfile
9
+ from collections import OrderedDict
10
+ from PIL import Image
11
+ from tabulate import tabulate
12
+ import json
13
+ import contextlib
14
+
15
+ import detectron2.utils.comm as comm
16
+ from detectron2.utils.file_io import PathManager
17
+ from detectron2.data import MetadataCatalog
18
+ from detectron2.evaluation import COCOPanopticEvaluator
19
+
20
+ from detectron2.utils.visualizer import ColorMode, Visualizer
21
+ from detectron2.data import MetadataCatalog
22
+ import io
23
+ import math
24
+ from PIL import Image
25
+
26
+ from detectron2.solver.lr_scheduler import _get_warmup_factor_at_iter
27
+
28
+
29
+ import logging
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class TF2WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
35
+ """
36
+ Poly learning rate schedule used in TF DeepLab2.
37
+ Reference: https://github.com/google-research/deeplab2/blob/main/trainer/trainer_utils.py#L23
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ optimizer: torch.optim.Optimizer,
43
+ max_iters: int,
44
+ warmup_factor: float = 0.001,
45
+ warmup_iters: int = 1000,
46
+ warmup_method: str = "linear",
47
+ last_epoch: int = -1,
48
+ power: float = 0.9,
49
+ constant_ending: float = 0.0,
50
+ ):
51
+ self.max_iters = max_iters
52
+ self.warmup_factor = warmup_factor
53
+ self.warmup_iters = warmup_iters
54
+ self.warmup_method = warmup_method
55
+ self.power = power
56
+ self.constant_ending = constant_ending
57
+ super().__init__(optimizer, last_epoch)
58
+
59
+ def get_lr(self) -> List[float]:
60
+ warmup_factor = _get_warmup_factor_at_iter(
61
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
62
+ )
63
+ if self.constant_ending > 0 and warmup_factor == 1.0:
64
+ # Constant ending lr.
65
+ if (
66
+ math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
67
+ < self.constant_ending
68
+ ):
69
+ return [base_lr * self.constant_ending for base_lr in self.base_lrs]
70
+ if self.last_epoch < self.warmup_iters:
71
+ return [
72
+ base_lr * warmup_factor
73
+ for base_lr in self.base_lrs
74
+ ]
75
+ else:
76
+ return [
77
+ base_lr * math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
78
+ for base_lr in self.base_lrs
79
+ ]
80
+
81
+ def _compute_values(self) -> List[float]:
82
+ # The new interface
83
+ return self.get_lr()
84
+
85
+
86
+ class COCOPanopticEvaluatorwithVis(COCOPanopticEvaluator):
87
+ """
88
+ COCO Panoptic Evaluator that supports saving visualizations.
89
+ TODO(qihangyu): Note that original implementation will also write all predictions to a tmp folder
90
+ and then run official evaluation script, we may also check how to copy from the tmp folder for visualization.
91
+ """
92
+
93
+ def __init__(self, dataset_name: str, output_dir: Optional[str] = None, save_vis_num=0):
94
+ super().__init__(dataset_name=dataset_name, output_dir=output_dir)
95
+ self.metadata = MetadataCatalog.get("coco_2017_val_panoptic_with_sem_seg")
96
+ self.output_dir = output_dir
97
+ self.save_vis_num = save_vis_num
98
+
99
+ def process(self, inputs, outputs):
100
+ from panopticapi.utils import id2rgb
101
+
102
+ cur_save_num = 0
103
+ for input, output in zip(inputs, outputs):
104
+ panoptic_img, segments_info = output["panoptic_seg"]
105
+ panoptic_seg = panoptic_img.cpu()
106
+ panoptic_img = panoptic_seg.numpy()
107
+
108
+ file_name = os.path.basename(input["file_name"])
109
+ file_name_png = os.path.splitext(file_name)[0] + ".png"
110
+ if cur_save_num < self.save_vis_num:
111
+ image = output["original_image"]
112
+ image = image.permute(1, 2 ,0).cpu().numpy()#[:, :, ::-1]
113
+ visualizer = Visualizer(image, self.metadata, instance_mode=ColorMode.IMAGE)
114
+ vis_output = visualizer.draw_panoptic_seg_predictions(
115
+ panoptic_seg, segments_info
116
+ )
117
+ if not os.path.exists(os.path.join(self.output_dir, 'vis')):
118
+ os.makedirs(os.path.join(self.output_dir, 'vis'))
119
+ out_filename = os.path.join(self.output_dir, 'vis', file_name_png)
120
+ vis_output.save(out_filename)
121
+ cur_save_num += 1
122
+
123
+ if segments_info is None:
124
+ # If "segments_info" is None, we assume "panoptic_img" is a
125
+ # H*W int32 image storing the panoptic_id in the format of
126
+ # category_id * label_divisor + instance_id. We reserve -1 for
127
+ # VOID label, and add 1 to panoptic_img since the official
128
+ # evaluation script uses 0 for VOID label.
129
+ label_divisor = self._metadata.label_divisor
130
+ segments_info = []
131
+ for panoptic_label in np.unique(panoptic_img):
132
+ if panoptic_label == -1:
133
+ # VOID region.
134
+ continue
135
+ pred_class = panoptic_label // label_divisor
136
+ isthing = (
137
+ pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
138
+ )
139
+ segments_info.append(
140
+ {
141
+ "id": int(panoptic_label) + 1,
142
+ "category_id": int(pred_class),
143
+ "isthing": bool(isthing),
144
+ }
145
+ )
146
+ # Official evaluation script uses 0 for VOID label.
147
+ panoptic_img += 1
148
+
149
+
150
+ with io.BytesIO() as out:
151
+ Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
152
+ segments_info = [self._convert_category_id(x) for x in segments_info]
153
+ self._predictions.append(
154
+ {
155
+ "image_id": input["image_id"],
156
+ "file_name": file_name_png,
157
+ "png_string": out.getvalue(),
158
+ "segments_info": segments_info,
159
+ }
160
+ )
161
+
162
+ def evaluate(self):
163
+ comm.synchronize()
164
+
165
+ self._predictions = comm.gather(self._predictions)
166
+ self._predictions = list(itertools.chain(*self._predictions))
167
+ if not comm.is_main_process():
168
+ return
169
+
170
+ # PanopticApi requires local files
171
+ gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
172
+ gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
173
+
174
+ with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
175
+ logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
176
+ for p in self._predictions:
177
+ with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
178
+ f.write(p.pop("png_string"))
179
+
180
+ with open(gt_json, "r") as f:
181
+ json_data = json.load(f)
182
+ json_data["annotations"] = self._predictions
183
+
184
+ output_dir = self._output_dir or pred_dir
185
+ predictions_json = os.path.join(output_dir, "predictions.json")
186
+ with PathManager.open(predictions_json, "w") as f:
187
+ f.write(json.dumps(json_data))
188
+
189
+ from kmax_deeplab.evaluation.panoptic_evaluation import pq_compute
190
+
191
+ with contextlib.redirect_stdout(io.StringIO()):
192
+ pq_res = pq_compute(
193
+ gt_json,
194
+ PathManager.get_local_path(predictions_json),
195
+ gt_folder=gt_folder,
196
+ pred_folder=pred_dir,
197
+ )
198
+
199
+ res = {}
200
+ res["PQ"] = 100 * pq_res["All"]["pq"]
201
+ res["SQ"] = 100 * pq_res["All"]["sq"]
202
+ res["RQ"] = 100 * pq_res["All"]["rq"]
203
+ res["PQ_th"] = 100 * pq_res["Things"]["pq"]
204
+ res["SQ_th"] = 100 * pq_res["Things"]["sq"]
205
+ res["RQ_th"] = 100 * pq_res["Things"]["rq"]
206
+ res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
207
+ res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
208
+ res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
209
+
210
+ results = OrderedDict({"panoptic_seg": res})
211
+ _print_panoptic_results(pq_res)
212
+
213
+ return results
214
+
215
+
216
+ def _print_panoptic_results(pq_res):
217
+ headers = ["", "PQ", "SQ", "RQ", "#categories"]
218
+ data = []
219
+ for name in ["All", "Things", "Stuff"]:
220
+ row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
221
+ data.append(row)
222
+ table = tabulate(
223
+ data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
224
+ )
225
+ logger.info("Panoptic Evaluation Results:\n" + table)