mexma-siglip / mexma_siglip.py
visheratin's picture
Upload folder using huggingface_hub
48ecfae verified
raw
history blame
2.87 kB
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedModel,
SiglipVisionConfig,
SiglipVisionModel,
XLMRobertaConfig,
XLMRobertaModel,
)
class MexmaSigLIPConfig(PretrainedConfig):
def __init__(
self,
optimized: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.optimized = optimized
class MexmaSigLIP(PreTrainedModel):
config_class = MexmaSigLIPConfig
def __init__(self, config: MexmaSigLIPConfig):
super().__init__(config)
self.config = config
text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA")
if self.config.optimized:
text_config._attn_implementation = "sdpa"
self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False)
self.text_projector = torch.nn.Linear(1024, 1152, bias=False)
vision_congig = SiglipVisionConfig.from_pretrained(
"google/siglip-so400m-patch14-384"
)
if self.config.optimized:
vision_congig._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(vision_congig).vision_model
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10)
def forward(self, image_inputs, input_ids, attention_mask, normalize=False):
text_features = self.encode_texts(input_ids, attention_mask, normalize)
image_features = self.encode_images(image_inputs, normalize)
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale,
"logit_bias": self.logit_bias,
}
def encode_images(
self,
pixel_values,
normalize=False,
):
features = self.vision_model(pixel_values).pooler_output
return F.normalize(features, dim=-1) if normalize else features
def encode_texts(
self,
input_ids,
attention_mask,
normalize=False,
):
features = self.text_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state[:, 0]
features = self.text_projector(features)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(
self,
input_ids,
attention_mask,
pixel_values,
):
image_features = self.encode_images(pixel_values, normalize=True)
text_features = self.encode_texts(input_ids, attention_mask, normalize=True)
image_logits = (
self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias
)
text_logits = image_logits.T
return image_logits, text_logits