import os import torch import requests from tqdm import tqdm from torchvision import transforms from .videomaev2_finetune import vit_giant_patch14_224 def to_normalized_float_tensor(vid): return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 # NOTE: for those functions, which generally expect mini-batches, we keep them # as non-minibatch so that they are applied as if they were 4d (thus image). # this way, we only apply the transformation in the spatial domain def resize(vid, size, interpolation='bilinear'): # NOTE: using bilinear interpolation because we don't work on minibatches # at this level scale = None if isinstance(size, int): scale = float(size) / min(vid.shape[-2:]) size = None return torch.nn.functional.interpolate( vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) class ToFloatTensorInZeroOne(object): def __call__(self, vid): return to_normalized_float_tensor(vid) class Resize(object): def __init__(self, size): self.size = size def __call__(self, vid): return resize(vid, self.size) def preprocess_videomae(videos): transform = transforms.Compose( [ToFloatTensorInZeroOne(), Resize((224, 224))]) return torch.stack([transform(f) for f in torch.from_numpy(videos)]) def load_videomae_model(device, ckpt_path=None): if ckpt_path is None: current_dir = os.path.dirname(os.path.abspath(__file__)) ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth') if not os.path.exists(ckpt_path): # download the ckpt to the path ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth' response = requests.get(ckpt_url, stream=True, allow_redirects=True) total_size = int(response.headers.get("content-length", 0)) block_size = 1024 with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: with open(ckpt_path, "wb") as fw: for data in response.iter_content(block_size): progress_bar.update(len(data)) fw.write(data) model = vit_giant_patch14_224( img_size=224, pretrained=False, num_classes=174, all_frames=16, tubelet_size=2, drop_path_rate=0.3, use_mean_pooling=True) ckpt = torch.load(ckpt_path, map_location='cpu') for model_key in ['model', 'module']: if model_key in ckpt: ckpt = ckpt[model_key] break model.load_state_dict(ckpt) return model.to(device)