qingy2019 commited on
Commit
e3d2cee
1 Parent(s): e5d6152

gotta catch em all

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. viclip.py +30 -11
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
viclip.py CHANGED
@@ -11,20 +11,37 @@ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
11
  from .viclip_vision import clip_joint_l14, clip_joint_b16
12
  from .viclip_text import clip_text_l14, clip_text_b16
13
 
14
- logger = logging.getLogger(__name__)
 
 
15
 
 
16
 
17
- class ViCLIP(nn.Module):
18
- """docstring for ViCLIP"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def __init__(self,
21
- tokenizer=None,
22
- size='l',
23
- pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
24
- freeze_text=True):
25
- super(ViCLIP, self).__init__()
26
  if tokenizer:
27
  self.tokenizer = tokenizer
 
 
28
  else:
29
  self.tokenizer = _Tokenizer()
30
  self.max_txt_l = 32
@@ -217,6 +234,7 @@ class ViCLIP(nn.Module):
217
  context_length=self.max_txt_l,
218
  vocab_size=self.text_encoder_vocab_size,
219
  checkpoint_num=0,
 
220
  )
221
  elif encoder_name == "vit_b16":
222
  text_encoder = clip_text_b16(
@@ -224,6 +242,7 @@ class ViCLIP(nn.Module):
224
  context_length=self.max_txt_l,
225
  vocab_size=self.text_encoder_vocab_size,
226
  checkpoint_num=0,
 
227
  )
228
  else:
229
  raise NotImplementedError(f"Not implemented: {encoder_name}")
@@ -253,10 +272,10 @@ class ViCLIP(nn.Module):
253
  return clip_feat
254
 
255
  def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
256
- label_probs = (100.0 * clip_feature @ text_feats_tensor.T)
257
  top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
258
  return top_probs, top_labels
259
 
260
 
261
  if __name__ =="__main__":
262
- tokenizer = _Tokenizer()
 
11
  from .viclip_vision import clip_joint_l14, clip_joint_b16
12
  from .viclip_text import clip_text_l14, clip_text_b16
13
 
14
+ # from transformers import AutoModel
15
+ from transformers import PreTrainedModel #new
16
+ from transformers import PretrainedConfig
17
 
18
+ logger = logging.getLogger(__name__)
19
 
20
+ from .configuration_viclip import Config
21
+ # class ViCLIP(nn.Module):
22
+ class ViCLIP(PreTrainedModel):
23
+ _auto_class="AutoModel"
24
+ config_class=Config
25
+
26
+ def __init__(self,
27
+ # tokenizer=None, # config:PretrainedConfig is the only parameter
28
+ # size='l',
29
+ # pretrain=None,
30
+ # freeze_text=True,
31
+ config=PretrainedConfig()):
32
+ super(ViCLIP, self).__init__(config)
33
+ self.config=config
34
+ if 'size' in config.to_dict(): ###########
35
+ size=config.size
36
+ pretrain=None
37
+ tokenizer_path=config.tokenizer_path
38
+ tokenizer=None
39
+ freeze_text=True
40
 
 
 
 
 
 
 
41
  if tokenizer:
42
  self.tokenizer = tokenizer
43
+ elif tokenizer_path:
44
+ self.tokenizer = _Tokenizer(tokenizer_path)
45
  else:
46
  self.tokenizer = _Tokenizer()
47
  self.max_txt_l = 32
 
234
  context_length=self.max_txt_l,
235
  vocab_size=self.text_encoder_vocab_size,
236
  checkpoint_num=0,
237
+ tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
238
  )
239
  elif encoder_name == "vit_b16":
240
  text_encoder = clip_text_b16(
 
242
  context_length=self.max_txt_l,
243
  vocab_size=self.text_encoder_vocab_size,
244
  checkpoint_num=0,
245
+ tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
246
  )
247
  else:
248
  raise NotImplementedError(f"Not implemented: {encoder_name}")
 
272
  return clip_feat
273
 
274
  def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
275
+ label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
276
  top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
277
  return top_probs, top_labels
278
 
279
 
280
  if __name__ =="__main__":
281
+ tokenizer = _Tokenizer()