Spaces:
Runtime error
Runtime error
seokju cho
commited on
Commit
·
f8f62f3
1
Parent(s):
7722584
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- INSTALL.md +20 -0
- R-101.pkl +3 -0
- README.md +48 -12
- app.py +130 -0
- assets/fig1.png +0 -0
- cat_seg/__init__.py +19 -0
- cat_seg/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/__pycache__/cat_sam_model.cpython-38.pyc +0 -0
- cat_seg/__pycache__/cat_seg_model.cpython-38.pyc +0 -0
- cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc +0 -0
- cat_seg/__pycache__/config.cpython-38.pyc +0 -0
- cat_seg/__pycache__/pancat_model.cpython-38.pyc +0 -0
- cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc +0 -0
- cat_seg/cat_seg_model.py +386 -0
- cat_seg/config.py +93 -0
- cat_seg/data/__init__.py +2 -0
- cat_seg/data/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/data/dataset_mappers/__init__.py +1 -0
- cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc +0 -0
- cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc +0 -0
- cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc +0 -0
- cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py +180 -0
- cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py +165 -0
- cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py +186 -0
- cat_seg/data/datasets/__init__.py +8 -0
- cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc +0 -0
- cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc +0 -0
- cat_seg/data/datasets/register_ade20k_150.py +28 -0
- cat_seg/data/datasets/register_ade20k_847.py +0 -0
- cat_seg/data/datasets/register_coco_stuff.py +216 -0
- cat_seg/data/datasets/register_pascal_20.py +53 -0
- cat_seg/data/datasets/register_pascal_59.py +81 -0
- cat_seg/modeling/__init__.py +3 -0
- cat_seg/modeling/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/modeling/__pycache__/criterion.cpython-38.pyc +0 -0
- cat_seg/modeling/__pycache__/matcher.cpython-38.pyc +0 -0
- cat_seg/modeling/backbone/__init__.py +1 -0
- cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc +0 -0
- cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc +0 -0
- cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc +0 -0
- cat_seg/modeling/backbone/swin.py +768 -0
- cat_seg/modeling/heads/__init__.py +1 -0
INSTALL.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Installation
|
2 |
+
|
3 |
+
### Requirements
|
4 |
+
- Linux or macOS with Python ≥ 3.6
|
5 |
+
- PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
|
6 |
+
Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check
|
7 |
+
PyTorch version matches that is required by Detectron2.
|
8 |
+
- Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html).
|
9 |
+
- OpenCV is optional but needed by demo and visualization
|
10 |
+
- `pip install -r requirements.txt`
|
11 |
+
|
12 |
+
An example of installation is shown below:
|
13 |
+
|
14 |
+
```
|
15 |
+
git clone https://github.com/~~~/CAT-Seg.git
|
16 |
+
cd CAT-Seg
|
17 |
+
conda create -n catseg python=3.8
|
18 |
+
conda activate catseg
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
R-101.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1156c77bff95ecb027060b5c83391b45bf159acd7f5bf7eacb656be0c1f0ab55
|
3 |
+
size 178666803
|
README.md
CHANGED
@@ -1,12 +1,48 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CAT-Seg🐱: Cost Aggregation for Open-Vocabulary Semantic Segmentation
|
2 |
+
|
3 |
+
This is our official implementation of CAT-Seg🐱!
|
4 |
+
|
5 |
+
[[arXiv](#)] [[Project](#)]<br>
|
6 |
+
by [Seokju Cho](https://seokju-cho.github.io/)\*, [Heeseong Shin](https://github.com/hsshin98)\*, [Sunghwan Hong](https://sunghwanhong.github.io), Seungjun An, Seungjun Lee, [Anurag Arnab](https://anuragarnab.github.io), [Paul Hongsuck Seo](https://phseo.github.io), [Seungryong Kim](https://cvlab.korea.ac.kr)
|
7 |
+
|
8 |
+
|
9 |
+
## Introduction
|
10 |
+
![](assets/fig1.png)
|
11 |
+
We introduce cost aggregation to open-vocabulary semantic segmentation, which jointly aggregates both image and text modalities within the matching cost.
|
12 |
+
|
13 |
+
## Installation
|
14 |
+
Install required packages.
|
15 |
+
|
16 |
+
```bash
|
17 |
+
conda create --name catseg python=3.8
|
18 |
+
conda activate catseg
|
19 |
+
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
## Data Preparation
|
24 |
+
|
25 |
+
|
26 |
+
## Training
|
27 |
+
### Preparation
|
28 |
+
you have to blah
|
29 |
+
### Training script
|
30 |
+
```bash
|
31 |
+
python train.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
|
32 |
+
```
|
33 |
+
|
34 |
+
## Evaluation
|
35 |
+
```bash
|
36 |
+
python eval.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
|
37 |
+
```
|
38 |
+
|
39 |
+
## Citing CAT-Seg🐱 :pray:
|
40 |
+
|
41 |
+
```BibTeX
|
42 |
+
@article{liang2022open,
|
43 |
+
title={Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP},
|
44 |
+
author={Liang, Feng and Wu, Bichen and Dai, Xiaoliang and Li, Kunpeng and Zhao, Yinan and Zhang, Hang and Zhang, Peizhao and Vajda, Peter and Marculescu, Diana},
|
45 |
+
journal={arXiv preprint arXiv:2210.04150},
|
46 |
+
year={2022}
|
47 |
+
}
|
48 |
+
```
|
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import multiprocessing as mp
|
6 |
+
import os
|
7 |
+
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
8 |
+
try:
|
9 |
+
import detectron2
|
10 |
+
except ModuleNotFoundError:
|
11 |
+
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
12 |
+
|
13 |
+
try:
|
14 |
+
import segment_anything
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
os.system('pip install git+https://github.com/facebookresearch/segment-anything.git')
|
17 |
+
|
18 |
+
# fmt: off
|
19 |
+
import sys
|
20 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
21 |
+
# fmt: on
|
22 |
+
|
23 |
+
import tempfile
|
24 |
+
import time
|
25 |
+
import warnings
|
26 |
+
|
27 |
+
import cv2
|
28 |
+
import numpy as np
|
29 |
+
import tqdm
|
30 |
+
|
31 |
+
from detectron2.config import get_cfg
|
32 |
+
from detectron2.data.detection_utils import read_image
|
33 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
34 |
+
from detectron2.utils.logger import setup_logger
|
35 |
+
|
36 |
+
from cat_seg import add_cat_seg_config
|
37 |
+
from demo.predictor import VisualizationDemo
|
38 |
+
import gradio as gr
|
39 |
+
import torch
|
40 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
|
41 |
+
|
42 |
+
# constants
|
43 |
+
WINDOW_NAME = "MaskFormer demo"
|
44 |
+
|
45 |
+
|
46 |
+
def setup_cfg(args):
|
47 |
+
# load config from file and command-line arguments
|
48 |
+
cfg = get_cfg()
|
49 |
+
add_deeplab_config(cfg)
|
50 |
+
add_cat_seg_config(cfg)
|
51 |
+
cfg.merge_from_file(args.config_file)
|
52 |
+
cfg.merge_from_list(args.opts)
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
cfg.MODEL.DEVICE = "cuda"
|
55 |
+
cfg.freeze()
|
56 |
+
return cfg
|
57 |
+
|
58 |
+
|
59 |
+
def get_parser():
|
60 |
+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
61 |
+
parser.add_argument(
|
62 |
+
"--config-file",
|
63 |
+
default="configs/vitl_swinb_384.yaml",
|
64 |
+
metavar="FILE",
|
65 |
+
help="path to config file",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--input",
|
69 |
+
nargs="+",
|
70 |
+
help="A list of space separated input images; "
|
71 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--opts",
|
75 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
76 |
+
default=(
|
77 |
+
[
|
78 |
+
"MODEL.WEIGHTS", "model_final_cls.pth",
|
79 |
+
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
|
80 |
+
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
|
81 |
+
"TEST.SLIDING_WINDOW", "True",
|
82 |
+
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
|
83 |
+
"MODEL.PROMPT_ENSEMBLE_TYPE", "single",
|
84 |
+
"MODEL.DEVICE", "cpu",
|
85 |
+
]),
|
86 |
+
nargs=argparse.REMAINDER,
|
87 |
+
)
|
88 |
+
return parser
|
89 |
+
|
90 |
+
def save_masks(preds, text):
|
91 |
+
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
|
92 |
+
for i, t in enumerate(text):
|
93 |
+
dir = f"mask_{t}.png"
|
94 |
+
mask = preds == i
|
95 |
+
cv2.imwrite(dir, mask * 255)
|
96 |
+
|
97 |
+
def predict(image, text, model_type):
|
98 |
+
#import pdb; pdb.set_trace()
|
99 |
+
#use_sam = True #
|
100 |
+
use_sam = model_type != "CAT-Seg"
|
101 |
+
|
102 |
+
predictions, visualized_output = demo.run_on_image(image, text, use_sam)
|
103 |
+
#save_masks(predictions, text.split(','))
|
104 |
+
canvas = fc(visualized_output.fig)
|
105 |
+
canvas.draw()
|
106 |
+
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
|
107 |
+
|
108 |
+
return out[..., ::-1]
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
args = get_parser().parse_args()
|
112 |
+
cfg = setup_cfg(args)
|
113 |
+
global demo
|
114 |
+
demo = VisualizationDemo(cfg)
|
115 |
+
|
116 |
+
iface = gr.Interface(
|
117 |
+
fn=predict,
|
118 |
+
inputs=[gr.Image(), gr.Textbox(placeholder='background, cat, person'), ], #gr.Radio(["CAT-Seg", "Segment Anycat"], value="CAT-Seg")],
|
119 |
+
outputs="image",
|
120 |
+
description="""## Segment Anything with CAT-Seg!
|
121 |
+
Welcome to the Segment Anything with CAT-Seg!
|
122 |
+
|
123 |
+
In this demo, we combine state-of-the-art open-vocabulary semantic segmentation model, CAT-Seg with SAM(Segment Anything) for semantically labelling mask predictions from SAM.
|
124 |
+
|
125 |
+
Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
|
126 |
+
|
127 |
+
Also, the demo might run on a CPU depending on the demand, so it may take a little time to process your image.
|
128 |
+
|
129 |
+
To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
|
130 |
+
iface.launch()
|
assets/fig1.png
ADDED
cat_seg/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import data # register all new datasets
|
3 |
+
from . import modeling
|
4 |
+
|
5 |
+
# config
|
6 |
+
from .config import add_cat_seg_config
|
7 |
+
|
8 |
+
# dataset loading
|
9 |
+
from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
|
10 |
+
from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
|
11 |
+
MaskFormerPanopticDatasetMapper,
|
12 |
+
)
|
13 |
+
from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
|
14 |
+
MaskFormerSemanticDatasetMapper,
|
15 |
+
)
|
16 |
+
|
17 |
+
# models
|
18 |
+
from .cat_seg_model import CATSeg
|
19 |
+
from .test_time_augmentation import SemanticSegmentorWithTTA
|
cat_seg/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (693 Bytes). View file
|
|
cat_seg/__pycache__/cat_sam_model.cpython-38.pyc
ADDED
Binary file (13.7 kB). View file
|
|
cat_seg/__pycache__/cat_seg_model.cpython-38.pyc
ADDED
Binary file (12.6 kB). View file
|
|
cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc
ADDED
Binary file (10 kB). View file
|
|
cat_seg/__pycache__/config.cpython-38.pyc
ADDED
Binary file (2.39 kB). View file
|
|
cat_seg/__pycache__/pancat_model.cpython-38.pyc
ADDED
Binary file (11.4 kB). View file
|
|
cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc
ADDED
Binary file (4.41 kB). View file
|
|
cat_seg/cat_seg_model.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from detectron2.data import MetadataCatalog
|
10 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
|
11 |
+
from detectron2.modeling.backbone import Backbone
|
12 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
13 |
+
from detectron2.structures import ImageList
|
14 |
+
from detectron2.utils.memory import _ignore_torch_cuda_oom
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from einops import rearrange
|
18 |
+
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
|
19 |
+
|
20 |
+
@META_ARCH_REGISTRY.register()
|
21 |
+
class CATSeg(nn.Module):
|
22 |
+
@configurable
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
*,
|
26 |
+
backbone: Backbone,
|
27 |
+
sem_seg_head: nn.Module,
|
28 |
+
size_divisibility: int,
|
29 |
+
pixel_mean: Tuple[float],
|
30 |
+
pixel_std: Tuple[float],
|
31 |
+
clip_pixel_mean: Tuple[float],
|
32 |
+
clip_pixel_std: Tuple[float],
|
33 |
+
train_class_json: str,
|
34 |
+
test_class_json: str,
|
35 |
+
sliding_window: bool,
|
36 |
+
clip_finetune: str,
|
37 |
+
backbone_multiplier: float,
|
38 |
+
clip_pretrained: str,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
43 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
self.backbone = backbone
|
47 |
+
self.sem_seg_head = sem_seg_head
|
48 |
+
if size_divisibility < 0:
|
49 |
+
size_divisibility = self.backbone.size_divisibility
|
50 |
+
self.size_divisibility = size_divisibility
|
51 |
+
|
52 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
53 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
54 |
+
self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False)
|
55 |
+
self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False)
|
56 |
+
|
57 |
+
self.train_class_json = train_class_json
|
58 |
+
self.test_class_json = test_class_json
|
59 |
+
|
60 |
+
self.clip_finetune = clip_finetune
|
61 |
+
for name, params in self.sem_seg_head.predictor.clip_model.named_parameters():
|
62 |
+
if "visual" in name:
|
63 |
+
if clip_finetune == "prompt":
|
64 |
+
params.requires_grad = True if "prompt" in name else False
|
65 |
+
elif clip_finetune == "attention":
|
66 |
+
params.requires_grad = True if "attn" in name or "position" in name else False
|
67 |
+
elif clip_finetune == "full":
|
68 |
+
params.requires_grad = True
|
69 |
+
else:
|
70 |
+
params.requires_grad = False
|
71 |
+
else:
|
72 |
+
params.requires_grad = False
|
73 |
+
|
74 |
+
finetune_backbone = backbone_multiplier > 0.
|
75 |
+
for name, params in self.backbone.named_parameters():
|
76 |
+
if "norm0" in name:
|
77 |
+
params.requires_grad = False
|
78 |
+
else:
|
79 |
+
params.requires_grad = finetune_backbone
|
80 |
+
|
81 |
+
self.sliding_window = sliding_window
|
82 |
+
self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336)
|
83 |
+
self.sequential = False
|
84 |
+
|
85 |
+
self.use_sam = False
|
86 |
+
self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device)
|
87 |
+
|
88 |
+
amg_kwargs = {
|
89 |
+
"points_per_side": 32,
|
90 |
+
"points_per_batch": None,
|
91 |
+
#"pred_iou_thresh": 0.0,
|
92 |
+
#"stability_score_thresh": 0.0,
|
93 |
+
"stability_score_offset": None,
|
94 |
+
"box_nms_thresh": None,
|
95 |
+
"crop_n_layers": None,
|
96 |
+
"crop_nms_thresh": None,
|
97 |
+
"crop_overlap_ratio": None,
|
98 |
+
"crop_n_points_downscale_factor": None,
|
99 |
+
"min_mask_region_area": None,
|
100 |
+
}
|
101 |
+
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
102 |
+
self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs)
|
103 |
+
self.overlap_threshold = 0.8
|
104 |
+
self.panoptic_on = False
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def from_config(cls, cfg):
|
108 |
+
backbone = build_backbone(cfg)
|
109 |
+
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
|
110 |
+
|
111 |
+
return {
|
112 |
+
"backbone": backbone,
|
113 |
+
"sem_seg_head": sem_seg_head,
|
114 |
+
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
|
115 |
+
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
|
116 |
+
"pixel_std": cfg.MODEL.PIXEL_STD,
|
117 |
+
"clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN,
|
118 |
+
"clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD,
|
119 |
+
"train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON,
|
120 |
+
"test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON,
|
121 |
+
"sliding_window": cfg.TEST.SLIDING_WINDOW,
|
122 |
+
"clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE,
|
123 |
+
"backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER,
|
124 |
+
"clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED,
|
125 |
+
}
|
126 |
+
|
127 |
+
@property
|
128 |
+
def device(self):
|
129 |
+
return self.pixel_mean.device
|
130 |
+
|
131 |
+
def forward(self, batched_inputs):
|
132 |
+
"""
|
133 |
+
Args:
|
134 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
135 |
+
Each item in the list contains the inputs for one image.
|
136 |
+
For now, each item in the list is a dict that contains:
|
137 |
+
* "image": Tensor, image in (C, H, W) format.
|
138 |
+
* "instances": per-region ground truth
|
139 |
+
* Other information that's included in the original dicts, such as:
|
140 |
+
"height", "width" (int): the output resolution of the model (may be different
|
141 |
+
from input resolution), used in inference.
|
142 |
+
Returns:
|
143 |
+
list[dict]:
|
144 |
+
each dict has the results for one image. The dict contains the following keys:
|
145 |
+
|
146 |
+
* "sem_seg":
|
147 |
+
A Tensor that represents the
|
148 |
+
per-pixel segmentation prediced by the head.
|
149 |
+
The prediction has shape KxHxW that represents the logits of
|
150 |
+
each class for each pixel.
|
151 |
+
"""
|
152 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
153 |
+
sam_images = images
|
154 |
+
if not self.training and self.sliding_window:
|
155 |
+
if not self.sequential:
|
156 |
+
with _ignore_torch_cuda_oom():
|
157 |
+
return self.inference_sliding_window(batched_inputs)
|
158 |
+
self.sequential = True
|
159 |
+
return self.inference_sliding_window(batched_inputs)
|
160 |
+
|
161 |
+
clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images]
|
162 |
+
clip_images = ImageList.from_tensors(clip_images, self.size_divisibility)
|
163 |
+
|
164 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
165 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
166 |
+
|
167 |
+
clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, )
|
168 |
+
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
|
169 |
+
|
170 |
+
images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,)
|
171 |
+
features = self.backbone(images_resized)
|
172 |
+
|
173 |
+
outputs = self.sem_seg_head(clip_features, features)
|
174 |
+
|
175 |
+
if self.training:
|
176 |
+
targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0)
|
177 |
+
outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False)
|
178 |
+
|
179 |
+
num_classes = outputs.shape[1]
|
180 |
+
mask = targets != self.sem_seg_head.ignore_value
|
181 |
+
|
182 |
+
outputs = outputs.permute(0,2,3,1)
|
183 |
+
_targets = torch.zeros(outputs.shape, device=self.device)
|
184 |
+
_onehot = F.one_hot(targets[mask], num_classes=num_classes).float()
|
185 |
+
_targets[mask] = _onehot
|
186 |
+
|
187 |
+
loss = F.binary_cross_entropy_with_logits(outputs, _targets)
|
188 |
+
losses = {"loss_sem_seg" : loss}
|
189 |
+
return losses
|
190 |
+
else:
|
191 |
+
#outputs = outputs.sigmoid()
|
192 |
+
image_size = images.image_sizes[0]
|
193 |
+
if self.use_sam:
|
194 |
+
masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
|
195 |
+
outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size)
|
196 |
+
#outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size)
|
197 |
+
#outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text)
|
198 |
+
height = batched_inputs[0].get("height", image_size[0])
|
199 |
+
width = batched_inputs[0].get("width", image_size[1])
|
200 |
+
|
201 |
+
output = sem_seg_postprocess(outputs[0], image_size, height, width)
|
202 |
+
processed_results = [{'sem_seg': output}]
|
203 |
+
return processed_results
|
204 |
+
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
|
208 |
+
|
209 |
+
images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs]
|
210 |
+
stride = int(kernel * (1 - overlap))
|
211 |
+
unfold = nn.Unfold(kernel_size=kernel, stride=stride)
|
212 |
+
fold = nn.Fold(out_res, kernel_size=kernel, stride=stride)
|
213 |
+
|
214 |
+
image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
|
215 |
+
sam_images = [image]
|
216 |
+
image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
|
217 |
+
global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
|
218 |
+
image = torch.cat((image, global_image), dim=0)
|
219 |
+
|
220 |
+
images = (image - self.pixel_mean) / self.pixel_std
|
221 |
+
clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std
|
222 |
+
clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, )
|
223 |
+
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
|
224 |
+
|
225 |
+
if self.sequential:
|
226 |
+
outputs = []
|
227 |
+
for clip_feat, image in zip(clip_features, images):
|
228 |
+
feature = self.backbone(image.unsqueeze(0))
|
229 |
+
output = self.sem_seg_head(clip_feat.unsqueeze(0), feature)
|
230 |
+
outputs.append(output[0])
|
231 |
+
outputs = torch.stack(outputs, dim=0)
|
232 |
+
else:
|
233 |
+
features = self.backbone(images)
|
234 |
+
outputs = self.sem_seg_head(clip_features, features)
|
235 |
+
|
236 |
+
outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False)
|
237 |
+
outputs = outputs.sigmoid()
|
238 |
+
|
239 |
+
global_output = outputs[-1:]
|
240 |
+
global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
|
241 |
+
outputs = outputs[:-1]
|
242 |
+
outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
|
243 |
+
outputs = (outputs + global_output) / 2.
|
244 |
+
|
245 |
+
height = batched_inputs[0].get("height", out_res[0])
|
246 |
+
width = batched_inputs[0].get("width", out_res[1])
|
247 |
+
catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width)
|
248 |
+
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
|
249 |
+
|
250 |
+
masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
|
251 |
+
if self.use_sam:
|
252 |
+
outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res)
|
253 |
+
#outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res)
|
254 |
+
|
255 |
+
output = sem_seg_postprocess(outputs[0], out_res, height, width)
|
256 |
+
|
257 |
+
ret = [{'sem_seg': output}]
|
258 |
+
if self.panoptic_on:
|
259 |
+
panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:])
|
260 |
+
ret[0]['panoptic_seg'] = panoptic_r
|
261 |
+
|
262 |
+
return ret
|
263 |
+
|
264 |
+
def discrete_semantic_inference(self, outputs, masks, image_size):
|
265 |
+
catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu()
|
266 |
+
sam_outputs = torch.zeros_like(catseg_outputs).cpu()
|
267 |
+
catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
|
268 |
+
sam_classes = torch.zeros(len(masks))
|
269 |
+
for i in range(len(masks)):
|
270 |
+
m = masks[i]['segmentation']
|
271 |
+
s = masks[i]['stability_score']
|
272 |
+
idx = catseg_outputs[m].bincount().argmax()
|
273 |
+
sam_outputs[0, idx][m] = s
|
274 |
+
sam_classes[i] = idx
|
275 |
+
|
276 |
+
return sam_outputs, sam_classes
|
277 |
+
|
278 |
+
def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.):
|
279 |
+
#import pdb; pdb.set_trace()
|
280 |
+
catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
|
281 |
+
sam_outputs = torch.zeros_like(catseg_outputs)
|
282 |
+
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
|
283 |
+
sam_classes = torch.zeros(len(masks))
|
284 |
+
#import pdb; pdb.set_trace()
|
285 |
+
mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
|
286 |
+
mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
|
287 |
+
|
288 |
+
mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
|
289 |
+
mask_norm = mask_pred.sum(-1).sum(-1)
|
290 |
+
mask_cls = mask_cls / mask_norm[:, None]
|
291 |
+
mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
|
292 |
+
|
293 |
+
mask_logits = mask_pred * mask_score[:, None, None]
|
294 |
+
output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
|
295 |
+
|
296 |
+
return output.unsqueeze(0), mask_cls
|
297 |
+
|
298 |
+
def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None):
|
299 |
+
assert img is not None and text is not None
|
300 |
+
import pdb; pdb.set_trace()
|
301 |
+
#catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
|
302 |
+
img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
|
303 |
+
img = img.permute(1, 2, 0)
|
304 |
+
|
305 |
+
#sam_outputs = torch.zeros_like(catseg_outputs)
|
306 |
+
#catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
|
307 |
+
sam_classes = torch.zeros(len(masks))
|
308 |
+
#import pdb; pdb.set_trace()
|
309 |
+
mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
|
310 |
+
mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
|
311 |
+
|
312 |
+
mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img)
|
313 |
+
mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True)
|
314 |
+
mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu())
|
315 |
+
mask_cls = mask_cls.softmax(dim=1)
|
316 |
+
|
317 |
+
#mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
|
318 |
+
mask_norm = mask_pred.sum(-1).sum(-1)
|
319 |
+
mask_cls = mask_cls / mask_norm[:, None]
|
320 |
+
mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
|
321 |
+
|
322 |
+
mask_logits = mask_pred * mask_score[:, None, None]
|
323 |
+
output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
|
324 |
+
|
325 |
+
return output.unsqueeze(0), sam_classes
|
326 |
+
|
327 |
+
def panoptic_inference(self, outputs, masks, sam_classes, size=None):
|
328 |
+
#import pdb; pdb.set_trace()
|
329 |
+
scores = np.asarray([x['predicted_iou'] for x in masks])
|
330 |
+
mask_pred = np.asarray([x['segmentation'] for x in masks])
|
331 |
+
|
332 |
+
#keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
333 |
+
cur_scores = torch.tensor(scores)
|
334 |
+
cur_masks = torch.tensor(mask_pred)
|
335 |
+
cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0]
|
336 |
+
cur_classes = sam_classes.argmax(dim=-1)
|
337 |
+
#cur_mask_cls = mask_cls#[keep]
|
338 |
+
#cur_mask_cls = cur_mask_cls[:, :-1]
|
339 |
+
|
340 |
+
#import pdb; pdb.set_trace()
|
341 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
342 |
+
|
343 |
+
h, w = cur_masks.shape[-2:]
|
344 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
345 |
+
segments_info = []
|
346 |
+
|
347 |
+
current_segment_id = 0
|
348 |
+
if cur_masks.shape[0] == 0:
|
349 |
+
# We didn't detect any mask :(
|
350 |
+
return panoptic_seg, segments_info
|
351 |
+
else:
|
352 |
+
# take argmax
|
353 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
354 |
+
stuff_memory_list = {}
|
355 |
+
for k in range(cur_classes.shape[0]):
|
356 |
+
pred_class = cur_classes[k].item()
|
357 |
+
#isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
358 |
+
isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values()
|
359 |
+
mask = cur_mask_ids == k
|
360 |
+
mask_area = mask.sum().item()
|
361 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
362 |
+
|
363 |
+
if mask_area > 0 and original_area > 0:
|
364 |
+
if mask_area / original_area < self.overlap_threshold:
|
365 |
+
continue
|
366 |
+
|
367 |
+
# merge stuff regions
|
368 |
+
if not isthing:
|
369 |
+
if int(pred_class) in stuff_memory_list.keys():
|
370 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
371 |
+
continue
|
372 |
+
else:
|
373 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
374 |
+
|
375 |
+
current_segment_id += 1
|
376 |
+
panoptic_seg[mask] = current_segment_id
|
377 |
+
|
378 |
+
segments_info.append(
|
379 |
+
{
|
380 |
+
"id": current_segment_id,
|
381 |
+
"isthing": bool(isthing),
|
382 |
+
"category_id": int(pred_class),
|
383 |
+
}
|
384 |
+
)
|
385 |
+
|
386 |
+
return panoptic_seg, segments_info
|
cat_seg/config.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
from detectron2.config import CfgNode as CN
|
4 |
+
|
5 |
+
|
6 |
+
def add_cat_seg_config(cfg):
|
7 |
+
"""
|
8 |
+
Add config for MASK_FORMER.
|
9 |
+
"""
|
10 |
+
# data config
|
11 |
+
# select the dataset mapper
|
12 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
|
13 |
+
|
14 |
+
cfg.DATASETS.VAL_ALL = ("coco_2017_val_all_stuff_sem_seg",)
|
15 |
+
|
16 |
+
# Color augmentation
|
17 |
+
cfg.INPUT.COLOR_AUG_SSD = False
|
18 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
19 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
20 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
21 |
+
# Pad image and segmentation GT in dataset mapper.
|
22 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
23 |
+
|
24 |
+
# solver config
|
25 |
+
# weight decay on embedding
|
26 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
|
27 |
+
# optimizer
|
28 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
29 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
30 |
+
|
31 |
+
# mask_former model config
|
32 |
+
cfg.MODEL.MASK_FORMER = CN()
|
33 |
+
|
34 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
35 |
+
# you can use this config to override
|
36 |
+
cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
|
37 |
+
|
38 |
+
# swin transformer backbone
|
39 |
+
cfg.MODEL.SWIN = CN()
|
40 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
41 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
42 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
43 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
44 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
45 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
46 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
47 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
48 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
49 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
50 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
51 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
52 |
+
cfg.MODEL.SWIN.APE = False
|
53 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
54 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
55 |
+
|
56 |
+
# zero shot config
|
57 |
+
cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
|
58 |
+
cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
|
59 |
+
cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_INDEXES = "datasets/coco/coco_stuff/split/seen_indexes.json"
|
60 |
+
cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_INDEXES = "datasets/coco/coco_stuff/split/unseen_indexes.json"
|
61 |
+
|
62 |
+
cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED = "ViT-B/16"
|
63 |
+
|
64 |
+
cfg.MODEL.PROMPT_ENSEMBLE = False
|
65 |
+
cfg.MODEL.PROMPT_ENSEMBLE_TYPE = "single"
|
66 |
+
|
67 |
+
cfg.MODEL.CLIP_PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615]
|
68 |
+
cfg.MODEL.CLIP_PIXEL_STD = [68.5005327, 66.6321579, 70.3231630]
|
69 |
+
# three styles for clip classification, crop, mask, cropmask
|
70 |
+
|
71 |
+
cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM = 512
|
72 |
+
cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM = 128
|
73 |
+
cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM = 512
|
74 |
+
cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM = 128
|
75 |
+
|
76 |
+
cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS = [64, 32]
|
77 |
+
cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS = [256, 128]
|
78 |
+
cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS = [32, 16]
|
79 |
+
|
80 |
+
cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS = 4
|
81 |
+
cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS = 4
|
82 |
+
cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS = 128
|
83 |
+
cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES = [6, 6]
|
84 |
+
cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION = [24, 24]
|
85 |
+
cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES = 12
|
86 |
+
cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE = "linear"
|
87 |
+
|
88 |
+
cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH = 0
|
89 |
+
cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH = 0
|
90 |
+
cfg.SOLVER.CLIP_MULTIPLIER = 0.01
|
91 |
+
|
92 |
+
cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE = "attention"
|
93 |
+
cfg.TEST.SLIDING_WINDOW = False
|
cat_seg/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import datasets
|
cat_seg/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (184 Bytes). View file
|
|
cat_seg/data/dataset_mappers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (167 Bytes). View file
|
|
cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc
ADDED
Binary file (4.88 kB). View file
|
|
cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc
ADDED
Binary file (4.41 kB). View file
|
|
cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc
ADDED
Binary file (5.05 kB). View file
|
|
cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import detection_utils as utils
|
11 |
+
from detectron2.data import transforms as T
|
12 |
+
from detectron2.data.transforms import TransformGen
|
13 |
+
from detectron2.structures import BitMasks, Instances
|
14 |
+
|
15 |
+
__all__ = ["DETRPanopticDatasetMapper"]
|
16 |
+
|
17 |
+
|
18 |
+
def build_transform_gen(cfg, is_train):
|
19 |
+
"""
|
20 |
+
Create a list of :class:`TransformGen` from config.
|
21 |
+
Returns:
|
22 |
+
list[TransformGen]
|
23 |
+
"""
|
24 |
+
if is_train:
|
25 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
26 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
27 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
28 |
+
else:
|
29 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
30 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
31 |
+
sample_style = "choice"
|
32 |
+
if sample_style == "range":
|
33 |
+
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
|
34 |
+
len(min_size)
|
35 |
+
)
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
tfm_gens = []
|
39 |
+
if is_train:
|
40 |
+
tfm_gens.append(T.RandomFlip())
|
41 |
+
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
42 |
+
if is_train:
|
43 |
+
logger.info("TransformGens used in training: " + str(tfm_gens))
|
44 |
+
return tfm_gens
|
45 |
+
|
46 |
+
|
47 |
+
# This is specifically designed for the COCO dataset.
|
48 |
+
class DETRPanopticDatasetMapper:
|
49 |
+
"""
|
50 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
51 |
+
and map it into a format used by MaskFormer.
|
52 |
+
|
53 |
+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
|
54 |
+
|
55 |
+
The callable currently does the following:
|
56 |
+
|
57 |
+
1. Read the image from "file_name"
|
58 |
+
2. Applies geometric transforms to the image and annotation
|
59 |
+
3. Find and applies suitable cropping to the image and annotation
|
60 |
+
4. Prepare image and annotation to Tensors
|
61 |
+
"""
|
62 |
+
|
63 |
+
@configurable
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
is_train=True,
|
67 |
+
*,
|
68 |
+
crop_gen,
|
69 |
+
tfm_gens,
|
70 |
+
image_format,
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
NOTE: this interface is experimental.
|
74 |
+
Args:
|
75 |
+
is_train: for training or inference
|
76 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
77 |
+
crop_gen: crop augmentation
|
78 |
+
tfm_gens: data augmentation
|
79 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
80 |
+
"""
|
81 |
+
self.crop_gen = crop_gen
|
82 |
+
self.tfm_gens = tfm_gens
|
83 |
+
logging.getLogger(__name__).info(
|
84 |
+
"[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
|
85 |
+
str(self.tfm_gens), str(self.crop_gen)
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
self.img_format = image_format
|
90 |
+
self.is_train = is_train
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_config(cls, cfg, is_train=True):
|
94 |
+
# Build augmentation
|
95 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
96 |
+
crop_gen = [
|
97 |
+
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
|
98 |
+
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
|
99 |
+
]
|
100 |
+
else:
|
101 |
+
crop_gen = None
|
102 |
+
|
103 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
104 |
+
|
105 |
+
ret = {
|
106 |
+
"is_train": is_train,
|
107 |
+
"crop_gen": crop_gen,
|
108 |
+
"tfm_gens": tfm_gens,
|
109 |
+
"image_format": cfg.INPUT.FORMAT,
|
110 |
+
}
|
111 |
+
return ret
|
112 |
+
|
113 |
+
def __call__(self, dataset_dict):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
dict: a format that builtin models in detectron2 accept
|
120 |
+
"""
|
121 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
122 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
123 |
+
utils.check_image_size(dataset_dict, image)
|
124 |
+
|
125 |
+
if self.crop_gen is None:
|
126 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
127 |
+
else:
|
128 |
+
if np.random.rand() > 0.5:
|
129 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
130 |
+
else:
|
131 |
+
image, transforms = T.apply_transform_gens(
|
132 |
+
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
|
133 |
+
)
|
134 |
+
|
135 |
+
image_shape = image.shape[:2] # h, w
|
136 |
+
|
137 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
138 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
139 |
+
# Therefore it's important to use torch.Tensor.
|
140 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
141 |
+
|
142 |
+
if not self.is_train:
|
143 |
+
# USER: Modify this if you want to keep them for some reason.
|
144 |
+
dataset_dict.pop("annotations", None)
|
145 |
+
return dataset_dict
|
146 |
+
|
147 |
+
if "pan_seg_file_name" in dataset_dict:
|
148 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
149 |
+
segments_info = dataset_dict["segments_info"]
|
150 |
+
|
151 |
+
# apply the same transformation to panoptic segmentation
|
152 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
153 |
+
|
154 |
+
from panopticapi.utils import rgb2id
|
155 |
+
|
156 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
157 |
+
|
158 |
+
instances = Instances(image_shape)
|
159 |
+
classes = []
|
160 |
+
masks = []
|
161 |
+
for segment_info in segments_info:
|
162 |
+
class_id = segment_info["category_id"]
|
163 |
+
if not segment_info["iscrowd"]:
|
164 |
+
classes.append(class_id)
|
165 |
+
masks.append(pan_seg_gt == segment_info["id"])
|
166 |
+
|
167 |
+
classes = np.array(classes)
|
168 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
169 |
+
if len(masks) == 0:
|
170 |
+
# Some image does not have annotation (all ignored)
|
171 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
172 |
+
else:
|
173 |
+
masks = BitMasks(
|
174 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
175 |
+
)
|
176 |
+
instances.gt_masks = masks.tensor
|
177 |
+
|
178 |
+
dataset_dict["instances"] = instances
|
179 |
+
|
180 |
+
return dataset_dict
|
cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import detection_utils as utils
|
11 |
+
from detectron2.data import transforms as T
|
12 |
+
from detectron2.structures import BitMasks, Instances
|
13 |
+
|
14 |
+
from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
|
15 |
+
|
16 |
+
__all__ = ["MaskFormerPanopticDatasetMapper"]
|
17 |
+
|
18 |
+
|
19 |
+
class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
|
20 |
+
"""
|
21 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
22 |
+
and map it into a format used by MaskFormer for panoptic segmentation.
|
23 |
+
|
24 |
+
The callable currently does the following:
|
25 |
+
|
26 |
+
1. Read the image from "file_name"
|
27 |
+
2. Applies geometric transforms to the image and annotation
|
28 |
+
3. Find and applies suitable cropping to the image and annotation
|
29 |
+
4. Prepare image and annotation to Tensors
|
30 |
+
"""
|
31 |
+
|
32 |
+
@configurable
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
is_train=True,
|
36 |
+
*,
|
37 |
+
augmentations,
|
38 |
+
image_format,
|
39 |
+
ignore_label,
|
40 |
+
size_divisibility,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
NOTE: this interface is experimental.
|
44 |
+
Args:
|
45 |
+
is_train: for training or inference
|
46 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
47 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
48 |
+
ignore_label: the label that is ignored to evaluation
|
49 |
+
size_divisibility: pad image size to be divisible by this value
|
50 |
+
"""
|
51 |
+
super().__init__(
|
52 |
+
is_train,
|
53 |
+
augmentations=augmentations,
|
54 |
+
image_format=image_format,
|
55 |
+
ignore_label=ignore_label,
|
56 |
+
size_divisibility=size_divisibility,
|
57 |
+
)
|
58 |
+
|
59 |
+
def __call__(self, dataset_dict):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
dict: a format that builtin models in detectron2 accept
|
66 |
+
"""
|
67 |
+
assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
|
68 |
+
|
69 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
70 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
71 |
+
utils.check_image_size(dataset_dict, image)
|
72 |
+
|
73 |
+
# semantic segmentation
|
74 |
+
if "sem_seg_file_name" in dataset_dict:
|
75 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
76 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
77 |
+
else:
|
78 |
+
sem_seg_gt = None
|
79 |
+
|
80 |
+
# panoptic segmentation
|
81 |
+
if "pan_seg_file_name" in dataset_dict:
|
82 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
83 |
+
segments_info = dataset_dict["segments_info"]
|
84 |
+
else:
|
85 |
+
pan_seg_gt = None
|
86 |
+
segments_info = None
|
87 |
+
|
88 |
+
if pan_seg_gt is None:
|
89 |
+
raise ValueError(
|
90 |
+
"Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
|
91 |
+
dataset_dict["file_name"]
|
92 |
+
)
|
93 |
+
)
|
94 |
+
|
95 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
96 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
97 |
+
image = aug_input.image
|
98 |
+
if sem_seg_gt is not None:
|
99 |
+
sem_seg_gt = aug_input.sem_seg
|
100 |
+
|
101 |
+
# apply the same transformation to panoptic segmentation
|
102 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
103 |
+
|
104 |
+
from panopticapi.utils import rgb2id
|
105 |
+
|
106 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
107 |
+
|
108 |
+
# Pad image and segmentation label here!
|
109 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
110 |
+
if sem_seg_gt is not None:
|
111 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
112 |
+
pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
|
113 |
+
|
114 |
+
if self.size_divisibility > 0:
|
115 |
+
image_size = (image.shape[-2], image.shape[-1])
|
116 |
+
padding_size = [
|
117 |
+
0,
|
118 |
+
self.size_divisibility - image_size[1],
|
119 |
+
0,
|
120 |
+
self.size_divisibility - image_size[0],
|
121 |
+
]
|
122 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
123 |
+
if sem_seg_gt is not None:
|
124 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
125 |
+
pan_seg_gt = F.pad(
|
126 |
+
pan_seg_gt, padding_size, value=0
|
127 |
+
).contiguous() # 0 is the VOID panoptic label
|
128 |
+
|
129 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
130 |
+
|
131 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
132 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
133 |
+
# Therefore it's important to use torch.Tensor.
|
134 |
+
dataset_dict["image"] = image
|
135 |
+
if sem_seg_gt is not None:
|
136 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
137 |
+
|
138 |
+
if "annotations" in dataset_dict:
|
139 |
+
raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
|
140 |
+
|
141 |
+
# Prepare per-category binary masks
|
142 |
+
pan_seg_gt = pan_seg_gt.numpy()
|
143 |
+
instances = Instances(image_shape)
|
144 |
+
classes = []
|
145 |
+
masks = []
|
146 |
+
for segment_info in segments_info:
|
147 |
+
class_id = segment_info["category_id"]
|
148 |
+
if not segment_info["iscrowd"]:
|
149 |
+
classes.append(class_id)
|
150 |
+
masks.append(pan_seg_gt == segment_info["id"])
|
151 |
+
|
152 |
+
classes = np.array(classes)
|
153 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
154 |
+
if len(masks) == 0:
|
155 |
+
# Some image does not have annotation (all ignored)
|
156 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
157 |
+
else:
|
158 |
+
masks = BitMasks(
|
159 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
160 |
+
)
|
161 |
+
instances.gt_masks = masks.tensor
|
162 |
+
|
163 |
+
dataset_dict["instances"] = instances
|
164 |
+
|
165 |
+
return dataset_dict
|
cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from detectron2.config import configurable
|
10 |
+
from detectron2.data import MetadataCatalog
|
11 |
+
from detectron2.data import detection_utils as utils
|
12 |
+
from detectron2.data import transforms as T
|
13 |
+
from detectron2.projects.point_rend import ColorAugSSDTransform
|
14 |
+
from detectron2.structures import BitMasks, Instances
|
15 |
+
|
16 |
+
__all__ = ["MaskFormerSemanticDatasetMapper"]
|
17 |
+
|
18 |
+
|
19 |
+
class MaskFormerSemanticDatasetMapper:
|
20 |
+
"""
|
21 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
22 |
+
and map it into a format used by MaskFormer for semantic segmentation.
|
23 |
+
|
24 |
+
The callable currently does the following:
|
25 |
+
|
26 |
+
1. Read the image from "file_name"
|
27 |
+
2. Applies geometric transforms to the image and annotation
|
28 |
+
3. Find and applies suitable cropping to the image and annotation
|
29 |
+
4. Prepare image and annotation to Tensors
|
30 |
+
"""
|
31 |
+
|
32 |
+
@configurable
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
is_train=True,
|
36 |
+
*,
|
37 |
+
augmentations,
|
38 |
+
image_format,
|
39 |
+
ignore_label,
|
40 |
+
size_divisibility,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
NOTE: this interface is experimental.
|
44 |
+
Args:
|
45 |
+
is_train: for training or inference
|
46 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
47 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
48 |
+
ignore_label: the label that is ignored to evaluation
|
49 |
+
size_divisibility: pad image size to be divisible by this value
|
50 |
+
"""
|
51 |
+
self.is_train = is_train
|
52 |
+
self.tfm_gens = augmentations
|
53 |
+
self.img_format = image_format
|
54 |
+
self.ignore_label = ignore_label
|
55 |
+
self.size_divisibility = size_divisibility
|
56 |
+
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
mode = "training" if is_train else "inference"
|
59 |
+
logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_config(cls, cfg, is_train=True):
|
63 |
+
# Build augmentation
|
64 |
+
augs = [
|
65 |
+
T.ResizeShortestEdge(
|
66 |
+
cfg.INPUT.MIN_SIZE_TRAIN,
|
67 |
+
cfg.INPUT.MAX_SIZE_TRAIN,
|
68 |
+
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
|
69 |
+
)
|
70 |
+
]
|
71 |
+
if cfg.INPUT.CROP.ENABLED:
|
72 |
+
augs.append(
|
73 |
+
T.RandomCrop_CategoryAreaConstraint(
|
74 |
+
cfg.INPUT.CROP.TYPE,
|
75 |
+
cfg.INPUT.CROP.SIZE,
|
76 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
|
77 |
+
cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
if cfg.INPUT.COLOR_AUG_SSD:
|
81 |
+
augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
|
82 |
+
augs.append(T.RandomFlip())
|
83 |
+
|
84 |
+
# Assume always applies to the training set.
|
85 |
+
dataset_names = cfg.DATASETS.TRAIN
|
86 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
87 |
+
ignore_label = meta.ignore_label
|
88 |
+
|
89 |
+
ret = {
|
90 |
+
"is_train": is_train,
|
91 |
+
"augmentations": augs,
|
92 |
+
"image_format": cfg.INPUT.FORMAT,
|
93 |
+
"ignore_label": ignore_label,
|
94 |
+
"size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
|
95 |
+
}
|
96 |
+
return ret
|
97 |
+
|
98 |
+
def __call__(self, dataset_dict):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
dict: a format that builtin models in detectron2 accept
|
105 |
+
"""
|
106 |
+
assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
|
107 |
+
|
108 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
109 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
110 |
+
utils.check_image_size(dataset_dict, image)
|
111 |
+
|
112 |
+
if "sem_seg_file_name" in dataset_dict:
|
113 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
114 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
115 |
+
else:
|
116 |
+
sem_seg_gt = None
|
117 |
+
|
118 |
+
if sem_seg_gt is None:
|
119 |
+
raise ValueError(
|
120 |
+
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
|
121 |
+
dataset_dict["file_name"]
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
126 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
127 |
+
image = aug_input.image
|
128 |
+
sem_seg_gt = aug_input.sem_seg
|
129 |
+
|
130 |
+
# Pad image and segmentation label here!
|
131 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
132 |
+
if sem_seg_gt is not None:
|
133 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
134 |
+
# import ipdb; ipdb.set_trace()
|
135 |
+
if self.size_divisibility > 0:
|
136 |
+
image_size = (image.shape[-2], image.shape[-1])
|
137 |
+
# The ori_size is not the real original size, but size before padding
|
138 |
+
dataset_dict['ori_size'] = image_size
|
139 |
+
padding_size = [
|
140 |
+
0,
|
141 |
+
self.size_divisibility - image_size[1], # w: (left, right)
|
142 |
+
0,
|
143 |
+
self.size_divisibility - image_size[0], # h: 0,(top, bottom)
|
144 |
+
]
|
145 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
146 |
+
if sem_seg_gt is not None:
|
147 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
148 |
+
|
149 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
150 |
+
|
151 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
152 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
153 |
+
# Therefore it's important to use torch.Tensor.
|
154 |
+
dataset_dict["image"] = image
|
155 |
+
# print('#########################################################################################')
|
156 |
+
if sem_seg_gt is not None:
|
157 |
+
dataset_dict["sem_seg"] = sem_seg_gt.long()
|
158 |
+
|
159 |
+
if "annotations" in dataset_dict:
|
160 |
+
raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
|
161 |
+
|
162 |
+
# Prepare per-category binary masks
|
163 |
+
if sem_seg_gt is not None:
|
164 |
+
sem_seg_gt = sem_seg_gt.numpy()
|
165 |
+
instances = Instances(image_shape)
|
166 |
+
classes = np.unique(sem_seg_gt)
|
167 |
+
# remove ignored region
|
168 |
+
classes = classes[classes != self.ignore_label]
|
169 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
170 |
+
|
171 |
+
masks = []
|
172 |
+
for class_id in classes:
|
173 |
+
masks.append(sem_seg_gt == class_id)
|
174 |
+
|
175 |
+
if len(masks) == 0:
|
176 |
+
# Some image does not have annotation (all ignored)
|
177 |
+
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
|
178 |
+
else:
|
179 |
+
masks = BitMasks(
|
180 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
181 |
+
)
|
182 |
+
instances.gt_masks = masks.tensor
|
183 |
+
|
184 |
+
dataset_dict["instances"] = instances
|
185 |
+
|
186 |
+
return dataset_dict
|
cat_seg/data/datasets/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import (
|
3 |
+
register_coco_stuff,
|
4 |
+
register_ade20k_150,
|
5 |
+
register_ade20k_847,
|
6 |
+
register_pascal_20,
|
7 |
+
register_pascal_59,
|
8 |
+
)
|
cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (322 Bytes). View file
|
|
cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc
ADDED
Binary file (2.88 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc
ADDED
Binary file (51.8 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc
ADDED
Binary file (4.75 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc
ADDED
Binary file (7.85 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc
ADDED
Binary file (2.47 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc
ADDED
Binary file (9.57 kB). View file
|
|
cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc
ADDED
Binary file (9.56 kB). View file
|
|
cat_seg/data/datasets/register_ade20k_150.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
4 |
+
from detectron2.data.datasets import load_sem_seg
|
5 |
+
import copy
|
6 |
+
|
7 |
+
def _get_ade20k_150_meta():
|
8 |
+
ade20k_150_classes = ["wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"]
|
9 |
+
|
10 |
+
ret = {
|
11 |
+
"stuff_classes" : ade20k_150_classes,
|
12 |
+
}
|
13 |
+
return ret
|
14 |
+
|
15 |
+
def register_ade20k_150(root):
|
16 |
+
root = os.path.join(root, "ADEChallengeData2016")
|
17 |
+
meta = _get_ade20k_150_meta()
|
18 |
+
for name, image_dirname, sem_seg_dirname in [
|
19 |
+
("test", "images/validation", "annotations_detectron2/validation"),
|
20 |
+
]:
|
21 |
+
image_dir = os.path.join(root, image_dirname)
|
22 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
23 |
+
name = f"ade20k_150_{name}_sem_seg"
|
24 |
+
DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
|
25 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
|
26 |
+
|
27 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
28 |
+
register_ade20k_150(_root)
|
cat_seg/data/datasets/register_ade20k_847.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cat_seg/data/datasets/register_coco_stuff.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
4 |
+
from detectron2.data.datasets import load_sem_seg
|
5 |
+
|
6 |
+
COCO_CATEGORIES = [
|
7 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
8 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
9 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
10 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
11 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
12 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
13 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
14 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
15 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
16 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
17 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
18 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
19 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
20 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
21 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
22 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
23 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
24 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
25 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
26 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
27 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
28 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
29 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
30 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
31 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
32 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
33 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
34 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
35 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
36 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
37 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
38 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
39 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
40 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
41 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
42 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
43 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
44 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
45 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
46 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
47 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
48 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
49 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
50 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
51 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
52 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
53 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
54 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
55 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
56 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
57 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
58 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
59 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
60 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
61 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
62 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
63 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
64 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
65 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
66 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
67 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
68 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
69 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
70 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
71 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
72 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
73 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
74 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
75 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
76 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
77 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
78 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
79 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
80 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
81 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
82 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
83 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
84 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
85 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
86 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
87 |
+
{"id": 92, "name": "banner", "supercategory": "textile"},
|
88 |
+
{"id": 93, "name": "blanket", "supercategory": "textile"},
|
89 |
+
{"id": 94, "name": "branch", "supercategory": "plant"},
|
90 |
+
{"id": 95, "name": "bridge", "supercategory": "building"},
|
91 |
+
{"id": 96, "name": "building-other", "supercategory": "building"},
|
92 |
+
{"id": 97, "name": "bush", "supercategory": "plant"},
|
93 |
+
{"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
|
94 |
+
{"id": 99, "name": "cage", "supercategory": "structural"},
|
95 |
+
{"id": 100, "name": "cardboard", "supercategory": "raw-material"},
|
96 |
+
{"id": 101, "name": "carpet", "supercategory": "floor"},
|
97 |
+
{"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
|
98 |
+
{"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
|
99 |
+
{"id": 104, "name": "cloth", "supercategory": "textile"},
|
100 |
+
{"id": 105, "name": "clothes", "supercategory": "textile"},
|
101 |
+
{"id": 106, "name": "clouds", "supercategory": "sky"},
|
102 |
+
{"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
|
103 |
+
{"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
|
104 |
+
{"id": 109, "name": "curtain", "supercategory": "textile"},
|
105 |
+
{"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
|
106 |
+
{"id": 111, "name": "dirt", "supercategory": "ground"},
|
107 |
+
{"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
|
108 |
+
{"id": 113, "name": "fence", "supercategory": "structural"},
|
109 |
+
{"id": 114, "name": "floor-marble", "supercategory": "floor"},
|
110 |
+
{"id": 115, "name": "floor-other", "supercategory": "floor"},
|
111 |
+
{"id": 116, "name": "floor-stone", "supercategory": "floor"},
|
112 |
+
{"id": 117, "name": "floor-tile", "supercategory": "floor"},
|
113 |
+
{"id": 118, "name": "floor-wood", "supercategory": "floor"},
|
114 |
+
{"id": 119, "name": "flower", "supercategory": "plant"},
|
115 |
+
{"id": 120, "name": "fog", "supercategory": "water"},
|
116 |
+
{"id": 121, "name": "food-other", "supercategory": "food-stuff"},
|
117 |
+
{"id": 122, "name": "fruit", "supercategory": "food-stuff"},
|
118 |
+
{"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
|
119 |
+
{"id": 124, "name": "grass", "supercategory": "plant"},
|
120 |
+
{"id": 125, "name": "gravel", "supercategory": "ground"},
|
121 |
+
{"id": 126, "name": "ground-other", "supercategory": "ground"},
|
122 |
+
{"id": 127, "name": "hill", "supercategory": "solid"},
|
123 |
+
{"id": 128, "name": "house", "supercategory": "building"},
|
124 |
+
{"id": 129, "name": "leaves", "supercategory": "plant"},
|
125 |
+
{"id": 130, "name": "light", "supercategory": "furniture-stuff"},
|
126 |
+
{"id": 131, "name": "mat", "supercategory": "textile"},
|
127 |
+
{"id": 132, "name": "metal", "supercategory": "raw-material"},
|
128 |
+
{"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
|
129 |
+
{"id": 134, "name": "moss", "supercategory": "plant"},
|
130 |
+
{"id": 135, "name": "mountain", "supercategory": "solid"},
|
131 |
+
{"id": 136, "name": "mud", "supercategory": "ground"},
|
132 |
+
{"id": 137, "name": "napkin", "supercategory": "textile"},
|
133 |
+
{"id": 138, "name": "net", "supercategory": "structural"},
|
134 |
+
{"id": 139, "name": "paper", "supercategory": "raw-material"},
|
135 |
+
{"id": 140, "name": "pavement", "supercategory": "ground"},
|
136 |
+
{"id": 141, "name": "pillow", "supercategory": "textile"},
|
137 |
+
{"id": 142, "name": "plant-other", "supercategory": "plant"},
|
138 |
+
{"id": 143, "name": "plastic", "supercategory": "raw-material"},
|
139 |
+
{"id": 144, "name": "platform", "supercategory": "ground"},
|
140 |
+
{"id": 145, "name": "playingfield", "supercategory": "ground"},
|
141 |
+
{"id": 146, "name": "railing", "supercategory": "structural"},
|
142 |
+
{"id": 147, "name": "railroad", "supercategory": "ground"},
|
143 |
+
{"id": 148, "name": "river", "supercategory": "water"},
|
144 |
+
{"id": 149, "name": "road", "supercategory": "ground"},
|
145 |
+
{"id": 150, "name": "rock", "supercategory": "solid"},
|
146 |
+
{"id": 151, "name": "roof", "supercategory": "building"},
|
147 |
+
{"id": 152, "name": "rug", "supercategory": "textile"},
|
148 |
+
{"id": 153, "name": "salad", "supercategory": "food-stuff"},
|
149 |
+
{"id": 154, "name": "sand", "supercategory": "ground"},
|
150 |
+
{"id": 155, "name": "sea", "supercategory": "water"},
|
151 |
+
{"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
|
152 |
+
{"id": 157, "name": "sky-other", "supercategory": "sky"},
|
153 |
+
{"id": 158, "name": "skyscraper", "supercategory": "building"},
|
154 |
+
{"id": 159, "name": "snow", "supercategory": "ground"},
|
155 |
+
{"id": 160, "name": "solid-other", "supercategory": "solid"},
|
156 |
+
{"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
|
157 |
+
{"id": 162, "name": "stone", "supercategory": "solid"},
|
158 |
+
{"id": 163, "name": "straw", "supercategory": "plant"},
|
159 |
+
{"id": 164, "name": "structural-other", "supercategory": "structural"},
|
160 |
+
{"id": 165, "name": "table", "supercategory": "furniture-stuff"},
|
161 |
+
{"id": 166, "name": "tent", "supercategory": "building"},
|
162 |
+
{"id": 167, "name": "textile-other", "supercategory": "textile"},
|
163 |
+
{"id": 168, "name": "towel", "supercategory": "textile"},
|
164 |
+
{"id": 169, "name": "tree", "supercategory": "plant"},
|
165 |
+
{"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
|
166 |
+
{"id": 171, "name": "wall-brick", "supercategory": "wall"},
|
167 |
+
{"id": 172, "name": "wall-concrete", "supercategory": "wall"},
|
168 |
+
{"id": 173, "name": "wall-other", "supercategory": "wall"},
|
169 |
+
{"id": 174, "name": "wall-panel", "supercategory": "wall"},
|
170 |
+
{"id": 175, "name": "wall-stone", "supercategory": "wall"},
|
171 |
+
{"id": 176, "name": "wall-tile", "supercategory": "wall"},
|
172 |
+
{"id": 177, "name": "wall-wood", "supercategory": "wall"},
|
173 |
+
{"id": 178, "name": "water-other", "supercategory": "water"},
|
174 |
+
{"id": 179, "name": "waterdrops", "supercategory": "water"},
|
175 |
+
{"id": 180, "name": "window-blind", "supercategory": "window"},
|
176 |
+
{"id": 181, "name": "window-other", "supercategory": "window"},
|
177 |
+
{"id": 182, "name": "wood", "supercategory": "solid"},
|
178 |
+
]
|
179 |
+
|
180 |
+
|
181 |
+
def _get_coco_stuff_meta():
|
182 |
+
stuff_ids = [k["id"] for k in COCO_CATEGORIES]
|
183 |
+
assert len(stuff_ids) == 171, len(stuff_ids)
|
184 |
+
|
185 |
+
stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
186 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
187 |
+
|
188 |
+
ret = {
|
189 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
190 |
+
"stuff_classes": stuff_classes,
|
191 |
+
}
|
192 |
+
return ret
|
193 |
+
|
194 |
+
def register_all_coco_stuff_10k(root):
|
195 |
+
root = os.path.join(root, "coco-stuff")
|
196 |
+
meta = _get_coco_stuff_meta()
|
197 |
+
for name, image_dirname, sem_seg_dirname in [
|
198 |
+
("train", "images/train2017", "annotations_detectron2/train2017"),
|
199 |
+
("test", "images/val2017", "annotations_detectron2/val2017"),
|
200 |
+
]:
|
201 |
+
image_dir = os.path.join(root, image_dirname)
|
202 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
203 |
+
name = f"coco_2017_{name}_stuff_all_sem_seg"
|
204 |
+
DatasetCatalog.register(
|
205 |
+
name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
|
206 |
+
)
|
207 |
+
MetadataCatalog.get(name).set(
|
208 |
+
image_root=image_dir,
|
209 |
+
sem_seg_root=gt_dir,
|
210 |
+
evaluator_type="sem_seg",
|
211 |
+
ignore_label=255,
|
212 |
+
**meta,
|
213 |
+
)
|
214 |
+
|
215 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
216 |
+
register_all_coco_stuff_10k(_root)
|
cat_seg/data/datasets/register_pascal_20.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
4 |
+
from detectron2.data.datasets import load_sem_seg
|
5 |
+
import copy
|
6 |
+
|
7 |
+
def _get_pascal_voc_meta():
|
8 |
+
voc_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
|
9 |
+
voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
10 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
11 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
12 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
13 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
14 |
+
ret = {
|
15 |
+
"stuff_classes" : voc_classes,
|
16 |
+
"stuff_colors" : voc_colors,
|
17 |
+
}
|
18 |
+
return ret
|
19 |
+
|
20 |
+
def register_all_pascal_voc(root):
|
21 |
+
root = os.path.join(root, "VOCdevkit/VOC2012")
|
22 |
+
meta = _get_pascal_voc_meta()
|
23 |
+
for name, image_dirname, sem_seg_dirname in [
|
24 |
+
("test", "JPEGImages", "annotations_detectron2"),
|
25 |
+
("test_background", "JPEGImages", "annotations_detectron2_bg"),
|
26 |
+
]:
|
27 |
+
image_dir = os.path.join(root, image_dirname)
|
28 |
+
gt_dir = os.path.join(root, sem_seg_dirname, 'val')
|
29 |
+
name = f"voc_2012_{name}_sem_seg"
|
30 |
+
|
31 |
+
DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
|
32 |
+
if "background" in name:
|
33 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255,
|
34 |
+
stuff_classes=meta["stuff_classes"] + ["background"], stuff_colors=meta["stuff_colors"])
|
35 |
+
else:
|
36 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
|
37 |
+
|
38 |
+
def register_all_pascal_voc_background(root):
|
39 |
+
root = os.path.join(root, "VOCdevkit/VOC2012")
|
40 |
+
meta = _get_pascal_voc_meta()
|
41 |
+
meta["stuff_classes"] = meta["stuff_classes"] + ["background"]
|
42 |
+
for name, image_dirname, sem_seg_dirname in [
|
43 |
+
("test_background", "image", "label_openseg_background20"),
|
44 |
+
]:
|
45 |
+
image_dir = os.path.join(root, image_dirname, 'validation')
|
46 |
+
gt_dir = os.path.join(root, sem_seg_dirname, 'validation')
|
47 |
+
name = f"voc_2012_{name}_sem_seg"
|
48 |
+
DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
|
49 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255, **meta,)
|
50 |
+
|
51 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
52 |
+
register_all_pascal_voc(_root)
|
53 |
+
#register_all_pascal_voc_background(_root)
|
cat_seg/data/datasets/register_pascal_59.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
4 |
+
from detectron2.data.datasets import load_sem_seg
|
5 |
+
import copy
|
6 |
+
|
7 |
+
|
8 |
+
stuff_colors = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
9 |
+
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
10 |
+
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
11 |
+
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
|
12 |
+
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
|
13 |
+
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
|
14 |
+
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
|
15 |
+
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
|
16 |
+
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
|
17 |
+
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
|
18 |
+
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
19 |
+
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
20 |
+
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
21 |
+
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
|
22 |
+
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
23 |
+
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
|
24 |
+
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
25 |
+
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
26 |
+
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
27 |
+
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
28 |
+
[0, 0, 230], [119, 11, 32],
|
29 |
+
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
|
30 |
+
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
31 |
+
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
32 |
+
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
33 |
+
[64, 192, 96], [64, 160, 64], [64, 64, 0]]
|
34 |
+
|
35 |
+
def _get_pascal_context_59_meta():
|
36 |
+
#context_classes = ["aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", "horse", "keyboard", "light", "motorbike", "mountain", "mouse", "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "diningtable", "track", "train", "tree", "truck", "tvmonitor", "wall", "water", "window", "wood"]#, "background"]
|
37 |
+
context_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
|
38 |
+
context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_classes))]
|
39 |
+
ret = {
|
40 |
+
"stuff_colors" : context_colors,
|
41 |
+
"stuff_classes" : context_classes,
|
42 |
+
}
|
43 |
+
return ret
|
44 |
+
|
45 |
+
def register_pascal_context_59(root):
|
46 |
+
root = os.path.join(root, "VOCdevkit", "VOC2010")
|
47 |
+
meta = _get_pascal_context_59_meta()
|
48 |
+
for name, image_dirname, sem_seg_dirname in [
|
49 |
+
("test", "JPEGImages", "annotations_detectron2/pc59_val"),
|
50 |
+
]:
|
51 |
+
image_dir = os.path.join(root, image_dirname)
|
52 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
53 |
+
name = f"context_59_{name}_sem_seg"
|
54 |
+
DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
|
55 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
|
56 |
+
|
57 |
+
def _get_pascal_context_459_meta():
|
58 |
+
context_459_classes = ["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"]
|
59 |
+
context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_459_classes))]
|
60 |
+
ret = {
|
61 |
+
"stuff_colors" : context_colors,
|
62 |
+
"stuff_classes" : context_459_classes,
|
63 |
+
}
|
64 |
+
return ret
|
65 |
+
|
66 |
+
def register_pascal_context_459(root):
|
67 |
+
root = os.path.join(root, "VOCdevkit", "VOC2010")
|
68 |
+
meta = _get_pascal_context_459_meta()
|
69 |
+
for name, image_dirname, sem_seg_dirname in [
|
70 |
+
("test", "JPEGImages", "annotations_detectron2/pc459_val"),
|
71 |
+
]:
|
72 |
+
image_dir = os.path.join(root, image_dirname)
|
73 |
+
gt_dir = os.path.join(root, sem_seg_dirname)
|
74 |
+
name = f"context_459_{name}_sem_seg"
|
75 |
+
DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='tif', image_ext='jpg'))
|
76 |
+
MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=459, **meta,)
|
77 |
+
|
78 |
+
|
79 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
80 |
+
register_pascal_context_59(_root)
|
81 |
+
register_pascal_context_459(_root)
|
cat_seg/modeling/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from .backbone.swin import D2SwinTransformer
|
3 |
+
from .heads.cat_seg_head import CATSegHead
|
cat_seg/modeling/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (263 Bytes). View file
|
|
cat_seg/modeling/__pycache__/criterion.cpython-38.pyc
ADDED
Binary file (8.26 kB). View file
|
|
cat_seg/modeling/__pycache__/matcher.cpython-38.pyc
ADDED
Binary file (6.94 kB). View file
|
|
cat_seg/modeling/backbone/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (164 Bytes). View file
|
|
cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc
ADDED
Binary file (20 kB). View file
|
|
cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc
ADDED
Binary file (21.5 kB). View file
|
|
cat_seg/modeling/backbone/swin.py
ADDED
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.utils.checkpoint as checkpoint
|
16 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
17 |
+
|
18 |
+
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
|
19 |
+
|
20 |
+
|
21 |
+
class Mlp(nn.Module):
|
22 |
+
"""Multilayer perceptron."""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def window_partition(x, window_size):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
x: (B, H, W, C)
|
48 |
+
window_size (int): window size
|
49 |
+
Returns:
|
50 |
+
windows: (num_windows*B, window_size, window_size, C)
|
51 |
+
"""
|
52 |
+
B, H, W, C = x.shape
|
53 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
54 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
55 |
+
return windows
|
56 |
+
|
57 |
+
|
58 |
+
def window_reverse(windows, window_size, H, W):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
windows: (num_windows*B, window_size, window_size, C)
|
62 |
+
window_size (int): Window size
|
63 |
+
H (int): Height of image
|
64 |
+
W (int): Width of image
|
65 |
+
Returns:
|
66 |
+
x: (B, H, W, C)
|
67 |
+
"""
|
68 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
69 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
70 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class WindowAttention(nn.Module):
|
75 |
+
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
76 |
+
It supports both of shifted and non-shifted window.
|
77 |
+
Args:
|
78 |
+
dim (int): Number of input channels.
|
79 |
+
window_size (tuple[int]): The height and width of the window.
|
80 |
+
num_heads (int): Number of attention heads.
|
81 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
82 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
83 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
84 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
dim,
|
90 |
+
window_size,
|
91 |
+
num_heads,
|
92 |
+
qkv_bias=True,
|
93 |
+
qk_scale=None,
|
94 |
+
attn_drop=0.0,
|
95 |
+
proj_drop=0.0,
|
96 |
+
):
|
97 |
+
|
98 |
+
super().__init__()
|
99 |
+
self.dim = dim
|
100 |
+
self.window_size = window_size # Wh, Ww
|
101 |
+
self.num_heads = num_heads
|
102 |
+
head_dim = dim // num_heads
|
103 |
+
self.scale = qk_scale or head_dim ** -0.5
|
104 |
+
|
105 |
+
# define a parameter table of relative position bias
|
106 |
+
self.relative_position_bias_table = nn.Parameter(
|
107 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
108 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
109 |
+
|
110 |
+
# get pair-wise relative position index for each token inside the window
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
|
123 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
124 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
125 |
+
self.proj = nn.Linear(dim, dim)
|
126 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
127 |
+
|
128 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
129 |
+
self.softmax = nn.Softmax(dim=-1)
|
130 |
+
|
131 |
+
def forward(self, x, mask=None):
|
132 |
+
"""Forward function.
|
133 |
+
Args:
|
134 |
+
x: input features with shape of (num_windows*B, N, C)
|
135 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
136 |
+
"""
|
137 |
+
B_, N, C = x.shape
|
138 |
+
qkv = (
|
139 |
+
self.qkv(x)
|
140 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
141 |
+
.permute(2, 0, 3, 1, 4)
|
142 |
+
)
|
143 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
144 |
+
|
145 |
+
q = q * self.scale
|
146 |
+
attn = q @ k.transpose(-2, -1)
|
147 |
+
|
148 |
+
relative_position_bias = self.relative_position_bias_table[
|
149 |
+
self.relative_position_index.view(-1)
|
150 |
+
].view(
|
151 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
152 |
+
) # Wh*Ww,Wh*Ww,nH
|
153 |
+
relative_position_bias = relative_position_bias.permute(
|
154 |
+
2, 0, 1
|
155 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
156 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
157 |
+
|
158 |
+
if mask is not None:
|
159 |
+
nW = mask.shape[0]
|
160 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
161 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
162 |
+
attn = self.softmax(attn)
|
163 |
+
else:
|
164 |
+
attn = self.softmax(attn)
|
165 |
+
|
166 |
+
attn = self.attn_drop(attn)
|
167 |
+
|
168 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
169 |
+
x = self.proj(x)
|
170 |
+
x = self.proj_drop(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class SwinTransformerBlock(nn.Module):
|
175 |
+
"""Swin Transformer Block.
|
176 |
+
Args:
|
177 |
+
dim (int): Number of input channels.
|
178 |
+
num_heads (int): Number of attention heads.
|
179 |
+
window_size (int): Window size.
|
180 |
+
shift_size (int): Shift size for SW-MSA.
|
181 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
182 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
183 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
184 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
185 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
186 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
187 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
188 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
dim,
|
194 |
+
num_heads,
|
195 |
+
window_size=7,
|
196 |
+
shift_size=0,
|
197 |
+
mlp_ratio=4.0,
|
198 |
+
qkv_bias=True,
|
199 |
+
qk_scale=None,
|
200 |
+
drop=0.0,
|
201 |
+
attn_drop=0.0,
|
202 |
+
drop_path=0.0,
|
203 |
+
act_layer=nn.GELU,
|
204 |
+
norm_layer=nn.LayerNorm,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
self.dim = dim
|
208 |
+
self.num_heads = num_heads
|
209 |
+
self.window_size = window_size
|
210 |
+
self.shift_size = shift_size
|
211 |
+
self.mlp_ratio = mlp_ratio
|
212 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
213 |
+
|
214 |
+
self.norm1 = norm_layer(dim)
|
215 |
+
self.attn = WindowAttention(
|
216 |
+
dim,
|
217 |
+
window_size=to_2tuple(self.window_size),
|
218 |
+
num_heads=num_heads,
|
219 |
+
qkv_bias=qkv_bias,
|
220 |
+
qk_scale=qk_scale,
|
221 |
+
attn_drop=attn_drop,
|
222 |
+
proj_drop=drop,
|
223 |
+
)
|
224 |
+
|
225 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
226 |
+
self.norm2 = norm_layer(dim)
|
227 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
228 |
+
self.mlp = Mlp(
|
229 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
230 |
+
)
|
231 |
+
|
232 |
+
self.H = None
|
233 |
+
self.W = None
|
234 |
+
|
235 |
+
def forward(self, x, mask_matrix):
|
236 |
+
"""Forward function.
|
237 |
+
Args:
|
238 |
+
x: Input feature, tensor size (B, H*W, C).
|
239 |
+
H, W: Spatial resolution of the input feature.
|
240 |
+
mask_matrix: Attention mask for cyclic shift.
|
241 |
+
"""
|
242 |
+
B, L, C = x.shape
|
243 |
+
H, W = self.H, self.W
|
244 |
+
assert L == H * W, "input feature has wrong size"
|
245 |
+
|
246 |
+
shortcut = x
|
247 |
+
x = self.norm1(x)
|
248 |
+
x = x.view(B, H, W, C)
|
249 |
+
|
250 |
+
# pad feature maps to multiples of window size
|
251 |
+
pad_l = pad_t = 0
|
252 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
253 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
254 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
255 |
+
_, Hp, Wp, _ = x.shape
|
256 |
+
|
257 |
+
# cyclic shift
|
258 |
+
if self.shift_size > 0:
|
259 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
260 |
+
attn_mask = mask_matrix
|
261 |
+
else:
|
262 |
+
shifted_x = x
|
263 |
+
attn_mask = None
|
264 |
+
|
265 |
+
# partition windows
|
266 |
+
x_windows = window_partition(
|
267 |
+
shifted_x, self.window_size
|
268 |
+
) # nW*B, window_size, window_size, C
|
269 |
+
x_windows = x_windows.view(
|
270 |
+
-1, self.window_size * self.window_size, C
|
271 |
+
) # nW*B, window_size*window_size, C
|
272 |
+
|
273 |
+
# W-MSA/SW-MSA
|
274 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
275 |
+
|
276 |
+
# merge windows
|
277 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
278 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
279 |
+
|
280 |
+
# reverse cyclic shift
|
281 |
+
if self.shift_size > 0:
|
282 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
283 |
+
else:
|
284 |
+
x = shifted_x
|
285 |
+
|
286 |
+
if pad_r > 0 or pad_b > 0:
|
287 |
+
x = x[:, :H, :W, :].contiguous()
|
288 |
+
|
289 |
+
x = x.view(B, H * W, C)
|
290 |
+
|
291 |
+
# FFN
|
292 |
+
x = shortcut + self.drop_path(x)
|
293 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
294 |
+
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class PatchMerging(nn.Module):
|
299 |
+
"""Patch Merging Layer
|
300 |
+
Args:
|
301 |
+
dim (int): Number of input channels.
|
302 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
303 |
+
"""
|
304 |
+
|
305 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
306 |
+
super().__init__()
|
307 |
+
self.dim = dim
|
308 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
309 |
+
self.norm = norm_layer(4 * dim)
|
310 |
+
|
311 |
+
def forward(self, x, H, W):
|
312 |
+
"""Forward function.
|
313 |
+
Args:
|
314 |
+
x: Input feature, tensor size (B, H*W, C).
|
315 |
+
H, W: Spatial resolution of the input feature.
|
316 |
+
"""
|
317 |
+
B, L, C = x.shape
|
318 |
+
assert L == H * W, "input feature has wrong size"
|
319 |
+
|
320 |
+
x = x.view(B, H, W, C)
|
321 |
+
|
322 |
+
# padding
|
323 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
324 |
+
if pad_input:
|
325 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
326 |
+
|
327 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
328 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
329 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
330 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
331 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
332 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
x = self.reduction(x)
|
336 |
+
|
337 |
+
return x
|
338 |
+
|
339 |
+
|
340 |
+
class BasicLayer(nn.Module):
|
341 |
+
"""A basic Swin Transformer layer for one stage.
|
342 |
+
Args:
|
343 |
+
dim (int): Number of feature channels
|
344 |
+
depth (int): Depths of this stage.
|
345 |
+
num_heads (int): Number of attention head.
|
346 |
+
window_size (int): Local window size. Default: 7.
|
347 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
348 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
349 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
350 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
351 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
352 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
353 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
354 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
355 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
356 |
+
"""
|
357 |
+
|
358 |
+
def __init__(
|
359 |
+
self,
|
360 |
+
dim,
|
361 |
+
depth,
|
362 |
+
num_heads,
|
363 |
+
window_size=7,
|
364 |
+
mlp_ratio=4.0,
|
365 |
+
qkv_bias=True,
|
366 |
+
qk_scale=None,
|
367 |
+
drop=0.0,
|
368 |
+
attn_drop=0.0,
|
369 |
+
drop_path=0.0,
|
370 |
+
norm_layer=nn.LayerNorm,
|
371 |
+
downsample=None,
|
372 |
+
use_checkpoint=False,
|
373 |
+
):
|
374 |
+
super().__init__()
|
375 |
+
self.window_size = window_size
|
376 |
+
self.shift_size = window_size // 2
|
377 |
+
self.depth = depth
|
378 |
+
self.use_checkpoint = use_checkpoint
|
379 |
+
|
380 |
+
# build blocks
|
381 |
+
self.blocks = nn.ModuleList(
|
382 |
+
[
|
383 |
+
SwinTransformerBlock(
|
384 |
+
dim=dim,
|
385 |
+
num_heads=num_heads,
|
386 |
+
window_size=window_size,
|
387 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
388 |
+
mlp_ratio=mlp_ratio,
|
389 |
+
qkv_bias=qkv_bias,
|
390 |
+
qk_scale=qk_scale,
|
391 |
+
drop=drop,
|
392 |
+
attn_drop=attn_drop,
|
393 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
394 |
+
norm_layer=norm_layer,
|
395 |
+
)
|
396 |
+
for i in range(depth)
|
397 |
+
]
|
398 |
+
)
|
399 |
+
|
400 |
+
# patch merging layer
|
401 |
+
if downsample is not None:
|
402 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
403 |
+
else:
|
404 |
+
self.downsample = None
|
405 |
+
|
406 |
+
def forward(self, x, H, W):
|
407 |
+
"""Forward function.
|
408 |
+
Args:
|
409 |
+
x: Input feature, tensor size (B, H*W, C).
|
410 |
+
H, W: Spatial resolution of the input feature.
|
411 |
+
"""
|
412 |
+
|
413 |
+
# calculate attention mask for SW-MSA
|
414 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
415 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
416 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
417 |
+
h_slices = (
|
418 |
+
slice(0, -self.window_size),
|
419 |
+
slice(-self.window_size, -self.shift_size),
|
420 |
+
slice(-self.shift_size, None),
|
421 |
+
)
|
422 |
+
w_slices = (
|
423 |
+
slice(0, -self.window_size),
|
424 |
+
slice(-self.window_size, -self.shift_size),
|
425 |
+
slice(-self.shift_size, None),
|
426 |
+
)
|
427 |
+
cnt = 0
|
428 |
+
for h in h_slices:
|
429 |
+
for w in w_slices:
|
430 |
+
img_mask[:, h, w, :] = cnt
|
431 |
+
cnt += 1
|
432 |
+
|
433 |
+
mask_windows = window_partition(
|
434 |
+
img_mask, self.window_size
|
435 |
+
) # nW, window_size, window_size, 1
|
436 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
437 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
438 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
439 |
+
attn_mask == 0, float(0.0)
|
440 |
+
)
|
441 |
+
|
442 |
+
for blk in self.blocks:
|
443 |
+
blk.H, blk.W = H, W
|
444 |
+
if self.use_checkpoint:
|
445 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
446 |
+
else:
|
447 |
+
x = blk(x, attn_mask)
|
448 |
+
if self.downsample is not None:
|
449 |
+
x_down = self.downsample(x, H, W)
|
450 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
451 |
+
return x, H, W, x_down, Wh, Ww
|
452 |
+
else:
|
453 |
+
return x, H, W, x, H, W
|
454 |
+
|
455 |
+
|
456 |
+
class PatchEmbed(nn.Module):
|
457 |
+
"""Image to Patch Embedding
|
458 |
+
Args:
|
459 |
+
patch_size (int): Patch token size. Default: 4.
|
460 |
+
in_chans (int): Number of input image channels. Default: 3.
|
461 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
462 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
463 |
+
"""
|
464 |
+
|
465 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
466 |
+
super().__init__()
|
467 |
+
patch_size = to_2tuple(patch_size)
|
468 |
+
self.patch_size = patch_size
|
469 |
+
|
470 |
+
self.in_chans = in_chans
|
471 |
+
self.embed_dim = embed_dim
|
472 |
+
|
473 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
474 |
+
if norm_layer is not None:
|
475 |
+
self.norm = norm_layer(embed_dim)
|
476 |
+
else:
|
477 |
+
self.norm = None
|
478 |
+
|
479 |
+
def forward(self, x):
|
480 |
+
"""Forward function."""
|
481 |
+
# padding
|
482 |
+
_, _, H, W = x.size()
|
483 |
+
if W % self.patch_size[1] != 0:
|
484 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
485 |
+
if H % self.patch_size[0] != 0:
|
486 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
487 |
+
|
488 |
+
x = self.proj(x) # B C Wh Ww
|
489 |
+
if self.norm is not None:
|
490 |
+
Wh, Ww = x.size(2), x.size(3)
|
491 |
+
x = x.flatten(2).transpose(1, 2)
|
492 |
+
x = self.norm(x)
|
493 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
494 |
+
|
495 |
+
return x
|
496 |
+
|
497 |
+
|
498 |
+
class SwinTransformer(nn.Module):
|
499 |
+
"""Swin Transformer backbone.
|
500 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
501 |
+
https://arxiv.org/pdf/2103.14030
|
502 |
+
Args:
|
503 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
504 |
+
used in absolute postion embedding. Default 224.
|
505 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
506 |
+
in_chans (int): Number of input image channels. Default: 3.
|
507 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
508 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
509 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
510 |
+
window_size (int): Window size. Default: 7.
|
511 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
512 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
513 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
514 |
+
drop_rate (float): Dropout rate.
|
515 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
516 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
517 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
518 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
519 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
520 |
+
out_indices (Sequence[int]): Output from which stages.
|
521 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
522 |
+
-1 means not freezing any parameters.
|
523 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
524 |
+
"""
|
525 |
+
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
pretrain_img_size=224,
|
529 |
+
patch_size=4,
|
530 |
+
in_chans=3,
|
531 |
+
embed_dim=96,
|
532 |
+
depths=[2, 2, 6, 2],
|
533 |
+
num_heads=[3, 6, 12, 24],
|
534 |
+
window_size=7,
|
535 |
+
mlp_ratio=4.0,
|
536 |
+
qkv_bias=True,
|
537 |
+
qk_scale=None,
|
538 |
+
drop_rate=0.0,
|
539 |
+
attn_drop_rate=0.0,
|
540 |
+
drop_path_rate=0.2,
|
541 |
+
norm_layer=nn.LayerNorm,
|
542 |
+
ape=False,
|
543 |
+
patch_norm=True,
|
544 |
+
out_indices=(0, 1, 2), #3),
|
545 |
+
frozen_stages=-1,
|
546 |
+
use_checkpoint=False,
|
547 |
+
):
|
548 |
+
super().__init__()
|
549 |
+
|
550 |
+
self.pretrain_img_size = pretrain_img_size
|
551 |
+
self.num_layers = len(depths)
|
552 |
+
self.embed_dim = embed_dim
|
553 |
+
self.ape = ape
|
554 |
+
self.patch_norm = patch_norm
|
555 |
+
self.out_indices = out_indices
|
556 |
+
self.frozen_stages = frozen_stages
|
557 |
+
|
558 |
+
# split image into non-overlapping patches
|
559 |
+
self.patch_embed = PatchEmbed(
|
560 |
+
patch_size=patch_size,
|
561 |
+
in_chans=in_chans,
|
562 |
+
embed_dim=embed_dim,
|
563 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
564 |
+
)
|
565 |
+
|
566 |
+
# absolute position embedding
|
567 |
+
if self.ape:
|
568 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
569 |
+
patch_size = to_2tuple(patch_size)
|
570 |
+
patches_resolution = [
|
571 |
+
pretrain_img_size[0] // patch_size[0],
|
572 |
+
pretrain_img_size[1] // patch_size[1],
|
573 |
+
]
|
574 |
+
|
575 |
+
self.absolute_pos_embed = nn.Parameter(
|
576 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
577 |
+
)
|
578 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
579 |
+
|
580 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
581 |
+
|
582 |
+
# stochastic depth
|
583 |
+
dpr = [
|
584 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
585 |
+
] # stochastic depth decay rule
|
586 |
+
|
587 |
+
# build layers
|
588 |
+
self.layers = nn.ModuleList()
|
589 |
+
for i_layer in range(self.num_layers):
|
590 |
+
layer = BasicLayer(
|
591 |
+
dim=int(embed_dim * 2 ** i_layer),
|
592 |
+
depth=depths[i_layer],
|
593 |
+
num_heads=num_heads[i_layer],
|
594 |
+
window_size=window_size,
|
595 |
+
mlp_ratio=mlp_ratio,
|
596 |
+
qkv_bias=qkv_bias,
|
597 |
+
qk_scale=qk_scale,
|
598 |
+
drop=drop_rate,
|
599 |
+
attn_drop=attn_drop_rate,
|
600 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
601 |
+
norm_layer=norm_layer,
|
602 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
603 |
+
use_checkpoint=use_checkpoint,
|
604 |
+
)
|
605 |
+
self.layers.append(layer)
|
606 |
+
|
607 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
608 |
+
self.num_features = num_features
|
609 |
+
|
610 |
+
# add a norm layer for each output
|
611 |
+
for i_layer in out_indices:
|
612 |
+
layer = norm_layer(num_features[i_layer])
|
613 |
+
layer_name = f"norm{i_layer}"
|
614 |
+
self.add_module(layer_name, layer)
|
615 |
+
|
616 |
+
self._freeze_stages()
|
617 |
+
|
618 |
+
def _freeze_stages(self):
|
619 |
+
if self.frozen_stages >= 0:
|
620 |
+
self.patch_embed.eval()
|
621 |
+
for param in self.patch_embed.parameters():
|
622 |
+
param.requires_grad = False
|
623 |
+
|
624 |
+
if self.frozen_stages >= 1 and self.ape:
|
625 |
+
self.absolute_pos_embed.requires_grad = False
|
626 |
+
|
627 |
+
if self.frozen_stages >= 2:
|
628 |
+
self.pos_drop.eval()
|
629 |
+
for i in range(0, self.frozen_stages - 1):
|
630 |
+
m = self.layers[i]
|
631 |
+
m.eval()
|
632 |
+
for param in m.parameters():
|
633 |
+
param.requires_grad = False
|
634 |
+
|
635 |
+
def init_weights(self, pretrained=None):
|
636 |
+
"""Initialize the weights in backbone.
|
637 |
+
Args:
|
638 |
+
pretrained (str, optional): Path to pre-trained weights.
|
639 |
+
Defaults to None.
|
640 |
+
"""
|
641 |
+
|
642 |
+
def _init_weights(m):
|
643 |
+
if isinstance(m, nn.Linear):
|
644 |
+
trunc_normal_(m.weight, std=0.02)
|
645 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
646 |
+
nn.init.constant_(m.bias, 0)
|
647 |
+
elif isinstance(m, nn.LayerNorm):
|
648 |
+
nn.init.constant_(m.bias, 0)
|
649 |
+
nn.init.constant_(m.weight, 1.0)
|
650 |
+
|
651 |
+
def forward(self, x):
|
652 |
+
"""Forward function."""
|
653 |
+
x = self.patch_embed(x)
|
654 |
+
|
655 |
+
Wh, Ww = x.size(2), x.size(3)
|
656 |
+
if self.ape:
|
657 |
+
# interpolate the position embedding to the corresponding size
|
658 |
+
absolute_pos_embed = F.interpolate(
|
659 |
+
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
660 |
+
)
|
661 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
662 |
+
else:
|
663 |
+
x = x.flatten(2).transpose(1, 2)
|
664 |
+
x = self.pos_drop(x)
|
665 |
+
|
666 |
+
outs = {}
|
667 |
+
for i in range(self.num_layers):
|
668 |
+
layer = self.layers[i]
|
669 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
670 |
+
|
671 |
+
if i in self.out_indices:
|
672 |
+
norm_layer = getattr(self, f"norm{i}")
|
673 |
+
x_out = norm_layer(x_out)
|
674 |
+
|
675 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
676 |
+
outs["res{}".format(i + 2)] = out
|
677 |
+
|
678 |
+
return outs
|
679 |
+
|
680 |
+
def train(self, mode=True):
|
681 |
+
"""Convert the model into training mode while keep layers freezed."""
|
682 |
+
super(SwinTransformer, self).train(mode)
|
683 |
+
self._freeze_stages()
|
684 |
+
|
685 |
+
|
686 |
+
@BACKBONE_REGISTRY.register()
|
687 |
+
class D2SwinTransformer(SwinTransformer, Backbone):
|
688 |
+
def __init__(self, cfg, input_shape):
|
689 |
+
|
690 |
+
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
|
691 |
+
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
|
692 |
+
in_chans = 3
|
693 |
+
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
|
694 |
+
depths = cfg.MODEL.SWIN.DEPTHS
|
695 |
+
num_heads = cfg.MODEL.SWIN.NUM_HEADS
|
696 |
+
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
|
697 |
+
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
|
698 |
+
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
|
699 |
+
qk_scale = cfg.MODEL.SWIN.QK_SCALE
|
700 |
+
drop_rate = cfg.MODEL.SWIN.DROP_RATE
|
701 |
+
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
|
702 |
+
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
|
703 |
+
norm_layer = nn.LayerNorm
|
704 |
+
ape = cfg.MODEL.SWIN.APE
|
705 |
+
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
|
706 |
+
|
707 |
+
super().__init__(
|
708 |
+
pretrain_img_size,
|
709 |
+
patch_size,
|
710 |
+
in_chans,
|
711 |
+
embed_dim,
|
712 |
+
depths,
|
713 |
+
num_heads,
|
714 |
+
window_size,
|
715 |
+
mlp_ratio,
|
716 |
+
qkv_bias,
|
717 |
+
qk_scale,
|
718 |
+
drop_rate,
|
719 |
+
attn_drop_rate,
|
720 |
+
drop_path_rate,
|
721 |
+
norm_layer,
|
722 |
+
ape,
|
723 |
+
patch_norm,
|
724 |
+
)
|
725 |
+
|
726 |
+
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
|
727 |
+
|
728 |
+
self._out_feature_strides = {
|
729 |
+
"res2": 4,
|
730 |
+
"res3": 8,
|
731 |
+
"res4": 16,
|
732 |
+
#"res5": 32,
|
733 |
+
}
|
734 |
+
self._out_feature_channels = {
|
735 |
+
"res2": self.num_features[0],
|
736 |
+
"res3": self.num_features[1],
|
737 |
+
"res4": self.num_features[2],
|
738 |
+
#"res5": self.num_features[3],
|
739 |
+
}
|
740 |
+
|
741 |
+
def forward(self, x):
|
742 |
+
"""
|
743 |
+
Args:
|
744 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
745 |
+
Returns:
|
746 |
+
dict[str->Tensor]: names and the corresponding features
|
747 |
+
"""
|
748 |
+
assert (
|
749 |
+
x.dim() == 4
|
750 |
+
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
751 |
+
outputs = {}
|
752 |
+
y = super().forward(x)
|
753 |
+
for k in y.keys():
|
754 |
+
if k in self._out_features:
|
755 |
+
outputs[k] = y[k]
|
756 |
+
return outputs
|
757 |
+
|
758 |
+
def output_shape(self):
|
759 |
+
return {
|
760 |
+
name: ShapeSpec(
|
761 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
762 |
+
)
|
763 |
+
for name in self._out_features
|
764 |
+
}
|
765 |
+
|
766 |
+
@property
|
767 |
+
def size_divisibility(self):
|
768 |
+
return 32
|
cat_seg/modeling/heads/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|