Thouph commited on
Commit
fc68fe3
1 Parent(s): 5033331

Upload batched_inference.py

Browse files
Files changed (1) hide show
  1. batched_inference.py +104 -112
batched_inference.py CHANGED
@@ -1,7 +1,4 @@
1
- import csv
2
  import torch.multiprocessing as multiprocessing
3
- import pandas as pd
4
- import numpy as np
5
  import torchvision.transforms as transforms
6
  from torch import autocast
7
  from torch.utils.data import Dataset, DataLoader
@@ -9,79 +6,84 @@ from PIL import Image
9
  import torch
10
  from torchvision.transforms import InterpolationMode
11
  from tqdm import tqdm
12
- import random
13
  import json
 
14
 
15
  torch.backends.cuda.matmul.allow_tf32 = True
16
  torch.backends.cudnn.allow_tf32 = True
17
-
18
  torch.autograd.set_detect_anomaly(False)
19
-
20
  torch.autograd.profiler.emit_nvtx(enabled=False)
21
  torch.autograd.profiler.profile(enabled=False)
22
  torch.backends.cudnn.benchmark = True
23
 
 
24
  class ImageDataset(Dataset):
25
- def __init__(self, csv_file, train, base_path):
26
-
27
- self.csv_file = csv_file
28
- self.train = train
29
- self.all_image_names = self.csv_file[:]['md5'].apply(str)
30
- self.all_image_ext = self.csv_file[:]['file_ext'].apply(str)
31
- self.train_size = len(self.csv_file)
32
- self.base_path = base_path
33
- if self.train == True:
34
- print(f"Number of training images: {self.train_size}")
35
- self.thin_transform = transforms.Compose([
36
- transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
37
- transforms.CenterCrop(224),
38
- transforms.ToTensor(),
39
- transforms.Normalize(mean=[
40
- 0.48145466,
41
- 0.4578275,
42
- 0.40821073
43
- ], std=[
44
- 0.26862954,
45
- 0.26130258,
46
- 0.27577711
47
- ]) # Normalize image
48
-
49
- ])
50
- self.normal_transform = transforms.Compose([
51
- transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[
54
- 0.48145466,
55
- 0.4578275,
56
- 0.40821073
57
- ], std=[
58
- 0.26862954,
59
- 0.26130258,
60
- 0.27577711
61
- ]) # Normalize image
62
-
63
- ])
 
 
 
 
 
 
64
 
65
  def __len__(self):
66
- return len(self.all_image_names)
67
 
68
  def __getitem__(self, index):
69
- image = Image.open(self.base_path+"/"+str(self.all_image_names[index])+str(self.all_image_ext[index])).convert("RGB")
70
- ratio = image.height/image.width
71
  if ratio > 2.0 or ratio < 0.5:
72
  image = self.thin_transform(image)
73
  else:
74
  image = self.normal_transform(image)
75
 
76
-
77
  return {
78
  'image': image,
79
- "image_name": self.all_image_names[index]
 
80
  }
81
 
82
 
83
- def prepare_model():
84
- model = torch.load("path/to/your/model.pth").to("cuda")
85
  model.to(memory_format=torch.channels_last)
86
  model = model.eval()
87
  return model
@@ -94,20 +96,19 @@ def train(tagging_is_running, model, dataloader, train_data, output_queue):
94
 
95
  with torch.no_grad():
96
  for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
97
-
98
- data, image_names = data['image'].to("cuda"), data["image_name"]
99
  with autocast(device_type='cuda', dtype=torch.bfloat16):
100
- outputs = model(data)
101
 
102
  probabilities = torch.nn.functional.sigmoid(outputs)
103
- output_queue.put((probabilities.to("cpu"), image_names))
104
 
105
  counter += 1
106
  _ = tagging_is_running.get()
107
  print("Tagging finished!")
108
 
109
 
110
- def tag_writer(tagging_is_running, output_queue, output_file_name):
111
  with open("tags.json", "r") as file:
112
  tags = json.load(file)
113
  allowed_tags = sorted(tags)
@@ -116,78 +117,69 @@ def tag_writer(tagging_is_running, output_queue, output_file_name):
116
  tag_count = len(allowed_tags)
117
  assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
118
 
119
- with open(output_file_name, "w") as output_csv:
120
- writer = csv.writer(output_csv)
121
- writer.writerow(["image_name", "tags", "tag_probs"])
122
- while not (tagging_is_running.qsize()>0 and output_queue.qsize()>0):
123
- tag_probabilities, image_names = output_queue.get()
124
- tag_probabilities = tag_probabilities.tolist()
125
-
126
- for per_image_tag_probabilities,image_name in zip(tag_probabilities, image_names, strict=True):
127
- this_image_tags = []
128
- this_image_tag_probabilities = []
129
- for index, per_tag_probability in enumerate(per_image_tag_probabilities):
130
- if per_tag_probability > 0.3:
131
- tag = allowed_tags[index]
132
- if "placeholder" not in tag:
133
- this_image_tags.append(tag)
134
- this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
135
- image_row = [image_name," ".join(this_image_tags)," ".join(this_image_tag_probabilities)]
136
- writer.writerow(image_row)
137
-
138
-
139
-
140
-
 
 
 
 
 
 
 
 
141
 
142
-
143
- def set_seed(seed: int = 42) -> None:
144
- np.random.seed(seed)
145
- random.seed(seed)
146
- torch.manual_seed(seed)
147
- torch.cuda.manual_seed(seed)
148
- # When running on the CuDNN backend, two further options must be set
149
- torch.backends.cudnn.deterministic = True
150
- torch.backends.cudnn.benchmark = False
151
- # Set a fixed value for the hash seed
152
- print(f"Random seed set as {seed}")
153
-
154
-
155
- if __name__ == "__main__":
156
- steps = 0
157
- output_file_name = "your_file.csv"
158
- set_seed()
159
  multiprocessing.set_start_method('spawn')
160
  output_queue = multiprocessing.Queue()
161
  tagging_is_running = multiprocessing.Queue(maxsize=5)
162
  tagging_is_running.put("Running!")
163
 
164
  # initialize the computation device
165
- if torch.cuda.is_available():
166
- device = torch.device('cuda')
167
- else:
168
  raise RuntimeError("CUDA is not available!")
169
 
170
- model = prepare_model().to("cuda")
171
- batch_size = 128
172
-
173
 
174
  # read the training csv file
175
- train_csv = pd.read_csv('/path/to/a/list/of/files/and/their/extensions.csv')
176
  # train dataset
177
- train_data = ImageDataset(
178
- train_csv, train=True
179
- )
180
 
181
- train_loader = DataLoader(
182
- train_data,
183
  batch_size=batch_size,
184
  shuffle=False,
185
- num_workers=6,
186
- pin_memory=True
 
187
  )
188
- process_writer = multiprocessing.Process(target=tag_writer, args=(tagging_is_running, output_queue, output_file_name))
 
189
  process_writer.start()
190
- process_tagger = multiprocessing.Process(target=train, args=(tagging_is_running, model, train_loader, train_data, output_queue,))
 
191
  process_tagger.start()
192
  process_writer.join()
193
  process_tagger.join()
 
 
 
 
 
 
1
  import torch.multiprocessing as multiprocessing
 
 
2
  import torchvision.transforms as transforms
3
  from torch import autocast
4
  from torch.utils.data import Dataset, DataLoader
 
6
  import torch
7
  from torchvision.transforms import InterpolationMode
8
  from tqdm import tqdm
 
9
  import json
10
+ import os
11
 
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
  torch.backends.cudnn.allow_tf32 = True
 
14
  torch.autograd.set_detect_anomaly(False)
 
15
  torch.autograd.profiler.emit_nvtx(enabled=False)
16
  torch.autograd.profiler.profile(enabled=False)
17
  torch.backends.cudnn.benchmark = True
18
 
19
+
20
  class ImageDataset(Dataset):
21
+ def __init__(self, image_folder_path, allowed_extensions):
22
+ self.allowed_extensions = allowed_extensions
23
+ self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path)
24
+ self.train_size = len(self.all_image_paths)
25
+ print(f"Number of images to be tagged: {self.train_size}")
26
+ self.thin_transform = transforms.Compose([
27
+ transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
28
+ transforms.CenterCrop(224),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[
31
+ 0.48145466,
32
+ 0.4578275,
33
+ 0.40821073
34
+ ], std=[
35
+ 0.26862954,
36
+ 0.26130258,
37
+ 0.27577711
38
+ ]) # Normalize image
39
+ ])
40
+ self.normal_transform = transforms.Compose([
41
+ transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[
44
+ 0.48145466,
45
+ 0.4578275,
46
+ 0.40821073
47
+ ], std=[
48
+ 0.26862954,
49
+ 0.26130258,
50
+ 0.27577711
51
+ ]) # Normalize image
52
+
53
+ ])
54
+
55
+ def get_image_paths(self, folder_path):
56
+ image_paths = []
57
+ image_file_names = []
58
+ image_base_paths = []
59
+ for root, dirs, files in os.walk(folder_path):
60
+ for file in files:
61
+ if file.lower().split(".")[-1] in self.allowed_extensions:
62
+ image_paths.append((os.path.abspath(os.path.join(root, file))))
63
+ image_file_names.append(file.split(".")[0])
64
+ image_base_paths.append(root)
65
+ return image_paths, image_file_names, image_base_paths
66
 
67
  def __len__(self):
68
+ return len(self.all_image_paths)
69
 
70
  def __getitem__(self, index):
71
+ image = Image.open(self.all_image_paths[index]).convert("RGB")
72
+ ratio = image.height / image.width
73
  if ratio > 2.0 or ratio < 0.5:
74
  image = self.thin_transform(image)
75
  else:
76
  image = self.normal_transform(image)
77
 
 
78
  return {
79
  'image': image,
80
+ "image_name": self.all_image_names[index],
81
+ "image_root": self.image_base_paths[index]
82
  }
83
 
84
 
85
+ def prepare_model(model_path: str):
86
+ model = torch.load(model_path)
87
  model.to(memory_format=torch.channels_last)
88
  model = model.eval()
89
  return model
 
96
 
97
  with torch.no_grad():
98
  for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
99
+ this_data = data['image'].to("cuda")
 
100
  with autocast(device_type='cuda', dtype=torch.bfloat16):
101
+ outputs = model(this_data)
102
 
103
  probabilities = torch.nn.functional.sigmoid(outputs)
104
+ output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"]))
105
 
106
  counter += 1
107
  _ = tagging_is_running.get()
108
  print("Tagging finished!")
109
 
110
 
111
+ def tag_writer(tagging_is_running, output_queue, threshold):
112
  with open("tags.json", "r") as file:
113
  tags = json.load(file)
114
  allowed_tags = sorted(tags)
 
117
  tag_count = len(allowed_tags)
118
  assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
119
 
120
+ while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0):
121
+ tag_probabilities, image_names, image_roots = output_queue.get()
122
+ tag_probabilities = tag_probabilities.tolist()
123
+
124
+ for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots,
125
+ strict=True):
126
+ this_image_tags = []
127
+ this_image_tag_probabilities = []
128
+ for index, per_tag_probability in enumerate(per_image_tag_probabilities):
129
+ if per_tag_probability > threshold:
130
+ tag = allowed_tags[index]
131
+ if "placeholder" not in tag:
132
+ this_image_tags.append(tag)
133
+ this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
134
+ output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt")
135
+ with open(output_file, "w", encoding="utf-8") as this_output:
136
+ this_output.write(" ".join(this_image_tags))
137
+ this_output.write("\n")
138
+ this_output.write(" ".join(this_image_tag_probabilities))
139
+
140
+
141
+ def main():
142
+ image_folder_path = "/path/to/your/folder/"
143
+ # all images should be in this folder and/or its subfolders.
144
+ # I will generate a text file for every image.
145
+ model_path = "/path/to/your/model.pth"
146
+ allowed_extensions = {"jpg", "jpeg", "png", "webp"}
147
+ batch_size = 64
148
+ # if you have a 24GB card, you can try 256
149
+ threshold = 0.3
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  multiprocessing.set_start_method('spawn')
152
  output_queue = multiprocessing.Queue()
153
  tagging_is_running = multiprocessing.Queue(maxsize=5)
154
  tagging_is_running.put("Running!")
155
 
156
  # initialize the computation device
157
+ if not torch.cuda.is_available():
 
 
158
  raise RuntimeError("CUDA is not available!")
159
 
160
+ model = prepare_model(model_path).to("cuda")
 
 
161
 
162
  # read the training csv file
 
163
  # train dataset
164
+ dataset = ImageDataset(image_folder_path, allowed_extensions)
 
 
165
 
166
+ batched_loader = DataLoader(
167
+ dataset,
168
  batch_size=batch_size,
169
  shuffle=False,
170
+ num_workers=6, # if you have a big batch size, a good cpu, and enough cpu memory, try 12
171
+ pin_memory=True,
172
+ drop_last=False,
173
  )
174
+ process_writer = multiprocessing.Process(target=tag_writer,
175
+ args=(tagging_is_running, output_queue, threshold))
176
  process_writer.start()
177
+ process_tagger = multiprocessing.Process(target=train,
178
+ args=(tagging_is_running, model, batched_loader, dataset, output_queue,))
179
  process_tagger.start()
180
  process_writer.join()
181
  process_tagger.join()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()