calpt commited on
Commit
330fe7c
·
1 Parent(s): a47964c

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +3 -12
  2. modeling_clip.py +129 -0
config.json CHANGED
@@ -4,17 +4,14 @@
4
  "architectures": [
5
  "OpenCLIPVisionTextDualEncoderModel"
6
  ],
 
 
 
7
  "logit_scale_init_value": 2.6592,
8
  "model_type": "vision-text-dual-encoder",
9
  "projection_dim": 512,
10
  "text_config": {
11
  "_name_or_path": "xlm-roberta-base",
12
- "adapters": {
13
- "adapters": {},
14
- "config_map": {},
15
- "fusion_config_map": {},
16
- "fusions": {}
17
- },
18
  "add_cross_attention": false,
19
  "architectures": [
20
  "XLMRobertaForMaskedLM"
@@ -99,12 +96,6 @@
99
  "transformers_version": null,
100
  "vision_config": {
101
  "_name_or_path": "",
102
- "adapters": {
103
- "adapters": {},
104
- "config_map": {},
105
- "fusion_config_map": {},
106
- "fusions": {}
107
- },
108
  "add_cross_attention": false,
109
  "architectures": null,
110
  "attention_dropout": 0.0,
 
4
  "architectures": [
5
  "OpenCLIPVisionTextDualEncoderModel"
6
  ],
7
+ "auto_map": {
8
+ "AutoModel": "modeling_clip.OpenCLIPVisionTextDualEncoderModel"
9
+ },
10
  "logit_scale_init_value": 2.6592,
11
  "model_type": "vision-text-dual-encoder",
12
  "projection_dim": 512,
13
  "text_config": {
14
  "_name_or_path": "xlm-roberta-base",
 
 
 
 
 
 
15
  "add_cross_attention": false,
16
  "architectures": [
17
  "XLMRobertaForMaskedLM"
 
96
  "transformers_version": null,
97
  "vision_config": {
98
  "_name_or_path": "",
 
 
 
 
 
 
99
  "add_cross_attention": false,
100
  "architectures": null,
101
  "attention_dropout": 0.0,
modeling_clip.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
6
+ from transformers.models.vision_text_dual_encoder.modeling_vision_text_dual_encoder import clip_loss, CLIPOutput
7
+
8
+
9
+ class MeanPooler(nn.Module):
10
+ """Mean pooling"""
11
+
12
+ def forward(self, x, attention_mask):
13
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
14
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
15
+
16
+
17
+ class OpenCLIPVisionTextDualEncoderModel(VisionTextDualEncoderModel):
18
+ def __init__(
19
+ self,
20
+ config: Optional[VisionTextDualEncoderConfig] = None,
21
+ vision_model: Optional[PreTrainedModel] = None,
22
+ text_model: Optional[PreTrainedModel] = None,
23
+ add_text_model_pooling_layer: bool = False,
24
+ ):
25
+ super().__init__(config, vision_model, text_model)
26
+
27
+ # Remove text pooling layer
28
+ if not add_text_model_pooling_layer:
29
+ self.text_model.pooler = None
30
+
31
+ # Add mean pooling
32
+ self.pooler = MeanPooler()
33
+ # Overwrite text projection
34
+ hidden_size = (self.text_embed_dim + self.projection_dim) // 2
35
+ self.text_projection = nn.Sequential(
36
+ nn.Linear(self.text_embed_dim, hidden_size, bias=False),
37
+ nn.GELU(),
38
+ nn.Linear(hidden_size, self.projection_dim, bias=False),
39
+ )
40
+
41
+ def get_text_features(
42
+ self,
43
+ input_ids=None,
44
+ attention_mask=None,
45
+ position_ids=None,
46
+ token_type_ids=None,
47
+ output_attentions=None,
48
+ output_hidden_states=None,
49
+ return_dict=None,
50
+ ):
51
+ text_outputs = self.text_model(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
+ position_ids=position_ids,
55
+ token_type_ids=token_type_ids,
56
+ output_attentions=output_attentions,
57
+ output_hidden_states=output_hidden_states,
58
+ return_dict=return_dict,
59
+ )
60
+
61
+ pooled_output = self.pooler(text_outputs, attention_mask)
62
+ text_features = self.text_projection(pooled_output)
63
+
64
+ return text_features
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.LongTensor] = None,
69
+ pixel_values: Optional[torch.FloatTensor] = None,
70
+ attention_mask: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.LongTensor] = None,
72
+ return_loss: Optional[bool] = None,
73
+ token_type_ids: Optional[torch.LongTensor] = None,
74
+ output_attentions: Optional[bool] = None,
75
+ output_hidden_states: Optional[bool] = None,
76
+ return_dict: Optional[bool] = None,
77
+ ) -> Union[Tuple[torch.Tensor], CLIPOutput]:
78
+
79
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
80
+
81
+ vision_outputs = self.vision_model(
82
+ pixel_values=pixel_values,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+
88
+ text_outputs = self.text_model(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ image_embeds = vision_outputs[1] # pooler_output
99
+ image_embeds = self.visual_projection(image_embeds)
100
+
101
+ pooled_output = self.pooler(text_outputs, attention_mask)
102
+ text_embeds = self.text_projection(pooled_output)
103
+
104
+ # normalized features
105
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
106
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
107
+
108
+ # cosine similarity as logits
109
+ logit_scale = self.logit_scale.exp()
110
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
111
+ logits_per_image = logits_per_text.T
112
+
113
+ loss = None
114
+ if return_loss:
115
+ loss = clip_loss(logits_per_text)
116
+
117
+ if not return_dict:
118
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
119
+ return ((loss,) + output) if loss is not None else output
120
+
121
+ return CLIPOutput(
122
+ loss=loss,
123
+ logits_per_image=logits_per_image,
124
+ logits_per_text=logits_per_text,
125
+ text_embeds=text_embeds,
126
+ image_embeds=image_embeds,
127
+ text_model_output=text_outputs,
128
+ vision_model_output=vision_outputs,
129
+ )