Safetensors
custom_code
RADIO-H / extra_models.py
gheinrich's picture
Upload model
fec8d08 verified
raw
history blame
5.69 kB
from distutils.version import LooseVersion
from types import MethodType
from typing import List, Optional, Tuple, Union
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.registry import register_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .forward_intermediates import forward_intermediates
from .input_conditioner import InputConditioner
_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
class PaliGemmaWrapper(nn.Module):
def __init__(self, vis_model: nn.Module, embed_dim: int):
super().__init__()
self.vis_model = vis_model
self.embed_dim = embed_dim
@property
def patch_size(self):
return self.vis_model.embeddings.patch_size
@property
def blocks(self):
return self.vis_model.encoder.layers
@property
def embed_dim(self):
return self.vis_model.embeddings.embed_dim
def forward(self, x: torch.Tensor):
outputs = self.vis_model(
x,
return_dict=False,
interpolate_pos_encoding=True,
)
features = outputs[0].to(torch.float32)
summary = features.mean(dim=1)
return summary, features
def forward_features(self, x: torch.Tensor):
return self(x)
def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
if LooseVersion(tx_version) > LooseVersion('4.44.2'):
warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
extra_args = dict()
if dtype is not None:
extra_args['torch_dtype'] = dtype
rev = str(dtype).split('.')[-1]
extra_args['revision'] = rev
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
vis_model = model.vision_tower.vision_model
vis_model = PaliGemmaWrapper(vis_model, embed_dim)
return vis_model
@register_model
def paligemma_896_student(**kwargs):
model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
return model
def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
x = F.scaled_dot_product_attention(
q, k, v,
is_causal=False,
dropout_p=self.attn_drop.p if self.training else 0.,
scale=self.scale,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
if cache_dir:
torch.hub.set_dir(cache_dir)
model: nn.Module = torch.hub.load(
'facebookresearch/dinov2',
dino_v2_model,
pretrained=pretrained,
# **kwargs,
)
if _has_torch_sdpa:
for n, m in model.named_modules():
if n.endswith('.attn'):
m.forward = MethodType(dv2_sdpa, m)
return model
class DinoWrapper(nn.Module):
def __init__(self, dino_model: nn.Module):
super().__init__()
self.inner = dino_model
dino_model.blocks = nn.Sequential(*dino_model.blocks)
@property
def embed_dim(self):
return self.inner.embed_dim
@property
def patch_size(self):
return self.inner.patch_size
@property
def num_cls_tokens(self):
return getattr(self.inner, 'num_tokens', 1)
@property
def num_registers(self):
return getattr(self.inner, 'num_register_tokens', 0)
@property
def num_summary_tokens(self):
return self.num_cls_tokens + self.num_registers
@property
def blocks(self):
return self.inner.blocks
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
parts = self.inner.forward_features(*args, **kwargs)
cls_token = parts['x_norm_clstoken']
features = parts['x_norm_patchtokens']
return cls_token, features
def forward_features(self, x: torch.Tensor):
x = self.inner.prepare_tokens_with_masks(x)
x = self.inner.blocks(x)
x_norm = self.inner.norm(x)
return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
def patchify(self, x: torch.Tensor) -> torch.Tensor:
return self.inner.prepare_tokens_with_masks(x)
def forward_intermediates(self,
x: torch.Tensor,
norm: bool = False,
**kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
return forward_intermediates(
self,
patch_extractor=self.inner.prepare_tokens_with_masks,
num_summary_tokens=self.num_summary_tokens,
num_cls_tokens=self.num_cls_tokens,
norm=self.inner.norm if norm else lambda y: y,
x=x,
**kwargs,
)
def _dino_student(arch: str, **kwargs):
from . import dinov2_arch
factory = getattr(dinov2_arch, arch)
model = factory()
model = DinoWrapper(model)
conditioner = InputConditioner(
input_scale=1.0,
norm_mean=IMAGENET_DEFAULT_MEAN,
norm_std=IMAGENET_DEFAULT_STD,
)
model.input_conditioner = conditioner
return model
@register_model
def dino_v2_l_student(**kwargs):
return _dino_student('dinov2_vitl14_reg', **kwargs)
@register_model
def dino_v2_g_student(**kwargs):
return _dino_student('dinov2_vitg14_reg', **kwargs)