gotta catch em all
Browse files
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
viclip.py
CHANGED
@@ -11,20 +11,37 @@ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
|
11 |
from .viclip_vision import clip_joint_l14, clip_joint_b16
|
12 |
from .viclip_text import clip_text_l14, clip_text_b16
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
|
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
def __init__(self,
|
21 |
-
tokenizer=None,
|
22 |
-
size='l',
|
23 |
-
pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
|
24 |
-
freeze_text=True):
|
25 |
-
super(ViCLIP, self).__init__()
|
26 |
if tokenizer:
|
27 |
self.tokenizer = tokenizer
|
|
|
|
|
28 |
else:
|
29 |
self.tokenizer = _Tokenizer()
|
30 |
self.max_txt_l = 32
|
@@ -217,6 +234,7 @@ class ViCLIP(nn.Module):
|
|
217 |
context_length=self.max_txt_l,
|
218 |
vocab_size=self.text_encoder_vocab_size,
|
219 |
checkpoint_num=0,
|
|
|
220 |
)
|
221 |
elif encoder_name == "vit_b16":
|
222 |
text_encoder = clip_text_b16(
|
@@ -224,6 +242,7 @@ class ViCLIP(nn.Module):
|
|
224 |
context_length=self.max_txt_l,
|
225 |
vocab_size=self.text_encoder_vocab_size,
|
226 |
checkpoint_num=0,
|
|
|
227 |
)
|
228 |
else:
|
229 |
raise NotImplementedError(f"Not implemented: {encoder_name}")
|
@@ -253,10 +272,10 @@ class ViCLIP(nn.Module):
|
|
253 |
return clip_feat
|
254 |
|
255 |
def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
|
256 |
-
label_probs = (100.0 * clip_feature @ text_feats_tensor.T)
|
257 |
top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
|
258 |
return top_probs, top_labels
|
259 |
|
260 |
|
261 |
if __name__ =="__main__":
|
262 |
-
tokenizer = _Tokenizer()
|
|
|
11 |
from .viclip_vision import clip_joint_l14, clip_joint_b16
|
12 |
from .viclip_text import clip_text_l14, clip_text_b16
|
13 |
|
14 |
+
# from transformers import AutoModel
|
15 |
+
from transformers import PreTrainedModel #new
|
16 |
+
from transformers import PretrainedConfig
|
17 |
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
from .configuration_viclip import Config
|
21 |
+
# class ViCLIP(nn.Module):
|
22 |
+
class ViCLIP(PreTrainedModel):
|
23 |
+
_auto_class="AutoModel"
|
24 |
+
config_class=Config
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
# tokenizer=None, # config:PretrainedConfig is the only parameter
|
28 |
+
# size='l',
|
29 |
+
# pretrain=None,
|
30 |
+
# freeze_text=True,
|
31 |
+
config=PretrainedConfig()):
|
32 |
+
super(ViCLIP, self).__init__(config)
|
33 |
+
self.config=config
|
34 |
+
if 'size' in config.to_dict(): ###########
|
35 |
+
size=config.size
|
36 |
+
pretrain=None
|
37 |
+
tokenizer_path=config.tokenizer_path
|
38 |
+
tokenizer=None
|
39 |
+
freeze_text=True
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if tokenizer:
|
42 |
self.tokenizer = tokenizer
|
43 |
+
elif tokenizer_path:
|
44 |
+
self.tokenizer = _Tokenizer(tokenizer_path)
|
45 |
else:
|
46 |
self.tokenizer = _Tokenizer()
|
47 |
self.max_txt_l = 32
|
|
|
234 |
context_length=self.max_txt_l,
|
235 |
vocab_size=self.text_encoder_vocab_size,
|
236 |
checkpoint_num=0,
|
237 |
+
tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
|
238 |
)
|
239 |
elif encoder_name == "vit_b16":
|
240 |
text_encoder = clip_text_b16(
|
|
|
242 |
context_length=self.max_txt_l,
|
243 |
vocab_size=self.text_encoder_vocab_size,
|
244 |
checkpoint_num=0,
|
245 |
+
tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
|
246 |
)
|
247 |
else:
|
248 |
raise NotImplementedError(f"Not implemented: {encoder_name}")
|
|
|
272 |
return clip_feat
|
273 |
|
274 |
def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
|
275 |
+
label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
|
276 |
top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
|
277 |
return top_probs, top_labels
|
278 |
|
279 |
|
280 |
if __name__ =="__main__":
|
281 |
+
tokenizer = _Tokenizer()
|