JUNJIE99 commited on
Commit
0dd0ddf
·
verified ·
1 Parent(s): 6feb350

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cir_candi_2.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/cir_query.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/res-ft-mmeb.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/res-scaling.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/res-zs-cir.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/res-zs-mmeb.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 JUNJIE99
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
README.md CHANGED
@@ -1,3 +1,120 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval</h1>
2
+
3
+ <p align="center">
4
+ <a href="https://arxiv.org/abs/2412.14475">
5
+ <img alt="Build" src="http://img.shields.io/badge/cs.CV-arXiv%3A2412.14475-B31B1B.svg">
6
+ </a>
7
+ <a href="https://github.com/VectorSpaceLab/MegaPairs">
8
+ <img alt="Build" src="https://img.shields.io/badge/Github-Code-blue">
9
+ </a>
10
+ <a href="https://huggingface.co/datasets/JUNJIE99/MegaPairs">
11
+ <img alt="Build" src="https://img.shields.io/badge/🤗 Datasets-MegaPairs-yellow">
12
+ </p>
13
+
14
+ <p align="center">
15
+ </a>
16
+ <a href="https://huggingface.co/JUNJIE99/MMRet-base">
17
+ <img alt="Build" src="https://img.shields.io/badge/🤗 Model-MMRet_base-yellow">
18
+ </a>
19
+ <a href="https://huggingface.co/JUNJIE99/MMRet-large">
20
+ <img alt="Build" src="https://img.shields.io/badge/🤗 Model-MMRet_large-yellow">
21
+ </a>
22
+ <a href="https://huggingface.co/JUNJIE99/MMRet-MLLM">
23
+ <img alt="Build" src="https://img.shields.io/badge/🤗 Model-MMRet_MLLM-yellow">
24
+ </a>
25
+ </p>
26
+
27
+ ## News
28
+ ```2024-12-27``` 🚀🚀 MMRet-CLIP models are released in Huggingface: [MMRet-base](https://huggingface.co/JUNJIE99/MMRet-base) and [MMRet-large](https://huggingface.co/JUNJIE99/MMRet-large).
29
+
30
+ ```2024-12-19``` 🎉🎉 Release our paper: [MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval](https://arxiv.org/pdf/2412.14475).
31
+
32
+ ## Release Plan
33
+ - [x] Paper
34
+ - [x] MMRet-base and MMRet-large models
35
+ - [ ] MMRet-MLLM model
36
+ - [ ] MegaPairs Dataset
37
+ - [ ] Evaluation code
38
+ - [ ] Fine-tuning code
39
+
40
+
41
+ ## Introduction
42
+ In this project, we introduce **MegaPairs**, a novel data synthesis method that leverages open-domain images to create *heterogeneous KNN triplets* for universal multimodal retrieval. Our MegaPairs dataset contains over 26 million triplets, and we have trained a series of multimodal retrieval models, **MMRets**, including MMRet-CLIP (base and large) and MMRet-MLLM.
43
+
44
+ MMRets achieve state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB). Extensive experiments demonstrate the ***efficiency, scalability, and generalization*** features of MegaPairs. Please refer to our [paper](https://arxiv.org/abs/2412.14475) for more details.
45
+
46
+ ## Model Usage
47
+
48
+ ### 1. MMRet-CLIP Models
49
+ You can easily use MMRet-CLIP models based on ```transformers```
50
+ ```python
51
+ import torch
52
+ from transformers import AutoModel
53
+
54
+ MODEL_NAME = "JUNJIE99/MMRet-base" # or "JUNJIE99/MMRet-large"
55
+
56
+ model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True
57
+ model.set_processor(MODEL_NAME)
58
+ model.eval()
59
+
60
+ with torch.no_grad():
61
+ query = model.encode(
62
+ images = "./assets/cir_query.png",
63
+ text = "Make the background dark, as if the camera has taken the photo at night"
64
+ )
65
+
66
+ candidates = model.encode(
67
+ images = ["./assets/cir_candi_1.png", "./assets/cir_candi_2.png"]
68
+ )
69
+
70
+ scores = query @ candidates.T
71
+ print(scores)
72
+ ```
73
+
74
+
75
+
76
+
77
+ ### 2. MMRet-MLLM Models
78
+ ```Will be released soon.```
79
+
80
+ ## Model Performance
81
+ ### Zero-Shot Composed Image Retrieval
82
+
83
+ MMRet sets a new performance benchmark in zero-shot composed image retrieval tasks. On the CIRCO benchmark, our MMRet-base model, with only 149 million parameters, surpasses all previous models, including those with 50 times more parameters. Additionally, MMRet-MLLM achieves an 8.1% improvement over the previous state-of-the-art model.
84
+
85
+ <img src="./assets/res-zs-cir.png" width="800">
86
+
87
+ ### Zero-Shot Performance on MMEB
88
+
89
+ MMRet-MLLM achieves state-of-the-art zero-shot performance on the Massive Multimodal Embedding Benchmark (MMEB), despite being trained only on the ImageText-to-Image paradigm. This demonstrates the excellent generalization capability of MegaPairs for multimodal embedding.
90
+
91
+ <img src="./assets/res-zs-mmeb.png" width="800">
92
+
93
+ ### Fine-Tuning Performance on MMEB
94
+
95
+ After fine-tuning on downstream tasks, MMRet-MLLM maintains its leading performance. Notably, it surpasses the previous state-of-the-art by 7.1% on the MMEB out-of-distribution (OOD) set. These results demonstrate the robust generalization capability of MMRet-MLLM and highlight the potential of MegaPairs as foundational training data for universal multimodal embedding.
96
+
97
+ <img src="./assets/res-ft-mmeb.png" width="800">
98
+
99
+ ### Performance Scaling
100
+ MegaPairs showcases **scalability**: MMRet-base improves as training data increases. It also demonstrates **efficiency**: with just 0.5M training samples, MMRet-base significantly outperforms MagicLens, which uses the same CLIP-base backbone and was trained on 36.7M samples.
101
+
102
+ <img src="./assets/res-scaling.png" width="800">
103
+
104
+
105
+ ## License
106
+ The annotations for MegaPairs and the MMRet models are released under the [MIT License](LICENSE). The images in MegaPairs originate from the [Recap-Datacomp](https://huggingface.co/datasets/UCSC-VLAA/Recap-DataComp-1B), which is released under the CC BY 4.0 license.
107
+
108
+
109
+
110
+ ## Citation
111
+ If you find this repository useful, please consider giving a star ⭐ and citation
112
+
113
+ ```
114
+ @article{zhou2024megapairs,
115
+ title={MegaPairs: Massive Data Synthesis For Universal Multimodal Retrieval},
116
+ author={Zhou, Junjie and Liu, Zheng and Liu, Ze and Xiao, Shitao and Wang, Yueze and Zhao, Bo and Zhang, Chen Jason and Lian, Defu and Xiong, Yongping},
117
+ journal={arXiv preprint arXiv:2412.14475},
118
+ year={2024}
119
+ }
120
+ ```
assets/cir_candi_1.png ADDED
assets/cir_candi_2.png ADDED

Git LFS Details

  • SHA256: e3c4b3debd66d4789545d52bdd4c18bdbabaa072b69110ad08ff54b7b4503557
  • Pointer size: 131 Bytes
  • Size of remote file: 901 kB
assets/cir_query.png ADDED

Git LFS Details

  • SHA256: 42ad0e1200ba5c8767ea752f6358f6b51719674b064666b415b6ab0ea90ae62b
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
assets/res-ft-mmeb.png ADDED

Git LFS Details

  • SHA256: 7011ea195b7e51c1e206c735072ec5d2acb96a5aa22fb7ff280523be21c2ba0f
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
assets/res-scaling.png ADDED

Git LFS Details

  • SHA256: ad36953de26851baac1ac625467d87517d5151714104a92cc89f421d2c7326f0
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
assets/res-zs-cir.png ADDED

Git LFS Details

  • SHA256: 61da2bc03b36c97c48d3e52bb96ccaecbb99d5d913847f665cbcd2d6a84435df
  • Pointer size: 131 Bytes
  • Size of remote file: 506 kB
assets/res-zs-mmeb.png ADDED

Git LFS Details

  • SHA256: 9e3fb656f33f9b2ff4290d9a900a6bcec40c3df97ff308cdfef82a62209bbf97
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
config.json ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPModel"
4
+ ],
5
+ "initializer_factor": 1.0,
6
+ "logit_scale_init_value": 2.6592,
7
+ "model_type": "clip",
8
+ "projection_dim": 512,
9
+ "auto_map": {
10
+ "AutoModel": "modeling_MMRet_CLIP.CLIPModel"
11
+ },
12
+ "text_config": {
13
+ "_name_or_path": "",
14
+ "add_cross_attention": false,
15
+ "architectures": null,
16
+ "attention_dropout": 0.0,
17
+ "bad_words_ids": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "decoder_start_token_id": null,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "dropout": 0.0,
24
+ "early_stopping": false,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": 2,
27
+ "finetuning_task": null,
28
+ "forced_bos_token_id": null,
29
+ "forced_eos_token_id": null,
30
+ "hidden_act": "quick_gelu",
31
+ "hidden_size": 512,
32
+ "id2label": {
33
+ "0": "LABEL_0",
34
+ "1": "LABEL_1"
35
+ },
36
+ "initializer_factor": 1.0,
37
+ "initializer_range": 0.02,
38
+ "intermediate_size": 2048,
39
+ "is_decoder": false,
40
+ "is_encoder_decoder": false,
41
+ "label2id": {
42
+ "LABEL_0": 0,
43
+ "LABEL_1": 1
44
+ },
45
+ "layer_norm_eps": 1e-05,
46
+ "length_penalty": 1.0,
47
+ "max_length": 20,
48
+ "max_position_embeddings": 77,
49
+ "min_length": 0,
50
+ "model_type": "clip_text_model",
51
+ "no_repeat_ngram_size": 0,
52
+ "num_attention_heads": 8,
53
+ "num_beam_groups": 1,
54
+ "num_beams": 1,
55
+ "num_hidden_layers": 12,
56
+ "num_return_sequences": 1,
57
+ "output_attentions": false,
58
+ "output_hidden_states": false,
59
+ "output_scores": false,
60
+ "pad_token_id": 1,
61
+ "prefix": null,
62
+ "problem_type": null,
63
+ "projection_dim" : 512,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "sep_token_id": null,
70
+ "task_specific_params": null,
71
+ "temperature": 1.0,
72
+ "tie_encoder_decoder": false,
73
+ "tie_word_embeddings": true,
74
+ "tokenizer_class": null,
75
+ "top_k": 50,
76
+ "top_p": 1.0,
77
+ "torch_dtype": null,
78
+ "torchscript": false,
79
+ "transformers_version": "4.12.0.dev0",
80
+ "use_bfloat16": false,
81
+ "vocab_size": 49408
82
+ },
83
+ "text_config_dict": null,
84
+ "torch_dtype": "bfloat16",
85
+ "transformers_version": null,
86
+ "vision_config": {
87
+ "_name_or_path": "",
88
+ "add_cross_attention": false,
89
+ "architectures": null,
90
+ "attention_dropout": 0.0,
91
+ "bad_words_ids": null,
92
+ "bos_token_id": null,
93
+ "chunk_size_feed_forward": 0,
94
+ "decoder_start_token_id": null,
95
+ "diversity_penalty": 0.0,
96
+ "do_sample": false,
97
+ "dropout": 0.0,
98
+ "early_stopping": false,
99
+ "encoder_no_repeat_ngram_size": 0,
100
+ "eos_token_id": null,
101
+ "finetuning_task": null,
102
+ "forced_bos_token_id": null,
103
+ "forced_eos_token_id": null,
104
+ "hidden_act": "quick_gelu",
105
+ "hidden_size": 768,
106
+ "id2label": {
107
+ "0": "LABEL_0",
108
+ "1": "LABEL_1"
109
+ },
110
+ "image_size": 224,
111
+ "initializer_factor": 1.0,
112
+ "initializer_range": 0.02,
113
+ "intermediate_size": 3072,
114
+ "is_decoder": false,
115
+ "is_encoder_decoder": false,
116
+ "label2id": {
117
+ "LABEL_0": 0,
118
+ "LABEL_1": 1
119
+ },
120
+ "layer_norm_eps": 1e-05,
121
+ "length_penalty": 1.0,
122
+ "max_length": 20,
123
+ "min_length": 0,
124
+ "model_type": "clip_vision_model",
125
+ "no_repeat_ngram_size": 0,
126
+ "num_attention_heads": 12,
127
+ "num_beam_groups": 1,
128
+ "num_beams": 1,
129
+ "num_hidden_layers": 12,
130
+ "num_return_sequences": 1,
131
+ "output_attentions": false,
132
+ "output_hidden_states": false,
133
+ "output_scores": false,
134
+ "pad_token_id": null,
135
+ "patch_size": 16,
136
+ "prefix": null,
137
+ "problem_type": null,
138
+ "projection_dim" : 512,
139
+ "pruned_heads": {},
140
+ "remove_invalid_values": false,
141
+ "repetition_penalty": 1.0,
142
+ "return_dict": true,
143
+ "return_dict_in_generate": false,
144
+ "sep_token_id": null,
145
+ "task_specific_params": null,
146
+ "temperature": 1.0,
147
+ "tie_encoder_decoder": false,
148
+ "tie_word_embeddings": true,
149
+ "tokenizer_class": null,
150
+ "top_k": 50,
151
+ "top_p": 1.0,
152
+ "torch_dtype": null,
153
+ "torchscript": false,
154
+ "transformers_version": "4.12.0.dev0",
155
+ "use_bfloat16": false
156
+ },
157
+ "vision_config_dict": {
158
+ "patch_size": 16
159
+ }
160
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:300ab945304bfa6d6e26046db1867815d326d7156c019fb39ba725472bc6c846
3
+ size 299289098
modeling_MMRet_CLIP.py ADDED
@@ -0,0 +1,1678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CLIP model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ from PIL import Image
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
27
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_flash_attn_2_available,
36
+ is_flash_attn_greater_or_equal_2_10,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
+ from transformers import CLIPProcessor
42
+
43
+ if is_flash_attn_2_available():
44
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ # General docstring
50
+ _CONFIG_FOR_DOC = "MMRet_CLIP"
51
+
52
+ # Image classification docstring
53
+ _IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base"
54
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
55
+
56
+
57
+ # contrastive loss function, adapted from
58
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
59
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
60
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
61
+
62
+
63
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
64
+ caption_loss = contrastive_loss(similarity)
65
+ image_loss = contrastive_loss(similarity.t())
66
+ return (caption_loss + image_loss) / 2.0
67
+
68
+
69
+ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
72
+ model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
73
+ """
74
+ square_tensor = torch.pow(tensor, 2)
75
+ sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
76
+ normed_tensor = torch.pow(sum_tensor, 0.5)
77
+ return normed_tensor
78
+
79
+
80
+ @dataclass
81
+ class CLIPVisionModelOutput(ModelOutput):
82
+ """
83
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
84
+
85
+ Args:
86
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
87
+ The image embeddings obtained by applying the projection layer to the pooler_output.
88
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
89
+ Sequence of hidden-states at the output of the last layer of the model.
90
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
91
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
92
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
93
+
94
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
95
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
96
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
97
+ sequence_length)`.
98
+
99
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
100
+ heads.
101
+ """
102
+
103
+ image_embeds: Optional[torch.FloatTensor] = None
104
+ last_hidden_state: torch.FloatTensor = None
105
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
106
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
107
+
108
+
109
+ @dataclass
110
+ class CLIPTextModelOutput(ModelOutput):
111
+ """
112
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
113
+
114
+ Args:
115
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
116
+ The text embeddings obtained by applying the projection layer to the pooler_output.
117
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
118
+ Sequence of hidden-states at the output of the last layer of the model.
119
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
120
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
121
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
122
+
123
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
124
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
125
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
126
+ sequence_length)`.
127
+
128
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
129
+ heads.
130
+ """
131
+
132
+ text_embeds: Optional[torch.FloatTensor] = None
133
+ last_hidden_state: torch.FloatTensor = None
134
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
135
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
136
+
137
+
138
+ @dataclass
139
+ class CLIPOutput(ModelOutput):
140
+ """
141
+ Args:
142
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
143
+ Contrastive loss for image-text similarity.
144
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
145
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
146
+ similarity scores.
147
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
148
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
149
+ similarity scores.
150
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
151
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
152
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
153
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
154
+ text_model_output (`BaseModelOutputWithPooling`):
155
+ The output of the [`CLIPTextModel`].
156
+ vision_model_output (`BaseModelOutputWithPooling`):
157
+ The output of the [`CLIPVisionModel`].
158
+ """
159
+
160
+ loss: Optional[torch.FloatTensor] = None
161
+ logits_per_image: torch.FloatTensor = None
162
+ logits_per_text: torch.FloatTensor = None
163
+ text_embeds: torch.FloatTensor = None
164
+ image_embeds: torch.FloatTensor = None
165
+ text_model_output: BaseModelOutputWithPooling = None
166
+ vision_model_output: BaseModelOutputWithPooling = None
167
+
168
+ def to_tuple(self) -> Tuple[Any]:
169
+ return tuple(
170
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
171
+ for k in self.keys()
172
+ )
173
+
174
+
175
+ class CLIPVisionEmbeddings(nn.Module):
176
+ def __init__(self, config: CLIPVisionConfig):
177
+ super().__init__()
178
+ self.config = config
179
+ self.embed_dim = config.hidden_size
180
+ self.image_size = config.image_size
181
+ self.patch_size = config.patch_size
182
+
183
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
184
+
185
+ self.patch_embedding = nn.Conv2d(
186
+ in_channels=config.num_channels,
187
+ out_channels=self.embed_dim,
188
+ kernel_size=self.patch_size,
189
+ stride=self.patch_size,
190
+ bias=False,
191
+ )
192
+
193
+ self.num_patches = (self.image_size // self.patch_size) ** 2
194
+ self.num_positions = self.num_patches + 1
195
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
196
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
197
+
198
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
199
+ batch_size = pixel_values.shape[0]
200
+ target_dtype = self.patch_embedding.weight.dtype
201
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
202
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
203
+
204
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
205
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
206
+ embeddings = embeddings + self.position_embedding(self.position_ids)
207
+ return embeddings
208
+
209
+
210
+ class CLIPTextEmbeddings(nn.Module):
211
+ def __init__(self, config: CLIPTextConfig):
212
+ super().__init__()
213
+ embed_dim = config.hidden_size
214
+
215
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
216
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
217
+
218
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
219
+ self.register_buffer(
220
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
221
+ )
222
+
223
+ def forward(
224
+ self,
225
+ input_ids: Optional[torch.LongTensor] = None,
226
+ position_ids: Optional[torch.LongTensor] = None,
227
+ inputs_embeds: Optional[torch.FloatTensor] = None,
228
+ ) -> torch.Tensor:
229
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
230
+
231
+ if position_ids is None:
232
+ position_ids = self.position_ids[:, :seq_length]
233
+
234
+ if inputs_embeds is None:
235
+ inputs_embeds = self.token_embedding(input_ids)
236
+
237
+ position_embeddings = self.position_embedding(position_ids)
238
+ embeddings = inputs_embeds + position_embeddings
239
+
240
+ return embeddings
241
+
242
+
243
+ class CLIPAttention(nn.Module):
244
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
245
+
246
+ def __init__(self, config):
247
+ super().__init__()
248
+ self.config = config
249
+ self.embed_dim = config.hidden_size
250
+ self.num_heads = config.num_attention_heads
251
+ self.head_dim = self.embed_dim // self.num_heads
252
+ if self.head_dim * self.num_heads != self.embed_dim:
253
+ raise ValueError(
254
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
255
+ f" {self.num_heads})."
256
+ )
257
+ self.scale = self.head_dim**-0.5
258
+ self.dropout = config.attention_dropout
259
+
260
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
261
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
262
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
263
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
264
+
265
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
266
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
267
+
268
+ def forward(
269
+ self,
270
+ hidden_states: torch.Tensor,
271
+ attention_mask: Optional[torch.Tensor] = None,
272
+ causal_attention_mask: Optional[torch.Tensor] = None,
273
+ output_attentions: Optional[bool] = False,
274
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
275
+ """Input shape: Batch x Time x Channel"""
276
+
277
+ bsz, tgt_len, embed_dim = hidden_states.size()
278
+
279
+ # get query proj
280
+ query_states = self.q_proj(hidden_states) * self.scale
281
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
282
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
283
+
284
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
285
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
286
+ key_states = key_states.view(*proj_shape)
287
+ value_states = value_states.view(*proj_shape)
288
+
289
+ src_len = key_states.size(1)
290
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
291
+
292
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
293
+ raise ValueError(
294
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
295
+ f" {attn_weights.size()}"
296
+ )
297
+
298
+ # apply the causal_attention_mask first
299
+ if causal_attention_mask is not None:
300
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
301
+ raise ValueError(
302
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
303
+ f" {causal_attention_mask.size()}"
304
+ )
305
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
306
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
307
+
308
+ if attention_mask is not None:
309
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
310
+ raise ValueError(
311
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
312
+ )
313
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
314
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
315
+
316
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
317
+
318
+ if output_attentions:
319
+ # this operation is a bit akward, but it's required to
320
+ # make sure that attn_weights keeps its gradient.
321
+ # In order to do so, attn_weights have to reshaped
322
+ # twice and have to be reused in the following
323
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
324
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
325
+ else:
326
+ attn_weights_reshaped = None
327
+
328
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
329
+
330
+ attn_output = torch.bmm(attn_probs, value_states)
331
+
332
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
333
+ raise ValueError(
334
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
335
+ f" {attn_output.size()}"
336
+ )
337
+
338
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
339
+ attn_output = attn_output.transpose(1, 2)
340
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
341
+
342
+ attn_output = self.out_proj(attn_output)
343
+
344
+ return attn_output, attn_weights_reshaped
345
+
346
+
347
+ class CLIPFlashAttention2(CLIPAttention):
348
+ """
349
+ CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
350
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
351
+ flash attention and deal with padding tokens in case the input contains any of them.
352
+ """
353
+
354
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
355
+ def __init__(self, *args, **kwargs):
356
+ super().__init__(*args, **kwargs)
357
+
358
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
359
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
360
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
361
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
362
+
363
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ attention_mask: Optional[torch.Tensor] = None,
368
+ causal_attention_mask: Optional[torch.Tensor] = None,
369
+ output_attentions: Optional[bool] = False,
370
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
371
+ output_attentions = False
372
+
373
+ batch_size, q_len, _ = hidden_states.size()
374
+
375
+ query_states = self.q_proj(hidden_states)
376
+ key_states = self.k_proj(hidden_states)
377
+ value_states = self.v_proj(hidden_states)
378
+
379
+ # Flash attention requires the input to have the shape
380
+ # batch_size x seq_length x head_dim x hidden_dim
381
+ # therefore we just need to keep the original shape
382
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
383
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
384
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
385
+
386
+ dropout_rate = self.dropout if self.training else 0.0
387
+
388
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
389
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
390
+ # cast them back in the correct dtype just to be sure everything works as expected.
391
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
392
+ # in fp32.
393
+
394
+ input_dtype = query_states.dtype
395
+ if input_dtype == torch.float32:
396
+ if torch.is_autocast_enabled():
397
+ target_dtype = torch.get_autocast_gpu_dtype()
398
+ # Handle the case where the model is quantized
399
+ elif hasattr(self.config, "_pre_quantization_dtype"):
400
+ target_dtype = self.config._pre_quantization_dtype
401
+ else:
402
+ target_dtype = self.q_proj.weight.dtype
403
+
404
+ logger.warning_once(
405
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
406
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
407
+ f" {target_dtype}."
408
+ )
409
+
410
+ query_states = query_states.to(target_dtype)
411
+ key_states = key_states.to(target_dtype)
412
+ value_states = value_states.to(target_dtype)
413
+
414
+ attn_output = _flash_attention_forward(
415
+ query_states,
416
+ key_states,
417
+ value_states,
418
+ attention_mask,
419
+ q_len,
420
+ dropout=dropout_rate,
421
+ is_causal=causal_attention_mask is not None,
422
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
423
+ )
424
+
425
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
426
+ attn_output = self.out_proj(attn_output)
427
+
428
+ if not output_attentions:
429
+ attn_weights = None
430
+
431
+ return attn_output, attn_weights
432
+
433
+
434
+ class CLIPSdpaAttention(CLIPAttention):
435
+ """
436
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
437
+ `CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
438
+ SDPA API.
439
+ """
440
+
441
+ # Adapted from CLIPAttention.forward
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.Tensor,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ causal_attention_mask: Optional[torch.Tensor] = None,
447
+ output_attentions: Optional[bool] = False,
448
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
449
+ if output_attentions:
450
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
451
+ logger.warning_once(
452
+ "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
453
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
454
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
455
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
456
+ )
457
+ return super().forward(
458
+ hidden_states=hidden_states,
459
+ attention_mask=attention_mask,
460
+ causal_attention_mask=causal_attention_mask,
461
+ output_attentions=output_attentions,
462
+ )
463
+
464
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
465
+ if attention_mask is not None and causal_attention_mask is not None:
466
+ attn_mask = attention_mask + causal_attention_mask
467
+ elif causal_attention_mask is not None:
468
+ attn_mask = causal_attention_mask
469
+ else:
470
+ attn_mask = attention_mask
471
+
472
+ bsz, tgt_len, embed_dim = hidden_states.size()
473
+
474
+ query_states = self.q_proj(hidden_states)
475
+ key_states = self.k_proj(hidden_states)
476
+ value_states = self.v_proj(hidden_states)
477
+
478
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
479
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
480
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
481
+
482
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
483
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
484
+ if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
485
+ query_states = query_states.contiguous()
486
+ key_states = key_states.contiguous()
487
+ value_states = value_states.contiguous()
488
+
489
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
490
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
491
+ query_states,
492
+ key_states,
493
+ value_states,
494
+ attn_mask=attn_mask,
495
+ dropout_p=self.dropout if self.training else 0.0,
496
+ scale=self.scale,
497
+ )
498
+
499
+ attn_output = attn_output.transpose(1, 2)
500
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
501
+
502
+ attn_output = self.out_proj(attn_output)
503
+
504
+ return attn_output, None
505
+
506
+
507
+ CLIP_ATTENTION_CLASSES = {
508
+ "eager": CLIPAttention,
509
+ "sdpa": CLIPSdpaAttention,
510
+ "flash_attention_2": CLIPFlashAttention2,
511
+ }
512
+
513
+
514
+ class CLIPMLP(nn.Module):
515
+ def __init__(self, config):
516
+ super().__init__()
517
+ self.config = config
518
+ self.activation_fn = ACT2FN[config.hidden_act]
519
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
520
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
521
+
522
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
523
+ hidden_states = self.fc1(hidden_states)
524
+ hidden_states = self.activation_fn(hidden_states)
525
+ hidden_states = self.fc2(hidden_states)
526
+ return hidden_states
527
+
528
+
529
+ class CLIPEncoderLayer(nn.Module):
530
+ def __init__(self, config: CLIPConfig):
531
+ super().__init__()
532
+ self.embed_dim = config.hidden_size
533
+ self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
534
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
535
+ self.mlp = CLIPMLP(config)
536
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
537
+
538
+ def forward(
539
+ self,
540
+ hidden_states: torch.Tensor,
541
+ attention_mask: torch.Tensor,
542
+ causal_attention_mask: torch.Tensor,
543
+ output_attentions: Optional[bool] = False,
544
+ ) -> Tuple[torch.FloatTensor]:
545
+ """
546
+ Args:
547
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
548
+ attention_mask (`torch.FloatTensor`): attention mask of size
549
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
550
+ `(config.encoder_attention_heads,)`.
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
553
+ returned tensors for more detail.
554
+ """
555
+ residual = hidden_states
556
+
557
+ hidden_states = self.layer_norm1(hidden_states)
558
+ hidden_states, attn_weights = self.self_attn(
559
+ hidden_states=hidden_states,
560
+ attention_mask=attention_mask,
561
+ causal_attention_mask=causal_attention_mask,
562
+ output_attentions=output_attentions,
563
+ )
564
+ hidden_states = residual + hidden_states
565
+
566
+ residual = hidden_states
567
+ hidden_states = self.layer_norm2(hidden_states)
568
+ hidden_states = self.mlp(hidden_states)
569
+ hidden_states = residual + hidden_states
570
+
571
+ outputs = (hidden_states,)
572
+
573
+ if output_attentions:
574
+ outputs += (attn_weights,)
575
+
576
+ return outputs
577
+
578
+
579
+ class CLIPPreTrainedModel(PreTrainedModel):
580
+ """
581
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
582
+ models.
583
+ """
584
+
585
+ config_class = CLIPConfig
586
+ base_model_prefix = "clip"
587
+ supports_gradient_checkpointing = True
588
+ _supports_sdpa = True
589
+ _supports_flash_attn_2 = True
590
+
591
+ def _init_weights(self, module):
592
+ """Initialize the weights"""
593
+ factor = self.config.initializer_factor
594
+ if isinstance(module, CLIPTextEmbeddings):
595
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
596
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
597
+ elif isinstance(module, CLIPVisionEmbeddings):
598
+ factor = self.config.initializer_factor
599
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
600
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
601
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
602
+ elif isinstance(module, CLIPAttention):
603
+ factor = self.config.initializer_factor
604
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
605
+ out_proj_std = (module.embed_dim**-0.5) * factor
606
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
607
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
608
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
609
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
610
+ elif isinstance(module, CLIPMLP):
611
+ factor = self.config.initializer_factor
612
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
613
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
614
+ nn.init.normal_(module.fc1.weight, std=fc_std)
615
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
616
+ elif isinstance(module, CLIPModel):
617
+ nn.init.normal_(
618
+ module.text_projection.weight,
619
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
620
+ )
621
+ nn.init.normal_(
622
+ module.visual_projection.weight,
623
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
624
+ )
625
+ elif isinstance(module, CLIPVisionModelWithProjection):
626
+ nn.init.normal_(
627
+ module.visual_projection.weight,
628
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
629
+ )
630
+ elif isinstance(module, CLIPTextModelWithProjection):
631
+ nn.init.normal_(
632
+ module.text_projection.weight,
633
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
634
+ )
635
+ elif isinstance(module, CLIPForImageClassification):
636
+ nn.init.normal_(
637
+ module.classifier.weight,
638
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
639
+ )
640
+
641
+ if isinstance(module, nn.LayerNorm):
642
+ module.bias.data.zero_()
643
+ module.weight.data.fill_(1.0)
644
+ if isinstance(module, nn.Linear) and module.bias is not None:
645
+ module.bias.data.zero_()
646
+
647
+
648
+ CLIP_START_DOCSTRING = r"""
649
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
650
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
651
+ etc.)
652
+
653
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
654
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
655
+ and behavior.
656
+
657
+ Parameters:
658
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
659
+ Initializing with a config file does not load the weights associated with the model, only the
660
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
661
+ """
662
+
663
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
664
+ Args:
665
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
666
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
667
+ it.
668
+
669
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
670
+ [`PreTrainedTokenizer.__call__`] for details.
671
+
672
+ [What are input IDs?](../glossary#input-ids)
673
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
674
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
675
+
676
+ - 1 for tokens that are **not masked**,
677
+ - 0 for tokens that are **masked**.
678
+
679
+ [What are attention masks?](../glossary#attention-mask)
680
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
681
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
682
+ config.max_position_embeddings - 1]`.
683
+
684
+ [What are position IDs?](../glossary#position-ids)
685
+ output_attentions (`bool`, *optional*):
686
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
687
+ tensors for more detail.
688
+ output_hidden_states (`bool`, *optional*):
689
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
690
+ more detail.
691
+ return_dict (`bool`, *optional*):
692
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
693
+ """
694
+
695
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
696
+ Args:
697
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
698
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
699
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
700
+ output_attentions (`bool`, *optional*):
701
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
702
+ tensors for more detail.
703
+ output_hidden_states (`bool`, *optional*):
704
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
705
+ more detail.
706
+ return_dict (`bool`, *optional*):
707
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
708
+ """
709
+
710
+ CLIP_INPUTS_DOCSTRING = r"""
711
+ Args:
712
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
713
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
714
+ it.
715
+
716
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
717
+ [`PreTrainedTokenizer.__call__`] for details.
718
+
719
+ [What are input IDs?](../glossary#input-ids)
720
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
721
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
722
+
723
+ - 1 for tokens that are **not masked**,
724
+ - 0 for tokens that are **masked**.
725
+
726
+ [What are attention masks?](../glossary#attention-mask)
727
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
728
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
729
+ config.max_position_embeddings - 1]`.
730
+
731
+ [What are position IDs?](../glossary#position-ids)
732
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
733
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
734
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
735
+ return_loss (`bool`, *optional*):
736
+ Whether or not to return the contrastive loss.
737
+ output_attentions (`bool`, *optional*):
738
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
739
+ tensors for more detail.
740
+ output_hidden_states (`bool`, *optional*):
741
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
742
+ more detail.
743
+ return_dict (`bool`, *optional*):
744
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
745
+ """
746
+
747
+
748
+ class CLIPEncoder(nn.Module):
749
+ """
750
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
751
+ [`CLIPEncoderLayer`].
752
+
753
+ Args:
754
+ config: CLIPConfig
755
+ """
756
+
757
+ def __init__(self, config: CLIPConfig):
758
+ super().__init__()
759
+ self.config = config
760
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
761
+ self.gradient_checkpointing = False
762
+
763
+ def forward(
764
+ self,
765
+ inputs_embeds,
766
+ attention_mask: Optional[torch.Tensor] = None,
767
+ causal_attention_mask: Optional[torch.Tensor] = None,
768
+ output_attentions: Optional[bool] = None,
769
+ output_hidden_states: Optional[bool] = None,
770
+ return_dict: Optional[bool] = None,
771
+ ) -> Union[Tuple, BaseModelOutput]:
772
+ r"""
773
+ Args:
774
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
775
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
776
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
777
+ than the model's internal embedding lookup matrix.
778
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
779
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
780
+
781
+ - 1 for tokens that are **not masked**,
782
+ - 0 for tokens that are **masked**.
783
+
784
+ [What are attention masks?](../glossary#attention-mask)
785
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
786
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
787
+
788
+ - 1 for tokens that are **not masked**,
789
+ - 0 for tokens that are **masked**.
790
+
791
+ [What are attention masks?](../glossary#attention-mask)
792
+ output_attentions (`bool`, *optional*):
793
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794
+ returned tensors for more detail.
795
+ output_hidden_states (`bool`, *optional*):
796
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
797
+ for more detail.
798
+ return_dict (`bool`, *optional*):
799
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
800
+ """
801
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
802
+ output_hidden_states = (
803
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
804
+ )
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ encoder_states = () if output_hidden_states else None
808
+ all_attentions = () if output_attentions else None
809
+
810
+ hidden_states = inputs_embeds
811
+ for idx, encoder_layer in enumerate(self.layers):
812
+ if output_hidden_states:
813
+ encoder_states = encoder_states + (hidden_states,)
814
+ if self.gradient_checkpointing and self.training:
815
+ layer_outputs = self._gradient_checkpointing_func(
816
+ encoder_layer.__call__,
817
+ hidden_states,
818
+ attention_mask,
819
+ causal_attention_mask,
820
+ output_attentions,
821
+ )
822
+ else:
823
+ layer_outputs = encoder_layer(
824
+ hidden_states,
825
+ attention_mask,
826
+ causal_attention_mask,
827
+ output_attentions=output_attentions,
828
+ )
829
+
830
+ hidden_states = layer_outputs[0]
831
+
832
+ if output_attentions:
833
+ all_attentions = all_attentions + (layer_outputs[1],)
834
+
835
+ if output_hidden_states:
836
+ encoder_states = encoder_states + (hidden_states,)
837
+
838
+ if not return_dict:
839
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
840
+ return BaseModelOutput(
841
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
842
+ )
843
+
844
+
845
+ class CLIPTextTransformer(nn.Module):
846
+ def __init__(self, config: CLIPTextConfig):
847
+ super().__init__()
848
+ self.config = config
849
+ embed_dim = config.hidden_size
850
+ self.embeddings = CLIPTextEmbeddings(config)
851
+ self.encoder = CLIPEncoder(config)
852
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
853
+
854
+ # For `pooled_output` computation
855
+ self.eos_token_id = config.eos_token_id
856
+
857
+ # For attention mask, it differs between `flash_attention_2` and other attention implementations
858
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
859
+
860
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
861
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
862
+ def forward(
863
+ self,
864
+ input_ids: Optional[torch.Tensor] = None,
865
+ attention_mask: Optional[torch.Tensor] = None,
866
+ position_ids: Optional[torch.Tensor] = None,
867
+ output_attentions: Optional[bool] = None,
868
+ output_hidden_states: Optional[bool] = None,
869
+ return_dict: Optional[bool] = None,
870
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
871
+ r"""
872
+ Returns:
873
+
874
+ """
875
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
876
+ output_hidden_states = (
877
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
878
+ )
879
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
880
+
881
+ if input_ids is None:
882
+ raise ValueError("You have to specify input_ids")
883
+
884
+ input_shape = input_ids.size()
885
+ input_ids = input_ids.view(-1, input_shape[-1])
886
+
887
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
888
+
889
+ # CLIP's text model uses causal mask, prepare it here.
890
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
891
+ causal_attention_mask = _create_4d_causal_attention_mask(
892
+ input_shape, hidden_states.dtype, device=hidden_states.device
893
+ )
894
+
895
+ # expand attention_mask
896
+ if attention_mask is not None and not self._use_flash_attention_2:
897
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
898
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
899
+
900
+ encoder_outputs = self.encoder(
901
+ inputs_embeds=hidden_states,
902
+ attention_mask=attention_mask,
903
+ causal_attention_mask=causal_attention_mask,
904
+ output_attentions=output_attentions,
905
+ output_hidden_states=output_hidden_states,
906
+ return_dict=return_dict,
907
+ )
908
+
909
+ last_hidden_state = encoder_outputs[0]
910
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
911
+
912
+ if self.eos_token_id == 2:
913
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
914
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
915
+ # ------------------------------------------------------------
916
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
917
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
918
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
919
+ pooled_output = last_hidden_state[
920
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
921
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
922
+ ]
923
+ else:
924
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
925
+ pooled_output = last_hidden_state[
926
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
927
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
928
+ # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
929
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
930
+ .int()
931
+ .argmax(dim=-1),
932
+ ]
933
+
934
+ if not return_dict:
935
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
936
+
937
+ return BaseModelOutputWithPooling(
938
+ last_hidden_state=last_hidden_state,
939
+ pooler_output=pooled_output,
940
+ hidden_states=encoder_outputs.hidden_states,
941
+ attentions=encoder_outputs.attentions,
942
+ )
943
+
944
+
945
+ @add_start_docstrings(
946
+ """The text model from CLIP without any head or projection on top.""",
947
+ CLIP_START_DOCSTRING,
948
+ )
949
+ class CLIPTextModel(CLIPPreTrainedModel):
950
+ config_class = CLIPTextConfig
951
+
952
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
953
+
954
+ def __init__(self, config: CLIPTextConfig):
955
+ super().__init__(config)
956
+ self.text_model = CLIPTextTransformer(config)
957
+ # Initialize weights and apply final processing
958
+ self.post_init()
959
+
960
+ def get_input_embeddings(self) -> nn.Module:
961
+ return self.text_model.embeddings.token_embedding
962
+
963
+ def set_input_embeddings(self, value):
964
+ self.text_model.embeddings.token_embedding = value
965
+
966
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
967
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
968
+ def forward(
969
+ self,
970
+ input_ids: Optional[torch.Tensor] = None,
971
+ attention_mask: Optional[torch.Tensor] = None,
972
+ position_ids: Optional[torch.Tensor] = None,
973
+ output_attentions: Optional[bool] = None,
974
+ output_hidden_states: Optional[bool] = None,
975
+ return_dict: Optional[bool] = None,
976
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
977
+ r"""
978
+ Returns:
979
+
980
+ Examples:
981
+
982
+ ```python
983
+ >>> from transformers import AutoTokenizer, CLIPTextModel
984
+
985
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
986
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
987
+
988
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
989
+
990
+ >>> outputs = model(**inputs)
991
+ >>> last_hidden_state = outputs.last_hidden_state
992
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
993
+ ```"""
994
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995
+
996
+ return self.text_model(
997
+ input_ids=input_ids,
998
+ attention_mask=attention_mask,
999
+ position_ids=position_ids,
1000
+ output_attentions=output_attentions,
1001
+ output_hidden_states=output_hidden_states,
1002
+ return_dict=return_dict,
1003
+ )
1004
+
1005
+
1006
+ class CLIPVisionTransformer(nn.Module):
1007
+ def __init__(self, config: CLIPVisionConfig):
1008
+ super().__init__()
1009
+ self.config = config
1010
+ embed_dim = config.hidden_size
1011
+
1012
+ self.embeddings = CLIPVisionEmbeddings(config)
1013
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1014
+ self.encoder = CLIPEncoder(config)
1015
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1016
+
1017
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1018
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
1019
+ def forward(
1020
+ self,
1021
+ pixel_values: Optional[torch.FloatTensor] = None,
1022
+ output_attentions: Optional[bool] = None,
1023
+ output_hidden_states: Optional[bool] = None,
1024
+ return_dict: Optional[bool] = None,
1025
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1026
+ r"""
1027
+ Returns:
1028
+
1029
+ """
1030
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1031
+ output_hidden_states = (
1032
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1033
+ )
1034
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
+
1036
+ if pixel_values is None:
1037
+ raise ValueError("You have to specify pixel_values")
1038
+
1039
+ hidden_states = self.embeddings(pixel_values)
1040
+ hidden_states = self.pre_layrnorm(hidden_states)
1041
+
1042
+ encoder_outputs = self.encoder(
1043
+ inputs_embeds=hidden_states,
1044
+ output_attentions=output_attentions,
1045
+ output_hidden_states=output_hidden_states,
1046
+ return_dict=return_dict,
1047
+ )
1048
+
1049
+ last_hidden_state = encoder_outputs[0]
1050
+ pooled_output = last_hidden_state[:, 0, :]
1051
+ pooled_output = self.post_layernorm(pooled_output)
1052
+
1053
+ if not return_dict:
1054
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1055
+
1056
+ return BaseModelOutputWithPooling(
1057
+ last_hidden_state=last_hidden_state,
1058
+ pooler_output=pooled_output,
1059
+ hidden_states=encoder_outputs.hidden_states,
1060
+ attentions=encoder_outputs.attentions,
1061
+ )
1062
+
1063
+
1064
+ @add_start_docstrings(
1065
+ """The vision model from CLIP without any head or projection on top.""",
1066
+ CLIP_START_DOCSTRING,
1067
+ )
1068
+ class CLIPVisionModel(CLIPPreTrainedModel):
1069
+ config_class = CLIPVisionConfig
1070
+ main_input_name = "pixel_values"
1071
+ _no_split_modules = ["CLIPEncoderLayer"]
1072
+
1073
+ def __init__(self, config: CLIPVisionConfig):
1074
+ super().__init__(config)
1075
+ self.vision_model = CLIPVisionTransformer(config)
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ def get_input_embeddings(self) -> nn.Module:
1080
+ return self.vision_model.embeddings.patch_embedding
1081
+
1082
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1083
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
1084
+ def forward(
1085
+ self,
1086
+ pixel_values: Optional[torch.FloatTensor] = None,
1087
+ output_attentions: Optional[bool] = None,
1088
+ output_hidden_states: Optional[bool] = None,
1089
+ return_dict: Optional[bool] = None,
1090
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1091
+ r"""
1092
+ Returns:
1093
+
1094
+ Examples:
1095
+
1096
+ ```python
1097
+ >>> from PIL import Image
1098
+ >>> import requests
1099
+ >>> from transformers import AutoProcessor, CLIPVisionModel
1100
+
1101
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1102
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1103
+
1104
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1105
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1106
+
1107
+ >>> inputs = processor(images=image, return_tensors="pt")
1108
+
1109
+ >>> outputs = model(**inputs)
1110
+ >>> last_hidden_state = outputs.last_hidden_state
1111
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1112
+ ```"""
1113
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114
+
1115
+ return self.vision_model(
1116
+ pixel_values=pixel_values,
1117
+ output_attentions=output_attentions,
1118
+ output_hidden_states=output_hidden_states,
1119
+ return_dict=return_dict,
1120
+ )
1121
+
1122
+
1123
+ @add_start_docstrings(CLIP_START_DOCSTRING)
1124
+ class CLIPModel(CLIPPreTrainedModel):
1125
+ config_class = CLIPConfig
1126
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
1127
+
1128
+ def __init__(self, config: CLIPConfig):
1129
+ super().__init__(config)
1130
+
1131
+ if not isinstance(config.text_config, CLIPTextConfig):
1132
+ raise TypeError(
1133
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
1134
+ f" {type(config.text_config)}."
1135
+ )
1136
+
1137
+ if not isinstance(config.vision_config, CLIPVisionConfig):
1138
+ raise TypeError(
1139
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1140
+ f" {type(config.vision_config)}."
1141
+ )
1142
+
1143
+ text_config = config.text_config
1144
+ vision_config = config.vision_config
1145
+
1146
+ self.projection_dim = config.projection_dim
1147
+ self.text_embed_dim = text_config.hidden_size
1148
+ self.vision_embed_dim = vision_config.hidden_size
1149
+
1150
+ text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
1151
+ self.text_model = text_model.text_model
1152
+
1153
+ vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
1154
+ self.vision_model = vision_model.vision_model
1155
+
1156
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1157
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
1158
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1159
+
1160
+ # Initialize weights and apply final processing
1161
+ self.post_init()
1162
+
1163
+ def set_processor(self, model_name):
1164
+ self.processor = CLIPProcessor.from_pretrained(model_name)
1165
+
1166
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1167
+ def get_text_features(
1168
+ self,
1169
+ input_ids: Optional[torch.Tensor] = None,
1170
+ attention_mask: Optional[torch.Tensor] = None,
1171
+ position_ids: Optional[torch.Tensor] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ ) -> torch.FloatTensor:
1176
+ r"""
1177
+ Returns:
1178
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1179
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
1180
+
1181
+ Examples:
1182
+
1183
+ ```python
1184
+ >>> from transformers import AutoTokenizer, CLIPModel
1185
+
1186
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1187
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1188
+
1189
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1190
+ >>> text_features = model.get_text_features(**inputs)
1191
+ ```"""
1192
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1193
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1194
+ output_hidden_states = (
1195
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1196
+ )
1197
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1198
+
1199
+ text_outputs = self.text_model(
1200
+ input_ids=input_ids,
1201
+ attention_mask=attention_mask,
1202
+ position_ids=position_ids,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=return_dict,
1206
+ )
1207
+
1208
+ pooled_output = text_outputs[1]
1209
+ text_features = self.text_projection(pooled_output)
1210
+
1211
+ return text_features
1212
+
1213
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1214
+ def get_image_features(
1215
+ self,
1216
+ pixel_values: Optional[torch.FloatTensor] = None,
1217
+ output_attentions: Optional[bool] = None,
1218
+ output_hidden_states: Optional[bool] = None,
1219
+ return_dict: Optional[bool] = None,
1220
+ ) -> torch.FloatTensor:
1221
+ r"""
1222
+ Returns:
1223
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1224
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1225
+
1226
+ Examples:
1227
+
1228
+ ```python
1229
+ >>> from PIL import Image
1230
+ >>> import requests
1231
+ >>> from transformers import AutoProcessor, CLIPModel
1232
+
1233
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1234
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1235
+
1236
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1237
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1238
+
1239
+ >>> inputs = processor(images=image, return_tensors="pt")
1240
+
1241
+ >>> image_features = model.get_image_features(**inputs)
1242
+ ```"""
1243
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1244
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1245
+ output_hidden_states = (
1246
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1247
+ )
1248
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1249
+
1250
+ vision_outputs = self.vision_model(
1251
+ pixel_values=pixel_values,
1252
+ output_attentions=output_attentions,
1253
+ output_hidden_states=output_hidden_states,
1254
+ return_dict=return_dict,
1255
+ )
1256
+
1257
+ pooled_output = vision_outputs[1] # pooled_output
1258
+ image_features = self.visual_projection(pooled_output)
1259
+
1260
+ return image_features
1261
+
1262
+
1263
+ def encode_image(self, images):
1264
+ embeddings = self.get_image_features(images)
1265
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1266
+ return embeddings
1267
+
1268
+ def encode_text(self, text):
1269
+ embeddings = self.get_text_features(**text)
1270
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1271
+ return embeddings
1272
+
1273
+ def encode_multimodal(self, images, text):
1274
+ text_embeddings = self.get_text_features(**text)
1275
+ image_embeddings = self.get_image_features(images)
1276
+
1277
+ embeddings = text_embeddings + image_embeddings
1278
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
1279
+
1280
+ return embeddings.contiguous()
1281
+
1282
+ def data_process(self, images=None, text=None):
1283
+ if images is None and text is not None:
1284
+ text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1285
+
1286
+ return images, text, "text"
1287
+ elif images is not None and text is None:
1288
+ if isinstance(images, str):
1289
+ images = Image.open(images).convert("RGB")
1290
+ elif isinstance(images, list):
1291
+ images = [Image.open(image).convert("RGB") for image in images]
1292
+ images = self.processor(images=images, return_tensors="pt").to(self.device)
1293
+ images = images["pixel_values"]
1294
+ return images, text, "images"
1295
+ elif images is not None and text is not None:
1296
+ assert type(images) == type(text), "images and text must be the same type: list or str"
1297
+ if isinstance(images, str):
1298
+ images = Image.open(images).convert("RGB")
1299
+ elif isinstance(images, list):
1300
+ assert len(images) == len(text), "images and text must be lists of the same length when use list"
1301
+ images = [Image.open(image).convert("RGB") for image in images]
1302
+ images = self.processor(images=images, return_tensors="pt").to(self.device)
1303
+ images = images["pixel_values"]
1304
+ text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
1305
+ return images, text, "multimodal"
1306
+ else:
1307
+ raise ValueError("images and text cannot both be None")
1308
+
1309
+ def encode(self, images=None, text=None):
1310
+ images, text, data_type = self.data_process(images, text)
1311
+ if data_type == "images":
1312
+ return self.encode_image(images)
1313
+ elif data_type == "text":
1314
+ return self.encode_text(text)
1315
+ elif data_type == "multimodal":
1316
+ return self.encode_multimodal(images, text)
1317
+
1318
+
1319
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1320
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
1321
+ def forward(
1322
+ self,
1323
+ input_ids: Optional[torch.LongTensor] = None,
1324
+ pixel_values: Optional[torch.FloatTensor] = None,
1325
+ attention_mask: Optional[torch.Tensor] = None,
1326
+ position_ids: Optional[torch.LongTensor] = None,
1327
+ return_loss: Optional[bool] = None,
1328
+ output_attentions: Optional[bool] = None,
1329
+ output_hidden_states: Optional[bool] = None,
1330
+ return_dict: Optional[bool] = None,
1331
+ ) -> Union[Tuple, CLIPOutput]:
1332
+ r"""
1333
+ Returns:
1334
+
1335
+ Examples:
1336
+
1337
+ ```python
1338
+ >>> from PIL import Image
1339
+ >>> import requests
1340
+ >>> from transformers import AutoProcessor, CLIPModel
1341
+
1342
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1343
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1344
+
1345
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1346
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1347
+
1348
+ >>> inputs = processor(
1349
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1350
+ ... )
1351
+
1352
+ >>> outputs = model(**inputs)
1353
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1354
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1355
+ ```"""
1356
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1357
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1358
+ output_hidden_states = (
1359
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1360
+ )
1361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362
+
1363
+ vision_outputs = self.vision_model(
1364
+ pixel_values=pixel_values,
1365
+ output_attentions=output_attentions,
1366
+ output_hidden_states=output_hidden_states,
1367
+ return_dict=return_dict,
1368
+ )
1369
+
1370
+ text_outputs = self.text_model(
1371
+ input_ids=input_ids,
1372
+ attention_mask=attention_mask,
1373
+ position_ids=position_ids,
1374
+ output_attentions=output_attentions,
1375
+ output_hidden_states=output_hidden_states,
1376
+ return_dict=return_dict,
1377
+ )
1378
+
1379
+ image_embeds = vision_outputs[1]
1380
+ image_embeds = self.visual_projection(image_embeds)
1381
+
1382
+ text_embeds = text_outputs[1]
1383
+ text_embeds = self.text_projection(text_embeds)
1384
+
1385
+ # normalized features
1386
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
1387
+ text_embeds = text_embeds / _get_vector_norm(text_embeds)
1388
+
1389
+ # cosine similarity as logits
1390
+ logit_scale = self.logit_scale.exp()
1391
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
1392
+ text_embeds.device
1393
+ )
1394
+ logits_per_image = logits_per_text.t()
1395
+
1396
+ loss = None
1397
+ if return_loss:
1398
+ loss = clip_loss(logits_per_text)
1399
+
1400
+ if not return_dict:
1401
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1402
+ return ((loss,) + output) if loss is not None else output
1403
+
1404
+ return CLIPOutput(
1405
+ loss=loss,
1406
+ logits_per_image=logits_per_image,
1407
+ logits_per_text=logits_per_text,
1408
+ text_embeds=text_embeds,
1409
+ image_embeds=image_embeds,
1410
+ text_model_output=text_outputs,
1411
+ vision_model_output=vision_outputs,
1412
+ )
1413
+
1414
+
1415
+ @add_start_docstrings(
1416
+ """
1417
+ CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
1418
+ """,
1419
+ CLIP_START_DOCSTRING,
1420
+ )
1421
+ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1422
+ config_class = CLIPTextConfig
1423
+
1424
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
1425
+
1426
+ def __init__(self, config: CLIPTextConfig):
1427
+ super().__init__(config)
1428
+
1429
+ text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation)
1430
+ self.text_model = text_model.text_model
1431
+
1432
+ self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1433
+
1434
+ # Initialize weights and apply final processing
1435
+ self.post_init()
1436
+
1437
+ def get_input_embeddings(self) -> nn.Module:
1438
+ return self.text_model.embeddings.token_embedding
1439
+
1440
+ def set_input_embeddings(self, value):
1441
+ self.text_model.embeddings.token_embedding = value
1442
+
1443
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1444
+ @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
1445
+ def forward(
1446
+ self,
1447
+ input_ids: Optional[torch.Tensor] = None,
1448
+ attention_mask: Optional[torch.Tensor] = None,
1449
+ position_ids: Optional[torch.Tensor] = None,
1450
+ output_attentions: Optional[bool] = None,
1451
+ output_hidden_states: Optional[bool] = None,
1452
+ return_dict: Optional[bool] = None,
1453
+ ) -> Union[Tuple, CLIPTextModelOutput]:
1454
+ r"""
1455
+ Returns:
1456
+
1457
+ Examples:
1458
+
1459
+ ```python
1460
+ >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
1461
+
1462
+ >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1463
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1464
+
1465
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1466
+
1467
+ >>> outputs = model(**inputs)
1468
+ >>> text_embeds = outputs.text_embeds
1469
+ ```"""
1470
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1471
+
1472
+ text_outputs = self.text_model(
1473
+ input_ids=input_ids,
1474
+ attention_mask=attention_mask,
1475
+ position_ids=position_ids,
1476
+ output_attentions=output_attentions,
1477
+ output_hidden_states=output_hidden_states,
1478
+ return_dict=return_dict,
1479
+ )
1480
+
1481
+ pooled_output = text_outputs[1]
1482
+
1483
+ text_embeds = self.text_projection(pooled_output)
1484
+
1485
+ if not return_dict:
1486
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
1487
+ return tuple(output for output in outputs if output is not None)
1488
+
1489
+ return CLIPTextModelOutput(
1490
+ text_embeds=text_embeds,
1491
+ last_hidden_state=text_outputs.last_hidden_state,
1492
+ hidden_states=text_outputs.hidden_states,
1493
+ attentions=text_outputs.attentions,
1494
+ )
1495
+
1496
+
1497
+ @add_start_docstrings(
1498
+ """
1499
+ CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
1500
+ """,
1501
+ CLIP_START_DOCSTRING,
1502
+ )
1503
+ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
1504
+ config_class = CLIPVisionConfig
1505
+ main_input_name = "pixel_values"
1506
+
1507
+ def __init__(self, config: CLIPVisionConfig):
1508
+ super().__init__(config)
1509
+
1510
+ vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation)
1511
+ self.vision_model = vision_model.vision_model
1512
+
1513
+ self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1514
+
1515
+ # Initialize weights and apply final processing
1516
+ self.post_init()
1517
+
1518
+ def get_input_embeddings(self) -> nn.Module:
1519
+ return self.vision_model.embeddings.patch_embedding
1520
+
1521
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1522
+ @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
1523
+ def forward(
1524
+ self,
1525
+ pixel_values: Optional[torch.FloatTensor] = None,
1526
+ output_attentions: Optional[bool] = None,
1527
+ output_hidden_states: Optional[bool] = None,
1528
+ return_dict: Optional[bool] = None,
1529
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
1530
+ r"""
1531
+ Returns:
1532
+
1533
+ Examples:
1534
+
1535
+ ```python
1536
+ >>> from PIL import Image
1537
+ >>> import requests
1538
+ >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
1539
+
1540
+ >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1541
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1542
+
1543
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1544
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1545
+
1546
+ >>> inputs = processor(images=image, return_tensors="pt")
1547
+
1548
+ >>> outputs = model(**inputs)
1549
+ >>> image_embeds = outputs.image_embeds
1550
+ ```"""
1551
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1552
+
1553
+ vision_outputs = self.vision_model(
1554
+ pixel_values=pixel_values,
1555
+ output_attentions=output_attentions,
1556
+ output_hidden_states=output_hidden_states,
1557
+ return_dict=return_dict,
1558
+ )
1559
+
1560
+ pooled_output = vision_outputs[1] # pooled_output
1561
+
1562
+ image_embeds = self.visual_projection(pooled_output)
1563
+
1564
+ if not return_dict:
1565
+ outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
1566
+ return tuple(output for output in outputs if output is not None)
1567
+
1568
+ return CLIPVisionModelOutput(
1569
+ image_embeds=image_embeds,
1570
+ last_hidden_state=vision_outputs.last_hidden_state,
1571
+ hidden_states=vision_outputs.hidden_states,
1572
+ attentions=vision_outputs.attentions,
1573
+ )
1574
+
1575
+
1576
+ @add_start_docstrings(
1577
+ """
1578
+ CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1579
+ the patch tokens) e.g. for ImageNet.
1580
+ """,
1581
+ CLIP_START_DOCSTRING,
1582
+ )
1583
+ class CLIPForImageClassification(CLIPPreTrainedModel):
1584
+ main_input_name = "pixel_values"
1585
+
1586
+ def __init__(self, config: CLIPConfig) -> None:
1587
+ super().__init__(config)
1588
+
1589
+ self.num_labels = config.num_labels
1590
+ vision_model = CLIPVisionModel._from_config(
1591
+ config.vision_config, attn_implementation=config._attn_implementation
1592
+ )
1593
+ self.vision_model = vision_model.vision_model
1594
+
1595
+ # Classifier head
1596
+ self.classifier = (
1597
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1598
+ )
1599
+
1600
+ # Initialize weights and apply final processing
1601
+ self.post_init()
1602
+
1603
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1604
+ @add_code_sample_docstrings(
1605
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1606
+ output_type=ImageClassifierOutput,
1607
+ config_class=_CONFIG_FOR_DOC,
1608
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1609
+ )
1610
+ def forward(
1611
+ self,
1612
+ pixel_values: Optional[torch.Tensor] = None,
1613
+ labels: Optional[torch.Tensor] = None,
1614
+ output_attentions: Optional[bool] = None,
1615
+ output_hidden_states: Optional[bool] = None,
1616
+ return_dict: Optional[bool] = None,
1617
+ ) -> Union[tuple, ImageClassifierOutput]:
1618
+ r"""
1619
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1620
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1621
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1622
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1623
+ """
1624
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1625
+ output_hidden_states = (
1626
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1627
+ )
1628
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1629
+
1630
+ outputs = self.vision_model(
1631
+ pixel_values,
1632
+ output_attentions=output_attentions,
1633
+ output_hidden_states=output_hidden_states,
1634
+ return_dict=return_dict,
1635
+ )
1636
+
1637
+ sequence_output = outputs[0]
1638
+
1639
+ # average pool the patch tokens
1640
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
1641
+ # apply classifier
1642
+ logits = self.classifier(sequence_output)
1643
+
1644
+ loss = None
1645
+ if labels is not None:
1646
+ # move labels to correct device to enable model parallelism
1647
+ labels = labels.to(logits.device)
1648
+ if self.config.problem_type is None:
1649
+ if self.num_labels == 1:
1650
+ self.config.problem_type = "regression"
1651
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1652
+ self.config.problem_type = "single_label_classification"
1653
+ else:
1654
+ self.config.problem_type = "multi_label_classification"
1655
+
1656
+ if self.config.problem_type == "regression":
1657
+ loss_fct = MSELoss()
1658
+ if self.num_labels == 1:
1659
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1660
+ else:
1661
+ loss = loss_fct(logits, labels)
1662
+ elif self.config.problem_type == "single_label_classification":
1663
+ loss_fct = CrossEntropyLoss()
1664
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1665
+ elif self.config.problem_type == "multi_label_classification":
1666
+ loss_fct = BCEWithLogitsLoss()
1667
+ loss = loss_fct(logits, labels)
1668
+
1669
+ if not return_dict:
1670
+ output = (logits,) + outputs[2:]
1671
+ return ((loss,) + output) if loss is not None else output
1672
+
1673
+ return ImageClassifierOutput(
1674
+ loss=loss,
1675
+ logits=logits,
1676
+ hidden_states=outputs.hidden_states,
1677
+ attentions=outputs.attentions,
1678
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "resample": 3,
18
+ "size": 224
19
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "unk_token": {
3
+ "content": "<|endoftext|>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": true,
8
+ "__type": "AddedToken"
9
+ },
10
+ "bos_token": {
11
+ "content": "<|startoftext|>",
12
+ "single_word": false,
13
+ "lstrip": false,
14
+ "rstrip": false,
15
+ "normalized": true,
16
+ "__type": "AddedToken"
17
+ },
18
+ "eos_token": {
19
+ "content": "<|endoftext|>",
20
+ "single_word": false,
21
+ "lstrip": false,
22
+ "rstrip": false,
23
+ "normalized": true,
24
+ "__type": "AddedToken"
25
+ },
26
+ "pad_token": "<|endoftext|>",
27
+ "errors": "replace",
28
+ "add_prefix_space": false,
29
+ "do_lower_case": true,
30
+ "name_or_path": "openai/clip-vit-base-patch16",
31
+ "model_max_length": 77,
32
+ "special_tokens_map_file": "./special_tokens_map.json",
33
+ "tokenizer_class": "CLIPTokenizer"
34
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff