|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
SiglipVisionConfig, |
|
SiglipVisionModel, |
|
XLMRobertaConfig, |
|
XLMRobertaModel, |
|
) |
|
|
|
|
|
class MexmaSigLIPConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
optimized: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.optimized = optimized |
|
|
|
|
|
class MexmaSigLIP(PreTrainedModel): |
|
config_class = MexmaSigLIPConfig |
|
|
|
def __init__(self, config: MexmaSigLIPConfig): |
|
super().__init__(config) |
|
self.config = config |
|
text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA") |
|
if self.config.optimized: |
|
text_config._attn_implementation = "sdpa" |
|
self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False) |
|
self.text_projector = torch.nn.Linear(1024, 1152, bias=False) |
|
vision_congig = SiglipVisionConfig.from_pretrained( |
|
"google/siglip-so400m-patch14-384" |
|
) |
|
if self.config.optimized: |
|
vision_congig._attn_implementation = "flash_attention_2" |
|
self.vision_model = SiglipVisionModel(vision_congig).vision_model |
|
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10) |
|
|
|
def forward(self, image_inputs, input_ids, attention_mask, normalize=False): |
|
text_features = self.encode_texts(input_ids, attention_mask, normalize) |
|
image_features = self.encode_images(image_inputs, normalize) |
|
return { |
|
"image_features": image_features, |
|
"text_features": text_features, |
|
"logit_scale": self.logit_scale, |
|
"logit_bias": self.logit_bias, |
|
} |
|
|
|
def encode_images( |
|
self, |
|
pixel_values, |
|
normalize=False, |
|
): |
|
features = self.vision_model(pixel_values).pooler_output |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def encode_texts( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
normalize=False, |
|
): |
|
features = self.text_model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
).last_hidden_state[:, 0] |
|
features = self.text_projector(features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def get_logits( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
pixel_values, |
|
): |
|
image_features = self.encode_images(pixel_values, normalize=True) |
|
text_features = self.encode_texts(input_ids, attention_mask, normalize=True) |
|
image_logits = ( |
|
self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias |
|
) |
|
text_logits = image_logits.T |
|
return image_logits, text_logits |
|
|