diff --git a/__pycache__/constants.cpython-39.pyc b/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7052fc115eac1aac81b65e24b03aa7650bcbd6f5 Binary files /dev/null and b/__pycache__/constants.cpython-39.pyc differ diff --git a/__pycache__/models.cpython-39.pyc b/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b50f30610463e3caaa6b71c1f6049e04ef3c0d5 Binary files /dev/null and b/__pycache__/models.cpython-39.pyc differ diff --git a/__pycache__/utils.cpython-39.pyc b/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2258e2f44c110958847898310efea86c5e6d3f1 Binary files /dev/null and b/__pycache__/utils.cpython-39.pyc differ diff --git a/app.py b/app.py index 365b744512c5764325a49d9b9fc399bf1528d319..96512b0b1d13a3fa0e859ebd2aaed9508a8dc9d7 100644 --- a/app.py +++ b/app.py @@ -1,102 +1,97 @@ import gradio as gr -import numpy as np -from ultralytics import YOLO -from torchvision.transforms.functional import to_tensor -from huggingface_hub import hf_hub_download import torch -import albumentations as A -from albumentations.pytorch.transforms import ToTensorV2 -import pandas as pd +import numpy as np from sklearn.metrics.pairwise import cosine_similarity +import pandas as pd +from PIL import Image, ImageDraw +import matplotlib.pyplot as plt +import matplotlib -from utils import * -from models import YOLOStamp, Encoder - -device = 'cuda' if torch.cuda.is_available() else 'cpu' - - -yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect') +from pipelines.detection.yolo_v8 import Yolov8Pipeline +from pipelines.detection.yolo_stamp import YoloStampPipeline +from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline +from pipelines.feature_extraction.vae import VaePipeline +from pipelines.feature_extraction.vits8 import Vits8Pipeline -yolo_stamp = YOLOStamp() -yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth'), map_location='cpu')) -yolo_stamp = yolo_stamp.to(device) -yolo_stamp.eval() -transform = A.Compose([ - A.Normalize(), - ToTensorV2(p=1.0), -]) +from utils import * -vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'), map_location='cpu') -vits8 = vits8.to(device) -vits8.eval() -encoder = Encoder() -encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth'), map_location='cpu')) -encoder = encoder.to(device) -encoder.eval() +yolov8 = Yolov8Pipeline.from_pretrained(local_model_path='yolov8_old_backup.pt') +yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt') +vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt') +vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt') +dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt') -def predict(image, det_choice, emb_choice): +def doc_predict(image, det_choice, seg_choice, emb_choice): - shape = torch.tensor(image.size) image = image.convert('RGB') if det_choice == 'yolov8': - coef = torch.hstack((shape, shape)) / 640 - image = image.resize((640, 640)) - boxes = yolov8(image)[0].boxes.xyxy.cpu() - image_with_boxes = visualize_bbox(image, boxes) + boxes = yolov8(image) elif det_choice == 'yolo-stamp': - coef = torch.hstack((shape, shape)) / 448 - image = image.resize((448, 448)) - image_tensor = transform(image=np.array(image))['image'] - output = yolo_stamp(image_tensor.unsqueeze(0).to(device)) - - boxes = output_tensor_to_boxes(output[0].detach().cpu()) - boxes = nonmax_suppression(boxes) - boxes = xywh2xyxy(torch.tensor(boxes)[:, :4]) - image_with_boxes = visualize_bbox(image, boxes) + boxes = yolo_stamp(image) else: return - + image_with_boxes = visualize_bbox(image, boxes) + + segmented_stamps = [] + for box in boxes: + cropped_stamp = image.crop(box.tolist()) + segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp) + + widths, heights = zip(*(i.size for i in segmented_stamps)) + + total_width = sum(widths) + max_height = max(heights) + + concatenated_stamps = Image.new('RGB', (total_width, max_height)) + + x_offset = 0 + for im in segmented_stamps: + concatenated_stamps.paste(im, (x_offset,0)) + x_offset += im.size[0] embeddings = [] if emb_choice == 'vits8': - for box in boxes: - cropped_stamp = to_tensor(image.crop(box.tolist())) - embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu()) + for stamp in segmented_stamps: + embeddings.append(vits8(stamp)) elif emb_choice == 'vae-encoder': - for box in boxes: - cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118))) - embeddings.append(np.array(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu())) + for stamp in segmented_stamps: + embeddings.append(vae(stamp)) embeddings = np.stack(embeddings) similarities = cosine_similarity(embeddings) - boxes = boxes * coef df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2']) fig, ax = plt.subplots() im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax, cmap="YlGn", cbarlabel="Embeddings similarities") texts = annotate_heatmap(im, valfmt="{x:.3f}") - return image_with_boxes, df_boxes, embeddings, fig + return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig -examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']] -inputs = [ - gr.Image(type="pil"), +doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']] +doc_inputs = [ + gr.Image(label="Document image", type="pil"), gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'), + gr.Checkbox(label="Use segmentation model"), gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'), ] -outputs = [ - gr.Image(type="pil"), +doc_outputs = [ + gr.Image(label="Document with bounding boxes", type="pil"), gr.DataFrame(type='pandas', label="Bounding boxes"), + gr.Image(label="Segmented stamps", type="pil"), gr.DataFrame(type='numpy', label="Embeddings"), gr.Plot(label="Cosine Similarities") ] -app = gr.Interface(predict, inputs, outputs, examples=examples) -app.launch() \ No newline at end of file + +with gr.Blocks() as demo: + with gr.Tab("Signle document"): + gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples) + +demo.launch(inline=False) \ No newline at end of file diff --git a/detection_models/__init__.py b/detection_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/detection_models/__pycache__/__init__.cpython-39.pyc b/detection_models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2103ff3f9a3baa2d71175ed10627228d0c94364c Binary files /dev/null and b/detection_models/__pycache__/__init__.cpython-39.pyc differ diff --git a/detection_models/yolo_stamp/__init__.py b/detection_models/yolo_stamp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc b/detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e6039c1068185fd51d9232fa6d0056c66c2bbe0 Binary files /dev/null and b/detection_models/yolo_stamp/__pycache__/__init__.cpython-39.pyc differ diff --git a/detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc b/detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35b1e74f3cd28712038a18fba39cdaf4ca9b5dd5 Binary files /dev/null and b/detection_models/yolo_stamp/__pycache__/constants.cpython-39.pyc differ diff --git a/detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc b/detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8187750270e9ea1e3acbe03867654d211a799a7d Binary files /dev/null and b/detection_models/yolo_stamp/__pycache__/model.cpython-39.pyc differ diff --git a/detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc b/detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7055a9d6b9413d31c998fd9bcbd2ab57ed3e6b1 Binary files /dev/null and b/detection_models/yolo_stamp/__pycache__/utils.cpython-39.pyc differ diff --git a/constants.py b/detection_models/yolo_stamp/constants.py similarity index 84% rename from constants.py rename to detection_models/yolo_stamp/constants.py index 7c6b41e35c445ff37789678e5dd32e79090f2f19..6a6013ac285b7bf5cb6dae65b6ea7199d74f45b1 100644 --- a/constants.py +++ b/detection_models/yolo_stamp/constants.py @@ -23,11 +23,3 @@ STD = (0.229, 0.224, 0.225) MEAN = (0.485, 0.456, 0.406) # box color to show the bounding box on image BOX_COLOR = (0, 0, 255) - - -# dimenstion of image embedding -Z_DIM = 128 -# hidden dimensions for encoder model -ENC_HIDDEN_DIM = 16 -# hidden dimensions for decoder model -DEC_HIDDEN_DIM = 64 \ No newline at end of file diff --git a/detection_models/yolo_stamp/data.py b/detection_models/yolo_stamp/data.py new file mode 100644 index 0000000000000000000000000000000000000000..aa90995ee881f4f1e9e420d9a260f83dca4acb95 --- /dev/null +++ b/detection_models/yolo_stamp/data.py @@ -0,0 +1,141 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np +from sklearn.model_selection import train_test_split +import albumentations as A +from albumentations.pytorch.transforms import ToTensorV2 +from PIL import Image + +from pathlib import Path +from random import randint + +from utils import * + +""" + Dataset class for storing stamps data. + + Arguments: + data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,) + image_folder -- path to folder containing images + transforms -- transforms from albumentations package +""" +class StampDataset(Dataset): + def __init__( + self, + data=read_data(), + image_folder=Path(IMAGE_FOLDER), + transforms=None): + self.data = data + self.image_folder = image_folder + self.transforms = transforms + + def __getitem__(self, idx): + item = self.data[idx] + image_fn = self.image_folder / item['file_path'] + boxes = item['boxes'] + box_nb = item['box_nb'] + labels = torch.zeros((box_nb, 2), dtype=torch.int64) + labels[:, 0] = 1 + + img = np.array(Image.open(image_fn)) + + try: + if self.transforms: + sample = self.transforms(**{ + "image":img, + "bboxes": boxes, + "labels": labels, + }) + img = sample['image'] + boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) + except: + return self.__getitem__(randint(0, len(self.data)-1)) + + target_tensor = boxes_to_tensor(boxes.type(torch.float32)) + return img, target_tensor + + def __len__(self): + return len(self.data) + +def collate_fn(batch): + return tuple(zip(*batch)) + + +def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None): + """ + Creates StampDataset objects. + + Arguments: + data_path -- string or Path, specifying path to annotations file + train_transforms -- transforms to be applied during training + val_transforms -- transforms to be applied during validation + + Returns: + (train_dataset, val_dataset) -- tuple of StampDataset for training and validation + """ + data = read_data(data_path) + if train_transforms is None: + train_transforms = A.Compose([ + A.RandomCropNearBBox(max_part_shift=0.6, p=0.4), + A.Resize(height=448, width=448), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3), + # A.Blur(blur_limit=4, p=0.3), + A.Normalize(), + ToTensorV2(p=1.0), + ], + bbox_params={ + "format":"coco", + 'label_fields': ['labels'] + }) + + if val_transforms is None: + val_transforms = A.Compose([ + A.Resize(height=448, width=448), + A.Normalize(), + ToTensorV2(p=1.0), + ], + bbox_params={ + "format":"coco", + 'label_fields': ['labels'] + }) + train, test_data = train_test_split(data, test_size=0.1, shuffle=True) + + train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True) + + train_dataset = StampDataset(train_data, transforms=train_transforms) + val_dataset = StampDataset(val_data, transforms=val_transforms) + test_dataset = StampDataset(test_data, transforms=val_transforms) + + return train_dataset, val_dataset, test_dataset + + +def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None): + """ + Creates StampDataset objects. + + Arguments: + batch_size -- integer specifying the number of images in the batch + data_path -- string or Path, specifying path to annotations file + train_transforms -- transforms to be applied during training + val_transforms -- transforms to be applied during validation + + Returns: + (train_loader, val_loader) -- tuple of DataLoader for training and validation + """ + train_dataset, val_dataset, _ = get_datasets(data_path) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn, drop_last=True) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + collate_fn=collate_fn) + + return train_loader, val_loader diff --git a/detection_models/yolo_stamp/loss.py b/detection_models/yolo_stamp/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2a85aabbcc007e400db02f505f618231c57609ac --- /dev/null +++ b/detection_models/yolo_stamp/loss.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from utils import * + +""" + Class for loss for training YOLO model. + + Argmunets: + h_coord: weight for loss related to coordinates and shapes of box + h__noobj: weight for loss of predicting presence of box when it is absent. +""" +class YOLOLoss(nn.Module): + def __init__(self, h_coord=0.5, h_noobj=2., h_shape=2., h_obj=10.): + super().__init__() + self.h_coord = h_coord + self.h_noobj = h_noobj + self.h_shape = h_shape + self.h_obj = h_obj + + def square_error(self, output, target): + return (output - target) ** 2 + + def forward(self, output, target): + + pred_xy, pred_wh, pred_obj = yolo_head(output) + gt_xy, gt_wh, gt_obj = process_target(target) + + pred_ul = pred_xy - 0.5 * pred_wh + pred_br = pred_xy + 0.5 * pred_wh + pred_area = pred_wh[..., 0] * pred_wh[..., 1] + + gt_ul = gt_xy - 0.5 * gt_wh + gt_br = gt_xy + 0.5 * gt_wh + gt_area = gt_wh[..., 0] * gt_wh[..., 1] + + intersect_ul = torch.max(pred_ul, gt_ul) + intersect_br = torch.min(pred_br, gt_br) + intersect_wh = intersect_br - intersect_ul + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + + iou = intersect_area / (pred_area + gt_area - intersect_area) + max_iou = torch.max(iou, dim=3, keepdim=True)[0] + best_box_index = torch.unsqueeze(torch.eq(iou, max_iou).float(), dim=-1) + gt_box_conf = best_box_index * gt_obj + + xy_loss = (self.square_error(pred_xy, gt_xy) * gt_box_conf).sum() + wh_loss = (self.square_error(pred_wh, gt_wh) * gt_box_conf).sum() + obj_loss = (self.square_error(pred_obj, gt_obj) * gt_box_conf).sum() + noobj_loss = (self.square_error(pred_obj, gt_obj) * (1 - gt_box_conf)).sum() + + total_loss = self.h_coord * xy_loss + self.h_shape * wh_loss + self.h_obj * obj_loss + self.h_noobj * noobj_loss + return total_loss \ No newline at end of file diff --git a/detection_models/yolo_stamp/model.py b/detection_models/yolo_stamp/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2e8728f446f7abe0eadbf6ff595e19024e5c6d --- /dev/null +++ b/detection_models/yolo_stamp/model.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + +from .constants import * + +""" + Class for custom activation. +""" +class SymReLU(nn.Module): + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input): + return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input)) + + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str + + +""" + Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046. +""" +class YOLOStamp(nn.Module): + def __init__( + self, + anchors=ANCHORS, + in_channels=3, + ): + super().__init__() + + self.register_buffer('anchors', torch.tensor(anchors)) + + self.act = SymReLU() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm1 = nn.BatchNorm2d(num_features=8) + self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm2 = nn.BatchNorm2d(num_features=16) + self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm3 = nn.BatchNorm2d(num_features=16) + self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm4 = nn.BatchNorm2d(num_features=16) + self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm5 = nn.BatchNorm2d(num_features=16) + self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm6 = nn.BatchNorm2d(num_features=24) + self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm7 = nn.BatchNorm2d(num_features=24) + self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm8 = nn.BatchNorm2d(num_features=48) + self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm9 = nn.BatchNorm2d(num_features=48) + self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm10 = nn.BatchNorm2d(num_features=48) + self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.norm11 = nn.BatchNorm2d(num_features=64) + self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) + self.norm12 = nn.BatchNorm2d(num_features=256) + self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) + + def forward(self, x, head=True): + x = x.type(self.conv1.weight.dtype) + x = self.act(self.pool(self.norm1(self.conv1(x)))) + x = self.act(self.pool(self.norm2(self.conv2(x)))) + x = self.act(self.pool(self.norm3(self.conv3(x)))) + x = self.act(self.pool(self.norm4(self.conv4(x)))) + x = self.act(self.pool(self.norm5(self.conv5(x)))) + x = self.act(self.norm6(self.conv6(x))) + x = self.act(self.norm7(self.conv7(x))) + x = self.act(self.pool(self.norm8(self.conv8(x)))) + x = self.act(self.norm9(self.conv9(x))) + x = self.act(self.norm10(self.conv10(x))) + x = self.act(self.norm11(self.conv11(x))) + x = self.act(self.norm12(self.conv12(x))) + x = self.conv13(x) + nb, _, nh, nw= x.shape + x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5) + return x diff --git a/detection_models/yolo_stamp/train.ipynb b/detection_models/yolo_stamp/train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..78ebbd291106f5070058048affd31dd4c55ef0b4 --- /dev/null +++ b/detection_models/yolo_stamp/train.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from model import *\n", + "from loss import *\n", + "from data import *\n", + "from torch import optim\n", + "from tqdm import tqdm\n", + "\n", + "import pytorch_lightning as pl\n", + "from torchmetrics.detection import MeanAveragePrecision\n", + "from pytorch_lightning.loggers import TensorBoardLogger" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "_, _, test_dataset = get_datasets()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class LitModel(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.model = YOLOStamp()\n", + " self.criterion = YOLOLoss()\n", + " self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n", + " \n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n", + " # return optimizer\n", + " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n", + " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " images, targets = batch\n", + " tensor_images = torch.stack(images)\n", + " tensor_targets = torch.stack(targets)\n", + " output = self.model(tensor_images)\n", + " loss = self.criterion(output, tensor_targets)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " images, targets = batch\n", + " tensor_images = torch.stack(images)\n", + " tensor_targets = torch.stack(targets)\n", + " output = self.model(tensor_images)\n", + " loss = self.criterion(output, tensor_targets)\n", + " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "\n", + " for i in range(len(images)):\n", + " boxes = output_tensor_to_boxes(output[i].detach().cpu())\n", + " boxes = nonmax_suppression(boxes)\n", + " target = target_tensor_to_boxes(targets[i])[::BOX]\n", + " if not boxes:\n", + " boxes = torch.zeros((1, 5))\n", + " preds = [\n", + " dict(\n", + " boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n", + " scores=torch.tensor(boxes)[:, 4].clone().detach(),\n", + " labels=torch.zeros(len(boxes)),\n", + " )\n", + " ]\n", + " target = [\n", + " dict(\n", + " boxes=torch.tensor(target),\n", + " labels=torch.zeros(len(target)),\n", + " )\n", + " ]\n", + " self.val_map.update(preds, target)\n", + " \n", + " def on_validation_epoch_end(self):\n", + " mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n", + " mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n", + " mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n", + " self.log_dict(mAPs)\n", + " self.val_map.reset()\n", + "\n", + " image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n", + " output = self.model(image.unsqueeze(0))\n", + " boxes = output_tensor_to_boxes(output[0].detach().cpu())\n", + " boxes = nonmax_suppression(boxes)\n", + " img = image.permute(1, 2, 0).cpu().numpy()\n", + " img = visualize_bbox(img.copy(), boxes=boxes)\n", + " img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n", + " \n", + " self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "litmodel = LitModel()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "logger = TensorBoardLogger(\"detection_logs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader, val_loader = get_loaders(batch_size=8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n", + "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%tensorboard" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/detection_models/yolo_stamp/utils.py b/detection_models/yolo_stamp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a46bd4152aec02c2fc8f5229ec2970a62cd5379 --- /dev/null +++ b/detection_models/yolo_stamp/utils.py @@ -0,0 +1,275 @@ +import torch +import cv2 +import pandas as pd +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt +from .constants import * + + +def output_tensor_to_boxes(boxes_tensor): + """ + Converts the YOLO output tensor to list of boxes with probabilites. + + Arguments: + boxes_tensor -- tensor of shape (S, S, BOX, 5) + + Returns: + boxes -- list of shape (None, 5) + + Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold. + For example, the actual output size of scores would be (10, 5) if there are 10 boxes + """ + cell_w, cell_h = W/S, H/S + boxes = [] + + for i in range(S): + for j in range(S): + for b in range(BOX): + anchor_wh = torch.tensor(ANCHORS[b]) + data = boxes_tensor[i,j,b] + xy = torch.sigmoid(data[:2]) + wh = torch.exp(data[2:4])*anchor_wh + obj_prob = torch.sigmoid(data[4]) + + if obj_prob > OUTPUT_THRESH: + x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1] + x, y = x_center+j-w/2, y_center+i-h/2 + x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h + box = [x,y,w,h, obj_prob] + boxes.append(box) + return boxes + + +def plot_img(img, size=(7,7)): + plt.figure(figsize=size) + plt.imshow(img) + plt.show() + + +def plot_normalized_img(img, std=STD, mean=MEAN, size=(7,7)): + mean = mean if isinstance(mean, np.ndarray) else np.array(mean) + std = std if isinstance(std, np.ndarray) else np.array(std) + plt.figure(figsize=size) + plt.imshow((255. * (img * std + mean)).astype(np.uint)) + plt.show() + + +def visualize_bbox(img, boxes, thickness=2, color=BOX_COLOR, draw_center=True): + """ + Draws boxes on the given image. + + Arguments: + img -- torch.Tensor of shape (3, W, H) or numpy.ndarray of shape (W, H, 3) + boxes -- list of shape (None, 5) + thickness -- number specifying the thickness of box border + color -- RGB tuple of shape (3,) specifying the color of boxes + draw_center -- boolean specifying whether to draw center or not + + Returns: + img_copy -- numpy.ndarray of shape(W, H, 3) containing image with bouning boxes + """ + img_copy = img.cpu().permute(1,2,0).numpy() if isinstance(img, torch.Tensor) else img.copy() + for box in boxes: + x,y,w,h = int(box[0]), int(box[1]), int(box[2]), int(box[3]) + img_copy = cv2.rectangle( + img_copy, + (x,y),(x+w, y+h), + color, thickness) + if draw_center: + center = (x+w//2, y+h//2) + img_copy = cv2.circle(img_copy, center=center, radius=3, color=(0,255,0), thickness=2) + return img_copy + + +def read_data(annotations=Path(ANNOTATIONS_PATH)): + """ + Reads annotations data from .csv file. Must contain columns: image_name, bbox_x, bbox_y, bbox_width, bbox_height. + + Arguments: + annotations_path -- string or Path specifying path of annotations file + + Returns: + data -- list of dictionaries containing path, number of boxes and boxes itself + """ + data = [] + + boxes = pd.read_csv(annotations) + image_names = boxes['image_name'].unique() + + for image_name in image_names: + cur_boxes = boxes[boxes['image_name'] == image_name] + img_data = { + 'file_path': image_name, + 'box_nb': len(cur_boxes), + 'boxes': []} + stamp_nb = img_data['box_nb'] + if stamp_nb <= STAMP_NB_MAX: + img_data['boxes'] = cur_boxes[['bbox_x', 'bbox_y','bbox_width','bbox_height']].values + data.append(img_data) + return data + +def xywh2xyxy(x): + """ + Converts xywh format to xyxy + + Arguments: + x -- torch.Tensor or np.array (xywh format) + + Returns: + y -- torch.Tensor or np.array (xyxy) + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + y[..., 1] = x[..., 1] + y[..., 2] = x[..., 0] + x[..., 2] + y[..., 3] = x[..., 1] + x[..., 3] + return y + +def boxes_to_tensor(boxes): + """ + Convert list of boxes (and labels) to tensor format + + Arguments: + boxes -- list of boxes + + Returns: + boxes_tensor -- tensor of shape (S, S, BOX, 5) + """ + boxes_tensor = torch.zeros((S, S, BOX, 5)) + cell_w, cell_h = W/S, H/S + for i, box in enumerate(boxes): + x, y, w, h = box + # normalize xywh with cell_size + x, y, w, h = x / cell_w, y / cell_h, w / cell_w, h / cell_h + center_x, center_y = x + w / 2, y + h / 2 + grid_x = int(np.floor(center_x)) + grid_y = int(np.floor(center_y)) + + if grid_x < S and grid_y < S: + boxes_tensor[grid_y, grid_x, :, 0:4] = torch.tensor(BOX * [[center_x - grid_x, center_y - grid_y, w, h]]) + boxes_tensor[grid_y, grid_x, :, 4] = torch.tensor(BOX * [1.]) + return boxes_tensor + + +def target_tensor_to_boxes(boxes_tensor, output_threshold=OUTPUT_THRESH): + """ + Recover target tensor (tensor output of dataset) to bboxes. + Arguments: + boxes_tensor -- tensor of shape (S, S, BOX, 5) + Returns: + boxes -- list of boxes, each box is [x, y, w, h] + """ + cell_w, cell_h = W/S, H/S + boxes = [] + for i in range(S): + for j in range(S): + for b in range(BOX): + data = boxes_tensor[i,j,b] + x_center,y_center, w, h, obj_prob = data[0], data[1], data[2], data[3], data[4] + if obj_prob > output_threshold: + x, y = x_center+j-w/2, y_center+i-h/2 + x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h + box = [x,y,w,h] + boxes.append(box) + return boxes + + +def overlap(interval_1, interval_2): + """ + Calculates length of overlap between two intervals. + + Arguments: + interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval + interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval + + Returns: + overlap -- length of overlap + """ + x1, x2 = interval_1 + x3, x4 = interval_2 + if x3 < x1: + if x4 < x1: + return 0 + else: + return min(x2,x4) - x1 + else: + if x2 < x3: + return 0 + else: + return min(x2,x4) - x3 + + +def compute_iou(box1, box2): + """ + Compute IOU between box1 and box2. + + Argmunets: + box1 -- list of shape (5, ). Represents the first box + box2 -- list of shape (5, ). Represents the second box + Each box is [x, y, w, h, prob] + + Returns: + iou -- intersection over union score between two boxes + """ + x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3] + x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3] + + area1, area2 = w1*h1, w2*h2 + intersect_w = overlap((x1,x1+w1), (x2,x2+w2)) + intersect_h = overlap((y1,y1+h1), (y2,y2+w2)) + if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2: + return 1. + intersect_area = intersect_w*intersect_h + iou = intersect_area/(area1 + area2 - intersect_area) + return iou + + +def nonmax_suppression(boxes, iou_thresh = IOU_THRESH): + """ + Removes ovelap bboxes + + Arguments: + boxes -- list of shape (None, 5) + iou_thresh -- maximal value of iou when boxes are considered different + Each box is [x, y, w, h, prob] + + Returns: + boxes -- list of shape (None, 5) with removed overlapping boxes + """ + boxes = sorted(boxes, key=lambda x: x[4], reverse=True) + for i, current_box in enumerate(boxes): + if current_box[4] <= 0: + continue + for j in range(i+1, len(boxes)): + iou = compute_iou(current_box, boxes[j]) + if iou > iou_thresh: + boxes[j][4] = 0 + boxes = [box for box in boxes if box[4] > 0] + return boxes + + + +def yolo_head(yolo_output): + """ + Converts a yolo output tensor to separate tensors of coordinates, shapes and probabilities. + + Arguments: + yolo_output -- tensor of shape (batch_size, S, S, BOX, 5) + + Returns: + xy -- tensor of shape (batch_size, S, S, BOX, 2) containing coordinates of centers of found boxes for each anchor in each grid cell + wh -- tensor of shape (batch_size, S, S, BOX, 2) containing width and height of found boxes for each anchor in each grid cell + prob -- tensor of shape (batch_size, S, S, BOX, 1) containing the probability of presence of boxes for each anchor in each grid cell + """ + xy = torch.sigmoid(yolo_output[..., 0:2]) + anchors_wh = torch.tensor(ANCHORS, device=yolo_output.device).view(1, 1, 1, len(ANCHORS), 2) + wh = torch.exp(yolo_output[..., 2:4]) * anchors_wh + prob = torch.sigmoid(yolo_output[..., 4:5]) + return xy, wh, prob + +def process_target(target): + xy = target[..., 0:2] + wh = target[..., 2:4] + prob = target[..., 4:5] + return xy, wh, prob \ No newline at end of file diff --git a/detection_models/yolov8/__init__.py b/detection_models/yolov8/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/detection_models/yolov8/train.ipynb b/detection_models/yolov8/train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..79f6774e2bba15bbbf8a57f9e4026579051c22f4 --- /dev/null +++ b/detection_models/yolov8/train.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "HOME = os.getcwd()\n", + "print(HOME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pip install method (recommended)\n", + "\n", + "%pip install ultralytics==8.0.20\n", + "\n", + "from IPython import display\n", + "display.clear_output()\n", + "\n", + "import ultralytics\n", + "ultralytics.checks()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ultralytics import YOLO\n", + "\n", + "from IPython.display import display, Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir {HOME}/datasets\n", + "%cd {HOME}/datasets\n", + "\n", + "%pip install roboflow --quiet\n", + "\n", + "from roboflow import Roboflow\n", + "rf = Roboflow(api_key=\"YOUR_API_KEY\")\n", + "project = rf.workspace(\"WORKSPACE\").project(\"PROJECT\")\n", + "dataset = project.version(1).download(\"yolov8\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "\n", + "!yolo task=detect mode=train model=yolov8s.pt data={dataset.location}/data.yaml epochs=25 imgsz=800 plots=True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "Image(filename=f'{HOME}/runs/detect/train/confusion_matrix.png', width=600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "Image(filename=f'{HOME}/runs/detect/train/results.png', width=600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "Image(filename=f'{HOME}/runs/detect/train/val_batch0_pred.jpg', width=600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "\n", + "!yolo task=detect mode=val model={HOME}/runs/detect/train/weights/best.pt data={dataset.location}/data.yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {HOME}\n", + "!yolo task=detect mode=predict model={HOME}/runs/detect/train/weights/best.pt conf=0.25 source={dataset.location}/test/images save=True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "from IPython.display import Image, display\n", + "\n", + "for image_path in glob.glob(f'{HOME}/runs/detect/predict3/*.jpg')[:3]:\n", + " display(Image(filename=image_path, width=600))\n", + " print(\"\\n\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/embedding_models/__init__.py b/embedding_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/embedding_models/__pycache__/__init__.cpython-39.pyc b/embedding_models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1a5d3d462e34baebb508ddad602ac6c0227cc2d Binary files /dev/null and b/embedding_models/__pycache__/__init__.cpython-39.pyc differ diff --git a/embedding_models/vae/__init__.py b/embedding_models/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/embedding_models/vae/__pycache__/__init__.cpython-39.pyc b/embedding_models/vae/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c489f3717b2efe7490c72200e8e0926203685cf3 Binary files /dev/null and b/embedding_models/vae/__pycache__/__init__.cpython-39.pyc differ diff --git a/embedding_models/vae/__pycache__/constants.cpython-39.pyc b/embedding_models/vae/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9caddca7993ca090f61a751bde15bf092d93b6b3 Binary files /dev/null and b/embedding_models/vae/__pycache__/constants.cpython-39.pyc differ diff --git a/embedding_models/vae/__pycache__/model.cpython-39.pyc b/embedding_models/vae/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae12c58b2af03bde3e9cd44bdcca6d0cef37de6 Binary files /dev/null and b/embedding_models/vae/__pycache__/model.cpython-39.pyc differ diff --git a/embedding_models/vae/constants.py b/embedding_models/vae/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..bda9b9ed2f8d5d46ff9072d8f8ae5b9f94c923cf --- /dev/null +++ b/embedding_models/vae/constants.py @@ -0,0 +1,6 @@ +# dimenstion of image embedding +Z_DIM = 128 +# hidden dimensions for encoder model +ENC_HIDDEN_DIM = 16 +# hidden dimensions for decoder model +DEC_HIDDEN_DIM = 64 \ No newline at end of file diff --git a/embedding_models/vae/losses.py b/embedding_models/vae/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..4c73941b1d9d4a3c551ea7d116a0082fe602d94f --- /dev/null +++ b/embedding_models/vae/losses.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +from torch.distributions.kl import kl_divergence +from torch.distributions.normal import Normal +from torch.nn.functional import relu + + + +class BatchHardTripletLoss(nn.Module): + def __init__(self, margin=1., squared=False, agg='sum'): + """ + Initalize the loss function with a margin parameter, whether or not to consider + squared Euclidean distance and how to aggregate the loss in a batch + """ + super().__init__() + self.margin = margin + self.squared = squared + self.agg = agg + self.eps = 1e-8 + + def get_pairwise_distances(self, embeddings): + """ + Computing Euclidean distance for all possible pairs of embeddings. + """ + ab = embeddings.mm(embeddings.t()) + a_squared = ab.diag().unsqueeze(1) + b_squared = ab.diag().unsqueeze(0) + distances = a_squared - 2 * ab + b_squared + distances = relu(distances) + + if not self.squared: + distances = torch.sqrt(distances + self.eps) + + return distances + + def hardest_triplet_mining(self, dist_mat, labels): + + assert len(dist_mat.size()) == 2 + assert dist_mat.size(0) == dist_mat.size(1) + N = dist_mat.size(0) + + is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) + is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) + + dist_ap, relative_p_inds = torch.max( + (dist_mat * is_pos), 1, keepdim=True) + + dist_an, relative_n_inds = torch.min( + (dist_mat * is_neg), 1, keepdim=True) + + return dist_ap, dist_an + + def forward(self, embeddings, labels): + + distances = self.get_pairwise_distances(embeddings) + dist_ap, dist_an = self.hardest_triplet_mining(distances, labels) + + triplet_loss = relu(dist_ap - dist_an + self.margin).sum() + return triplet_loss + + +class VAELoss(nn.Module): + def __init__(self): + super().__init__() + self.reconstruction_loss = nn.BCELoss(reduction='sum') + + def kl_divergence_loss(self, q_dist): + return kl_divergence( + q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev)) + ).sum(-1) + + + def forward(self, output, target, encoding): + loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target) + return loss + + diff --git a/embedding_models/vae/model.py b/embedding_models/vae/model.py new file mode 100644 index 0000000000000000000000000000000000000000..965c8815f95aee365aab639962385bd147cf5d96 --- /dev/null +++ b/embedding_models/vae/model.py @@ -0,0 +1,147 @@ +import torch.nn as nn +from torch.distributions.normal import Normal + +from .constants import * + + +class Encoder(nn.Module): + ''' + Encoder Class + Values: + im_chan: the number of channels of the output image, a scalar + hidden_dim: the inner dimension, a scalar + ''' + + def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM): + super(Encoder, self).__init__() + self.z_dim = output_chan + self.disc = nn.Sequential( + self.make_disc_block(im_chan, hidden_dim), + self.make_disc_block(hidden_dim, hidden_dim * 2), + self.make_disc_block(hidden_dim * 2, hidden_dim * 4), + self.make_disc_block(hidden_dim * 4, hidden_dim * 8), + self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True), + ) + + def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False): + ''' + Function to return a sequence of operations corresponding to a encoder block of the VAE, + corresponding to a convolution, a batchnorm (except for in the last layer), and an activation + Parameters: + input_channels: how many channels the input feature representation has + output_channels: how many channels the output feature representation should have + kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) + stride: the stride of the convolution + final_layer: whether we're on the final layer (affects activation and batchnorm) + ''' + if not final_layer: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size, stride), + nn.BatchNorm2d(output_channels), + nn.LeakyReLU(0.2, inplace=True), + ) + else: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size, stride), + ) + + def forward(self, image): + ''' + Function for completing a forward pass of the Encoder: Given an image tensor, + returns a 1-dimension tensor representing fake/real. + Parameters: + image: a flattened image tensor with dimension (im_dim) + ''' + disc_pred = self.disc(image) + encoding = disc_pred.view(len(disc_pred), -1) + # The stddev output is treated as the log of the variance of the normal + # distribution by convention and for numerical stability + return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp() + + +class Decoder(nn.Module): + ''' + Decoder Class + Values: + z_dim: the dimension of the noise vector, a scalar + im_chan: the number of channels of the output image, a scalar + hidden_dim: the inner dimension, a scalar + ''' + + def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM): + super(Decoder, self).__init__() + self.z_dim = z_dim + self.gen = nn.Sequential( + self.make_gen_block(z_dim, hidden_dim * 16), + self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1), + self.make_gen_block(hidden_dim * 8, hidden_dim * 4), + self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4), + self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4), + self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), + ) + + def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): + ''' + Function to return a sequence of operations corresponding to a Decoder block of the VAE, + corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation + Parameters: + input_channels: how many channels the input feature representation has + output_channels: how many channels the output feature representation should have + kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) + stride: the stride of the convolution + final_layer: whether we're on the final layer (affects activation and batchnorm) + ''' + if not final_layer: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + else: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), + nn.Sigmoid(), + ) + + def forward(self, noise): + ''' + Function for completing a forward pass of the Decoder: Given a noise vector, + returns a generated image. + Parameters: + noise: a noise tensor with dimensions (batch_size, z_dim) + ''' + x = noise.view(len(noise), self.z_dim, 1, 1) + return self.gen(x) + + +class VAE(nn.Module): + ''' + VAE Class + Values: + z_dim: the dimension of the noise vector, a scalar + im_chan: the number of channels of the output image, a scalar + MNIST is black-and-white, so that's our default + hidden_dim: the inner dimension, a scalar + ''' + + def __init__(self, z_dim=Z_DIM, im_chan=3): + super(VAE, self).__init__() + self.z_dim = z_dim + self.encode = Encoder(im_chan, z_dim) + self.decode = Decoder(z_dim, im_chan) + + def forward(self, images): + ''' + Function for completing a forward pass of the Decoder: Given a noise vector, + returns a generated image. + Parameters: + images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width) + Returns: + decoding: the autoencoded image + q_dist: the z-distribution of the encoding + ''' + q_mean, q_stddev = self.encode(images) + q_dist = Normal(q_mean, q_stddev) + z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation + decoding = self.decode(z_sample) + return decoding, q_dist \ No newline at end of file diff --git a/embedding_models/vae/train.ipynb b/embedding_models/vae/train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ec47f49927f8fb39c80088e3f0640be1456905fa --- /dev/null +++ b/embedding_models/vae/train.ipynb @@ -0,0 +1,393 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "\n", + "from pathlib import Path\n", + "import os\n", + "from PIL import Image\n", + "\n", + "from model import VAE\n", + "from losses import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader, Dataset\n", + "from torchvision import transforms\n", + "import pandas as pd\n", + "import re\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "IMAGE_FOLDER = './data/images/'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "image_names = os.listdir(IMAGE_FOLDER)\n", + "data = pd.DataFrame({'image_name': image_names})\n", + "data['label'] = data['image_name'].apply(lambda x: int(re.match('^\\d+', x)[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class StampDataset(Dataset):\n", + " def __init__(self, data, image_folder=Path(IMAGE_FOLDER), transform=None):\n", + " super().__init__()\n", + " self.image_folder = image_folder\n", + " self.data = data\n", + " self.transform = transform\n", + "\n", + " def __getitem__(self, idx):\n", + " image = Image.open(self.image_folder / self.data.iloc[idx]['image_name'])\n", + " label = self.data.iloc[idx]['label']\n", + " if self.transform:\n", + " image = self.transform(image)\n", + "\n", + " return image, label\n", + "\n", + " \n", + " def __len__(self):\n", + " return len(self.data)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "train_data, val_data = train_test_split(data, test_size=0.3, shuffle=True, stratify=data['label'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_transform = transforms.Compose([\n", + " transforms.Resize((118, 118)),\n", + " transforms.RandomHorizontalFlip(0.5),\n", + " transforms.RandomVerticalFlip(0.5),\n", + " transforms.ToTensor(),\n", + " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", + "])\n", + "\n", + "val_transform = transforms.Compose([\n", + " transforms.Resize((118, 118)),\n", + " transforms.ToTensor(),\n", + " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", + "])\n", + "train_dataset = StampDataset(train_data, transform=train_transform)\n", + "val_dataset = StampDataset(val_data, transform=val_transform)\n", + "\n", + "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=256)\n", + "val_loader = DataLoader(val_dataset, shuffle=True, batch_size=256)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from torch import optim\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "\n", + "from torchvision.utils import make_grid" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "MEAN = torch.tensor((0.76302232, 0.77820438, 0.81879729)).view(3, 1, 1)\n", + "STD = torch.tensor((0.16563211, 0.14949341, 0.1055889)).view(3, 1, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "class LitModel(pl.LightningModule):\n", + " def __init__(self, alpha=1e-3):\n", + " super().__init__()\n", + " self.vae = VAE()\n", + " self.vae_loss = VAELoss()\n", + " self.triplet_loss = BatchHardTripletLoss(margin=1.)\n", + " self.alpha = alpha\n", + " \n", + " def forward(self, x):\n", + " return self.vae(x)\n", + " \n", + " def configure_optimizers(self):\n", + " optimizer = optim.AdamW(self.parameters(), lr=3e-4)\n", + " return optimizer\n", + " # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n", + " # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " images, labels = batch\n", + " labels = labels.unsqueeze(1)\n", + " recon_images, encoding = self.vae(images)\n", + " vae_loss = self.vae_loss(recon_images, images, encoding)\n", + " self.log(\"train_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", + " triplet_loss = self.triplet_loss(encoding.mean, labels)\n", + " self.log(\"train_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", + " loss = self.alpha * triplet_loss + vae_loss\n", + " self.log(\"train_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " images, labels = batch\n", + " labels = labels.unsqueeze(1)\n", + " recon_images, encoding = self.vae(images)\n", + " vae_loss = self.vae_loss(recon_images, images, encoding)\n", + " self.log(\"val_vae_loss\", vae_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", + " triplet_loss = self.triplet_loss(encoding.mean, labels)\n", + " self.log(\"val_triplet_loss\", triplet_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)\n", + " loss = self.alpha * triplet_loss + vae_loss\n", + " self.log(\"val_loss\", loss, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def on_validation_epoch_end(self):\n", + " images, _ = iter(val_loader).next()\n", + " image_unflat = images.detach().cpu()\n", + " image_grid = make_grid(image_unflat[:16], nrow=4)\n", + " self.logger.experiment.add_image('original images', image_grid, self.current_epoch)\n", + "\n", + " recon_images, _ = self.vae(images.to(self.device))\n", + " image_unflat = recon_images.detach().cpu()\n", + " image_grid = make_grid(image_unflat[:16], nrow=4)\n", + " self.logger.experiment.add_image('reconstructed images', image_grid, self.current_epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "litmodel = LitModel()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "logger = TensorBoardLogger(\"reconstruction_logs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n", + "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import hf_hub_download" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "emb_model = torch.jit.load(hf_hub_download(repo_id=\"stamps-labs/vits8-stamp\", filename=\"vits8stamp-torchscript.pth\")).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "val_transform = transforms.Compose([\n", + " transforms.Resize((224, 224)),\n", + " transforms.ToTensor(),\n", + " # transforms.Normalize((0.76302232, 0.77820438, 0.81879729), (0.16563211, 0.14949341, 0.1055889)),\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "train_data['embed'] = train_data['image_name'].apply(lambda x: emb_model(val_transform(Image.open(Path(IMAGE_FOLDER) / x)).unsqueeze(0).to(device))[0].tolist())" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:1: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", + " embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n", + "C:\\Users\\javid\\AppData\\Local\\Temp\\ipykernel_23064\\1572292890.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", + " labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)\n" + ] + } + ], + "source": [ + "embeds = pd.DataFrame(train_data['embed'].tolist()).append(pd.DataFrame(val_data['embed'].tolist()), ignore_index=True)\n", + "labels = pd.DataFrame(train_data['label']).append(pd.DataFrame(val_data['label']), ignore_index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "embeds.to_csv('./all_embeds.tsv', sep='\\t', index=False, header=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "labels.to_csv('./all_labels.tsv', sep='\\t', index=False, header=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(litmodel.vae.encode.state_dict(), './models/encoder.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "im = train_dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 132, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Encoder()\n", + "model.load_state_dict(torch.load('./models/encoder.pth'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/embedding_models/vits8/__init__.py b/embedding_models/vits8/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/embedding_models/vits8/example.py b/embedding_models/vits8/example.py new file mode 100644 index 0000000000000000000000000000000000000000..779cdea9749f55aae321fffcfff501b0a07780d7 --- /dev/null +++ b/embedding_models/vits8/example.py @@ -0,0 +1,10 @@ +from PIL import Image +from model import ViTStamp +def get_embeddings(img_path: str): + model = ViTStamp() + image = Image.open(img_path) + embeddings = model(image=image) + return embeddings + +if __name__ == "__main__": + print(get_embeddings("oml/data/test/images/99d_15.bmp")) \ No newline at end of file diff --git a/embedding_models/vits8/model.py b/embedding_models/vits8/model.py new file mode 100644 index 0000000000000000000000000000000000000000..684bb59476db2b411c8d82da25427f57ff7fdb6e --- /dev/null +++ b/embedding_models/vits8/model.py @@ -0,0 +1,13 @@ +import torch +from torchvision import transforms +from huggingface_hub import hf_hub_download + +class ViTStamp(): + def __init__(self): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = torch.jit.load(hf_hub_download(repo_id="stamps-labs/vits8-stamp", filename="vits8stamp-torchscript.pth")) + self.transform = transforms.ToTensor() + def __call__(self, image) -> torch.Tensor(): + img_tensor = self.transform(image).cuda().unsqueeze(0) if self.device == "cuda" else self.transform(image).unsqueeze(0) + features = self.model(img_tensor) + return features \ No newline at end of file diff --git a/embedding_models/vits8/oml/__init__.py b/embedding_models/vits8/oml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/embedding_models/vits8/oml/create_dataset.py b/embedding_models/vits8/oml/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc410156205c0d6ffec417ab5e8415e3a5ee28c --- /dev/null +++ b/embedding_models/vits8/oml/create_dataset.py @@ -0,0 +1,71 @@ +import os +from PIL import Image +import pandas as pd + +import argparse + +parser = argparse.ArgumentParser("Create a dataset for training with OML", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument("--root-data-path", help="Path to images for dataset", default="data/train_val/") +parser.add_argument("--image-data-path", help="Image folder in root data path", default="images/") +parser.add_argument("--train-val-split", + help="In which ratio to split data in format train:val (For example 80:20)", default="80:20") +parser.add_argument("--separator", + help="What separator is used in image name to separate class name and instance (E.g. circle1_5, separator=_)", + default="_") + +args = parser.parse_args() +config = vars(args) + +root_path = config["root_data_path"] +image_path = config["image_data_path"] +separator = config["separator"] + +train_prc, val_prc = tuple(int(num)/100 for num in config["train_val_split"].split(":")) + +class_names = set() +for image in os.listdir(root_path+image_path): + if image.endswith(("png", "jpg", "bmp", "webp")): + img_name = image.split(".")[0] + Image.open(root_path+image_path+image).resize((224,224)).save(root_path+image_path+img_name+".png", "PNG") + if not image.endswith("png"): + os.remove(root_path+image_path+image) + img_name = img_name.split(separator) + class_name = img_name[0]+img_name[1] + class_names.add(class_name) + else: + print("Not all of the images are in supported format") + + +#For each class in set assign its index in a set as a class label. +class_label_dict = {} +for ind, name in enumerate(class_names): + class_label_dict[name] = ind + +class_count = len(class_names) +train_class_count = int(class_count*train_prc) +print(train_class_count) + +df_dict = {"label": [], + "path": [], + "split": [], + "is_query": [], + "is_gallery": []} +for image in os.listdir(root_path+image_path): + if image.endswith((".png", ".jpg", ".bmp", ".webp")): + img_name = image.split(".")[0].split(separator) + class_name = img_name[0]+img_name[1] + label = class_label_dict[class_name] + path = image_path+image + split = "train" if label <= train_class_count else "validation" + is_query, is_gallery = (1, 1) if split=="validation" else (None, None) + df_dict["label"].append(label) + df_dict["path"].append(path) + df_dict["split"].append(split) + df_dict["is_query"].append(is_query) + df_dict["is_gallery"].append(is_gallery) + +df = pd.DataFrame(df_dict) + +df.to_csv(root_path+"df_stamps.csv", index=False) \ No newline at end of file diff --git a/embedding_models/vits8/oml/data/test/images/99d_15.bmp b/embedding_models/vits8/oml/data/test/images/99d_15.bmp new file mode 100644 index 0000000000000000000000000000000000000000..fe1d3753b6701699f50374b0f2cd63f00a92ed2f Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99d_15.bmp differ diff --git a/embedding_models/vits8/oml/data/test/images/99e_20.bmp b/embedding_models/vits8/oml/data/test/images/99e_20.bmp new file mode 100644 index 0000000000000000000000000000000000000000..61536e9ddb3853f55a8b5dd3954cfc00396a6e21 Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99e_20.bmp differ diff --git a/embedding_models/vits8/oml/data/test/images/99f_25.bmp b/embedding_models/vits8/oml/data/test/images/99f_25.bmp new file mode 100644 index 0000000000000000000000000000000000000000..d821bcaf8f5c28ccc0e3b62ae5dbce413ccc76cf Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99f_25.bmp differ diff --git a/embedding_models/vits8/oml/data/test/images/99g_30.bmp b/embedding_models/vits8/oml/data/test/images/99g_30.bmp new file mode 100644 index 0000000000000000000000000000000000000000..1743a97580629ee703e497a0270fcafbc32bfbe5 Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99g_30.bmp differ diff --git a/embedding_models/vits8/oml/data/test/images/99h_35.bmp b/embedding_models/vits8/oml/data/test/images/99h_35.bmp new file mode 100644 index 0000000000000000000000000000000000000000..d5aab89cff54b1643925688b55809ad721fe2c09 Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99h_35.bmp differ diff --git a/embedding_models/vits8/oml/data/test/images/99i_40.bmp b/embedding_models/vits8/oml/data/test/images/99i_40.bmp new file mode 100644 index 0000000000000000000000000000000000000000..228590d7d1738707d0d072e13986232589f518d3 Binary files /dev/null and b/embedding_models/vits8/oml/data/test/images/99i_40.bmp differ diff --git a/embedding_models/vits8/oml/data/train_val/df_stamps.csv b/embedding_models/vits8/oml/data/train_val/df_stamps.csv new file mode 100644 index 0000000000000000000000000000000000000000..a0e52c7119b61b9c72a74523832a9d2e9b6a4976 --- /dev/null +++ b/embedding_models/vits8/oml/data/train_val/df_stamps.csv @@ -0,0 +1,41 @@ +label,path,split,is_query,is_gallery +0,images/circle6_1239.png,train,, +8,images/triangle19_1242.png,train,, +21,images/rectangle11_1248.png,train,, +39,images/triangle10_1232.png,validation,1.0,1.0 +33,images/word14_1241.png,validation,1.0,1.0 +38,images/word5_1233.png,validation,1.0,1.0 +15,images/circle19_1236.png,train,, +22,images/circle15_1244.png,train,, +32,images/circle21_1249.png,train,, +26,images/oval20_1242.png,train,, +6,images/oval5_1237.png,train,, +23,images/word9_1241.png,train,, +9,images/triangle22_1238.png,train,, +31,images/circle12_1239.png,train,, +11,images/word21_1231.png,train,, +4,images/oval2_1235.png,train,, +20,images/rectangle18_1246.png,train,, +12,images/circle24_1234.png,train,, +5,images/circle2_1249.png,train,, +37,images/word22_1238.png,validation,1.0,1.0 +34,images/triangle18_1247.png,validation,1.0,1.0 +1,images/oval7_1241.png,train,, +10,images/triangle13_1240.png,train,, +14,images/rectangle12_1236.png,train,, +36,images/circle8_1237.png,validation,1.0,1.0 +24,images/triangle9_1245.png,train,, +29,images/word23_1243.png,train,, +28,images/triangle11_1244.png,train,, +16,images/circle2_1246.png,train,, +30,images/circle3_1247.png,train,, +18,images/oval24_1248.png,train,, +2,images/oval12_1231.png,train,, +3,images/oval18_1234.png,train,, +25,images/rectangle11_1245.png,train,, +17,images/word9_1244.png,train,, +13,images/triangle14_1237.png,train,, +35,images/circle2_1233.png,validation,1.0,1.0 +7,images/word18_1239.png,train,, +19,images/rectangle13_1236.png,train,, +27,images/circle24_1246.png,train,, diff --git a/embedding_models/vits8/oml/data/train_val/images/circle12_1239.png b/embedding_models/vits8/oml/data/train_val/images/circle12_1239.png new file mode 100644 index 0000000000000000000000000000000000000000..2319cadef0042b28c5a420440b919d83f8d7d7a3 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle12_1239.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle15_1244.png b/embedding_models/vits8/oml/data/train_val/images/circle15_1244.png new file mode 100644 index 0000000000000000000000000000000000000000..8ac8e02cc01982fd538e4e75c42dd7c79bf7b752 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle15_1244.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle19_1236.png b/embedding_models/vits8/oml/data/train_val/images/circle19_1236.png new file mode 100644 index 0000000000000000000000000000000000000000..828506b9a8f1360b795fc6057d315ca7eea93381 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle19_1236.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle21_1249.png b/embedding_models/vits8/oml/data/train_val/images/circle21_1249.png new file mode 100644 index 0000000000000000000000000000000000000000..9bf3ddb90eb106bc7d096276eeeb7d11376344c3 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle21_1249.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle24_1234.png b/embedding_models/vits8/oml/data/train_val/images/circle24_1234.png new file mode 100644 index 0000000000000000000000000000000000000000..9bc8d64f59c2c8c5befe7a4be5ec44a91f4204e2 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle24_1234.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle24_1246.png b/embedding_models/vits8/oml/data/train_val/images/circle24_1246.png new file mode 100644 index 0000000000000000000000000000000000000000..61765f2e35ff5843f91d9668c7e0874f44f50e00 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle24_1246.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle2_1233.png b/embedding_models/vits8/oml/data/train_val/images/circle2_1233.png new file mode 100644 index 0000000000000000000000000000000000000000..227cf52e2c039b299c1a970233fa7dd62e9f4835 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle2_1233.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle2_1246.png b/embedding_models/vits8/oml/data/train_val/images/circle2_1246.png new file mode 100644 index 0000000000000000000000000000000000000000..496b76d892be7a4dc8b42ab483cd0061fd4bb800 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle2_1246.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle2_1249.png b/embedding_models/vits8/oml/data/train_val/images/circle2_1249.png new file mode 100644 index 0000000000000000000000000000000000000000..2cae10f7be87f1ab24dd88d094c9e9ee66f587aa Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle2_1249.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle3_1247.png b/embedding_models/vits8/oml/data/train_val/images/circle3_1247.png new file mode 100644 index 0000000000000000000000000000000000000000..6f876153b325b4950b4e981d865b890ed1b56ddc Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle3_1247.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle6_1239.png b/embedding_models/vits8/oml/data/train_val/images/circle6_1239.png new file mode 100644 index 0000000000000000000000000000000000000000..da8098ca176df7dd26672244e26594318055a3a1 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle6_1239.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/circle8_1237.png b/embedding_models/vits8/oml/data/train_val/images/circle8_1237.png new file mode 100644 index 0000000000000000000000000000000000000000..95cde0c30f4f28f5fc234559d732ce88171fa532 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/circle8_1237.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval12_1231.png b/embedding_models/vits8/oml/data/train_val/images/oval12_1231.png new file mode 100644 index 0000000000000000000000000000000000000000..70de05de54b3ab12dcbe83c91b6e9f0ea8dcc2f8 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval12_1231.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval18_1234.png b/embedding_models/vits8/oml/data/train_val/images/oval18_1234.png new file mode 100644 index 0000000000000000000000000000000000000000..beaf1e414a0666824fdd6b58365e1d595fbe89ac Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval18_1234.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval20_1242.png b/embedding_models/vits8/oml/data/train_val/images/oval20_1242.png new file mode 100644 index 0000000000000000000000000000000000000000..f0fcb2954c0e974a8698f7fda6fb2c6d02ac9c86 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval20_1242.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval24_1248.png b/embedding_models/vits8/oml/data/train_val/images/oval24_1248.png new file mode 100644 index 0000000000000000000000000000000000000000..cb87c10111ebae10ab5b877d046f634b98e37df8 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval24_1248.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval2_1235.png b/embedding_models/vits8/oml/data/train_val/images/oval2_1235.png new file mode 100644 index 0000000000000000000000000000000000000000..8c0f53a03413b10e7b6b4190ba8cddf66b989c44 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval2_1235.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval5_1237.png b/embedding_models/vits8/oml/data/train_val/images/oval5_1237.png new file mode 100644 index 0000000000000000000000000000000000000000..c7d339d501d7826d6383a232af9fc11f0c675323 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval5_1237.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/oval7_1241.png b/embedding_models/vits8/oml/data/train_val/images/oval7_1241.png new file mode 100644 index 0000000000000000000000000000000000000000..7e381a540b52d1ae452c626705b931180412fd90 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/oval7_1241.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/rectangle11_1245.png b/embedding_models/vits8/oml/data/train_val/images/rectangle11_1245.png new file mode 100644 index 0000000000000000000000000000000000000000..5720461186314262c4a2ea0a23975ff23426366c Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/rectangle11_1245.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/rectangle11_1248.png b/embedding_models/vits8/oml/data/train_val/images/rectangle11_1248.png new file mode 100644 index 0000000000000000000000000000000000000000..6c9af8d675e02fbc3f42aadfd87cc1bb90892a96 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/rectangle11_1248.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/rectangle12_1236.png b/embedding_models/vits8/oml/data/train_val/images/rectangle12_1236.png new file mode 100644 index 0000000000000000000000000000000000000000..811ff0041f0643c146e43069d73fdc11b8ef29e8 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/rectangle12_1236.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/rectangle13_1236.png b/embedding_models/vits8/oml/data/train_val/images/rectangle13_1236.png new file mode 100644 index 0000000000000000000000000000000000000000..448beb6b6cdd6a581cfe83cf876a205ee6c3d07b Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/rectangle13_1236.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/rectangle18_1246.png b/embedding_models/vits8/oml/data/train_val/images/rectangle18_1246.png new file mode 100644 index 0000000000000000000000000000000000000000..3594a45b9697af43dfc58a3e06d4d5f3e8253d6f Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/rectangle18_1246.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle10_1232.png b/embedding_models/vits8/oml/data/train_val/images/triangle10_1232.png new file mode 100644 index 0000000000000000000000000000000000000000..677a77d556f85128b7d0fe9a3bc6c1d5dbb800e3 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle10_1232.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle11_1244.png b/embedding_models/vits8/oml/data/train_val/images/triangle11_1244.png new file mode 100644 index 0000000000000000000000000000000000000000..02ec882a8eedd044e8929f7f12793f3d190f5320 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle11_1244.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle13_1240.png b/embedding_models/vits8/oml/data/train_val/images/triangle13_1240.png new file mode 100644 index 0000000000000000000000000000000000000000..b8b95c1aca3303c2624beb16ba77d5547660ac1e Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle13_1240.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle14_1237.png b/embedding_models/vits8/oml/data/train_val/images/triangle14_1237.png new file mode 100644 index 0000000000000000000000000000000000000000..3b0df5e3876a4d5a410ae5534ade993933e9ec39 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle14_1237.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle18_1247.png b/embedding_models/vits8/oml/data/train_val/images/triangle18_1247.png new file mode 100644 index 0000000000000000000000000000000000000000..3d8dd09dbd1cd708f3b9454183e369c15b31b2c4 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle18_1247.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle19_1242.png b/embedding_models/vits8/oml/data/train_val/images/triangle19_1242.png new file mode 100644 index 0000000000000000000000000000000000000000..147d222c00d14fdaea7844f8dab56bd53f8a4570 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle19_1242.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle22_1238.png b/embedding_models/vits8/oml/data/train_val/images/triangle22_1238.png new file mode 100644 index 0000000000000000000000000000000000000000..9db2431d5983ccac88331b975c700eb25c2da295 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle22_1238.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/triangle9_1245.png b/embedding_models/vits8/oml/data/train_val/images/triangle9_1245.png new file mode 100644 index 0000000000000000000000000000000000000000..3c604e7f974dbc996b8ec85407d99128aab2a8b9 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/triangle9_1245.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word14_1241.png b/embedding_models/vits8/oml/data/train_val/images/word14_1241.png new file mode 100644 index 0000000000000000000000000000000000000000..ec263c3be3e1aeaae7d03a44c34563bf3d37d19b Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word14_1241.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word18_1239.png b/embedding_models/vits8/oml/data/train_val/images/word18_1239.png new file mode 100644 index 0000000000000000000000000000000000000000..b2d8acf5d474ac571c8c1b13a588d13a50f60d3e Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word18_1239.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word21_1231.png b/embedding_models/vits8/oml/data/train_val/images/word21_1231.png new file mode 100644 index 0000000000000000000000000000000000000000..d1408ee764645e98ccbb8dc528908ac06f217f1c Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word21_1231.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word22_1238.png b/embedding_models/vits8/oml/data/train_val/images/word22_1238.png new file mode 100644 index 0000000000000000000000000000000000000000..9bec8b4d7d8299661d82f7332774db0bacd8a92f Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word22_1238.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word23_1243.png b/embedding_models/vits8/oml/data/train_val/images/word23_1243.png new file mode 100644 index 0000000000000000000000000000000000000000..c3b96e9b6695ea2f51dec3441de01bd6e5bb7d79 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word23_1243.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word5_1233.png b/embedding_models/vits8/oml/data/train_val/images/word5_1233.png new file mode 100644 index 0000000000000000000000000000000000000000..298a2c5d25d36803d00b3be7051619e1ac1e61cc Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word5_1233.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word9_1241.png b/embedding_models/vits8/oml/data/train_val/images/word9_1241.png new file mode 100644 index 0000000000000000000000000000000000000000..f22c895133a9d9fe0b8614e09b4fa8ee6a3e9a0b Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word9_1241.png differ diff --git a/embedding_models/vits8/oml/data/train_val/images/word9_1244.png b/embedding_models/vits8/oml/data/train_val/images/word9_1244.png new file mode 100644 index 0000000000000000000000000000000000000000..cefb650bc49faea9627f1b6ed0b7255b7d78f3e3 Binary files /dev/null and b/embedding_models/vits8/oml/data/train_val/images/word9_1244.png differ diff --git a/embedding_models/vits8/oml/example.py b/embedding_models/vits8/oml/example.py new file mode 100644 index 0000000000000000000000000000000000000000..034bcd77068e8be1f82e7a6abe524336f1981176 --- /dev/null +++ b/embedding_models/vits8/oml/example.py @@ -0,0 +1,12 @@ +from PIL import Image +from model_oml import EmbeddingModelOML + +def get_embeddings(img_path: str): + model = EmbeddingModelOML() + image = Image.open(img_path) + embeddings = model(image=image) + return embeddings + + +if __name__ == "__main__": + print(get_embeddings("data/test/images/99d_15.bmp")) \ No newline at end of file diff --git a/embedding_models/vits8/oml/model_oml.py b/embedding_models/vits8/oml/model_oml.py new file mode 100644 index 0000000000000000000000000000000000000000..71bf088b9448f10c9cf77243c5e9bf443a57d43f --- /dev/null +++ b/embedding_models/vits8/oml/model_oml.py @@ -0,0 +1,14 @@ +from oml.models.vit.vit import ViTExtractor +from oml.registry.transforms import get_transforms_for_pretrained +import torch +from PIL import Image + +class EmbeddingModelOML: + def __init__(self, model_path: str = "../models/vits8-stamp.ckpt", arch: str = "vits8", normalise_features: bool = False): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.extractor = ViTExtractor(model_path, arch=arch, normalise_features=normalise_features).to(self.device) + self.transform, _ = get_transforms_for_pretrained("vits8_dino") + def __call__(self, image: Image.Image) -> torch.Tensor: + img_tensor = self.transform(image).cuda().unsqueeze(0) if self.device == "cuda" else self.transform(image).unsqueeze(0) + features = self.extractor(img_tensor) + return features diff --git a/embedding_models/vits8/oml/pack_jit.py b/embedding_models/vits8/oml/pack_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..72bae178f2cb016db24653eec1d2f7b0b2354f28 --- /dev/null +++ b/embedding_models/vits8/oml/pack_jit.py @@ -0,0 +1,34 @@ +import torch +from model_oml import EmbeddingModelOML +from huggingface_hub import HfApi +import argparse + +parser = argparse.ArgumentParser("Packing checkpoint to JIT and serving to HF repo", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument("--upload-to-hf", action="store_true", help="Whether to upload model to hf hub, REQUIRES LOGGING") +parser.add_argument("--path-to-save", help="Where to save the model file", default="../models/") +parser.add_argument("--model-name", help="Which model name to save in folder", default="vits8stamp-torchscript.pth") +parser.add_argument("--repo-id", help="repository id on huggingface", default="stamps-labs/vits8-stamp") + +args = parser.parse_args() +config = vars(args) + +if __name__ == "__main__": + model = EmbeddingModelOML().extractor.cuda() + + model.eval() + + with torch.no_grad(): + model_ts = torch.jit.script(model) + + model_ts.save(config["path_to_save"]+config["model_name"]) + if config["upload_to_hf"]: + api = HfApi() + api.upload_file( + path_or_fileobj=config["path_to_save"]+config["model_name"], + path_in_repo=config["model_name"], + repo_id=config["repo_id"], + repo_type="model" + ) + \ No newline at end of file diff --git a/embedding_models/vits8/oml/requirements.txt b/embedding_models/vits8/oml/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..198176b473dbc57038358c2530753748ad2b92d5 --- /dev/null +++ b/embedding_models/vits8/oml/requirements.txt @@ -0,0 +1,5 @@ +huggingface_hub==0.15.1 +pandas==1.5.2 +Pillow==10.0.0 +pytorch_lightning==1.6.5 +open-metric-learning==0.4.2 \ No newline at end of file diff --git a/embedding_models/vits8/oml/train.py b/embedding_models/vits8/oml/train.py new file mode 100644 index 0000000000000000000000000000000000000000..328047ad418a1648e75b2fe8397a9dd33990830a --- /dev/null +++ b/embedding_models/vits8/oml/train.py @@ -0,0 +1,64 @@ +import pytorch_lightning as pl +import torch +import pandas as pd + +from oml.datasets.base import DatasetQueryGallery, DatasetWithLabels +from oml.lightning.modules.extractor import ExtractorModule +from oml.lightning.callbacks.metric import MetricValCallback +from oml.losses.triplet import TripletLossWithMiner +from oml.metrics.embeddings import EmbeddingMetrics +from oml.miners.inbatch_all_tri import AllTripletsMiner +from oml.models.vit.vit import ViTExtractor +from oml.samplers.balance import BalanceSampler +from pytorch_lightning.loggers import TensorBoardLogger + +import argparse + +parser = argparse.ArgumentParser("Train model with OML", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("--root-dir", help="root directory for train data", default="data/train_val/") +parser.add_argument("--train-dataframe-name", help="name of dataframe in root directory", default="df_stamps.csv") +parser.add_argument("--train-images", help="name of directory with images", default="images/") +parser.add_argument("--num-epochs", help="number of epochs to train model", default=100) +parser.add_argument("--model-arch", help="which model architecture to use, check model zoo", default="vits8") +parser.add_argument("--weights", + help=""" + pretrained weights for model, choose from model zoo + https://open-metric-learning.readthedocs.io/en/latest/feature_extraction/zoo.html + """, + default="vits8_dino") +parser.add_argument("--checkpoint", help="resume training from checkpoint, provide path", default=None) +parser.add_argument("--num-labels", help="number of labels in dataset, set less if cuda out of memory", default=6) +parser.add_argument("--num-instances", help="number of instances for each label in batch, set less if cuda out of memory", default=2) +parser.add_argument("--val-batch-size", help="batch size for validation", default=4) +parser.add_argument("--log-data", action="store_true", help="Whether to log data") + +args = parser.parse_args() +config = vars(args) + +dataset_root = config['root_dir'] +df = pd.read_csv(f"{dataset_root}{config['train_dataframe_name']}") + +df_train = df[df["split"] == "train"].reset_index(drop=True) +df_val = df[df["split"] == "validation"].reset_index(drop=True) +df_val["is_query"] = df_val["is_query"].astype(bool) +df_val["is_gallery"] = df_val["is_gallery"].astype(bool) + +extractor = ViTExtractor(config['weights'], arch=config['model_arch'], normalise_features=False) + +optimizer = torch.optim.SGD(extractor.parameters(), lr=1e-6) +train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root) +criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner()) +batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=config['num_labels'], n_instances=config['num_instances']) +train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler) + +val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root) +val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config['val_batch_size']) +metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,], cmc_top_k=(5, 3, 1)), log_images=True) + +if config['log_data']: + logger = TensorBoardLogger(".") + +pl_model = ExtractorModule(extractor, criterion, optimizer) +trainer = pl.Trainer(max_epochs=config['num_epochs'], callbacks=[metric_callback], num_sanity_val_steps=0, accelerator='gpu', devices=1, resume_from_checkpoint=config['checkpoint']) +trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader) \ No newline at end of file diff --git a/embedding_models/vits8/requirements.txt b/embedding_models/vits8/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..adcf56de13f4483dec06657e140db1358c33bfb3 --- /dev/null +++ b/embedding_models/vits8/requirements.txt @@ -0,0 +1,6 @@ +huggingface_hub==0.15.1 +pandas==1.5.2 +Pillow==10.0.0 +pytorch_lightning==1.6.5 +torch==1.13.0 +torchvision==0.14.0 diff --git a/pipelines/__init__.py b/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipelines/__pycache__/__init__.cpython-39.pyc b/pipelines/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c561fd4e698c41cf33282130ec3cee653f99f36 Binary files /dev/null and b/pipelines/__pycache__/__init__.cpython-39.pyc differ diff --git a/pipelines/detection/__init__.py b/pipelines/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipelines/detection/__pycache__/__init__.cpython-39.pyc b/pipelines/detection/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fbf3855f751ef8453e9a9a2563cc704aae2500a Binary files /dev/null and b/pipelines/detection/__pycache__/__init__.cpython-39.pyc differ diff --git a/pipelines/detection/__pycache__/yolo_stamp.cpython-39.pyc b/pipelines/detection/__pycache__/yolo_stamp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28577e5420ce1b6eeacae0fa2c02b082fe42a391 Binary files /dev/null and b/pipelines/detection/__pycache__/yolo_stamp.cpython-39.pyc differ diff --git a/pipelines/detection/__pycache__/yolo_v8.cpython-39.pyc b/pipelines/detection/__pycache__/yolo_v8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e790ab410fcf7f8095dab4e42277716965a690 Binary files /dev/null and b/pipelines/detection/__pycache__/yolo_v8.cpython-39.pyc differ diff --git a/pipelines/detection/yolo_stamp.py b/pipelines/detection/yolo_stamp.py new file mode 100644 index 0000000000000000000000000000000000000000..5d212e9d6fe3eec04370567df4cae53efc720283 --- /dev/null +++ b/pipelines/detection/yolo_stamp.py @@ -0,0 +1,42 @@ +from typing import Any +from detection_models.yolo_stamp.constants import * +from detection_models.yolo_stamp.utils import * +import albumentations as A +from albumentations.pytorch.transforms import ToTensorV2 +import torch +from huggingface_hub import hf_hub_download +import numpy as np + +class YoloStampPipeline: + def __init__(self): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = None + self.transform = A.Compose([ + A.Normalize(), + ToTensorV2(p=1.0), + ]) + + @classmethod + def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): + yolo = cls() + if model_path_hf is not None and filename_hf is not None: + yolo.model = torch.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location="cpu") + yolo.model.to(yolo.device) + yolo.model.eval() + elif local_model_path is not None: + yolo.model = torch.load(local_model_path, map_location="cpu") + yolo.model.to(yolo.device) + yolo.model.eval() + return yolo + + def __call__(self, image) -> torch.Tensor: + shape = torch.tensor(image.size) + coef = torch.hstack((shape, shape)) / 448 + image = image.convert("RGB").resize((448, 448)) + image_tensor = self.transform(image=np.array(image))["image"] + output = self.model(image_tensor.unsqueeze(0).to(self.device)) + boxes = output_tensor_to_boxes(output[0].detach().cpu()) + boxes = nonmax_suppression(boxes=boxes) + boxes = xywh2xyxy(torch.tensor(boxes)[:, :4]) + boxes = boxes * coef + return boxes \ No newline at end of file diff --git a/pipelines/detection/yolo_v8.py b/pipelines/detection/yolo_v8.py new file mode 100644 index 0000000000000000000000000000000000000000..ecef268e4d6e7906568db10c92a7782564725592 --- /dev/null +++ b/pipelines/detection/yolo_v8.py @@ -0,0 +1,42 @@ +from typing import Any +from huggingface_hub import hf_hub_download +import torchvision +from torchvision.transforms import ToTensor +import torch + +class Yolov8Pipeline: + def __init__(self): + self.model = None + self.transform = ToTensor() + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + @classmethod + def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): + yolo = cls() + if model_path_hf is not None and filename_hf is not None: + yolo.model = torch.jit.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu') + elif local_model_path is not None: + yolo.model = torch.jit.load(local_model_path) + return yolo + + def __call__(self, image, nms_threshold: float = 0.45, conf_threshold: float = 0.15): + shape = torch.tensor(image.size) + coef = torch.hstack((shape, shape)) / 640 + img = image.convert("RGB").resize((640, 640)) + img_tensor = self.transform(img).unsqueeze(0).to(self.device) + pred, boxes, scores = self.model(img_tensor, conf_thres = conf_threshold) + selected = torchvision.ops.nms(boxes, scores, nms_threshold) + predictions_new = list() + for i in selected: + #remove prob and class + pred_i = torch.Tensor(pred[i][:4]) + #Loop through coordinates + for j in range(4): + #If any are negative, map to 0 + if pred_i[j] < 0: + pred_i[j] = 0 + #multiply by coef + pred_i *= coef + predictions_new.append(pred_i) + predictions_new = torch.stack(predictions_new, dim=0) + return predictions_new \ No newline at end of file diff --git a/pipelines/feature_extraction/__init__.py b/pipelines/feature_extraction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipelines/feature_extraction/__pycache__/__init__.cpython-39.pyc b/pipelines/feature_extraction/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecc36c715a0b806c689909c2195eeb28a6c4d13b Binary files /dev/null and b/pipelines/feature_extraction/__pycache__/__init__.cpython-39.pyc differ diff --git a/pipelines/feature_extraction/__pycache__/vae.cpython-39.pyc b/pipelines/feature_extraction/__pycache__/vae.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89193fea715cbc38b040dbebbbfa43b6268eb4c6 Binary files /dev/null and b/pipelines/feature_extraction/__pycache__/vae.cpython-39.pyc differ diff --git a/pipelines/feature_extraction/__pycache__/vits8.cpython-39.pyc b/pipelines/feature_extraction/__pycache__/vits8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5747531def54d4d9b0ff34b933d12382defc8c53 Binary files /dev/null and b/pipelines/feature_extraction/__pycache__/vits8.cpython-39.pyc differ diff --git a/pipelines/feature_extraction/vae.py b/pipelines/feature_extraction/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6fdbfbfba4047fe2e66fac83baab3f006920da --- /dev/null +++ b/pipelines/feature_extraction/vae.py @@ -0,0 +1,26 @@ +from huggingface_hub import hf_hub_download +import torch +from torchvision.transforms.functional import to_tensor + +class VaePipeline: + def __init__(self): + self.encoder = None + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + @classmethod + def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): + vae = cls() + if model_path_hf is not None and filename_hf is not None: + vae.encoder = torch.load(hf_hub_download(model_path_hf, filename_hf), map_location='cpu') + vae.encoder.to(vae.device) + vae.encoder.eval() + elif local_model_path is not None: + vae.encoder = torch.load(local_model_path, map_location='cpu') + vae.encoder.to(vae.device) + vae.encoder.eval() + return vae + + def __call__(self, image) -> torch.Tensor: + image = image.convert("RGB") + img_tensor = to_tensor(image.resize((118, 118))) + return self.encoder(img_tensor.unsqueeze(0).to(self.device))[0][0].detach().cpu() \ No newline at end of file diff --git a/pipelines/feature_extraction/vits8.py b/pipelines/feature_extraction/vits8.py new file mode 100644 index 0000000000000000000000000000000000000000..c20d9b1e52319901547063abb8d181461feabb20 --- /dev/null +++ b/pipelines/feature_extraction/vits8.py @@ -0,0 +1,27 @@ +import torch +from torchvision import transforms +from huggingface_hub import hf_hub_download + +class Vits8Pipeline: + def __init__(self): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = None # Initialized upon loading torchscript + self.transform = transforms.ToTensor() + + @classmethod + def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): + vit = cls() + if model_path_hf is not None and filename_hf is not None: + vit.model = torch.jit.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu') + vit.model.to(vit.device) + vit.model.eval() + elif local_model_path is not None: + vit.model = torch.jit.load(local_model_path, map_location='cpu') + vit.model.to(vit.device) + vit.model.eval() + return vit + + def __call__(self, image) -> torch.Tensor: + image = image.convert("RGB") + img_tensor = self.transform(image).to(self.device).unsqueeze(0) + return self.model(img_tensor)[0].detach().cpu() \ No newline at end of file diff --git a/pipelines/segmentation/__init__.py b/pipelines/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipelines/segmentation/__pycache__/__init__.cpython-39.pyc b/pipelines/segmentation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a99c7ee118ace11a7919c24b2f2239540130e621 Binary files /dev/null and b/pipelines/segmentation/__pycache__/__init__.cpython-39.pyc differ diff --git a/pipelines/segmentation/__pycache__/deeplabv3.cpython-39.pyc b/pipelines/segmentation/__pycache__/deeplabv3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4d148de1d04eb752271061e6a7ebb5cc8118b01 Binary files /dev/null and b/pipelines/segmentation/__pycache__/deeplabv3.cpython-39.pyc differ diff --git a/pipelines/segmentation/deeplabv3.py b/pipelines/segmentation/deeplabv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e31a62c614386ff3d719efd59f9f960ea23fd6ab --- /dev/null +++ b/pipelines/segmentation/deeplabv3.py @@ -0,0 +1,37 @@ +from typing import Any +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +import torchvision.transforms as transforms +import numpy as np + +class DeepLabv3Pipeline: + + def __init__(self): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.transforms = transforms.Compose( + [ + transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST), + transforms.ToTensor() + ] + ) + self.model = None + + @classmethod + def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): + dl = cls() + if model_path_hf is not None and filename_hf is not None: + dl.model = torch.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu') + dl.model.to(dl.device) + dl.model.eval() + elif local_model_path is not None: + dl.model = torch.load(local_model_path, map_location='cpu') + dl.model.to(dl.device) + dl.model.eval() + return dl + + def __call__(self, image: Image.Image, threshold: float = 0) -> Image.Image: + image = image.convert("RGB") + output = self.model(self.transforms(image).unsqueeze(0).to(self.device)) + return Image.fromarray((255 * np.where(output['out'][0].permute(1, 2, 0).detach().cpu() > threshold, + self.transforms(image).permute(1, 2, 0), 1)).astype(np.uint8)) diff --git a/requirements.txt b/requirements.txt index 67605647d101ff00cb1ea8f0796e9f68526f5306..21007267169cb3d3ec88de0c123390d6d58371f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ torch==1.12.0 torchvision==0.13.0 -ultralytics==8.0.57 -albumentations==1.3.0 scikit-learn==1.1.3 matplotlib==3.6.0 pillow==9.3.0 diff --git a/segmentation_models/__init__.py b/segmentation_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/segmentation_models/deeplabv3/__init__.py b/segmentation_models/deeplabv3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/segmentation_models/deeplabv3/data.py b/segmentation_models/deeplabv3/data.py new file mode 100644 index 0000000000000000000000000000000000000000..6a53ae93563598e27853ccb3ae4dddc8dcbd5e84 --- /dev/null +++ b/segmentation_models/deeplabv3/data.py @@ -0,0 +1,114 @@ +from pathlib import Path +from typing import Any, Callable, Optional + +import numpy as np +from PIL import Image +from torchvision.datasets.vision import VisionDataset + + +class SegmentationDataset(VisionDataset): + """A PyTorch dataset for image segmentation task. + The dataset is compatible with torchvision transforms. + The transforms passed would be applied to both the Images and Masks. + """ + def __init__(self, + root: str, + image_folder: str, + mask_folder: str, + transforms: Optional[Callable] = None, + seed: int = None, + fraction: float = None, + subset: str = None, + image_color_mode: str = "rgb", + mask_color_mode: str = "grayscale") -> None: + """ + Args: + root (str): Root directory path. + image_folder (str): Name of the folder that contains the images in the root directory. + mask_folder (str): Name of the folder that contains the masks in the root directory. + transforms (Optional[Callable], optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.ToTensor`` for images. Defaults to None. + seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None. + fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None. + subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None. + image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'. + mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'. + + Raises: + OSError: If image folder doesn't exist in root. + OSError: If mask folder doesn't exist in root. + ValueError: If subset is not either 'Train' or 'Test' + ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale' + """ + super().__init__(root, transforms) + image_folder_path = Path(self.root) / image_folder + mask_folder_path = Path(self.root) / mask_folder + if not image_folder_path.exists(): + raise OSError(f"{image_folder_path} does not exist.") + if not mask_folder_path.exists(): + raise OSError(f"{mask_folder_path} does not exist.") + + if image_color_mode not in ["rgb", "grayscale"]: + raise ValueError( + f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale." + ) + if mask_color_mode not in ["rgb", "grayscale"]: + raise ValueError( + f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale." + ) + + self.image_color_mode = image_color_mode + self.mask_color_mode = mask_color_mode + + if not fraction: + self.image_names = sorted(image_folder_path.glob("*")) + self.mask_names = sorted(mask_folder_path.glob("*")) + else: + if subset not in ["Train", "Test"]: + raise (ValueError( + f"{subset} is not a valid input. Acceptable values are Train and Test." + )) + self.fraction = fraction + self.image_list = np.array(sorted(image_folder_path.glob("*"))) + self.mask_list = np.array(sorted(mask_folder_path.glob("*"))) + if seed: + np.random.seed(seed) + indices = np.arange(len(self.image_list)) + np.random.shuffle(indices) + self.image_list = self.image_list[indices] + self.mask_list = self.mask_list[indices] + if subset == "Train": + self.image_names = self.image_list[:int( + np.ceil(len(self.image_list) * (1 - self.fraction)))] + self.mask_names = self.mask_list[:int( + np.ceil(len(self.mask_list) * (1 - self.fraction)))] + else: + self.image_names = self.image_list[ + int(np.ceil(len(self.image_list) * (1 - self.fraction))):] + self.mask_names = self.mask_list[ + int(np.ceil(len(self.mask_list) * (1 - self.fraction))):] + + def __len__(self) -> int: + return len(self.image_names) + + def __getitem__(self, index: int) -> Any: + image_path = self.image_names[index] + mask_path = self.mask_names[index] + with open(image_path, "rb") as image_file, open(mask_path, + "rb") as mask_file: + image = Image.open(image_file) + if self.image_color_mode == "rgb": + image = image.convert("RGB") + elif self.image_color_mode == "grayscale": + image = image.convert("L") + mask = Image.open(mask_file) + if self.mask_color_mode == "rgb": + mask = mask.convert("RGB") + elif self.mask_color_mode == "grayscale": + mask = mask.convert("L") + sample = {"image": image, "mask": mask} + if self.transforms: + sample["image"] = self.transforms(sample["image"]) + sample["mask"] = self.transforms(sample["mask"]) + return sample \ No newline at end of file diff --git a/segmentation_models/deeplabv3/main.py b/segmentation_models/deeplabv3/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0720e38c8a579f14891357f43b6a9b79f3a6ea37 --- /dev/null +++ b/segmentation_models/deeplabv3/main.py @@ -0,0 +1,64 @@ +from pathlib import Path + +import click +import torch +from sklearn.metrics import f1_score +from torch.utils import data + +from utils import * +from model import createDeepLabv3 +from trainer import train_model + + +@click.command() +@click.option("--data-directory", + required=True, + help="Specify the data directory.") +@click.option("--exp_directory", + required=True, + help="Specify the experiment directory.") +@click.option( + "--epochs", + default=25, + type=int, + help="Specify the number of epochs you want to run the experiment for.") +@click.option("--batch-size", + default=4, + type=int, + help="Specify the batch size for the dataloader.") +def main(data_directory, exp_directory, epochs, batch_size): + # Create the deeplabv3 resnet101 model which is pretrained on a subset + # of COCO train2017, on the 20 categories that are present in the Pascal VOC dataset. + model = createDeepLabv3() + model.train() + data_directory = Path(data_directory) + # Create the experiment directory if not present + exp_directory = Path(exp_directory) + if not exp_directory.exists(): + exp_directory.mkdir() + + # Specify the loss function + criterion = torch.nn.MSELoss(reduction='mean') + # Specify the optimizer with a lower learning rate + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + # Specify the evaluation metrics + metrics = {'f1_score': f1_score, 'iou': iou} + + # Create the dataloader + dataloaders = get_dataloader_single_folder( + data_directory, batch_size=batch_size) + _ = train_model(model, + criterion, + dataloaders, + optimizer, + bpath=exp_directory, + metrics=metrics, + num_epochs=epochs) + + # Save the trained model + torch.save(model, exp_directory / 'weights.pt') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/segmentation_models/deeplabv3/model.py b/segmentation_models/deeplabv3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..719ff07526c9be8f0df79a7979dd8d0195ce2c89 --- /dev/null +++ b/segmentation_models/deeplabv3/model.py @@ -0,0 +1,20 @@ +""" DeepLabv3 Model download and change the head for your prediction""" +from torchvision.models.segmentation.deeplabv3 import DeepLabHead +from torchvision import models + + +def createDeepLabv3(outputchannels=1): + """DeepLabv3 class with custom head + + Args: + outputchannels (int, optional): The number of output channels + in your dataset masks. Defaults to 1. + + Returns: + model: Returns the DeepLabv3 model with the ResNet101 backbone. + """ + model = models.segmentation.deeplabv3_resnet50(pretrained=True) + model.classifier = DeepLabHead(2048, outputchannels) + # Set the model in training mode + model.train() + return model \ No newline at end of file diff --git a/segmentation_models/deeplabv3/trainer.py b/segmentation_models/deeplabv3/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..42e0d7f1efce70f0b43d34d908933aa70c145d7d --- /dev/null +++ b/segmentation_models/deeplabv3/trainer.py @@ -0,0 +1,88 @@ +import copy +import csv +import os +import time + +import numpy as np +import torch +from tqdm import tqdm + + +def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, + num_epochs): + since = time.time() + best_model_wts = copy.deepcopy(model.state_dict()) + best_loss = 1e10 + # Use gpu if available + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.to(device) + # Initialize the log file for training and testing loss and metrics + fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \ + [f'Train_{m}' for m in metrics.keys()] + \ + [f'Test_{m}' for m in metrics.keys()] + with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + for epoch in range(1, num_epochs + 1): + print('Epoch {}/{}'.format(epoch, num_epochs)) + print('-' * 10) + # Each epoch has a training and validation phase + # Initialize batch summary + batchsummary = {a: [0] for a in fieldnames} + + for phase in ['Train', 'Test']: + if phase == 'Train': + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + # Iterate over data. + for sample in tqdm(iter(dataloaders[phase])): + inputs = sample['image'].to(device) + masks = sample['mask'].to(device) + # zero the parameter gradients + optimizer.zero_grad() + + # track history if only in train + with torch.set_grad_enabled(phase == 'Train'): + outputs = model(inputs) + loss = criterion(outputs['out'], masks) + y_pred = outputs['out'].data.cpu().numpy().ravel() + y_true = masks.data.cpu().numpy().ravel() + for name, metric in metrics.items(): + if name == 'f1_score': + # Use a classification threshold of 0.1 + batchsummary[f'{phase}_{name}'].append( + metric(y_true > 0, y_pred > 0.1)) + else: + batchsummary[f'{phase}_{name}'].append( + metric(y_true.astype('uint8'), y_pred)) + + # backward + optimize only if in training phase + if phase == 'Train': + loss.backward() + optimizer.step() + batchsummary['epoch'] = epoch + epoch_loss = loss + batchsummary[f'{phase}_loss'] = epoch_loss.item() + print('{} Loss: {:.4f}'.format(phase, loss)) + for field in fieldnames[3:]: + batchsummary[field] = np.mean(batchsummary[field]) + print(batchsummary) + with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writerow(batchsummary) + # deep copy the model + if phase == 'Test' and loss < best_loss: + best_loss = loss + best_model_wts = copy.deepcopy(model.state_dict()) + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + print('Lowest Loss: {:4f}'.format(best_loss)) + + # load best model weights + model.load_state_dict(best_model_wts) + return model \ No newline at end of file diff --git a/segmentation_models/deeplabv3/utils.py b/segmentation_models/deeplabv3/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef377da502166dc14f5ef5e6237b48371a68d3c --- /dev/null +++ b/segmentation_models/deeplabv3/utils.py @@ -0,0 +1,58 @@ +from pathlib import Path + +from torch.utils.data import DataLoader +from torchvision import transforms + +from data import SegmentationDataset + + +def get_dataloader_single_folder(data_dir: str, + image_folder: str = 'images', + mask_folder: str = 'masks', + fraction: float = 0.2, + batch_size: int = 4): + """Create train and test dataloader from a single directory containing + the image and mask folders. + + Args: + data_dir (str): Data directory path or root + image_folder (str, optional): Image folder name. Defaults to 'Images'. + mask_folder (str, optional): Mask folder name. Defaults to 'Masks'. + fraction (float, optional): Fraction of Test set. Defaults to 0.2. + batch_size (int, optional): Dataloader batch size. Defaults to 4. + + Returns: + dataloaders: Returns dataloaders dictionary containing the + Train and Test dataloaders. + """ + data_transforms = transforms.Compose([transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) + + image_datasets = { + x: SegmentationDataset(data_dir, + image_folder=image_folder, + mask_folder=mask_folder, + seed=100, + fraction=fraction, + subset=x, + transforms=data_transforms) + for x in ['Train', 'Test'] + } + dataloaders = { + x: DataLoader(image_datasets[x], + batch_size=batch_size, + shuffle=True, + num_workers=0) + for x in ['Train', 'Test'] + } + return dataloaders + + +def iou(gt_mask, pred_mask, threshold): + + pred_mask = (pred_mask > threshold) * 1 + gt_mask = (gt_mask == 1) * 1 + + overlap = pred_mask * gt_mask # Logical AND + union = (pred_mask + gt_mask)>0 # Logical OR + iou = overlap.sum() / float(union.sum()) + return iou \ No newline at end of file diff --git a/utils.py b/utils.py index 22d259dbcd17b392fe17ce8519f8a712b43d7d17..c64e9aa5ce27f2b45da30eba56715992fa041c48 100644 --- a/utils.py +++ b/utils.py @@ -1,142 +1,7 @@ -from PIL import Image, ImageDraw -import torch -import numpy as np -import matplotlib.pyplot as plt -import matplotlib -from constants import * - -def visualize_bbox(image: Image, prediction): - img = image.copy() - draw = ImageDraw.Draw(img) - for i, box in enumerate(prediction): - x1, y1, x2, y2 = box.cpu() - draw = ImageDraw.Draw(img) - text_w, text_h = draw.textsize(str(i + 1)) - label_y = y1 if y1 <= text_h else y1 - text_h - draw.rectangle((x1, y1, x2, y2), outline='red') - draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red') - draw.text((x1, label_y), str(i + 1), fill='white') - return img - -def xywh2xyxy(x): - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 0] = x[..., 0] - y[..., 1] = x[..., 1] - y[..., 2] = x[..., 0] + x[..., 2] - y[..., 3] = x[..., 1] + x[..., 3] - return y - -def output_tensor_to_boxes(boxes_tensor): - """ - Converts the YOLO output tensor to list of boxes with probabilites. - - Arguments: - boxes_tensor -- tensor of shape (S, S, BOX, 5) - - Returns: - boxes -- list of shape (None, 5) - - Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold. - For example, the actual output size of scores would be (10, 5) if there are 10 boxes - """ - cell_w, cell_h = W/S, H/S - boxes = [] - - for i in range(S): - for j in range(S): - for b in range(BOX): - anchor_wh = torch.tensor(ANCHORS[b]) - data = boxes_tensor[i,j,b] - xy = torch.sigmoid(data[:2]) - wh = torch.exp(data[2:4])*anchor_wh - obj_prob = torch.sigmoid(data[4]) - - if obj_prob > OUTPUT_THRESH: - x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1] - x, y = x_center+j-w/2, y_center+i-h/2 - x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h - box = [x,y,w,h, obj_prob] - boxes.append(box) - return boxes - -def overlap(interval_1, interval_2): - """ - Calculates length of overlap between two intervals. - - Arguments: - interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval - interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval - - Returns: - overlap -- length of overlap - """ - x1, x2 = interval_1 - x3, x4 = interval_2 - if x3 < x1: - if x4 < x1: - return 0 - else: - return min(x2,x4) - x1 - else: - if x2 < x3: - return 0 - else: - return min(x2,x4) - x3 - - -def compute_iou(box1, box2): - """ - Compute IOU between box1 and box2. - - Argmunets: - box1 -- list of shape (5, ). Represents the first box - box2 -- list of shape (5, ). Represents the second box - Each box is [x, y, w, h, prob] - - Returns: - iou -- intersection over union score between two boxes - """ - x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3] - x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3] - - area1, area2 = w1*h1, w2*h2 - intersect_w = overlap((x1,x1+w1), (x2,x2+w2)) - intersect_h = overlap((y1,y1+h1), (y2,y2+w2)) - if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2: - return 1. - intersect_area = intersect_w*intersect_h - iou = intersect_area/(area1 + area2 - intersect_area) - return iou - - -def nonmax_suppression(boxes, iou_thresh = IOU_THRESH): - """ - Removes ovelap bboxes - - Arguments: - boxes -- list of shape (None, 5) - iou_thresh -- maximal value of iou when boxes are considered different - Each box is [x, y, w, h, prob] - - Returns: - boxes -- list of shape (None, 5) with removed overlapping boxes - """ - boxes = sorted(boxes, key=lambda x: x[4], reverse=True) - for i, current_box in enumerate(boxes): - if current_box[4] <= 0: - continue - for j in range(i+1, len(boxes)): - iou = compute_iou(current_box, boxes[j]) - if iou > iou_thresh: - boxes[j][4] = 0 - boxes = [box for box in boxes if box[4] > 0] - return boxes - def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs): """ Create a heatmap from a numpy array and two lists of labels. - Parameters ---------- data @@ -196,7 +61,6 @@ def annotate_heatmap(im, data=None, valfmt="{x:.2f}", threshold=None, **textkw): """ A function to annotate a heatmap. - Parameters ---------- im @@ -247,4 +111,17 @@ def annotate_heatmap(im, data=None, valfmt="{x:.2f}", text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) - return texts \ No newline at end of file + return texts + +def visualize_bbox(image: Image, prediction): + img = image.copy() + draw = ImageDraw.Draw(img) + for i, box in enumerate(prediction): + x1, y1, x2, y2 = box.cpu() + draw = ImageDraw.Draw(img) + text_w, text_h = draw.textsize(str(i + 1)) + label_y = y1 if y1 <= text_h else y1 - text_h + draw.rectangle((x1, y1, x2, y2), outline='red') + draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red') + draw.text((x1, label_y), str(i + 1), fill='white') + return img \ No newline at end of file