sadjava commited on
Commit
b2a333c
1 Parent(s): 1376c2b

Update README.md

Browse files
detection_models/yolo_stamp/train.ipynb DELETED
@@ -1,185 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from model import *\n",
10
- "from loss import *\n",
11
- "from data import *\n",
12
- "from torch import optim\n",
13
- "from tqdm import tqdm\n",
14
- "\n",
15
- "import pytorch_lightning as pl\n",
16
- "from torchmetrics.detection import MeanAveragePrecision\n",
17
- "from pytorch_lightning.loggers import TensorBoardLogger"
18
- ]
19
- },
20
- {
21
- "cell_type": "code",
22
- "execution_count": 2,
23
- "metadata": {},
24
- "outputs": [],
25
- "source": [
26
- "_, _, test_dataset = get_datasets()"
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": 3,
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "class LitModel(pl.LightningModule):\n",
36
- " def __init__(self):\n",
37
- " super().__init__()\n",
38
- " self.model = YOLOStamp()\n",
39
- " self.criterion = YOLOLoss()\n",
40
- " self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n",
41
- " \n",
42
- " def forward(self, x):\n",
43
- " return self.model(x)\n",
44
- "\n",
45
- " def configure_optimizers(self):\n",
46
- " optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n",
47
- " # return optimizer\n",
48
- " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n",
49
- " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
50
- "\n",
51
- " def training_step(self, batch, batch_idx):\n",
52
- " images, targets = batch\n",
53
- " tensor_images = torch.stack(images)\n",
54
- " tensor_targets = torch.stack(targets)\n",
55
- " output = self.model(tensor_images)\n",
56
- " loss = self.criterion(output, tensor_targets)\n",
57
- " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
58
- " return loss\n",
59
- "\n",
60
- " def validation_step(self, batch, batch_idx):\n",
61
- " images, targets = batch\n",
62
- " tensor_images = torch.stack(images)\n",
63
- " tensor_targets = torch.stack(targets)\n",
64
- " output = self.model(tensor_images)\n",
65
- " loss = self.criterion(output, tensor_targets)\n",
66
- " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n",
67
- "\n",
68
- " for i in range(len(images)):\n",
69
- " boxes = output_tensor_to_boxes(output[i].detach().cpu())\n",
70
- " boxes = nonmax_suppression(boxes)\n",
71
- " target = target_tensor_to_boxes(targets[i])[::BOX]\n",
72
- " if not boxes:\n",
73
- " boxes = torch.zeros((1, 5))\n",
74
- " preds = [\n",
75
- " dict(\n",
76
- " boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n",
77
- " scores=torch.tensor(boxes)[:, 4].clone().detach(),\n",
78
- " labels=torch.zeros(len(boxes)),\n",
79
- " )\n",
80
- " ]\n",
81
- " target = [\n",
82
- " dict(\n",
83
- " boxes=torch.tensor(target),\n",
84
- " labels=torch.zeros(len(target)),\n",
85
- " )\n",
86
- " ]\n",
87
- " self.val_map.update(preds, target)\n",
88
- " \n",
89
- " def on_validation_epoch_end(self):\n",
90
- " mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n",
91
- " mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n",
92
- " mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n",
93
- " self.log_dict(mAPs)\n",
94
- " self.val_map.reset()\n",
95
- "\n",
96
- " image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n",
97
- " output = self.model(image.unsqueeze(0))\n",
98
- " boxes = output_tensor_to_boxes(output[0].detach().cpu())\n",
99
- " boxes = nonmax_suppression(boxes)\n",
100
- " img = image.permute(1, 2, 0).cpu().numpy()\n",
101
- " img = visualize_bbox(img.copy(), boxes=boxes)\n",
102
- " img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n",
103
- " \n",
104
- " self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 4,
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "litmodel = LitModel()"
114
- ]
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": 5,
119
- "metadata": {},
120
- "outputs": [],
121
- "source": [
122
- "logger = TensorBoardLogger(\"detection_logs\")"
123
- ]
124
- },
125
- {
126
- "cell_type": "code",
127
- "execution_count": 7,
128
- "metadata": {},
129
- "outputs": [],
130
- "source": [
131
- "epochs = 100"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "execution_count": 8,
137
- "metadata": {},
138
- "outputs": [],
139
- "source": [
140
- "train_loader, val_loader = get_loaders(batch_size=8)"
141
- ]
142
- },
143
- {
144
- "cell_type": "code",
145
- "execution_count": null,
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n",
150
- "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)"
151
- ]
152
- },
153
- {
154
- "cell_type": "code",
155
- "execution_count": null,
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "%tensorboard"
160
- ]
161
- }
162
- ],
163
- "metadata": {
164
- "kernelspec": {
165
- "display_name": "Python 3",
166
- "language": "python",
167
- "name": "python3"
168
- },
169
- "language_info": {
170
- "codemirror_mode": {
171
- "name": "ipython",
172
- "version": 3
173
- },
174
- "file_extension": ".py",
175
- "mimetype": "text/x-python",
176
- "name": "python",
177
- "nbconvert_exporter": "python",
178
- "pygments_lexer": "ipython3",
179
- "version": "3.9.0"
180
- },
181
- "orig_nbformat": 4
182
- },
183
- "nbformat": 4,
184
- "nbformat_minor": 2
185
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
detection_models/yolo_stamp/utils.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import cv2
3
  import pandas as pd
4
  import numpy as np
5
  from pathlib import Path
@@ -53,33 +52,6 @@ def plot_normalized_img(img, std=STD, mean=MEAN, size=(7,7)):
53
  plt.figure(figsize=size)
54
  plt.imshow((255. * (img * std + mean)).astype(np.uint))
55
  plt.show()
56
-
57
-
58
- def visualize_bbox(img, boxes, thickness=2, color=BOX_COLOR, draw_center=True):
59
- """
60
- Draws boxes on the given image.
61
-
62
- Arguments:
63
- img -- torch.Tensor of shape (3, W, H) or numpy.ndarray of shape (W, H, 3)
64
- boxes -- list of shape (None, 5)
65
- thickness -- number specifying the thickness of box border
66
- color -- RGB tuple of shape (3,) specifying the color of boxes
67
- draw_center -- boolean specifying whether to draw center or not
68
-
69
- Returns:
70
- img_copy -- numpy.ndarray of shape(W, H, 3) containing image with bouning boxes
71
- """
72
- img_copy = img.cpu().permute(1,2,0).numpy() if isinstance(img, torch.Tensor) else img.copy()
73
- for box in boxes:
74
- x,y,w,h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
75
- img_copy = cv2.rectangle(
76
- img_copy,
77
- (x,y),(x+w, y+h),
78
- color, thickness)
79
- if draw_center:
80
- center = (x+w//2, y+h//2)
81
- img_copy = cv2.circle(img_copy, center=center, radius=3, color=(0,255,0), thickness=2)
82
- return img_copy
83
 
84
 
85
  def read_data(annotations=Path(ANNOTATIONS_PATH)):
 
1
  import torch
 
2
  import pandas as pd
3
  import numpy as np
4
  from pathlib import Path
 
52
  plt.figure(figsize=size)
53
  plt.imshow((255. * (img * std + mean)).astype(np.uint))
54
  plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  def read_data(annotations=Path(ANNOTATIONS_PATH)):
requirements.txt CHANGED
@@ -1,7 +1,14 @@
1
- torch==1.12.0
2
- torchvision==0.13.0
3
- scikit-learn==1.1.3
 
4
  matplotlib==3.6.0
5
- pillow==9.3.0
6
  pandas==1.5.1
7
- gradio>=3.36.1
 
 
 
 
 
 
 
1
+ albumentations==1.3.0
2
+ click==8.0.4
3
+ gradio==3.36.1
4
+ huggingface_hub==0.14.1
5
  matplotlib==3.6.0
6
+ numpy==1.23.4
7
  pandas==1.5.1
8
+ Pillow==9.3.0
9
+ Pillow==10.0.0
10
+ pytorch_lightning==2.0.2
11
+ scikit_learn==1.1.3
12
+ torch==1.12.0+cu116
13
+ torchvision==0.13.0+cu116
14
+ tqdm==4.64.1