haoningwu commited on
Commit
60b48b9
·
verified ·
1 Parent(s): 9bc9c76

Upload 14 files

Browse files
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.6.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "num_train_timesteps": 1000,
8
+ "set_alpha_to_one": false,
9
+ "skip_prk_steps": true,
10
+ "steps_offset": 1,
11
+ "trained_betas": null,
12
+ "clip_sample": false
13
+ }
text_encoder_BiomedCLIP/config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "architectures": [
4
+ "BiomedCLIPModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_biomed_clip.BiomedCLIPConfig",
8
+ "AutoProcessor": "processing_biomed_clip.BiomedCLIPProcessor",
9
+ "AutoModel": "modeling_biomed_clip.BiomedCLIPModel",
10
+ "AutoModelForImageClassification": "modeling_biomed_clip.BiomedCLIPForImageClassification"
11
+ },
12
+ "initializer_factor": 1.0,
13
+ "logit_scale_init_value": 4.4454,
14
+ "model_type": "clip",
15
+ "projection_dim": 512,
16
+ "text_config": {
17
+ "attention_probs_dropout_prob": 0.1,
18
+ "gradient_checkpointing": false,
19
+ "hidden_act": "gelu",
20
+ "hidden_dropout_prob": 0.1,
21
+ "hidden_size": 768,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 3072,
24
+ "layer_norm_eps": 1e-12,
25
+ "max_position_embeddings": 512,
26
+ "model_type": "bert",
27
+ "num_attention_heads": 12,
28
+ "num_hidden_layers": 12,
29
+ "pad_token_id": 0,
30
+ "position_embedding_type": "absolute",
31
+ "transformers_version": "4.6.0.dev0",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 30522
35
+ },
36
+ "text_config_dict": {
37
+ "attention_probs_dropout_prob": 0.1,
38
+ "gradient_checkpointing": false,
39
+ "hidden_act": "gelu",
40
+ "hidden_dropout_prob": 0.1,
41
+ "hidden_size": 768,
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "layer_norm_eps": 1e-12,
45
+ "max_position_embeddings": 512,
46
+ "model_type": "bert",
47
+ "num_attention_heads": 12,
48
+ "num_hidden_layers": 12,
49
+ "pad_token_id": 0,
50
+ "position_embedding_type": "absolute",
51
+ "transformers_version": "4.6.0.dev0",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "vocab_size": 30522
55
+ },
56
+ "text_projection_config": {
57
+ "hidden_size": 768,
58
+ "intermediate_size": 640,
59
+ "projection_dim": 512,
60
+ "hidden_act": "gelu"
61
+ },
62
+ "text_projection_config_dict": {
63
+ "hidden_size": 768,
64
+ "intermediate_size": 640,
65
+ "projection_dim": 512,
66
+ "hidden_act": "gelu",
67
+ "num_hidden_layers": 2
68
+ },
69
+ "torch_dtype": "float32",
70
+ "transformers_version": null,
71
+ "vision_config": {
72
+ "attention_probs_dropout_prob": 0.0,
73
+ "hidden_act": "gelu",
74
+ "hidden_dropout_prob": 0.0,
75
+ "hidden_size": 768,
76
+ "image_size": 224,
77
+ "initializer_range": 0.02,
78
+ "intermediate_size": 3072,
79
+ "layer_norm_eps": 1e-12,
80
+ "model_type": "vit",
81
+ "num_attention_heads": 12,
82
+ "num_channels": 3,
83
+ "num_hidden_layers": 12,
84
+ "patch_size": 16,
85
+ "qkv_bias": true
86
+ },
87
+ "vision_config_dict": {
88
+ "attention_probs_dropout_prob": 0.0,
89
+ "hidden_act": "gelu",
90
+ "hidden_dropout_prob": 0.0,
91
+ "hidden_size": 768,
92
+ "image_size": 224,
93
+ "initializer_range": 0.02,
94
+ "intermediate_size": 3072,
95
+ "layer_norm_eps": 1e-12,
96
+ "model_type": "vit",
97
+ "num_attention_heads": 12,
98
+ "num_channels": 3,
99
+ "num_hidden_layers": 12,
100
+ "patch_size": 16,
101
+ "qkv_bias": true
102
+ }
103
+ }
text_encoder_BiomedCLIP/configuration_biomed_clip.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import *
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
5
+
6
+ class BiomedCLIPTextProjectionConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ hidden_size=768,
10
+ intermediate_size=640,
11
+ projection_dim=512,
12
+ num_hidden_layers=2,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(**kwargs)
16
+
17
+ self.hidden_size = hidden_size
18
+ self.intermediate_size = intermediate_size
19
+ self.projection_dim = projection_dim
20
+ self.num_hidden_layers = num_hidden_layers
21
+
22
+ @classmethod
23
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
24
+ cls._set_token_in_kwargs(kwargs)
25
+
26
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
27
+
28
+ # get the vision config dict if we are loading from CLIPConfig
29
+ if config_dict.get("model_type") == "clip":
30
+ config_dict = config_dict["text_projection_config"]
31
+
32
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
33
+ logger.warning(
34
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
35
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
36
+ )
37
+
38
+ return cls.from_dict(config_dict, **kwargs)
39
+
40
+ class BiomedCLIPConfig(CLIPConfig):
41
+ def __init__(
42
+ self, text_config=None, text_projection_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
43
+ ):
44
+ # If `_config_dict` exist, we use them for the backward compatibility.
45
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
46
+ # of confusion!).
47
+ super().__init__(text_config, vision_config, projection_dim, logit_scale_init_value, **kwargs)
48
+
49
+ text_projection_config_dict = kwargs.pop("text_projection_config_dict", None)
50
+ if text_projection_config is None:
51
+ if text_projection_config_dict is not None:
52
+ text_projection_config = {}
53
+
54
+ _text_projection_config_dict = BiomedCLIPTextProjectionConfig(**text_projection_config_dict)
55
+
56
+ text_projection_config.update(_text_projection_config_dict)
57
+ else:
58
+ text_projection_config = BiomedCLIPTextProjectionConfig(**text_projection_config)
59
+
60
+ self.text_projection_config = text_projection_config
text_encoder_BiomedCLIP/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6e9edd227566b056eef0c682bb13f9e501364c1c29b8ccb06986a9ffb5cfbb
3
+ size 783658420
text_encoder_BiomedCLIP/modeling_biomed_clip.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Modified by chuhac for a timm-free implementation
3
+ # Model can be directly imported with ``from_pretrained`` and ``trust_remote_code = True`` in the huggingface format
4
+ # Diff from HF CLIP Implementation:
5
+ # 1. pre-norm instead of post-norm in Vision Tower (the original implementation is right but the module registration order is misleading)
6
+ # 2. CLS Pooling with MLP in Text Tower
7
+ # 3. Remove pre norm in Vision Tower
8
+ # 4. CNN bias in Vision Tower
9
+ # 5. Change layer_norm eps from 1e-5 to 1e-12, which introduce a little numerical variations (1e-5 level)
10
+ ## ******************************** ##
11
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """ PyTorch BiomedCLIP model """
24
+ """ No need for timm or open-clip-torch """
25
+
26
+
27
+ from dataclasses import dataclass
28
+ from typing import Any, Optional, Tuple, Union, List
29
+
30
+ import math
31
+ import torch
32
+ import torch.utils.checkpoint
33
+ from torch import nn
34
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
35
+
36
+ from transformers.activations import ACT2FN
37
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutput,
40
+ BaseModelOutputWithPooling,
41
+ ImageClassifierOutput,
42
+ BaseModelOutputWithPoolingAndCrossAttentions,
43
+ BaseModelOutputWithPastAndCrossAttentions
44
+ )
45
+ from transformers.modeling_utils import PreTrainedModel
46
+ from transformers.utils import (
47
+ ModelOutput,
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ logging,
52
+ replace_return_docstrings,
53
+ )
54
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
55
+ from transformers.models.clip.modeling_clip import *
56
+
57
+ from .configuration_biomed_clip import BiomedCLIPTextProjectionConfig, BiomedCLIPConfig
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+
64
+ # contrastive loss function, adapted from
65
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
66
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
67
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
68
+
69
+
70
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
71
+ caption_loss = contrastive_loss(similarity)
72
+ image_loss = contrastive_loss(similarity.t())
73
+ return (caption_loss + image_loss) / 2.0
74
+
75
+
76
+ class BiomedCLIPVisionEmbeddings(CLIPVisionEmbeddings):
77
+ def __init__(self, config: CLIPVisionConfig):
78
+ super().__init__(config)
79
+
80
+ self.patch_embedding = nn.Conv2d(
81
+ in_channels=config.num_channels,
82
+ out_channels=self.embed_dim,
83
+ kernel_size=self.patch_size,
84
+ stride=self.patch_size,
85
+ # True in open_clip
86
+ bias=True,
87
+ )
88
+
89
+ # TODO
90
+ class BiomedCLIPTextEmbeddings(nn.Module):
91
+ def __init__(self, config: CLIPTextConfig):
92
+ super().__init__()
93
+ embed_dim = config.hidden_size
94
+
95
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
96
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
97
+ self.token_type_embedding = nn.Embedding(config.type_vocab_size, embed_dim)
98
+
99
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
100
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
101
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
102
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
103
+
104
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
105
+ self.register_buffer(
106
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
107
+ )
108
+ self.register_buffer(
109
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ input_ids: Optional[torch.LongTensor] = None,
115
+ token_type_ids: Optional[torch.LongTensor] = None,
116
+ position_ids: Optional[torch.LongTensor] = None,
117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
118
+ past_key_values_length: int = 0,
119
+ ) -> torch.Tensor:
120
+
121
+ if input_ids is not None:
122
+ input_shape = input_ids.size()
123
+ else:
124
+ input_shape = inputs_embeds.size()[:-1]
125
+
126
+ seq_length = input_shape[1]
127
+
128
+ if position_ids is None:
129
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
130
+
131
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
132
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
133
+ # issue #5664
134
+ if token_type_ids is None:
135
+ if hasattr(self, "token_type_ids"):
136
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
137
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
138
+ token_type_ids = buffered_token_type_ids_expanded
139
+ else:
140
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
141
+
142
+ if inputs_embeds is None:
143
+ inputs_embeds = self.token_embedding(input_ids)
144
+ token_type_embeddings = self.token_type_embedding(token_type_ids)
145
+
146
+ embeddings = inputs_embeds + token_type_embeddings
147
+ if self.position_embedding_type == "absolute":
148
+ position_embeddings = self.position_embedding(position_ids)
149
+ embeddings += position_embeddings
150
+
151
+ embeddings = self.layer_norm(embeddings)
152
+ embeddings = self.dropout(embeddings)
153
+ return embeddings
154
+
155
+
156
+ class BiomedCLIPAttention(nn.Module):
157
+ def __init__(self, config, position_embedding_type=None):
158
+ super().__init__()
159
+ super().__init__()
160
+ self.config = config
161
+ self.embed_dim = config.hidden_size
162
+ self.num_heads = config.num_attention_heads
163
+ self.head_dim = self.embed_dim // self.num_heads
164
+ if self.head_dim * self.num_heads != self.embed_dim:
165
+ raise ValueError(
166
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
167
+ f" {self.num_heads})."
168
+ )
169
+ self.scale = self.head_dim**-0.5
170
+ self.dropout = nn.Dropout(config.attention_dropout)
171
+
172
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
173
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
174
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
175
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
176
+
177
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
178
+ new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
179
+ x = x.view(new_x_shape)
180
+ return x.permute(0, 2, 1, 3)
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ attention_mask: Optional[torch.FloatTensor] = None,
186
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
187
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
188
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
189
+ output_attentions: Optional[bool] = False,
190
+ ) -> Tuple[torch.Tensor]:
191
+
192
+ mixed_query_layer = self.q_proj(hidden_states)
193
+
194
+ # If this is instantiated as a cross-attention module, the keys
195
+ # and values come from an encoder; the attention mask needs to be
196
+ # such that the encoder's padding tokens are not attended to.
197
+ is_cross_attention = encoder_hidden_states is not None
198
+
199
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
200
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
201
+
202
+ query_layer = self.transpose_for_scores(mixed_query_layer)
203
+
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+
209
+ attention_scores = attention_scores / math.sqrt(self.head_dim)
210
+ if attention_mask is not None:
211
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
212
+ attention_scores = attention_scores + attention_mask
213
+
214
+ # Normalize the attention scores to probabilities.
215
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
216
+
217
+ # This is actually dropping out entire tokens to attend to, which might
218
+ # seem a bit unusual, but is taken from the original Transformer paper.
219
+ attention_probs = self.dropout(attention_probs)
220
+
221
+
222
+ context_layer = torch.matmul(attention_probs, value_layer)
223
+
224
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
225
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
226
+ context_layer = context_layer.view(new_context_layer_shape).contiguous()
227
+
228
+ outputs = self.out_proj(context_layer)
229
+ return outputs, attention_probs
230
+
231
+
232
+
233
+
234
+ class BiomedCLIPEncoderLayer(nn.Module):
235
+ def __init__(self, config: BiomedCLIPConfig, norm='pre'):
236
+ super().__init__()
237
+ self.embed_dim = config.hidden_size
238
+ # pre-norm
239
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
240
+ self.self_attn = BiomedCLIPAttention(config)
241
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
242
+ self.mlp = CLIPMLP(config)
243
+ self.norm = norm
244
+
245
+ if self.norm == 'pre':
246
+ self.forward = self.pre_norm_forward
247
+ elif self.norm == 'post':
248
+ self.forward = self.post_norm_forward
249
+
250
+
251
+ def pre_norm_forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ attention_mask: torch.Tensor,
255
+ output_attentions: Optional[bool] = False,
256
+ ) -> Tuple[torch.FloatTensor]:
257
+ """
258
+ Args:
259
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
260
+ attention_mask (`torch.FloatTensor`): attention mask of size
261
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
262
+ `(config.encoder_attention_heads,)`.
263
+ output_attentions (`bool`, *optional*):
264
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
265
+ returned tensors for more detail.
266
+ """
267
+ residual = hidden_states
268
+
269
+ hidden_states = self.layer_norm1(hidden_states)
270
+ hidden_states, attn_weights = self.self_attn(
271
+ hidden_states=hidden_states,
272
+ attention_mask=attention_mask,
273
+ output_attentions=output_attentions,
274
+ )
275
+ hidden_states = residual + hidden_states
276
+
277
+ residual = hidden_states
278
+ hidden_states = self.layer_norm2(hidden_states)
279
+ hidden_states = self.mlp(hidden_states)
280
+ hidden_states = residual + hidden_states
281
+
282
+ outputs = (hidden_states,)
283
+
284
+ if output_attentions:
285
+ outputs += (attn_weights,)
286
+
287
+ return outputs
288
+
289
+ def post_norm_forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ attention_mask: torch.Tensor,
293
+ output_attentions: Optional[bool] = False,
294
+ ) -> Tuple[torch.FloatTensor]:
295
+ """
296
+ Args:
297
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
298
+ attention_mask (`torch.FloatTensor`): attention mask of size
299
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
300
+ `(config.encoder_attention_heads,)`.
301
+ output_attentions (`bool`, *optional*):
302
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
303
+ returned tensors for more detail.
304
+ """
305
+ residual = hidden_states
306
+
307
+ hidden_states, attn_weights = self.self_attn(
308
+ hidden_states=hidden_states,
309
+ attention_mask=attention_mask,
310
+ output_attentions=output_attentions,
311
+ )
312
+ hidden_states = residual + hidden_states
313
+
314
+ hidden_states = self.layer_norm1(hidden_states)
315
+
316
+ residual = hidden_states
317
+ hidden_states = self.mlp(hidden_states)
318
+ hidden_states = residual + hidden_states
319
+ hidden_states = self.layer_norm2(hidden_states)
320
+ outputs = (hidden_states,)
321
+
322
+ if output_attentions:
323
+ outputs += (attn_weights,)
324
+
325
+ return outputs
326
+
327
+
328
+ class BiomedCLIPTextProjection(nn.Module):
329
+ def __init__(self, config):
330
+ super().__init__()
331
+ self.config = config
332
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
333
+ self.activation_fn = ACT2FN[config.hidden_act]
334
+ self.fc2 = nn.Linear(config.intermediate_size, config.projection_dim, bias=False)
335
+
336
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
337
+ hidden_states = self.fc1(hidden_states)
338
+ hidden_states = self.activation_fn(hidden_states)
339
+ hidden_states = self.fc2(hidden_states)
340
+ return hidden_states
341
+
342
+
343
+ class BiomedCLIPEncoder(nn.Module):
344
+ """
345
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
346
+ [`BiomedCLIPEncoderLayer`].
347
+ Args:
348
+ config: BiomedCLIPConfig
349
+ """
350
+ def __init__(self, config, norm='pre'):
351
+ super().__init__()
352
+ self.config = config
353
+ self.norm = norm
354
+ self.layers = nn.ModuleList([BiomedCLIPEncoderLayer(config, norm) for _ in range(config.num_hidden_layers)])
355
+ self.gradient_checkpointing = False
356
+
357
+ def forward(
358
+ self,
359
+ hidden_states: torch.Tensor,
360
+ attention_mask: Optional[torch.FloatTensor] = None,
361
+ head_mask: Optional[torch.FloatTensor] = None,
362
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
363
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
364
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
365
+ use_cache: Optional[bool] = None,
366
+ output_attentions: Optional[bool] = False,
367
+ output_hidden_states: Optional[bool] = False,
368
+ return_dict: Optional[bool] = True,
369
+ ) :
370
+ all_hidden_states = () if output_hidden_states else None
371
+ all_self_attentions = () if output_attentions else None
372
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
373
+
374
+ if self.gradient_checkpointing and self.training:
375
+ if use_cache:
376
+ logger.warning_once(
377
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
378
+ )
379
+ use_cache = False
380
+
381
+ next_decoder_cache = () if use_cache else None
382
+ for i, layer_module in enumerate(self.layers):
383
+ if output_hidden_states:
384
+ all_hidden_states = all_hidden_states + (hidden_states,)
385
+
386
+ layer_head_mask = head_mask[i] if head_mask is not None else None
387
+ past_key_value = past_key_values[i] if past_key_values is not None else None
388
+
389
+ if self.gradient_checkpointing and self.training:
390
+ layer_outputs = self._gradient_checkpointing_func(
391
+ layer_module.__call__,
392
+ hidden_states,
393
+ attention_mask,
394
+ output_attentions,
395
+ )
396
+ else:
397
+ layer_outputs = layer_module(
398
+ hidden_states,
399
+ attention_mask,
400
+ output_attentions,
401
+ )
402
+
403
+ hidden_states = layer_outputs[0]
404
+ if use_cache:
405
+ next_decoder_cache += (layer_outputs[-1],)
406
+ if output_attentions:
407
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
408
+ if self.config.add_cross_attention:
409
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
410
+
411
+ if output_hidden_states:
412
+ all_hidden_states = all_hidden_states + (hidden_states,)
413
+
414
+ if not return_dict:
415
+ return tuple(
416
+ v
417
+ for v in [
418
+ hidden_states,
419
+ next_decoder_cache,
420
+ all_hidden_states,
421
+ all_self_attentions,
422
+ all_cross_attentions,
423
+ ]
424
+ if v is not None
425
+ )
426
+ return BaseModelOutputWithPastAndCrossAttentions(
427
+ last_hidden_state=hidden_states,
428
+ past_key_values=next_decoder_cache,
429
+ hidden_states=all_hidden_states,
430
+ attentions=all_self_attentions,
431
+ cross_attentions=all_cross_attentions,
432
+ )
433
+
434
+
435
+
436
+ class BiomedCLIPTextTransformer(CLIPPreTrainedModel):
437
+ def __init__(self, config: CLIPTextConfig):
438
+ super().__init__(config)
439
+ self.config = config
440
+ embed_dim = config.hidden_size
441
+ self.embeddings = BiomedCLIPTextEmbeddings(config)
442
+ self.encoder = BiomedCLIPEncoder(config, norm='post')
443
+ # no final_ln
444
+ # self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
445
+
446
+ # For `pooled_output` computation
447
+
448
+ def forward(
449
+ self,
450
+ input_ids: Optional[torch.Tensor] = None,
451
+ attention_mask: Optional[torch.Tensor] = None,
452
+ token_type_ids: Optional[torch.Tensor] = None,
453
+ position_ids: Optional[torch.Tensor] = None,
454
+ inputs_embeds: Optional[torch.Tensor] = None,
455
+ encoder_hidden_states: Optional[torch.Tensor] = None,
456
+ encoder_attention_mask: Optional[torch.Tensor] = None,
457
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
458
+ use_cache: Optional[bool] = None,
459
+ output_attentions: Optional[bool] = None,
460
+ output_hidden_states: Optional[bool] = None,
461
+ return_dict: Optional[bool] = None,
462
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
463
+ r"""
464
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
465
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
466
+ the model is configured as a decoder.
467
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
468
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
469
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
470
+ - 1 for tokens that are **not masked**,
471
+ - 0 for tokens that are **masked**.
472
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
473
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
474
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
475
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
476
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
477
+ use_cache (`bool`, *optional*):
478
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
479
+ `past_key_values`).
480
+ """
481
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
482
+ output_hidden_states = (
483
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
484
+ )
485
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
+
487
+ if self.config.is_decoder:
488
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
489
+ else:
490
+ use_cache = False
491
+
492
+ if input_ids is not None and inputs_embeds is not None:
493
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
494
+ elif input_ids is not None:
495
+ input_shape = input_ids.size()
496
+ elif inputs_embeds is not None:
497
+ input_shape = inputs_embeds.size()[:-1]
498
+ else:
499
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
500
+
501
+ batch_size, seq_length = input_shape
502
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
503
+
504
+ # past_key_values_length
505
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
506
+
507
+ if token_type_ids is None:
508
+ if hasattr(self.embeddings, "token_type_ids"):
509
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
510
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
511
+ token_type_ids = buffered_token_type_ids_expanded
512
+ else:
513
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
514
+
515
+ embedding_output = self.embeddings(
516
+ input_ids=input_ids,
517
+ position_ids=position_ids,
518
+ token_type_ids=token_type_ids,
519
+ inputs_embeds=inputs_embeds,
520
+ past_key_values_length=past_key_values_length,
521
+ )
522
+
523
+ if attention_mask is None:
524
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
525
+
526
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
527
+ # ourselves in which case we just need to make it broadcastable to all heads.
528
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
529
+
530
+ # If a 2D or 3D attention mask is provided for the cross-attention
531
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
532
+ if self.config.is_decoder and encoder_hidden_states is not None:
533
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
534
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
535
+ if encoder_attention_mask is None:
536
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
537
+
538
+ if use_sdpa_attention_masks:
539
+ # Expand the attention mask for SDPA.
540
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
541
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
542
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
543
+ )
544
+ else:
545
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
546
+ else:
547
+ encoder_extended_attention_mask = None
548
+
549
+
550
+ encoder_outputs = self.encoder(
551
+ embedding_output,
552
+ attention_mask=extended_attention_mask,
553
+ output_attentions=output_attentions,
554
+ output_hidden_states=output_hidden_states,
555
+ return_dict=return_dict,
556
+ )
557
+ sequence_output = encoder_outputs[0]
558
+
559
+ return (sequence_output, sequence_output[:, 0, :])
560
+
561
+
562
+
563
+ class BiomedCLIPVisionTransformer(nn.Module):
564
+ def __init__(self, config: CLIPVisionConfig):
565
+ super().__init__()
566
+ self.config = config
567
+ embed_dim = config.hidden_size
568
+
569
+ self.embeddings = BiomedCLIPVisionEmbeddings(config)
570
+ # No pre_norm in open_clip Vision Tower
571
+ # self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
572
+ self.encoder = BiomedCLIPEncoder(config)
573
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
574
+
575
+ def forward(
576
+ self,
577
+ pixel_values: Optional[torch.FloatTensor] = None,
578
+ output_attentions: Optional[bool] = None,
579
+ output_hidden_states: Optional[bool] = None,
580
+ return_dict: Optional[bool] = None,
581
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
582
+ r"""
583
+ Returns:
584
+ """
585
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
586
+ output_hidden_states = (
587
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
588
+ )
589
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
590
+
591
+ if pixel_values is None:
592
+ raise ValueError("You have to specify pixel_values")
593
+
594
+ hidden_states = self.embeddings(pixel_values)
595
+ # hidden_states = self.pre_layrnorm(hidden_states)
596
+
597
+ encoder_outputs = self.encoder(
598
+ hidden_states=hidden_states,
599
+ output_attentions=output_attentions,
600
+ output_hidden_states=output_hidden_states,
601
+ return_dict=return_dict,
602
+ )
603
+
604
+ last_hidden_state = encoder_outputs[0]
605
+ pooled_output = last_hidden_state[:, 0, :]
606
+ pooled_output = self.post_layernorm(pooled_output)
607
+
608
+ if not return_dict:
609
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
610
+
611
+ return BaseModelOutputWithPooling(
612
+ last_hidden_state=last_hidden_state,
613
+ pooler_output=pooled_output,
614
+ hidden_states=encoder_outputs.hidden_states,
615
+ attentions=encoder_outputs.attentions,
616
+ )
617
+
618
+
619
+ class BiomedCLIPModel(CLIPPreTrainedModel):
620
+ config_class = BiomedCLIPConfig
621
+ _no_split_modules = ["BiomedCLIPTextEmbeddings", "BiomedCLIPEncoderLayer"]
622
+
623
+ def __init__(self, config: BiomedCLIPConfig):
624
+ super().__init__(config)
625
+
626
+ if not isinstance(config.text_config, CLIPTextConfig):
627
+ raise ValueError(
628
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
629
+ f" {type(config.text_config)}."
630
+ )
631
+
632
+ if not isinstance(config.vision_config, CLIPVisionConfig):
633
+ raise ValueError(
634
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
635
+ f" {type(config.vision_config)}."
636
+ )
637
+
638
+ text_config = config.text_config
639
+ text_projection_config = config.text_projection_config
640
+ vision_config = config.vision_config
641
+
642
+
643
+ self.projection_dim = config.projection_dim
644
+ self.text_embed_dim = text_config.hidden_size
645
+ self.vision_embed_dim = vision_config.hidden_size
646
+
647
+ self.text_model = BiomedCLIPTextTransformer(text_config)
648
+ self.vision_model = BiomedCLIPVisionTransformer(vision_config)
649
+
650
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
651
+
652
+ self.text_projection = BiomedCLIPTextProjection(text_projection_config)
653
+
654
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
655
+
656
+ # Initialize weights and apply final processing
657
+ self.post_init()
658
+
659
+ def get_text_features(
660
+ self,
661
+ input_ids: Optional[torch.Tensor] = None,
662
+ attention_mask: Optional[torch.Tensor] = None,
663
+ token_type_ids: Optional[torch.Tensor] = None,
664
+ position_ids: Optional[torch.Tensor] = None,
665
+ output_attentions: Optional[bool] = None,
666
+ output_hidden_states: Optional[bool] = None,
667
+ return_dict: Optional[bool] = None,
668
+ ) -> torch.FloatTensor:
669
+ r"""
670
+ Returns:
671
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
672
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
673
+ Examples:
674
+ ```python
675
+ >>> from transformers import AutoTokenizer, CLIPModel
676
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
677
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
678
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
679
+ >>> text_features = model.get_text_features(**inputs)
680
+ ```"""
681
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
687
+
688
+ text_outputs = self.text_model(
689
+ input_ids=input_ids,
690
+ attention_mask=attention_mask,
691
+ token_type_ids=token_type_ids,
692
+ position_ids=position_ids,
693
+ output_attentions=output_attentions,
694
+ output_hidden_states=output_hidden_states,
695
+ return_dict=return_dict,
696
+ )
697
+
698
+ pooled_output = text_outputs[1]
699
+ text_features = self.text_projection(pooled_output)
700
+
701
+ return text_features
702
+
703
+ def get_image_features(
704
+ self,
705
+ pixel_values: Optional[torch.FloatTensor] = None,
706
+ output_attentions: Optional[bool] = None,
707
+ output_hidden_states: Optional[bool] = None,
708
+ return_dict: Optional[bool] = None,
709
+ ) -> torch.FloatTensor:
710
+ r"""
711
+ Returns:
712
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
713
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
714
+ Examples:
715
+ ```python
716
+ >>> from PIL import Image
717
+ >>> import requests
718
+ >>> from transformers import AutoProcessor, CLIPModel
719
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
720
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
721
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
722
+ >>> image = Image.open(requests.get(url, stream=True).raw)
723
+ >>> inputs = processor(images=image, return_tensors="pt")
724
+ >>> image_features = model.get_image_features(**inputs)
725
+ ```"""
726
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
727
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
728
+ output_hidden_states = (
729
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
730
+ )
731
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
732
+
733
+ vision_outputs = self.vision_model(
734
+ pixel_values=pixel_values,
735
+ output_attentions=output_attentions,
736
+ output_hidden_states=output_hidden_states,
737
+ return_dict=return_dict,
738
+ )
739
+
740
+ pooled_output = vision_outputs[1] # pooled_output
741
+ image_features = self.visual_projection(pooled_output)
742
+
743
+ return image_features
744
+
745
+ def forward(
746
+ self,
747
+ input_ids: Optional[torch.LongTensor] = None,
748
+ pixel_values: Optional[torch.FloatTensor] = None,
749
+ attention_mask: Optional[torch.Tensor] = None,
750
+ token_type_ids: Optional[torch.LongTensor] = None,
751
+ position_ids: Optional[torch.LongTensor] = None,
752
+ return_loss: Optional[bool] = None,
753
+ output_attentions: Optional[bool] = None,
754
+ output_hidden_states: Optional[bool] = None,
755
+ return_dict: Optional[bool] = None,
756
+ ) -> Union[Tuple, CLIPOutput]:
757
+ r"""
758
+ Returns:
759
+ Examples:
760
+ ```python
761
+ >>> from PIL import Image
762
+ >>> import requests
763
+ >>> from transformers import AutoProcessor, CLIPModel
764
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
765
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
766
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
767
+ >>> image = Image.open(requests.get(url, stream=True).raw)
768
+ >>> inputs = processor(
769
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
770
+ ... )
771
+ >>> outputs = model(**inputs)
772
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
773
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
774
+ ```"""
775
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
776
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
777
+ output_hidden_states = (
778
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
779
+ )
780
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
781
+
782
+ vision_outputs = self.vision_model(
783
+ pixel_values=pixel_values,
784
+ output_attentions=output_attentions,
785
+ output_hidden_states=output_hidden_states,
786
+ return_dict=return_dict,
787
+ )
788
+
789
+ text_outputs = self.text_model(
790
+ input_ids=input_ids,
791
+ token_type_ids=token_type_ids,
792
+ attention_mask=attention_mask,
793
+ position_ids=position_ids,
794
+ output_attentions=output_attentions,
795
+ output_hidden_states=output_hidden_states,
796
+ return_dict=return_dict,
797
+ )
798
+
799
+ image_embeds = vision_outputs[1]
800
+ image_embeds = self.visual_projection(image_embeds)
801
+
802
+ text_embeds = text_outputs[1]
803
+ text_embeds = self.text_projection(text_embeds)
804
+
805
+ # normalized features
806
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
807
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
808
+
809
+ # cosine similarity as logits
810
+ logit_scale = self.logit_scale.exp()
811
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
812
+ logits_per_image = logits_per_text.t()
813
+
814
+ loss = None
815
+ if return_loss:
816
+ loss = clip_loss(logits_per_text)
817
+
818
+ if not return_dict:
819
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
820
+ return ((loss,) + output) if loss is not None else output
821
+
822
+ return CLIPOutput(
823
+ loss=loss,
824
+ logits_per_image=logits_per_image,
825
+ logits_per_text=logits_per_text,
826
+ text_embeds=text_embeds,
827
+ image_embeds=image_embeds,
828
+ text_model_output=text_outputs,
829
+ vision_model_output=vision_outputs,
830
+ )
831
+
832
+
833
+ class BiomedCLIPForImageClassification(CLIPPreTrainedModel):
834
+ main_input_name = "pixel_values"
835
+
836
+ def __init__(self, config: BiomedCLIPConfig) -> None:
837
+ super().__init__(config)
838
+
839
+ self.num_labels = config.num_labels
840
+ self.vision_model = BiomedCLIPVisionTransformer(config.vision_config)
841
+
842
+ # Classifier head
843
+ self.classifier = (
844
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
845
+ )
846
+
847
+ # Initialize weights and apply final processing
848
+ self.post_init()
849
+
850
+ def forward(
851
+ self,
852
+ pixel_values: Optional[torch.Tensor] = None,
853
+ labels: Optional[torch.Tensor] = None,
854
+ output_attentions: Optional[bool] = None,
855
+ output_hidden_states: Optional[bool] = None,
856
+ return_dict: Optional[bool] = None,
857
+ ) -> Union[tuple, ImageClassifierOutput]:
858
+ r"""
859
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
860
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
861
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
862
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
863
+ """
864
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
865
+ output_hidden_states = (
866
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
867
+ )
868
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
869
+
870
+ outputs = self.vision_model(
871
+ pixel_values,
872
+ output_attentions=output_attentions,
873
+ output_hidden_states=output_hidden_states,
874
+ return_dict=return_dict,
875
+ )
876
+
877
+ sequence_output = outputs[0]
878
+
879
+ # average pool the patch tokens
880
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
881
+ # apply classifier
882
+ logits = self.classifier(sequence_output)
883
+
884
+ loss = None
885
+ if labels is not None:
886
+ # move labels to correct device to enable model parallelism
887
+ labels = labels.to(logits.device)
888
+ if self.config.problem_type is None:
889
+ if self.num_labels == 1:
890
+ self.config.problem_type = "regression"
891
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
892
+ self.config.problem_type = "single_label_classification"
893
+ else:
894
+ self.config.problem_type = "multi_label_classification"
895
+
896
+ if self.config.problem_type == "regression":
897
+ loss_fct = MSELoss()
898
+ if self.num_labels == 1:
899
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
900
+ else:
901
+ loss = loss_fct(logits, labels)
902
+ elif self.config.problem_type == "single_label_classification":
903
+ loss_fct = CrossEntropyLoss()
904
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
905
+ elif self.config.problem_type == "multi_label_classification":
906
+ loss_fct = BCEWithLogitsLoss()
907
+ loss = loss_fct(logits, labels)
908
+
909
+ if not return_dict:
910
+ output = (logits,) + outputs[2:]
911
+ return ((loss,) + output) if loss is not None else output
912
+
913
+ return ImageClassifierOutput(
914
+ loss=loss,
915
+ logits=logits,
916
+ hidden_states=outputs.hidden_states,
917
+ attentions=outputs.attentions,
918
+ )
text_encoder_BiomedCLIP/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "image_processor_type": "CLIPImageProcessor",
7
+ "tokenizer_type": "BertTokenizer",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
text_encoder_BiomedCLIP/processing_biomed_clip.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for CLIP
17
+ """
18
+
19
+ import warnings
20
+
21
+ from transformers.processing_utils import ProcessorMixin
22
+ from transformers.tokenization_utils_base import BatchEncoding
23
+
24
+
25
+ class BiomedCLIPProcessor(ProcessorMixin):
26
+ r"""
27
+ Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor.
28
+ [`CLIPProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`CLIPTokenizerFast`]. See the
29
+ [`~CLIPProcessor.__call__`] and [`~CLIPProcessor.decode`] for more information.
30
+ Args:
31
+ image_processor ([`CLIPImageProcessor`], *optional*):
32
+ The image processor is a required input.
33
+ tokenizer ([`CLIPTokenizerFast`], *optional*):
34
+ The tokenizer is a required input.
35
+ """
36
+
37
+ attributes = ["image_processor", "tokenizer"]
38
+ image_processor_class = "CLIPImageProcessor"
39
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
40
+
41
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
42
+ feature_extractor = None
43
+ if "feature_extractor" in kwargs:
44
+ warnings.warn(
45
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
46
+ " instead.",
47
+ FutureWarning,
48
+ )
49
+ feature_extractor = kwargs.pop("feature_extractor")
50
+
51
+ image_processor = image_processor if image_processor is not None else feature_extractor
52
+ if image_processor is None:
53
+ raise ValueError("You need to specify an `image_processor`.")
54
+ if tokenizer is None:
55
+ raise ValueError("You need to specify a `tokenizer`.")
56
+
57
+ super().__init__(image_processor, tokenizer)
58
+
59
+ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
60
+ """
61
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
62
+ and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
63
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
64
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
65
+ of the above two methods for more information.
66
+ Args:
67
+ text (`str`, `List[str]`, `List[List[str]]`):
68
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
69
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
70
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
71
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
72
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
73
+ tensor. Both channels-first and channels-last formats are supported.
74
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
75
+ If set, will return tensors of a particular framework. Acceptable values are:
76
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
77
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
78
+ - `'np'`: Return NumPy `np.ndarray` objects.
79
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
80
+ Returns:
81
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
82
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
83
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
84
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
85
+ `None`).
86
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
87
+ """
88
+ tokenizer_kwargs, image_processor_kwargs = {}, {}
89
+ if kwargs:
90
+ tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys}
91
+ image_processor_kwargs = {
92
+ k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys
93
+ }
94
+
95
+ if text is None and images is None:
96
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
97
+
98
+ if text is not None:
99
+ encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs)
100
+
101
+ if images is not None:
102
+ image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs)
103
+
104
+ if text is not None and images is not None:
105
+ encoding["pixel_values"] = image_features.pixel_values
106
+ return encoding
107
+ elif text is not None:
108
+ return encoding
109
+ else:
110
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
111
+
112
+ def batch_decode(self, *args, **kwargs):
113
+ """
114
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
115
+ refer to the docstring of this method for more information.
116
+ """
117
+ return self.tokenizer.batch_decode(*args, **kwargs)
118
+
119
+ def decode(self, *args, **kwargs):
120
+ """
121
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
122
+ the docstring of this method for more information.
123
+ """
124
+ return self.tokenizer.decode(*args, **kwargs)
125
+
126
+ @property
127
+ def model_input_names(self):
128
+ tokenizer_input_names = self.tokenizer.model_input_names
129
+ image_processor_input_names = self.image_processor.model_input_names
130
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
131
+
132
+ @property
133
+ def feature_extractor_class(self):
134
+ warnings.warn(
135
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
136
+ FutureWarning,
137
+ )
138
+ return self.image_processor_class
139
+
140
+ @property
141
+ def feature_extractor(self):
142
+ warnings.warn(
143
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
144
+ FutureWarning,
145
+ )
146
+ return self.image_processor
text_encoder_BiomedCLIP/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bdc400de59a85620ddc7584d06913dc901c47f22647899c6addec71b9a5c9a2
3
+ size 783733062
tokenizer_BiomedCLIP/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_BiomedCLIP/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_BiomedCLIP/tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_basic_tokenize": true,
5
+ "do_lower_case": true,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 1000000000000000019884624838656,
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
tokenizer_BiomedCLIP/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.29.2",
4
+ "_name_or_path": "/mnt/petrelfs/wuhaoning/MedicalGen/MRI_diffusion/train_vae/MedGen_MRI_CT_240930-000404/checkpoint_lastest/",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 1,
19
+ "latent_channels": 16,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 1,
23
+ "sample_size": 512,
24
+ "scaling_factor": 1.003,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eb7d0e0e2f647d82f02eb285d8fb885b1247d1ec9342c5770855a27809ec6f8
3
+ size 335293452