Spaces:
Runtime error
Runtime error
changed to pipelines
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/constants.cpython-39.pyc +0 -0
- __pycache__/models.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +55 -60
- detection_models/__init__.py +0 -0
- detection_models/__pycache__/__init__.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__init__.py +0 -0
- detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc +0 -0
- detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc +0 -0
- constants.py → detection_models/yolo_stamp/constants.py +0 -8
- detection_models/yolo_stamp/data.py +141 -0
- detection_models/yolo_stamp/loss.py +52 -0
- detection_models/yolo_stamp/model.py +80 -0
- detection_models/yolo_stamp/train.ipynb +185 -0
- detection_models/yolo_stamp/utils.py +275 -0
- detection_models/yolov8/__init__.py +0 -0
- detection_models/yolov8/train.ipynb +144 -0
- embedding_models/__init__.py +0 -0
- embedding_models/__pycache__/__init__.cpython-39.pyc +0 -0
- embedding_models/vae/__init__.py +0 -0
- embedding_models/vae/__pycache__/__init__.cpython-39.pyc +0 -0
- embedding_models/vae/__pycache__/constants.cpython-39.pyc +0 -0
- embedding_models/vae/__pycache__/model.cpython-39.pyc +0 -0
- embedding_models/vae/constants.py +6 -0
- embedding_models/vae/losses.py +77 -0
- embedding_models/vae/model.py +147 -0
- embedding_models/vae/train.ipynb +393 -0
- embedding_models/vits8/__init__.py +0 -0
- embedding_models/vits8/example.py +10 -0
- embedding_models/vits8/model.py +13 -0
- embedding_models/vits8/oml/__init__.py +0 -0
- embedding_models/vits8/oml/create_dataset.py +71 -0
- embedding_models/vits8/oml/data/test/images/99d_15.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99e_20.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99f_25.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99g_30.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99h_35.bmp +0 -0
- embedding_models/vits8/oml/data/test/images/99i_40.bmp +0 -0
- embedding_models/vits8/oml/data/train_val/df_stamps.csv +41 -0
- embedding_models/vits8/oml/data/train_val/images/circle12_1239.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle15_1244.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle19_1236.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle21_1249.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle24_1234.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle24_1246.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1233.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1246.png +0 -0
- embedding_models/vits8/oml/data/train_val/images/circle2_1249.png +0 -0
__pycache__/constants.cpython-39.pyc
ADDED
Binary file (660 Bytes). View file
|
|
__pycache__/models.cpython-39.pyc
ADDED
Binary file (5.22 kB). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (7.73 kB). View file
|
|
app.py
CHANGED
@@ -1,102 +1,97 @@
|
|
1 |
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
from ultralytics import YOLO
|
4 |
-
from torchvision.transforms.functional import to_tensor
|
5 |
-
from huggingface_hub import hf_hub_download
|
6 |
import torch
|
7 |
-
import
|
8 |
-
from albumentations.pytorch.transforms import ToTensorV2
|
9 |
-
import pandas as pd
|
10 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
from
|
13 |
-
from
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')
|
19 |
|
20 |
-
|
21 |
-
yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu'))
|
22 |
-
yolo_stamp = yolo_stamp.to(device)
|
23 |
-
yolo_stamp.eval()
|
24 |
-
transform = A.Compose([
|
25 |
-
A.Normalize(),
|
26 |
-
ToTensorV2(p=1.0),
|
27 |
-
])
|
28 |
|
29 |
-
vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu')
|
30 |
-
vits8 = vits8.to(device)
|
31 |
-
vits8.eval()
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
|
38 |
|
39 |
-
def
|
40 |
|
41 |
-
shape = torch.tensor(image.size)
|
42 |
image = image.convert('RGB')
|
43 |
|
44 |
if det_choice == 'yolov8':
|
45 |
-
|
46 |
-
image = image.resize((640, 640))
|
47 |
-
boxes = yolov8(image)[0].boxes.xyxy.cpu()
|
48 |
-
image_with_boxes = visualize_bbox(image, boxes)
|
49 |
|
50 |
elif det_choice == 'yolo-stamp':
|
51 |
-
|
52 |
-
image = image.resize((448, 448))
|
53 |
-
image_tensor = transform(image=np.array(image))['image']
|
54 |
-
output = yolo_stamp(image_tensor.unsqueeze(0).to(device))
|
55 |
-
|
56 |
-
boxes = output_tensor_to_boxes(output[0].detach().cpu())
|
57 |
-
boxes = nonmax_suppression(boxes)
|
58 |
-
boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
|
59 |
-
image_with_boxes = visualize_bbox(image, boxes)
|
60 |
else:
|
61 |
return
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
embeddings = []
|
65 |
if emb_choice == 'vits8':
|
66 |
-
for
|
67 |
-
|
68 |
-
embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())
|
69 |
|
70 |
elif emb_choice == 'vae-encoder':
|
71 |
-
for
|
72 |
-
|
73 |
-
embeddings.append(np.array(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu()))
|
74 |
|
75 |
embeddings = np.stack(embeddings)
|
76 |
|
77 |
similarities = cosine_similarity(embeddings)
|
78 |
|
79 |
-
boxes = boxes * coef
|
80 |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
81 |
|
82 |
fig, ax = plt.subplots()
|
83 |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
|
84 |
cmap="YlGn", cbarlabel="Embeddings similarities")
|
85 |
texts = annotate_heatmap(im, valfmt="{x:.3f}")
|
86 |
-
return image_with_boxes, df_boxes, embeddings, fig
|
87 |
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
gr.Image(type="pil"),
|
92 |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
|
|
|
93 |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
|
94 |
]
|
95 |
-
|
96 |
-
gr.Image(type="pil"),
|
97 |
gr.DataFrame(type='pandas', label="Bounding boxes"),
|
|
|
98 |
gr.DataFrame(type='numpy', label="Embeddings"),
|
99 |
gr.Plot(label="Cosine Similarities")
|
100 |
]
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import numpy as np
|
|
|
|
|
4 |
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
+
import pandas as pd
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib
|
9 |
|
10 |
+
from pipelines.detection.yolo_v8 import Yolov8Pipeline
|
11 |
+
from pipelines.detection.yolo_stamp import YoloStampPipeline
|
12 |
+
from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline
|
13 |
+
from pipelines.feature_extraction.vae import VaePipeline
|
14 |
+
from pipelines.feature_extraction.vits8 import Vits8Pipeline
|
|
|
|
|
15 |
|
16 |
+
from utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
|
|
|
|
18 |
|
19 |
+
yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt')
|
20 |
+
yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt')
|
21 |
+
vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt')
|
22 |
+
vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt')
|
23 |
+
dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt')
|
24 |
|
25 |
|
26 |
+
def doc_predict(image, det_choice, seg_choice, emb_choice):
|
27 |
|
|
|
28 |
image = image.convert('RGB')
|
29 |
|
30 |
if det_choice == 'yolov8':
|
31 |
+
boxes = yolov8(image)
|
|
|
|
|
|
|
32 |
|
33 |
elif det_choice == 'yolo-stamp':
|
34 |
+
boxes = yolo_stamp(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
return
|
37 |
+
image_with_boxes = visualize_bbox(image, boxes)
|
38 |
+
|
39 |
+
segmented_stamps = []
|
40 |
+
for box in boxes:
|
41 |
+
cropped_stamp = image.crop(box.tolist())
|
42 |
+
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
|
43 |
+
|
44 |
+
widths, heights = zip(*(i.size for i in segmented_stamps))
|
45 |
+
|
46 |
+
total_width = sum(widths)
|
47 |
+
max_height = max(heights)
|
48 |
+
|
49 |
+
concatenated_stamps = Image.new('RGB', (total_width, max_height))
|
50 |
+
|
51 |
+
x_offset = 0
|
52 |
+
for im in segmented_stamps:
|
53 |
+
concatenated_stamps.paste(im, (x_offset,0))
|
54 |
+
x_offset += im.size[0]
|
55 |
|
56 |
embeddings = []
|
57 |
if emb_choice == 'vits8':
|
58 |
+
for stamp in segmented_stamps:
|
59 |
+
embeddings.append(vits8(stamp))
|
|
|
60 |
|
61 |
elif emb_choice == 'vae-encoder':
|
62 |
+
for stamp in segmented_stamps:
|
63 |
+
embeddings.append(vae(stamp))
|
|
|
64 |
|
65 |
embeddings = np.stack(embeddings)
|
66 |
|
67 |
similarities = cosine_similarity(embeddings)
|
68 |
|
|
|
69 |
df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
70 |
|
71 |
fig, ax = plt.subplots()
|
72 |
im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
|
73 |
cmap="YlGn", cbarlabel="Embeddings similarities")
|
74 |
texts = annotate_heatmap(im, valfmt="{x:.3f}")
|
75 |
+
return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig
|
76 |
|
77 |
|
78 |
+
doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']]
|
79 |
+
doc_inputs = [
|
80 |
+
gr.Image(label="Document image", type="pil"),
|
81 |
gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
|
82 |
+
gr.Checkbox(label="Use segmentation model"),
|
83 |
gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
|
84 |
]
|
85 |
+
doc_outputs = [
|
86 |
+
gr.Image(label="Document with bounding boxes", type="pil"),
|
87 |
gr.DataFrame(type='pandas', label="Bounding boxes"),
|
88 |
+
gr.Image(label="Segmented stamps", type="pil"),
|
89 |
gr.DataFrame(type='numpy', label="Embeddings"),
|
90 |
gr.Plot(label="Cosine Similarities")
|
91 |
]
|
92 |
+
|
93 |
+
with gr.Blocks() as demo:
|
94 |
+
with gr.Tab("Signle document"):
|
95 |
+
gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples)
|
96 |
+
|
97 |
+
demo.launch(inline=False)
|
detection_models/__init__.py
ADDED
File without changes
|
detection_models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (165 Bytes). View file
|
|
detection_models/yolo_stamp/__init__.py
ADDED
File without changes
|
detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (176 Bytes). View file
|
|
detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (611 Bytes). View file
|
|
detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc
ADDED
Binary file (3.07 kB). View file
|
|
detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (9.31 kB). View file
|
|
constants.py → detection_models/yolo_stamp/constants.py
RENAMED
@@ -23,11 +23,3 @@ STD = (0.229, 0.224, 0.225)
|
|
23 |
MEAN = (0.485, 0.456, 0.406)
|
24 |
# box color to show the bounding box on image
|
25 |
BOX_COLOR = (0, 0, 255)
|
26 |
-
|
27 |
-
|
28 |
-
# dimenstion of image embedding
|
29 |
-
Z_DIM = 128
|
30 |
-
# hidden dimensions for encoder model
|
31 |
-
ENC_HIDDEN_DIM = 16
|
32 |
-
# hidden dimensions for decoder model
|
33 |
-
DEC_HIDDEN_DIM = 64
|
|
|
23 |
MEAN = (0.485, 0.456, 0.406)
|
24 |
# box color to show the bounding box on image
|
25 |
BOX_COLOR = (0, 0, 255)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detection_models/yolo_stamp/data.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
import albumentations as A
|
6 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
from random import randint
|
11 |
+
|
12 |
+
from utils import *
|
13 |
+
|
14 |
+
"""
|
15 |
+
Dataset class for storing stamps data.
|
16 |
+
|
17 |
+
Arguments:
|
18 |
+
data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,)
|
19 |
+
image_folder -- path to folder containing images
|
20 |
+
transforms -- transforms from albumentations package
|
21 |
+
"""
|
22 |
+
class StampDataset(Dataset):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
data=read_data(),
|
26 |
+
image_folder=Path(IMAGE_FOLDER),
|
27 |
+
transforms=None):
|
28 |
+
self.data = data
|
29 |
+
self.image_folder = image_folder
|
30 |
+
self.transforms = transforms
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
item = self.data[idx]
|
34 |
+
image_fn = self.image_folder / item['file_path']
|
35 |
+
boxes = item['boxes']
|
36 |
+
box_nb = item['box_nb']
|
37 |
+
labels = torch.zeros((box_nb, 2), dtype=torch.int64)
|
38 |
+
labels[:, 0] = 1
|
39 |
+
|
40 |
+
img = np.array(Image.open(image_fn))
|
41 |
+
|
42 |
+
try:
|
43 |
+
if self.transforms:
|
44 |
+
sample = self.transforms(**{
|
45 |
+
"image":img,
|
46 |
+
"bboxes": boxes,
|
47 |
+
"labels": labels,
|
48 |
+
})
|
49 |
+
img = sample['image']
|
50 |
+
boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
|
51 |
+
except:
|
52 |
+
return self.__getitem__(randint(0, len(self.data)-1))
|
53 |
+
|
54 |
+
target_tensor = boxes_to_tensor(boxes.type(torch.float32))
|
55 |
+
return img, target_tensor
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.data)
|
59 |
+
|
60 |
+
def collate_fn(batch):
|
61 |
+
return tuple(zip(*batch))
|
62 |
+
|
63 |
+
|
64 |
+
def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None):
|
65 |
+
"""
|
66 |
+
Creates StampDataset objects.
|
67 |
+
|
68 |
+
Arguments:
|
69 |
+
data_path -- string or Path, specifying path to annotations file
|
70 |
+
train_transforms -- transforms to be applied during training
|
71 |
+
val_transforms -- transforms to be applied during validation
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
(train_dataset, val_dataset) -- tuple of StampDataset for training and validation
|
75 |
+
"""
|
76 |
+
data = read_data(data_path)
|
77 |
+
if train_transforms is None:
|
78 |
+
train_transforms = A.Compose([
|
79 |
+
A.RandomCropNearBBox(max_part_shift=0.6, p=0.4),
|
80 |
+
A.Resize(height=448, width=448),
|
81 |
+
A.HorizontalFlip(p=0.5),
|
82 |
+
A.VerticalFlip(p=0.5),
|
83 |
+
# A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3),
|
84 |
+
# A.Blur(blur_limit=4, p=0.3),
|
85 |
+
A.Normalize(),
|
86 |
+
ToTensorV2(p=1.0),
|
87 |
+
],
|
88 |
+
bbox_params={
|
89 |
+
"format":"coco",
|
90 |
+
'label_fields': ['labels']
|
91 |
+
})
|
92 |
+
|
93 |
+
if val_transforms is None:
|
94 |
+
val_transforms = A.Compose([
|
95 |
+
A.Resize(height=448, width=448),
|
96 |
+
A.Normalize(),
|
97 |
+
ToTensorV2(p=1.0),
|
98 |
+
],
|
99 |
+
bbox_params={
|
100 |
+
"format":"coco",
|
101 |
+
'label_fields': ['labels']
|
102 |
+
})
|
103 |
+
train, test_data = train_test_split(data, test_size=0.1, shuffle=True)
|
104 |
+
|
105 |
+
train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True)
|
106 |
+
|
107 |
+
train_dataset = StampDataset(train_data, transforms=train_transforms)
|
108 |
+
val_dataset = StampDataset(val_data, transforms=val_transforms)
|
109 |
+
test_dataset = StampDataset(test_data, transforms=val_transforms)
|
110 |
+
|
111 |
+
return train_dataset, val_dataset, test_dataset
|
112 |
+
|
113 |
+
|
114 |
+
def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None):
|
115 |
+
"""
|
116 |
+
Creates StampDataset objects.
|
117 |
+
|
118 |
+
Arguments:
|
119 |
+
batch_size -- integer specifying the number of images in the batch
|
120 |
+
data_path -- string or Path, specifying path to annotations file
|
121 |
+
train_transforms -- transforms to be applied during training
|
122 |
+
val_transforms -- transforms to be applied during validation
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
(train_loader, val_loader) -- tuple of DataLoader for training and validation
|
126 |
+
"""
|
127 |
+
train_dataset, val_dataset, _ = get_datasets(data_path)
|
128 |
+
|
129 |
+
train_loader = DataLoader(
|
130 |
+
train_dataset,
|
131 |
+
batch_size=batch_size,
|
132 |
+
shuffle=True,
|
133 |
+
num_workers=num_workers,
|
134 |
+
collate_fn=collate_fn, drop_last=True)
|
135 |
+
|
136 |
+
val_loader = DataLoader(
|
137 |
+
val_dataset,
|
138 |
+
batch_size=batch_size,
|
139 |
+
collate_fn=collate_fn)
|
140 |
+
|
141 |
+
return train_loader, val_loader
|
detection_models/yolo_stamp/loss.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from utils import *
|
4 |
+
|
5 |
+
"""
|
6 |
+
Class for loss for training YOLO model.
|
7 |
+
|
8 |
+
Argmunets:
|
9 |
+
h_coord: weight for loss related to coordinates and shapes of box
|
10 |
+
h__noobj: weight for loss of predicting presence of box when it is absent.
|
11 |
+
"""
|
12 |
+
class YOLOLoss(nn.Module):
|
13 |
+
def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.):
|
14 |
+
super().__init__()
|
15 |
+
self.h_coord = h_coord
|
16 |
+
self.h_noobj = h_noobj
|
17 |
+
self.h_shape = h_shape
|
18 |
+
self.h_obj = h_obj
|
19 |
+
|
20 |
+
def square_error(self, output, target):
|
21 |
+
return (output - target) ** 2
|
22 |
+
|
23 |
+
def forward(self, output, target):
|
24 |
+
|
25 |
+
pred_xy, pred_wh, pred_obj = yolo_head(output)
|
26 |
+
gt_xy, gt_wh, gt_obj = process_target(target)
|
27 |
+
|
28 |
+
pred_ul = pred_xy - 0.5 * pred_wh
|
29 |
+
pred_br = pred_xy + 0.5 * pred_wh
|
30 |
+
pred_area = pred_wh[..., 0] * pred_wh[..., 1]
|
31 |
+
|
32 |
+
gt_ul = gt_xy - 0.5 * gt_wh
|
33 |
+
gt_br = gt_xy + 0.5 * gt_wh
|
34 |
+
gt_area = gt_wh[..., 0] * gt_wh[..., 1]
|
35 |
+
|
36 |
+
intersect_ul = torch.max(pred_ul, gt_ul)
|
37 |
+
intersect_br = torch.min(pred_br, gt_br)
|
38 |
+
intersect_wh = intersect_br - intersect_ul
|
39 |
+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
40 |
+
|
41 |
+
iou = intersect_area / (pred_area + gt_area - intersect_area)
|
42 |
+
max_iou = torch.max(iou, dim=3, keepdim=True)[0]
|
43 |
+
best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1)
|
44 |
+
gt_box_conf = best_box_index * gt_obj
|
45 |
+
|
46 |
+
xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum()
|
47 |
+
wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum()
|
48 |
+
obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum()
|
49 |
+
noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum()
|
50 |
+
|
51 |
+
total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss
|
52 |
+
return total_loss
|
detection_models/yolo_stamp/model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .constants import *
|
5 |
+
|
6 |
+
"""
|
7 |
+
Class for custom activation.
|
8 |
+
"""
|
9 |
+
class SymReLU(nn.Module):
|
10 |
+
def __init__(self, inplace: bool = False):
|
11 |
+
super().__init__()
|
12 |
+
self.inplace = inplace
|
13 |
+
|
14 |
+
def forward(self, input):
|
15 |
+
return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
|
16 |
+
|
17 |
+
def extra_repr(self) -> str:
|
18 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
19 |
+
return inplace_str
|
20 |
+
|
21 |
+
|
22 |
+
"""
|
23 |
+
Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
|
24 |
+
"""
|
25 |
+
class YOLOStamp(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
anchors=ANCHORS,
|
29 |
+
in_channels=3,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.register_buffer('anchors', torch.tensor(anchors))
|
34 |
+
|
35 |
+
self.act = SymReLU()
|
36 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
37 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
38 |
+
self.norm1 = nn.BatchNorm2d(num_features=8)
|
39 |
+
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
40 |
+
self.norm2 = nn.BatchNorm2d(num_features=16)
|
41 |
+
self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
42 |
+
self.norm3 = nn.BatchNorm2d(num_features=16)
|
43 |
+
self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
44 |
+
self.norm4 = nn.BatchNorm2d(num_features=16)
|
45 |
+
self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
46 |
+
self.norm5 = nn.BatchNorm2d(num_features=16)
|
47 |
+
self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
48 |
+
self.norm6 = nn.BatchNorm2d(num_features=24)
|
49 |
+
self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
50 |
+
self.norm7 = nn.BatchNorm2d(num_features=24)
|
51 |
+
self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
52 |
+
self.norm8 = nn.BatchNorm2d(num_features=48)
|
53 |
+
self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
54 |
+
self.norm9 = nn.BatchNorm2d(num_features=48)
|
55 |
+
self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
56 |
+
self.norm10 = nn.BatchNorm2d(num_features=48)
|
57 |
+
self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
58 |
+
self.norm11 = nn.BatchNorm2d(num_features=64)
|
59 |
+
self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
60 |
+
self.norm12 = nn.BatchNorm2d(num_features=256)
|
61 |
+
self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
62 |
+
|
63 |
+
def forward(self, x, head=True):
|
64 |
+
x = x.type(self.conv1.weight.dtype)
|
65 |
+
x = self.act(self.pool(self.norm1(self.conv1(x))))
|
66 |
+
x = self.act(self.pool(self.norm2(self.conv2(x))))
|
67 |
+
x = self.act(self.pool(self.norm3(self.conv3(x))))
|
68 |
+
x = self.act(self.pool(self.norm4(self.conv4(x))))
|
69 |
+
x = self.act(self.pool(self.norm5(self.conv5(x))))
|
70 |
+
x = self.act(self.norm6(self.conv6(x)))
|
71 |
+
x = self.act(self.norm7(self.conv7(x)))
|
72 |
+
x = self.act(self.pool(self.norm8(self.conv8(x))))
|
73 |
+
x = self.act(self.norm9(self.conv9(x)))
|
74 |
+
x = self.act(self.norm10(self.conv10(x)))
|
75 |
+
x = self.act(self.norm11(self.conv11(x)))
|
76 |
+
x = self.act(self.norm12(self.conv12(x)))
|
77 |
+
x = self.conv13(x)
|
78 |
+
nb, _, nh, nw= x.shape
|
79 |
+
x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
|
80 |
+
return x
|
detection_models/yolo_stamp/train.ipynb
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from model import *\n",
|
10 |
+
"from loss import *\n",
|
11 |
+
"from data import *\n",
|
12 |
+
"from torch import optim\n",
|
13 |
+
"from tqdm import tqdm\n",
|
14 |
+
"\n",
|
15 |
+
"import pytorch_lightning as pl\n",
|
16 |
+
"from torchmetrics.detection import MeanAveragePrecision\n",
|
17 |
+
"from pytorch_lightning.loggers import TensorBoardLogger"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": 2,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"_, _, test_dataset = get_datasets()"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 3,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"class LitModel(pl.LightningModule):\n",
|
36 |
+
" def __init__(self):\n",
|
37 |
+
" super().__init__()\n",
|
38 |
+
" self.model = YOLOStamp()\n",
|
39 |
+
" self.criterion = YOLOLoss()\n",
|
40 |
+
" self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n",
|
41 |
+
" \n",
|
42 |
+
" def forward(self, x):\n",
|
43 |
+
" return self.model(x)\n",
|
44 |
+
"\n",
|
45 |
+
" def configure_optimizers(self):\n",
|
46 |
+
" optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n",
|
47 |
+
" # return optimizer\n",
|
48 |
+
" scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
|
49 |
+
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
|
50 |
+
"\n",
|
51 |
+
" def training_step(self, batch, batch_idx):\n",
|
52 |
+
" images, targets = batch\n",
|
53 |
+
" tensor_images = torch.stack(images)\n",
|
54 |
+
" tensor_targets = torch.stack(targets)\n",
|
55 |
+
" output = self.model(tensor_images)\n",
|
56 |
+
" loss = self.criterion(output, tensor_targets)\n",
|
57 |
+
" self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
|
58 |
+
" return loss\n",
|
59 |
+
"\n",
|
60 |
+
" def validation_step(self, batch, batch_idx):\n",
|
61 |
+
" images, targets = batch\n",
|
62 |
+
" tensor_images = torch.stack(images)\n",
|
63 |
+
" tensor_targets = torch.stack(targets)\n",
|
64 |
+
" output = self.model(tensor_images)\n",
|
65 |
+
" loss = self.criterion(output, tensor_targets)\n",
|
66 |
+
" self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
|
67 |
+
"\n",
|
68 |
+
" for i in range(len(images)):\n",
|
69 |
+
" boxes = output_tensor_to_boxes(output[i].detach().cpu())\n",
|
70 |
+
" boxes = nonmax_suppression(boxes)\n",
|
71 |
+
" target = target_tensor_to_boxes(targets[i])[::BOX]\n",
|
72 |
+
" if not boxes:\n",
|
73 |
+
" boxes = torch.zeros((1, 5))\n",
|
74 |
+
" preds = [\n",
|
75 |
+
" dict(\n",
|
76 |
+
" boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n",
|
77 |
+
" scores=torch.tensor(boxes)[:, 4].clone().detach(),\n",
|
78 |
+
" labels=torch.zeros(len(boxes)),\n",
|
79 |
+
" )\n",
|
80 |
+
" ]\n",
|
81 |
+
" target = [\n",
|
82 |
+
" dict(\n",
|
83 |
+
" boxes=torch.tensor(target),\n",
|
84 |
+
" labels=torch.zeros(len(target)),\n",
|
85 |
+
" )\n",
|
86 |
+
" ]\n",
|
87 |
+
" self.val_map.update(preds, target)\n",
|
88 |
+
" \n",
|
89 |
+
" def on_validation_epoch_end(self):\n",
|
90 |
+
" mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n",
|
91 |
+
" mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n",
|
92 |
+
" mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n",
|
93 |
+
" self.log_dict(mAPs)\n",
|
94 |
+
" self.val_map.reset()\n",
|
95 |
+
"\n",
|
96 |
+
" image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n",
|
97 |
+
" output = self.model(image.unsqueeze(0))\n",
|
98 |
+
" boxes = output_tensor_to_boxes(output[0].detach().cpu())\n",
|
99 |
+
" boxes = nonmax_suppression(boxes)\n",
|
100 |
+
" img = image.permute(1, 2, 0).cpu().numpy()\n",
|
101 |
+
" img = visualize_bbox(img.copy(), boxes=boxes)\n",
|
102 |
+
" img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n",
|
103 |
+
" \n",
|
104 |
+
" self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 4,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"litmodel = LitModel()"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": 5,
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"logger = TensorBoardLogger(\"detection_logs\")"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 7,
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [],
|
130 |
+
"source": [
|
131 |
+
"epochs = 100"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 8,
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"train_loader, val_loader = get_loaders(batch_size=8)"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": null,
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
|
150 |
+
"trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"%tensorboard"
|
160 |
+
]
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"metadata": {
|
164 |
+
"kernelspec": {
|
165 |
+
"display_name": "Python 3",
|
166 |
+
"language": "python",
|
167 |
+
"name": "python3"
|
168 |
+
},
|
169 |
+
"language_info": {
|
170 |
+
"codemirror_mode": {
|
171 |
+
"name": "ipython",
|
172 |
+
"version": 3
|
173 |
+
},
|
174 |
+
"file_extension": ".py",
|
175 |
+
"mimetype": "text/x-python",
|
176 |
+
"name": "python",
|
177 |
+
"nbconvert_exporter": "python",
|
178 |
+
"pygments_lexer": "ipython3",
|
179 |
+
"version": "3.9.0"
|
180 |
+
},
|
181 |
+
"orig_nbformat": 4
|
182 |
+
},
|
183 |
+
"nbformat": 4,
|
184 |
+
"nbformat_minor": 2
|
185 |
+
}
|
detection_models/yolo_stamp/utils.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from .constants import *
|
8 |
+
|
9 |
+
|
10 |
+
def output_tensor_to_boxes(boxes_tensor):
|
11 |
+
"""
|
12 |
+
Converts the YOLO output tensor to list of boxes with probabilites.
|
13 |
+
|
14 |
+
Arguments:
|
15 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
boxes -- list of shape (None, 5)
|
19 |
+
|
20 |
+
Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
|
21 |
+
For example, the actual output size of scores would be (10, 5) if there are 10 boxes
|
22 |
+
"""
|
23 |
+
cell_w, cell_h = W/S, H/S
|
24 |
+
boxes = []
|
25 |
+
|
26 |
+
for i in range(S):
|
27 |
+
for j in range(S):
|
28 |
+
for b in range(BOX):
|
29 |
+
anchor_wh = torch.tensor(ANCHORS[b])
|
30 |
+
data = boxes_tensor[i,j,b]
|
31 |
+
xy = torch.sigmoid(data[:2])
|
32 |
+
wh = torch.exp(data[2:4])*anchor_wh
|
33 |
+
obj_prob = torch.sigmoid(data[4])
|
34 |
+
|
35 |
+
if obj_prob > OUTPUT_THRESH:
|
36 |
+
x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1]
|
37 |
+
x, y = x_center+j-w/2, y_center+i-h/2
|
38 |
+
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
|
39 |
+
box = [x,y,w,h, obj_prob]
|
40 |
+
boxes.append(box)
|
41 |
+
return boxes
|
42 |
+
|
43 |
+
|
44 |
+
def plot_img(img, size=(7,7)):
|
45 |
+
plt.figure(figsize=size)
|
46 |
+
plt.imshow(img)
|
47 |
+
plt.show()
|
48 |
+
|
49 |
+
|
50 |
+
def plot_normalized_img(img, std=STD, mean=MEAN, size=(7,7)):
|
51 |
+
mean = mean if isinstance(mean, np.ndarray) else np.array(mean)
|
52 |
+
std = std if isinstance(std, np.ndarray) else np.array(std)
|
53 |
+
plt.figure(figsize=size)
|
54 |
+
plt.imshow((255. * (img * std + mean)).astype(np.uint))
|
55 |
+
plt.show()
|
56 |
+
|
57 |
+
|
58 |
+
def visualize_bbox(img, boxes, thickness=2, color=BOX_COLOR, draw_center=True):
|
59 |
+
"""
|
60 |
+
Draws boxes on the given image.
|
61 |
+
|
62 |
+
Arguments:
|
63 |
+
img -- torch.Tensor of shape (3, W, H) or numpy.ndarray of shape (W, H, 3)
|
64 |
+
boxes -- list of shape (None, 5)
|
65 |
+
thickness -- number specifying the thickness of box border
|
66 |
+
color -- RGB tuple of shape (3,) specifying the color of boxes
|
67 |
+
draw_center -- boolean specifying whether to draw center or not
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
img_copy -- numpy.ndarray of shape(W, H, 3) containing image with bouning boxes
|
71 |
+
"""
|
72 |
+
img_copy = img.cpu().permute(1,2,0).numpy() if isinstance(img, torch.Tensor) else img.copy()
|
73 |
+
for box in boxes:
|
74 |
+
x,y,w,h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
75 |
+
img_copy = cv2.rectangle(
|
76 |
+
img_copy,
|
77 |
+
(x,y),(x+w, y+h),
|
78 |
+
color, thickness)
|
79 |
+
if draw_center:
|
80 |
+
center = (x+w//2, y+h//2)
|
81 |
+
img_copy = cv2.circle(img_copy, center=center, radius=3, color=(0,255,0), thickness=2)
|
82 |
+
return img_copy
|
83 |
+
|
84 |
+
|
85 |
+
def read_data(annotations=Path(ANNOTATIONS_PATH)):
|
86 |
+
"""
|
87 |
+
Reads annotations data from .csv file. Must contain columns: image_name, bbox_x, bbox_y, bbox_width, bbox_height.
|
88 |
+
|
89 |
+
Arguments:
|
90 |
+
annotations_path -- string or Path specifying path of annotations file
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
data -- list of dictionaries containing path, number of boxes and boxes itself
|
94 |
+
"""
|
95 |
+
data = []
|
96 |
+
|
97 |
+
boxes = pd.read_csv(annotations)
|
98 |
+
image_names = boxes['image_name'].unique()
|
99 |
+
|
100 |
+
for image_name in image_names:
|
101 |
+
cur_boxes = boxes[boxes['image_name'] == image_name]
|
102 |
+
img_data = {
|
103 |
+
'file_path': image_name,
|
104 |
+
'box_nb': len(cur_boxes),
|
105 |
+
'boxes': []}
|
106 |
+
stamp_nb = img_data['box_nb']
|
107 |
+
if stamp_nb <= STAMP_NB_MAX:
|
108 |
+
img_data['boxes'] = cur_boxes[['bbox_x', 'bbox_y','bbox_width','bbox_height']].values
|
109 |
+
data.append(img_data)
|
110 |
+
return data
|
111 |
+
|
112 |
+
def xywh2xyxy(x):
|
113 |
+
"""
|
114 |
+
Converts xywh format to xyxy
|
115 |
+
|
116 |
+
Arguments:
|
117 |
+
x -- torch.Tensor or np.array (xywh format)
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
y -- torch.Tensor or np.array (xyxy)
|
121 |
+
"""
|
122 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
123 |
+
y[..., 0] = x[..., 0]
|
124 |
+
y[..., 1] = x[..., 1]
|
125 |
+
y[..., 2] = x[..., 0] + x[..., 2]
|
126 |
+
y[..., 3] = x[..., 1] + x[..., 3]
|
127 |
+
return y
|
128 |
+
|
129 |
+
def boxes_to_tensor(boxes):
|
130 |
+
"""
|
131 |
+
Convert list of boxes (and labels) to tensor format
|
132 |
+
|
133 |
+
Arguments:
|
134 |
+
boxes -- list of boxes
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
138 |
+
"""
|
139 |
+
boxes_tensor = torch.zeros((S, S, BOX, 5))
|
140 |
+
cell_w, cell_h = W/S, H/S
|
141 |
+
for i, box in enumerate(boxes):
|
142 |
+
x, y, w, h = box
|
143 |
+
# normalize xywh with cell_size
|
144 |
+
x, y, w, h = x / cell_w, y / cell_h, w / cell_w, h / cell_h
|
145 |
+
center_x, center_y = x + w / 2, y + h / 2
|
146 |
+
grid_x = int(np.floor(center_x))
|
147 |
+
grid_y = int(np.floor(center_y))
|
148 |
+
|
149 |
+
if grid_x < S and grid_y < S:
|
150 |
+
boxes_tensor[grid_y, grid_x, :, 0:4] = torch.tensor(BOX * [[center_x - grid_x, center_y - grid_y, w, h]])
|
151 |
+
boxes_tensor[grid_y, grid_x, :, 4] = torch.tensor(BOX * [1.])
|
152 |
+
return boxes_tensor
|
153 |
+
|
154 |
+
|
155 |
+
def target_tensor_to_boxes(boxes_tensor, output_threshold=OUTPUT_THRESH):
|
156 |
+
"""
|
157 |
+
Recover target tensor (tensor output of dataset) to bboxes.
|
158 |
+
Arguments:
|
159 |
+
boxes_tensor -- tensor of shape (S, S, BOX, 5)
|
160 |
+
Returns:
|
161 |
+
boxes -- list of boxes, each box is [x, y, w, h]
|
162 |
+
"""
|
163 |
+
cell_w, cell_h = W/S, H/S
|
164 |
+
boxes = []
|
165 |
+
for i in range(S):
|
166 |
+
for j in range(S):
|
167 |
+
for b in range(BOX):
|
168 |
+
data = boxes_tensor[i,j,b]
|
169 |
+
x_center,y_center, w, h, obj_prob = data[0], data[1], data[2], data[3], data[4]
|
170 |
+
if obj_prob > output_threshold:
|
171 |
+
x, y = x_center+j-w/2, y_center+i-h/2
|
172 |
+
x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
|
173 |
+
box = [x,y,w,h]
|
174 |
+
boxes.append(box)
|
175 |
+
return boxes
|
176 |
+
|
177 |
+
|
178 |
+
def overlap(interval_1, interval_2):
|
179 |
+
"""
|
180 |
+
Calculates length of overlap between two intervals.
|
181 |
+
|
182 |
+
Arguments:
|
183 |
+
interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval
|
184 |
+
interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
overlap -- length of overlap
|
188 |
+
"""
|
189 |
+
x1, x2 = interval_1
|
190 |
+
x3, x4 = interval_2
|
191 |
+
if x3 < x1:
|
192 |
+
if x4 < x1:
|
193 |
+
return 0
|
194 |
+
else:
|
195 |
+
return min(x2,x4) - x1
|
196 |
+
else:
|
197 |
+
if x2 < x3:
|
198 |
+
return 0
|
199 |
+
else:
|
200 |
+
return min(x2,x4) - x3
|
201 |
+
|
202 |
+
|
203 |
+
def compute_iou(box1, box2):
|
204 |
+
"""
|
205 |
+
Compute IOU between box1 and box2.
|
206 |
+
|
207 |
+
Argmunets:
|
208 |
+
box1 -- list of shape (5, ). Represents the first box
|
209 |
+
box2 -- list of shape (5, ). Represents the second box
|
210 |
+
Each box is [x, y, w, h, prob]
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
iou -- intersection over union score between two boxes
|
214 |
+
"""
|
215 |
+
x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3]
|
216 |
+
x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3]
|
217 |
+
|
218 |
+
area1, area2 = w1*h1, w2*h2
|
219 |
+
intersect_w = overlap((x1,x1+w1), (x2,x2+w2))
|
220 |
+
intersect_h = overlap((y1,y1+h1), (y2,y2+w2))
|
221 |
+
if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2:
|
222 |
+
return 1.
|
223 |
+
intersect_area = intersect_w*intersect_h
|
224 |
+
iou = intersect_area/(area1 + area2 - intersect_area)
|
225 |
+
return iou
|
226 |
+
|
227 |
+
|
228 |
+
def nonmax_suppression(boxes, iou_thresh = IOU_THRESH):
|
229 |
+
"""
|
230 |
+
Removes ovelap bboxes
|
231 |
+
|
232 |
+
Arguments:
|
233 |
+
boxes -- list of shape (None, 5)
|
234 |
+
iou_thresh -- maximal value of iou when boxes are considered different
|
235 |
+
Each box is [x, y, w, h, prob]
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
boxes -- list of shape (None, 5) with removed overlapping boxes
|
239 |
+
"""
|
240 |
+
boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
|
241 |
+
for i, current_box in enumerate(boxes):
|
242 |
+
if current_box[4] <= 0:
|
243 |
+
continue
|
244 |
+
for j in range(i+1, len(boxes)):
|
245 |
+
iou = compute_iou(current_box, boxes[j])
|
246 |
+
if iou > iou_thresh:
|
247 |
+
boxes[j][4] = 0
|
248 |
+
boxes = [box for box in boxes if box[4] > 0]
|
249 |
+
return boxes
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
def yolo_head(yolo_output):
|
254 |
+
"""
|
255 |
+
Converts a yolo output tensor to separate tensors of coordinates, shapes and probabilities.
|
256 |
+
|
257 |
+
Arguments:
|
258 |
+
yolo_output -- tensor of shape (batch_size, S, S, BOX, 5)
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
xy -- tensor of shape (batch_size, S, S, BOX, 2) containing coordinates of centers of found boxes for each anchor in each grid cell
|
262 |
+
wh -- tensor of shape (batch_size, S, S, BOX, 2) containing width and height of found boxes for each anchor in each grid cell
|
263 |
+
prob -- tensor of shape (batch_size, S, S, BOX, 1) containing the probability of presence of boxes for each anchor in each grid cell
|
264 |
+
"""
|
265 |
+
xy = torch.sigmoid(yolo_output[..., 0:2])
|
266 |
+
anchors_wh = torch.tensor(ANCHORS, device=yolo_output.device).view(1, 1, 1, len(ANCHORS), 2)
|
267 |
+
wh = torch.exp(yolo_output[..., 2:4]) * anchors_wh
|
268 |
+
prob = torch.sigmoid(yolo_output[..., 4:5])
|
269 |
+
return xy, wh, prob
|
270 |
+
|
271 |
+
def process_target(target):
|
272 |
+
xy = target[..., 0:2]
|
273 |
+
wh = target[..., 2:4]
|
274 |
+
prob = target[..., 4:5]
|
275 |
+
return xy, wh, prob
|
detection_models/yolov8/__init__.py
ADDED
File without changes
|
detection_models/yolov8/train.ipynb
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"HOME = os.getcwd()\n",
|
11 |
+
"print(HOME)"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"# Pip install method (recommended)\n",
|
21 |
+
"\n",
|
22 |
+
"%pip install ultralytics==8.0.20\n",
|
23 |
+
"\n",
|
24 |
+
"from IPython import display\n",
|
25 |
+
"display.clear_output()\n",
|
26 |
+
"\n",
|
27 |
+
"import ultralytics\n",
|
28 |
+
"ultralytics.checks()"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": null,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"from ultralytics import YOLO\n",
|
38 |
+
"\n",
|
39 |
+
"from IPython.display import display, Image"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"!mkdir {HOME}/datasets\n",
|
49 |
+
"%cd {HOME}/datasets\n",
|
50 |
+
"\n",
|
51 |
+
"%pip install roboflow --quiet\n",
|
52 |
+
"\n",
|
53 |
+
"from roboflow import Roboflow\n",
|
54 |
+
"rf = Roboflow(api_key=\"YOUR_API_KEY\")\n",
|
55 |
+
"project = rf.workspace(\"WORKSPACE\").project(\"PROJECT\")\n",
|
56 |
+
"dataset = project.version(1).download(\"yolov8\")"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"%cd {HOME}\n",
|
66 |
+
"\n",
|
67 |
+
"!yolo task=detect mode=train model=yolov8s.pt data={dataset.location}/data.yaml epochs=25 imgsz=800 plots=True"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": null,
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"%cd {HOME}\n",
|
77 |
+
"Image(filename=f'{HOME}/runs/detect/train/confusion_matrix.png', width=600)"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"%cd {HOME}\n",
|
87 |
+
"Image(filename=f'{HOME}/runs/detect/train/results.png', width=600)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"%cd {HOME}\n",
|
97 |
+
"Image(filename=f'{HOME}/runs/detect/train/val_batch0_pred.jpg', width=600)"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"%cd {HOME}\n",
|
107 |
+
"\n",
|
108 |
+
"!yolo task=detect mode=val model={HOME}/runs/detect/train/weights/best.pt data={dataset.location}/data.yaml"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [],
|
116 |
+
"source": [
|
117 |
+
"%cd {HOME}\n",
|
118 |
+
"!yolo task=detect mode=predict model={HOME}/runs/detect/train/weights/best.pt conf=0.25 source={dataset.location}/test/images save=True"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"import glob\n",
|
128 |
+
"from IPython.display import Image, display\n",
|
129 |
+
"\n",
|
130 |
+
"for image_path in glob.glob(f'{HOME}/runs/detect/predict3/*.jpg')[:3]:\n",
|
131 |
+
" display(Image(filename=image_path, width=600))\n",
|
132 |
+
" print(\"\\n\")"
|
133 |
+
]
|
134 |
+
}
|
135 |
+
],
|
136 |
+
"metadata": {
|
137 |
+
"language_info": {
|
138 |
+
"name": "python"
|
139 |
+
},
|
140 |
+
"orig_nbformat": 4
|
141 |
+
},
|
142 |
+
"nbformat": 4,
|
143 |
+
"nbformat_minor": 2
|
144 |
+
}
|
embedding_models/__init__.py
ADDED
File without changes
|
embedding_models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (165 Bytes). View file
|
|
embedding_models/vae/__init__.py
ADDED
File without changes
|
embedding_models/vae/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (169 Bytes). View file
|
|
embedding_models/vae/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (237 Bytes). View file
|
|
embedding_models/vae/__pycache__/model.cpython-39.pyc
ADDED
Binary file (5.87 kB). View file
|
|
embedding_models/vae/constants.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dimenstion of image embedding
|
2 |
+
Z_DIM = 128
|
3 |
+
# hidden dimensions for encoder model
|
4 |
+
ENC_HIDDEN_DIM = 16
|
5 |
+
# hidden dimensions for decoder model
|
6 |
+
DEC_HIDDEN_DIM = 64
|
embedding_models/vae/losses.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.distributions.kl import kl_divergence
|
4 |
+
from torch.distributions.normal import Normal
|
5 |
+
from torch.nn.functional import relu
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class BatchHardTripletLoss(nn.Module):
|
10 |
+
def __init__(self, margin=1., squared=False, agg='sum'):
|
11 |
+
"""
|
12 |
+
Initalize the loss function with a margin parameter, whether or not to consider
|
13 |
+
squared Euclidean distance and how to aggregate the loss in a batch
|
14 |
+
"""
|
15 |
+
super().__init__()
|
16 |
+
self.margin = margin
|
17 |
+
self.squared = squared
|
18 |
+
self.agg = agg
|
19 |
+
self.eps = 1e-8
|
20 |
+
|
21 |
+
def get_pairwise_distances(self, embeddings):
|
22 |
+
"""
|
23 |
+
Computing Euclidean distance for all possible pairs of embeddings.
|
24 |
+
"""
|
25 |
+
ab = embeddings.mm(embeddings.t())
|
26 |
+
a_squared = ab.diag().unsqueeze(1)
|
27 |
+
b_squared = ab.diag().unsqueeze(0)
|
28 |
+
distances = a_squared - 2 * ab + b_squared
|
29 |
+
distances = relu(distances)
|
30 |
+
|
31 |
+
if not self.squared:
|
32 |
+
distances = torch.sqrt(distances + self.eps)
|
33 |
+
|
34 |
+
return distances
|
35 |
+
|
36 |
+
def hardest_triplet_mining(self, dist_mat, labels):
|
37 |
+
|
38 |
+
assert len(dist_mat.size()) == 2
|
39 |
+
assert dist_mat.size(0) == dist_mat.size(1)
|
40 |
+
N = dist_mat.size(0)
|
41 |
+
|
42 |
+
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
|
43 |
+
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
|
44 |
+
|
45 |
+
dist_ap, relative_p_inds = torch.max(
|
46 |
+
(dist_mat * is_pos), 1, keepdim=True)
|
47 |
+
|
48 |
+
dist_an, relative_n_inds = torch.min(
|
49 |
+
(dist_mat * is_neg), 1, keepdim=True)
|
50 |
+
|
51 |
+
return dist_ap, dist_an
|
52 |
+
|
53 |
+
def forward(self, embeddings, labels):
|
54 |
+
|
55 |
+
distances = self.get_pairwise_distances(embeddings)
|
56 |
+
dist_ap, dist_an = self.hardest_triplet_mining(distances, labels)
|
57 |
+
|
58 |
+
triplet_loss = relu(dist_ap - dist_an + self.margin).sum()
|
59 |
+
return triplet_loss
|
60 |
+
|
61 |
+
|
62 |
+
class VAELoss(nn.Module):
|
63 |
+
def __init__(self):
|
64 |
+
super().__init__()
|
65 |
+
self.reconstruction_loss = nn.BCELoss(reduction='sum')
|
66 |
+
|
67 |
+
def kl_divergence_loss(self, q_dist):
|
68 |
+
return kl_divergence(
|
69 |
+
q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
|
70 |
+
).sum(-1)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, output, target, encoding):
|
74 |
+
loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target)
|
75 |
+
return loss
|
76 |
+
|
77 |
+
|
embedding_models/vae/model.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.distributions.normal import Normal
|
3 |
+
|
4 |
+
from .constants import *
|
5 |
+
|
6 |
+
|
7 |
+
class Encoder(nn.Module):
|
8 |
+
'''
|
9 |
+
Encoder Class
|
10 |
+
Values:
|
11 |
+
im_chan: the number of channels of the output image, a scalar
|
12 |
+
hidden_dim: the inner dimension, a scalar
|
13 |
+
'''
|
14 |
+
|
15 |
+
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
|
16 |
+
super(Encoder, self).__init__()
|
17 |
+
self.z_dim = output_chan
|
18 |
+
self.disc = nn.Sequential(
|
19 |
+
self.make_disc_block(im_chan, hidden_dim),
|
20 |
+
self.make_disc_block(hidden_dim, hidden_dim * 2),
|
21 |
+
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
|
22 |
+
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
|
23 |
+
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
|
24 |
+
)
|
25 |
+
|
26 |
+
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
|
27 |
+
'''
|
28 |
+
Function to return a sequence of operations corresponding to a encoder block of the VAE,
|
29 |
+
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
|
30 |
+
Parameters:
|
31 |
+
input_channels: how many channels the input feature representation has
|
32 |
+
output_channels: how many channels the output feature representation should have
|
33 |
+
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
|
34 |
+
stride: the stride of the convolution
|
35 |
+
final_layer: whether we're on the final layer (affects activation and batchnorm)
|
36 |
+
'''
|
37 |
+
if not final_layer:
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
40 |
+
nn.BatchNorm2d(output_channels),
|
41 |
+
nn.LeakyReLU(0.2, inplace=True),
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
return nn.Sequential(
|
45 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, image):
|
49 |
+
'''
|
50 |
+
Function for completing a forward pass of the Encoder: Given an image tensor,
|
51 |
+
returns a 1-dimension tensor representing fake/real.
|
52 |
+
Parameters:
|
53 |
+
image: a flattened image tensor with dimension (im_dim)
|
54 |
+
'''
|
55 |
+
disc_pred = self.disc(image)
|
56 |
+
encoding = disc_pred.view(len(disc_pred), -1)
|
57 |
+
# The stddev output is treated as the log of the variance of the normal
|
58 |
+
# distribution by convention and for numerical stability
|
59 |
+
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
|
60 |
+
|
61 |
+
|
62 |
+
class Decoder(nn.Module):
|
63 |
+
'''
|
64 |
+
Decoder Class
|
65 |
+
Values:
|
66 |
+
z_dim: the dimension of the noise vector, a scalar
|
67 |
+
im_chan: the number of channels of the output image, a scalar
|
68 |
+
hidden_dim: the inner dimension, a scalar
|
69 |
+
'''
|
70 |
+
|
71 |
+
def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM):
|
72 |
+
super(Decoder, self).__init__()
|
73 |
+
self.z_dim = z_dim
|
74 |
+
self.gen = nn.Sequential(
|
75 |
+
self.make_gen_block(z_dim, hidden_dim * 16),
|
76 |
+
self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1),
|
77 |
+
self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
|
78 |
+
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4),
|
79 |
+
self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4),
|
80 |
+
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
|
81 |
+
)
|
82 |
+
|
83 |
+
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
|
84 |
+
'''
|
85 |
+
Function to return a sequence of operations corresponding to a Decoder block of the VAE,
|
86 |
+
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
|
87 |
+
Parameters:
|
88 |
+
input_channels: how many channels the input feature representation has
|
89 |
+
output_channels: how many channels the output feature representation should have
|
90 |
+
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
|
91 |
+
stride: the stride of the convolution
|
92 |
+
final_layer: whether we're on the final layer (affects activation and batchnorm)
|
93 |
+
'''
|
94 |
+
if not final_layer:
|
95 |
+
return nn.Sequential(
|
96 |
+
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
|
97 |
+
nn.BatchNorm2d(output_channels),
|
98 |
+
nn.ReLU(inplace=True),
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
return nn.Sequential(
|
102 |
+
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
|
103 |
+
nn.Sigmoid(),
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, noise):
|
107 |
+
'''
|
108 |
+
Function for completing a forward pass of the Decoder: Given a noise vector,
|
109 |
+
returns a generated image.
|
110 |
+
Parameters:
|
111 |
+
noise: a noise tensor with dimensions (batch_size, z_dim)
|
112 |
+
'''
|
113 |
+
x = noise.view(len(noise), self.z_dim, 1, 1)
|
114 |
+
return self.gen(x)
|
115 |
+
|
116 |
+
|
117 |
+
class VAE(nn.Module):
|
118 |
+
'''
|
119 |
+
VAE Class
|
120 |
+
Values:
|
121 |
+
z_dim: the dimension of the noise vector, a scalar
|
122 |
+
im_chan: the number of channels of the output image, a scalar
|
123 |
+
MNIST is black-and-white, so that's our default
|
124 |
+
hidden_dim: the inner dimension, a scalar
|
125 |
+
'''
|
126 |
+
|
127 |
+
def __init__(self, z_dim=Z_DIM, im_chan=3):
|
128 |
+
super(VAE, self).__init__()
|
129 |
+
self.z_dim = z_dim
|
130 |
+
self.encode = Encoder(im_chan, z_dim)
|
131 |
+
self.decode = Decoder(z_dim, im_chan)
|
132 |
+
|
133 |
+
def forward(self, images):
|
134 |
+
'''
|
135 |
+
Function for completing a forward pass of the Decoder: Given a noise vector,
|
136 |
+
returns a generated image.
|
137 |
+
Parameters:
|
138 |
+
images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
|
139 |
+
Returns:
|
140 |
+
decoding: the autoencoded image
|
141 |
+
q_dist: the z-distribution of the encoding
|
142 |
+
'''
|
143 |
+
q_mean, q_stddev = self.encode(images)
|
144 |
+
q_dist = Normal(q_mean, q_stddev)
|
145 |
+
z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
|
146 |
+
decoding = self.decode(z_sample)
|
147 |
+
return decoding, q_dist
|
embedding_models/vae/train.ipynb
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import torch.nn as nn\n",
|
11 |
+
"import numpy as np\n",
|
12 |
+
"\n",
|
13 |
+
"from pathlib import Path\n",
|
14 |
+
"import os\n",
|
15 |
+
"from PIL import Image\n",
|
16 |
+
"\n",
|
17 |
+
"from model import VAE\n",
|
18 |
+
"from losses import *"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
28 |
+
"from torchvision import transforms\n",
|
29 |
+
"import pandas as pd\n",
|
30 |
+
"import re\n",
|
31 |
+
"from sklearn.model_selection import train_test_split"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 1,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"IMAGE_FOLDER = './data/images/'"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 5,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"image_names = os.listdir(IMAGE_FOLDER)\n",
|
50 |
+
"data = pd.DataFrame({'image_name': image_names})\n",
|
51 |
+
"data['label'] = data['image_name'].apply(lambda x: int(re.match('^\\d+', x)[0]))"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"class StampDataset(Dataset):\n",
|
61 |
+
" def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):\n",
|
62 |
+
" super().__init__()\n",
|
63 |
+
" self.image_folder = image_folder\n",
|
64 |
+
" self.data = data\n",
|
65 |
+
" self.transform = transform\n",
|
66 |
+
"\n",
|
67 |
+
" def __getitem__(self, idx):\n",
|
68 |
+
" image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])\n",
|
69 |
+
" label = self.data.iloc[idx]['label']\n",
|
70 |
+
" if self.transform:\n",
|
71 |
+
" image = self.transform(image)\n",
|
72 |
+
"\n",
|
73 |
+
" return image, label\n",
|
74 |
+
"\n",
|
75 |
+
" \n",
|
76 |
+
" def __len__(self):\n",
|
77 |
+
" return len(self.data)"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 6,
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": null,
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"train_transform = transforms.Compose([\n",
|
96 |
+
" transforms.Resize((118, 118)),\n",
|
97 |
+
" transforms.RandomHorizontalFlip(0.5),\n",
|
98 |
+
" transforms.RandomVerticalFlip(0.5),\n",
|
99 |
+
" transforms.ToTensor(),\n",
|
100 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
101 |
+
"])\n",
|
102 |
+
"\n",
|
103 |
+
"val_transform = transforms.Compose([\n",
|
104 |
+
" transforms.Resize((118, 118)),\n",
|
105 |
+
" transforms.ToTensor(),\n",
|
106 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
107 |
+
"])\n",
|
108 |
+
"train_dataset = StampDataset(train_data, transform=train_transform)\n",
|
109 |
+
"val_dataset = StampDataset(val_data, transform=val_transform)\n",
|
110 |
+
"\n",
|
111 |
+
"train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)\n",
|
112 |
+
"val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": 8,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"import pytorch_lightning as pl\n",
|
122 |
+
"from torch import optim\n",
|
123 |
+
"from pytorch_lightning.loggers import TensorBoardLogger\n",
|
124 |
+
"\n",
|
125 |
+
"from torchvision.utils import make_grid"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": 9,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)\n",
|
135 |
+
"STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": 9,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"class LitModel(pl.LightningModule):\n",
|
145 |
+
" def __init__(self, alpha=1e-3):\n",
|
146 |
+
" super().__init__()\n",
|
147 |
+
" self.vae = VAE()\n",
|
148 |
+
" self.vae_loss = VAELoss()\n",
|
149 |
+
" self.triplet_loss = BatchHardTripletLoss(margin=1.)\n",
|
150 |
+
" self.alpha = alpha\n",
|
151 |
+
" \n",
|
152 |
+
" def forward(self, x):\n",
|
153 |
+
" return self.vae(x)\n",
|
154 |
+
" \n",
|
155 |
+
" def configure_optimizers(self):\n",
|
156 |
+
" optimizer = optim.AdamW(self.parameters(), lr=3e-4)\n",
|
157 |
+
" return optimizer\n",
|
158 |
+
" # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
|
159 |
+
" # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
|
160 |
+
"\n",
|
161 |
+
" def training_step(self, batch, batch_idx):\n",
|
162 |
+
" images, labels = batch\n",
|
163 |
+
" labels = labels.unsqueeze(1)\n",
|
164 |
+
" recon_images, encoding = self.vae(images)\n",
|
165 |
+
" vae_loss = self.vae_loss(recon_images, images, encoding)\n",
|
166 |
+
" self.log(\"train_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
167 |
+
" triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
|
168 |
+
" self.log(\"train_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
169 |
+
" loss = self.alpha * triplet_loss + vae_loss\n",
|
170 |
+
" self.log(\"train_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
|
171 |
+
" return loss\n",
|
172 |
+
"\n",
|
173 |
+
" def validation_step(self, batch, batch_idx):\n",
|
174 |
+
" images, labels = batch\n",
|
175 |
+
" labels = labels.unsqueeze(1)\n",
|
176 |
+
" recon_images, encoding = self.vae(images)\n",
|
177 |
+
" vae_loss = self.vae_loss(recon_images, images, encoding)\n",
|
178 |
+
" self.log(\"val_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
179 |
+
" triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
|
180 |
+
" self.log(\"val_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
|
181 |
+
" loss = self.alpha * triplet_loss + vae_loss\n",
|
182 |
+
" self.log(\"val_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
|
183 |
+
" return loss\n",
|
184 |
+
"\n",
|
185 |
+
" def on_validation_epoch_end(self):\n",
|
186 |
+
" images, _ = iter(val_loader).next()\n",
|
187 |
+
" image_unflat = images.detach().cpu()\n",
|
188 |
+
" image_grid = make_grid(image_unflat[:16], nrow=4)\n",
|
189 |
+
" self.logger.experiment.add_image('original images', image_grid, self.current_epoch)\n",
|
190 |
+
"\n",
|
191 |
+
" recon_images, _ = self.vae(images.to(self.device))\n",
|
192 |
+
" image_unflat = recon_images.detach().cpu()\n",
|
193 |
+
" image_grid = make_grid(image_unflat[:16], nrow=4)\n",
|
194 |
+
" self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": 10,
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"litmodel = LitModel()"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": 11,
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"logger = TensorBoardLogger(\"reconstruction_logs\")"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"execution_count": 12,
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"epochs = 100"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": null,
|
227 |
+
"metadata": {},
|
228 |
+
"outputs": [],
|
229 |
+
"source": [
|
230 |
+
"trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
|
231 |
+
"trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [],
|
239 |
+
"source": [
|
240 |
+
"%tensorboard"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": 8,
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": 11,
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [],
|
257 |
+
"source": [
|
258 |
+
"from huggingface_hub import hf_hub_download"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": 12,
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [],
|
266 |
+
"source": [
|
267 |
+
"emb_model = torch.jit.load(hf_hub_download(repo_id=\"stamps-labs/vits8-stamp\", filename=\"vits8stamp-torchscript.pth\")).to(device)"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 21,
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"val_transform = transforms.Compose([\n",
|
277 |
+
" transforms.Resize((224, 224)),\n",
|
278 |
+
" transforms.ToTensor(),\n",
|
279 |
+
" # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
|
280 |
+
"])"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "code",
|
285 |
+
"execution_count": 28,
|
286 |
+
"metadata": {},
|
287 |
+
"outputs": [],
|
288 |
+
"source": [
|
289 |
+
"train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"cell_type": "code",
|
294 |
+
"execution_count": 34,
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [
|
297 |
+
{
|
298 |
+
"name": "stderr",
|
299 |
+
"output_type": "stream",
|
300 |
+
"text": [
|
301 |
+
"C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:1: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
302 |
+
" embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
|
303 |
+
"C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
|
304 |
+
" labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)\n"
|
305 |
+
]
|
306 |
+
}
|
307 |
+
],
|
308 |
+
"source": [
|
309 |
+
"embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
|
310 |
+
"labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": 35,
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [],
|
318 |
+
"source": [
|
319 |
+
"embeds.to_csv('./all_embeds.tsv', sep='\\t', index=False, header=False)"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "code",
|
324 |
+
"execution_count": 36,
|
325 |
+
"metadata": {},
|
326 |
+
"outputs": [],
|
327 |
+
"source": [
|
328 |
+
"labels.to_csv('./all_labels.tsv', sep='\\t', index=False, header=False)"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "code",
|
333 |
+
"execution_count": 126,
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [],
|
336 |
+
"source": [
|
337 |
+
"torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": 129,
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [],
|
345 |
+
"source": [
|
346 |
+
"im = train_dataset[0]"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "code",
|
351 |
+
"execution_count": 132,
|
352 |
+
"metadata": {},
|
353 |
+
"outputs": [
|
354 |
+
{
|
355 |
+
"data": {
|
356 |
+
"text/plain": [
|
357 |
+
"<All keys matched successfully>"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
"execution_count": 132,
|
361 |
+
"metadata": {},
|
362 |
+
"output_type": "execute_result"
|
363 |
+
}
|
364 |
+
],
|
365 |
+
"source": [
|
366 |
+
"model = Encoder()\n",
|
367 |
+
"model.load_state_dict(torch.load('./models/encoder.pth'))"
|
368 |
+
]
|
369 |
+
}
|
370 |
+
],
|
371 |
+
"metadata": {
|
372 |
+
"kernelspec": {
|
373 |
+
"display_name": "Python 3",
|
374 |
+
"language": "python",
|
375 |
+
"name": "python3"
|
376 |
+
},
|
377 |
+
"language_info": {
|
378 |
+
"codemirror_mode": {
|
379 |
+
"name": "ipython",
|
380 |
+
"version": 3
|
381 |
+
},
|
382 |
+
"file_extension": ".py",
|
383 |
+
"mimetype": "text/x-python",
|
384 |
+
"name": "python",
|
385 |
+
"nbconvert_exporter": "python",
|
386 |
+
"pygments_lexer": "ipython3",
|
387 |
+
"version": "3.9.0"
|
388 |
+
},
|
389 |
+
"orig_nbformat": 4
|
390 |
+
},
|
391 |
+
"nbformat": 4,
|
392 |
+
"nbformat_minor": 2
|
393 |
+
}
|
embedding_models/vits8/__init__.py
ADDED
File without changes
|
embedding_models/vits8/example.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from model import ViTStamp
|
3 |
+
def get_embeddings(img_path: str):
|
4 |
+
model = ViTStamp()
|
5 |
+
image = Image.open(img_path)
|
6 |
+
embeddings = model(image=image)
|
7 |
+
return embeddings
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
print(get_embeddings("oml/data/test/images/99d_15.bmp"))
|
embedding_models/vits8/model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
|
5 |
+
class ViTStamp():
|
6 |
+
def __init__(self):
|
7 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
8 |
+
self.model = torch.jit.load(hf_hub_download(repo_id="stamps-labs/vits8-stamp", filename="vits8stamp-torchscript.pth"))
|
9 |
+
self.transform = transforms.ToTensor()
|
10 |
+
def __call__(self, image) -> torch.Tensor():
|
11 |
+
img_tensor = self.transform(image).cuda().unsqueeze(0) if self.device == "cuda" else self.transform(image).unsqueeze(0)
|
12 |
+
features = self.model(img_tensor)
|
13 |
+
return features
|
embedding_models/vits8/oml/__init__.py
ADDED
File without changes
|
embedding_models/vits8/oml/create_dataset.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser("Create a dataset for training with OML",
|
8 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
9 |
+
|
10 |
+
parser.add_argument("--root-data-path", help="Path to images for dataset", default="data/train_val/")
|
11 |
+
parser.add_argument("--image-data-path", help="Image folder in root data path", default="images/")
|
12 |
+
parser.add_argument("--train-val-split",
|
13 |
+
help="In which ratio to split data in format train:val (For example 80:20)", default="80:20")
|
14 |
+
parser.add_argument("--separator",
|
15 |
+
help="What separator is used in image name to separate class name and instance (E.g. circle1_5, separator=_)",
|
16 |
+
default="_")
|
17 |
+
|
18 |
+
args = parser.parse_args()
|
19 |
+
config = vars(args)
|
20 |
+
|
21 |
+
root_path = config["root_data_path"]
|
22 |
+
image_path = config["image_data_path"]
|
23 |
+
separator = config["separator"]
|
24 |
+
|
25 |
+
train_prc, val_prc = tuple(int(num)/100 for num in config["train_val_split"].split(":"))
|
26 |
+
|
27 |
+
class_names = set()
|
28 |
+
for image in os.listdir(root_path+image_path):
|
29 |
+
if image.endswith(("png", "jpg", "bmp", "webp")):
|
30 |
+
img_name = image.split(".")[0]
|
31 |
+
Image.open(root_path+image_path+image).resize((224,224)).save(root_path+image_path+img_name+".png", "PNG")
|
32 |
+
if not image.endswith("png"):
|
33 |
+
os.remove(root_path+image_path+image)
|
34 |
+
img_name = img_name.split(separator)
|
35 |
+
class_name = img_name[0]+img_name[1]
|
36 |
+
class_names.add(class_name)
|
37 |
+
else:
|
38 |
+
print("Not all of the images are in supported format")
|
39 |
+
|
40 |
+
|
41 |
+
#For each class in set assign its index in a set as a class label.
|
42 |
+
class_label_dict = {}
|
43 |
+
for ind, name in enumerate(class_names):
|
44 |
+
class_label_dict[name] = ind
|
45 |
+
|
46 |
+
class_count = len(class_names)
|
47 |
+
train_class_count = int(class_count*train_prc)
|
48 |
+
print(train_class_count)
|
49 |
+
|
50 |
+
df_dict = {"label": [],
|
51 |
+
"path": [],
|
52 |
+
"split": [],
|
53 |
+
"is_query": [],
|
54 |
+
"is_gallery": []}
|
55 |
+
for image in os.listdir(root_path+image_path):
|
56 |
+
if image.endswith((".png", ".jpg", ".bmp", ".webp")):
|
57 |
+
img_name = image.split(".")[0].split(separator)
|
58 |
+
class_name = img_name[0]+img_name[1]
|
59 |
+
label = class_label_dict[class_name]
|
60 |
+
path = image_path+image
|
61 |
+
split = "train" if label <= train_class_count else "validation"
|
62 |
+
is_query, is_gallery = (1, 1) if split=="validation" else (None, None)
|
63 |
+
df_dict["label"].append(label)
|
64 |
+
df_dict["path"].append(path)
|
65 |
+
df_dict["split"].append(split)
|
66 |
+
df_dict["is_query"].append(is_query)
|
67 |
+
df_dict["is_gallery"].append(is_gallery)
|
68 |
+
|
69 |
+
df = pd.DataFrame(df_dict)
|
70 |
+
|
71 |
+
df.to_csv(root_path+"df_stamps.csv", index=False)
|
embedding_models/vits8/oml/data/test/images/99d_15.bmp
ADDED
embedding_models/vits8/oml/data/test/images/99e_20.bmp
ADDED
embedding_models/vits8/oml/data/test/images/99f_25.bmp
ADDED
embedding_models/vits8/oml/data/test/images/99g_30.bmp
ADDED
embedding_models/vits8/oml/data/test/images/99h_35.bmp
ADDED
embedding_models/vits8/oml/data/test/images/99i_40.bmp
ADDED
embedding_models/vits8/oml/data/train_val/df_stamps.csv
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
label,path,split,is_query,is_gallery
|
2 |
+
0,images/circle6_1239.png,train,,
|
3 |
+
8,images/triangle19_1242.png,train,,
|
4 |
+
21,images/rectangle11_1248.png,train,,
|
5 |
+
39,images/triangle10_1232.png,validation,1.0,1.0
|
6 |
+
33,images/word14_1241.png,validation,1.0,1.0
|
7 |
+
38,images/word5_1233.png,validation,1.0,1.0
|
8 |
+
15,images/circle19_1236.png,train,,
|
9 |
+
22,images/circle15_1244.png,train,,
|
10 |
+
32,images/circle21_1249.png,train,,
|
11 |
+
26,images/oval20_1242.png,train,,
|
12 |
+
6,images/oval5_1237.png,train,,
|
13 |
+
23,images/word9_1241.png,train,,
|
14 |
+
9,images/triangle22_1238.png,train,,
|
15 |
+
31,images/circle12_1239.png,train,,
|
16 |
+
11,images/word21_1231.png,train,,
|
17 |
+
4,images/oval2_1235.png,train,,
|
18 |
+
20,images/rectangle18_1246.png,train,,
|
19 |
+
12,images/circle24_1234.png,train,,
|
20 |
+
5,images/circle2_1249.png,train,,
|
21 |
+
37,images/word22_1238.png,validation,1.0,1.0
|
22 |
+
34,images/triangle18_1247.png,validation,1.0,1.0
|
23 |
+
1,images/oval7_1241.png,train,,
|
24 |
+
10,images/triangle13_1240.png,train,,
|
25 |
+
14,images/rectangle12_1236.png,train,,
|
26 |
+
36,images/circle8_1237.png,validation,1.0,1.0
|
27 |
+
24,images/triangle9_1245.png,train,,
|
28 |
+
29,images/word23_1243.png,train,,
|
29 |
+
28,images/triangle11_1244.png,train,,
|
30 |
+
16,images/circle2_1246.png,train,,
|
31 |
+
30,images/circle3_1247.png,train,,
|
32 |
+
18,images/oval24_1248.png,train,,
|
33 |
+
2,images/oval12_1231.png,train,,
|
34 |
+
3,images/oval18_1234.png,train,,
|
35 |
+
25,images/rectangle11_1245.png,train,,
|
36 |
+
17,images/word9_1244.png,train,,
|
37 |
+
13,images/triangle14_1237.png,train,,
|
38 |
+
35,images/circle2_1233.png,validation,1.0,1.0
|
39 |
+
7,images/word18_1239.png,train,,
|
40 |
+
19,images/rectangle13_1236.png,train,,
|
41 |
+
27,images/circle24_1246.png,train,,
|
embedding_models/vits8/oml/data/train_val/images/circle12_1239.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle15_1244.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle19_1236.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle21_1249.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle24_1234.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle24_1246.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1233.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1246.png
ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1249.png
ADDED