Upload batched_inference.py
Browse files- batched_inference.py +193 -0
batched_inference.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import torch.multiprocessing as multiprocessing
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from torch import autocast
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
from torchvision.transforms import InterpolationMode
|
11 |
+
from tqdm import tqdm
|
12 |
+
import random
|
13 |
+
import json
|
14 |
+
|
15 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
16 |
+
torch.backends.cudnn.allow_tf32 = True
|
17 |
+
|
18 |
+
torch.autograd.set_detect_anomaly(False)
|
19 |
+
|
20 |
+
torch.autograd.profiler.emit_nvtx(enabled=False)
|
21 |
+
torch.autograd.profiler.profile(enabled=False)
|
22 |
+
torch.backends.cudnn.benchmark = True
|
23 |
+
|
24 |
+
class ImageDataset(Dataset):
|
25 |
+
def __init__(self, csv_file, train, base_path):
|
26 |
+
|
27 |
+
self.csv_file = csv_file
|
28 |
+
self.train = train
|
29 |
+
self.all_image_names = self.csv_file[:]['md5'].apply(str)
|
30 |
+
self.all_image_ext = self.csv_file[:]['file_ext'].apply(str)
|
31 |
+
self.train_size = len(self.csv_file)
|
32 |
+
self.base_path = base_path
|
33 |
+
if self.train == True:
|
34 |
+
print(f"Number of training images: {self.train_size}")
|
35 |
+
self.thin_transform = transforms.Compose([
|
36 |
+
transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
|
37 |
+
transforms.CenterCrop(224),
|
38 |
+
transforms.ToTensor(),
|
39 |
+
transforms.Normalize(mean=[
|
40 |
+
0.48145466,
|
41 |
+
0.4578275,
|
42 |
+
0.40821073
|
43 |
+
], std=[
|
44 |
+
0.26862954,
|
45 |
+
0.26130258,
|
46 |
+
0.27577711
|
47 |
+
]) # Normalize image
|
48 |
+
|
49 |
+
])
|
50 |
+
self.normal_transform = transforms.Compose([
|
51 |
+
transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize(mean=[
|
54 |
+
0.48145466,
|
55 |
+
0.4578275,
|
56 |
+
0.40821073
|
57 |
+
], std=[
|
58 |
+
0.26862954,
|
59 |
+
0.26130258,
|
60 |
+
0.27577711
|
61 |
+
]) # Normalize image
|
62 |
+
|
63 |
+
])
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.all_image_names)
|
67 |
+
|
68 |
+
def __getitem__(self, index):
|
69 |
+
image = Image.open(self.base_path+"/"+str(self.all_image_names[index])+str(self.all_image_ext[index])).convert("RGB")
|
70 |
+
ratio = image.height/image.width
|
71 |
+
if ratio > 2.0 or ratio < 0.5:
|
72 |
+
image = self.thin_transform(image)
|
73 |
+
else:
|
74 |
+
image = self.normal_transform(image)
|
75 |
+
|
76 |
+
|
77 |
+
return {
|
78 |
+
'image': image,
|
79 |
+
"image_name": self.all_image_names[index]
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def prepare_model():
|
84 |
+
model = torch.load("path/to/your/model.pth").to("cuda")
|
85 |
+
model.to(memory_format=torch.channels_last)
|
86 |
+
model = model.eval()
|
87 |
+
return model
|
88 |
+
|
89 |
+
|
90 |
+
def train(tagging_is_running, model, dataloader, train_data, output_queue):
|
91 |
+
print('Begin tagging')
|
92 |
+
model.eval()
|
93 |
+
counter = 0
|
94 |
+
|
95 |
+
with torch.no_grad():
|
96 |
+
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
|
97 |
+
|
98 |
+
data, image_names = data['image'].to("cuda"), data["image_name"]
|
99 |
+
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
100 |
+
outputs = model(data)
|
101 |
+
|
102 |
+
probabilities = torch.nn.functional.sigmoid(outputs)
|
103 |
+
output_queue.put((probabilities.to("cpu"), image_names))
|
104 |
+
|
105 |
+
counter += 1
|
106 |
+
_ = tagging_is_running.get()
|
107 |
+
print("Tagging finished!")
|
108 |
+
|
109 |
+
|
110 |
+
def tag_writer(tagging_is_running, output_queue, output_file_name):
|
111 |
+
with open("tags.json", "r") as file:
|
112 |
+
tags = json.load(file)
|
113 |
+
allowed_tags = sorted(tags)
|
114 |
+
del tags
|
115 |
+
allowed_tags.extend(["placeholder0", "placeholder1", "placeholder2"])
|
116 |
+
tag_count = len(allowed_tags)
|
117 |
+
assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
|
118 |
+
|
119 |
+
with open(output_file_name, "w") as output_csv:
|
120 |
+
writer = csv.writer(output_csv)
|
121 |
+
writer.writerow(["image_name", "tags", "tag_probs"])
|
122 |
+
while not (tagging_is_running.qsize()>0 and output_queue.qsize()>0):
|
123 |
+
tag_probabilities, image_names = output_queue.get()
|
124 |
+
tag_probabilities = tag_probabilities.tolist()
|
125 |
+
|
126 |
+
for per_image_tag_probabilities,image_name in zip(tag_probabilities, image_names, strict=True):
|
127 |
+
this_image_tags = []
|
128 |
+
this_image_tag_probabilities = []
|
129 |
+
for index, per_tag_probability in enumerate(per_image_tag_probabilities):
|
130 |
+
if per_tag_probability > 0.3:
|
131 |
+
tag = allowed_tags[index]
|
132 |
+
if "placeholder" not in tag:
|
133 |
+
this_image_tags.append(tag)
|
134 |
+
this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
|
135 |
+
image_row = [image_name," ".join(this_image_tags)," ".join(this_image_tag_probabilities)]
|
136 |
+
writer.writerow(image_row)
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def set_seed(seed: int = 42) -> None:
|
144 |
+
np.random.seed(seed)
|
145 |
+
random.seed(seed)
|
146 |
+
torch.manual_seed(seed)
|
147 |
+
torch.cuda.manual_seed(seed)
|
148 |
+
# When running on the CuDNN backend, two further options must be set
|
149 |
+
torch.backends.cudnn.deterministic = True
|
150 |
+
torch.backends.cudnn.benchmark = False
|
151 |
+
# Set a fixed value for the hash seed
|
152 |
+
print(f"Random seed set as {seed}")
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
steps = 0
|
157 |
+
output_file_name = "your_file.csv"
|
158 |
+
set_seed()
|
159 |
+
multiprocessing.set_start_method('spawn')
|
160 |
+
output_queue = multiprocessing.Queue()
|
161 |
+
tagging_is_running = multiprocessing.Queue(maxsize=5)
|
162 |
+
tagging_is_running.put("Running!")
|
163 |
+
|
164 |
+
# initialize the computation device
|
165 |
+
if torch.cuda.is_available():
|
166 |
+
device = torch.device('cuda')
|
167 |
+
else:
|
168 |
+
raise RuntimeError("CUDA is not available!")
|
169 |
+
|
170 |
+
model = prepare_model().to("cuda")
|
171 |
+
batch_size = 128
|
172 |
+
|
173 |
+
|
174 |
+
# read the training csv file
|
175 |
+
train_csv = pd.read_csv('/path/to/a/list/of/files/and/their/extensions.csv')
|
176 |
+
# train dataset
|
177 |
+
train_data = ImageDataset(
|
178 |
+
train_csv, train=True
|
179 |
+
)
|
180 |
+
|
181 |
+
train_loader = DataLoader(
|
182 |
+
train_data,
|
183 |
+
batch_size=batch_size,
|
184 |
+
shuffle=False,
|
185 |
+
num_workers=6,
|
186 |
+
pin_memory=True
|
187 |
+
)
|
188 |
+
process_writer = multiprocessing.Process(target=tag_writer, args=(tagging_is_running, output_queue, output_file_name))
|
189 |
+
process_writer.start()
|
190 |
+
process_tagger = multiprocessing.Process(target=train, args=(tagging_is_running, model, train_loader, train_data, output_queue,))
|
191 |
+
process_tagger.start()
|
192 |
+
process_writer.join()
|
193 |
+
process_tagger.join()
|