sonoisa commited on
Commit
f5d1536
·
1 Parent(s): d15c1bd

Remove unused model

Browse files
Files changed (1) hide show
  1. app.py +137 -130
app.py CHANGED
@@ -173,130 +173,134 @@ class ClipTextModel(nn.Module):
173
  torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
174
 
175
 
176
- class ClipVisionModel(nn.Module):
177
- def __init__(self, model_name_or_path, device=None):
178
- super(ClipVisionModel, self).__init__()
179
-
180
- if os.path.exists(model_name_or_path):
181
- # load from file system
182
- visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
183
- else:
184
- # download from the Hugging Face model hub
185
- filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
186
- visual_projection_state_dict = torch.load(filename)
187
-
188
- self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
189
- config = self.model.config
190
-
191
- self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
192
-
193
- vision_embed_dim = config.hidden_size
194
- projection_dim = 512
195
-
196
- self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
197
- self.visual_projection.load_state_dict(visual_projection_state_dict)
198
-
199
- self.eval()
200
-
201
- if device is None:
202
- device = "cuda" if torch.cuda.is_available() else "cpu"
203
- self.device = torch.device(device)
204
- self.to(self.device)
205
-
206
- def forward(
207
- self,
208
- pixel_values=None,
209
- output_attentions=None,
210
- output_hidden_states=None,
211
- return_dict=None,
212
- ):
213
- output_states = self.model(
214
- pixel_values=pixel_values,
215
- output_attentions=output_attentions,
216
- output_hidden_states=output_hidden_states,
217
- return_dict=return_dict,
218
- )
219
- image_embeds = self.visual_projection(output_states[1])
220
-
221
- return image_embeds
222
-
223
- @torch.no_grad()
224
- def encode_image(self, images, batch_size=8):
225
- all_embeddings = []
226
- iterator = range(0, len(images), batch_size)
227
- for batch_idx in iterator:
228
- batch = images[batch_idx:batch_idx + batch_size]
229
-
230
- encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
231
- model_output = self(**encoded_input)
232
- image_embeddings = model_output.cpu()
233
-
234
- all_embeddings.extend(image_embeddings)
235
-
236
- # return torch.stack(all_embeddings).numpy()
237
- return torch.stack(all_embeddings)
238
-
239
- @staticmethod
240
- def remove_alpha_channel(image):
241
- image.convert("RGBA")
242
- alpha = image.convert('RGBA').split()[-1]
243
- background = Image.new("RGBA", image.size, (255, 255, 255))
244
- background.paste(image, mask=alpha)
245
- image = background.convert("RGB")
246
- return image
247
-
248
- def save(self, output_dir):
249
- self.model.save_pretrained(output_dir)
250
- self.feature_extractor.save_pretrained(output_dir)
251
- torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
252
-
253
-
254
- class ClipModel(nn.Module):
255
- def __init__(self, model_name_or_path, device=None):
256
- super(ClipModel, self).__init__()
257
-
258
- if os.path.exists(model_name_or_path):
259
- # load from file system
260
- repo_dir = model_name_or_path
261
- else:
262
- # download from the Hugging Face model hub
263
- repo_dir = snapshot_download(model_name_or_path)
264
-
265
- self.text_model = ClipTextModel(repo_dir, device=device)
266
- self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
267
-
268
- with torch.no_grad():
269
- logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
270
- logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
271
- self.logit_scale = logit_scale
272
-
273
- self.eval()
274
-
275
- if device is None:
276
- device = "cuda" if torch.cuda.is_available() else "cpu"
277
- self.device = torch.device(device)
278
- self.to(self.device)
279
-
280
- def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
281
- image_features = self.vision_model(pixel_values=pixel_values)
282
- text_features = self.text_model(input_ids=input_ids,
283
- attention_mask=attention_mask,
284
- token_type_ids=token_type_ids)[0]
285
-
286
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
287
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
288
-
289
- logit_scale = self.logit_scale.exp()
290
- logits_per_image = logit_scale * image_features @ text_features.t()
291
- logits_per_text = logits_per_image.t()
292
-
293
- return logits_per_image, logits_per_text
294
-
295
- def save(self, output_dir):
296
- torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
297
- self.text_model.save(output_dir)
298
- self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
-
 
 
 
 
300
 
301
  def encode_text(text, model):
302
  text = normalize_text(text)
@@ -304,10 +308,10 @@ def encode_text(text, model):
304
  return text_embedding
305
 
306
 
307
- def encode_image(image_filename, model):
308
- image = Image.open(image_filename)
309
- image_embedding = model.vision_model.encode_image([image]).numpy()
310
- return image_embedding
311
 
312
 
313
  st.title("いらすと検索(日本語CLIPゼロショット)")
@@ -316,7 +320,9 @@ description_text = st.empty()
316
  if "model" not in st.session_state:
317
  description_text.text("日本語CLIPモデル読み込み中... ")
318
  device = "cuda" if torch.cuda.is_available() else "cpu"
319
- model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
 
 
320
  st.session_state.model = model
321
 
322
  print("extract dataset")
@@ -325,7 +331,8 @@ if "model" not in st.session_state:
325
  )
326
 
327
  print("loading dataset")
328
- df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet").to_pandas()
 
329
  st.session_state.df = df
330
 
331
  # sentence_vectors = np.stack(df["sentence_vector"])
 
173
  torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
174
 
175
 
176
+ # class ClipVisionModel(nn.Module):
177
+ # def __init__(self, model_name_or_path, device=None):
178
+ # super(ClipVisionModel, self).__init__()
179
+
180
+ # if os.path.exists(model_name_or_path):
181
+ # # load from file system
182
+ # visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
183
+ # else:
184
+ # # download from the Hugging Face model hub
185
+ # filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
186
+ # visual_projection_state_dict = torch.load(filename)
187
+
188
+ # self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
189
+ # config = self.model.config
190
+
191
+ # self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
192
+
193
+ # vision_embed_dim = config.hidden_size
194
+ # projection_dim = 512
195
+
196
+ # self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
197
+ # self.visual_projection.load_state_dict(visual_projection_state_dict)
198
+
199
+ # self.eval()
200
+
201
+ # if device is None:
202
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
203
+ # self.device = torch.device(device)
204
+ # self.to(self.device)
205
+
206
+ # def forward(
207
+ # self,
208
+ # pixel_values=None,
209
+ # output_attentions=None,
210
+ # output_hidden_states=None,
211
+ # return_dict=None,
212
+ # ):
213
+ # output_states = self.model(
214
+ # pixel_values=pixel_values,
215
+ # output_attentions=output_attentions,
216
+ # output_hidden_states=output_hidden_states,
217
+ # return_dict=return_dict,
218
+ # )
219
+ # image_embeds = self.visual_projection(output_states[1])
220
+
221
+ # return image_embeds
222
+
223
+ # @torch.no_grad()
224
+ # def encode_image(self, images, batch_size=8):
225
+ # all_embeddings = []
226
+ # iterator = range(0, len(images), batch_size)
227
+ # for batch_idx in iterator:
228
+ # batch = images[batch_idx:batch_idx + batch_size]
229
+
230
+ # encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
231
+ # model_output = self(**encoded_input)
232
+ # image_embeddings = model_output.cpu()
233
+
234
+ # all_embeddings.extend(image_embeddings)
235
+
236
+ # # return torch.stack(all_embeddings).numpy()
237
+ # return torch.stack(all_embeddings)
238
+
239
+ # @staticmethod
240
+ # def remove_alpha_channel(image):
241
+ # image.convert("RGBA")
242
+ # alpha = image.convert('RGBA').split()[-1]
243
+ # background = Image.new("RGBA", image.size, (255, 255, 255))
244
+ # background.paste(image, mask=alpha)
245
+ # image = background.convert("RGB")
246
+ # return image
247
+
248
+ # def save(self, output_dir):
249
+ # self.model.save_pretrained(output_dir)
250
+ # self.feature_extractor.save_pretrained(output_dir)
251
+ # torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
252
+
253
+
254
+ # class ClipModel(nn.Module):
255
+ # def __init__(self, model_name_or_path, device=None):
256
+ # super(ClipModel, self).__init__()
257
+
258
+ # if os.path.exists(model_name_or_path):
259
+ # # load from file system
260
+ # repo_dir = model_name_or_path
261
+ # else:
262
+ # # download from the Hugging Face model hub
263
+ # repo_dir = snapshot_download(model_name_or_path)
264
+
265
+ # self.text_model = ClipTextModel(repo_dir, device=device)
266
+ # self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
267
+
268
+ # with torch.no_grad():
269
+ # logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
270
+ # logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
271
+ # self.logit_scale = logit_scale
272
+
273
+ # self.eval()
274
+
275
+ # if device is None:
276
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
277
+ # self.device = torch.device(device)
278
+ # self.to(self.device)
279
+
280
+ # def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
281
+ # image_features = self.vision_model(pixel_values=pixel_values)
282
+ # text_features = self.text_model(input_ids=input_ids,
283
+ # attention_mask=attention_mask,
284
+ # token_type_ids=token_type_ids)[0]
285
+
286
+ # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
287
+ # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
288
+
289
+ # logit_scale = self.logit_scale.exp()
290
+ # logits_per_image = logit_scale * image_features @ text_features.t()
291
+ # logits_per_text = logits_per_image.t()
292
+
293
+ # return logits_per_image, logits_per_text
294
+
295
+ # def save(self, output_dir):
296
+ # torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
297
+ # self.text_model.save(output_dir)
298
+ # self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
+
300
+
301
+ def DummyClipModel:
302
+ def __init__(self, text_model)
303
+ self.text_model = text_model
304
 
305
  def encode_text(text, model):
306
  text = normalize_text(text)
 
308
  return text_embedding
309
 
310
 
311
+ # def encode_image(image_filename, model):
312
+ # image = Image.open(image_filename)
313
+ # image_embedding = model.vision_model.encode_image([image]).numpy()
314
+ # return image_embedding
315
 
316
 
317
  st.title("いらすと検索(日本語CLIPゼロショット)")
 
320
  if "model" not in st.session_state:
321
  description_text.text("日本語CLIPモデル読み込み中... ")
322
  device = "cuda" if torch.cuda.is_available() else "cpu"
323
+ text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
324
+ # model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
325
+ model = DummyClipModel(text_model)
326
  st.session_state.model = model
327
 
328
  print("extract dataset")
 
331
  )
332
 
333
  print("loading dataset")
334
+ df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet",
335
+ columns=["page", "description", "image_url", "image_vector"]).to_pandas()
336
  st.session_state.df = df
337
 
338
  # sentence_vectors = np.stack(df["sentence_vector"])