sadjava commited on
Commit
fd52b7f
1 Parent(s): a5018b2

changed to pipelines

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/constants.cpython-39.pyc +0 -0
  2. __pycache__/models.cpython-39.pyc +0 -0
  3. __pycache__/utils.cpython-39.pyc +0 -0
  4. app.py +55 -60
  5. detection_models/__init__.py +0 -0
  6. detection_models/__pycache__/__init__.cpython-39.pyc +0 -0
  7. detection_models/yolo_stamp/__init__.py +0 -0
  8. detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc +0 -0
  9. detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc +0 -0
  10. detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc +0 -0
  11. detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc +0 -0
  12. constants.py → detection_models/yolo_stamp/constants.py +0 -8
  13. detection_models/yolo_stamp/data.py +141 -0
  14. detection_models/yolo_stamp/loss.py +52 -0
  15. detection_models/yolo_stamp/model.py +80 -0
  16. detection_models/yolo_stamp/train.ipynb +185 -0
  17. detection_models/yolo_stamp/utils.py +275 -0
  18. detection_models/yolov8/__init__.py +0 -0
  19. detection_models/yolov8/train.ipynb +144 -0
  20. embedding_models/__init__.py +0 -0
  21. embedding_models/__pycache__/__init__.cpython-39.pyc +0 -0
  22. embedding_models/vae/__init__.py +0 -0
  23. embedding_models/vae/__pycache__/__init__.cpython-39.pyc +0 -0
  24. embedding_models/vae/__pycache__/constants.cpython-39.pyc +0 -0
  25. embedding_models/vae/__pycache__/model.cpython-39.pyc +0 -0
  26. embedding_models/vae/constants.py +6 -0
  27. embedding_models/vae/losses.py +77 -0
  28. embedding_models/vae/model.py +147 -0
  29. embedding_models/vae/train.ipynb +393 -0
  30. embedding_models/vits8/__init__.py +0 -0
  31. embedding_models/vits8/example.py +10 -0
  32. embedding_models/vits8/model.py +13 -0
  33. embedding_models/vits8/oml/__init__.py +0 -0
  34. embedding_models/vits8/oml/create_dataset.py +71 -0
  35. embedding_models/vits8/oml/data/test/images/99d_15.bmp +0 -0
  36. embedding_models/vits8/oml/data/test/images/99e_20.bmp +0 -0
  37. embedding_models/vits8/oml/data/test/images/99f_25.bmp +0 -0
  38. embedding_models/vits8/oml/data/test/images/99g_30.bmp +0 -0
  39. embedding_models/vits8/oml/data/test/images/99h_35.bmp +0 -0
  40. embedding_models/vits8/oml/data/test/images/99i_40.bmp +0 -0
  41. embedding_models/vits8/oml/data/train_val/df_stamps.csv +41 -0
  42. embedding_models/vits8/oml/data/train_val/images/circle12_1239.png +0 -0
  43. embedding_models/vits8/oml/data/train_val/images/circle15_1244.png +0 -0
  44. embedding_models/vits8/oml/data/train_val/images/circle19_1236.png +0 -0
  45. embedding_models/vits8/oml/data/train_val/images/circle21_1249.png +0 -0
  46. embedding_models/vits8/oml/data/train_val/images/circle24_1234.png +0 -0
  47. embedding_models/vits8/oml/data/train_val/images/circle24_1246.png +0 -0
  48. embedding_models/vits8/oml/data/train_val/images/circle2_1233.png +0 -0
  49. embedding_models/vits8/oml/data/train_val/images/circle2_1246.png +0 -0
  50. embedding_models/vits8/oml/data/train_val/images/circle2_1249.png +0 -0
__pycache__/constants.cpython-39.pyc ADDED
Binary file (660 Bytes). View file
 
__pycache__/models.cpython-39.pyc ADDED
Binary file (5.22 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.73 kB). View file
 
app.py CHANGED
@@ -1,102 +1,97 @@
1
  import gradio as gr
2
- import numpy as np
3
- from ultralytics import YOLO
4
- from torchvision.transforms.functional import to_tensor
5
- from huggingface_hub import hf_hub_download
6
  import torch
7
- import albumentations as A
8
- from albumentations.pytorch.transforms import ToTensorV2
9
- import pandas as pd
10
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
11
 
12
- from utils import *
13
- from models import YOLOStamp, Encoder
14
-
15
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
-
17
-
18
- yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')
19
 
20
- yolo_stamp = YOLOStamp()
21
- yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu'))
22
- yolo_stamp = yolo_stamp.to(device)
23
- yolo_stamp.eval()
24
- transform = A.Compose([
25
- A.Normalize(),
26
- ToTensorV2(p=1.0),
27
- ])
28
 
29
- vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu')
30
- vits8 = vits8.to(device)
31
- vits8.eval()
32
 
33
- encoder = Encoder()
34
- encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth'), map_location='cpu'))
35
- encoder = encoder.to(device)
36
- encoder.eval()
 
37
 
38
 
39
- def predict(image, det_choice, emb_choice):
40
 
41
- shape = torch.tensor(image.size)
42
  image = image.convert('RGB')
43
 
44
  if det_choice == 'yolov8':
45
- coef = torch.hstack((shape, shape)) / 640
46
- image = image.resize((640, 640))
47
- boxes = yolov8(image)[0].boxes.xyxy.cpu()
48
- image_with_boxes = visualize_bbox(image, boxes)
49
 
50
  elif det_choice == 'yolo-stamp':
51
- coef = torch.hstack((shape, shape)) / 448
52
- image = image.resize((448, 448))
53
- image_tensor = transform(image=np.array(image))['image']
54
- output = yolo_stamp(image_tensor.unsqueeze(0).to(device))
55
-
56
- boxes = output_tensor_to_boxes(output[0].detach().cpu())
57
- boxes = nonmax_suppression(boxes)
58
- boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
59
- image_with_boxes = visualize_bbox(image, boxes)
60
  else:
61
  return
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  embeddings = []
65
  if emb_choice == 'vits8':
66
- for box in boxes:
67
- cropped_stamp = to_tensor(image.crop(box.tolist()))
68
- embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())
69
 
70
  elif emb_choice == 'vae-encoder':
71
- for box in boxes:
72
- cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118)))
73
- embeddings.append(np.array(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu()))
74
 
75
  embeddings = np.stack(embeddings)
76
 
77
  similarities = cosine_similarity(embeddings)
78
 
79
- boxes = boxes * coef
80
  df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
81
 
82
  fig, ax = plt.subplots()
83
  im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
84
  cmap="YlGn", cbarlabel="Embeddings similarities")
85
  texts = annotate_heatmap(im, valfmt="{x:.3f}")
86
- return image_with_boxes, df_boxes, embeddings, fig
87
 
88
 
89
- examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']]
90
- inputs = [
91
- gr.Image(type="pil"),
92
  gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
 
93
  gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
94
  ]
95
- outputs = [
96
- gr.Image(type="pil"),
97
  gr.DataFrame(type='pandas', label="Bounding boxes"),
 
98
  gr.DataFrame(type='numpy', label="Embeddings"),
99
  gr.Plot(label="Cosine Similarities")
100
  ]
101
- app = gr.Interface(predict, inputs, outputs, examples=examples)
102
- app.launch()
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
  import torch
3
+ import numpy as np
 
 
4
  from sklearn.metrics.pairwise import cosine_similarity
5
+ import pandas as pd
6
+ from PIL import Image, ImageDraw
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib
9
 
10
+ from pipelines.detection.yolo_v8 import Yolov8Pipeline
11
+ from pipelines.detection.yolo_stamp import YoloStampPipeline
12
+ from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline
13
+ from pipelines.feature_extraction.vae import VaePipeline
14
+ from pipelines.feature_extraction.vits8 import Vits8Pipeline
 
 
15
 
16
+ from utils import *
 
 
 
 
 
 
 
17
 
 
 
 
18
 
19
+ yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt')
20
+ yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt')
21
+ vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt')
22
+ vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt')
23
+ dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt')
24
 
25
 
26
+ def doc_predict(image, det_choice, seg_choice, emb_choice):
27
 
 
28
  image = image.convert('RGB')
29
 
30
  if det_choice == 'yolov8':
31
+ boxes = yolov8(image)
 
 
 
32
 
33
  elif det_choice == 'yolo-stamp':
34
+ boxes = yolo_stamp(image)
 
 
 
 
 
 
 
 
35
  else:
36
  return
37
+ image_with_boxes = visualize_bbox(image, boxes)
38
+
39
+ segmented_stamps = []
40
+ for box in boxes:
41
+ cropped_stamp = image.crop(box.tolist())
42
+ segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
43
+
44
+ widths, heights = zip(*(i.size for i in segmented_stamps))
45
+
46
+ total_width = sum(widths)
47
+ max_height = max(heights)
48
+
49
+ concatenated_stamps = Image.new('RGB', (total_width, max_height))
50
+
51
+ x_offset = 0
52
+ for im in segmented_stamps:
53
+ concatenated_stamps.paste(im, (x_offset,0))
54
+ x_offset += im.size[0]
55
 
56
  embeddings = []
57
  if emb_choice == 'vits8':
58
+ for stamp in segmented_stamps:
59
+ embeddings.append(vits8(stamp))
 
60
 
61
  elif emb_choice == 'vae-encoder':
62
+ for stamp in segmented_stamps:
63
+ embeddings.append(vae(stamp))
 
64
 
65
  embeddings = np.stack(embeddings)
66
 
67
  similarities = cosine_similarity(embeddings)
68
 
 
69
  df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
70
 
71
  fig, ax = plt.subplots()
72
  im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
73
  cmap="YlGn", cbarlabel="Embeddings similarities")
74
  texts = annotate_heatmap(im, valfmt="{x:.3f}")
75
+ return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig
76
 
77
 
78
+ doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']]
79
+ doc_inputs = [
80
+ gr.Image(label="Document image", type="pil"),
81
  gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
82
+ gr.Checkbox(label="Use segmentation model"),
83
  gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
84
  ]
85
+ doc_outputs = [
86
+ gr.Image(label="Document with bounding boxes", type="pil"),
87
  gr.DataFrame(type='pandas', label="Bounding boxes"),
88
+ gr.Image(label="Segmented stamps", type="pil"),
89
  gr.DataFrame(type='numpy', label="Embeddings"),
90
  gr.Plot(label="Cosine Similarities")
91
  ]
92
+
93
+ with gr.Blocks() as demo:
94
+ with gr.Tab("Signle document"):
95
+ gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples)
96
+
97
+ demo.launch(inline=False)
detection_models/__init__.py ADDED
File without changes
detection_models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (165 Bytes). View file
 
detection_models/yolo_stamp/__init__.py ADDED
File without changes
detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc ADDED
Binary file (611 Bytes). View file
 
detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc ADDED
Binary file (3.07 kB). View file
 
detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc ADDED
Binary file (9.31 kB). View file
 
constants.py → detection_models/yolo_stamp/constants.py RENAMED
@@ -23,11 +23,3 @@ STD = (0.229, 0.224, 0.225)
23
  MEAN = (0.485, 0.456, 0.406)
24
  # box color to show the bounding box on image
25
  BOX_COLOR = (0, 0, 255)
26
-
27
-
28
- # dimenstion of image embedding
29
- Z_DIM = 128
30
- # hidden dimensions for encoder model
31
- ENC_HIDDEN_DIM = 16
32
- # hidden dimensions for decoder model
33
- DEC_HIDDEN_DIM = 64
 
23
  MEAN = (0.485, 0.456, 0.406)
24
  # box color to show the bounding box on image
25
  BOX_COLOR = (0, 0, 255)
 
 
 
 
 
 
 
 
detection_models/yolo_stamp/data.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split
5
+ import albumentations as A
6
+ from albumentations.pytorch.transforms import ToTensorV2
7
+ from PIL import Image
8
+
9
+ from pathlib import Path
10
+ from random import randint
11
+
12
+ from utils import *
13
+
14
+ """
15
+ Dataset class for storing stamps data.
16
+
17
+ Arguments:
18
+ data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,)
19
+ image_folder -- path to folder containing images
20
+ transforms -- transforms from albumentations package
21
+ """
22
+ class StampDataset(Dataset):
23
+ def __init__(
24
+ self,
25
+ data=read_data(),
26
+ image_folder=Path(IMAGE_FOLDER),
27
+ transforms=None):
28
+ self.data = data
29
+ self.image_folder = image_folder
30
+ self.transforms = transforms
31
+
32
+ def __getitem__(self, idx):
33
+ item = self.data[idx]
34
+ image_fn = self.image_folder / item['file_path']
35
+ boxes = item['boxes']
36
+ box_nb = item['box_nb']
37
+ labels = torch.zeros((box_nb, 2), dtype=torch.int64)
38
+ labels[:, 0] = 1
39
+
40
+ img = np.array(Image.open(image_fn))
41
+
42
+ try:
43
+ if self.transforms:
44
+ sample = self.transforms(**{
45
+ "image":img,
46
+ "bboxes": boxes,
47
+ "labels": labels,
48
+ })
49
+ img = sample['image']
50
+ boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
51
+ except:
52
+ return self.__getitem__(randint(0, len(self.data)-1))
53
+
54
+ target_tensor = boxes_to_tensor(boxes.type(torch.float32))
55
+ return img, target_tensor
56
+
57
+ def __len__(self):
58
+ return len(self.data)
59
+
60
+ def collate_fn(batch):
61
+ return tuple(zip(*batch))
62
+
63
+
64
+ def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None):
65
+ """
66
+ Creates StampDataset objects.
67
+
68
+ Arguments:
69
+ data_path -- string or Path, specifying path to annotations file
70
+ train_transforms -- transforms to be applied during training
71
+ val_transforms -- transforms to be applied during validation
72
+
73
+ Returns:
74
+ (train_dataset, val_dataset) -- tuple of StampDataset for training and validation
75
+ """
76
+ data = read_data(data_path)
77
+ if train_transforms is None:
78
+ train_transforms = A.Compose([
79
+ A.RandomCropNearBBox(max_part_shift=0.6, p=0.4),
80
+ A.Resize(height=448, width=448),
81
+ A.HorizontalFlip(p=0.5),
82
+ A.VerticalFlip(p=0.5),
83
+ # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3),
84
+ # A.Blur(blur_limit=4, p=0.3),
85
+ A.Normalize(),
86
+ ToTensorV2(p=1.0),
87
+ ],
88
+ bbox_params={
89
+ "format":"coco",
90
+ 'label_fields': ['labels']
91
+ })
92
+
93
+ if val_transforms is None:
94
+ val_transforms = A.Compose([
95
+ A.Resize(height=448, width=448),
96
+ A.Normalize(),
97
+ ToTensorV2(p=1.0),
98
+ ],
99
+ bbox_params={
100
+ "format":"coco",
101
+ 'label_fields': ['labels']
102
+ })
103
+ train, test_data = train_test_split(data, test_size=0.1, shuffle=True)
104
+
105
+ train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True)
106
+
107
+ train_dataset = StampDataset(train_data, transforms=train_transforms)
108
+ val_dataset = StampDataset(val_data, transforms=val_transforms)
109
+ test_dataset = StampDataset(test_data, transforms=val_transforms)
110
+
111
+ return train_dataset, val_dataset, test_dataset
112
+
113
+
114
+ def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None):
115
+ """
116
+ Creates StampDataset objects.
117
+
118
+ Arguments:
119
+ batch_size -- integer specifying the number of images in the batch
120
+ data_path -- string or Path, specifying path to annotations file
121
+ train_transforms -- transforms to be applied during training
122
+ val_transforms -- transforms to be applied during validation
123
+
124
+ Returns:
125
+ (train_loader, val_loader) -- tuple of DataLoader for training and validation
126
+ """
127
+ train_dataset, val_dataset, _ = get_datasets(data_path)
128
+
129
+ train_loader = DataLoader(
130
+ train_dataset,
131
+ batch_size=batch_size,
132
+ shuffle=True,
133
+ num_workers=num_workers,
134
+ collate_fn=collate_fn, drop_last=True)
135
+
136
+ val_loader = DataLoader(
137
+ val_dataset,
138
+ batch_size=batch_size,
139
+ collate_fn=collate_fn)
140
+
141
+ return train_loader, val_loader
detection_models/yolo_stamp/loss.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import *
4
+
5
+ """
6
+ Class for loss for training YOLO model.
7
+
8
+ Argmunets:
9
+ h_coord: weight for loss related to coordinates and shapes of box
10
+ h__noobj: weight for loss of predicting presence of box when it is absent.
11
+ """
12
+ class YOLOLoss(nn.Module):
13
+ def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.):
14
+ super().__init__()
15
+ self.h_coord = h_coord
16
+ self.h_noobj = h_noobj
17
+ self.h_shape = h_shape
18
+ self.h_obj = h_obj
19
+
20
+ def square_error(self, output, target):
21
+ return (output - target) ** 2
22
+
23
+ def forward(self, output, target):
24
+
25
+ pred_xy, pred_wh, pred_obj = yolo_head(output)
26
+ gt_xy, gt_wh, gt_obj = process_target(target)
27
+
28
+ pred_ul = pred_xy - 0.5 * pred_wh
29
+ pred_br = pred_xy + 0.5 * pred_wh
30
+ pred_area = pred_wh[..., 0] * pred_wh[..., 1]
31
+
32
+ gt_ul = gt_xy - 0.5 * gt_wh
33
+ gt_br = gt_xy + 0.5 * gt_wh
34
+ gt_area = gt_wh[..., 0] * gt_wh[..., 1]
35
+
36
+ intersect_ul = torch.max(pred_ul, gt_ul)
37
+ intersect_br = torch.min(pred_br, gt_br)
38
+ intersect_wh = intersect_br - intersect_ul
39
+ intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
40
+
41
+ iou = intersect_area / (pred_area + gt_area - intersect_area)
42
+ max_iou = torch.max(iou, dim=3, keepdim=True)[0]
43
+ best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1)
44
+ gt_box_conf = best_box_index * gt_obj
45
+
46
+ xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum()
47
+ wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum()
48
+ obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum()
49
+ noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum()
50
+
51
+ total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss
52
+ return total_loss
detection_models/yolo_stamp/model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .constants import *
5
+
6
+ """
7
+ Class for custom activation.
8
+ """
9
+ class SymReLU(nn.Module):
10
+ def __init__(self, inplace: bool = False):
11
+ super().__init__()
12
+ self.inplace = inplace
13
+
14
+ def forward(self, input):
15
+ return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
16
+
17
+ def extra_repr(self) -> str:
18
+ inplace_str = 'inplace=True' if self.inplace else ''
19
+ return inplace_str
20
+
21
+
22
+ """
23
+ Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
24
+ """
25
+ class YOLOStamp(nn.Module):
26
+ def __init__(
27
+ self,
28
+ anchors=ANCHORS,
29
+ in_channels=3,
30
+ ):
31
+ super().__init__()
32
+
33
+ self.register_buffer('anchors', torch.tensor(anchors))
34
+
35
+ self.act = SymReLU()
36
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
37
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
38
+ self.norm1 = nn.BatchNorm2d(num_features=8)
39
+ self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
40
+ self.norm2 = nn.BatchNorm2d(num_features=16)
41
+ self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
42
+ self.norm3 = nn.BatchNorm2d(num_features=16)
43
+ self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
44
+ self.norm4 = nn.BatchNorm2d(num_features=16)
45
+ self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
46
+ self.norm5 = nn.BatchNorm2d(num_features=16)
47
+ self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
48
+ self.norm6 = nn.BatchNorm2d(num_features=24)
49
+ self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
50
+ self.norm7 = nn.BatchNorm2d(num_features=24)
51
+ self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
52
+ self.norm8 = nn.BatchNorm2d(num_features=48)
53
+ self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
54
+ self.norm9 = nn.BatchNorm2d(num_features=48)
55
+ self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
56
+ self.norm10 = nn.BatchNorm2d(num_features=48)
57
+ self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
58
+ self.norm11 = nn.BatchNorm2d(num_features=64)
59
+ self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
60
+ self.norm12 = nn.BatchNorm2d(num_features=256)
61
+ self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
62
+
63
+ def forward(self, x, head=True):
64
+ x = x.type(self.conv1.weight.dtype)
65
+ x = self.act(self.pool(self.norm1(self.conv1(x))))
66
+ x = self.act(self.pool(self.norm2(self.conv2(x))))
67
+ x = self.act(self.pool(self.norm3(self.conv3(x))))
68
+ x = self.act(self.pool(self.norm4(self.conv4(x))))
69
+ x = self.act(self.pool(self.norm5(self.conv5(x))))
70
+ x = self.act(self.norm6(self.conv6(x)))
71
+ x = self.act(self.norm7(self.conv7(x)))
72
+ x = self.act(self.pool(self.norm8(self.conv8(x))))
73
+ x = self.act(self.norm9(self.conv9(x)))
74
+ x = self.act(self.norm10(self.conv10(x)))
75
+ x = self.act(self.norm11(self.conv11(x)))
76
+ x = self.act(self.norm12(self.conv12(x)))
77
+ x = self.conv13(x)
78
+ nb, _, nh, nw= x.shape
79
+ x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
80
+ return x
detection_models/yolo_stamp/train.ipynb ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from model import *\n",
10
+ "from loss import *\n",
11
+ "from data import *\n",
12
+ "from torch import optim\n",
13
+ "from tqdm import tqdm\n",
14
+ "\n",
15
+ "import pytorch_lightning as pl\n",
16
+ "from torchmetrics.detection import MeanAveragePrecision\n",
17
+ "from pytorch_lightning.loggers import TensorBoardLogger"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 2,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "_, _, test_dataset = get_datasets()"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 3,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "class LitModel(pl.LightningModule):\n",
36
+ " def __init__(self):\n",
37
+ " super().__init__()\n",
38
+ " self.model = YOLOStamp()\n",
39
+ " self.criterion = YOLOLoss()\n",
40
+ " self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n",
41
+ " \n",
42
+ " def forward(self, x):\n",
43
+ " return self.model(x)\n",
44
+ "\n",
45
+ " def configure_optimizers(self):\n",
46
+ " optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n",
47
+ " # return optimizer\n",
48
+ " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
49
+ " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
50
+ "\n",
51
+ " def training_step(self, batch, batch_idx):\n",
52
+ " images, targets = batch\n",
53
+ " tensor_images = torch.stack(images)\n",
54
+ " tensor_targets = torch.stack(targets)\n",
55
+ " output = self.model(tensor_images)\n",
56
+ " loss = self.criterion(output, tensor_targets)\n",
57
+ " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
58
+ " return loss\n",
59
+ "\n",
60
+ " def validation_step(self, batch, batch_idx):\n",
61
+ " images, targets = batch\n",
62
+ " tensor_images = torch.stack(images)\n",
63
+ " tensor_targets = torch.stack(targets)\n",
64
+ " output = self.model(tensor_images)\n",
65
+ " loss = self.criterion(output, tensor_targets)\n",
66
+ " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
67
+ "\n",
68
+ " for i in range(len(images)):\n",
69
+ " boxes = output_tensor_to_boxes(output[i].detach().cpu())\n",
70
+ " boxes = nonmax_suppression(boxes)\n",
71
+ " target = target_tensor_to_boxes(targets[i])[::BOX]\n",
72
+ " if not boxes:\n",
73
+ " boxes = torch.zeros((1, 5))\n",
74
+ " preds = [\n",
75
+ " dict(\n",
76
+ " boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n",
77
+ " scores=torch.tensor(boxes)[:, 4].clone().detach(),\n",
78
+ " labels=torch.zeros(len(boxes)),\n",
79
+ " )\n",
80
+ " ]\n",
81
+ " target = [\n",
82
+ " dict(\n",
83
+ " boxes=torch.tensor(target),\n",
84
+ " labels=torch.zeros(len(target)),\n",
85
+ " )\n",
86
+ " ]\n",
87
+ " self.val_map.update(preds, target)\n",
88
+ " \n",
89
+ " def on_validation_epoch_end(self):\n",
90
+ " mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n",
91
+ " mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n",
92
+ " mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n",
93
+ " self.log_dict(mAPs)\n",
94
+ " self.val_map.reset()\n",
95
+ "\n",
96
+ " image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n",
97
+ " output = self.model(image.unsqueeze(0))\n",
98
+ " boxes = output_tensor_to_boxes(output[0].detach().cpu())\n",
99
+ " boxes = nonmax_suppression(boxes)\n",
100
+ " img = image.permute(1, 2, 0).cpu().numpy()\n",
101
+ " img = visualize_bbox(img.copy(), boxes=boxes)\n",
102
+ " img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n",
103
+ " \n",
104
+ " self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 4,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "litmodel = LitModel()"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 5,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "logger = TensorBoardLogger(\"detection_logs\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 7,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "epochs = 100"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 8,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "train_loader, val_loader = get_loaders(batch_size=8)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
150
+ "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "%tensorboard"
160
+ ]
161
+ }
162
+ ],
163
+ "metadata": {
164
+ "kernelspec": {
165
+ "display_name": "Python 3",
166
+ "language": "python",
167
+ "name": "python3"
168
+ },
169
+ "language_info": {
170
+ "codemirror_mode": {
171
+ "name": "ipython",
172
+ "version": 3
173
+ },
174
+ "file_extension": ".py",
175
+ "mimetype": "text/x-python",
176
+ "name": "python",
177
+ "nbconvert_exporter": "python",
178
+ "pygments_lexer": "ipython3",
179
+ "version": "3.9.0"
180
+ },
181
+ "orig_nbformat": 4
182
+ },
183
+ "nbformat": 4,
184
+ "nbformat_minor": 2
185
+ }
detection_models/yolo_stamp/utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import pandas as pd
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import matplotlib.pyplot as plt
7
+ from .constants import *
8
+
9
+
10
+ def output_tensor_to_boxes(boxes_tensor):
11
+ """
12
+ Converts the YOLO output tensor to list of boxes with probabilites.
13
+
14
+ Arguments:
15
+ boxes_tensor -- tensor of shape (S, S, BOX, 5)
16
+
17
+ Returns:
18
+ boxes -- list of shape (None, 5)
19
+
20
+ Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
21
+ For example, the actual output size of scores would be (10, 5) if there are 10 boxes
22
+ """
23
+ cell_w, cell_h = W/S, H/S
24
+ boxes = []
25
+
26
+ for i in range(S):
27
+ for j in range(S):
28
+ for b in range(BOX):
29
+ anchor_wh = torch.tensor(ANCHORS[b])
30
+ data = boxes_tensor[i,j,b]
31
+ xy = torch.sigmoid(data[:2])
32
+ wh = torch.exp(data[2:4])*anchor_wh
33
+ obj_prob = torch.sigmoid(data[4])
34
+
35
+ if obj_prob > OUTPUT_THRESH:
36
+ x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1]
37
+ x, y = x_center+j-w/2, y_center+i-h/2
38
+ x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
39
+ box = [x,y,w,h, obj_prob]
40
+ boxes.append(box)
41
+ return boxes
42
+
43
+
44
+ def plot_img(img, size=(7,7)):
45
+ plt.figure(figsize=size)
46
+ plt.imshow(img)
47
+ plt.show()
48
+
49
+
50
+ def plot_normalized_img(img, std=STD, mean=MEAN, size=(7,7)):
51
+ mean = mean if isinstance(mean, np.ndarray) else np.array(mean)
52
+ std = std if isinstance(std, np.ndarray) else np.array(std)
53
+ plt.figure(figsize=size)
54
+ plt.imshow((255. * (img * std + mean)).astype(np.uint))
55
+ plt.show()
56
+
57
+
58
+ def visualize_bbox(img, boxes, thickness=2, color=BOX_COLOR, draw_center=True):
59
+ """
60
+ Draws boxes on the given image.
61
+
62
+ Arguments:
63
+ img -- torch.Tensor of shape (3, W, H) or numpy.ndarray of shape (W, H, 3)
64
+ boxes -- list of shape (None, 5)
65
+ thickness -- number specifying the thickness of box border
66
+ color -- RGB tuple of shape (3,) specifying the color of boxes
67
+ draw_center -- boolean specifying whether to draw center or not
68
+
69
+ Returns:
70
+ img_copy -- numpy.ndarray of shape(W, H, 3) containing image with bouning boxes
71
+ """
72
+ img_copy = img.cpu().permute(1,2,0).numpy() if isinstance(img, torch.Tensor) else img.copy()
73
+ for box in boxes:
74
+ x,y,w,h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
75
+ img_copy = cv2.rectangle(
76
+ img_copy,
77
+ (x,y),(x+w, y+h),
78
+ color, thickness)
79
+ if draw_center:
80
+ center = (x+w//2, y+h//2)
81
+ img_copy = cv2.circle(img_copy, center=center, radius=3, color=(0,255,0), thickness=2)
82
+ return img_copy
83
+
84
+
85
+ def read_data(annotations=Path(ANNOTATIONS_PATH)):
86
+ """
87
+ Reads annotations data from .csv file. Must contain columns: image_name, bbox_x, bbox_y, bbox_width, bbox_height.
88
+
89
+ Arguments:
90
+ annotations_path -- string or Path specifying path of annotations file
91
+
92
+ Returns:
93
+ data -- list of dictionaries containing path, number of boxes and boxes itself
94
+ """
95
+ data = []
96
+
97
+ boxes = pd.read_csv(annotations)
98
+ image_names = boxes['image_name'].unique()
99
+
100
+ for image_name in image_names:
101
+ cur_boxes = boxes[boxes['image_name'] == image_name]
102
+ img_data = {
103
+ 'file_path': image_name,
104
+ 'box_nb': len(cur_boxes),
105
+ 'boxes': []}
106
+ stamp_nb = img_data['box_nb']
107
+ if stamp_nb <= STAMP_NB_MAX:
108
+ img_data['boxes'] = cur_boxes[['bbox_x', 'bbox_y','bbox_width','bbox_height']].values
109
+ data.append(img_data)
110
+ return data
111
+
112
+ def xywh2xyxy(x):
113
+ """
114
+ Converts xywh format to xyxy
115
+
116
+ Arguments:
117
+ x -- torch.Tensor or np.array (xywh format)
118
+
119
+ Returns:
120
+ y -- torch.Tensor or np.array (xyxy)
121
+ """
122
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
123
+ y[..., 0] = x[..., 0]
124
+ y[..., 1] = x[..., 1]
125
+ y[..., 2] = x[..., 0] + x[..., 2]
126
+ y[..., 3] = x[..., 1] + x[..., 3]
127
+ return y
128
+
129
+ def boxes_to_tensor(boxes):
130
+ """
131
+ Convert list of boxes (and labels) to tensor format
132
+
133
+ Arguments:
134
+ boxes -- list of boxes
135
+
136
+ Returns:
137
+ boxes_tensor -- tensor of shape (S, S, BOX, 5)
138
+ """
139
+ boxes_tensor = torch.zeros((S, S, BOX, 5))
140
+ cell_w, cell_h = W/S, H/S
141
+ for i, box in enumerate(boxes):
142
+ x, y, w, h = box
143
+ # normalize xywh with cell_size
144
+ x, y, w, h = x / cell_w, y / cell_h, w / cell_w, h / cell_h
145
+ center_x, center_y = x + w / 2, y + h / 2
146
+ grid_x = int(np.floor(center_x))
147
+ grid_y = int(np.floor(center_y))
148
+
149
+ if grid_x < S and grid_y < S:
150
+ boxes_tensor[grid_y, grid_x, :, 0:4] = torch.tensor(BOX * [[center_x - grid_x, center_y - grid_y, w, h]])
151
+ boxes_tensor[grid_y, grid_x, :, 4] = torch.tensor(BOX * [1.])
152
+ return boxes_tensor
153
+
154
+
155
+ def target_tensor_to_boxes(boxes_tensor, output_threshold=OUTPUT_THRESH):
156
+ """
157
+ Recover target tensor (tensor output of dataset) to bboxes.
158
+ Arguments:
159
+ boxes_tensor -- tensor of shape (S, S, BOX, 5)
160
+ Returns:
161
+ boxes -- list of boxes, each box is [x, y, w, h]
162
+ """
163
+ cell_w, cell_h = W/S, H/S
164
+ boxes = []
165
+ for i in range(S):
166
+ for j in range(S):
167
+ for b in range(BOX):
168
+ data = boxes_tensor[i,j,b]
169
+ x_center,y_center, w, h, obj_prob = data[0], data[1], data[2], data[3], data[4]
170
+ if obj_prob > output_threshold:
171
+ x, y = x_center+j-w/2, y_center+i-h/2
172
+ x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
173
+ box = [x,y,w,h]
174
+ boxes.append(box)
175
+ return boxes
176
+
177
+
178
+ def overlap(interval_1, interval_2):
179
+ """
180
+ Calculates length of overlap between two intervals.
181
+
182
+ Arguments:
183
+ interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval
184
+ interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval
185
+
186
+ Returns:
187
+ overlap -- length of overlap
188
+ """
189
+ x1, x2 = interval_1
190
+ x3, x4 = interval_2
191
+ if x3 < x1:
192
+ if x4 < x1:
193
+ return 0
194
+ else:
195
+ return min(x2,x4) - x1
196
+ else:
197
+ if x2 < x3:
198
+ return 0
199
+ else:
200
+ return min(x2,x4) - x3
201
+
202
+
203
+ def compute_iou(box1, box2):
204
+ """
205
+ Compute IOU between box1 and box2.
206
+
207
+ Argmunets:
208
+ box1 -- list of shape (5, ). Represents the first box
209
+ box2 -- list of shape (5, ). Represents the second box
210
+ Each box is [x, y, w, h, prob]
211
+
212
+ Returns:
213
+ iou -- intersection over union score between two boxes
214
+ """
215
+ x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3]
216
+ x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3]
217
+
218
+ area1, area2 = w1*h1, w2*h2
219
+ intersect_w = overlap((x1,x1+w1), (x2,x2+w2))
220
+ intersect_h = overlap((y1,y1+h1), (y2,y2+w2))
221
+ if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2:
222
+ return 1.
223
+ intersect_area = intersect_w*intersect_h
224
+ iou = intersect_area/(area1 + area2 - intersect_area)
225
+ return iou
226
+
227
+
228
+ def nonmax_suppression(boxes, iou_thresh = IOU_THRESH):
229
+ """
230
+ Removes ovelap bboxes
231
+
232
+ Arguments:
233
+ boxes -- list of shape (None, 5)
234
+ iou_thresh -- maximal value of iou when boxes are considered different
235
+ Each box is [x, y, w, h, prob]
236
+
237
+ Returns:
238
+ boxes -- list of shape (None, 5) with removed overlapping boxes
239
+ """
240
+ boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
241
+ for i, current_box in enumerate(boxes):
242
+ if current_box[4] <= 0:
243
+ continue
244
+ for j in range(i+1, len(boxes)):
245
+ iou = compute_iou(current_box, boxes[j])
246
+ if iou > iou_thresh:
247
+ boxes[j][4] = 0
248
+ boxes = [box for box in boxes if box[4] > 0]
249
+ return boxes
250
+
251
+
252
+
253
+ def yolo_head(yolo_output):
254
+ """
255
+ Converts a yolo output tensor to separate tensors of coordinates, shapes and probabilities.
256
+
257
+ Arguments:
258
+ yolo_output -- tensor of shape (batch_size, S, S, BOX, 5)
259
+
260
+ Returns:
261
+ xy -- tensor of shape (batch_size, S, S, BOX, 2) containing coordinates of centers of found boxes for each anchor in each grid cell
262
+ wh -- tensor of shape (batch_size, S, S, BOX, 2) containing width and height of found boxes for each anchor in each grid cell
263
+ prob -- tensor of shape (batch_size, S, S, BOX, 1) containing the probability of presence of boxes for each anchor in each grid cell
264
+ """
265
+ xy = torch.sigmoid(yolo_output[..., 0:2])
266
+ anchors_wh = torch.tensor(ANCHORS, device=yolo_output.device).view(1, 1, 1, len(ANCHORS), 2)
267
+ wh = torch.exp(yolo_output[..., 2:4]) * anchors_wh
268
+ prob = torch.sigmoid(yolo_output[..., 4:5])
269
+ return xy, wh, prob
270
+
271
+ def process_target(target):
272
+ xy = target[..., 0:2]
273
+ wh = target[..., 2:4]
274
+ prob = target[..., 4:5]
275
+ return xy, wh, prob
detection_models/yolov8/__init__.py ADDED
File without changes
detection_models/yolov8/train.ipynb ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "HOME = os.getcwd()\n",
11
+ "print(HOME)"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# Pip install method (recommended)\n",
21
+ "\n",
22
+ "%pip install ultralytics==8.0.20\n",
23
+ "\n",
24
+ "from IPython import display\n",
25
+ "display.clear_output()\n",
26
+ "\n",
27
+ "import ultralytics\n",
28
+ "ultralytics.checks()"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "from ultralytics import YOLO\n",
38
+ "\n",
39
+ "from IPython.display import display, Image"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "!mkdir {HOME}/datasets\n",
49
+ "%cd {HOME}/datasets\n",
50
+ "\n",
51
+ "%pip install roboflow --quiet\n",
52
+ "\n",
53
+ "from roboflow import Roboflow\n",
54
+ "rf = Roboflow(api_key=\"YOUR_API_KEY\")\n",
55
+ "project = rf.workspace(\"WORKSPACE\").project(\"PROJECT\")\n",
56
+ "dataset = project.version(1).download(\"yolov8\")"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "%cd {HOME}\n",
66
+ "\n",
67
+ "!yolo task=detect mode=train model=yolov8s.pt data={dataset.location}/data.yaml epochs=25 imgsz=800 plots=True"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "%cd {HOME}\n",
77
+ "Image(filename=f'{HOME}/runs/detect/train/confusion_matrix.png', width=600)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "%cd {HOME}\n",
87
+ "Image(filename=f'{HOME}/runs/detect/train/results.png', width=600)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "%cd {HOME}\n",
97
+ "Image(filename=f'{HOME}/runs/detect/train/val_batch0_pred.jpg', width=600)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "%cd {HOME}\n",
107
+ "\n",
108
+ "!yolo task=detect mode=val model={HOME}/runs/detect/train/weights/best.pt data={dataset.location}/data.yaml"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "%cd {HOME}\n",
118
+ "!yolo task=detect mode=predict model={HOME}/runs/detect/train/weights/best.pt conf=0.25 source={dataset.location}/test/images save=True"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "import glob\n",
128
+ "from IPython.display import Image, display\n",
129
+ "\n",
130
+ "for image_path in glob.glob(f'{HOME}/runs/detect/predict3/*.jpg')[:3]:\n",
131
+ " display(Image(filename=image_path, width=600))\n",
132
+ " print(\"\\n\")"
133
+ ]
134
+ }
135
+ ],
136
+ "metadata": {
137
+ "language_info": {
138
+ "name": "python"
139
+ },
140
+ "orig_nbformat": 4
141
+ },
142
+ "nbformat": 4,
143
+ "nbformat_minor": 2
144
+ }
embedding_models/__init__.py ADDED
File without changes
embedding_models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (165 Bytes). View file
 
embedding_models/vae/__init__.py ADDED
File without changes
embedding_models/vae/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
embedding_models/vae/__pycache__/constants.cpython-39.pyc ADDED
Binary file (237 Bytes). View file
 
embedding_models/vae/__pycache__/model.cpython-39.pyc ADDED
Binary file (5.87 kB). View file
 
embedding_models/vae/constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # dimenstion of image embedding
2
+ Z_DIM = 128
3
+ # hidden dimensions for encoder model
4
+ ENC_HIDDEN_DIM = 16
5
+ # hidden dimensions for decoder model
6
+ DEC_HIDDEN_DIM = 64
embedding_models/vae/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.distributions.kl import kl_divergence
4
+ from torch.distributions.normal import Normal
5
+ from torch.nn.functional import relu
6
+
7
+
8
+
9
+ class BatchHardTripletLoss(nn.Module):
10
+ def __init__(self, margin=1., squared=False, agg='sum'):
11
+ """
12
+ Initalize the loss function with a margin parameter, whether or not to consider
13
+ squared Euclidean distance and how to aggregate the loss in a batch
14
+ """
15
+ super().__init__()
16
+ self.margin = margin
17
+ self.squared = squared
18
+ self.agg = agg
19
+ self.eps = 1e-8
20
+
21
+ def get_pairwise_distances(self, embeddings):
22
+ """
23
+ Computing Euclidean distance for all possible pairs of embeddings.
24
+ """
25
+ ab = embeddings.mm(embeddings.t())
26
+ a_squared = ab.diag().unsqueeze(1)
27
+ b_squared = ab.diag().unsqueeze(0)
28
+ distances = a_squared - 2 * ab + b_squared
29
+ distances = relu(distances)
30
+
31
+ if not self.squared:
32
+ distances = torch.sqrt(distances + self.eps)
33
+
34
+ return distances
35
+
36
+ def hardest_triplet_mining(self, dist_mat, labels):
37
+
38
+ assert len(dist_mat.size()) == 2
39
+ assert dist_mat.size(0) == dist_mat.size(1)
40
+ N = dist_mat.size(0)
41
+
42
+ is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
43
+ is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
44
+
45
+ dist_ap, relative_p_inds = torch.max(
46
+ (dist_mat * is_pos), 1, keepdim=True)
47
+
48
+ dist_an, relative_n_inds = torch.min(
49
+ (dist_mat * is_neg), 1, keepdim=True)
50
+
51
+ return dist_ap, dist_an
52
+
53
+ def forward(self, embeddings, labels):
54
+
55
+ distances = self.get_pairwise_distances(embeddings)
56
+ dist_ap, dist_an = self.hardest_triplet_mining(distances, labels)
57
+
58
+ triplet_loss = relu(dist_ap - dist_an + self.margin).sum()
59
+ return triplet_loss
60
+
61
+
62
+ class VAELoss(nn.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.reconstruction_loss = nn.BCELoss(reduction='sum')
66
+
67
+ def kl_divergence_loss(self, q_dist):
68
+ return kl_divergence(
69
+ q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
70
+ ).sum(-1)
71
+
72
+
73
+ def forward(self, output, target, encoding):
74
+ loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target)
75
+ return loss
76
+
77
+
embedding_models/vae/model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.distributions.normal import Normal
3
+
4
+ from .constants import *
5
+
6
+
7
+ class Encoder(nn.Module):
8
+ '''
9
+ Encoder Class
10
+ Values:
11
+ im_chan: the number of channels of the output image, a scalar
12
+ hidden_dim: the inner dimension, a scalar
13
+ '''
14
+
15
+ def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
16
+ super(Encoder, self).__init__()
17
+ self.z_dim = output_chan
18
+ self.disc = nn.Sequential(
19
+ self.make_disc_block(im_chan, hidden_dim),
20
+ self.make_disc_block(hidden_dim, hidden_dim * 2),
21
+ self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
22
+ self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
23
+ self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
24
+ )
25
+
26
+ def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
27
+ '''
28
+ Function to return a sequence of operations corresponding to a encoder block of the VAE,
29
+ corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
30
+ Parameters:
31
+ input_channels: how many channels the input feature representation has
32
+ output_channels: how many channels the output feature representation should have
33
+ kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
34
+ stride: the stride of the convolution
35
+ final_layer: whether we're on the final layer (affects activation and batchnorm)
36
+ '''
37
+ if not final_layer:
38
+ return nn.Sequential(
39
+ nn.Conv2d(input_channels, output_channels, kernel_size, stride),
40
+ nn.BatchNorm2d(output_channels),
41
+ nn.LeakyReLU(0.2, inplace=True),
42
+ )
43
+ else:
44
+ return nn.Sequential(
45
+ nn.Conv2d(input_channels, output_channels, kernel_size, stride),
46
+ )
47
+
48
+ def forward(self, image):
49
+ '''
50
+ Function for completing a forward pass of the Encoder: Given an image tensor,
51
+ returns a 1-dimension tensor representing fake/real.
52
+ Parameters:
53
+ image: a flattened image tensor with dimension (im_dim)
54
+ '''
55
+ disc_pred = self.disc(image)
56
+ encoding = disc_pred.view(len(disc_pred), -1)
57
+ # The stddev output is treated as the log of the variance of the normal
58
+ # distribution by convention and for numerical stability
59
+ return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
60
+
61
+
62
+ class Decoder(nn.Module):
63
+ '''
64
+ Decoder Class
65
+ Values:
66
+ z_dim: the dimension of the noise vector, a scalar
67
+ im_chan: the number of channels of the output image, a scalar
68
+ hidden_dim: the inner dimension, a scalar
69
+ '''
70
+
71
+ def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM):
72
+ super(Decoder, self).__init__()
73
+ self.z_dim = z_dim
74
+ self.gen = nn.Sequential(
75
+ self.make_gen_block(z_dim, hidden_dim * 16),
76
+ self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1),
77
+ self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
78
+ self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4),
79
+ self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4),
80
+ self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
81
+ )
82
+
83
+ def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
84
+ '''
85
+ Function to return a sequence of operations corresponding to a Decoder block of the VAE,
86
+ corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
87
+ Parameters:
88
+ input_channels: how many channels the input feature representation has
89
+ output_channels: how many channels the output feature representation should have
90
+ kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
91
+ stride: the stride of the convolution
92
+ final_layer: whether we're on the final layer (affects activation and batchnorm)
93
+ '''
94
+ if not final_layer:
95
+ return nn.Sequential(
96
+ nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
97
+ nn.BatchNorm2d(output_channels),
98
+ nn.ReLU(inplace=True),
99
+ )
100
+ else:
101
+ return nn.Sequential(
102
+ nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
103
+ nn.Sigmoid(),
104
+ )
105
+
106
+ def forward(self, noise):
107
+ '''
108
+ Function for completing a forward pass of the Decoder: Given a noise vector,
109
+ returns a generated image.
110
+ Parameters:
111
+ noise: a noise tensor with dimensions (batch_size, z_dim)
112
+ '''
113
+ x = noise.view(len(noise), self.z_dim, 1, 1)
114
+ return self.gen(x)
115
+
116
+
117
+ class VAE(nn.Module):
118
+ '''
119
+ VAE Class
120
+ Values:
121
+ z_dim: the dimension of the noise vector, a scalar
122
+ im_chan: the number of channels of the output image, a scalar
123
+ MNIST is black-and-white, so that's our default
124
+ hidden_dim: the inner dimension, a scalar
125
+ '''
126
+
127
+ def __init__(self, z_dim=Z_DIM, im_chan=3):
128
+ super(VAE, self).__init__()
129
+ self.z_dim = z_dim
130
+ self.encode = Encoder(im_chan, z_dim)
131
+ self.decode = Decoder(z_dim, im_chan)
132
+
133
+ def forward(self, images):
134
+ '''
135
+ Function for completing a forward pass of the Decoder: Given a noise vector,
136
+ returns a generated image.
137
+ Parameters:
138
+ images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
139
+ Returns:
140
+ decoding: the autoencoded image
141
+ q_dist: the z-distribution of the encoding
142
+ '''
143
+ q_mean, q_stddev = self.encode(images)
144
+ q_dist = Normal(q_mean, q_stddev)
145
+ z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
146
+ decoding = self.decode(z_sample)
147
+ return decoding, q_dist
embedding_models/vae/train.ipynb ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "import numpy as np\n",
12
+ "\n",
13
+ "from pathlib import Path\n",
14
+ "import os\n",
15
+ "from PIL import Image\n",
16
+ "\n",
17
+ "from model import VAE\n",
18
+ "from losses import *"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "from torch.utils.data import DataLoader, Dataset\n",
28
+ "from torchvision import transforms\n",
29
+ "import pandas as pd\n",
30
+ "import re\n",
31
+ "from sklearn.model_selection import train_test_split"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 1,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "IMAGE_FOLDER = './data/images/'"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 5,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "image_names = os.listdir(IMAGE_FOLDER)\n",
50
+ "data = pd.DataFrame({'image_name': image_names})\n",
51
+ "data['label'] = data['image_name'].apply(lambda x: int(re.match('^\\d+', x)[0]))"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "class StampDataset(Dataset):\n",
61
+ " def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):\n",
62
+ " super().__init__()\n",
63
+ " self.image_folder = image_folder\n",
64
+ " self.data = data\n",
65
+ " self.transform = transform\n",
66
+ "\n",
67
+ " def __getitem__(self, idx):\n",
68
+ " image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])\n",
69
+ " label = self.data.iloc[idx]['label']\n",
70
+ " if self.transform:\n",
71
+ " image = self.transform(image)\n",
72
+ "\n",
73
+ " return image, label\n",
74
+ "\n",
75
+ " \n",
76
+ " def __len__(self):\n",
77
+ " return len(self.data)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 6,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "train_transform = transforms.Compose([\n",
96
+ " transforms.Resize((118, 118)),\n",
97
+ " transforms.RandomHorizontalFlip(0.5),\n",
98
+ " transforms.RandomVerticalFlip(0.5),\n",
99
+ " transforms.ToTensor(),\n",
100
+ " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
101
+ "])\n",
102
+ "\n",
103
+ "val_transform = transforms.Compose([\n",
104
+ " transforms.Resize((118, 118)),\n",
105
+ " transforms.ToTensor(),\n",
106
+ " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
107
+ "])\n",
108
+ "train_dataset = StampDataset(train_data, transform=train_transform)\n",
109
+ "val_dataset = StampDataset(val_data, transform=val_transform)\n",
110
+ "\n",
111
+ "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)\n",
112
+ "val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 8,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "import pytorch_lightning as pl\n",
122
+ "from torch import optim\n",
123
+ "from pytorch_lightning.loggers import TensorBoardLogger\n",
124
+ "\n",
125
+ "from torchvision.utils import make_grid"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 9,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)\n",
135
+ "STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 9,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "class LitModel(pl.LightningModule):\n",
145
+ " def __init__(self, alpha=1e-3):\n",
146
+ " super().__init__()\n",
147
+ " self.vae = VAE()\n",
148
+ " self.vae_loss = VAELoss()\n",
149
+ " self.triplet_loss = BatchHardTripletLoss(margin=1.)\n",
150
+ " self.alpha = alpha\n",
151
+ " \n",
152
+ " def forward(self, x):\n",
153
+ " return self.vae(x)\n",
154
+ " \n",
155
+ " def configure_optimizers(self):\n",
156
+ " optimizer = optim.AdamW(self.parameters(), lr=3e-4)\n",
157
+ " return optimizer\n",
158
+ " # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
159
+ " # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
160
+ "\n",
161
+ " def training_step(self, batch, batch_idx):\n",
162
+ " images, labels = batch\n",
163
+ " labels = labels.unsqueeze(1)\n",
164
+ " recon_images, encoding = self.vae(images)\n",
165
+ " vae_loss = self.vae_loss(recon_images, images, encoding)\n",
166
+ " self.log(\"train_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
167
+ " triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
168
+ " self.log(\"train_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
169
+ " loss = self.alpha * triplet_loss + vae_loss\n",
170
+ " self.log(\"train_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
171
+ " return loss\n",
172
+ "\n",
173
+ " def validation_step(self, batch, batch_idx):\n",
174
+ " images, labels = batch\n",
175
+ " labels = labels.unsqueeze(1)\n",
176
+ " recon_images, encoding = self.vae(images)\n",
177
+ " vae_loss = self.vae_loss(recon_images, images, encoding)\n",
178
+ " self.log(\"val_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
179
+ " triplet_loss = self.triplet_loss(encoding.mean, labels)\n",
180
+ " self.log(\"val_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n",
181
+ " loss = self.alpha * triplet_loss + vae_loss\n",
182
+ " self.log(\"val_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n",
183
+ " return loss\n",
184
+ "\n",
185
+ " def on_validation_epoch_end(self):\n",
186
+ " images, _ = iter(val_loader).next()\n",
187
+ " image_unflat = images.detach().cpu()\n",
188
+ " image_grid = make_grid(image_unflat[:16], nrow=4)\n",
189
+ " self.logger.experiment.add_image('original images', image_grid, self.current_epoch)\n",
190
+ "\n",
191
+ " recon_images, _ = self.vae(images.to(self.device))\n",
192
+ " image_unflat = recon_images.detach().cpu()\n",
193
+ " image_grid = make_grid(image_unflat[:16], nrow=4)\n",
194
+ " self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 10,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "litmodel = LitModel()"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 11,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "logger = TensorBoardLogger(\"reconstruction_logs\")"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 12,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "epochs = 100"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
231
+ "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "%tensorboard"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 8,
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 11,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "from huggingface_hub import hf_hub_download"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 12,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "emb_model = torch.jit.load(hf_hub_download(repo_id=\"stamps-labs/vits8-stamp\", filename=\"vits8stamp-torchscript.pth\")).to(device)"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 21,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "val_transform = transforms.Compose([\n",
277
+ " transforms.Resize((224, 224)),\n",
278
+ " transforms.ToTensor(),\n",
279
+ " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n",
280
+ "])"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": 28,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": 34,
295
+ "metadata": {},
296
+ "outputs": [
297
+ {
298
+ "name": "stderr",
299
+ "output_type": "stream",
300
+ "text": [
301
+ "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:1: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
302
+ " embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
303
+ "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
304
+ " labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)\n"
305
+ ]
306
+ }
307
+ ],
308
+ "source": [
309
+ "embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n",
310
+ "labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 35,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "embeds.to_csv('./all_embeds.tsv', sep='\\t', index=False, header=False)"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 36,
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "labels.to_csv('./all_labels.tsv', sep='\\t', index=False, header=False)"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 126,
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 129,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "im = train_dataset[0]"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 132,
352
+ "metadata": {},
353
+ "outputs": [
354
+ {
355
+ "data": {
356
+ "text/plain": [
357
+ "<All keys matched successfully>"
358
+ ]
359
+ },
360
+ "execution_count": 132,
361
+ "metadata": {},
362
+ "output_type": "execute_result"
363
+ }
364
+ ],
365
+ "source": [
366
+ "model = Encoder()\n",
367
+ "model.load_state_dict(torch.load('./models/encoder.pth'))"
368
+ ]
369
+ }
370
+ ],
371
+ "metadata": {
372
+ "kernelspec": {
373
+ "display_name": "Python 3",
374
+ "language": "python",
375
+ "name": "python3"
376
+ },
377
+ "language_info": {
378
+ "codemirror_mode": {
379
+ "name": "ipython",
380
+ "version": 3
381
+ },
382
+ "file_extension": ".py",
383
+ "mimetype": "text/x-python",
384
+ "name": "python",
385
+ "nbconvert_exporter": "python",
386
+ "pygments_lexer": "ipython3",
387
+ "version": "3.9.0"
388
+ },
389
+ "orig_nbformat": 4
390
+ },
391
+ "nbformat": 4,
392
+ "nbformat_minor": 2
393
+ }
embedding_models/vits8/__init__.py ADDED
File without changes
embedding_models/vits8/example.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from model import ViTStamp
3
+ def get_embeddings(img_path: str):
4
+ model = ViTStamp()
5
+ image = Image.open(img_path)
6
+ embeddings = model(image=image)
7
+ return embeddings
8
+
9
+ if __name__ == "__main__":
10
+ print(get_embeddings("oml/data/test/images/99d_15.bmp"))
embedding_models/vits8/model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ class ViTStamp():
6
+ def __init__(self):
7
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ self.model = torch.jit.load(hf_hub_download(repo_id="stamps-labs/vits8-stamp", filename="vits8stamp-torchscript.pth"))
9
+ self.transform = transforms.ToTensor()
10
+ def __call__(self, image) -> torch.Tensor():
11
+ img_tensor = self.transform(image).cuda().unsqueeze(0) if self.device == "cuda" else self.transform(image).unsqueeze(0)
12
+ features = self.model(img_tensor)
13
+ return features
embedding_models/vits8/oml/__init__.py ADDED
File without changes
embedding_models/vits8/oml/create_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import pandas as pd
4
+
5
+ import argparse
6
+
7
+ parser = argparse.ArgumentParser("Create a dataset for training with OML",
8
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9
+
10
+ parser.add_argument("--root-data-path", help="Path to images for dataset", default="data/train_val/")
11
+ parser.add_argument("--image-data-path", help="Image folder in root data path", default="images/")
12
+ parser.add_argument("--train-val-split",
13
+ help="In which ratio to split data in format train:val (For example 80:20)", default="80:20")
14
+ parser.add_argument("--separator",
15
+ help="What separator is used in image name to separate class name and instance (E.g. circle1_5, separator=_)",
16
+ default="_")
17
+
18
+ args = parser.parse_args()
19
+ config = vars(args)
20
+
21
+ root_path = config["root_data_path"]
22
+ image_path = config["image_data_path"]
23
+ separator = config["separator"]
24
+
25
+ train_prc, val_prc = tuple(int(num)/100 for num in config["train_val_split"].split(":"))
26
+
27
+ class_names = set()
28
+ for image in os.listdir(root_path+image_path):
29
+ if image.endswith(("png", "jpg", "bmp", "webp")):
30
+ img_name = image.split(".")[0]
31
+ Image.open(root_path+image_path+image).resize((224,224)).save(root_path+image_path+img_name+".png", "PNG")
32
+ if not image.endswith("png"):
33
+ os.remove(root_path+image_path+image)
34
+ img_name = img_name.split(separator)
35
+ class_name = img_name[0]+img_name[1]
36
+ class_names.add(class_name)
37
+ else:
38
+ print("Not all of the images are in supported format")
39
+
40
+
41
+ #For each class in set assign its index in a set as a class label.
42
+ class_label_dict = {}
43
+ for ind, name in enumerate(class_names):
44
+ class_label_dict[name] = ind
45
+
46
+ class_count = len(class_names)
47
+ train_class_count = int(class_count*train_prc)
48
+ print(train_class_count)
49
+
50
+ df_dict = {"label": [],
51
+ "path": [],
52
+ "split": [],
53
+ "is_query": [],
54
+ "is_gallery": []}
55
+ for image in os.listdir(root_path+image_path):
56
+ if image.endswith((".png", ".jpg", ".bmp", ".webp")):
57
+ img_name = image.split(".")[0].split(separator)
58
+ class_name = img_name[0]+img_name[1]
59
+ label = class_label_dict[class_name]
60
+ path = image_path+image
61
+ split = "train" if label <= train_class_count else "validation"
62
+ is_query, is_gallery = (1, 1) if split=="validation" else (None, None)
63
+ df_dict["label"].append(label)
64
+ df_dict["path"].append(path)
65
+ df_dict["split"].append(split)
66
+ df_dict["is_query"].append(is_query)
67
+ df_dict["is_gallery"].append(is_gallery)
68
+
69
+ df = pd.DataFrame(df_dict)
70
+
71
+ df.to_csv(root_path+"df_stamps.csv", index=False)
embedding_models/vits8/oml/data/test/images/99d_15.bmp ADDED
embedding_models/vits8/oml/data/test/images/99e_20.bmp ADDED
embedding_models/vits8/oml/data/test/images/99f_25.bmp ADDED
embedding_models/vits8/oml/data/test/images/99g_30.bmp ADDED
embedding_models/vits8/oml/data/test/images/99h_35.bmp ADDED
embedding_models/vits8/oml/data/test/images/99i_40.bmp ADDED
embedding_models/vits8/oml/data/train_val/df_stamps.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ label,path,split,is_query,is_gallery
2
+ 0,images/circle6_1239.png,train,,
3
+ 8,images/triangle19_1242.png,train,,
4
+ 21,images/rectangle11_1248.png,train,,
5
+ 39,images/triangle10_1232.png,validation,1.0,1.0
6
+ 33,images/word14_1241.png,validation,1.0,1.0
7
+ 38,images/word5_1233.png,validation,1.0,1.0
8
+ 15,images/circle19_1236.png,train,,
9
+ 22,images/circle15_1244.png,train,,
10
+ 32,images/circle21_1249.png,train,,
11
+ 26,images/oval20_1242.png,train,,
12
+ 6,images/oval5_1237.png,train,,
13
+ 23,images/word9_1241.png,train,,
14
+ 9,images/triangle22_1238.png,train,,
15
+ 31,images/circle12_1239.png,train,,
16
+ 11,images/word21_1231.png,train,,
17
+ 4,images/oval2_1235.png,train,,
18
+ 20,images/rectangle18_1246.png,train,,
19
+ 12,images/circle24_1234.png,train,,
20
+ 5,images/circle2_1249.png,train,,
21
+ 37,images/word22_1238.png,validation,1.0,1.0
22
+ 34,images/triangle18_1247.png,validation,1.0,1.0
23
+ 1,images/oval7_1241.png,train,,
24
+ 10,images/triangle13_1240.png,train,,
25
+ 14,images/rectangle12_1236.png,train,,
26
+ 36,images/circle8_1237.png,validation,1.0,1.0
27
+ 24,images/triangle9_1245.png,train,,
28
+ 29,images/word23_1243.png,train,,
29
+ 28,images/triangle11_1244.png,train,,
30
+ 16,images/circle2_1246.png,train,,
31
+ 30,images/circle3_1247.png,train,,
32
+ 18,images/oval24_1248.png,train,,
33
+ 2,images/oval12_1231.png,train,,
34
+ 3,images/oval18_1234.png,train,,
35
+ 25,images/rectangle11_1245.png,train,,
36
+ 17,images/word9_1244.png,train,,
37
+ 13,images/triangle14_1237.png,train,,
38
+ 35,images/circle2_1233.png,validation,1.0,1.0
39
+ 7,images/word18_1239.png,train,,
40
+ 19,images/rectangle13_1236.png,train,,
41
+ 27,images/circle24_1246.png,train,,
embedding_models/vits8/oml/data/train_val/images/circle12_1239.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle15_1244.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle19_1236.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle21_1249.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle24_1234.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle24_1246.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1233.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1246.png ADDED
embedding_models/vits8/oml/data/train_val/images/circle2_1249.png ADDED