Thouph commited on
Commit
38d5616
1 Parent(s): d5e3cf2

Upload batched_inference.py

Browse files
Files changed (1) hide show
  1. batched_inference.py +193 -0
batched_inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import torch.multiprocessing as multiprocessing
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torchvision.transforms as transforms
6
+ from torch import autocast
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from PIL import Image
9
+ import torch
10
+ from torchvision.transforms import InterpolationMode
11
+ from tqdm import tqdm
12
+ import random
13
+ import json
14
+
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+ torch.backends.cudnn.allow_tf32 = True
17
+
18
+ torch.autograd.set_detect_anomaly(False)
19
+
20
+ torch.autograd.profiler.emit_nvtx(enabled=False)
21
+ torch.autograd.profiler.profile(enabled=False)
22
+ torch.backends.cudnn.benchmark = True
23
+
24
+ class ImageDataset(Dataset):
25
+ def __init__(self, csv_file, train, base_path):
26
+
27
+ self.csv_file = csv_file
28
+ self.train = train
29
+ self.all_image_names = self.csv_file[:]['md5'].apply(str)
30
+ self.all_image_ext = self.csv_file[:]['file_ext'].apply(str)
31
+ self.train_size = len(self.csv_file)
32
+ self.base_path = base_path
33
+ if self.train == True:
34
+ print(f"Number of training images: {self.train_size}")
35
+ self.thin_transform = transforms.Compose([
36
+ transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
37
+ transforms.CenterCrop(224),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[
40
+ 0.48145466,
41
+ 0.4578275,
42
+ 0.40821073
43
+ ], std=[
44
+ 0.26862954,
45
+ 0.26130258,
46
+ 0.27577711
47
+ ]) # Normalize image
48
+
49
+ ])
50
+ self.normal_transform = transforms.Compose([
51
+ transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(mean=[
54
+ 0.48145466,
55
+ 0.4578275,
56
+ 0.40821073
57
+ ], std=[
58
+ 0.26862954,
59
+ 0.26130258,
60
+ 0.27577711
61
+ ]) # Normalize image
62
+
63
+ ])
64
+
65
+ def __len__(self):
66
+ return len(self.all_image_names)
67
+
68
+ def __getitem__(self, index):
69
+ image = Image.open(self.base_path+"/"+str(self.all_image_names[index])+str(self.all_image_ext[index])).convert("RGB")
70
+ ratio = image.height/image.width
71
+ if ratio > 2.0 or ratio < 0.5:
72
+ image = self.thin_transform(image)
73
+ else:
74
+ image = self.normal_transform(image)
75
+
76
+
77
+ return {
78
+ 'image': image,
79
+ "image_name": self.all_image_names[index]
80
+ }
81
+
82
+
83
+ def prepare_model():
84
+ model = torch.load("path/to/your/model.pth").to("cuda")
85
+ model.to(memory_format=torch.channels_last)
86
+ model = model.eval()
87
+ return model
88
+
89
+
90
+ def train(tagging_is_running, model, dataloader, train_data, output_queue):
91
+ print('Begin tagging')
92
+ model.eval()
93
+ counter = 0
94
+
95
+ with torch.no_grad():
96
+ for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
97
+
98
+ data, image_names = data['image'].to("cuda"), data["image_name"]
99
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
100
+ outputs = model(data)
101
+
102
+ probabilities = torch.nn.functional.sigmoid(outputs)
103
+ output_queue.put((probabilities.to("cpu"), image_names))
104
+
105
+ counter += 1
106
+ _ = tagging_is_running.get()
107
+ print("Tagging finished!")
108
+
109
+
110
+ def tag_writer(tagging_is_running, output_queue, output_file_name):
111
+ with open("tags.json", "r") as file:
112
+ tags = json.load(file)
113
+ allowed_tags = sorted(tags)
114
+ del tags
115
+ allowed_tags.extend(["placeholder0", "placeholder1", "placeholder2"])
116
+ tag_count = len(allowed_tags)
117
+ assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
118
+
119
+ with open(output_file_name, "w") as output_csv:
120
+ writer = csv.writer(output_csv)
121
+ writer.writerow(["image_name", "tags", "tag_probs"])
122
+ while not (tagging_is_running.qsize()>0 and output_queue.qsize()>0):
123
+ tag_probabilities, image_names = output_queue.get()
124
+ tag_probabilities = tag_probabilities.tolist()
125
+
126
+ for per_image_tag_probabilities,image_name in zip(tag_probabilities, image_names, strict=True):
127
+ this_image_tags = []
128
+ this_image_tag_probabilities = []
129
+ for index, per_tag_probability in enumerate(per_image_tag_probabilities):
130
+ if per_tag_probability > 0.3:
131
+ tag = allowed_tags[index]
132
+ if "placeholder" not in tag:
133
+ this_image_tags.append(tag)
134
+ this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
135
+ image_row = [image_name," ".join(this_image_tags)," ".join(this_image_tag_probabilities)]
136
+ writer.writerow(image_row)
137
+
138
+
139
+
140
+
141
+
142
+
143
+ def set_seed(seed: int = 42) -> None:
144
+ np.random.seed(seed)
145
+ random.seed(seed)
146
+ torch.manual_seed(seed)
147
+ torch.cuda.manual_seed(seed)
148
+ # When running on the CuDNN backend, two further options must be set
149
+ torch.backends.cudnn.deterministic = True
150
+ torch.backends.cudnn.benchmark = False
151
+ # Set a fixed value for the hash seed
152
+ print(f"Random seed set as {seed}")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ steps = 0
157
+ output_file_name = "your_file.csv"
158
+ set_seed()
159
+ multiprocessing.set_start_method('spawn')
160
+ output_queue = multiprocessing.Queue()
161
+ tagging_is_running = multiprocessing.Queue(maxsize=5)
162
+ tagging_is_running.put("Running!")
163
+
164
+ # initialize the computation device
165
+ if torch.cuda.is_available():
166
+ device = torch.device('cuda')
167
+ else:
168
+ raise RuntimeError("CUDA is not available!")
169
+
170
+ model = prepare_model().to("cuda")
171
+ batch_size = 128
172
+
173
+
174
+ # read the training csv file
175
+ train_csv = pd.read_csv('/path/to/a/list/of/files/and/their/extensions.csv')
176
+ # train dataset
177
+ train_data = ImageDataset(
178
+ train_csv, train=True
179
+ )
180
+
181
+ train_loader = DataLoader(
182
+ train_data,
183
+ batch_size=batch_size,
184
+ shuffle=False,
185
+ num_workers=6,
186
+ pin_memory=True
187
+ )
188
+ process_writer = multiprocessing.Process(target=tag_writer, args=(tagging_is_running, output_queue, output_file_name))
189
+ process_writer.start()
190
+ process_tagger = multiprocessing.Process(target=train, args=(tagging_is_running, model, train_loader, train_data, output_queue,))
191
+ process_tagger.start()
192
+ process_writer.join()
193
+ process_tagger.join()