Safetensors
custom_code
RADIO-g / radio_model.py
gheinrich's picture
Upload model (#1)
b240ec6 verified
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
import torch
from torch import nn
from timm.models import create_model, VisionTransformer
from .enable_cpe_support import enable_cpe
from .input_conditioner import InputConditioner
from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
from . import eradio_model
from .enable_spectral_reparam import configure_spectral_reparam_from_args
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
class Resolution(NamedTuple):
height: int
width: int
class RADIOModel(nn.Module):
def __init__(
self,
model: nn.Module,
input_conditioner: InputConditioner,
patch_size: int,
max_resolution: int,
preferred_resolution: Resolution,
summary_idxs: Optional[torch.Tensor] = None,
window_size: int = None,
adaptors: Dict[str, AdaptorBase] = None,
feature_normalizer: Optional[FeatureNormalizer] = None,
inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None,
):
super().__init__()
self.model = model
self.input_conditioner = input_conditioner
if summary_idxs is not None:
self.register_buffer('summary_idxs', summary_idxs)
else:
self.summary_idxs = None
self._preferred_resolution = preferred_resolution
self._patch_size = patch_size
self._max_resolution = max_resolution
self._window_size = window_size
adaptors = adaptors or dict()
self.adaptors = nn.ModuleDict(adaptors)
if feature_normalizer is None:
feature_normalizer = nn.Identity()
self.feature_normalizer = feature_normalizer
self.inter_feature_normalizer = inter_feature_normalizer
@property
def num_summary_tokens(self) -> int:
if hasattr(self.model, 'num_summary_tokens'):
return self.model.num_summary_tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
return patch_gen.num_skip
elif self.model.global_pool == 'avg':
return 0
return 1
@property
def num_cls_tokens(self) -> int:
if hasattr(self.model, 'num_cls_tokens'):
return self.model.num_cls_tokens
patch_gen = getattr(self.model, 'patch_generator', None)
if patch_gen is not None:
return patch_gen.num_cls_tokens
elif self.model.global_pool == 'avg':
return 0
return 1
@property
def patch_size(self) -> int:
if self._patch_size is not None:
return self._patch_size
if hasattr(self.model, "patch_size"):
return self.model.patch_size
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
return patch_gen.patch_size
return None
@property
def max_resolution(self) -> int:
return self._max_resolution
@property
def preferred_resolution(self) -> Resolution:
return self._preferred_resolution
@property
def window_size(self) -> int:
return self._window_size
@property
def min_resolution_step(self) -> int:
res = self.patch_size
if self.window_size is not None:
res *= self.window_size
return res
@property
def blocks(self) -> Iterable[nn.Module]:
blocks = getattr(self.model, 'blocks', None)
if blocks is not None:
return blocks
return None
@property
def embed_dim(self) -> int:
return self.model.embed_dim
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
ret = self.input_conditioner
self.input_conditioner = nn.Identity()
return ret
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
height = int(round(height / self.min_resolution_step) * self.min_resolution_step)
width = int(round(width / self.min_resolution_step) * self.min_resolution_step)
height = max(height, self.min_resolution_step)
width = max(width, self.min_resolution_step)
return Resolution(height=height, width=width)
def switch_to_deploy(self):
fn = getattr(self.model, 'switch_to_deploy', None)
if fn is not None:
fn()
def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
'''
Forward process for model.
Args:
x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
otherwise `x` is expected to be mean centered with unit standard deviation.
feature_format: ['NLC', 'NCHW'] - The output format for the features.
'''
res_step = self.min_resolution_step
if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
x = self.input_conditioner(x)
y = self.model.forward_features(x)
ret = self._extract_final(x, y, feature_fmt=feature_fmt)
return ret
def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'):
if isinstance(self.model, VisionTransformer):
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, patch_gen.num_skip :]
elif self.model.global_pool == "avg":
all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
bb_summary = all_summary
all_feat = y
else:
all_summary = y[:, 0]
bb_summary = all_summary
all_feat = y[:, 1:]
elif isinstance(self.model, eradio_model.ERADIO):
_, f = y
all_feat = f.flatten(2).transpose(1, 2)
all_summary = all_feat.mean(dim=1)
bb_summary = all_summary
elif isinstance(y, (list, tuple)):
all_summary, all_feat = y
bb_summary = all_summary
else:
all_summary = y[:, :self.num_cls_tokens]
if self.summary_idxs is not None and all_summary.shape[1] > 1:
if all_summary.shape[1] == 1:
# Create dummy duplicates
all_summary = all_summary.expand(-1, 128, -1)
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, self.num_summary_tokens:]
all_feat = self.feature_normalizer(all_feat)
if feature_fmt == 'NCHW':
fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2])
.permute(0, 3, 1, 2)
)
elif feature_fmt == 'NLC':
fmt_feat = all_feat
else:
raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]')
ret = RadioOutput(bb_summary.flatten(1), fmt_feat)
if self.adaptors:
ret = dict(backbone=ret)
for name, adaptor in self.adaptors.items():
if all_summary.ndim == 3:
summary = all_summary[:, adaptor.head_idx]
else:
summary = all_summary
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
v = adaptor(ada_input).to(torch.float32)
ret[name] = v
return ret
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
aggregation: Optional[str] = "sparse",
norm_alpha_scheme: Optional[str] = "post-alpha",
) -> List[RadioOutput]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC
intermediates_only: Only return intermediate features
aggregation: intermediate layer aggregation method (sparse or dense).
Dense accumulation is done by averaging the features in each group.
norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none")
Only affects dense aggregation
Returns:
List of RadioOutput objects.
"""
x = self.input_conditioner(x)
intermediates = self.model.forward_intermediates(
x,
indices=indices,
return_prefix_tokens=return_prefix_tokens,
norm=norm,
stop_early=stop_early,
output_fmt=output_fmt,
intermediates_only=intermediates_only,
aggregation=aggregation,
inter_feature_normalizer=self.inter_feature_normalizer,
norm_alpha_scheme=norm_alpha_scheme,
)
if not intermediates_only:
final, intermediates = intermediates
def prepare_summary(summ: Optional[torch.Tensor]):
if summ is None:
return summ
if self.summary_idxs is not None and summ.shape[1] > 1:
summ = summ[:, self.summary_idxs]
return summ.flatten(1)
if return_prefix_tokens:
radio_outputs = [
RadioOutput(prepare_summary(summary), features)
for summary, features in intermediates
]
else:
radio_outputs = intermediates
if intermediates_only:
return radio_outputs
else:
final = self._extract_final(x, final, feature_fmt=output_fmt)
return final, radio_outputs
def create_model_from_args(args) -> nn.Module:
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
# Skip weight initialization unless it's explicitly requested.
weight_init = args.model_kwargs.pop("weight_init", "skip")
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
weight_init=weight_init,
**args.model_kwargs,
)
if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
model.norm = nn.Identity()
model.head = nn.Identity()
assert (
not args.cls_token_per_teacher or args.cpe_max_size is not None
), "CPE must be enabled for multiple CLS tokens!"
if args.cpe_max_size is not None:
uq_teachers = set(t['name'] for t in args.teachers)
enable_cpe(
model,
args.cpe_max_size,
num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1,
register_multiple=getattr(args, 'register_multiple', None),
num_registers=getattr(args, 'cpe_num_registers', None),
)
return model