diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d51fd670841e11be55bb304e3f1ee04edb458907
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Pietro Mazzaglia
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 0bdb6af55988f2104f7d2740cce3bced17c9103e..5cc933730f4af5eca5454fbb9488831a936f2e84 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,177 @@
----
-title: Genrl
-emoji: 💻
-colorFrom: blue
-colorTo: pink
-sdk: gradio
-sdk_version: 4.37.1
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# GenRL: Multimodal foundation world models for generalist embodied agents
+
+
+
+
+
+
+ Website   | Models 🤗   | Datasets 🤗   | Gradio demo   | Notebooks  
+
+
+## Get started
+
+### Creating the environment
+
+We recommend using `conda` to create the environment
+
+```
+conda create --name genrl python=3.10
+
+conda activate genrl
+
+pip install -r requirements.txt
+```
+
+### Downloading InternVideo2
+
+Download InternVideo 2 [[here]](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4/blob/main/InternVideo2-stage2_1b-224p-f4.pt).
+
+Place in the `models` folder.
+
+Note: the file access is restricted, so you'll need an HuggingFace account to request access to the file.
+
+Note: By default, the code expects the model to be placed in the `models` folder. The variable `MODELS_ROOT_PATH` indicating where the model should be place is set in `tools/genrl_utils.py`.
+
+## Data
+
+### Download datasets
+
+The datasets used to pre-trained the models can be downloaded [[here]](https://huggingface.co/datasets/mazpie/genrl_datasets).
+
+The file are `tar.gz` and can be extracted using the `tar` utility on Linux. For example:
+
+```
+tar -zxvf walker_data.tar.gz
+```
+
+### Collecting and pre-processing data
+
+If you don't want to download our datasets, you collect and pre-process the data on your own.
+
+Data can be collected running a DreamerV3 agent on a task, by running:
+
+```
+python3 collect_data.py agent=dreamer task=stickman_walk
+```
+
+or the Plan2Explore agent, by running:
+
+```
+python3 collect_data.py agent=plan2explore conf/defaults=dreamer_v2 task=stickman_walk
+```
+
+A repo for the experiment will be created under the directory `exp_local`, such as: `exp_local/YYYY.MM.DD/HHMMSS_agentname`. The data can then be found in the `buffer` subdirectory.
+
+
+After obtaining the data, it should be processed to obtain the video embeddings for each frame sequence in the episodes. The processing can be done by running:
+
+```
+python3 process_dataset.py dataset_dir=data/stickman_example
+```
+
+where `data/stickman_example` is replaced by the folder of the data you want to process.
+
+## Agents
+
+### Downloading pre-trained models
+
+If you want to test our work, without having to pre-train the models, you can do this by using our pre-trained models.
+
+Pretrained models can be found [[here]](https://huggingface.co/mazpie/genrl_models)
+
+Here's a snippet to download them easily:
+
+```
+import os
+from huggingface_hub import hf_hub_download
+
+def download_model(model_folder, model_filename):
+ REPO_ID = 'mazpie/genrl_models'
+ filename_list = [model_filename]
+ if not os.path.exists(model_folder):
+ os.makedirs(model_folder)
+ for filename in filename_list:
+ local_file = os.path.join(model_folder, filename)
+ if not os.path.exists(local_file):
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
+
+download_model('models', 'genrl_stickman_500k_2.pt')
+```
+
+Pre-trained models can be used by setting `snapshot_load_dir=...` when running `train.py`.
+
+Note: the pre-trained models are not trained to solve any tasks. They only contain a pre-trained multimodal foundation world model (world model + connector and aligner).
+
+### Training multimodal foundation world models
+
+In order to train a multimodal foundation world model from data, you should run something like:
+
+```
+# Note: frames = update steps
+
+
+python3 train.py task=stickman_walk replay_load_dir=data/stickman_example num_train_frames=500_010 visual_every_frames=25_000 train_world_model=True train_connector=True reset_world_model=True reset_connector=True
+```
+
+### Behavior learning
+
+After pre-training a model, you can train the behavior for a task using:
+
+```
+python3 train.py task=stickman_walk snapshot_load_dir=models/genrl_stickman_500k_2.pt num_train_frames=50_010 batch_size=32 batch_length=32 agent.imag_reward_fn=video_text_reward eval_modality=task_imag
+```
+
+Data-free RL can be performed by additionaly passing the option:
+
+`train_from_data=False`
+
+The prompts for each task can be found and edited in `tools/genrl_utils.py`. However, you can also pass a custom prompt for a task by passing the option:
+
+`+agent.imag_reward_args.task_prompt=custom_prompt`
+
+## Other utilities
+
+### Gradio demo
+
+There's a gradio demo that can be found at `demo/app.py`.
+
+If launching demo like a standard Python program with:
+
+```
+python3 demo/app.py
+```
+
+it will return a local endpoint (e.g. http://127.0.0.1:7860) where to access a dashboard to play with GenRL.
+
+
+
+
+
+### Notebooks
+
+You can find several notebooks to test our code in the `notebooks` directory.
+
+`demo_videoclip` : can be used to test the correct functioning of the InternVideo2 component
+
+`text2video` : utility to generate video reconstructions from text prompts
+
+`video2video` : utility to generate video reconstructions from video prompts
+
+`visualize_dataset_episodes` : utility to generate videos from the episodes in a given dataset
+
+`visualize_env` : used to play with the environment and, for instance, understand how the reward function of each task works
+
+### Stickman environment
+
+We introduced the Stickman environment as a simplified 2D version of the Humanoid environment.
+
+This can be found in the `envs/custom_dmc_tasks` folder. You will find an `.xml` model and a `.py` files containing the tasks.
+
+## Acknowledgments
+
+We would like to thank the authors of the following repositories for their useful code and models:
+
+* [InternVideo2](https://github.com/OpenGVLab/InternVideo)
+* [Franka Kitchen](https://github.com/google-research/relay-policy-learning)
+* [DreamerV3](https://github.com/danijar/dreamerv3)
+* [DreamerV3-torch](https://github.com/NM512/dreamerv3-torch)
\ No newline at end of file
diff --git a/agent/dreamer.py b/agent/dreamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb171b57b5b9c180b7fea7c6f5fd53919162ffd
--- /dev/null
+++ b/agent/dreamer.py
@@ -0,0 +1,462 @@
+import torch.nn as nn
+import torch
+
+import tools.utils as utils
+import agent.dreamer_utils as common
+from collections import OrderedDict
+import numpy as np
+
+from tools.genrl_utils import *
+
+def stop_gradient(x):
+ return x.detach()
+
+Module = nn.Module
+
+def env_reward(agent, seq):
+ return agent.wm.heads['reward'](seq['feat']).mean
+
+class DreamerAgent(Module):
+
+ def __init__(self,
+ name, cfg, obs_space, act_spec, **kwargs):
+ super().__init__()
+ self.name = name
+ self.cfg = cfg
+ self.cfg.update(**kwargs)
+ self.obs_space = obs_space
+ self.act_spec = act_spec
+ self._use_amp = (cfg.precision == 16)
+ self.device = cfg.device
+ self.act_dim = act_spec.shape[0]
+ self.wm = WorldModel(cfg, obs_space, self.act_dim,)
+ self.instantiate_acting_behavior()
+
+ self.to(cfg.device)
+ self.requires_grad_(requires_grad=False)
+
+ def instantiate_acting_behavior(self,):
+ self._acting_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size).to(self.device)
+
+ def act(self, obs, meta, step, eval_mode, state):
+ if self.cfg.only_random_actions:
+ return np.random.uniform(-1, 1, self.act_dim,).astype(self.act_spec.dtype), (None, None)
+ obs = {k : torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()}
+ if state is None:
+ latent = self.wm.rssm.initial(len(obs['reward']))
+ action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device)
+ else:
+ latent, action = state
+ embed = self.wm.encoder(self.wm.preprocess(obs))
+ should_sample = (not eval_mode) or (not self.cfg.eval_state_mean)
+ latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample)
+ feat = self.wm.rssm.get_feat(latent)
+ if eval_mode:
+ actor = self._acting_behavior.actor(feat)
+ try:
+ action = actor.mean
+ except:
+ action = actor._mean
+ else:
+ actor = self._acting_behavior.actor(feat)
+ action = actor.sample()
+ new_state = (latent, action)
+ return action.cpu().numpy()[0], new_state
+
+ def update_wm(self, data, step):
+ metrics = {}
+ state, outputs, mets = self.wm.update(data, state=None)
+ outputs['is_terminal'] = data['is_terminal']
+ metrics.update(mets)
+ return state, outputs, metrics
+
+ def update_acting_behavior(self, state=None, outputs=None, metrics={}, data=None, reward_fn=None):
+ if self.cfg.only_random_actions:
+ return {}, metrics
+ if outputs is not None:
+ post = outputs['post']
+ is_terminal = outputs['is_terminal']
+ else:
+ data = self.wm.preprocess(data)
+ embed = self.wm.encoder(data)
+ post, _ = self.wm.rssm.observe(
+ embed, data['action'], data['is_first'])
+ is_terminal = data['is_terminal']
+ #
+ start = {k: stop_gradient(v) for k,v in post.items()}
+ if reward_fn is None:
+ acting_reward_fn = lambda seq: globals()[self.cfg.acting_reward_fn](self, seq) #.mode()
+ else:
+ acting_reward_fn = lambda seq: reward_fn(self, seq) #.mode()
+ metrics.update(self._acting_behavior.update(self.wm, start, is_terminal, acting_reward_fn))
+ return start, metrics
+
+ def update(self, data, step):
+ state, outputs, metrics = self.update_wm(data, step)
+ start, metrics = self.update_acting_behavior(state, outputs, metrics, data)
+ return state, metrics
+
+ def report(self, data):
+ report = {}
+ data = self.wm.preprocess(data)
+ for key in self.wm.heads['decoder'].cnn_keys:
+ name = key.replace('/', '_')
+ report[f'openl_{name}'] = self.wm.video_pred(data, key)
+ for fn in getattr(self.cfg, 'additional_report_fns', []):
+ call_fn = globals()[fn]
+ additional_report = call_fn(self, data)
+ report.update(additional_report)
+ return report
+
+ def get_meta_specs(self):
+ return tuple()
+
+ def init_meta(self):
+ return OrderedDict()
+
+ def update_meta(self, meta, global_step, time_step, finetune=False):
+ return meta
+
+class WorldModel(Module):
+ def __init__(self, config, obs_space, act_dim,):
+ super().__init__()
+ shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
+ self.shapes = shapes
+ self.cfg = config
+ self.device = config.device
+ self.encoder = common.Encoder(shapes, **config.encoder)
+ # Computing embed dim
+ with torch.no_grad():
+ zeros = {k: torch.zeros( (1,) + v) for k, v in shapes.items()}
+ outs = self.encoder(zeros)
+ embed_dim = outs.shape[1]
+ self.embed_dim = embed_dim
+ self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device,)
+ self.heads = {}
+ self._use_amp = (config.precision == 16)
+ self.inp_size = self.rssm.get_feat_size()
+ self.decoder_input_fn = getattr(self.rssm, f'get_{config.decoder_inputs}')
+ self.decoder_input_size = getattr(self.rssm, f'get_{config.decoder_inputs}_size')()
+ self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=self.decoder_input_size, image_dist=config.image_dist)
+ self.heads['reward'] = common.MLP(self.inp_size, (1,), **config.reward_head)
+ # zero init
+ with torch.no_grad():
+ for p in self.heads['reward']._out.parameters():
+ p.data = p.data * 0
+ #
+ if config.pred_discount:
+ self.heads['discount'] = common.MLP(self.inp_size, (1,), **config.discount_head)
+ for name in config.grad_heads:
+ assert name in self.heads, name
+ self.grad_heads = config.grad_heads
+ self.heads = nn.ModuleDict(self.heads)
+ self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp)
+ self.e2e_update_fns = {}
+ self.detached_update_fns = {}
+ self.eval()
+
+ def add_module_to_update(self, name, module, update_fn, detached=False):
+ self.add_module(name, module)
+ if detached:
+ self.detached_update_fns[name] = update_fn
+ else:
+ self.e2e_update_fns[name] = update_fn
+ self.model_opt = common.Optimizer('model', self.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
+
+ def update(self, data, state=None):
+ self.train()
+ with common.RequiresGrad(self):
+ with torch.cuda.amp.autocast(enabled=self._use_amp):
+ if getattr(self.cfg, "freeze_decoder", False):
+ self.heads['decoder'].requires_grad_(False)
+ if getattr(self.cfg, "freeze_post", False) or getattr(self.cfg, "freeze_model", False):
+ self.heads['decoder'].requires_grad_(False)
+ self.encoder.requires_grad_(False)
+ # Updating only prior
+ self.grad_heads = []
+ self.rssm.requires_grad_(False)
+ if not getattr(self.cfg, "freeze_model", False):
+ self.rssm._ensemble_img_out.requires_grad_(True)
+ self.rssm._ensemble_img_dist.requires_grad_(True)
+ model_loss, state, outputs, metrics = self.loss(data, state)
+ model_loss, metrics = self.update_additional_e2e_modules(data, outputs, model_loss, metrics)
+ metrics.update(self.model_opt(model_loss, self.parameters()))
+ if len(self.detached_update_fns) > 0:
+ detached_loss, metrics = self.update_additional_detached_modules(data, outputs, metrics)
+ self.eval()
+ return state, outputs, metrics
+
+ def update_additional_detached_modules(self, data, outputs, metrics):
+ # additional detached losses
+ detached_loss = 0
+ for k in self.detached_update_fns:
+ detached_module = getattr(self, k)
+ with common.RequiresGrad(detached_module):
+ with torch.cuda.amp.autocast(enabled=self._use_amp):
+ add_loss, add_metrics = self.detached_update_fns[k](self, k, data, outputs, metrics)
+ metrics.update(add_metrics)
+ opt_metrics = self.model_opt(add_loss, detached_module.parameters())
+ metrics.update({ f'{k}_{m}' : opt_metrics[m] for m in opt_metrics})
+ return detached_loss, metrics
+
+ def update_additional_e2e_modules(self, data, outputs, model_loss, metrics):
+ # additional e2e losses
+ for k in self.e2e_update_fns:
+ add_loss, add_metrics = self.e2e_update_fns[k](self, k, data, outputs, metrics)
+ model_loss += add_loss
+ metrics.update(add_metrics)
+ return model_loss, metrics
+
+ def observe_data(self, data, state=None):
+ data = self.preprocess(data)
+ embed = self.encoder(data)
+ post, prior = self.rssm.observe(
+ embed, data['action'], data['is_first'], state)
+ kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
+ outs = dict(embed=embed, post=post, prior=prior, is_terminal=data['is_terminal'])
+ return outs, { 'model_kl' : kl_value.mean() }
+
+ def loss(self, data, state=None):
+ data = self.preprocess(data)
+ embed = self.encoder(data)
+ post, prior = self.rssm.observe(
+ embed, data['action'], data['is_first'], state)
+ kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
+ assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape
+ likes = {}
+ losses = {'kl': kl_loss}
+ feat = self.rssm.get_feat(post)
+ for name, head in self.heads.items():
+ grad_head = (name in self.grad_heads)
+ if name == 'decoder':
+ inp = self.decoder_input_fn(post)
+ else:
+ inp = feat
+ inp = inp if grad_head else stop_gradient(inp)
+ out = head(inp)
+ dists = out if isinstance(out, dict) else {name: out}
+ for key, dist in dists.items():
+ like = dist.log_prob(data[key])
+ likes[key] = like
+ losses[key] = -like.mean()
+ model_loss = sum(
+ self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items())
+ outs = dict(
+ embed=embed, feat=feat, post=post,
+ prior=prior, likes=likes, kl=kl_value)
+ metrics = {f'{name}_loss': value for name, value in losses.items()}
+ metrics['model_kl'] = kl_value.mean()
+ metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
+ metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
+ last_state = {k: v[:, -1] for k, v in post.items()}
+ return model_loss, last_state, outs, metrics
+
+ def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False):
+ flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
+ start = {k: flatten(v) for k, v in start.items()}
+ start['feat'] = self.rssm.get_feat(start)
+ inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1)
+ policy_dist = policy(inp)
+ start['action'] = torch.zeros_like(policy_dist.sample(), device=self.device) #.mode())
+ seq = {k: [v] for k, v in start.items()}
+ if task_cond is not None: seq['task'] = [task_cond]
+ for _ in range(horizon):
+ inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1)
+ policy_dist = policy(stop_gradient(inp))
+ action = policy_dist.sample() if not eval_policy else policy_dist.mean
+ state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
+ feat = self.rssm.get_feat(state)
+ for key, value in {**state, 'action': action, 'feat': feat}.items():
+ seq[key].append(value)
+ if task_cond is not None: seq['task'].append(task_cond)
+ # shape will be (T, B, *DIMS)
+ seq = {k: torch.stack(v, 0) for k, v in seq.items()}
+ if 'discount' in self.heads:
+ disc = self.heads['discount'](seq['feat']).mean()
+ if is_terminal is not None:
+ # Override discount prediction for the first step with the true
+ # discount factor from the replay buffer.
+ true_first = 1.0 - flatten(is_terminal)
+ disc = torch.cat([true_first[None], disc[1:]], 0)
+ else:
+ disc = torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device)
+ seq['discount'] = disc * self.cfg.discount
+ # Shift discount factors because they imply whether the following state
+ # will be valid, not whether the current state is valid.
+ seq['weight'] = torch.cumprod(torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0)
+ return seq
+
+ def preprocess(self, obs):
+ obs = obs.copy()
+ for key, value in obs.items():
+ if key.startswith('log_'):
+ continue
+ if value.dtype in [np.uint8, torch.uint8]:
+ value = value / 255.0 - 0.5
+ obs[key] = value
+ obs['reward'] = {
+ 'identity': nn.Identity(),
+ 'sign': torch.sign,
+ 'tanh': torch.tanh,
+ }[self.cfg.clip_rewards](obs['reward'])
+ obs['discount'] = (1.0 - obs['is_terminal'].float())
+ if len(obs['discount'].shape) < len(obs['reward'].shape):
+ obs['discount'] = obs['discount'].unsqueeze(-1)
+ return obs
+
+ def video_pred(self, data, key, nvid=8):
+ decoder = self.heads['decoder'] # B, T, C, H, W
+ truth = data[key][:nvid] + 0.5
+ embed = self.encoder(data)
+ states, _ = self.rssm.observe(
+ embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5])
+ recon = decoder(self.decoder_input_fn(states))[key].mean[:nvid] # mode
+ init = {k: v[:, -1] for k, v in states.items()}
+ prior = self.rssm.imagine(data['action'][:nvid, 5:], init)
+ prior_recon = decoder(self.decoder_input_fn(prior))[key].mean # mode
+ model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1)
+ error = (model - truth + 1) / 2
+ video = torch.cat([truth, model, error], 3)
+ B, T, C, H, W = video.shape
+ return video
+
+class ActorCritic(Module):
+ def __init__(self, config, act_spec, feat_size, name=''):
+ super().__init__()
+ self.name = name
+ self.cfg = config
+ self.act_spec = act_spec
+ self._use_amp = (config.precision == 16)
+ self.device = config.device
+
+ if getattr(self.cfg, 'discrete_actions', False):
+ self.cfg.actor.dist = 'onehot'
+
+ self.actor_grad = getattr(self.cfg, f'{self.name}_actor_grad'.strip('_'))
+
+ inp_size = feat_size
+ self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor)
+ self.critic = common.MLP(inp_size, (1,), **self.cfg.critic)
+ if self.cfg.slow_target:
+ self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic)
+ self._updates = 0 # tf.Variable(0, tf.int64)
+ else:
+ self._target_critic = self.critic
+ self.actor_opt = common.Optimizer('actor', self.actor.parameters(), **self.cfg.actor_opt, use_amp=self._use_amp)
+ self.critic_opt = common.Optimizer('critic', self.critic.parameters(), **self.cfg.critic_opt, use_amp=self._use_amp)
+
+ if self.cfg.reward_ema:
+ # register ema_vals to nn.Module for enabling torch.save and torch.load
+ self.register_buffer("ema_vals", torch.zeros((2,)).to(self.device))
+ self.reward_ema = common.RewardEMA(device=self.device)
+ self.rewnorm = common.StreamNorm(momentum=1, scale=1.0, device=self.device)
+ else:
+ self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device)
+
+ # zero init
+ with torch.no_grad():
+ for p in self.critic._out.parameters():
+ p.data = p.data * 0
+ # hard copy critic initial params
+ for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
+ d.data = s.data
+ #
+
+
+ def update(self, world_model, start, is_terminal, reward_fn):
+ metrics = {}
+ hor = self.cfg.imag_horizon
+ # The weights are is_terminal flags for the imagination start states.
+ # Technically, they should multiply the losses from the second trajectory
+ # step onwards, which is the first imagined step. However, we are not
+ # training the action that led into the first step anyway, so we can use
+ # them to scale the whole sequence.
+ with common.RequiresGrad(self.actor):
+ with torch.cuda.amp.autocast(enabled=self._use_amp):
+ seq = world_model.imagine(self.actor, start, is_terminal, hor)
+ reward = reward_fn(seq)
+ seq['reward'], mets1 = self.rewnorm(reward)
+ mets1 = {f'reward_{k}': v for k, v in mets1.items()}
+ target, mets2, baseline = self.target(seq)
+ actor_loss, mets3 = self.actor_loss(seq, target, baseline)
+ metrics.update(self.actor_opt(actor_loss, self.actor.parameters()))
+ with common.RequiresGrad(self.critic):
+ with torch.cuda.amp.autocast(enabled=self._use_amp):
+ seq = {k: stop_gradient(v) for k,v in seq.items()}
+ critic_loss, mets4 = self.critic_loss(seq, target)
+ metrics.update(self.critic_opt(critic_loss, self.critic.parameters()))
+ metrics.update(**mets1, **mets2, **mets3, **mets4)
+ self.update_slow_target() # Variables exist after first forward pass.
+ return { f'{self.name}_{k}'.strip('_') : v for k,v in metrics.items() }
+
+ def actor_loss(self, seq, target, baseline): #, step):
+ # Two state-actions are lost at the end of the trajectory, one for the boostrap
+ # value prediction and one because the corresponding action does not lead
+ # anywhere anymore. One target is lost at the start of the trajectory
+ # because the initial state comes from the replay buffer.
+ policy = self.actor(stop_gradient(seq['feat'][:-2])) # actions are the ones in [1:-1]
+
+ metrics = {}
+ if self.cfg.reward_ema:
+ offset, scale = self.reward_ema(target, self.ema_vals)
+ normed_target = (target - offset) / scale
+ normed_baseline = (baseline - offset) / scale
+ # adv = normed_target - normed_baseline
+ metrics['normed_target_mean'] = normed_target.mean()
+ metrics['normed_target_std'] = normed_target.std()
+ metrics["reward_ema_005"] = self.ema_vals[0]
+ metrics["reward_ema_095"] = self.ema_vals[1]
+ else:
+ normed_target = target
+ normed_baseline = baseline
+
+ if self.actor_grad == 'dynamics':
+ objective = normed_target[1:]
+ elif self.actor_grad == 'reinforce':
+ advantage = normed_target[1:] - normed_baseline[1:]
+ objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:,:,None] * advantage
+ else:
+ raise NotImplementedError(self.actor_grad)
+
+ ent = policy.entropy()[:,:,None]
+ ent_scale = self.cfg.actor_ent
+ objective += ent_scale * ent
+ metrics['actor_ent'] = ent.mean()
+ metrics['actor_ent_scale'] = ent_scale
+
+ weight = stop_gradient(seq['weight'])
+ actor_loss = -(weight[:-2] * objective).mean()
+ return actor_loss, metrics
+
+ def critic_loss(self, seq, target):
+ feat = seq['feat'][:-1]
+ target = stop_gradient(target)
+ weight = stop_gradient(seq['weight'])
+ dist = self.critic(feat)
+ critic_loss = -(dist.log_prob(target)[:,:,None] * weight[:-1]).mean()
+ metrics = {'critic': dist.mean.mean() }
+ return critic_loss, metrics
+
+ def target(self, seq):
+ reward = seq['reward']
+ disc = seq['discount']
+ value = self._target_critic(seq['feat']).mean
+ # Skipping last time step because it is used for bootstrapping.
+ target = common.lambda_return(
+ reward[:-1], value[:-1], disc[:-1],
+ bootstrap=value[-1],
+ lambda_=self.cfg.discount_lambda,
+ axis=0)
+ metrics = {}
+ metrics['critic_slow'] = value.mean()
+ metrics['critic_target'] = target.mean()
+ return target, metrics, value[:-1]
+
+ def update_slow_target(self):
+ if self.cfg.slow_target:
+ if self._updates % self.cfg.slow_target_update == 0:
+ mix = 1.0 if self._updates == 0 else float(
+ self.cfg.slow_target_fraction)
+ for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
+ d.data = mix * s.data + (1 - mix) * d.data
+ self._updates += 1
\ No newline at end of file
diff --git a/agent/dreamer.yaml b/agent/dreamer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c9d9c8333bd9a7830759f9f2e66a3bc8ea951853
--- /dev/null
+++ b/agent/dreamer.yaml
@@ -0,0 +1,9 @@
+# @package agent
+_target_: agent.dreamer.DreamerAgent
+name: dreamer
+cfg: ???
+obs_space: ???
+act_spec: ???
+grad_heads: [decoder, reward]
+reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
+actor_ent: 3e-4
\ No newline at end of file
diff --git a/agent/dreamer_utils.py b/agent/dreamer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d560cd14cd0df61a0fceee37279a7b05850b91
--- /dev/null
+++ b/agent/dreamer_utils.py
@@ -0,0 +1,1040 @@
+import re
+
+import numpy as np
+
+import tools.utils as utils
+import torch.nn as nn
+import torch
+import torch.distributions as D
+import torch.nn.functional as F
+
+Module = nn.Module
+
+def symlog(x):
+ return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
+
+def symexp(x):
+ return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
+
+def signed_hyperbolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
+ """Signed hyperbolic transform, inverse of signed_parabolic."""
+ return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
+
+def signed_parabolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
+ """Signed parabolic transform, inverse of signed_hyperbolic."""
+ z = torch.sqrt(1 + 4 * eps * (eps + 1 + torch.abs(x))) / 2 / eps - 1 / 2 / eps
+ return torch.sign(x) * (torch.square(z) - 1)
+
+class SampleDist:
+ def __init__(self, dist: D.Distribution, samples=100):
+ self._dist = dist
+ self._samples = samples
+
+ @property
+ def name(self):
+ return 'SampleDist'
+
+ def __getattr__(self, name):
+ return getattr(self._dist, name)
+
+ @property
+ def mean(self):
+ sample = self._dist.rsample((self._samples,))
+ return torch.mean(sample, 0)
+
+ def mode(self):
+ dist = self._dist.expand((self._samples, *self._dist.batch_shape))
+ sample = dist.rsample()
+ logprob = dist.log_prob(sample)
+ batch_size = sample.size(1)
+ feature_size = sample.size(2)
+ indices = torch.argmax(logprob, dim=0).reshape(1, batch_size, 1).expand(1, batch_size, feature_size)
+ return torch.gather(sample, 0, indices).squeeze(0)
+
+ def entropy(self):
+ sample = self._dist.rsample((self._samples,))
+ logprob = self._dist.log_prob(sample)
+ return -torch.mean(logprob, 0)
+
+ def sample(self):
+ return self._dist.rsample()
+
+class MSEDist:
+ def __init__(self, mode, agg="sum"):
+ self._mode = mode
+ self._agg = agg
+
+ @property
+ def mean(self):
+ return self._mode
+
+ def mode(self):
+ return self._mode
+
+ def log_prob(self, value):
+ assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
+ distance = (self._mode - value) ** 2
+ if self._agg == "mean":
+ loss = distance.mean(list(range(len(distance.shape)))[2:])
+ elif self._agg == "sum":
+ loss = distance.sum(list(range(len(distance.shape)))[2:])
+ else:
+ raise NotImplementedError(self._agg)
+ return -loss
+
+class SymlogDist:
+
+ def __init__(self, mode, dims, dist='mse', agg='sum', tol=1e-8):
+ self._mode = mode
+ self._dims = tuple([-x for x in range(1, dims + 1)])
+ self._dist = dist
+ self._agg = agg
+ self._tol = tol
+ self.batch_shape = mode.shape[:len(mode.shape) - dims]
+ self.event_shape = mode.shape[len(mode.shape) - dims:]
+
+ def mode(self):
+ return symexp(self._mode)
+
+ def mean(self):
+ return symexp(self._mode)
+
+ def log_prob(self, value):
+ assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
+ if self._dist == 'mse':
+ distance = (self._mode - symlog(value)) ** 2
+ distance = torch.where(distance < self._tol, torch.tensor([0.], dtype=distance.dtype, device=distance.device), distance)
+ elif self._dist == 'abs':
+ distance = torch.abs(self._mode - symlog(value))
+ distance = torch.where(distance < self._tol, torch.tensor([0.], dtype=distance.dtype, device=distance.device), distance)
+ else:
+ raise NotImplementedError(self._dist)
+ if self._agg == 'mean':
+ loss = distance.mean(self._dims)
+ elif self._agg == 'sum':
+ loss = distance.sum(self._dims)
+ else:
+ raise NotImplementedError(self._agg)
+ return -loss
+
+class TwoHotDist:
+ def __init__(
+ self,
+ logits,
+ low=-20.0,
+ high=20.0,
+ transfwd=symlog,
+ transbwd=symexp,
+ ):
+ assert logits.shape[-1] == 255
+ self.logits = logits
+ self.probs = torch.softmax(logits, -1)
+ self.buckets = torch.linspace(low, high, steps=255).to(logits.device)
+ self.width = (self.buckets[-1] - self.buckets[0]) / 255
+ self.transfwd = transfwd
+ self.transbwd = transbwd
+
+ @property
+ def mean(self):
+ _mean = self.probs * self.buckets
+ return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
+
+ @property
+ def mode(self):
+ return self.mean
+
+ # Inside OneHotCategorical, log_prob is calculated using only max element in targets
+ def log_prob(self, x):
+ x = self.transfwd(x)
+ # x(time, batch, 1)
+ below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
+ above = len(self.buckets) - torch.sum(
+ (self.buckets > x[..., None]).to(torch.int32), dim=-1
+ )
+ # this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits.
+ below = torch.clip(below, 0, len(self.buckets) - 1)
+ above = torch.clip(above, 0, len(self.buckets) - 1)
+ equal = below == above
+
+ dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
+ dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
+ total = dist_to_below + dist_to_above
+ weight_below = dist_to_above / total
+ weight_above = dist_to_below / total
+ target = (
+ F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None]
+ + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
+ )
+ log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
+ target = target.squeeze(-2)
+
+ return (target * log_pred).sum(-1)
+
+ def log_prob_target(self, target):
+ log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
+ return (target * log_pred).sum(-1)
+
+class OneHotDist(D.OneHotCategorical):
+
+ def __init__(self, logits=None, probs=None, unif_mix=0.99):
+ super().__init__(logits=logits, probs=probs)
+ probs = super().probs
+ probs = unif_mix * probs + (1 - unif_mix) * torch.ones_like(probs, device=probs.device) / probs.shape[-1]
+ super().__init__(probs=probs)
+
+ def mode(self):
+ _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
+ return _mode.detach() + super().logits - super().logits.detach()
+
+ def sample(self, sample_shape=(), seed=None):
+ if seed is not None:
+ raise ValueError('need to check')
+ sample = super().sample(sample_shape)
+ probs = super().probs
+ while len(probs.shape) < len(sample.shape):
+ probs = probs[None]
+ sample += probs - probs.detach() # ST-gradients
+ return sample
+
+class BernoulliDist(D.Bernoulli):
+ def __init__(self, logits=None, probs=None):
+ super().__init__(logits=logits, probs=probs)
+
+ def sample(self, sample_shape=(), seed=None):
+ if seed is not None:
+ raise ValueError('need to check')
+ sample = super().sample(sample_shape)
+ probs = super().probs
+ while len(probs.shape) < len(sample.shape):
+ probs = probs[None]
+ sample += probs - probs.detach() # ST-gradients
+ return sample
+
+def static_scan_for_lambda_return(fn, inputs, start):
+ last = start
+ indices = range(inputs[0].shape[0])
+ indices = reversed(indices)
+ flag = True
+ for index in indices:
+ inp = lambda x: (_input[x].unsqueeze(0) for _input in inputs)
+ last = fn(last, *inp(index))
+ if flag:
+ outputs = last
+ flag = False
+ else:
+ outputs = torch.cat([last, outputs], dim=0)
+ return outputs
+
+def lambda_return(
+ reward, value, pcont, bootstrap, lambda_, axis):
+ # Setting lambda=1 gives a discounted Monte Carlo return.
+ # Setting lambda=0 gives a fixed 1-step return.
+ #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
+ assert len(reward.shape) == len(value.shape), (reward.shape, value.shape)
+ if isinstance(pcont, (int, float)):
+ pcont = pcont * torch.ones_like(reward, device=reward.device)
+ dims = list(range(len(reward.shape)))
+ dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:]
+ if axis != 0:
+ reward = reward.permute(dims)
+ value = value.permute(dims)
+ pcont = pcont.permute(dims)
+ if bootstrap is None:
+ bootstrap = torch.zeros_like(value[-1], device=reward.device)
+ if len(bootstrap.shape) < len(value.shape):
+ bootstrap = bootstrap[None]
+ next_values = torch.cat([value[1:], bootstrap], 0)
+ inputs = reward + pcont * next_values * (1 - lambda_)
+ returns = static_scan_for_lambda_return(
+ lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg,
+ (inputs, pcont), bootstrap)
+ if axis != 0:
+ returns = returns.permute(dims)
+ return returns
+
+def static_scan(fn, inputs, start, reverse=False, unpack=False):
+ last = start
+ indices = range(inputs[0].shape[0])
+ flag = True
+ for index in indices:
+ inp = lambda x: (_input[x] for _input in inputs)
+ if unpack:
+ last = fn(last, *[inp[index] for inp in inputs])
+ else:
+ last = fn(last, inp(index))
+ if flag:
+ if type(last) == type({}):
+ outputs = {key: [value] for key, value in last.items()}
+ else:
+ outputs = []
+ for _last in last:
+ if type(_last) == type({}):
+ outputs.append({key: [value] for key, value in _last.items()})
+ else:
+ outputs.append([_last])
+ flag = False
+ else:
+ if type(last) == type({}):
+ for key in last.keys():
+ outputs[key].append(last[key])
+ else:
+ for j in range(len(outputs)):
+ if type(last[j]) == type({}):
+ for key in last[j].keys():
+ outputs[j][key].append(last[j][key])
+ else:
+ outputs[j].append(last[j])
+ # Stack everything at the end
+ if type(last) == type({}):
+ for key in last.keys():
+ outputs[key] = torch.stack(outputs[key], dim=0)
+ else:
+ for j in range(len(outputs)):
+ if type(last[j]) == type({}):
+ for key in last[j].keys():
+ outputs[j][key] = torch.stack(outputs[j][key], dim=0)
+ else:
+ outputs[j] = torch.stack(outputs[j], dim=0)
+ if type(last) == type({}):
+ outputs = [outputs]
+ return outputs
+
+class EnsembleRSSM(Module):
+
+ def __init__(
+ self, ensemble=5, stoch=30, deter=200, hidden=200, discrete=False,
+ act='SiLU', norm='none', std_act='softplus', min_std=0.1, action_dim=None, embed_dim=1536, device='cuda',
+ single_obs_posterior=False, cell_input='stoch', cell_type='gru',):
+ super().__init__()
+ assert action_dim is not None
+ self.device = device
+ self._embed_dim = embed_dim
+ self._action_dim = action_dim
+ self._ensemble = ensemble
+ self._stoch = stoch
+ self._deter = deter
+ self._hidden = hidden
+ self._discrete = discrete
+ self._act = get_act(act)
+ self._norm = norm
+ self._std_act = std_act
+ self._min_std = min_std
+ self._cell_type = cell_type
+ self.cell_input = cell_input
+ if cell_type == 'gru':
+ self._cell = GRUCell(self._hidden, self._deter, norm=True, device=self.device)
+ else:
+ raise NotImplementedError(f"{cell_type} not implemented")
+ self.single_obs_posterior = single_obs_posterior
+
+ if discrete:
+ self._ensemble_img_dist = nn.ModuleList([ nn.Linear(hidden, stoch*discrete) for _ in range(ensemble)])
+ self._obs_dist = nn.Linear(hidden, stoch*discrete)
+ else:
+ self._ensemble_img_dist = nn.ModuleList([ nn.Linear(hidden, 2*stoch) for _ in range(ensemble)])
+ self._obs_dist = nn.Linear(hidden, 2*stoch)
+
+ # Layer that projects (stoch, input) to cell_state space
+ cell_state_input_size = getattr(self, f'get_{self.cell_input}_size')()
+ self._img_in = nn.Sequential(nn.Linear(cell_state_input_size + action_dim, hidden, bias=norm != 'none'), NormLayer(norm, hidden))
+ # Layer that project deter -> hidden [before projecting hidden -> stoch]
+ self._ensemble_img_out = nn.ModuleList([ nn.Sequential(nn.Linear(self.get_deter_size(), hidden, bias=norm != 'none'), NormLayer(norm, hidden)) for _ in range(ensemble)])
+
+ if self.single_obs_posterior:
+ self._obs_out = nn.Sequential(nn.Linear(embed_dim, hidden, bias=norm != 'none'), NormLayer(norm, hidden))
+ else:
+ self._obs_out = nn.Sequential(nn.Linear(deter + embed_dim, hidden, bias=norm != 'none'), NormLayer(norm, hidden))
+
+ def initial(self, batch_size):
+ if self._discrete:
+ state = dict(
+ logit=torch.zeros([batch_size, self._stoch, self._discrete], device=self.device),
+ stoch=torch.zeros([batch_size, self._stoch, self._discrete], device=self.device),
+ deter=self._cell.get_initial_state(None, batch_size))
+ else:
+ state = dict(
+ mean=torch.zeros([batch_size, self._stoch], device=self.device),
+ std=torch.zeros([batch_size, self._stoch], device=self.device),
+ stoch=torch.zeros([batch_size, self._stoch], device=self.device),
+ deter=self._cell.get_initial_state(None, batch_size))
+ return state
+
+ def observe(self, embed, action, is_first, state=None):
+ swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
+ if state is None: state = self.initial(action.shape[0])
+
+ post, prior = static_scan(
+ lambda prev, inputs: self.obs_step(prev[0], *inputs),
+ (swap(action), swap(embed), swap(is_first)), (state, state))
+ post = {k: swap(v) for k, v in post.items()}
+ prior = {k: swap(v) for k, v in prior.items()}
+ return post, prior
+
+ def imagine(self, action, state=None, sample=True):
+ swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
+ if state is None:
+ state = self.initial(action.shape[0])
+ assert isinstance(state, dict), state
+ action = swap(action)
+ prior = static_scan(self.img_step, [action, float(sample) + torch.zeros(action.shape[0])], state, unpack=True)[0]
+ prior = {k: swap(v) for k, v in prior.items()}
+ return prior
+
+ def get_stoch_size(self,):
+ if self._discrete:
+ return self._stoch * self._discrete
+ else:
+ return self._stoch
+
+ def get_deter_size(self,):
+ return self._cell.state_size
+
+ def get_feat_size(self,):
+ return self.get_deter_size() + self.get_stoch_size()
+
+ def get_stoch(self, state):
+ stoch = state['stoch']
+ if self._discrete:
+ shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete]
+ stoch = stoch.reshape(shape)
+ return stoch
+
+ def get_deter(self, state):
+ return state['deter']
+
+ def get_feat(self, state):
+ deter = self.get_deter(state)
+ stoch = self.get_stoch(state)
+ return torch.cat([stoch, deter], -1)
+
+ def get_dist(self, state, ensemble=False):
+ if ensemble:
+ state = self._suff_stats_ensemble(state['deter'])
+ if self._discrete:
+ logit = state['logit']
+ dist = D.Independent(OneHotDist(logit.float()), 1)
+ else:
+ mean, std = state['mean'], state['std']
+ dist = D.Independent(D.Normal(mean, std), 1)
+ dist.sample = dist.rsample
+ return dist
+
+ def get_unif_dist(self, state):
+ if self._discrete:
+ logit = state['logit']
+ dist = D.Independent(OneHotDist(torch.ones_like(logit, device=logit.device)), 1)
+ else:
+ mean, std = state['mean'], state['std']
+ dist = D.Independent(D.Normal(torch.zeros_like(mean, device=mean.device), torch.ones_like(std, device=std.device)), 1)
+ dist.sample = dist.rsample
+ return dist
+
+ def obs_step(self, prev_state, prev_action, embed, is_first, should_sample=True):
+ if is_first.any():
+ prev_state = { k: torch.einsum('b,b...->b...', 1.0 - is_first.float(), x) for k, x in prev_state.items() }
+ prev_action = torch.einsum('b,b...->b...', 1.0 - is_first.float(), prev_action)
+ #
+ prior = self.img_step(prev_state, prev_action, should_sample)
+ stoch, stats = self.get_post_stoch(embed, prior, should_sample)
+ post = {'stoch': stoch, 'deter': prior['deter'], **stats}
+ return post, prior
+
+ def get_post_stoch(self, embed, prior, should_sample=True):
+ if self.single_obs_posterior:
+ x = embed
+ else:
+ x = torch.cat([prior['deter'], embed], -1)
+ x = self._obs_out(x)
+ x = self._act(x)
+
+ bs = list(x.shape[:-1])
+ x = x.reshape([-1, x.shape[-1]])
+ stats = self._suff_stats_layer('_obs_dist', x)
+ stats = { k: v.reshape( bs + list(v.shape[1:])) for k, v in stats.items()}
+
+ dist = self.get_dist(stats)
+ stoch = dist.sample() if should_sample else dist.mode()
+ return stoch, stats
+
+ def img_step(self, prev_state, prev_action, sample=True,):
+ prev_state_input = getattr(self, f'get_{self.cell_input}')(prev_state)
+ x = torch.cat([prev_state_input, prev_action], -1)
+ x = self._img_in(x)
+ x = self._act(x)
+ deter = prev_state['deter']
+ if self._cell_type == 'gru':
+ x, deter = self._cell(x, [deter])
+ temp_state = {'deter' : deter[0] }
+ else:
+ raise NotImplementedError(f"no {self._cell_type} cell method")
+ deter = deter[0] # It's wrapped in a list.
+ stoch, stats = self.get_stoch_stats_from_deter_state(temp_state, sample)
+ prior = {'stoch': stoch, 'deter': deter, **stats}
+ return prior
+
+ def get_stoch_stats_from_deter_state(self, temp_state, sample=True):
+ stats = self._suff_stats_ensemble(self.get_deter(temp_state))
+ index = torch.randint(0, self._ensemble, ())
+ stats = {k: v[index] for k, v in stats.items()}
+ dist = self.get_dist(stats)
+ if sample:
+ stoch = dist.sample()
+ else:
+ try:
+ stoch = dist.mode()
+ except:
+ stoch = dist.mean
+ return stoch, stats
+
+ def _suff_stats_ensemble(self, inp):
+ bs = list(inp.shape[:-1])
+ inp = inp.reshape([-1, inp.shape[-1]])
+ stats = []
+ for k in range(self._ensemble):
+ x = self._ensemble_img_out[k](inp)
+ x = self._act(x)
+ stats.append(self._suff_stats_layer('_ensemble_img_dist', x, k=k))
+ stats = {
+ k: torch.stack([x[k] for x in stats], 0)
+ for k, v in stats[0].items()}
+ stats = {
+ k: v.reshape([v.shape[0]] + bs + list(v.shape[2:]))
+ for k, v in stats.items()}
+ return stats
+
+ def _suff_stats_layer(self, name, x, k=None):
+ layer = getattr(self, name)
+ if k is not None:
+ layer = layer[k]
+ x = layer(x)
+ if self._discrete:
+ logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete])
+ return {'logit': logit}
+ else:
+ mean, std = torch.chunk(x, 2, -1)
+ std = {
+ 'softplus': lambda: F.softplus(std),
+ 'sigmoid': lambda: torch.sigmoid(std),
+ 'sigmoid2': lambda: 2 * torch.sigmoid(std / 2),
+ }[self._std_act]()
+ std = std + self._min_std
+ return {'mean': mean, 'std': std}
+
+ def vq_loss(self, post, prior, balance):
+ dim_repr = prior['output'].shape[-1]
+ # Vectors and codes are the same, but vectors have gradients
+ dyn_loss = balance * F.mse_loss(prior['output'], post['vectors'].detach()) + (1 - balance) * F.mse_loss(prior['output'].detach(), post['vectors'])
+ dyn_loss += balance * F.mse_loss(prior['output'], post['codes'].detach()) + (1 - balance) * F.mse_loss(prior['output'].detach(), post['codes'])
+ dyn_loss /= 2
+ vq_loss = 0.25 * F.mse_loss(post['output'], post['codes'].detach()) + F.mse_loss(post['output'].detach(), post['codes'])
+
+ loss = vq_loss + dyn_loss
+ return loss * dim_repr, dyn_loss * dim_repr
+
+ def kl_loss(self, post, prior, forward, balance, free, free_avg,):
+ kld = D.kl_divergence
+ sg = lambda x: {k: v.detach() for k, v in x.items()}
+ lhs, rhs = (prior, post) if forward else (post, prior)
+ mix = balance if forward else (1 - balance)
+ dtype = post['stoch'].dtype
+ device = post['stoch'].device
+ free_tensor = torch.tensor([free], dtype=dtype, device=device)
+ if balance == 0.5:
+ value = kld(self.get_dist(lhs), self.get_dist(rhs))
+ loss = torch.maximum(value, free_tensor).mean()
+ else:
+ value_lhs = value = kld(self.get_dist(lhs), self.get_dist(sg(rhs)))
+ value_rhs = kld(self.get_dist(sg(lhs)), self.get_dist(rhs))
+ if free_avg:
+ loss_lhs = torch.maximum(value_lhs.mean(), free_tensor)
+ loss_rhs = torch.maximum(value_rhs.mean(), free_tensor)
+ else:
+ loss_lhs = torch.maximum(value_lhs, free_tensor).mean()
+ loss_rhs = torch.maximum(value_rhs, free_tensor).mean()
+ loss = mix * loss_lhs + (1 - mix) * loss_rhs
+ return loss, value
+
+
+class Encoder(Module):
+
+ def __init__(
+ self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='SiLU', norm='none',
+ cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400], symlog_inputs=False,):
+ super().__init__()
+ self.shapes = shapes
+ self.cnn_keys = [
+ k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3]
+ self.mlp_keys = [
+ k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1]
+ print('Encoder CNN inputs:', list(self.cnn_keys))
+ print('Encoder MLP inputs:', list(self.mlp_keys))
+ self._act = get_act(act)
+ self._norm = norm
+ self._cnn_depth = cnn_depth
+ self._cnn_kernels = cnn_kernels
+ self._mlp_layers = mlp_layers
+ self._symlog_inputs = symlog_inputs
+
+ if len(self.cnn_keys) > 0:
+ self._conv_model = []
+ for i, kernel in enumerate(self._cnn_kernels):
+ if i == 0:
+ prev_depth = 3
+ else:
+ prev_depth = 2 ** (i-1) * self._cnn_depth
+ depth = 2 ** i * self._cnn_depth
+ self._conv_model.append(nn.Conv2d(prev_depth, depth, kernel, stride=2))
+ self._conv_model.append(ImgChLayerNorm(depth) if norm == 'layer' else NormLayer(norm,depth))
+ self._conv_model.append(self._act)
+ self._conv_model = nn.Sequential(*self._conv_model)
+ if len(self.mlp_keys) > 0:
+ self._mlp_model = []
+ for i, width in enumerate(self._mlp_layers):
+ if i == 0:
+ prev_width = np.sum([shapes[k] for k in self.mlp_keys])
+ else:
+ prev_width = self._mlp_layers[i-1]
+ self._mlp_model.append(nn.Linear(prev_width, width, bias=norm != 'none'))
+ self._mlp_model.append(NormLayer(norm, width))
+ self._mlp_model.append(self._act)
+ if len(self._mlp_model) == 0:
+ self._mlp_model.append(nn.Identity())
+ self._mlp_model = nn.Sequential(*self._mlp_model)
+
+ def forward(self, data):
+ key, shape = list(self.shapes.items())[0]
+ batch_dims = data[key].shape[:-len(shape)]
+ data = {
+ k: v.reshape((-1,) + tuple(v.shape)[len(batch_dims):])
+ for k, v in data.items()}
+ outputs = []
+ if self.cnn_keys:
+ outputs.append(self._cnn({k: data[k] for k in self.cnn_keys}))
+ if self.mlp_keys:
+ outputs.append(self._mlp({k: data[k] for k in self.mlp_keys}))
+ output = torch.cat(outputs, -1)
+ return output.reshape(batch_dims + output.shape[1:])
+
+ def _cnn(self, data):
+ x = torch.cat(list(data.values()), -1)
+ x = self._conv_model(x)
+ return x.reshape(tuple(x.shape[:-3]) + (-1,))
+
+ def _mlp(self, data):
+ x = torch.cat(list(data.values()), -1)
+ if self._symlog_inputs:
+ x = symlog(x)
+ x = self._mlp_model(x)
+ return x
+
+
+class Decoder(Module):
+
+ def __init__(
+ self, shapes, cnn_keys=r'.*', mlp_keys=r'.*', act='SiLU', norm='none',
+ cnn_depth=48, cnn_kernels=(4, 4, 4, 4), mlp_layers=[400, 400, 400, 400], embed_dim=1024, mlp_dist='mse', image_dist='mse'):
+ super().__init__()
+ self._embed_dim = embed_dim
+ self._shapes = shapes
+ self.cnn_keys = [
+ k for k, v in shapes.items() if re.match(cnn_keys, k) and len(v) == 3]
+ self.mlp_keys = [
+ k for k, v in shapes.items() if re.match(mlp_keys, k) and len(v) == 1]
+ print('Decoder CNN outputs:', list(self.cnn_keys))
+ print('Decoder MLP outputs:', list(self.mlp_keys))
+ self._act = get_act(act)
+ self._norm = norm
+ self._cnn_depth = cnn_depth
+ self._cnn_kernels = cnn_kernels
+ self._mlp_layers = mlp_layers
+ self.channels = {k: self._shapes[k][0] for k in self.cnn_keys}
+ self._mlp_dist = mlp_dist
+ self._image_dist = image_dist
+
+ if len(self.cnn_keys) > 0:
+
+ self._conv_in = nn.Sequential(nn.Linear(embed_dim, 32*self._cnn_depth))
+ self._conv_model = []
+ for i, kernel in enumerate(self._cnn_kernels):
+ if i == 0:
+ prev_depth = 32*self._cnn_depth
+ else:
+ prev_depth = 2 ** (len(self._cnn_kernels) - (i - 1) - 2) * self._cnn_depth
+ depth = 2 ** (len(self._cnn_kernels) - i - 2) * self._cnn_depth
+ act, norm = self._act, self._norm
+ # Last layer is dist layer
+ if i == len(self._cnn_kernels) - 1:
+ depth, act, norm = sum(self.channels.values()), nn.Identity(), 'none'
+ self._conv_model.append(nn.ConvTranspose2d(prev_depth, depth, kernel, stride=2))
+ self._conv_model.append(ImgChLayerNorm(depth) if norm == 'layer' else NormLayer(norm, depth))
+ self._conv_model.append(act)
+ self._conv_model = nn.Sequential(*self._conv_model)
+ if len(self.mlp_keys) > 0:
+ self._mlp_model = []
+ for i, width in enumerate(self._mlp_layers):
+ if i == 0:
+ prev_width = embed_dim
+ else:
+ prev_width = self._mlp_layers[i-1]
+ self._mlp_model.append(nn.Linear(prev_width, width, bias=self._norm != 'none'))
+ self._mlp_model.append(NormLayer(self._norm, width))
+ self._mlp_model.append(self._act)
+ self._mlp_model = nn.Sequential(*self._mlp_model)
+ for key, shape in { k : shapes[k] for k in self.mlp_keys }.items():
+ self.add_module(f'dense_{key}', DistLayer(width, shape, dist=self._mlp_dist))
+
+ def forward(self, features):
+ outputs = {}
+
+ if self.cnn_keys:
+ outputs.update(self._cnn(features))
+ if self.mlp_keys:
+ outputs.update(self._mlp(features))
+ return outputs
+
+ def _cnn(self, features):
+ x = self._conv_in(features)
+ x = x.reshape([-1, 32 * self._cnn_depth, 1, 1,])
+ x = self._conv_model(x)
+ x = x.reshape(list(features.shape[:-1]) + list(x.shape[1:]))
+ if len(x.shape) == 5:
+ means = torch.split(x, list(self.channels.values()), 2)
+ else:
+ means = torch.split(x, list(self.channels.values()), 1)
+ image_dist = dict(mse=lambda x : MSEDist(x), normal_unit_std=lambda x : D.Independent(D.Normal(x, 1.0), 3))[self._image_dist]
+ dists = { key: image_dist(mean) for (key, shape), mean in zip(self.channels.items(), means)}
+ return dists
+
+ def _mlp(self, features):
+ shapes = {k: self._shapes[k] for k in self.mlp_keys}
+ x = features
+ x = self._mlp_model(x)
+ dists = {}
+ for key, shape in shapes.items():
+ dists[key] = getattr(self, f'dense_{key}')(x)
+ return dists
+
+
+class MLP(Module):
+
+ def __init__(self, in_shape, shape, layers, units, act='SiLU', norm='none', **out):
+ super().__init__()
+ self._in_shape = in_shape
+ if out['dist'] == 'twohot':
+ shape = 255
+ self._shape = (shape,) if isinstance(shape, int) else shape
+ self._layers = layers
+ self._units = units
+ self._norm = norm
+ self._act = get_act(act)
+ self._out = out
+
+ last_units = in_shape
+ for index in range(self._layers):
+ self.add_module(f'dense{index}', nn.Linear(last_units, units, bias=norm != 'none'))
+ self.add_module(f'norm{index}', NormLayer(norm, units))
+ last_units = units
+ self._out = DistLayer(units, shape, **out)
+
+ def forward(self, features):
+ x = features
+ x = x.reshape([-1, x.shape[-1]])
+ for index in range(self._layers):
+ x = getattr(self, f'dense{index}')(x)
+ x = getattr(self, f'norm{index}')(x)
+ x = self._act(x)
+ x = x.reshape(list(features.shape[:-1]) + [x.shape[-1]])
+ return self._out(x)
+
+
+class GRUCell(Module):
+
+ def __init__(self, inp_size, size, norm=False, act='Tanh', update_bias=-1, device='cuda', **kwargs):
+ super().__init__()
+ self._inp_size = inp_size
+ self._size = size
+ self._act = get_act(act)
+ self._norm = norm
+ self._update_bias = update_bias
+ self.device = device
+ self._layer = nn.Linear(inp_size + size, 3 * size, bias=(not norm), **kwargs)
+ if norm:
+ self._norm = nn.LayerNorm(3*size)
+
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return torch.zeros((batch_size), self._size, device=self.device)
+
+ @property
+ def state_size(self):
+ return self._size
+
+ def forward(self, inputs, deter_state):
+ """
+ inputs : non-linear combination of previous stoch and action
+ deter_state : prev hidden state of the cell
+ """
+ deter_state = deter_state[0] # State is wrapped in a list.
+ parts = self._layer(torch.cat([inputs, deter_state], -1))
+ if self._norm:
+ parts = self._norm(parts)
+ reset, cand, update = torch.chunk(parts, 3, -1)
+ reset = torch.sigmoid(reset)
+ cand = self._act(reset * cand)
+ update = torch.sigmoid(update + self._update_bias)
+ output = update * cand + (1 - update) * deter_state
+ return output, [output]
+
+class DistLayer(Module):
+
+ def __init__(
+ self, in_dim, shape, dist='mse', min_std=0.1, max_std=1.0, init_std=0.0, bias=True):
+ super().__init__()
+ self._in_dim = in_dim
+ self._shape = shape if type(shape) in [list,tuple] else [shape]
+ self._dist = dist
+ self._min_std = min_std
+ self._init_std = init_std
+ self._max_std = max_std
+ self._out = nn.Linear(in_dim, int(np.prod(shape)) , bias=bias)
+ if dist in ('normal', 'tanh_normal', 'trunc_normal'):
+ self._std = nn.Linear(in_dim, int(np.prod(shape)) )
+
+ def forward(self, inputs):
+ out = self._out(inputs)
+ out = out.reshape(list(inputs.shape[:-1]) + list(self._shape))
+ if self._dist in ('normal', 'tanh_normal', 'trunc_normal'):
+ std = self._std(inputs)
+ std = std.reshape(list(inputs.shape[:-1]) + list(self._shape))
+ if self._dist == 'mse':
+ return MSEDist(out,)
+ if self._dist == 'normal_unit_std':
+ dist = D.Normal(out, 1.0)
+ dist.sample = dist.rsample
+ return D.Independent(dist, len(self._shape))
+ if self._dist == 'normal':
+ mean = torch.tanh(out)
+ std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std
+ dist = D.Normal(mean, std)
+ dist.sample = dist.rsample
+ return D.Independent(dist, len(self._shape))
+ if self._dist == 'binary':
+ out = torch.sigmoid(out)
+ dist = BernoulliDist(out)
+ return D.Independent(dist, len(self._shape))
+ if self._dist == 'tanh_normal':
+ mean = 5 * torch.tanh(out / 5)
+ std = F.softplus(std + self._init_std) + self._min_std
+ dist = utils.SquashedNormal(mean, std)
+ dist = D.Independent(dist, len(self._shape))
+ return SampleDist(dist)
+ if self._dist == 'trunc_normal':
+ mean = torch.tanh(out)
+ std = 2 * torch.sigmoid((std + self._init_std) / 2) + self._min_std
+ dist = utils.TruncatedNormal(mean, std)
+ return D.Independent(dist, 1)
+ if self._dist == 'onehot':
+ return OneHotDist(out.float())
+ if self._dist == 'twohot':
+ return TwoHotDist(out.float())
+ if self._dist == 'symlog_mse':
+ return SymlogDist(out, len(self._shape), 'mse')
+ raise NotImplementedError(self._dist)
+
+
+class NormLayer(Module):
+
+ def __init__(self, name, dim=None):
+ super().__init__()
+ if name == 'none':
+ self._layer = None
+ elif name == 'layer':
+ assert dim != None
+ self._layer = nn.LayerNorm(dim)
+ else:
+ raise NotImplementedError(name)
+
+ def forward(self, features):
+ if self._layer is None:
+ return features
+ return self._layer(features)
+
+
+def get_act(name):
+ if name == 'none':
+ return nn.Identity()
+ elif hasattr(nn, name):
+ return getattr(nn, name)()
+ else:
+ raise NotImplementedError(name)
+
+
+class Optimizer:
+
+ def __init__(
+ self, name, parameters, lr, eps=1e-4, clip=None, wd=None,
+ opt='adam', wd_pattern=r'.*', use_amp=False):
+ assert 0 <= wd < 1
+ assert not clip or 1 <= clip
+ self._name = name
+ self._clip = clip
+ self._wd = wd
+ self._wd_pattern = wd_pattern
+ self._opt = {
+ 'adam': lambda: torch.optim.Adam(parameters, lr, eps=eps),
+ 'nadam': lambda: torch.optim.Nadam(parameters, lr, eps=eps),
+ 'adamax': lambda: torch.optim.Adamax(parameters, lr, eps=eps),
+ 'sgd': lambda: torch.optim.SGD(parameters, lr),
+ 'momentum': lambda: torch.optim.SGD(lr, momentum=0.9),
+ }[opt]()
+ self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
+ self._once = True
+
+ def __call__(self, loss, params):
+ params = list(params)
+ assert len(loss.shape) == 0 or (len(loss.shape) == 1 and loss.shape[0] == 1), (self._name, loss.shape)
+ metrics = {}
+
+ # Count parameters.
+ if self._once:
+ count = sum(p.numel() for p in params if p.requires_grad)
+ print(f'Found {count} {self._name} parameters.')
+ self._once = False
+
+ # Check loss.
+ metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy()
+
+ # Compute scaled gradient.
+ self._scaler.scale(loss).backward()
+ self._scaler.unscale_(self._opt)
+
+ # Gradient clipping.
+ if self._clip:
+ norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
+ metrics[f'{self._name}_grad_norm'] = norm.item()
+
+ # Weight decay.
+ if self._wd:
+ self._apply_weight_decay(params)
+
+ # # Apply gradients.
+ self._scaler.step(self._opt)
+ self._scaler.update()
+
+ self._opt.zero_grad()
+ return metrics
+
+ def _apply_weight_decay(self, varibs):
+ nontrivial = (self._wd_pattern != r'.*')
+ if nontrivial:
+ raise NotImplementedError('Non trivial weight decay')
+ else:
+ for var in varibs:
+ var.data = (1 - self._wd) * var.data
+
+class StreamNorm:
+
+ def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8, device='cuda'):
+ # Momentum of 0 normalizes only based on the current batch.
+ # Momentum of 1 disables normalization.
+ self.device = device
+ self._shape = tuple(shape)
+ self._momentum = momentum
+ self._scale = scale
+ self._eps = eps
+ self.mag = None # torch.ones(shape).to(self.device)
+
+ self.step = 0
+ self.mean = None # torch.zeros(shape).to(self.device)
+ self.square_mean = None # torch.zeros(shape).to(self.device)
+
+ def reset(self):
+ self.step = 0
+ self.mag = None # torch.ones_like(self.mag).to(self.device)
+ self.mean = None # torch.zeros_like(self.mean).to(self.device)
+ self.square_mean = None # torch.zeros_like(self.square_mean).to(self.device)
+
+ def __call__(self, inputs):
+ metrics = {}
+ self.update(inputs)
+ metrics['mean'] = inputs.mean()
+ metrics['std'] = inputs.std()
+ outputs = self.transform(inputs)
+ metrics['normed_mean'] = outputs.mean()
+ metrics['normed_std'] = outputs.std()
+ return outputs, metrics
+
+ def update(self, inputs):
+ self.step += 1
+ batch = inputs.reshape((-1,) + self._shape)
+
+ mag = torch.abs(batch).mean(0)
+ if self.mag is not None:
+ self.mag.data = self._momentum * self.mag.data + (1 - self._momentum) * mag
+ else:
+ self.mag = mag.clone().detach()
+
+ mean = torch.mean(batch)
+ if self.mean is not None:
+ self.mean.data = self._momentum * self.mean.data + (1 - self._momentum) * mean
+ else:
+ self.mean = mean.clone().detach()
+
+ square_mean = torch.mean(batch * batch)
+ if self.square_mean is not None:
+ self.square_mean.data = self._momentum * self.square_mean.data + (1 - self._momentum) * square_mean
+ else:
+ self.square_mean = square_mean.clone().detach()
+
+ def transform(self, inputs):
+ if self._momentum == 1:
+ return inputs
+ values = inputs.reshape((-1,) + self._shape)
+ values /= self.mag[None] + self._eps
+ values *= self._scale
+ return values.reshape(inputs.shape)
+
+ def corrected_mean_var_std(self,):
+ corr = 1 # 1 - self._momentum ** self.step # NOTE: this led to exploding values for first few iterations
+ corr_mean = self.mean / corr
+ corr_var = (self.square_mean / corr) - self.mean ** 2
+ corr_std = torch.sqrt(torch.maximum(corr_var, torch.zeros_like(corr_var, device=self.device)) + self._eps)
+ return corr_mean, corr_var, corr_std
+
+class RequiresGrad:
+
+ def __init__(self, model):
+ self._model = model
+
+ def __enter__(self):
+ self._model.requires_grad_(requires_grad=True)
+
+ def __exit__(self, *args):
+ self._model.requires_grad_(requires_grad=False)
+
+class RewardEMA:
+ """running mean and std"""
+
+ def __init__(self, device, alpha=1e-2):
+ self.device = device
+ self.alpha = alpha
+ self.range = torch.tensor([0.05, 0.95]).to(device)
+
+ def __call__(self, x, ema_vals):
+ flat_x = torch.flatten(x.detach())
+ x_quantile = torch.quantile(input=flat_x, q=self.range)
+ # this should be in-place operation
+ ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
+ scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
+ offset = ema_vals[0]
+ return offset.detach(), scale.detach()
+
+class ImgChLayerNorm(nn.Module):
+ def __init__(self, ch, eps=1e-03):
+ super(ImgChLayerNorm, self).__init__()
+ self.norm = torch.nn.LayerNorm(ch, eps=eps)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ x = x.permute(0, 3, 1, 2)
+ return x
\ No newline at end of file
diff --git a/agent/genrl.py b/agent/genrl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8bf95670616ef291dba25d942570cdae409c863
--- /dev/null
+++ b/agent/genrl.py
@@ -0,0 +1,124 @@
+import torch
+from agent.dreamer import DreamerAgent, ActorCritic, stop_gradient, env_reward
+import agent.dreamer_utils as common
+import agent.video_utils as video_utils
+from tools.genrl_utils import *
+
+def connector_update_fn(self, module_name, data, outputs, metrics):
+ connector = getattr(self, module_name)
+ n_frames = connector.n_frames
+ B, T = data['observation'].shape[:2]
+
+ # video embed are actions
+ if getattr(self.cfg, "viclip_encode", False):
+ video_embed = data['clip_video']
+ else:
+ # Obtaining video embed
+ with torch.no_grad():
+ viclip_model = getattr(self, 'viclip_model')
+ processed_obs = viclip_model.preprocess_transf(data['observation'].reshape(B*T, *data['observation'].shape[2:]) / 255)
+ reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224)
+ video_embed = viclip_model.get_vid_features(reshaped_obs.to(viclip_model.device))
+
+ # Get posterior states from original model
+ wm_post = outputs['post']
+ return connector.update(video_embed, wm_post)
+
+class GenRLAgent(DreamerAgent):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.n_frames = 8 # NOTE: this should become an hyperparam if changing the model
+ self.viclip_emb_dim = 512 # NOTE: this should become an hyperparam if changing the model
+
+ assert self.cfg.batch_length % self.n_frames == 0, "Fix batch length param"
+
+ if 'clip_video' in self.obs_space:
+ self.viclip_emb_dim = self.obs_space['clip_video'].shape[0]
+
+ connector = video_utils.VideoSSM(**self.cfg.connector, **self.cfg.connector_rssm, connector_kl=self.cfg.connector_kl,
+ n_frames=self.n_frames, action_dim=self.viclip_emb_dim + self.n_frames,
+ clip_add_noise=self.cfg.clip_add_noise, clip_lafite_noise=self.cfg.clip_lafite_noise,
+ device=self.device, cell_input='stoch')
+
+ connector.to(self.device)
+
+ self.wm.add_module_to_update('connector', connector, connector_update_fn, detached=self.cfg.connector.detached_post)
+
+ if getattr(self.cfg, 'imag_reward_fn', None) is not None:
+ self.instantiate_imag_behavior()
+
+ def instantiate_imag_behavior(self):
+ self._imag_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size, name='imag').to(self.device)
+ self._imag_behavior.rewnorm = common.StreamNorm(**self.cfg.imag_reward_norm, device=self.device)
+
+ def finetune_mode(self,):
+ self._acting_behavior = self._imag_behavior
+ self.wm.detached_update_fns = {}
+ self.wm.e2e_update_fns = {}
+ self.wm.grad_heads.append('reward')
+
+ def update_wm(self, data, step):
+ return super().update_wm(data, step)
+
+ def report(self, data, key='observation', nvid=8):
+ # Redefine data with trim
+ n_frames = self.wm.connector.n_frames
+ obs = data['observation'][:nvid, n_frames:]
+ B, T = obs.shape[:2]
+
+ report_data = super().report(data)
+ wm = self.wm
+ n_frames = wm.connector.n_frames
+
+ # Init is same as Dreamer for reporting
+ truth = data[key][:nvid] / 255
+ decoder = wm.heads['decoder'] # B, T, C, H, W
+ preprocessed_data = self.wm.preprocess(data)
+
+ embed = wm.encoder(preprocessed_data)
+ states, _ = wm.rssm.observe(embed[:nvid, :n_frames], data['action'][:nvid, :n_frames], data['is_first'][:nvid, :n_frames])
+ recon = decoder(wm.decoder_input_fn(states))[key].mean[:nvid] # mode
+ dreamer_init = {k: v[:, -1] for k, v in states.items()}
+
+ # video embed are actions
+ if getattr(self.cfg, "viclip_encode", False):
+ video_embed = data['clip_video'][:nvid,n_frames*2-1::n_frames]
+ else:
+ # Obtain embed
+ processed_obs = wm.viclip_model.preprocess_transf(obs.reshape(B*T, *obs.shape[2:]) / 255)
+ reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224)
+ video_embed = wm.viclip_model.get_vid_features(reshaped_obs.to(wm.viclip_model.device))
+
+ video_embed = video_embed.to(self.device)
+
+ # Get actions
+ video_embed = video_embed.reshape(B, T // n_frames, -1).unsqueeze(2).repeat(1,1,n_frames, 1).reshape(B, T, -1)
+ prior = wm.connector.video_imagine(video_embed, dreamer_init, reset_every_n_frames=False)
+ prior_recon = decoder(wm.decoder_input_fn(prior))[key].mean # mode
+ model = torch.clip(torch.cat([recon[:, :n_frames] + 0.5, prior_recon + 0.5], 1), 0, 1)
+ error = (model - truth + 1) / 2
+
+ # Add video to logs
+ video = torch.cat([truth, model, error], 3)
+ report_data['video_clip_pred'] = video
+
+ return report_data
+
+ def update_imag_behavior(self, state=None, outputs=None, metrics={}, seq_data=None,):
+ if getattr(self.cfg, 'imag_reward_fn', None) is None:
+ return outputs['post'], metrics
+ if outputs is not None:
+ post = outputs['post']
+ is_terminal = outputs['is_terminal']
+ else:
+ seq_data = self.wm.preprocess(seq_data)
+ embed = self.wm.encoder(seq_data)
+ post, _ = self.wm.rssm.observe(
+ embed, seq_data['action'], seq_data['is_first'])
+ is_terminal = seq_data['is_terminal']
+ #
+ start = {k: stop_gradient(v) for k,v in post.items()}
+ imag_reward_fn = lambda seq: globals()[self.cfg.imag_reward_fn](self, seq, **self.cfg.imag_reward_args)
+ metrics.update(self._imag_behavior.update(self.wm, start, is_terminal, imag_reward_fn,))
+ return start, metrics
\ No newline at end of file
diff --git a/agent/genrl.yaml b/agent/genrl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a06c0842e83c91fd6addc11fe23b5dc228e5c95
--- /dev/null
+++ b/agent/genrl.yaml
@@ -0,0 +1,22 @@
+# @package agent
+_target_: agent.genrl.GenRLAgent
+name: genrl
+cfg: ???
+obs_space: ???
+act_spec: ???
+grad_heads: [decoder]
+reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
+actor_ent: 0
+additional_report_fns: ['report_text2video']
+
+clip_add_noise: 0.0
+clip_lafite_noise: 0.5
+
+connector: { token_dropout: 0, loss_scale: 1, denoising_ae: True, detached_post: True, temporal_embeds: False, rescale_embeds: True}
+connector_rssm: {ensemble: 1, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, norm: layer, std_act: softplus, min_std: 0.1, single_obs_posterior: false, learn_initial: True } # act: elu,
+connector_kl: {free: 0.0, forward: True, balance: 0.8, free_avg: False, } # note forward is true by default
+
+imag_reward_fn: null
+imag_reward_norm: {momentum: 1.00, scale: 1.0, eps: 1e-8}
+imag_reward_args: {score_fn: 'max_cosine', sample_for_target: False, align_initial : False, weighted_align : False, align_sequence: True, skip_first_target: True }
+# +imag_reward_args.task_prompt
\ No newline at end of file
diff --git a/agent/plan2explore.py b/agent/plan2explore.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdc38b5071bae1dba0330d37705133edb5303f18
--- /dev/null
+++ b/agent/plan2explore.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from agent.dreamer import DreamerAgent, stop_gradient
+import agent.dreamer_utils as common
+
+class Disagreement(nn.Module):
+ def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None):
+ super().__init__()
+ if pred_dim is None: pred_dim = obs_dim
+ self.ensemble = nn.ModuleList([
+ nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim),
+ nn.ReLU(), nn.Linear(hidden_dim, pred_dim))
+ for _ in range(n_models)
+ ])
+
+ def forward(self, obs, action, next_obs):
+ assert obs.shape[0] == next_obs.shape[0]
+ assert obs.shape[0] == action.shape[0]
+
+ errors = []
+ for model in self.ensemble:
+ next_obs_hat = model(torch.cat([obs, action], dim=-1))
+ model_error = torch.norm(next_obs - next_obs_hat,
+ dim=-1,
+ p=2,
+ keepdim=True)
+ errors.append(model_error)
+
+ return torch.cat(errors, dim=1)
+
+ def get_disagreement(self, obs, action):
+ assert obs.shape[0] == action.shape[0]
+
+ preds = []
+ for model in self.ensemble:
+ next_obs_hat = model(torch.cat([obs, action], dim=-1))
+ preds.append(next_obs_hat)
+ preds = torch.stack(preds, dim=0)
+ return torch.var(preds, dim=0).mean(dim=-1)
+
+
+class Plan2Explore(DreamerAgent):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ in_dim = self.wm.inp_size
+ pred_dim = self.wm.embed_dim
+ self.hidden_dim = pred_dim
+ self.reward_free = True
+
+ self.disagreement = Disagreement(in_dim, self.act_dim,
+ self.hidden_dim, pred_dim=pred_dim).to(self.device)
+
+ # optimizers
+ self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
+ self.disagreement.train()
+ self.requires_grad_(requires_grad=False)
+
+ def update_disagreement(self, obs, action, next_obs, step):
+ metrics = dict()
+
+ error = self.disagreement(obs, action, next_obs)
+
+ loss = error.mean()
+
+ metrics.update(self.disagreement_opt(loss, self.disagreement.parameters()))
+
+ metrics['disagreement_loss'] = loss.item()
+
+ return metrics
+
+ def compute_intr_reward(self, seq):
+ obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:])
+ intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device)
+ if len(action.shape) > 2:
+ B, T, _ = action.shape
+ obs = obs.reshape(B*T, -1)
+ action = action.reshape(B*T, -1)
+ reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1)
+ else:
+ reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1)
+ intr_rew[1:] = reward
+ return intr_rew
+
+ def update(self, data, step):
+ metrics = {}
+ B, T, _ = data['action'].shape
+ state, outputs, mets = self.wm.update(data, state=None)
+ metrics.update(mets)
+ start = outputs['post']
+ start = {k: stop_gradient(v) for k,v in start.items()}
+ if self.reward_free:
+ T = T-1
+ inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1)
+ action = data['action'][:, 1:].reshape(B*T, -1)
+ out = stop_gradient(outputs['embed'][:,1:]).reshape(B*T,-1)
+ with common.RequiresGrad(self.disagreement):
+ with torch.cuda.amp.autocast(enabled=self._use_amp):
+ metrics.update(
+ self.update_disagreement(inp, action, out, step))
+ metrics.update(self._acting_behavior.update(
+ self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward))
+ else:
+ reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean
+ metrics.update(self._acting_behavior.update(
+ self.wm, start, data['is_terminal'], reward_fn))
+ return state, metrics
\ No newline at end of file
diff --git a/agent/plan2explore.yaml b/agent/plan2explore.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..abf306f183a76bff2e65141af1cf0f4873b505d3
--- /dev/null
+++ b/agent/plan2explore.yaml
@@ -0,0 +1,9 @@
+# @package agent
+_target_: agent.plan2explore.Plan2Explore
+name: plan2explore
+cfg: ???
+obs_space: ???
+act_spec: ???
+grad_heads: [decoder]
+reward_norm: {momentum: 0.95, scale: 1.0, eps: 1e-8}
+actor_ent: 0
\ No newline at end of file
diff --git a/agent/video_utils.py b/agent/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7b81907c83820484127b381e872271539e1bab
--- /dev/null
+++ b/agent/video_utils.py
@@ -0,0 +1,240 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import agent.dreamer_utils as common
+from collections import defaultdict
+import numpy as np
+
+class ResidualLinear(nn.Module):
+ def __init__(self, in_channels, out_channels, norm='layer', act='SiLU', prenorm=False):
+ super().__init__()
+ self.norm_layer = common.NormLayer(norm, in_channels if prenorm else out_channels)
+ self.act = common.get_act(act)
+ self.layer = nn.Linear(in_channels, out_channels)
+ self.prenorm = prenorm
+ self.res_proj = nn.Identity() if in_channels == out_channels else nn.Linear(in_channels, out_channels)
+
+ def forward(self, x):
+ if self.prenorm:
+ h = self.norm_layer(x)
+ h = self.layer(h)
+ else:
+ h = self.layer(x)
+ h = self.norm_layer(h)
+ h = self.act(h)
+ return h + self.res_proj(x)
+
+class UNetDenoiser(nn.Module):
+ def __init__(self, in_channels : int, mid_channels : int, n_layers : int, norm='layer', act= 'SiLU', ):
+ super().__init__()
+ out_channels = in_channels
+ self.down = nn.ModuleList()
+ for i in range(n_layers):
+ if i == (n_layers - 1):
+ self.down.append(ResidualLinear(in_channels, mid_channels, norm=norm, act=act))
+ else:
+ self.down.append(ResidualLinear(in_channels, in_channels, norm=norm, act=act))
+
+ self.mid = nn.ModuleList()
+ for i in range(n_layers):
+ self.mid.append(ResidualLinear(mid_channels, mid_channels, norm=norm, act=act))
+
+ self.up = nn.ModuleList()
+ for i in range(n_layers):
+ if i == 0:
+ self.up.append(ResidualLinear(mid_channels * 2, out_channels, norm='none', act='Identity'))
+ else:
+ self.up.append(ResidualLinear(out_channels * 2, out_channels, norm=norm, act=act))
+
+ def forward(self, x):
+ down_res = []
+ for down_layer in self.down:
+ x = down_layer(x)
+ down_res.append(x)
+
+ for mid_layer in self.mid:
+ x = mid_layer(x)
+
+ down_res.reverse()
+ for up_layer, res in zip(self.up, down_res):
+ x = up_layer(torch.cat([x, res], dim=-1))
+ return x
+
+
+class VideoSSM(common.EnsembleRSSM):
+ def __init__(self, *args,
+ connector_kl={}, temporal_embeds=False, detached_post=True, n_frames=8,
+ token_dropout=0., loss_scale=1, clip_add_noise=0, clip_lafite_noise=0,
+ rescale_embeds=False, denoising_ae=False, learn_initial=True, **kwargs,):
+ super().__init__(*args, **kwargs)
+ #
+ self.n_frames = n_frames
+ # by default, adding the n_frames in actions (doesn't hurt and easier to test whether it's useful or not)
+ self.viclip_emb_dim = kwargs['action_dim'] - self.n_frames
+ #
+ self.temporal_embeds = temporal_embeds
+ self.detached_post = detached_post
+ self.connector_kl = connector_kl
+ self.token_dropout = token_dropout
+ self.loss_scale = loss_scale
+ self.rescale_embeds = rescale_embeds
+ self.clip_add_noise = clip_add_noise
+ self.clip_lafite_noise = clip_lafite_noise
+ self.clip_const = np.sqrt(self.viclip_emb_dim).item()
+ self.denoising_ae = denoising_ae
+ if self.denoising_ae:
+ self.aligner = UNetDenoiser(self.viclip_emb_dim, self.viclip_emb_dim // 2, n_layers=2, norm='layer', act='SiLU')
+ self.learn_initial = learn_initial
+ if self.learn_initial:
+ self.initial_state_pred = nn.Sequential(
+ nn.Linear(kwargs['action_dim'], kwargs['hidden']),
+ common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'),
+ nn.Linear(kwargs['hidden'], kwargs['hidden']),
+ common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'),
+ nn.Linear(kwargs['hidden'], kwargs['deter'])
+ )
+ # Deleting non-useful models
+ del self._obs_out
+ del self._obs_dist
+
+ def initial(self, batch_size, init_embed=None, ignore_learned=False):
+ init = super().initial(batch_size)
+ if self.learn_initial and not ignore_learned and hasattr(self, 'initial_state_pred'):
+ assert init_embed is not None
+ # patcher to avoid edge cases
+ if init_embed.shape[-1] == self.viclip_emb_dim:
+ patcher = torch.zeros((*init_embed.shape[:-1], 8), device=self.device)
+ init_embed = torch.cat([init_embed, patcher], dim=-1)
+ init['deter'] = self.initial_state_pred(init_embed)
+ stoch, stats = self.get_stoch_stats_from_deter_state(init)
+ init['stoch'] = stoch
+ init.update(stats)
+ return init
+
+ def get_action(self, video_embed):
+ n_frames = self.n_frames
+ B, T = video_embed.shape[:2]
+
+ if self.rescale_embeds:
+ video_embed = video_embed * self.clip_const
+
+ temporal_embeds = F.one_hot(torch.arange(T).to(video_embed.device) % n_frames, n_frames).reshape(1, T, n_frames,).repeat(B, 1, 1,)
+ if not self.temporal_embeds:
+ temporal_embeds *= 0
+
+ return torch.cat([video_embed, temporal_embeds],dim=-1)
+
+ def update(self, video_embed, wm_post):
+ n_frames = self.n_frames
+ B, T = video_embed.shape[:2]
+ loss = 0
+ metrics = {}
+
+ # NOVEL
+ video_embed = video_embed[:,n_frames-1::n_frames] # tested
+ video_embed = video_embed.to(self.device)
+ video_embed = video_embed.reshape(B, T // n_frames, 1, -1).repeat(1,1, n_frames, 1).reshape(B, T, -1)
+
+ orig_video_embed = video_embed
+
+ if self.clip_add_noise > 0:
+ video_embed = video_embed + torch.randn_like(video_embed, device=video_embed.device) * self.clip_add_noise
+ video_embed = nn.functional.normalize(video_embed, dim=-1)
+ if self.clip_lafite_noise > 0:
+ normed_noise = F.normalize(torch.randn_like(video_embed, device=video_embed.device), dim=-1)
+ video_embed = (1 - self.clip_lafite_noise) * video_embed + self.clip_lafite_noise * normed_noise
+ video_embed = nn.functional.normalize(video_embed, dim=-1)
+
+ if self.denoising_ae:
+ assert (self.clip_lafite_noise + self.clip_add_noise) > 0, "Nothing to denoise"
+ denoised_embed = self.aligner(video_embed)
+ denoised_embed = F.normalize(denoised_embed, dim=-1)
+ denoising_loss = 1 - F.cosine_similarity(denoised_embed, orig_video_embed, dim=-1).mean() # works same as F.mse_loss(denoised_embed, orig_video_embed).mean()
+ loss += denoising_loss
+ metrics['aligner_cosine_distance'] = denoising_loss
+ # if using a denoiser, it's the denoiser's duty to denoise the video embed
+ video_embed = orig_video_embed # could also be denoised_embed for e2e training
+
+ embed_actions = self.get_action(video_embed)
+
+ if self.detached_post:
+ wm_post = { k : v.reshape(B, T, *v.shape[2:]).detach() for k,v in wm_post.items() }
+ else:
+ wm_post = { k : v.reshape(B, T, *v.shape[2:]) for k,v in wm_post.items() }
+
+ # Get prior states
+ prior_states = defaultdict(list)
+ for t in range(T):
+ # Get video action
+ action = embed_actions[:, t]
+
+ if t == 0:
+ prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action)
+ else:
+ # Get deter from prior, get stoch from wm_post
+ prev_state = prior
+ prev_state[self.cell_input] = wm_post[self.cell_input][:, t-1]
+
+ if self.token_dropout > 0:
+ prev_state['stoch'] = torch.einsum('b...,b->b...', prev_state['stoch'], (torch.rand(B, device=action.device) > self.token_dropout).float() )
+
+ prior = self.img_step(prev_state, action)
+ for k in prior:
+ prior_states[k].append(prior[k])
+
+ # Aggregate
+ for k in prior_states:
+ prior_states[k] = torch.stack(prior_states[k], dim=1)
+
+ # Compute loss
+ prior = prior_states
+
+ kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl)
+ video_loss = self.loss_scale * kl_loss
+ metrics['connector_kl'] = kl_value.mean()
+ loss += video_loss
+
+ # Compute initial KL
+ video_embed = video_embed.reshape(B, T // n_frames, n_frames, -1)[:,1:,0].reshape(B * (T//n_frames-1), 1, -1) # taking only one (0) and skipping first temporal step
+ embed_actions = self.get_action(video_embed)
+ wm_post = { k : v.reshape(B, T // n_frames, n_frames, *v.shape[2:])[:,1:,0].reshape(B * (T//n_frames-1), *v.shape[2:]) for k,v in wm_post.items() }
+ action = embed_actions[:, 0]
+ prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action)
+ prior = self.img_step(prev_state, action)
+ kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl)
+ metrics['connector_initial_kl'] = kl_value.mean()
+
+ return loss, metrics
+
+ def video_imagine(self, video_embed, dreamer_init=None, sample=True, reset_every_n_frames=True, denoise=False):
+ n_frames = self.n_frames
+ B, T = video_embed.shape[:2]
+
+ if self.denoising_ae and denoise:
+ denoised_embed = self.aligner(video_embed)
+ video_embed = F.normalize(denoised_embed, dim=-1)
+
+ action = self.get_action(video_embed)
+ # Imagine
+ init = self.initial(batch_size=B, init_embed=action[:, 0]) # -> this ensures only stoch is used from the current frame
+ if dreamer_init is not None:
+ init[self.cell_input] = dreamer_init[self.cell_input]
+
+ if reset_every_n_frames:
+ prior_states = defaultdict(list)
+ for action_chunk in torch.chunk(action, T // n_frames, dim=1):
+ prior = self.imagine(action_chunk, init, sample=sample)
+ for k in prior:
+ prior_states[k].append(prior[k])
+
+ # -> this ensures only stoch is used from the current frame
+ init = self.initial(batch_size=B, ignore_learned=True)
+ init[self.cell_input] = prior[self.cell_input][:, -1]
+
+ # Agg
+ for k in prior_states:
+ prior_states[k] = torch.cat(prior_states[k], dim=1)
+ prior = prior_states
+ else:
+ prior = self.imagine(action, init, sample=sample)
+ return prior
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f2cf4c8d8726c1105d452457850f0a64f9e3de0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,80 @@
+import os
+import sys
+import gradio as gr
+
+# prototyping
+# from demo_test import Text2Video, Video2Video
+
+from demo.t2v import Text2Video
+
+t2v_examples = [
+ ['walk fast clean',16,],
+ ['run fast clean',16,],
+ ['standing up',16],
+ ['doing the splits',16],
+ ['doing backflips',16],
+ ['a headstand',16],
+ ['karate kick',16],
+ ['crunch abs',16],
+ ['doing push ups',16],
+]
+
+def do_nothing():
+ return
+
+def videocrafter_demo(result_dir='./tmp/'):
+ text2video = Text2Video(result_dir)
+ # video2video = Video2Video(result_dir)
+
+ # tex
+ with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
+ gr.Markdown(" \
+ \
+ [Models] ")
+
+ gr.Markdown(" Notes: ")
+ gr.Markdown(" - Low quality of the videos generated is expected, as the work focuses on visual-language alignment for behavior learning, not on video generation quality. ")
+ gr.Markdown(" - The model is trained on small 64x64 images, and the videos are generated only from a small 512-dimensional embedding. ")
+ gr.Markdown(" - Some prompts require styling instructions, e.g. fast, clean, in order to work well. See some of the examples. ")
+
+ #######t2v#######
+ with gr.Tab(label="Text2Video"):
+ with gr.Column():
+ with gr.Row(): # .style(equal_height=False)
+ with gr.Column():
+ input_text = gr.Text(label='prompt')
+ duration = gr.Slider(minimum=8, maximum=32, elem_id=f"duration", label="duration", value=16, step=8)
+ send_btn = gr.Button("Send")
+ with gr.Column(): # label='result',
+ pass
+ with gr.Column(): # label='result',
+ output_video_1 = gr.Video(autoplay=True, width=256, height=256)
+ with gr.Row():
+ gr.Examples(examples=t2v_examples,
+ inputs=[input_text,duration],
+ outputs=[output_video_1],
+ fn=text2video.get_prompt,
+ cache_examples=False)
+ #cache_examples=os.getenv('SYSTEM') == 'spaces')
+ send_btn.click(
+ fn=text2video.get_prompt,
+ inputs=[input_text,duration],
+ outputs=[output_video_1],
+ )
+ input_text.submit(
+ fn=text2video.get_prompt,
+ inputs=[input_text,duration],
+ outputs=[output_video_1],
+ )
+
+ return videocrafter_iface
+
+if __name__ == "__main__":
+ result_dir = os.path.join('./', 'results')
+ videocrafter_iface = videocrafter_demo(result_dir)
+ videocrafter_iface.queue() # concurrency_count=1, max_size=10
+ videocrafter_iface.launch()
+ # videocrafter_iface.launch(server_name='0.0.0.0', server_port=80)
\ No newline at end of file
diff --git a/assets/GenRL_fig1.png b/assets/GenRL_fig1.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0636d0ab6b97d11ba4c050d2f6d891b27db9c2a
Binary files /dev/null and b/assets/GenRL_fig1.png differ
diff --git a/assets/dashboard.png b/assets/dashboard.png
new file mode 100644
index 0000000000000000000000000000000000000000..95f6af9507f4cb69cba94b8e2a7872a3d32dd85a
Binary files /dev/null and b/assets/dashboard.png differ
diff --git a/assets/video_samples/a_spider_walking_on_the_floor.mp4 b/assets/video_samples/a_spider_walking_on_the_floor.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..065f3f172043e829bb9b44d44b6eab571ec7278a
Binary files /dev/null and b/assets/video_samples/a_spider_walking_on_the_floor.mp4 differ
diff --git a/assets/video_samples/backflip.mp4 b/assets/video_samples/backflip.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8b4bf26f0fc820039fc0e0fc165a4e5da597f695
Binary files /dev/null and b/assets/video_samples/backflip.mp4 differ
diff --git a/assets/video_samples/dancing.mp4 b/assets/video_samples/dancing.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2abc5c1e50cd8fc1bdb245231de1e8f4fa054833
Binary files /dev/null and b/assets/video_samples/dancing.mp4 differ
diff --git a/assets/video_samples/dead_spider_white.gif b/assets/video_samples/dead_spider_white.gif
new file mode 100644
index 0000000000000000000000000000000000000000..0e1d748ff1ca55e2c106ce280f9313fe1b4a904a
Binary files /dev/null and b/assets/video_samples/dead_spider_white.gif differ
diff --git a/assets/video_samples/dog_running_seen_from_the_side.mp4 b/assets/video_samples/dog_running_seen_from_the_side.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..20b735e898c9a6e4b02f2f98a792469a5ed52d3a
Binary files /dev/null and b/assets/video_samples/dog_running_seen_from_the_side.mp4 differ
diff --git a/assets/video_samples/doing_splits.mp4 b/assets/video_samples/doing_splits.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2f2756b29a0065faa7ca3c4c460d47e76e3d1729
Binary files /dev/null and b/assets/video_samples/doing_splits.mp4 differ
diff --git a/assets/video_samples/flex.mp4 b/assets/video_samples/flex.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..bc98c07d88bf15a354b39b4c19e29b440fc4d7ff
Binary files /dev/null and b/assets/video_samples/flex.mp4 differ
diff --git a/assets/video_samples/headstand.mp4 b/assets/video_samples/headstand.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..dbf01583ee275c4f07d6c2a1c8ed59c1536046d0
Binary files /dev/null and b/assets/video_samples/headstand.mp4 differ
diff --git a/assets/video_samples/karate_kick.mp4 b/assets/video_samples/karate_kick.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b3c7c399acbafc88ce4d662c81ee2309c216f7c2
Binary files /dev/null and b/assets/video_samples/karate_kick.mp4 differ
diff --git a/assets/video_samples/lying_down_with_legs_up.mp4 b/assets/video_samples/lying_down_with_legs_up.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7b5ffead69ef85bd83eee6a20f54f7be44f34443
Binary files /dev/null and b/assets/video_samples/lying_down_with_legs_up.mp4 differ
diff --git a/assets/video_samples/person_standing_up_with_hands_up_seen_from_the_side.mp4 b/assets/video_samples/person_standing_up_with_hands_up_seen_from_the_side.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6bcc52764d54321bbff0ea401b91ece2bdd4a4e7
Binary files /dev/null and b/assets/video_samples/person_standing_up_with_hands_up_seen_from_the_side.mp4 differ
diff --git a/assets/video_samples/punching.mp4 b/assets/video_samples/punching.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..eb868f8272892ecb3bef7429ed07c41ec7843e48
Binary files /dev/null and b/assets/video_samples/punching.mp4 differ
diff --git a/collect_data.py b/collect_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a48f677b6f2dc0adc284fc3093d211830bfe077d
--- /dev/null
+++ b/collect_data.py
@@ -0,0 +1,326 @@
+import warnings
+
+warnings.filterwarnings('ignore', category=DeprecationWarning)
+
+import os
+
+os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
+
+from pathlib import Path
+
+import hydra
+import numpy as np
+import torch
+import wandb
+from dm_env import specs
+
+import tools.utils as utils
+from tools.logger import Logger
+from tools.replay import ReplayBuffer, make_replay_loader
+
+torch.backends.cudnn.benchmark = True
+
+# os.environ['WANDB_API_KEY'] = 'local-1b6c1e2a2fd8d4c98b8c049eb2914dbceccd4b7c' # local-1b6c1e2a2fd8d4c98b8c049eb2914dbceccd4b7c
+# os.environ['WANDB_BASE_URL'] = 'https://192.168.170.90:443'
+# os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'
+
+def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
+ cfg.obs_type = obs_type
+ cfg.obs_shape = obs_spec.shape
+ cfg.action_shape = action_spec.shape
+ cfg.num_expl_steps = num_expl_steps
+ return hydra.utils.instantiate(cfg)
+
+
+def make_dreamer_agent(obs_space, action_spec, cur_config, cfg):
+ from copy import deepcopy
+ cur_config = deepcopy(cur_config)
+ del cur_config.agent
+ return hydra.utils.instantiate(cfg, cfg=cur_config, obs_space=obs_space, act_spec=action_spec)
+
+class Workspace:
+ def __init__(self, cfg, savedir=None, workdir=None):
+ self.workdir = Path.cwd() if workdir is None else workdir
+ print(f'workspace: {self.workdir}')
+ self.cfg = cfg
+
+ utils.set_seed_everywhere(cfg.seed)
+ self.device = torch.device(cfg.device)
+
+ # create logger
+ self.logger = Logger(self.workdir,
+ use_tb=cfg.use_tb,
+ use_wandb=cfg.use_wandb)
+ # create envs
+ self.task = task = cfg.task
+ img_size = cfg.img_size
+
+ import envs.main as envs
+ self.train_env = envs.make(task, cfg.obs_type, cfg.action_repeat, cfg.seed, img_size=img_size, viclip_encode=cfg.viclip_encode, clip_hd_rendering=cfg.clip_hd_rendering)
+
+ # # create agent
+ self.agent = make_dreamer_agent(self.train_env.obs_space, self.train_env.act_space['action'], cfg, cfg.agent)
+
+ # get meta specs
+ meta_specs = self.agent.get_meta_specs()
+ # create replay buffer
+ data_specs = (self.train_env.obs_space,
+ self.train_env.act_space,
+ specs.Array((1,), np.float32, 'reward'),
+ specs.Array((1,), np.float32, 'discount'))
+
+ # create data storage
+ self.replay_storage = ReplayBuffer(data_specs, meta_specs,
+ self.workdir / 'buffer',
+ length=cfg.batch_length, **cfg.replay,
+ device=cfg.device)
+
+ # create replay buffer
+ self.replay_loader = make_replay_loader(self.replay_storage,
+ cfg.batch_size,)
+ self._replay_iter = None
+
+ self.timer = utils.Timer()
+ self._global_step = 0
+ self._global_episode = 0
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ @property
+ def global_episode(self):
+ return self._global_episode
+
+ @property
+ def global_frame(self):
+ return self.global_step * self.cfg.action_repeat
+
+ @property
+ def replay_iter(self):
+ if self._replay_iter is None:
+ self._replay_iter = iter(self.replay_loader)
+ return self._replay_iter
+
+ def eval(self):
+ import envs.main as envs
+ eval_env = envs.make(self.task, self.cfg.obs_type, self.cfg.action_repeat, self.cfg.seed, img_size=64,)
+ step, episode, total_reward = 0, 0, 0
+ eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
+ meta = self.agent.init_meta()
+ while eval_until_episode(episode):
+ time_step, dreamer_obs = eval_env.reset()
+ agent_state = None
+ while not time_step.last():
+ with torch.no_grad(), utils.eval_mode(self.agent):
+ action, agent_state = self.agent.act(dreamer_obs,
+ meta,
+ self.global_step,
+ eval_mode=True,
+ state=agent_state)
+ time_step, dreamer_obs = eval_env.step(action)
+ total_reward += time_step.reward
+ step += 1
+
+ episode += 1
+
+ with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
+ log('episode_reward', total_reward / episode)
+ log('episode_length', step * self.cfg.action_repeat / episode)
+ log('episode', self.global_episode)
+ log('step', self.global_step)
+
+ def eval_imag_behavior(self,):
+ self.agent._backup_acting_behavior = self.agent._acting_behavior
+ self.agent._acting_behavior = self.agent._imag_behavior
+ self.eval()
+ self.agent._acting_behavior = self.agent._backup_acting_behavior
+
+ def train(self):
+ # predicates
+ train_until_step = utils.Until(self.cfg.num_train_frames, self.cfg.action_repeat)
+ seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat)
+ eval_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat)
+ train_every_n_steps = max(self.cfg.train_every_actions // self.cfg.action_repeat, 1)
+ should_train_step = utils.Every(train_every_n_steps * self.cfg.action_repeat, self.cfg.action_repeat)
+ should_log_scalars = utils.Every(self.cfg.log_every_frames, self.cfg.action_repeat)
+ should_log_visual = utils.Every(self.cfg.visual_every_frames, self.cfg.action_repeat)
+ should_save_model = utils.Every(self.cfg.save_every_frames, self.cfg.action_repeat)
+
+ episode_step, episode_reward = 0, 0
+ time_step, dreamer_obs = self.train_env.reset()
+ agent_state = None
+ meta = self.agent.init_meta()
+ data = dreamer_obs
+ self.replay_storage.add(data, meta)
+ metrics = None
+ while train_until_step(self.global_step):
+ if time_step.last():
+ self._global_episode += 1
+ # wait until all the metrics schema is populated
+ if metrics is not None:
+ # log stats
+ elapsed_time, total_time = self.timer.reset()
+ episode_frame = episode_step * self.cfg.action_repeat
+ with self.logger.log_and_dump_ctx(self.global_frame,
+ ty='train') as log:
+ log('fps', episode_frame / elapsed_time)
+ log('total_time', total_time)
+ log('episode_reward', episode_reward)
+ log('episode_length', episode_frame)
+ log('episode', self.global_episode)
+ log('buffer_size', len(self.replay_storage))
+ log('step', self.global_step)
+ if should_save_model(self.global_step):
+ # save last model
+ self.save_last_model()
+
+ # reset env
+ time_step, dreamer_obs = self.train_env.reset()
+ # Updating agent
+ agent_state = None # Resetting agent's latent state
+ meta = self.agent.init_meta()
+ data = dreamer_obs
+ self.replay_storage.add(data, meta)
+ episode_step = 0
+ episode_reward = 0
+
+ # try to evaluate
+ if eval_every_step(self.global_step):
+ if self.cfg.eval_modality == 'task':
+ self.eval()
+ if self.cfg.eval_modality == 'task_imag':
+ self.eval_imag_behavior()
+ if self.cfg.eval_modality == 'from_text':
+ self.logger.log('eval_total_time', self.timer.total_time(),
+ self.global_frame)
+ self.eval_from_text()
+
+ meta = self.agent.update_meta(meta, self.global_step, time_step)
+ # sample action
+ with torch.no_grad(), utils.eval_mode(self.agent):
+ if seed_until_step(self.global_step):
+ action = self.train_env.act_space['action'].sample()
+ if getattr(self.cfg, 'discrete_actions', False):
+ action = (action == np.max(action)).astype(np.float32) # one-hot
+ else:
+ action, agent_state = self.agent.act(dreamer_obs, # time_step.observation
+ meta,
+ self.global_step,
+ eval_mode=False,
+ state=agent_state)
+
+ # try to update the agent
+ if not seed_until_step(self.global_step):
+ if should_train_step(self.global_step):
+ # prof.step()
+ # Sampling data
+ batch_data = next(self.replay_iter)
+ if hasattr(self.agent, ' update_wm'):
+ state, outputs, metrics = self.agent.update_wm(batch_data, self.global_step)
+ if hasattr(self.agent, "update_acting_behavior"):
+ metrics = self.agent.update_acting_behavior(state=state, outputs=outputs, metrics=metrics, data=batch_data)[1]
+ if hasattr(self.agent, "update_imag_behavior"):
+ metrics.update(self.agent.update_imag_behavior(state=state, outputs=outputs, metrics=metrics, seq_data=batch_data,)[1])
+ else:
+ outputs, metrics = self.agent.update(batch_data, self.global_step)
+
+ if should_log_scalars(self.global_step):
+ self.logger.log_metrics(metrics, self.global_frame, ty='train')
+ if self.global_step > 0 and should_log_visual(self.global_step):
+ if hasattr(self.agent, 'report'):
+ with torch.no_grad(), utils.eval_mode(self.agent):
+ videos = self.agent.report(next(self.replay_iter))
+ self.logger.log_visual(videos, self.global_frame)
+
+ # take env step
+ time_step, dreamer_obs = self.train_env.step(action)
+ episode_reward += time_step.reward
+ data = dreamer_obs
+ if time_step.last():
+ if getattr(self.train_env, "accumulate", False):
+ assert not self.replay_storage._ongoing
+ # NOTE: this is ok as it comes right after adding to the repl
+ accumulated_data, accumulated_key = self.train_env.process_accumulate()
+ data[accumulated_key] = accumulated_data[-1]
+ self.replay_storage._ongoing_eps[0][accumulated_key][-len(accumulated_data[:-1]):] = accumulated_data[:-1]
+ self.replay_storage.add(data, meta)
+ episode_step += 1
+ self._global_step += 1
+
+ @utils.retry
+ def save_snapshot(self):
+ snapshot = self.get_snapshot_dir() / f'snapshot_{self.global_frame}.pt'
+ keys_to_save = ['agent', '_global_step', '_global_episode']
+ payload = {k: self.__dict__[k] for k in keys_to_save}
+ with snapshot.open('wb') as f:
+ torch.save(payload, f)
+
+ def setup_wandb(self):
+ cfg = self.cfg
+ exp_name = '_'.join([
+ cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
+ str(cfg.seed)
+ ])
+ wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name)
+ flat_cfg = utils.flatten_dict(cfg)
+ wandb.config.update(flat_cfg)
+ self.wandb_run_id = wandb.run.id
+
+ @utils.retry
+ def save_last_model(self):
+ snapshot = self.root_dir / 'last_snapshot.pt'
+ if snapshot.is_file():
+ temp = Path(str(snapshot).replace("last_snapshot.pt", "second_last_snapshot.pt"))
+ os.replace(snapshot, temp)
+ keys_to_save = ['agent', '_global_step', '_global_episode']
+ if self.cfg.use_wandb:
+ keys_to_save.append('wandb_run_id')
+ payload = {k: self.__dict__[k] for k in keys_to_save}
+ with snapshot.open('wb') as f:
+ torch.save(payload, f)
+
+ def load_snapshot(self, snapshot_dir):
+ try:
+ snapshot = snapshot_dir / 'last_snapshot.pt'
+ with snapshot.open('rb') as f:
+ payload = torch.load(f)
+ except:
+ snapshot = snapshot_dir / 'second_last_snapshot.pt'
+ with snapshot.open('rb') as f:
+ payload = torch.load(f)
+ for k,v in payload.items():
+ setattr(self, k, v)
+ if k == 'wandb_run_id':
+ assert wandb.run is None
+ cfg = self.cfg
+ exp_name = '_'.join([
+ cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
+ str(cfg.seed)
+ ])
+ wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, id=v, resume="must")
+
+ def get_snapshot_dir(self):
+ snap_dir = self.cfg.snapshot_dir
+ snapshot_dir = self.workdir / Path(snap_dir)
+ snapshot_dir.mkdir(exist_ok=True, parents=True)
+ return snapshot_dir
+
+@hydra.main(config_path='.', config_name='collect_data')
+def main(cfg):
+ from collect_data import Workspace as W
+ root_dir = Path.cwd()
+ cfg.workdir = str(root_dir)
+ workspace = W(cfg)
+ workspace.root_dir = root_dir
+ snapshot = workspace.root_dir / 'last_snapshot.pt'
+ if snapshot.exists():
+ print(f'resuming: {snapshot}')
+ workspace.load_snapshot(workspace.root_dir)
+ if cfg.use_wandb and wandb.run is None:
+ # otherwise it was resumed
+ workspace.setup_wandb()
+ workspace.train()
+
+if __name__ == '__main__':
+ main()
diff --git a/collect_data.yaml b/collect_data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..647ee9cd762af31da4ea387e394f229407fdcc52
--- /dev/null
+++ b/collect_data.yaml
@@ -0,0 +1,54 @@
+defaults:
+ - agent: dreamer
+ - conf/env: dmc_pixels
+ - conf/defaults: dreamer_v3
+ - override hydra/launcher: submitit_local
+
+# mode
+label: default
+# task settings
+task: stickman_walk
+# train settings
+num_train_frames: 2000010
+num_seed_frames: 4000
+# eval
+eval_every_frames: 100000
+eval_modality: null
+num_eval_episodes: 3
+# snapshot
+snapshot_dir: ../../../trained_models/${obs_type}/${task}/${agent.name}/${seed}
+save_every_frames: 10_000
+# misc
+seed: 1
+device: cuda:0
+use_tb: true
+use_wandb: true
+
+# Clip stuff
+viclip_encode: false
+viclip_model: internvideo2
+clip_hd_rendering: false
+
+# experiment
+experiment: data
+project_name: genrl
+
+# log settings
+log_every_frames: 1000
+visual_every_frames: 100000000 # edit for debug
+workdir: ???
+
+hydra:
+ run:
+ dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}
+ sweep:
+ dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}
+ subdir: ${hydra.job.num}
+ launcher:
+ timeout_min: 4300
+ cpus_per_task: 10
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 160
+ nodes: 1
+ submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm
diff --git a/conf/defaults/dreamer_v2.yaml b/conf/defaults/dreamer_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bddff4df7da506934fac3b67891263a87234a799
--- /dev/null
+++ b/conf/defaults/dreamer_v2.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+img_size: 64
+
+# Dreamer defaults
+rssm: {ensemble: 1, hidden: 512, deter: 512, stoch: 32, discrete: 32, norm: none, std_act: softplus, min_std: 0.1, single_obs_posterior: false, } # act: elu,
+discount_head: {layers: 4, units: 512, norm: none, dist: binary} # act: elu
+reward_head: {layers: 4, units: 512, norm: none, dist: mse} # act: elu
+kl: {free: 1.0, forward: False, balance: 0.8, free_avg: False, }
+loss_scales: {kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0}
+model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 1000, wd: 1e-6}
+replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False}
+decoder_inputs: feat
+image_dist: normal_unit_std
+
+actor: {layers: 4, units: 512, norm: none, dist: trunc_normal, min_std: 0.1 } # act: elu
+critic: {layers: 4, units: 512, norm: none, dist: mse} # act: elu,
+actor_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+critic_opt: {opt: adam, lr: 8e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+discount: 0.99
+discount_lambda: 0.95
+slow_target: True
+slow_target_update: 100
+slow_target_fraction: 1
+slow_baseline: True
+reward_ema: False
+
+acting_reward_fn: env_reward
+clip_rewards: identity
+
+batch_size: 50
+batch_length: 50
+imag_horizon: 15
+eval_state_mean: False
+
+precision: 16
+train_every_actions: 10
+only_random_actions: False
+#
\ No newline at end of file
diff --git a/conf/defaults/dreamer_v3.yaml b/conf/defaults/dreamer_v3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbd60b085ed8f2faa255f001acff2800ed7deee9
--- /dev/null
+++ b/conf/defaults/dreamer_v3.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+img_size: 64
+
+# Dreamer defaults
+rssm: {ensemble: 1, hidden: 512, deter: 512, stoch: 32, discrete: 32, norm: layer, std_act: softplus, min_std: 0.1, single_obs_posterior: false, } # act: elu,
+discount_head: {layers: 4, units: 512, norm: layer, dist: binary} # act: elu
+reward_head: {layers: 4, units: 512, norm: layer, dist: twohot} # act: elu
+kl: { free: 1.0, forward: False, balance: 0.85, free_avg: False,}
+loss_scales: {kl: 0.6, reward: 1.0, discount: 1.0, proprio: 1.0}
+model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000, wd: 1e-6}
+replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False}
+decoder_inputs: feat
+image_dist: mse
+# Actor Critic
+actor: {layers: 4, units: 512, norm: layer, dist: normal, min_std: 0.1 } # act: elu
+critic: {layers: 4, units: 512, norm: layer, dist: twohot } # act: elu,
+actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+discount: 0.99
+discount_lambda: 0.95
+slow_target: True
+slow_target_update: 100
+slow_target_fraction: 1
+slow_baseline: True
+reward_ema: True
+
+acting_reward_fn: env_reward
+clip_rewards: identity
+
+batch_size: 50
+batch_length: 50
+imag_horizon: 15
+eval_state_mean: False
+
+precision: 16
+train_every_actions: 10
+only_random_actions: False
+#
\ No newline at end of file
diff --git a/conf/defaults/genrl.yaml b/conf/defaults/genrl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2000579983e3797f044dbae73ddf98c55ee86e87
--- /dev/null
+++ b/conf/defaults/genrl.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+img_size: 64
+
+# Dreamer defaults
+rssm: {ensemble: 1, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, norm: layer, std_act: softplus, min_std: 0.1, single_obs_posterior: true, } # act: elu,
+discount_head: {layers: 4, units: 512, norm: none, dist: binary} # act: elu
+reward_head: {layers: 4, units: 1024, norm: layer, dist: twohot} # act: elu
+kl: { free: 1.0, forward: False, balance: 0.85, free_avg: False, }
+loss_scales: {kl: 0.6, reward: 1.0, discount: 1.0, proprio: 1.0}
+model_opt: {opt: adam, lr: 1e-4, eps: 1e-8, clip: 1000, wd: 1e-6}
+replay: {capacity: 20e6, ongoing: False, minlen: 48, maxlen: 48, prioritize_ends: False}
+decoder_inputs: stoch
+image_dist: mse
+# Actor Critic
+actor: {layers: 4, units: 1024, norm: layer, dist: normal, min_std: 0.1 } # act: elu
+critic: {layers: 4, units: 1024, norm: layer, dist: twohot } # act: elu,
+actor_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+critic_opt: {opt: adam, lr: 3e-5, eps: 1e-5, clip: 100, wd: 1e-6}
+discount: 0.99
+discount_lambda: 0.95
+slow_target: True
+slow_target_update: 100
+slow_target_fraction: 1
+slow_baseline: True
+reward_ema: True
+
+acting_reward_fn: env_reward
+clip_rewards: identity
+
+batch_size: 48
+batch_length: 48
+imag_horizon: 16
+eval_state_mean: False
+
+precision: 16
+train_every_actions: 10
+only_random_actions: False
\ No newline at end of file
diff --git a/conf/env/dmc_pixels.yaml b/conf/env/dmc_pixels.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c578a38c790d331517b2c9d8d50ed3c35dbade3a
--- /dev/null
+++ b/conf/env/dmc_pixels.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+obs_type: pixels
+action_repeat: 2
+encoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: layer, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} # act: elu
+decoder: {mlp_keys: '$^', cnn_keys: 'observation', norm: layer, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400], } # act: elu
+pred_discount: False
+imag_actor_grad: dynamics
+actor_grad: dynamics
\ No newline at end of file
diff --git a/conf/train_mode/train_behavior.yaml b/conf/train_mode/train_behavior.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5dfc3f5866d435c7326c66cf0e5a1c604bd5aaa3
--- /dev/null
+++ b/conf/train_mode/train_behavior.yaml
@@ -0,0 +1,5 @@
+num_train_frames: 500_010
+batch_size: 32
+batch_length: 32
+agent.imag_reward_fn: video_text_reward
+eval_modality: task_imag
\ No newline at end of file
diff --git a/conf/train_mode/train_model.yaml b/conf/train_mode/train_model.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8c795e5c4a30c342215d8df9a7ca2bdfaf045943
--- /dev/null
+++ b/conf/train_mode/train_model.yaml
@@ -0,0 +1,6 @@
+num_train_frames: 5_000_010
+visual_every_frames: 250_000
+train_world_model: True
+train_connector: True
+reset_world_model: True
+reset_connector: True
\ No newline at end of file
diff --git a/demo/demo_test.py b/demo/demo_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf875834d991a8715e5af65d70fd2242659a093
--- /dev/null
+++ b/demo/demo_test.py
@@ -0,0 +1,23 @@
+import os
+from pathlib import Path
+VIDEO_PATH = Path(os.path.abspath('')) / 'assets' / 'video_samples'
+
+class Text2Video():
+ def __init__(self, result_dir='./tmp/') -> None:
+ pass
+
+ def get_prompt(self, input_text, steps=50, cfg_scale=15.0, eta=1.0, fps=16):
+
+ return str(VIDEO_PATH / 'headstand.mp4')
+
+class Video2Video:
+ def __init__(self, result_dir='./tmp/') -> None:
+ pass
+
+ def get_image(self, input_image, input_prompt, i2v_steps=50, i2v_cfg_scale=15.0, i2v_eta=1.0, i2v_fps=16):
+
+ return str(VIDEO_PATH / 'dancing.mp4')
+
+if __name__ == '__main__':
+ t2v = Text2Video()
+ print(t2v.get_prompt('test'))
\ No newline at end of file
diff --git a/demo/t2v.py b/demo/t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..34e64168ef918f953e3effc89449a824ad6546bc
--- /dev/null
+++ b/demo/t2v.py
@@ -0,0 +1,115 @@
+from pathlib import Path
+import os
+import sys
+sys.path.append(str(Path(os.path.abspath(''))))
+
+import torch
+import numpy as np
+from tools.genrl_utils import ViCLIPGlobalInstance
+
+import time
+import torchvision
+from huggingface_hub import hf_hub_download
+
+def save_videos(batch_tensors, savedir, filenames, fps=10):
+ # b,samples,c,t,h,w
+ n_samples = batch_tensors.shape[1]
+ for idx, vid_tensor in enumerate(batch_tensors):
+ video = vid_tensor.detach().cpu()
+ video = torch.clamp(video.float(), 0., 1.)
+ video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
+
+class Text2Video():
+ def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
+ model_folder = str(Path(os.path.abspath('')) / 'models')
+ model_filename = 'genrl_stickman_500k_2.pt'
+
+ if not os.path.isfile(os.path.join(model_folder, model_filename)):
+ self.download_model(model_folder, model_filename)
+ if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
+ self.download_internvideo2(model_folder)
+ self.agent = torch.load(os.path.join(model_folder, model_filename))
+ model_name = 'internvideo2'
+
+ # Get ViCLIP
+ viclip_global_instance = ViCLIPGlobalInstance(model_name)
+ if not viclip_global_instance._instantiated:
+ print("Instantiating InternVideo2")
+ viclip_global_instance.instantiate()
+ self.clip = viclip_global_instance.viclip
+ self.tokenizer = viclip_global_instance.viclip_tokenizer
+
+ self.result_dir = result_dir
+ if not os.path.exists(self.result_dir):
+ os.mkdir(self.result_dir)
+
+ def get_prompt(self, prompt, duration):
+ torch.cuda.empty_cache()
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
+ start = time.time()
+
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
+
+ labels_list = [prompt_str]
+ with torch.no_grad():
+ wm = world_model = self.agent.wm
+ connector = self.agent.wm.connector
+ decoder = world_model.heads['decoder']
+ n_frames = connector.n_frames
+
+ # Get text(video) embed
+ text_feat = []
+ for text in labels_list:
+ with torch.no_grad():
+ text_feat.append(self.clip.get_txt_feat(text,))
+ text_feat = torch.stack(text_feat, dim=0).to(self.clip.device)
+
+ video_embed = text_feat
+
+ B = video_embed.shape[0]
+ T = 1
+
+ # Get actions
+ video_embed = video_embed.repeat(1, duration, 1)
+ with torch.no_grad():
+ # Imagine
+ prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=True)
+ # Decode
+ prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5
+
+ save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
+ print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
+ return os.path.join(self.result_dir, f"{prompt_str}.mp4")
+
+ def download_model(self, model_folder, model_filename):
+ REPO_ID = 'mazpie/genrl_models'
+ filename_list = [model_filename]
+ if not os.path.exists(model_folder):
+ os.makedirs(model_folder)
+ for filename in filename_list:
+ local_file = os.path.join(model_folder, filename)
+
+ if not os.path.exists(local_file):
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
+
+ def download_internvideo2(self, model_folder):
+ REPO_ID = 'OpenGVLab/InternVideo2-Stage2_1B-224p-f4'
+ filename_list = ['InternVideo2-stage2_1b-224p-f4.pt']
+ if not os.path.exists(model_folder):
+ os.makedirs(model_folder)
+ for filename in filename_list:
+ local_file = os.path.join(model_folder, filename)
+
+ if not os.path.exists(local_file):
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
+
+if __name__ == '__main__':
+ t2v = Text2Video()
+ video_path = t2v.get_prompt('a black swan swims on the pond', 8)
+ print('done', video_path)
\ No newline at end of file
diff --git a/envs/__init__.py b/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/envs/custom_dmc_tasks/__init__.py b/envs/custom_dmc_tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a47945d5fe196099562ea67367c649e82af2b3f7
--- /dev/null
+++ b/envs/custom_dmc_tasks/__init__.py
@@ -0,0 +1,13 @@
+from . import cheetah
+from . import walker
+from . import quadruped
+from . import jaco
+from . import stickman
+from dm_control import suite
+
+suite._DOMAINS['stickman'] = stickman
+suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
+suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
+
+def make_jaco(task, obs_type, seed, img_size, ):
+ return jaco.make(task, obs_type, seed, img_size, )
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/cheetah.py b/envs/custom_dmc_tasks/cheetah.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb2110ab050fb8b00a8ba8185eb159c1b80e5e44
--- /dev/null
+++ b/envs/custom_dmc_tasks/cheetah.py
@@ -0,0 +1,247 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Cheetah Domain."""
+
+import collections
+import os
+
+from dm_control.suite import cheetah
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import io as resources
+
+# How long the simulation will run, in seconds.
+_DEFAULT_TIME_LIMIT = 10
+
+_DOWN_HEIGHT = 0.15
+_HIGH_HEIGHT = 1.00
+_MID_HEIGHT = 0.45
+
+
+# Running speed above which reward is 1.
+_RUN_SPEED = 10
+_SPIN_SPEED = 5
+
+def make(task,
+ task_kwargs=None,
+ environment_kwargs=None,
+ visualize_reward=False):
+ task_kwargs = task_kwargs or {}
+ if environment_kwargs is not None:
+ task_kwargs = task_kwargs.copy()
+ task_kwargs['environment_kwargs'] = environment_kwargs
+ env = SUITE[task](**task_kwargs)
+ env.task.visualize_reward = visualize_reward
+ return env
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+ xml = resources.GetResource(
+ os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
+ return xml, common.ASSETS
+
+
+@cheetah.SUITE.add('custom')
+def flipping(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=False, flip=False, random=random, goal='flipping')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+@cheetah.SUITE.add('custom')
+def standing(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=False, flip=False, random=random, goal='standing')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+
+@cheetah.SUITE.add('custom')
+def lying_down(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=False, flip=False, random=random, goal='lying_down')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+
+@cheetah.SUITE.add('custom')
+def run_backward(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=False, flip=False, random=random, goal='run_backward')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+
+@cheetah.SUITE.add('custom')
+def flip(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=True, flip=True, random=random, goal='flip')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+
+@cheetah.SUITE.add('custom')
+def flip_backward(time_limit=_DEFAULT_TIME_LIMIT,
+ random=None,
+ environment_kwargs=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(forward=False, flip=True, random=random, goal='flip_backward')
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics,
+ task,
+ time_limit=time_limit,
+ **environment_kwargs)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Cheetah domain."""
+ def speed(self):
+ """Returns the horizontal speed of the Cheetah."""
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
+
+ def angmomentum(self):
+ """Returns the angular momentum of torso of the Cheetah about Y axis."""
+ return self.named.data.subtree_angmom['torso'][1]
+
+
+class Cheetah(base.Task):
+ """A `Task` to train a running Cheetah."""
+ def __init__(self, goal=None, forward=True, flip=False, random=None):
+ self._forward = 1 if forward else -1
+ self._flip = flip
+ self._goal = goal
+ super(Cheetah, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # The indexing below assumes that all joints have a single DOF.
+ assert physics.model.nq == physics.model.njnt
+ is_limited = physics.model.jnt_limited == 1
+ lower, upper = physics.model.jnt_range[is_limited].T
+ physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
+
+ # Stabilize the model before the actual simulation.
+ for _ in range(200):
+ physics.step()
+
+ physics.data.time = 0
+ self._timeout_progress = 0
+ super().initialize_episode(physics)
+
+ def _get_lying_down_reward(self, physics):
+ torso = physics.named.data.xpos['torso', 'z']
+
+ torso_down = rewards.tolerance(torso,
+ bounds=(-float('inf'), _DOWN_HEIGHT),
+ margin=_DOWN_HEIGHT * 1.5,)
+
+ feet = physics.named.data.xpos['bfoot', 'z'] + physics.named.data.xpos['ffoot', 'z']
+
+ feet_up = rewards.tolerance(feet,
+ bounds=(_MID_HEIGHT, float('inf')),
+ margin=_MID_HEIGHT / 2,)
+ return (torso_down + feet_up) / 2
+
+ def _get_standing_reward(self, physics):
+ bfoot = physics.named.data.xpos['bfoot', 'z']
+ ffoot = physics.named.data.xpos['ffoot', 'z']
+ max_foot = bfoot if bfoot > ffoot else ffoot
+ min_foot = bfoot if bfoot <= ffoot else ffoot
+
+ low_foot_low = rewards.tolerance(min_foot,
+ bounds=(-float('inf'), _DOWN_HEIGHT),
+ margin=_DOWN_HEIGHT * 1.5,)
+
+ high_foot_high = rewards.tolerance(max_foot,
+ bounds=(_HIGH_HEIGHT, float('inf')),
+ margin=_HIGH_HEIGHT / 2,)
+ return high_foot_high * low_foot_low
+
+ def _get_flip_reward(self, physics):
+ return rewards.tolerance(self._forward * physics.angmomentum(),
+ bounds=(_SPIN_SPEED, float('inf')),
+ margin=_SPIN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
+
+ def get_observation(self, physics):
+ """Returns an observation of the state, ignoring horizontal position."""
+ obs = collections.OrderedDict()
+ # Ignores horizontal position to maintain translational invariance.
+ obs['position'] = physics.data.qpos[1:].copy()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ if self._goal in ['run', 'flip', 'run_backward', 'flip_backward']:
+ if self._flip:
+ return self._get_flip_reward(physics)
+ else:
+ reward = rewards.tolerance(self._forward * physics.speed(),
+ bounds=(_RUN_SPEED, float('inf')),
+ margin=_RUN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
+ return reward
+ elif self._goal == 'lying_down':
+ return self._get_lying_down_reward(physics)
+ elif self._goal == 'flipping':
+ self._forward = True
+ fwd_reward = self._get_flip_reward(physics)
+ self._forward = False
+ back_reward = self._get_flip_reward(physics)
+ return max(fwd_reward, back_reward)
+ elif self._goal == 'standing':
+ return self._get_standing_reward(physics)
+ else:
+ raise NotImplementedError(self._goal)
diff --git a/envs/custom_dmc_tasks/cheetah.xml b/envs/custom_dmc_tasks/cheetah.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e55c95a2621e412a3688c3244c109d6cd6453044
--- /dev/null
+++ b/envs/custom_dmc_tasks/cheetah.xml
@@ -0,0 +1,74 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/jaco.py b/envs/custom_dmc_tasks/jaco.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c162d05c5dba5620de3a2f51c5ac4362456e495
--- /dev/null
+++ b/envs/custom_dmc_tasks/jaco.py
@@ -0,0 +1,222 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A task where the goal is to move the hand close to a target prop or site."""
+
+import collections
+
+from dm_control import composer
+from dm_control.composer import initializers
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.manipulation.shared import arenas
+from dm_control.manipulation.shared import cameras
+from dm_control.manipulation.shared import constants
+from dm_control.manipulation.shared import observations
+from dm_control.manipulation.shared import registry
+from dm_control.manipulation.shared import robots
+from dm_control.manipulation.shared import tags
+from dm_control.manipulation.shared import workspaces
+from dm_control.utils import rewards
+import numpy as np
+
+
+_ReachWorkspace = collections.namedtuple(
+ '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
+
+# Ensures that the props are not touching the table before settling.
+_PROP_Z_OFFSET = 0.001
+
+_DUPLO_WORKSPACE = _ReachWorkspace(
+ target_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, _PROP_Z_OFFSET),
+ upper=(0.1, 0.1, _PROP_Z_OFFSET)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.1, -0.1, 0.2),
+ upper=(0.1, 0.1, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_SITE_WORKSPACE = _ReachWorkspace(
+ target_bbox=workspaces.BoundingBox(
+ lower=(-0.2, -0.2, 0.02),
+ upper=(0.2, 0.2, 0.4)),
+ tcp_bbox=workspaces.BoundingBox(
+ lower=(-0.2, -0.2, 0.02),
+ upper=(0.2, 0.2, 0.4)),
+ arm_offset=robots.ARM_OFFSET)
+
+_TARGET_RADIUS = 0.05
+_TIME_LIMIT = 10
+
+TASKS = {
+ 'reach_top_left': workspaces.BoundingBox(
+ lower=(-0.09, 0.09, _PROP_Z_OFFSET),
+ upper=(-0.09, 0.09, _PROP_Z_OFFSET)),
+ 'reach_top_right': workspaces.BoundingBox(
+ lower=(0.09, 0.09, _PROP_Z_OFFSET),
+ upper=(0.09, 0.09, _PROP_Z_OFFSET)),
+ 'reach_bottom_left': workspaces.BoundingBox(
+ lower=(-0.09, -0.09, _PROP_Z_OFFSET),
+ upper=(-0.09, -0.09, _PROP_Z_OFFSET)),
+ 'reach_bottom_right': workspaces.BoundingBox(
+ lower=(0.09, -0.09, _PROP_Z_OFFSET),
+ upper=(0.09, -0.09, _PROP_Z_OFFSET)),
+}
+
+
+def make(task_id, obs_type, seed, img_size=64,):
+ obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
+ obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(width=img_size))
+ obs_settings = obs_settings._replace(camera=obs_settings[-1]._replace(height=img_size))
+ if obs_type == 'states':
+ global _TIME_LIMIT
+ _TIME_LIMIT = 10.04
+ # Note: Adding this fixes the problem of having 249 steps with action repeat = 1
+ task = _reach(task_id, obs_settings=obs_settings, use_site=False)
+ return composer.Environment(task, time_limit=_TIME_LIMIT, random_state=seed)
+
+
+class MTReach(composer.Task):
+ """Bring the hand close to a target prop or site."""
+
+ def __init__(
+ self, task_id, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
+ """Initializes a new `Reach` task.
+
+ Args:
+ arena: `composer.Entity` instance.
+ arm: `robot_base.RobotArm` instance.
+ hand: `robot_base.RobotHand` instance.
+ prop: `composer.Entity` instance specifying the prop to reach to, or None
+ in which case the target is a fixed site whose position is specified by
+ the workspace.
+ obs_settings: `observations.ObservationSettings` instance.
+ workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
+ control_timestep: Float specifying the control timestep in seconds.
+ """
+ self._task_id = task_id
+ self._arena = arena
+ self._arm = arm
+ self._hand = hand
+ self._arm.attach(self._hand)
+ self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
+ self.control_timestep = control_timestep
+ self._tcp_initializer = initializers.ToolCenterPointInitializer(
+ self._hand, self._arm,
+ position=distributions.Uniform(*workspace.tcp_bbox),
+ quaternion=workspaces.DOWN_QUATERNION)
+
+ # Add custom camera observable.
+ self._task_observables = cameras.add_camera_observables(
+ arena, obs_settings, cameras.FRONT_CLOSE)
+
+ target_pos_distribution = distributions.Uniform(*TASKS[task_id])
+ self._prop = prop
+ if prop:
+ # The prop itself is used to visualize the target location.
+ self._make_target_site(parent_entity=prop, visible=False)
+ self._target = self._arena.add_free_entity(prop)
+ self._prop_placer = initializers.PropPlacer(
+ props=[prop],
+ position=target_pos_distribution,
+ quaternion=workspaces.uniform_z_rotation,
+ settle_physics=True)
+ else:
+ self._target = self._make_target_site(parent_entity=arena, visible=True)
+ self._target_placer = target_pos_distribution
+
+ # Commented to match EXORL
+ # obs = observable.MJCFFeature('pos', self._target)
+ # obs.configure(**obs_settings.prop_pose._asdict())
+ # self._task_observables['target_position'] = obs
+
+ # Add sites for visualizing the prop and target bounding boxes.
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
+ rgba=constants.GREEN, name='tcp_spawn_area')
+ workspaces.add_bbox_site(
+ body=self.root_entity.mjcf_model.worldbody,
+ lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
+ rgba=constants.BLUE, name='target_spawn_area')
+
+ def _make_target_site(self, parent_entity, visible):
+ return workspaces.add_target_site(
+ body=parent_entity.mjcf_model.worldbody,
+ radius=_TARGET_RADIUS, visible=visible,
+ rgba=constants.RED, name='target_site')
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def arm(self):
+ return self._arm
+
+ @property
+ def hand(self):
+ return self._hand
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ def get_reward(self, physics):
+ hand_pos = physics.bind(self._hand.tool_center_point).xpos
+ target_pos = physics.bind(self._target).xpos
+ # This was used exceptionally for the PT reward predictor experiments
+ # target_pos = distributions.Uniform(*TASKS[self._task_id])()
+ distance = np.linalg.norm(hand_pos - target_pos)
+ return rewards.tolerance(
+ distance, bounds=(0, _TARGET_RADIUS), margin=_TARGET_RADIUS)
+
+ def initialize_episode(self, physics, random_state):
+ self._hand.set_grasp(physics, close_factors=random_state.uniform())
+ self._tcp_initializer(physics, random_state)
+ if self._prop:
+ self._prop_placer(physics, random_state)
+ else:
+ physics.bind(self._target).pos = (
+ self._target_placer(random_state=random_state))
+
+
+def _reach(task_id, obs_settings, use_site):
+ """Configure and instantiate a `Reach` task.
+
+ Args:
+ obs_settings: An `observations.ObservationSettings` instance.
+ use_site: Boolean, if True then the target will be a fixed site, otherwise
+ it will be a moveable Duplo brick.
+
+ Returns:
+ An instance of `reach.Reach`.
+ """
+ arena = arenas.Standard()
+ arm = robots.make_arm(obs_settings=obs_settings)
+ hand = robots.make_hand(obs_settings=obs_settings)
+ if use_site:
+ workspace = _SITE_WORKSPACE
+ prop = None
+ else:
+ workspace = _DUPLO_WORKSPACE
+ prop = props.Duplo(observable_options=observations.make_options(
+ obs_settings, observations.FREEPROP_OBSERVABLES))
+ task = MTReach(task_id, arena=arena, arm=arm, hand=hand, prop=prop,
+ obs_settings=obs_settings,
+ workspace=workspace,
+ control_timestep=constants.CONTROL_TIMESTEP)
+ return task
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/quadruped.py b/envs/custom_dmc_tasks/quadruped.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fed164dfed1f4648de59cd373bf36ac2551f590
--- /dev/null
+++ b/envs/custom_dmc_tasks/quadruped.py
@@ -0,0 +1,683 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Quadruped Domain."""
+
+import collections
+
+from dm_control.suite import quadruped
+from dm_control import mujoco
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+from dm_control.utils import io as resources
+from lxml import etree
+import numpy as np
+from scipy import ndimage
+import os
+
+enums = mjbindings.enums
+mjlib = mjbindings.mjlib
+
+
+_DEFAULT_TIME_LIMIT = 20
+_CONTROL_TIMESTEP = .02
+
+# Horizontal speeds above which the move reward is 1.
+_RUN_SPEED = 5
+_WALK_SPEED = 0.5
+
+_JUMP_HEIGHT = 1.0 # -also good for foot up
+_LIE_DOWN_HEIGHT = 0.2
+_FOOT_DOWN_HEIGHT = 0.2
+_FOOT_UP_HEIGHT = 0.8
+
+# Constants related to terrain generation.
+_HEIGHTFIELD_ID = 0
+_TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth.
+_TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters).
+
+# Named model elements.
+_TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
+_WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']
+
+def make(task,
+ task_kwargs=None,
+ environment_kwargs=None,
+ visualize_reward=False):
+ task_kwargs = task_kwargs or {}
+ if environment_kwargs is not None:
+ task_kwargs = task_kwargs.copy()
+ task_kwargs['environment_kwargs'] = environment_kwargs
+ env = SUITE[task](**task_kwargs)
+ env.task.visualize_reward = visualize_reward
+ return env
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+ xml = resources.GetResource(
+ os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
+ return xml, common.ASSETS
+
+
+def make_model(floor_size=None, terrain=False, rangefinders=False,
+ walls_and_ball=False):
+ """Returns the model XML string."""
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+ xml_string = common.read_model(os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Set floor size.
+ if floor_size is not None:
+ floor_geom = mjcf.find('.//geom[@name=\'floor\']')
+ floor_geom.attrib['size'] = f'{floor_size} {floor_size} .5'
+
+ # Remove walls, ball and target.
+ if not walls_and_ball:
+ for wall in _WALLS:
+ wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
+ wall_geom.getparent().remove(wall_geom)
+
+ # Remove ball.
+ ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
+ ball_body.getparent().remove(ball_body)
+
+ # Remove target.
+ target_site = xml_tools.find_element(mjcf, 'site', 'target')
+ target_site.getparent().remove(target_site)
+
+ # Remove terrain.
+ if not terrain:
+ terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
+ terrain_geom.getparent().remove(terrain_geom)
+
+ # Remove rangefinders if they're not used, as range computations can be
+ # expensive, especially in a scene with heightfields.
+ if not rangefinders:
+ rangefinder_sensors = mjcf.findall('.//rangefinder')
+ for rf in rangefinder_sensors:
+ rf.getparent().remove(rf)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+@quadruped.SUITE.add('custom')
+def lie_down(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Stand(goal='lie_down', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@quadruped.SUITE.add('custom')
+def two_legs(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Stand(goal='two_legs', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@quadruped.SUITE.add('custom')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Stand(goal='stand', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@quadruped.SUITE.add('custom')
+def jump(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Jump(desired_height=_JUMP_HEIGHT, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@quadruped.SUITE.add('custom')
+def roll(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Roll(desired_speed=_WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@quadruped.SUITE.add('custom')
+def roll_fast(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
+ physics = Physics.from_xml_string(xml_string, common.ASSETS)
+ task = Roll(desired_speed=_RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(physics, task, time_limit=time_limit,
+ control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Quadruped domain."""
+
+ def _reload_from_data(self, data):
+ super()._reload_from_data(data)
+ # Clear cached sensor names when the physics is reloaded.
+ self._sensor_types_to_names = {}
+ self._hinge_names = []
+
+ def _get_sensor_names(self, *sensor_types):
+ try:
+ sensor_names = self._sensor_types_to_names[sensor_types]
+ except KeyError:
+ [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types))
+ sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
+ self._sensor_types_to_names[sensor_types] = sensor_names
+ return sensor_names
+
+ def torso_upright(self):
+ """Returns the dot-product of the torso z-axis and the global z-axis."""
+ return np.asarray(self.named.data.xmat['torso', 'zz'])
+
+ def torso_velocity(self):
+ """Returns the velocity of the torso, in the local frame."""
+ return self.named.data.sensordata['velocimeter'].copy()
+
+ def com_height(self):
+ return self.named.data.sensordata['center_of_mass'].copy()[2]
+
+ def egocentric_state(self):
+ """Returns the state without global orientation or position."""
+ if not self._hinge_names:
+ [hinge_ids] = np.nonzero(self.model.jnt_type ==
+ enums.mjtJoint.mjJNT_HINGE)
+ self._hinge_names = [self.model.id2name(j_id, 'joint')
+ for j_id in hinge_ids]
+ return np.hstack((self.named.data.qpos[self._hinge_names],
+ self.named.data.qvel[self._hinge_names],
+ self.data.act))
+
+ def toe_positions(self):
+ """Returns toe positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
+ return torso_to_toe.dot(torso_frame)
+
+ def force_torque(self):
+ """Returns scaled force/torque sensor readings at the toes."""
+ force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
+ enums.mjtSensor.mjSENS_TORQUE)
+ return np.arcsinh(self.named.data.sensordata[force_torque_sensors])
+
+ def imu(self):
+ """Returns IMU-like sensor readings."""
+ imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
+ enums.mjtSensor.mjSENS_ACCELEROMETER)
+ return self.named.data.sensordata[imu_sensors]
+
+ def rangefinder(self):
+ """Returns scaled rangefinder sensor readings."""
+ rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
+ rf_readings = self.named.data.sensordata[rf_sensors]
+ no_intersection = -1.0
+ return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))
+
+ def origin_distance(self):
+ """Returns the distance from the origin to the workspace."""
+ return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))
+
+ def origin(self):
+ """Returns origin position in the torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ return -torso_pos.dot(torso_frame)
+
+ def ball_state(self):
+ """Returns ball position and velocity relative to the torso frame."""
+ data = self.named.data
+ torso_frame = data.xmat['torso'].reshape(3, 3)
+ ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
+ ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
+ ball_rot_vel = data.qvel['ball_root'][3:]
+ ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
+ return ball_state.dot(torso_frame).ravel()
+
+ def target_position(self):
+ """Returns target position in torso frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ torso_to_target = self.named.data.site_xpos['target'] - torso_pos
+ return torso_to_target.dot(torso_frame)
+
+ def ball_to_target_distance(self):
+ """Returns horizontal distance from the ball to the target."""
+ ball_to_target = (self.named.data.site_xpos['target'] -
+ self.named.data.xpos['ball'])
+ return np.linalg.norm(ball_to_target[:2])
+
+ def self_to_ball_distance(self):
+ """Returns horizontal distance from the quadruped workspace to the ball."""
+ self_to_ball = (self.named.data.site_xpos['workspace']
+ -self.named.data.xpos['ball'])
+ return np.linalg.norm(self_to_ball[:2])
+
+
+def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
+ """Find a height with no contacts given a body orientation.
+ Args:
+ physics: An instance of `Physics`.
+ orientation: A quaternion.
+ x_pos: A float. Position along global x-axis.
+ y_pos: A float. Position along global y-axis.
+ Raises:
+ RuntimeError: If a non-contacting configuration has not been found after
+ 10,000 attempts.
+ """
+ z_pos = 0.0 # Start embedded in the floor.
+ num_contacts = 1
+ num_attempts = 0
+ # Move up in 1cm increments until no contacts.
+ while num_contacts > 0:
+ try:
+ with physics.reset_context():
+ physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
+ physics.named.data.qpos['root'][3:] = orientation
+ except control.PhysicsError:
+ # We may encounter a PhysicsError here due to filling the contact
+ # buffer, in which case we simply increment the height and continue.
+ pass
+ num_contacts = physics.data.ncon
+ z_pos += 0.01
+ num_attempts += 1
+ if num_attempts > 10000:
+ raise RuntimeError('Failed to find a non-contacting configuration.')
+
+
+def _common_observations(physics):
+ """Returns the observations common to all tasks."""
+ obs = collections.OrderedDict()
+ obs['egocentric_state'] = physics.egocentric_state()
+ obs['torso_velocity'] = physics.torso_velocity()
+ obs['torso_upright'] = physics.torso_upright()
+ obs['imu'] = physics.imu()
+ obs['force_torque'] = physics.force_torque()
+ return obs
+
+def _lie_down_reward(physics, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+ Args:
+ physics: an instance of `Physics`.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ torso = physics.named.data.xpos['torso', 'z']
+ return rewards.tolerance(
+ torso,
+ bounds=(-float('inf'), _LIE_DOWN_HEIGHT),
+ margin=_LIE_DOWN_HEIGHT * 1.5)
+
+
+def _two_legs_reward(physics, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+ Args:
+ physics: an instance of `Physics`.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ toes = []
+ for t in ['toe_front_left', 'toe_front_right', 'toe_back_left', 'toe_back_right']:
+ toe = physics.named.data.xpos[t, 'z']
+ toes.append(toe)
+ toes = sorted(toes)
+ min_toes = sum(toes[:2]) / 2
+ max_toes = sum(toes[2:]) / 2
+ toes_up = rewards.tolerance(
+ max_toes,
+ bounds=(_FOOT_UP_HEIGHT, float('inf')),
+ margin=_FOOT_UP_HEIGHT // 2)
+ toes_down = rewards.tolerance(
+ min_toes,
+ bounds=(-float('inf'), _FOOT_DOWN_HEIGHT),
+ margin=_FOOT_DOWN_HEIGHT * 1.5)
+ return toes_down * toes_up
+
+
+def _upright_reward(physics, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+ Args:
+ physics: an instance of `Physics`.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ deviation = np.cos(np.deg2rad(deviation_angle))
+ return rewards.tolerance(
+ physics.torso_upright(),
+ bounds=(deviation, float('inf')),
+ sigmoid='linear',
+ margin=1 + deviation,
+ value_at_margin=0)
+
+
+class Move(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, desired_speed, random=None):
+ """Initializes an instance of `Move`.
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._desired_speed = desired_speed
+ super().__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Move reward term.
+ move_reward = rewards.tolerance(
+ physics.torso_velocity()[0],
+ bounds=(self._desired_speed, float('inf')),
+ margin=self._desired_speed,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ return _upright_reward(physics) * move_reward
+
+
+class Stand(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, random=None, goal='stand'):
+ """Initializes an instance of `Move`.
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super().__init__(random=random)
+ self._goal = goal
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ if self._goal == 'stand':
+ return _upright_reward(physics)
+ elif self._goal == 'lie_down':
+ return _lie_down_reward(physics)
+ elif self._goal == 'two_legs':
+ return _two_legs_reward(physics)
+
+class Jump(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, desired_height, random=None):
+ """Initializes an instance of `Move`.
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._desired_height = desired_height
+ super().__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Move reward term.
+ jump_up = rewards.tolerance(
+ physics.com_height(),
+ bounds=(self._desired_height, float('inf')),
+ margin=self._desired_height,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ return _upright_reward(physics) * jump_up
+
+
+class Roll(base.Task):
+ """A quadruped task solved by moving forward at a designated speed."""
+
+ def __init__(self, desired_speed, random=None):
+ """Initializes an instance of `Move`.
+ Args:
+ desired_speed: A float. If this value is zero, reward is given simply
+ for standing upright. Otherwise this specifies the horizontal velocity
+ at which the velocity-dependent reward component is maximized.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._desired_speed = desired_speed
+ super().__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ return _common_observations(physics)
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ # Move reward term.
+ move_reward = rewards.tolerance(
+ np.linalg.norm(physics.torso_velocity()),
+ bounds=(self._desired_speed, float('inf')),
+ margin=self._desired_speed,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ return _upright_reward(physics) * move_reward
+
+
+class Escape(base.Task):
+ """A quadruped task solved by escaping a bowl-shaped terrain."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Get heightfield resolution, assert that it is square.
+ res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
+ assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
+ # Sinusoidal bowl shape.
+ row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
+ radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
+ bowl_shape = .5 - np.cos(2*np.pi*radius)/2
+ # Random smooth bumps.
+ terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
+ bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
+ smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
+ # Terrain is elementwise product.
+ terrain = bowl_shape * smooth_bumps
+ start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
+ physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
+ super().initialize_episode(physics)
+
+ # If we have a rendering context, we need to re-upload the modified
+ # heightfield data.
+ if physics.contexts:
+ with physics.contexts.gl.make_current() as ctx:
+ ctx.call(mjlib.mjr_uploadHField,
+ physics.model.ptr,
+ physics.contexts.mujoco.ptr,
+ _HEIGHTFIELD_ID)
+
+ # Initial configuration.
+ orientation = self.random.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, orientation)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['origin'] = physics.origin()
+ obs['rangefinder'] = physics.rangefinder()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Escape reward term.
+ terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ escape_reward = rewards.tolerance(
+ physics.origin_distance(),
+ bounds=(terrain_size, float('inf')),
+ margin=terrain_size,
+ value_at_margin=0,
+ sigmoid='linear')
+
+ return _upright_reward(physics, deviation_angle=20) * escape_reward
+
+
+class Fetch(base.Task):
+ """A quadruped task solved by bringing a ball to the origin."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Initial configuration, random azimuth and horizontal position.
+ azimuth = self.random.uniform(0, 2*np.pi)
+ orientation = np.array((np.cos(azimuth/2), 0, 0, np.sin(azimuth/2)))
+ spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
+ x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
+ _find_non_contacting_height(physics, orientation, x_pos, y_pos)
+
+ # Initial ball state.
+ physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
+ -spawn_radius, spawn_radius, size=(2,))
+ physics.named.data.qpos['ball_root'][2] = 2
+ physics.named.data.qvel['ball_root'][:2] = 5*self.random.randn(2)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation to the agent."""
+ obs = _common_observations(physics)
+ obs['ball_state'] = physics.ball_state()
+ obs['target_position'] = physics.target_position()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+
+ # Reward for moving close to the ball.
+ arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
+ workspace_radius = physics.named.model.site_size['workspace', 0]
+ ball_radius = physics.named.model.geom_size['ball', 0]
+ reach_reward = rewards.tolerance(
+ physics.self_to_ball_distance(),
+ bounds=(0, workspace_radius+ball_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ # Reward for bringing the ball to the target.
+ target_radius = physics.named.model.site_size['target', 0]
+ fetch_reward = rewards.tolerance(
+ physics.ball_to_target_distance(),
+ bounds=(0, target_radius),
+ sigmoid='linear',
+ margin=arena_radius, value_at_margin=0)
+
+ reach_then_fetch = reach_reward * (0.5 + 0.5*fetch_reward)
+
+ return _upright_reward(physics) * reach_then_fetch
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/quadruped.xml b/envs/custom_dmc_tasks/quadruped.xml
new file mode 100644
index 0000000000000000000000000000000000000000..4024197d477fbe6423557d38d1eefe2a53c4862a
--- /dev/null
+++ b/envs/custom_dmc_tasks/quadruped.xml
@@ -0,0 +1,328 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/stickman.py b/envs/custom_dmc_tasks/stickman.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4460bce752b46c71da0f2d1404971449df6550
--- /dev/null
+++ b/envs/custom_dmc_tasks/stickman.py
@@ -0,0 +1,647 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Stickman Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import numpy as np
+import types
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import io as resources
+from dm_control import suite
+
+class StickmanYogaPoses:
+ lie_back = [ -1.2 , 0. , -1.57, 0, 0. , 0.0, 0, -0., 0.0]
+ lie_front = [-1.2, -0, 1.57, 0, 0, 0, 0, 0., 0.]
+ legs_up = [ -1.24 , 0. , -1.57, 1.57, 0. , 0.0, 1.57, -0., 0.0]
+
+ kneel = [ -0.5 , 0. , 0, 0, -1.57, -0.8, 1.57, -1.57, 0.0]
+ side_angle = [ -0.3 , 0. , 0.9, 0, 0, -0.7, 1.87, -1.07, 0.0]
+ stand_up = [-0.15, 0., 0.34, 0.74, -1.34, -0., 1.1, -0.66, -0.1]
+
+ lean_back = [-0.27, 0., -0.45, 0.22, -1.5, 0.86, 0.6, -0.8, -0.4]
+ boat = [ -1.04 , 0. , -0.8, 1.6, 0. , 0.0, 1.6, -0., 0.0]
+ bridge = [-1.1, 0., -2.2, -0.3, -1.5, 0., -0.3, -0.8, -0.4]
+
+ head_stand = [-1, 0., -3, 0.6, -1, -0.3, 0.9, -0.5, 0.3]
+ one_feet = [-0.2, 0., 0, 0.7, -1.34, 0.5, 1.5, -0.6, 0.1]
+ arabesque = [-0.34, 0., 1.57, 1.57, 0, 0., 0, -0., 0.]
+
+ # new
+ high_kick = [-0.165, 3.3 , 5.55 , 1.35 ,-0, +0.5 , -0.7, 0. , 0.2,]
+ splits = [-0.7, 0., 0.5, -0.7, -1. , 0, 1.75, 0., -0.45 ]
+ sit_knees = [-0.6, -0.2, 0.2, 0.95, -2.5, 0 , 0.95, -2.5, 0 ]
+
+
+_DEFAULT_TIME_LIMIT = 25
+_CONTROL_TIMESTEP = .025
+
+# Minimal height of torso over foot above which stand reward is 1.
+_STAND_HEIGHT = 1.15
+
+# Horizontal speeds (meters/second) above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 8
+
+# Copied from walker:
+_YOGA_HANDS_UP_HEIGHT = 1.75
+_YOGA_STAND_HEIGHT = 1.0 # lower than stan height = 1.2
+_YOGA_LIE_DOWN_HEIGHT = 0.1
+_YOGA_LEGS_UP_HEIGHT = 1.1
+
+_YOGA_FEET_UP_HEIGHT = 0.5
+_YOGA_FEET_UP_LIE_DOWN_HEIGHT = 0.35
+
+_YOGA_KNEE_HEIGHT = 0.25
+_YOGA_KNEESTAND_HEIGHT = 0.75
+
+_YOGA_SITTING_HEIGHT = 0.55
+_YOGA_SITTING_LEGS_HEIGHT = 0.15
+
+# speed from: https://github.com/rll-research/url_benchmark/blob/710c3eb/custom_dmc_tasks/py
+_SPIN_SPEED = 5.0
+#
+_PUNCH_SPEED = 5.0
+_PUNCH_DIST = 0.29
+
+
+SUITE = containers.TaggedTasks()
+
+def make(task,
+ task_kwargs=None,
+ environment_kwargs=None,
+ visualize_reward=False):
+ task_kwargs = task_kwargs or {}
+ if environment_kwargs is not None:
+ task_kwargs = task_kwargs.copy()
+ task_kwargs['environment_kwargs'] = environment_kwargs
+ env = SUITE[task](**task_kwargs)
+ env.task.visualize_reward = visualize_reward
+ return env
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+ xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks', 'stickman.xml'))
+ return xml, common.ASSETS
+
+@SUITE.add('custom')
+def hands_up(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the hands_up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='hands_up', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def boxing(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the boxing task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='boxing', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def arabesque(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Arabesque task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='arabesque', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def lying_down(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Lie Down task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='lying_down', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def legs_up(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Legs Up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='legs_up', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def high_kick(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='high_kick', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def one_foot(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='one_foot', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def lunge_pose(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='lunge_pose', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def sit_knees(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='sit_knees', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def headstand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Headstand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='flip', move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def urlb_flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Flip task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='urlb_flip', move_speed=_SPIN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def flipping(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Flipping task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='flipping', move_speed=2 * _RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Flip task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='flip', move_speed=2 * _RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def backflip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Backflip task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(goal='flip', move_speed=-2 * _RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@SUITE.add('custom')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(move_speed=0, goal='stand', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(move_speed=_WALK_SPEED, goal='walk', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@SUITE.add('custom')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Stickman(move_speed=_RUN_SPEED, goal='run', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the stickman domain."""
+ def torso_upright(self):
+ """Returns projection from z-axes of torso to the z-axes of world."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def torso_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos['torso', 'z']
+
+ def horizontal_velocity(self):
+ """Returns the horizontal velocity of the center-of-mass."""
+ return self.named.data.sensordata['torso_subtreelinvel'][0]
+
+ def orientations(self):
+ """Returns planar orientations of all bodies."""
+ return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
+
+ def angmomentum(self):
+ """Returns the angular momentum of torso of the stickman about Y axis."""
+ return self.named.data.subtree_angmom['torso'][1]
+
+
+class Stickman(base.Task):
+ """A planar stickman task."""
+ def __init__(self, move_speed=0., goal='walk', forward=True, random=None):
+ """Initializes an instance of `Stickman`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ self._forward = 1 if forward else -1
+ self._goal = goal
+ super().__init__(random=random)
+
+ def _hands_up_reward(self, physics):
+ standing = self._stand_reward(physics)
+ left_hand_height = physics.named.data.xpos['left_hand', 'z']
+ right_hand_height = physics.named.data.xpos['right_hand', 'z']
+
+ hand_height = (left_hand_height + right_hand_height) / 2
+
+ hands_up = rewards.tolerance(hand_height,
+ bounds=(_YOGA_HANDS_UP_HEIGHT, float('inf')),
+ margin=_YOGA_HANDS_UP_HEIGHT/2)
+ return standing * hands_up
+
+ def _boxing_reward(self, physics):
+ # torso up, but lower than standing
+ # foot up, higher than torso
+ # foot down
+ standing = self._stand_reward(physics)
+
+ left_hand_velocity = abs(physics.named.data.subtree_linvel['left_hand'][0])
+ right_hand_velocity = abs(physics.named.data.subtree_linvel['right_hand'][0])
+ punch_reward = rewards.tolerance(
+ max(left_hand_velocity, right_hand_velocity),
+ bounds=(_PUNCH_SPEED, float('inf')),
+ margin=_PUNCH_SPEED / 2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+
+ # left_hand_dist = physics.named.data.xpos['left_hand', 'x'] - physics.named.data.xpos['torso', 'x']
+ # right_hand_dist = physics.named.data.xpos['right_hand', 'x'] - physics.named.data.xpos['torso', 'x']
+ # punch_reward = rewards.tolerance(
+ # max(left_hand_dist, right_hand_dist),
+ # bounds=(_PUNCH_DIST, float('inf')),
+ # margin=_PUNCH_DIST / 2,)
+
+ return standing * punch_reward
+
+ def _arabesque_reward(self, physics):
+ # standing horizontal
+ # one foot up, same height as torso
+ # one foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ max_foot = 'right_foot' if right_foot_height > left_foot_height else 'left_foot'
+ min_foot = 'right_foot' if right_foot_height <= left_foot_height else 'left_foot'
+
+ min_foot_height = physics.named.data.xpos[min_foot, 'z']
+ max_foot_height = physics.named.data.xpos[max_foot, 'z']
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ min_foot_x = physics.named.data.xpos[min_foot, 'x']
+ max_foot_x = physics.named.data.xpos[max_foot, 'x']
+
+ correct_foot_pose = 0.1 if max_foot_x > min_foot_x else 1.0
+
+ feet_pose = (min_foot_down + max_foot_up * 2) / 3
+ return standing * feet_pose * correct_foot_pose
+
+ def _lying_down_reward(self, physics):
+ # torso down and horizontal
+ # thigh and feet down
+ torso_down = rewards.tolerance(physics.torso_height(),
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ horizontal = 1 - abs(physics.torso_upright())
+
+ thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
+ thigh_down = rewards.tolerance(thigh_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ leg_height = (physics.named.data.xpos['left_leg', 'z'] + physics.named.data.xpos['right_leg', 'z']) / 2
+ leg_down = rewards.tolerance(leg_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_down = rewards.tolerance(feet_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ return (3*torso_down + horizontal + thigh_down + feet_down + leg_down) / 7
+
+ def _legs_up_reward(self, physics):
+ # torso down and horizontal
+ # legs up with thigh down
+ torso_down = rewards.tolerance(physics.torso_height(),
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ horizontal = 1 - abs(physics.torso_upright())
+ torso_down = (3*torso_down +horizontal) / 4
+
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_up = rewards.tolerance(feet_height,
+ bounds=(_YOGA_FEET_UP_LIE_DOWN_HEIGHT, float('inf')),
+ margin=_YOGA_FEET_UP_LIE_DOWN_HEIGHT/2)
+
+ return torso_down * feet_up
+
+ def _high_kick_reward(self, physics):
+ # torso up, but lower than standing
+ # foot up, higher than torso
+ # foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ min_foot_height = min(left_foot_height, right_foot_height)
+ max_foot_height = max(left_foot_height, right_foot_height)
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/2)
+
+ feet_pose = (3 * max_foot_up + min_foot_down) / 4
+
+ return standing * feet_pose
+
+ def _one_foot_reward(self, physics):
+ # torso up, standing
+ # foot up higher than foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ min_foot_height = min(left_foot_height, right_foot_height)
+ max_foot_height = max(left_foot_height, right_foot_height)
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(_YOGA_FEET_UP_HEIGHT, float('inf')),
+ margin=_YOGA_FEET_UP_HEIGHT/2)
+
+ return standing * max_foot_up * min_foot_down
+
+ def _lunge_pose_reward(self, physics):
+ # torso up, standing, but lower
+ # leg up higher than leg down
+ # horiontal thigh and leg
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_KNEESTAND_HEIGHT, float('inf')),
+ margin=_YOGA_KNEESTAND_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ torso = (3*standing + upright) / 4
+
+ left_leg_height = physics.named.data.xpos['left_leg', 'z']
+ right_leg_height = physics.named.data.xpos['right_leg', 'z']
+
+ min_leg_height = min(left_leg_height, right_leg_height)
+ max_leg_height = max(left_leg_height, right_leg_height)
+
+ min_leg_down = rewards.tolerance(min_leg_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_leg_up = rewards.tolerance(max_leg_height,
+ bounds=(_YOGA_KNEE_HEIGHT, float('inf')),
+ margin=_YOGA_KNEE_HEIGHT / 2)
+
+ max_thigh = 'left_thigh' if max_leg_height == left_leg_height else 'right_thigh'
+ min_leg = 'left_leg' if min_leg_height == left_leg_height else 'right_leg'
+
+ max_thigh_horiz = 1 - abs(physics.named.data.xmat[max_thigh, 'zz'])
+ min_leg_horiz = 1 - abs(physics.named.data.xmat[min_leg, 'zz'])
+
+ legs = (min_leg_down + max_leg_up + max_thigh_horiz + min_leg_horiz) / 4
+
+ return torso * legs
+
+ def _sit_knees_reward(self, physics):
+ # torso up, standing, but lower
+ # foot up higher than foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_SITTING_HEIGHT, float('inf')),
+ margin=_YOGA_SITTING_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ torso_up = (3*standing + upright) / 4
+
+ legs_height = (physics.named.data.xpos['left_leg', 'z'] + physics.named.data.xpos['right_leg', 'z']) / 2
+ legs_down = rewards.tolerance(legs_height,
+ bounds=(-float('inf'), _YOGA_SITTING_LEGS_HEIGHT),
+ margin=_YOGA_SITTING_LEGS_HEIGHT*1.5)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_down = rewards.tolerance(feet_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+
+ l_thigh_foot_distance = max(0.1, abs(physics.named.data.xpos['left_foot', 'x'] - physics.named.data.xpos['left_thigh', 'x'])) - 0.1
+ r_thigh_foot_distance = max(0.1, abs(physics.named.data.xpos['right_foot', 'x'] - physics.named.data.xpos['right_thigh', 'x'])) - 0.1
+ close = np.exp(-(l_thigh_foot_distance + r_thigh_foot_distance)/2)
+
+ legs = (3 * legs_down + feet_down) / 4
+ return torso_up * legs * close
+
+ def _urlb_flip_reward(self, physics):
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT / 2)
+ upright = (1 + physics.torso_upright()) / 2
+ stand_reward = (3 * standing + upright) / 4
+ move_reward = rewards.tolerance(self._forward *
+ physics.named.data.subtree_angmom['torso'][1], # physics.angmomentum(),
+ bounds=(_SPIN_SPEED, float('inf')),
+ margin=_SPIN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
+ return stand_reward * (5 * move_reward + 1) / 6
+
+ def _flip_reward(self, physics):
+ thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
+ thigh_up = rewards.tolerance(thigh_height,
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ legs_up = rewards.tolerance(feet_height,
+ bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
+ margin=_YOGA_LEGS_UP_HEIGHT/2)
+ upside_down_reward = (3*legs_up + 2*thigh_up) / 5
+ if self._move_speed == 0:
+ return upside_down_reward
+ move_reward = rewards.tolerance(physics.named.data.subtree_angmom['torso'][1], # physics.angmomentum(),
+ bounds=(self._move_speed, float('inf')) if self._move_speed > 0 else (-float('inf'), self._move_speed),
+ margin=abs(self._move_speed)/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return upside_down_reward * (5*move_reward + 1) / 6
+
+
+ def _stand_reward(self, physics):
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT / 2)
+ upright = (1 + physics.torso_upright()) / 2
+ return (3 * standing + upright) / 4
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ In 'standing' mode, use initial orientation and small velocities.
+ In 'random' mode, randomize joint angles and let fall to the floor.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ super().initialize_episode(physics)
+
+ def get_observation(self, physics):
+ """Returns an observation of body orientations, height and velocites."""
+ obs = collections.OrderedDict()
+ obs['orientations'] = physics.orientations()
+ obs['height'] = physics.torso_height()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ if self._goal in ['stand', 'walk', 'run']:
+ stand_reward = self._stand_reward(physics)
+ move_reward = rewards.tolerance(
+ self._forward * physics.horizontal_velocity(),
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed / 2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return stand_reward * (5 * move_reward + 1) / 6
+ if self._goal == 'flipping':
+ self._move_speed = abs(self._move_speed)
+ pos_rew = self._flip_reward(physics)
+ self._move_speed = -abs(self._move_speed)
+ neg_rew = self._flip_reward(physics)
+ return max(pos_rew, neg_rew)
+ try:
+ reward_fn = getattr(self, f'_{self._goal}_reward')
+ return reward_fn(physics)
+ except Exception as e:
+ print(e)
+ raise NotImplementedError(f'Goal {self._goal} or function "_{self._goal}_reward" not implemented.')
+
+if __name__ == '__main__':
+ from dm_control import viewer
+ import numpy as np
+
+ env = boxing()
+ env.task.visualize_reward = True
+
+ action_spec = env.action_spec()
+
+ def zero_policy(time_step):
+ print(time_step.reward)
+ return np.zeros(action_spec.shape)
+
+ ts = env.reset()
+ while True:
+ ts = env.step(zero_policy(ts))
+
+ viewer.launch(env, policy=zero_policy)
+
+ # obs = env.reset()
+ # next_obs, reward, done, info = env.step(np.zeros(6))
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/stickman.xml b/envs/custom_dmc_tasks/stickman.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6ce060adf0caec18dc42c2d4f553b3a1e48e0edb
--- /dev/null
+++ b/envs/custom_dmc_tasks/stickman.xml
@@ -0,0 +1,108 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/walker.py b/envs/custom_dmc_tasks/walker.py
new file mode 100644
index 0000000000000000000000000000000000000000..827e73ba5fbabd5f6a24e11de6c8c91de64b66f1
--- /dev/null
+++ b/envs/custom_dmc_tasks/walker.py
@@ -0,0 +1,489 @@
+import os
+
+import numpy as np
+from dm_control.rl import control
+from dm_control.suite import common
+from dm_control.suite import walker
+from dm_control.utils import rewards
+from dm_control.utils import io as resources
+
+_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'custom_dmc_tasks')
+
+_YOGA_STAND_HEIGHT = 1.0 # lower than stan height = 1.2
+_YOGA_LIE_DOWN_HEIGHT = 0.1
+_YOGA_LEGS_UP_HEIGHT = 1.1
+
+_YOGA_FEET_UP_HEIGHT = 0.5
+_YOGA_FEET_UP_LIE_DOWN_HEIGHT = 0.35
+
+_YOGA_KNEE_HEIGHT = 0.25
+_YOGA_KNEESTAND_HEIGHT = 0.75
+
+_YOGA_SITTING_HEIGHT = 0.55
+_YOGA_SITTING_LEGS_HEIGHT = 0.15
+
+# speed from: https://github.com/rll-research/url_benchmark/blob/710c3eb/custom_dmc_tasks/walker.py
+_SPIN_SPEED = 5.0
+#
+
+class WalkerYogaPoses:
+ """
+ Joint positions for some yoga poses
+ """
+ lie_back = [ -1.2 , 0. , -1.57, 0, 0. , 0.0, 0, -0., 0.0]
+ lie_front = [-1.2, -0, 1.57, 0, -0.2, 0, 0, -0.2, 0.]
+ legs_up = [ -1.24 , 0. , -1.57, 1.57, 0. , 0.0, 1.57, -0., 0.0]
+
+ kneel = [ -0.5 , 0. , 0, 0, -1.57, -0.8, 1.57, -1.57, 0.0]
+ side_angle = [ -0.3 , 0. , 0.9, 0, 0, -0.7, 1.87, -1.07, 0.0]
+ stand_up = [-0.15, 0., 0.34, 0.74, -1.34, -0., 1.1, -0.66, -0.1]
+
+ lean_back = [-0.27, 0., -0.45, 0.22, -1.5, 0.86, 0.6, -0.8, -0.4]
+ boat = [ -1.04 , 0. , -0.8, 1.6, 0. , 0.0, 1.6, -0., 0.0]
+ bridge = [-1.1, 0., -2.2, -0.3, -1.5, 0., -0.3, -0.8, -0.4]
+
+ head_stand = [-1, 0., -3, 0.6, -1, -0.3, 0.9, -0.5, 0.3]
+ one_foot = [-0.2, 0., 0, 0.7, -1.34, 0.5, 1.5, -0.6, 0.1]
+
+ arabesque = [-0.34, 0., 1.57, 1.57, 0, 0., 0, -0., 0.]
+
+ # new
+ high_kick = [-0.165, 3.3 , 5.55 , 1.35 ,-0, +0.5 , -0.7, 0. , 0.2,]
+ splits = [-0.7, 0., 0.5, -0.7, -1. , 0, 1.75, 0., -0.45 ]
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return resources.GetResource(os.path.join(_TASKS_DIR, 'walker.xml')), common.ASSETS
+
+
+@walker.SUITE.add('custom')
+def walk_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Walk Backwards task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = BackwardsPlanarWalker(move_speed=walker._WALK_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def run_backwards(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Run Backwards task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = BackwardsPlanarWalker(move_speed=walker._RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def arabesque(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Arabesque task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='arabesque', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def lying_down(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Lie Down task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='lying_down', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def legs_up(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Legs Up task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='legs_up', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def high_kick(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='high_kick', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def one_foot(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='one_foot', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def lunge_pose(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='lunge_pose', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def sit_knees(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the High Kick task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='sit_knees', random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def headstand(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Headstand task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='flip', move_speed=0, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def urlb_flip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Flip task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='urlb_flip', move_speed=_SPIN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def flipping(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the flipping task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='flipping', move_speed=2* walker._RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+@walker.SUITE.add('custom')
+def flip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Flip task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='flip', move_speed=2* walker._RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+@walker.SUITE.add('custom')
+def backflip(time_limit=walker._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
+ """Returns the Backflip task."""
+ physics = walker.Physics.from_xml_string(*get_model_and_assets())
+ task = YogaPlanarWalker(goal='flip', move_speed=-2 * walker._RUN_SPEED, random=random)
+ environment_kwargs = environment_kwargs or {}
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=walker._CONTROL_TIMESTEP,
+ **environment_kwargs)
+
+
+class BackwardsPlanarWalker(walker.PlanarWalker):
+ """Backwards PlanarWalker task."""
+ def __init__(self, move_speed, random=None):
+ super().__init__(move_speed, random)
+
+ def get_reward(self, physics):
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ stand_reward = (3*standing + upright) / 4
+ if self._move_speed == 0:
+ return stand_reward
+ else:
+ move_reward = rewards.tolerance(physics.horizontal_velocity(),
+ bounds=(-float('inf'), -self._move_speed),
+ margin=self._move_speed/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return stand_reward * (5*move_reward + 1) / 6
+
+
+class YogaPlanarWalker(walker.PlanarWalker):
+ """Yoga PlanarWalker tasks."""
+
+ def __init__(self, goal='arabesque', move_speed=0, random=None):
+ super().__init__(0, random)
+ self._goal = goal
+ self._move_speed = move_speed
+
+ def _arabesque_reward(self, physics):
+ # standing horizontal
+ # one foot up, same height as torso
+ # one foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ max_foot = 'right_foot' if right_foot_height > left_foot_height else 'left_foot'
+ min_foot = 'right_foot' if right_foot_height <= left_foot_height else 'left_foot'
+
+ min_foot_height = physics.named.data.xpos[min_foot, 'z']
+ max_foot_height = physics.named.data.xpos[max_foot, 'z']
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ min_foot_x = physics.named.data.xpos[min_foot, 'x']
+ max_foot_x = physics.named.data.xpos[max_foot, 'x']
+
+ correct_foot_pose = 0.1 if max_foot_x > min_foot_x else 1.0
+
+ feet_pose = (min_foot_down + max_foot_up * 2) / 3
+ return standing * feet_pose * correct_foot_pose
+
+ def _lying_down_reward(self, physics):
+ # torso down and horizontal
+ # thigh and feet down
+ torso_down = rewards.tolerance(physics.torso_height(),
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ horizontal = 1 - abs(physics.torso_upright())
+
+ thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
+ thigh_down = rewards.tolerance(thigh_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ leg_height = (physics.named.data.xpos['left_leg', 'z'] + physics.named.data.xpos['right_leg', 'z']) / 2
+ leg_down = rewards.tolerance(leg_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_down = rewards.tolerance(feet_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ return (3*torso_down + horizontal + thigh_down + feet_down + leg_down) / 7
+
+ def _legs_up_reward(self, physics):
+ # torso down and horizontal
+ # legs up with thigh down
+ torso_down = rewards.tolerance(physics.torso_height(),
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ horizontal = 1 - abs(physics.torso_upright())
+ torso_down = (3*torso_down +horizontal) / 4
+
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_up = rewards.tolerance(feet_height,
+ bounds=(_YOGA_FEET_UP_LIE_DOWN_HEIGHT, float('inf')),
+ margin=_YOGA_FEET_UP_LIE_DOWN_HEIGHT/2)
+
+ return torso_down * feet_up
+
+ def _high_kick_reward(self, physics):
+ # torso up, but lower than standing
+ # foot up, higher than torso
+ # foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ min_foot_height = min(left_foot_height, right_foot_height)
+ max_foot_height = max(left_foot_height, right_foot_height)
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(walker._STAND_HEIGHT, float('inf')),
+ margin=walker._STAND_HEIGHT/2)
+
+ feet_pose = (3 * max_foot_up + min_foot_down) / 4
+
+ return standing * feet_pose
+
+ def _one_foot_reward(self, physics):
+ # torso up, standing
+ # foot up higher than foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+
+ left_foot_height = physics.named.data.xpos['left_foot', 'z']
+ right_foot_height = physics.named.data.xpos['right_foot', 'z']
+
+ min_foot_height = min(left_foot_height, right_foot_height)
+ max_foot_height = max(left_foot_height, right_foot_height)
+
+ min_foot_down = rewards.tolerance(min_foot_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_foot_up = rewards.tolerance(max_foot_height,
+ bounds=(_YOGA_FEET_UP_HEIGHT, float('inf')),
+ margin=_YOGA_FEET_UP_HEIGHT/2)
+
+ return standing * max_foot_up * min_foot_down
+
+ def _lunge_pose_reward(self, physics):
+ # torso up, standing, but lower
+ # leg up higher than leg down
+ # horiontal thigh and leg
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_KNEESTAND_HEIGHT, float('inf')),
+ margin=_YOGA_KNEESTAND_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ torso = (3*standing + upright) / 4
+
+ left_leg_height = physics.named.data.xpos['left_leg', 'z']
+ right_leg_height = physics.named.data.xpos['right_leg', 'z']
+
+ min_leg_height = min(left_leg_height, right_leg_height)
+ max_leg_height = max(left_leg_height, right_leg_height)
+
+ min_leg_down = rewards.tolerance(min_leg_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+ max_leg_up = rewards.tolerance(max_leg_height,
+ bounds=(_YOGA_KNEE_HEIGHT, float('inf')),
+ margin=_YOGA_KNEE_HEIGHT / 2)
+
+ max_thigh = 'left_thigh' if max_leg_height == left_leg_height else 'right_thigh'
+ min_leg = 'left_leg' if min_leg_height == left_leg_height else 'right_leg'
+
+ max_thigh_horiz = 1 - abs(physics.named.data.xmat[max_thigh, 'zz'])
+ min_leg_horiz = 1 - abs(physics.named.data.xmat[min_leg, 'zz'])
+
+ legs = (min_leg_down + max_leg_up + max_thigh_horiz + min_leg_horiz) / 4
+
+ return torso * legs
+
+ def _sit_knees_reward(self, physics):
+ # torso up, standing, but lower
+ # foot up higher than foot down
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_YOGA_SITTING_HEIGHT, float('inf')),
+ margin=_YOGA_SITTING_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ torso_up = (3*standing + upright) / 4
+
+ legs_height = (physics.named.data.xpos['left_leg', 'z'] + physics.named.data.xpos['right_leg', 'z']) / 2
+ legs_down = rewards.tolerance(legs_height,
+ bounds=(-float('inf'), _YOGA_SITTING_LEGS_HEIGHT),
+ margin=_YOGA_SITTING_LEGS_HEIGHT*1.5)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ feet_down = rewards.tolerance(feet_height,
+ bounds=(-float('inf'), _YOGA_LIE_DOWN_HEIGHT),
+ margin=_YOGA_LIE_DOWN_HEIGHT*1.5)
+
+ l_thigh_foot_distance = max(0.1, abs(physics.named.data.xpos['left_foot', 'x'] - physics.named.data.xpos['left_thigh', 'x'])) - 0.1
+ r_thigh_foot_distance = max(0.1, abs(physics.named.data.xpos['right_foot', 'x'] - physics.named.data.xpos['right_thigh', 'x'])) - 0.1
+ close = np.exp(-(l_thigh_foot_distance + r_thigh_foot_distance)/2)
+
+ legs = (3 * legs_down + feet_down) / 4
+ return torso_up * legs * close
+
+ def _urlb_flip_reward(self, physics):
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(walker._STAND_HEIGHT, float('inf')),
+ margin=walker._STAND_HEIGHT / 2)
+ upright = (1 + physics.torso_upright()) / 2
+ stand_reward = (3 * standing + upright) / 4
+ move_reward = rewards.tolerance(physics.named.data.subtree_angmom['torso'][1], # physics.angmomentum(),
+ bounds=(_SPIN_SPEED, float('inf')),
+ margin=_SPIN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
+ return stand_reward * (5 * move_reward + 1) / 6
+
+ def _flip_reward(self, physics):
+ thigh_height = (physics.named.data.xpos['left_thigh', 'z'] + physics.named.data.xpos['right_thigh', 'z']) / 2
+ thigh_up = rewards.tolerance(thigh_height,
+ bounds=(_YOGA_STAND_HEIGHT, float('inf')),
+ margin=_YOGA_STAND_HEIGHT/2)
+ feet_height = (physics.named.data.xpos['left_foot', 'z'] + physics.named.data.xpos['right_foot', 'z']) / 2
+ legs_up = rewards.tolerance(feet_height,
+ bounds=(_YOGA_LEGS_UP_HEIGHT, float('inf')),
+ margin=_YOGA_LEGS_UP_HEIGHT/2)
+ upside_down_reward = (3*legs_up + 2*thigh_up) / 5
+ if self._move_speed == 0:
+ return upside_down_reward
+ move_reward = rewards.tolerance(physics.named.data.subtree_angmom['torso'][1], # physics.angmomentum(),
+ bounds=(self._move_speed, float('inf')) if self._move_speed > 0 else (-float('inf'), self._move_speed),
+ margin=abs(self._move_speed)/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return upside_down_reward * (5*move_reward + 1) / 6
+
+ def get_reward(self, physics):
+ if self._goal == 'arabesque':
+ return self._arabesque_reward(physics)
+ elif self._goal == 'lying_down':
+ return self._lying_down_reward(physics)
+ elif self._goal == 'legs_up':
+ return self._legs_up_reward(physics)
+ elif self._goal == 'flip':
+ return self._flip_reward(physics)
+ elif self._goal == 'flipping':
+ self._move_speed = abs(self._move_speed)
+ pos_rew = self._flip_reward(physics)
+ self._move_speed = -abs(self._move_speed)
+ neg_rew = self._flip_reward(physics)
+ return max(pos_rew, neg_rew)
+ elif self._goal == 'high_kick':
+ return self._high_kick_reward(physics)
+ elif self._goal == 'one_foot':
+ return self._one_foot_reward(physics)
+ elif self._goal == 'lunge_pose':
+ return self._lunge_pose_reward(physics)
+ elif self._goal == 'sit_knees':
+ return self._sit_knees_reward(physics)
+ elif self._goal == 'urlb_flip':
+ return self._urlb_flip_reward(physics)
+ else:
+ raise NotImplementedError(f'Goal {self._goal} is not implemented.')
+
+
+if __name__ == '__main__':
+ from dm_control import viewer
+ import numpy as np
+
+ env = sit_knees()
+ env.task.visualize_reward = True
+
+ action_spec = env.action_spec()
+
+ def zero_policy(time_step):
+ print(time_step.reward)
+ return np.zeros(action_spec.shape)
+ viewer.launch(env, policy=zero_policy)
+
+ # obs = env.reset()
+ # next_obs, reward, done, info = env.step(np.zeros(6))
\ No newline at end of file
diff --git a/envs/custom_dmc_tasks/walker.xml b/envs/custom_dmc_tasks/walker.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6c23ded5503365cefd8c1f32743813b5eec5f1ed
--- /dev/null
+++ b/envs/custom_dmc_tasks/walker.xml
@@ -0,0 +1,71 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/envs/kitchen_extra.py b/envs/kitchen_extra.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16f57bf650e7e42619b70537346014a2bdfa38e
--- /dev/null
+++ b/envs/kitchen_extra.py
@@ -0,0 +1,299 @@
+"""Environments using kitchen and Franka robot."""
+import logging
+import sys
+from pathlib import Path
+sys.path.append((Path(__file__).parent.parent / 'third_party' / 'relay-policy-learning' / 'adept_envs').__str__())
+import adept_envs
+from adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1
+import os
+import numpy as np
+from dm_control.mujoco import engine
+
+OBS_ELEMENT_INDICES = {
+ "bottom burner": np.array([11, 12]),
+ "top burner": np.array([15, 16]),
+ "light switch": np.array([17, 18]),
+ "slide cabinet": np.array([19]),
+ "hinge cabinet": np.array([20, 21]),
+ "microwave": np.array([22]),
+ "kettle": np.array([23, 24, 25, 26, 27, 28, 29]),
+}
+OBS_ELEMENT_GOALS = {
+ "bottom burner": np.array([-0.88, -0.01]),
+ "top burner": np.array([-0.92, -0.01]),
+ "light switch": np.array([-0.69, -0.05]),
+ "slide cabinet": np.array([0.37]),
+ "hinge cabinet": np.array([0.0, 1.45]),
+ "microwave": np.array([-0.75]),
+ "kettle": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
+}
+BONUS_THRESH = 0.3
+
+logging.basicConfig(
+ level="INFO",
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ filemode="w",
+)
+logger = logging.getLogger()
+
+XPOS_NAMES = {
+ "light switch" : "lightswitchroot",
+ "slide cabinet" : "slidelink",
+ "microwave" : "microdoorroot",
+ "kettle" : "kettle",
+}
+
+class KitchenBase(KitchenTaskRelaxV1):
+ # A string of element names. The robot's task is then to modify each of
+ # these elements appropriately.
+ TASK_ELEMENTS = []
+ ALL_TASKS = [
+ "bottom burner",
+ "top burner",
+ "light switch",
+ "slide cabinet",
+ "hinge cabinet",
+ "microwave",
+ "kettle",
+ ]
+ REMOVE_TASKS_WHEN_COMPLETE = True
+ TERMINATE_ON_TASK_COMPLETE = True
+ TERMINATE_ON_WRONG_COMPLETE = False
+ COMPLETE_IN_ANY_ORDER = (
+ True # This allows for the tasks to be completed in arbitrary order.
+ )
+ GRIPPER_DISTANCE_REW = False
+
+ def __init__(
+ self, dense=True, dataset_url=None, ref_max_score=None, ref_min_score=None, **kwargs
+ ):
+ self.tasks_to_complete = list(self.TASK_ELEMENTS)
+ self.goal_masking = True
+ self.dense = dense
+ self.use_grasp_rewards = False
+
+ super(KitchenBase, self).__init__(**kwargs)
+
+ def set_goal_masking(self, goal_masking=True):
+ """Sets goal masking for goal-conditioned approaches (like RPL)."""
+ self.goal_masking = goal_masking
+
+ def _get_task_goal(self, task=None, actually_return_goal=False):
+ if task is None:
+ task = ["microwave", "kettle", "bottom burner", "light switch"]
+ new_goal = np.zeros_like(self.goal)
+ if self.goal_masking and not actually_return_goal:
+ return new_goal
+ for element in task:
+ element_idx = OBS_ELEMENT_INDICES[element]
+ element_goal = OBS_ELEMENT_GOALS[element]
+ new_goal[element_idx] = element_goal
+
+ return new_goal
+
+ def reset_model(self):
+ self.tasks_to_complete = list(self.TASK_ELEMENTS)
+ return super(KitchenBase, self).reset_model()
+
+ def _get_reward_n_score(self, obs_dict):
+ reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict)
+ next_q_obs = obs_dict["qp"]
+ next_obj_obs = obs_dict["obj_qp"]
+ idx_offset = len(next_q_obs)
+ completions = []
+ dense = 0
+ if self.GRIPPER_DISTANCE_REW:
+ assert len(self.tasks_to_complete) == 1
+ element = next(iter(self.tasks_to_complete))
+ gripper_pos = (self.sim.named.data.xpos['panda0_leftfinger'] + self.sim.named.data.xpos['panda0_rightfinger']) / 2
+ object_pos = self.sim.named.data.xpos[XPOS_NAMES[element]]
+ gripper_obj_dist = np.linalg.norm(object_pos - gripper_pos)
+ if self.dense:
+ reward_dict["bonus"] = -gripper_obj_dist
+ reward_dict["r_total"] = -gripper_obj_dist
+ score = -gripper_obj_dist
+ else:
+ reward_dict["bonus"] = gripper_obj_dist < 0.15
+ reward_dict["r_total"] = gripper_obj_dist < 0.15
+ score = gripper_obj_dist < 0.15
+ return reward_dict, score
+ for element in self.tasks_to_complete:
+ element_idx = OBS_ELEMENT_INDICES[element]
+ distance = np.linalg.norm(
+ next_obj_obs[..., element_idx - idx_offset] - OBS_ELEMENT_GOALS[element]
+ )
+ dense += -1 * distance # reward must be negative distance for RL
+ is_grasped = True
+ if not self.initializing and self.use_grasp_rewards:
+ if element == "slide cabinet":
+ is_grasped = False
+ for i in range(1, 6):
+ obj_pos = self.get_site_xpos("schandle{}".format(i))
+ left_pad = self.get_site_xpos("leftpad")
+ right_pad = self.get_site_xpos("rightpad")
+ within_sphere_left = np.linalg.norm(obj_pos - left_pad) < 0.07
+ within_sphere_right = np.linalg.norm(obj_pos - right_pad) < 0.07
+ right = right_pad[0] < obj_pos[0]
+ left = obj_pos[0] < left_pad[0]
+ if (
+ right
+ and left
+ and within_sphere_right
+ and within_sphere_left
+ ):
+ is_grasped = True
+ if element == "top left burner":
+ is_grasped = False
+ obj_pos = self.get_site_xpos("tlbhandle")
+ left_pad = self.get_site_xpos("leftpad")
+ right_pad = self.get_site_xpos("rightpad")
+ within_sphere_left = np.linalg.norm(obj_pos - left_pad) < 0.035
+ within_sphere_right = np.linalg.norm(obj_pos - right_pad) < 0.04
+ right = right_pad[0] < obj_pos[0]
+ left = obj_pos[0] < left_pad[0]
+ if within_sphere_right and within_sphere_left and right and left:
+ is_grasped = True
+ if element == "microwave":
+ is_grasped = False
+ for i in range(1, 6):
+ obj_pos = self.get_site_xpos("mchandle{}".format(i))
+ left_pad = self.get_site_xpos("leftpad")
+ right_pad = self.get_site_xpos("rightpad")
+ within_sphere_left = np.linalg.norm(obj_pos - left_pad) < 0.05
+ within_sphere_right = np.linalg.norm(obj_pos - right_pad) < 0.05
+ if (
+ right_pad[0] < obj_pos[0]
+ and obj_pos[0] < left_pad[0]
+ and within_sphere_right
+ and within_sphere_left
+ ):
+ is_grasped = True
+ if element == "hinge cabinet":
+ is_grasped = False
+ for i in range(1, 6):
+ obj_pos = self.get_site_xpos("hchandle{}".format(i))
+ left_pad = self.get_site_xpos("leftpad")
+ right_pad = self.get_site_xpos("rightpad")
+ within_sphere_left = np.linalg.norm(obj_pos - left_pad) < 0.06
+ within_sphere_right = np.linalg.norm(obj_pos - right_pad) < 0.06
+ if (
+ right_pad[0] < obj_pos[0]
+ and obj_pos[0] < left_pad[0]
+ and within_sphere_right
+ ):
+ is_grasped = True
+ if element == "light switch":
+ is_grasped = False
+ for i in range(1, 4):
+ obj_pos = self.get_site_xpos("lshandle{}".format(i))
+ left_pad = self.get_site_xpos("leftpad")
+ right_pad = self.get_site_xpos("rightpad")
+ within_sphere_left = np.linalg.norm(obj_pos - left_pad) < 0.045
+ within_sphere_right = np.linalg.norm(obj_pos - right_pad) < 0.03
+ if within_sphere_right and within_sphere_left:
+ is_grasped = True
+ complete = distance < BONUS_THRESH # and is_grasped
+ if complete:
+ completions.append(element)
+ if self.REMOVE_TASKS_WHEN_COMPLETE:
+ [self.tasks_to_complete.remove(element) for element in completions]
+ bonus = float(len(completions))
+ reward_dict["bonus"] = bonus
+ reward_dict["r_total"] = bonus
+ if self.dense:
+ reward_dict["r_total"] = dense
+ score = bonus
+ return reward_dict, score
+
+ def step(self, a, b=None):
+ obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b)
+ if self.TERMINATE_ON_TASK_COMPLETE:
+ done = not self.tasks_to_complete
+ if self.TERMINATE_ON_WRONG_COMPLETE:
+ all_goal = self._get_task_goal(task=self.ALL_TASKS)
+ for wrong_task in list(set(self.ALL_TASKS) - set(self.TASK_ELEMENTS)):
+ element_idx = OBS_ELEMENT_INDICES[wrong_task]
+ distance = np.linalg.norm(obs[..., element_idx] - all_goal[element_idx])
+ complete = distance < BONUS_THRESH
+ if complete:
+ done = True
+ break
+ env_info["completed_tasks"] = set(self.TASK_ELEMENTS) - set(
+ self.tasks_to_complete
+ )
+ return obs, reward, done, env_info
+
+ def get_goal(self):
+ """Loads goal state from dataset for goal-conditioned approaches (like RPL)."""
+ raise NotImplementedError
+
+ def _split_data_into_seqs(self, data):
+ """Splits dataset object into list of sequence dicts."""
+ seq_end_idxs = np.where(data["terminals"])[0]
+ start = 0
+ seqs = []
+ for end_idx in seq_end_idxs:
+ seqs.append(
+ dict(
+ states=data["observations"][start : end_idx + 1],
+ actions=data["actions"][start : end_idx + 1],
+ )
+ )
+ start = end_idx + 1
+ return seqs
+
+ def render(self, mode='rgb_array', resolution=(64,64)):
+ if mode =='rgb_array':
+ camera = engine.MovableCamera(self.sim, *resolution)
+ camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35)
+ img = camera.render()
+ return img
+ else:
+ super(KitchenTaskRelaxV1, self).render()
+
+
+class KitchenSlideV0(KitchenBase):
+ TASK_ELEMENTS = ["slide cabinet",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenHingeV0(KitchenBase):
+ TASK_ELEMENTS = ["hinge cabinet",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenLightV0(KitchenBase):
+ TASK_ELEMENTS = ["light switch",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenKettleV0(KitchenBase):
+ TASK_ELEMENTS = ["kettle",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenMicrowaveV0(KitchenBase):
+ TASK_ELEMENTS = ["microwave",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenBurnerV0(KitchenBase):
+ TASK_ELEMENTS = ["bottom burner",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenTopBurnerV0(KitchenBase):
+ TASK_ELEMENTS = ["top burner",]
+ COMPLETE_IN_ANY_ORDER = False
+
+class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase):
+ TASK_ELEMENTS = ["microwave", "kettle", "bottom burner", "light switch"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenMicrowaveKettleLightSliderV0(KitchenBase):
+ TASK_ELEMENTS = ["microwave", "kettle", "light switch", "slide cabinet"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenKettleMicrowaveLightSliderV0(KitchenBase):
+ TASK_ELEMENTS = ["kettle", "microwave", "light switch", "slide cabinet"]
+ COMPLETE_IN_ANY_ORDER = False
+
+
+class KitchenAllV0(KitchenBase):
+ TASK_ELEMENTS = KitchenBase.ALL_TASKS
\ No newline at end of file
diff --git a/envs/main.py b/envs/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c873d66d26315f92ba6196f3ccd095539267d457
--- /dev/null
+++ b/envs/main.py
@@ -0,0 +1,743 @@
+from collections import OrderedDict, deque
+from typing import Any, NamedTuple
+import os
+
+import dm_env
+import numpy as np
+from dm_env import StepType, specs
+
+import gym
+import torch
+
+class ExtendedTimeStep(NamedTuple):
+ step_type: Any
+ reward: Any
+ discount: Any
+ observation: Any
+ action: Any
+
+ def first(self):
+ return self.step_type == StepType.FIRST
+
+ def mid(self):
+ return self.step_type == StepType.MID
+
+ def last(self):
+ return self.step_type == StepType.LAST
+
+ def __getitem__(self, attr):
+ return getattr(self, attr)
+
+
+class FlattenJacoObservationWrapper(dm_env.Environment):
+ def __init__(self, env):
+ self._env = env
+ self._obs_spec = OrderedDict()
+ wrapped_obs_spec = env.observation_spec().copy()
+ if 'front_close' in wrapped_obs_spec:
+ spec = wrapped_obs_spec['front_close']
+ # drop batch dim
+ self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
+ dtype=spec.dtype,
+ minimum=spec.minimum,
+ maximum=spec.maximum,
+ name='pixels')
+ wrapped_obs_spec.pop('front_close')
+
+ for key, spec in wrapped_obs_spec.items():
+ assert spec.dtype == np.float64
+ assert type(spec) == specs.Array
+ dim = np.sum(
+ np.fromiter((int(np.prod(spec.shape))
+ for spec in wrapped_obs_spec.values()), np.int32))
+
+ self._obs_spec['observations'] = specs.Array(shape=(dim,),
+ dtype=np.float32,
+ name='observations')
+
+ def _transform_observation(self, time_step):
+ obs = OrderedDict()
+
+ if 'front_close' in time_step.observation:
+ pixels = time_step.observation['front_close']
+ time_step.observation.pop('front_close')
+ pixels = np.squeeze(pixels)
+ obs['pixels'] = pixels
+
+ features = []
+ for feature in time_step.observation.values():
+ features.append(feature.ravel())
+ obs['observations'] = np.concatenate(features, axis=0)
+ return time_step._replace(observation=obs)
+
+ def reset(self):
+ time_step = self._env.reset()
+ return self._transform_observation(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ return self._transform_observation(time_step)
+
+ def observation_spec(self):
+ return self._obs_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+
+class ActionRepeatWrapper(dm_env.Environment):
+ def __init__(self, env, num_repeats):
+ self._env = env
+ self._num_repeats = num_repeats
+
+ def step(self, action):
+ reward = 0.0
+ discount = 1.0
+ for i in range(self._num_repeats):
+ time_step = self._env.step(action)
+ reward += (time_step.reward or 0.0) * discount
+ discount *= time_step.discount
+ if time_step.last():
+ break
+
+ return time_step._replace(reward=reward, discount=discount)
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def reset(self):
+ return self._env.reset()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+
+class FramesWrapper(dm_env.Environment):
+ def __init__(self, env, num_frames=1, pixels_key='pixels'):
+ self._env = env
+ self._num_frames = num_frames
+ self._frames = deque([], maxlen=num_frames)
+ self._pixels_key = pixels_key
+
+ wrapped_obs_spec = env.observation_spec()
+ assert pixels_key in wrapped_obs_spec
+
+ pixels_shape = wrapped_obs_spec[pixels_key].shape
+ # remove batch dim
+ if len(pixels_shape) == 4:
+ pixels_shape = pixels_shape[1:]
+ self._obs_spec = specs.BoundedArray(shape=np.concatenate(
+ [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
+ dtype=np.uint8,
+ minimum=0,
+ maximum=255,
+ name='observation')
+
+ def _transform_observation(self, time_step):
+ assert len(self._frames) == self._num_frames
+ obs = np.concatenate(list(self._frames), axis=0)
+ return time_step._replace(observation=obs)
+
+ def _extract_pixels(self, time_step):
+ pixels = time_step.observation[self._pixels_key]
+ # remove batch dim
+ if len(pixels.shape) == 4:
+ pixels = pixels[0]
+ return pixels.transpose(2, 0, 1).copy()
+
+ def reset(self):
+ time_step = self._env.reset()
+ pixels = self._extract_pixels(time_step)
+ for _ in range(self._num_frames):
+ self._frames.append(pixels)
+ return self._transform_observation(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ pixels = self._extract_pixels(time_step)
+ self._frames.append(pixels)
+ return self._transform_observation(time_step)
+
+ def observation_spec(self):
+ return self._obs_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+class OneHotAction(gym.Wrapper):
+ def __init__(self, env):
+ assert isinstance(env.action_space, gym.spaces.Discrete)
+ super().__init__(env)
+ self._random = np.random.RandomState()
+ shape = (self.env.action_space.n,)
+ space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
+ space.discrete = True
+ self.action_space = space
+
+ def step(self, action):
+ index = np.argmax(action).astype(int)
+ reference = np.zeros_like(action)
+ reference[index] = 1
+ if not np.allclose(reference, action):
+ raise ValueError(f"Invalid one-hot action:\n{action}")
+ return self.env.step(index)
+
+ def reset(self):
+ return self.env.reset()
+
+ def _sample_action(self):
+ actions = self.env.action_space.n
+ index = self._random.randint(0, actions)
+ reference = np.zeros(actions, dtype=np.float32)
+ reference[index] = 1.0
+ return reference
+
+class ActionDTypeWrapper(dm_env.Environment):
+ def __init__(self, env, dtype):
+ self._env = env
+ wrapped_action_spec = env.action_spec()
+ self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
+ dtype,
+ wrapped_action_spec.minimum,
+ wrapped_action_spec.maximum,
+ 'action')
+
+ def step(self, action):
+ action = action.astype(self._env.action_spec().dtype)
+ return self._env.step(action)
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._action_spec
+
+ def reset(self):
+ return self._env.reset()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+
+class ObservationDTypeWrapper(dm_env.Environment):
+ def __init__(self, env, dtype):
+ self._env = env
+ self._dtype = dtype
+ wrapped_obs_spec = env.observation_spec()['observations']
+ self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
+ 'observation')
+
+ def _transform_observation(self, time_step):
+ obs = time_step.observation['observations'].astype(self._dtype)
+ return time_step._replace(observation=obs)
+
+ def reset(self):
+ time_step = self._env.reset()
+ return self._transform_observation(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ return self._transform_observation(time_step)
+
+ def observation_spec(self):
+ return self._obs_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+
+class ExtendedTimeStepWrapper(dm_env.Environment):
+ def __init__(self, env):
+ self._env = env
+
+ def reset(self):
+ time_step = self._env.reset()
+ return self._augment_time_step(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ return self._augment_time_step(time_step, action)
+
+ def _augment_time_step(self, time_step, action=None):
+ if action is None:
+ action_spec = self.action_spec()
+ action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
+ return ExtendedTimeStep(observation=time_step.observation,
+ step_type=time_step.step_type,
+ action=action,
+ reward=time_step.reward or 0.0,
+ discount=time_step.discount or 1.0)
+
+ def observation_spec(self):
+ return self._env.observation_spec()
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
+
+class DMC:
+ def __init__(self, env):
+ self._env = env
+ self._ignored_keys = []
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ assert time_step.discount in (0, 1)
+ obs = {
+ 'reward': time_step.reward,
+ 'is_first': False,
+ 'is_last': time_step.last(),
+ 'is_terminal': time_step.discount == 0,
+ 'observation': time_step.observation,
+ 'action' : action,
+ 'discount': time_step.discount
+ }
+ return time_step, obs
+
+ def reset(self):
+ time_step = self._env.reset()
+ obs = {
+ 'reward': 0.0,
+ 'is_first': True,
+ 'is_last': False,
+ 'is_terminal': False,
+ 'observation': time_step.observation,
+ 'action' : np.zeros_like(self.act_space['action'].sample()),
+ 'discount': time_step.discount
+ }
+ return time_step, obs
+
+ def __getattr__(self, name):
+ if name == 'obs_space':
+ obs_spaces = {
+ 'observation': self._env.observation_spec(),
+ 'is_first': gym.spaces.Box(0, 1, (), dtype=bool),
+ 'is_last': gym.spaces.Box(0, 1, (), dtype=bool),
+ 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool),
+ }
+ return obs_spaces
+ if name == 'act_space':
+ spec = self._env.action_spec()
+ action = gym.spaces.Box((spec.minimum)*spec.shape[0], (spec.maximum)*spec.shape[0], shape=spec.shape, dtype=np.float32)
+ act_space = {'action': action}
+ return act_space
+ return getattr(self._env, name)
+
+
+class OneHotAction(gym.Wrapper):
+ def __init__(self, env):
+ assert isinstance(env.action_space, gym.spaces.Discrete)
+ super().__init__(env)
+ self._random = np.random.RandomState()
+ shape = (self.env.action_space.n,)
+ space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
+ space.discrete = True
+ self.action_space = space
+
+ def step(self, action):
+ index = np.argmax(action).astype(int)
+ reference = np.zeros_like(action)
+ reference[index] = 1
+ if not np.allclose(reference, action):
+ raise ValueError(f"Invalid one-hot action:\n{action}")
+ return self.env.step(index)
+
+ def reset(self):
+ return self.env.reset()
+
+ def _sample_action(self):
+ actions = self.env.action_space.n
+ index = self._random.randint(0, actions)
+ reference = np.zeros(actions, dtype=np.float32)
+ reference[index] = 1.0
+ return reference
+
+class KitchenWrapper:
+ def __init__(
+ self,
+ name,
+ seed=0,
+ action_repeat=1,
+ size=(64, 64),
+ ):
+ import envs.kitchen_extra as kitchen_extra
+ self._env = {
+ 'microwave' : kitchen_extra.KitchenMicrowaveV0,
+ 'kettle' : kitchen_extra.KitchenKettleV0,
+ 'burner' : kitchen_extra.KitchenBurnerV0,
+ 'light' : kitchen_extra.KitchenLightV0,
+ 'hinge' : kitchen_extra.KitchenHingeV0,
+ 'slide' : kitchen_extra.KitchenSlideV0,
+ 'top_burner' : kitchen_extra.KitchenTopBurnerV0,
+ }[name]()
+
+ self._size = size
+ self._action_repeat = action_repeat
+ self._seed = seed
+ self._eval = False
+
+ def eval_mode(self,):
+ self._env.dense = False
+ self._eval = True
+
+ @property
+ def obs_space(self):
+ spaces = {
+ "observation": gym.spaces.Box(0, 255, (3,) + self._size, dtype=np.uint8),
+ "is_first": gym.spaces.Box(0, 1, (), dtype=bool),
+ "is_last": gym.spaces.Box(0, 1, (), dtype=bool),
+ "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
+ "state": self._env.observation_space,
+ }
+ return spaces
+
+ @property
+ def act_space(self):
+ action = self._env.action_space
+ return {"action": action}
+
+ def step(self, action):
+ # assert np.isfinite(action["action"]).all(), action["action"]
+ reward = 0.0
+ for _ in range(self._action_repeat):
+ state, rew, done, info = self._env.step(action.copy())
+ reward += rew
+ obs = {
+ "reward": reward,
+ "is_first": False,
+ "is_last": False, # will be handled by timelimit wrapper
+ "is_terminal": False, # will be handled by per_episode function
+ "observation": info['images'].transpose(2, 0, 1).copy(),
+ "state": state.astype(np.float32),
+ 'action' : action,
+ 'discount' : 1
+ }
+ if self._eval:
+ obs['reward'] = min(obs['reward'], 1)
+ if obs['reward'] > 0:
+ obs['is_last'] = True
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.MID if not obs['is_last'] else dm_env.StepType.LAST,
+ reward=obs['reward'],
+ discount=1,
+ observation=obs['observation']), obs
+
+ def reset(self,):
+ state = self._env.reset()
+ obs = {
+ "reward": 0.0,
+ "is_first": True,
+ "is_last": False,
+ "is_terminal": False,
+ "observation": self.get_visual_obs(self._size),
+ "state": state.astype(np.float32),
+ 'action' : np.zeros_like(self.act_space['action'].sample()),
+ 'discount' : 1
+ }
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST,
+ reward=None,
+ discount=None,
+ observation=obs['observation']), obs
+
+ def __getattr__(self, name):
+ if name == 'obs_space':
+ return self.obs_space
+ if name == 'act_space':
+ return self.act_space
+ return getattr(self._env, name)
+
+ def get_visual_obs(self, resolution):
+ img = self._env.render(resolution=resolution,).transpose(2, 0, 1).copy()
+ return img
+
+class ViClipWrapper:
+ def __init__(self, env, hd_rendering=False, device='cuda'):
+ self._env = env
+ try:
+ from tools.genrl_utils import viclip_global_instance
+ except:
+ from tools.genrl_utils import ViCLIPGlobalInstance
+ viclip_global_instance = ViCLIPGlobalInstance()
+
+ if not viclip_global_instance._instantiated:
+ viclip_global_instance.instantiate(device)
+ self.viclip_model = viclip_global_instance.viclip
+ self.n_frames = self.viclip_model.n_frames
+ self.viclip_emb_dim = viclip_global_instance.viclip_emb_dim
+ self.n_frames = self.viclip_model.n_frames
+ self.buffer = deque(maxlen=self.n_frames)
+ # NOTE: these are hardcoded for now, as they are the best settings
+ self.accumulate = True
+ self.accumulate_buffer = []
+ self.anticipate_conv1 = False
+ self.hd_rendering = hd_rendering
+
+ def hd_render(self, obs):
+ if not self.hd_rendering:
+ return obs['observation']
+ if self._env._domain_name in ['mw', 'kitchen', 'mujoco']:
+ return self.get_visual_obs((224,224,))
+ else:
+ render_kwargs = {**getattr(self, '_render_kwargs', {})}
+ render_kwargs.update({'width' : 224, 'height' : 224})
+ return self._env.physics.render(**render_kwargs).transpose(2,0,1)
+
+ def preprocess(self, x):
+ return x
+
+ def process_accumulate(self, process_at_once=4): # NOTE: this could be varied for increasing FPS, depending on the size of the GPU
+ self.accumulate = False
+ x = np.stack(self.accumulate_buffer, axis=0)
+ # Splitting in chunks
+ chunks = []
+ chunk_idxs = list(range(0, x.shape[0] + 1, process_at_once))
+ if chunk_idxs[-1] != x.shape[0]:
+ chunk_idxs.append(x.shape[0])
+ start = 0
+ for end in chunk_idxs[1:]:
+ embeds = self.clip_process(x[start:end], bypass=True)
+ chunks.append(embeds.cpu())
+ start = end
+ embeds = torch.cat(chunks, dim=0)
+ assert embeds.shape[0] == len(self.accumulate_buffer)
+ self.accumulate = True
+ self.accumulate_buffer = []
+ return [*embeds.cpu().numpy()], 'clip_video'
+
+ def process_episode(self, obs, process_at_once=8):
+ self.accumulate = False
+ sequences = []
+ for j in range(obs.shape[0] - self.n_frames + 1):
+ sequences.append(obs[j:j+self.n_frames].copy())
+ sequences = np.stack(sequences, axis=0)
+
+ idx_start = 0
+ clip_vid = []
+ for idx_end in range(process_at_once, sequences.shape[0] + process_at_once, process_at_once):
+ x = sequences[idx_start:idx_end]
+ with torch.no_grad(): # , torch.cuda.amp.autocast():
+ x = self.clip_process(x, bypass=True)
+ clip_vid.append(x)
+ idx_start = idx_end
+ if len(clip_vid) == 1: # process all at once
+ embeds = clip_vid[0]
+ else:
+ embeds = torch.cat(clip_vid, dim=0)
+ pad = torch.zeros( (self.n_frames - 1, *embeds.shape[1:]), device=embeds.device, dtype=embeds.dtype)
+ embeds = torch.cat([pad, embeds], dim=0)
+ assert embeds.shape[0] == obs.shape[0], f"Shapes are different {embeds.shape[0]} {obs.shape[0]}"
+ return embeds.cpu().numpy()
+
+ def get_sequence(self,):
+ return np.expand_dims(np.stack(self.buffer, axis=0), axis=0)
+
+ def clip_process(self, x, bypass=False):
+ if len(self.buffer) == self.n_frames or bypass:
+ if self.accumulate:
+ self.accumulate_buffer.append(self.preprocess(x)[0])
+ return torch.zeros(self.viclip_emb_dim)
+ with torch.no_grad():
+ B, n_frames, C, H, W = x.shape
+ obs = torch.from_numpy(x.copy().reshape(B * n_frames, C, H, W)).to(self.viclip_model.device)
+ processed_obs = self.viclip_model.preprocess_transf(obs / 255)
+ reshaped_obs = processed_obs.reshape(B, n_frames, 3,processed_obs.shape[-2],processed_obs.shape[-1])
+ video_embed = self.viclip_model.get_vid_features(reshaped_obs)
+ return video_embed.detach()
+ else:
+ return torch.zeros(self.viclip_emb_dim)
+
+ def step(self, action):
+ ts, obs = self._env.step(action)
+ self.buffer.append(self.hd_render(obs))
+ obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
+ return ts, obs
+
+ def reset(self,):
+ # Important to reset the buffer
+ self.buffer = deque(maxlen=self.n_frames)
+
+ ts, obs = self._env.reset()
+ self.buffer.append(self.hd_render(obs))
+ obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
+ return ts, obs
+
+ def __getattr__(self, name):
+ if name == 'obs_space':
+ space = self._env.obs_space
+ space['clip_video'] = gym.spaces.Box(-np.inf, np.inf, (self.viclip_emb_dim,), dtype=np.float32)
+ return space
+ return getattr(self._env, name)
+
+class TimeLimit:
+
+ def __init__(self, env, duration):
+ self._env = env
+ self._duration = duration
+ self._step = None
+
+ def __getattr__(self, name):
+ if name.startswith('__'):
+ raise AttributeError(name)
+ return getattr(self._env, name)
+
+ def step(self, action):
+ assert self._step is not None, 'Must reset environment.'
+ ts, obs = self._env.step(action)
+ self._step += 1
+ if self._duration and self._step >= self._duration:
+ ts = dm_env.TimeStep(dm_env.StepType.LAST, ts.reward, ts.discount, ts.observation)
+ obs['is_last'] = True
+ self._step = None
+ return ts, obs
+
+ def reset(self):
+ self._step = 0
+ return self._env.reset()
+
+ def reset_with_task_id(self, task_id):
+ self._step = 0
+ return self._env.reset_with_task_id(task_id)
+
+class ClipActionWrapper:
+
+ def __init__(self, env, low=-1.0, high=1.0):
+ self._env = env
+ self._low = low
+ self._high = high
+
+ def __getattr__(self, name):
+ if name.startswith('__'):
+ raise AttributeError(name)
+ return getattr(self._env, name)
+
+ def step(self, action):
+ clipped_action = np.clip(action, self._low, self._high)
+ return self._env.step(clipped_action)
+
+ def reset(self):
+ self._step = 0
+ return self._env.reset()
+
+ def reset_with_task_id(self, task_id):
+ self._step = 0
+ return self._env.reset_with_task_id(task_id)
+
+class NormalizeAction:
+
+ def __init__(self, env, key='action'):
+ self._env = env
+ self._key = key
+ space = env.act_space[key]
+ self._mask = np.isfinite(space.low) & np.isfinite(space.high)
+ self._low = np.where(self._mask, space.low, -1)
+ self._high = np.where(self._mask, space.high, 1)
+
+ def __getattr__(self, name):
+ if name.startswith('__'):
+ raise AttributeError(name)
+ try:
+ return getattr(self._env, name)
+ except AttributeError:
+ raise ValueError(name)
+
+ @property
+ def act_space(self):
+ low = np.where(self._mask, -np.ones_like(self._low), self._low)
+ high = np.where(self._mask, np.ones_like(self._low), self._high)
+ space = gym.spaces.Box(low, high, dtype=np.float32)
+ return {**self._env.act_space, self._key: space}
+
+ def step(self, action):
+ orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low
+ orig = np.where(self._mask, orig, action[self._key])
+ return self._env.step({**action, self._key: orig})
+
+def _make_jaco(obs_type, domain, task, action_repeat, seed, img_size,):
+ import envs.custom_dmc_tasks as cdmc
+ env = cdmc.make_jaco(task, obs_type, seed, img_size,)
+ env = ActionDTypeWrapper(env, np.float32)
+ env = ActionRepeatWrapper(env, action_repeat)
+ env = FlattenJacoObservationWrapper(env)
+ env._size = (img_size, img_size)
+ return env
+
+
+def _make_dmc(obs_type, domain, task, action_repeat, seed, img_size,):
+ visualize_reward = False
+ from dm_control import manipulation, suite
+ import envs.custom_dmc_tasks as cdmc
+
+ if (domain, task) in suite.ALL_TASKS:
+ env = suite.load(domain,
+ task,
+ task_kwargs=dict(random=seed),
+ environment_kwargs=dict(flat_observation=True),
+ visualize_reward=visualize_reward)
+ else:
+ env = cdmc.make(domain,
+ task,
+ task_kwargs=dict(random=seed),
+ environment_kwargs=dict(flat_observation=True),
+ visualize_reward=visualize_reward)
+ env = ActionDTypeWrapper(env, np.float32)
+ env = ActionRepeatWrapper(env, action_repeat)
+ if obs_type == 'pixels':
+ from dm_control.suite.wrappers import pixels
+ # zoom in camera for quadruped
+ camera_id = dict(locom_rodent=1,quadruped=2).get(domain, 0)
+ render_kwargs = dict(height=img_size, width=img_size, camera_id=camera_id)
+ env = pixels.Wrapper(env,
+ pixels_only=True,
+ render_kwargs=render_kwargs)
+ env._size = (img_size, img_size)
+ env._camera = camera_id
+ return env
+
+
+def make(name, obs_type, action_repeat, seed, img_size=64, viclip_encode=False, clip_hd_rendering=False, device='cuda'):
+ assert obs_type in ['states', 'pixels']
+ domain, task = name.split('_', 1)
+ if domain == 'kitchen':
+ env = TimeLimit(KitchenWrapper(task, seed=seed, action_repeat=action_repeat, size=(img_size,img_size)), 280 // action_repeat)
+ else:
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+ os.environ['MUJOCO_GL'] = 'egl'
+
+ domain = dict(cup='ball_in_cup', point='point_mass').get(domain, domain)
+
+ make_fn = _make_jaco if domain == 'jaco' else _make_dmc
+ env = make_fn(obs_type, domain, task, action_repeat, seed, img_size,)
+
+ if obs_type == 'pixels':
+ env = FramesWrapper(env,)
+ else:
+ env = ObservationDTypeWrapper(env, np.float32)
+
+ from dm_control.suite.wrappers import action_scale
+ env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
+ env = ExtendedTimeStepWrapper(env)
+
+ env = DMC(env)
+ env._domain_name = domain
+
+ if isinstance(env.act_space['action'], gym.spaces.Box):
+ env = ClipActionWrapper(env,)
+
+ if viclip_encode:
+ env = ViClipWrapper(env, hd_rendering=clip_hd_rendering, device=device)
+ return env
diff --git a/notebooks/demo_videoclip.ipynb b/notebooks/demo_videoclip.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6fcb0937aef833f3f0d2271168ea7b0f73b02e27
--- /dev/null
+++ b/notebooks/demo_videoclip.ipynb
@@ -0,0 +1,124 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# InternVideo 2 demo\n",
+ "\n",
+ "It can be used to test the capabilities of InternVideo2 and to verify that the models are loaded correctly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "import sys\n",
+ "import os\n",
+ "sys.path.append(str(pathlib.Path(os.path.abspath('')).parent))\n",
+ "\n",
+ "from tools.genrl_utils import viclip_global_instance\n",
+ "viclip_global_instance.instantiate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import cv2\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from tools.genrl_utils import INTERNVIDEO_PATH\n",
+ "\n",
+ "def _frame_from_video(video):\n",
+ " while video.isOpened():\n",
+ " success, frame = video.read()\n",
+ " if success:\n",
+ " yield frame\n",
+ " else:\n",
+ " break\n",
+ "\n",
+ "ASSET_PATH = pathlib.Path(os.path.abspath('')).parent / 'assets'\n",
+ "\n",
+ "# 83 % - A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run.\n",
+ "video = cv2.VideoCapture( str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/example1.mp4') )\n",
+ "# # 99 % - A karate kick\n",
+ "# video = cv2.VideoCapture( str( ASSET_PATH / 'video_samples/karate_kick.mp4') ) \n",
+ "# # 99 % - A headstand\n",
+ "# video = cv2.VideoCapture( str( ASSET_PATH / 'video_samples/headstand.mp4') ) \n",
+ "\n",
+ "frames = [x for x in _frame_from_video(video)]\n",
+ "processed_frames = viclip_global_instance.viclip.preprocess_transf(torch.from_numpy(np.stack(frames[:8], axis=0)).permute(0,3,1,2) / 255)\n",
+ "frames_tensor = processed_frames.reshape(1, 8, 3, 224,224)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "text_candidates = [\"A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon.\",\n",
+ " \"A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys.\",\n",
+ " \"A person dressed in a blue jacket shovels the snow-covered pavement outside their house.\",\n",
+ " \"A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner.\",\n",
+ " \"A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride.\",\n",
+ " \"A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees.\",\n",
+ " \"A playful dog slides down a snowy hill, wagging its tail with delight.\",\n",
+ " \"A person in a blue jacket walks their pet on a leash, enjoying a peaceful winter walk among the trees.\",\n",
+ " \"A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run.\",\n",
+ " \"A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery.\",\n",
+ " \"A person playing with a kid in the street\",\n",
+ " \"A group of friends playing bowling.\",\n",
+ " \"A japanese girl eating noodles\",\n",
+ " \"A painting by Monet\",\n",
+ " \"A karate kick\",\n",
+ " \"A headstand\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "text_feat = viclip_global_instance.viclip.get_txt_feat(text_candidates)\n",
+ "video_feat = viclip_global_instance.viclip.get_vid_features(frames_tensor.to(viclip_global_instance.viclip.device))\n",
+ "\n",
+ "sorted_probs, sorted_idxs = (100.0 * video_feat @ text_feat.T).softmax(dim=-1)[0].topk(len(text_feat))\n",
+ "\n",
+ "for p, i in zip(sorted_probs, sorted_idxs):\n",
+ " if p > 0.01:\n",
+ " print(int(p * 100), '% - ', text_candidates[i])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/text2video.ipynb b/notebooks/text2video.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..02ffb33836a89d652c32852edaabc5284b083496
--- /dev/null
+++ b/notebooks/text2video.ipynb
@@ -0,0 +1,161 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path \n",
+ "import os\n",
+ "import sys\n",
+ "sys.path.append(str(Path(os.path.abspath('')).parent))\n",
+ "\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.animation as animation\n",
+ "\n",
+ "agent_path = Path(os.path.abspath('')).parent / 'models' / 'genrl_stickman_500k_2.pt'\n",
+ "print(\"Model path\", agent_path)\n",
+ "\n",
+ "agent = torch.load(agent_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tools.genrl_utils import ViCLIPGlobalInstance, DOMAIN2PREDICATES\n",
+ "model_name = getattr(agent.cfg, 'viclip_model', 'viclip')\n",
+ "# Get ViCLIP\n",
+ "if 'viclip_global_instance' not in locals() or model_name != viclip_global_instance._model:\n",
+ " viclip_global_instance = ViCLIPGlobalInstance(model_name)\n",
+ " if not viclip_global_instance._instantiated:\n",
+ " print(\"Instantiating\")\n",
+ " viclip_global_instance.instantiate()\n",
+ " clip = viclip_global_instance.viclip\n",
+ " tokenizer = viclip_global_instance.viclip_tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SAVE = True\n",
+ "DENOISE = True\n",
+ "REVERSE = False\n",
+ "REPEAT_TIME = 2 # standard is n_frames for = 1 \n",
+ "TEXT_OVERLAY = True\n",
+ "\n",
+ "domain = agent.cfg.task.split('_')\n",
+ "\n",
+ "labels_list = ['high kick', 'stand up straight', 'doing splits']\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " wm = world_model = agent.wm\n",
+ " connector = agent.wm.connector\n",
+ " decoder = world_model.heads['decoder']\n",
+ " n_frames = connector.n_frames\n",
+ " \n",
+ " # Get text(video) embed\n",
+ " text_feat = []\n",
+ " for text in labels_list:\n",
+ " with torch.no_grad():\n",
+ " text_feat.append(clip.get_txt_feat(text,))\n",
+ " text_feat = torch.stack(text_feat, dim=0).to(clip.device)\n",
+ "\n",
+ " video_embed = text_feat\n",
+ "\n",
+ " B = video_embed.shape[0]\n",
+ " T = 1\n",
+ "\n",
+ " # Get initial state\n",
+ " init = connector.initial(B, init_embed=video_embed)\n",
+ "\n",
+ " # Get actions\n",
+ " video_embed = video_embed.repeat(1,n_frames, 1)\n",
+ " action = wm.connector.get_action(video_embed)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " # Imagine\n",
+ " prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=DENOISE)\n",
+ " # Decode\n",
+ " prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5\n",
+ "\n",
+ " # Plotting video\n",
+ " R = int(np.sqrt(B))\n",
+ " C = min((B + (R-1)) // R, B) \n",
+ "\n",
+ " fig, axes = plt.subplots(R, C, figsize=(3.5 * C, 4 * R))\n",
+ " fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)\n",
+ " fig.set_size_inches(4,4)\n",
+ " \n",
+ " if B == 1:\n",
+ " axes = [[axes]]\n",
+ " elif R == 1:\n",
+ " axes = [axes] \n",
+ " axes = [ a for row in axes for a in row]\n",
+ "\n",
+ " file_path = f'temp_text2video.gif'\n",
+ "\n",
+ " if SAVE:\n",
+ " ims = []\n",
+ " for t in range(prior_recon.shape[1]):\n",
+ " if t == 0 :\n",
+ " continue\n",
+ " toadd = []\n",
+ " for b in range(prior_recon.shape[0]):\n",
+ " ax = axes[b]\n",
+ " ax.set_axis_off()\n",
+ " img = np.clip(prior_recon[b, t if not REVERSE else -t].cpu().permute(1,2,0), 0, 1)\n",
+ " frame = ax.imshow(img)\n",
+ " if TEXT_OVERLAY: \n",
+ " test = ax.text(0,5, labels_list[b], color='white')\n",
+ " toadd.append(frame) # add both the image and the text to the list of artists \n",
+ " ims.append(toadd)\n",
+ "\n",
+ " # Save GIFs\n",
+ " anim = animation.ArtistAnimation(fig, ims, interval=700, blit=True, repeat_delay=700)\n",
+ " writer = animation.PillowWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800)\n",
+ " domain = agent.cfg.task.split('_')[0]\n",
+ " os.makedirs(f'videos/{domain}/text2video', exist_ok=True)\n",
+ " file_path = f'videos/{domain}/text2video/{\"_\".join(labels_list).replace(\" \",\"_\")}.gif'\n",
+ " print(\"GIF path: \", Path(os.path.abspath('')) / file_path)\n",
+ " anim.save(file_path, writer=writer)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.10 ('base')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/video2video.ipynb b/notebooks/video2video.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..326acc34ba20b640f955e833c05fcadfe76262a4
--- /dev/null
+++ b/notebooks/video2video.ipynb
@@ -0,0 +1,211 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path \n",
+ "import os\n",
+ "import glob\n",
+ "import json\n",
+ "import sys\n",
+ "sys.path.append(str(Path(os.path.abspath('')).parent))\n",
+ "\n",
+ "import torch\n",
+ "import torch.distributions as D\n",
+ "import numpy as np\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.cm as cm\n",
+ "import matplotlib.animation as animation\n",
+ "\n",
+ "import wandb\n",
+ "from tqdm import tqdm\n",
+ "api = wandb.Api()\n",
+ "\n",
+ "agent_path = Path(os.path.abspath('')).parent / 'models' / 'genrl_stickman_500k_2.pt'\n",
+ "print(\"Model path\", agent_path)\n",
+ "\n",
+ "agent = torch.load(agent_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tools.genrl_utils import ViCLIPGlobalInstance, DOMAIN2PREDICATES\n",
+ "model_name = getattr(agent.cfg, 'viclip_model', 'viclip')\n",
+ "# Get ViCLIP\n",
+ "if 'viclip_global_instance' not in locals() or model_name != viclip_global_instance._model:\n",
+ " viclip_global_instance = ViCLIPGlobalInstance(model_name)\n",
+ " if not viclip_global_instance._instantiated:\n",
+ " print(\"Instantiating\")\n",
+ " viclip_global_instance.instantiate()\n",
+ " clip = viclip_global_instance.viclip\n",
+ " tokenizer = viclip_global_instance.viclip_tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import cv2\n",
+ "\n",
+ "def get_vid_feat(frames, clip):\n",
+ " return clip.get_vid_features(frames,)\n",
+ "\n",
+ "def _frame_from_video(video):\n",
+ " while video.isOpened():\n",
+ " success, frame = video.read()\n",
+ " if success:\n",
+ " yield frame\n",
+ " else:\n",
+ " break\n",
+ "\n",
+ "v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)\n",
+ "v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)\n",
+ "def normalize(data):\n",
+ " return (data/255.0-v_mean)/v_std\n",
+ "\n",
+ "def denormalize(data):\n",
+ " return (((data * v_std) + v_mean) * 255) \n",
+ "\n",
+ "def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):\n",
+ " vid_list = [*vid_list[0]]\n",
+ " assert(len(vid_list) >= fnum)\n",
+ " vid_list = [cv2.resize(x, target_size) for x in vid_list]\n",
+ " vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]\n",
+ " vid_tube = np.concatenate(vid_tube, axis=1)\n",
+ " vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))\n",
+ " vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()\n",
+ " return vid_tube\n",
+ "\n",
+ "\n",
+ "def get_video_feat(frames, device=torch.device('cuda'), flip=False):\n",
+ " # Image\n",
+ " if frames.shape[1] == 1:\n",
+ " frames = frames.transpose(1,0,2,3,4).repeat(8, axis=0).transpose(1,0,2,3,4)\n",
+ "\n",
+ " # Short video\n",
+ " if frames.shape[1] == 4:\n",
+ " frames = frames.transpose(1,0,2,3,4).repeat(2, axis=0).transpose(1,0,2,3,4)\n",
+ "\n",
+ " k = max(frames.shape[1] // 128, 1)\n",
+ " frames = frames[:, ::k]\n",
+ " \n",
+ " # Horizontally flip\n",
+ " if flip:\n",
+ " frames = np.flip(frames, axis=-2)\n",
+ "\n",
+ " print(frames.shape,)\n",
+ " chosen_frames = frames[:, :8]\n",
+ " chosen_frames = frames2tensor(chosen_frames, device=device)\n",
+ " vid_feat = get_vid_feat(chosen_frames, clip,)\n",
+ " return vid_feat, chosen_frames\n",
+ "\n",
+ "VIDEO_PATH = Path(os.path.abspath('')).parent / 'assets' / 'video_samples'\n",
+ "video_name = 'headstand.mp4'\n",
+ "\n",
+ "video_file_path = str(VIDEO_PATH / video_name)\n",
+ "print(video_file_path)\n",
+ "video = cv2.VideoCapture(video_file_path)\n",
+ "frames = np.expand_dims(np.stack([ cv2.cvtColor(x, cv2.COLOR_BGR2RGB) for x in _frame_from_video(video)], axis=0), axis=0)\n",
+ "print('Video length:', frames.shape[1])\n",
+ "with torch.no_grad():\n",
+ " vid_feat, frames_feat = get_video_feat(frames, flip=False)\n",
+ "print(vid_feat.shape)\n",
+ "plt.imshow(frames[0,0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "video_embed = vid_feat\n",
+ "DENOISE = True\n",
+ "\n",
+ "T = video_embed.shape[0]\n",
+ "\n",
+ "from torchvision.transforms import transforms as vision_trans\n",
+ "trasnf = vision_trans.Resize(size=(64, 64), interpolation=vision_trans.InterpolationMode.NEAREST)\n",
+ "\n",
+ "wm = world_model = agent.wm\n",
+ "connector = agent.wm.connector\n",
+ "decoder = world_model.heads['decoder']\n",
+ "n_frames = connector.n_frames\n",
+ "\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " # Get actions\n",
+ " video_embed = video_embed.unsqueeze(1).repeat(1,n_frames, 1).reshape(1, n_frames * T, -1)\n",
+ " action = wm.connector.get_action(video_embed)\n",
+ "\n",
+ " # Imagine\n",
+ " prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=DENOISE)\n",
+ " prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5\n",
+ "\n",
+ " # Plotting video\n",
+ " ims = []\n",
+ " fig, axes = plt.subplots(1, 1, figsize=(4, 8), frameon=False)\n",
+ " fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)\n",
+ " fig.set_size_inches(4,2)\n",
+ "\n",
+ " for t in range(prior_recon.shape[1]):\n",
+ " toadd = []\n",
+ " for b in range(prior_recon.shape[0]):\n",
+ " ax = axes\n",
+ " ax.set_axis_off()\n",
+ " img = cv2.resize((np.clip(prior_recon[b, t].cpu().permute(1,2,0), 0, 1).numpy() *255).astype(np.uint8), (224,224))\n",
+ " orig_img = denormalize(frames_feat[b, t].cpu().permute(1,2,0) ).numpy().astype(np.uint8)\n",
+ " frame = ax.imshow(np.concatenate([orig_img, img], axis=1)) \n",
+ " toadd.append(frame) # add both the image and the text to the list of artists \n",
+ " ims.append(toadd)\n",
+ "\n",
+ " anim = animation.ArtistAnimation(fig, ims, interval=700, blit=True, repeat_delay=700, )\n",
+ "\n",
+ " # Save GIFs\n",
+ " writer = animation.PillowWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800,)\n",
+ " domain = agent.cfg.task.split('_')[0]\n",
+ " os.makedirs(f'videos/{domain}/video2video', exist_ok=True)\n",
+ " file_path = f'videos/{domain}/video2video/{video_name[:-4].replace(\" \",\"_\")}.gif'\n",
+ " anim.save(file_path, writer=writer, )\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.10 ('base')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/visualize_dataset_episodes.ipynb b/notebooks/visualize_dataset_episodes.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..f4fd82361492066b9739839bc5a46563f814eaee
--- /dev/null
+++ b/notebooks/visualize_dataset_episodes.ipynb
@@ -0,0 +1,117 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Visualize dataset\n",
+ "\n",
+ "Utilities to visualize episodes from a dataset. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "from matplotlib import pyplot as plt\n",
+ "from matplotlib import animation\n",
+ "import pathlib\n",
+ "from IPython.display import Video\n",
+ "import numpy as np\n",
+ "import os\n",
+ "\n",
+ "dataset_path = pathlib.Path(os.path.abspath('')).parent / 'data/stickman_example'\n",
+ "\n",
+ "directory = dataset_path.expanduser()\n",
+ "filenames = sorted(directory.glob('*.npz'))\n",
+ "if len(filenames) == 0:\n",
+ " raise ValueError(\"Empty directory (or no episodes)\")\n",
+ "\n",
+ "try:\n",
+ " filenames_dict = { int(str(f).replace(str(dataset_path), \"\").split(\"-\")[0][1:]) : f for f in filenames}\n",
+ "except Exception as e:\n",
+ " print(\"Error:\", e)\n",
+ "\n",
+ "print(directory)\n",
+ "print(len(filenames))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ep_num = next(iter(filenames_dict))\n",
+ "\n",
+ "filename = filenames_dict[ep_num]\n",
+ "with filename.open('rb') as f:\n",
+ " episode = np.load(f)\n",
+ " episode = {k: episode[k] for k in episode.keys()}\n",
+ "\n",
+ "# Show reward on top with red/green bar\n",
+ "pix_rew_max = np.round(episode['reward'] / 2 * 64)\n",
+ "for ob, pix_n in zip(episode['observation'], pix_rew_max):\n",
+ " if pix_n < 0:\n",
+ " pix_n = abs(pix_n)\n",
+ " ob[:, 0, :int(pix_n+1)] = np.array([255,0,0]).reshape(3,1)\n",
+ " else:\n",
+ " ob[:, 0, :int(pix_n+1)] = np.array([0,255,0]).reshape(3,1)\n",
+ "\n",
+ "# # np array with shape (frames, height, width, channels)\n",
+ "video = np.transpose(episode['observation'], axes=[0,2,3,1])\n",
+ "\n",
+ "fig = plt.figure(frameon=False)\n",
+ "ax = plt.Axes(fig, [0., 0., 1., 1.])\n",
+ "ax.set_axis_off()\n",
+ "fig.add_axes(ax)\n",
+ "fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)\n",
+ "fig.set_size_inches(2,2)\n",
+ "im = ax.imshow(video[0,:,:,:])\n",
+ "plt.close() # this is required to not display the generated image\n",
+ "\n",
+ "def init():\n",
+ " im.set_data(video[0,:,:,:])\n",
+ "\n",
+ "def animate(i):\n",
+ " im.set_data(video[i,:,:,:])\n",
+ " return im\n",
+ "\n",
+ "print('Episode reward', np.sum(episode['reward']))\n",
+ "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],interval=45)\n",
+ "file_path = str(pathlib.Path(os.path.abspath('')) / 'videos/temp.mp4')\n",
+ "anim.save(file_path)\n",
+ "print('Video file', file_path)\n",
+ "Video(file_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.8.10 ('base')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/visualize_env.ipynb b/notebooks/visualize_env.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..9523bfb7064f0993e5cd43251b18df4fca749a64
--- /dev/null
+++ b/notebooks/visualize_env.ipynb
@@ -0,0 +1,152 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Visualize environment and custom tasks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "import sys\n",
+ "import os\n",
+ "sys.path.append(str(pathlib.Path(os.path.abspath('')).parent))\n",
+ "\n",
+ "from envs.custom_dmc_tasks import *\n",
+ "from dm_control import suite\n",
+ "import numpy as np\n",
+ "\n",
+ "domain = 'stickman'\n",
+ "task = 'sit_knees'\n",
+ "\n",
+ "env = suite.load(domain_name=domain, task_name=task, visualize_reward=True)\n",
+ "\n",
+ "action_spec = env.action_spec()\n",
+ "\n",
+ "# Define a uniform random policy.\n",
+ "def random_policy(time_step):\n",
+ " del time_step # Unused.\n",
+ " return np.random.uniform(low=action_spec.minimum,\n",
+ " high=action_spec.maximum,\n",
+ " size=action_spec.shape)\n",
+ "\n",
+ "def zero_policy(time_step):\n",
+ " del time_step\n",
+ " return np.zeros(action_spec.shape)\n",
+ " \n",
+ "\n",
+ "class GoalSetWrapper:\n",
+ " def __init__(self, env, goal=None, goal_idx=None):\n",
+ " self._env = env\n",
+ " self._env._step_limit = float('inf')\n",
+ " self._goal = goal\n",
+ " self._goal_idx = goal_idx\n",
+ "\n",
+ " def step(self, *args, **kwargs):\n",
+ " if self._goal is not None:\n",
+ " self.set_goal(self._goal)\n",
+ " if self._goal_idx is not None:\n",
+ " self.set_goal_by_idx(self._goal_idx)\n",
+ " return self._env.step(*args, **kwargs)\n",
+ " \n",
+ " def set_goal_by_idx(self, idx_goal):\n",
+ " cur = self._env.physics.get_state().copy()\n",
+ " for idx, goal in idx_goal:\n",
+ " cur[idx] = goal\n",
+ " self._env.physics.set_state(cur)\n",
+ " self._env.step(np.zeros_like(self.action_spec().shape))\n",
+ "\n",
+ " def set_goal(self, goal):\n",
+ " goal = np.array(goal)\n",
+ " size = self._env.physics.get_state().shape[0] - goal.shape[0]\n",
+ " self._env.physics.set_state(np.concatenate((goal, np.zeros([size]))))\n",
+ " self._env.step(np.zeros_like(self.action_spec().shape))\n",
+ "\n",
+ " def __getattr__(self, name: str):\n",
+ " return getattr(self._env, name)\n",
+ "\n",
+ "\n",
+ "env = GoalSetWrapper(env)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "from envs.custom_dmc_tasks.stickman import StickmanYogaPoses\n",
+ "\n",
+ "obs = env.reset()\n",
+ "\n",
+ "for _ in range(1):\n",
+ " env.set_goal(StickmanYogaPoses.sit_knees)\n",
+ "\n",
+ "# for _ in range(20):\n",
+ "# obs = env.step(np.random.randn(*env.action_spec().shape))\n",
+ "print('Rew', obs.reward)\n",
+ "\n",
+ "print('Upright', env.physics.torso_upright())\n",
+ "print('Torso height', env.physics.torso_height())\n",
+ "\n",
+ "plt.imshow(env.physics.render(camera_id=0))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for _ in range(1):\n",
+ " obs = env.step(np.random.randn(*env.action_spec().shape))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env.physics.named.data.qpos"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "env.physics.named.data.xpos"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "mine_new",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/process_dataset.py b/process_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..147db4ca2b7d23c85981c08a13469daa17268353
--- /dev/null
+++ b/process_dataset.py
@@ -0,0 +1,140 @@
+import warnings
+
+warnings.filterwarnings('ignore', category=DeprecationWarning)
+
+import io
+import os
+from tqdm import tqdm
+
+os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
+
+from pathlib import Path
+from collections import OrderedDict
+
+import hydra
+import numpy as np
+import torch
+
+import tools.utils as utils
+from tools.replay import load_episode
+
+torch.backends.cudnn.benchmark = True
+
+if os.name == "nt":
+ import msvcrt
+
+ def portable_lock(fp):
+ fp.seek(0)
+ msvcrt.locking(fp, msvcrt.LK_LOCK, 1)
+
+ def portable_unlock(fp):
+ fp.seek(0)
+ msvcrt.locking(fp, msvcrt.LK_UNLCK, 1)
+else:
+ import fcntl
+
+ def portable_lock(fp):
+ fcntl.flock(fp, fcntl.LOCK_EX | fcntl.LOCK_NB)
+
+ def portable_unlock(fp):
+ fcntl.flock(fp, fcntl.LOCK_UN)
+
+
+class Locker:
+ def __init__(self, lock_name):
+ # e.g. lock_name = "./lockfile.lck"
+ self.lock_name = lock_name
+
+ def __enter__(self,):
+ open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
+ self.fd = os.open(self.lock_name, open_mode)
+ portable_lock(self.fd)
+
+ def __exit__(self, _type, value, tb):
+ portable_unlock(self.fd)
+ os.close(self.fd)
+ try:
+ os.remove(self.lock_name)
+ except:
+ pass
+
+class Workspace:
+ def __init__(self, cfg, savedir=None, workdir=None,):
+ self.workdir = Path.cwd() if workdir is None else workdir
+ print(f'workspace: {self.workdir}')
+
+ assert int(cfg.viclip_encode) == 1, "encoding only one (video or img)"
+
+ if cfg.viclip_encode:
+ self.key_to_add = 'clip_video'
+
+ self.key_to_process = getattr(cfg, 'key_to_process', 'observation')
+
+ self.cfg = cfg
+ self.device = torch.device(cfg.device)
+
+ # create envs
+ task = cfg.task
+ self.task = task
+ img_size = cfg.img_size
+
+ import envs.main as envs
+ self.train_env = envs.make(task, cfg.obs_type, cfg.action_repeat, cfg.seed, img_size=img_size, viclip_encode=cfg.viclip_encode, device='cuda')
+
+ self.dataset_path = Path(cfg.dataset_dir)
+
+ self.timer = utils.Timer()
+ self._global_step = 0
+ self._global_episode = 0
+
+ def process(self):
+ filenames = sorted(self.dataset_path.glob('**/*.npz'))
+ print(f"Found {len(filenames)} files")
+ episodes_to_process = {}
+
+ for idx, fname in tqdm(enumerate(filenames)):
+ lockname = str(fname.absolute()) + ".lck"
+ try:
+ with Locker(lockname):
+ episode = load_episode(fname)
+
+ # validate before continuing
+ if type(episode[self.key_to_add]) == np.ndarray and episode[self.key_to_add].size > 1 and episode[self.key_to_add].shape[0] == episode[self.key_to_process].shape[0]:
+ continue
+ else:
+ del episode[self.key_to_add]
+
+ add_data = self.train_env.process_episode(episode[self.key_to_process]) # .cpu().numpy()
+ if idx == 0:
+ print(add_data.shape)
+ episode[self.key_to_add] = add_data
+
+ # save episode
+ with io.BytesIO() as f1:
+ np.savez_compressed(f1, **episode)
+ f1.seek(0)
+ with fname.open('wb') as f2:
+ f2.write(f1.read())
+ except BlockingIOError:
+ print(f"File busy: {str(fname)}")
+ continue
+
+
+def start_processing(cfg, savedir, workdir):
+ from process_dataset import Workspace as W
+ root_dir = Path.cwd()
+ cfg.workdir = str(root_dir)
+ workspace = W(cfg, savedir, workdir)
+ workspace.root_dir = root_dir
+ snapshot = workspace.root_dir / 'last_snapshot.pt'
+ if snapshot.exists():
+ print(f'resuming: {snapshot}')
+ workspace.load_snapshot(workspace.root_dir)
+ workspace.process()
+
+@hydra.main(config_path='.', config_name='process_dataset')
+def main(cfg):
+ start_processing(cfg, None, None)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/process_dataset.yaml b/process_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b164fa7b15953de22cb22e8e1b4e49a91c989300
--- /dev/null
+++ b/process_dataset.yaml
@@ -0,0 +1,39 @@
+defaults:
+ - conf/env: dmc_pixels
+ - override hydra/launcher: submitit_local
+
+# task settings
+task: stickman_walk
+# misc
+seed: 1
+device: cuda:0
+img_size: 64
+
+# CLIP-related
+viclip_encode: true
+viclip_model: internvideo2
+
+# dataset-related
+dataset_dir: null
+key_to_process: observation
+
+# experiment
+project_name: genrl
+# log settings
+workdir: ???
+
+
+hydra:
+ run:
+ dir: ./exp_local/${now:%Y.%m.%d}/process_data_${now:%H%M%S}
+ sweep:
+ dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_process_data
+ subdir: ${hydra.job.num}
+ launcher:
+ timeout_min: 4300
+ cpus_per_task: 10
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 160
+ nodes: 1
+ submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_process_data_${experiment}/.slurm
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..211a0da4c02098bf7c97a250813fb250f07c4f24
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,37 @@
+bunch==1.0.1
+dm_control
+dm_env==1.6
+einops==0.7.0
+glfw==2.6.5
+gym==0.23.0
+haven_ai==0.7.3
+hydra-core==1.1.0
+imageio==2.9.0
+lxml==5.1.0
+matplotlib==3.9.0
+memory_profiler==0.61.0
+mujoco_py==2.1.2.14
+numpy==1.24.4
+omegaconf==2.1.2
+open_clip_torch==2.24.0
+opencv_python==4.9.0.80
+pandas==2.2.2
+pytest==8.0.2
+robosuite==1.4.1
+scipy==1.9.3
+setuptools==59.5.0
+tensorboard
+tensorboard_data_server
+tensorflow==2.15.0.post1
+termcolor==1.1.0
+torch==2.2.0
+torchvision
+tqdm==4.66.1
+wandb==0.16.4
+flash_attn==2.5.7
+peft==0.4.0
+transformers==4.31.0
+hydra-submitit-launcher==1.1.5
+moviepy==1.0.3
+imageio[ffmpeg]
+# for the demo: gradio==4.36.1
\ No newline at end of file
diff --git a/test/pytest.ini b/test/pytest.ini
new file mode 100644
index 0000000000000000000000000000000000000000..f861f05e293d04e0cf991d75a483aa857583a9ce
--- /dev/null
+++ b/test/pytest.ini
@@ -0,0 +1,5 @@
+[pytest]
+log_cli = 1
+log_cli_level = INFO
+log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)
+log_cli_date_format=%Y-%m-%d %H:%M:%S
\ No newline at end of file
diff --git a/test/test_env.py b/test/test_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b62833ffb33d95f8957999eaaccb4dc01aa01f8a
--- /dev/null
+++ b/test/test_env.py
@@ -0,0 +1,22 @@
+import pathlib
+import sys
+sys.path.append(str(pathlib.Path(__file__).parent.parent))
+
+import logging
+LOGGER = logging.getLogger(__name__)
+
+import pytest
+
+import envs.main as envs
+from tools.task_scores import MAX as task_max_scores
+
+@pytest.mark.filterwarnings('ignore::UserWarning')
+@pytest.mark.filterwarnings('ignore::DeprecationWarning')
+# @pytest.mark.filterwarnings('ignore:The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives')
+def test_envs():
+ for task_name in task_max_scores.keys():
+ LOGGER.info(task_name)
+ env = envs.make(task_name, 'pixels', action_repeat=2, seed=0)
+ env.reset()
+ env.step(env.act_space['action'].sample())
+
diff --git a/third_party/InternVideo/.gitignore b/third_party/InternVideo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..0e7227c68f3d80a10c75b0e06292109848f3f530
--- /dev/null
+++ b/third_party/InternVideo/.gitignore
@@ -0,0 +1,161 @@
+# Docker file from Python is inspired from here :
+# https://github.com/github/gitignore/blob/master/Python.gitignore
+
+events*
+*log.txt
+*.pth
+log_*.txt
+log
+latest
+*.pt
+checkpoint*
+batchscript*
+logs/
+debug*
+
+# custom
+.vscode
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+tests/report/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+
+# Cython debug symbols
+cython_debug/
+
+
+# others
+work_dir/
+batchscript-*
+phoenix-slurm-*
+wandb
diff --git a/third_party/InternVideo/.gitmodules b/third_party/InternVideo/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..da27c2465547a14d16cce8a4260a7f0a3ae20617
--- /dev/null
+++ b/third_party/InternVideo/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "InternVideo1/Pretrain/UniFormerV2"]
+ path = InternVideo1/Pretrain/UniFormerV2
+ url = https://github.com/OpenGVLab/UniFormerV2.git
+[submodule "InternVideo1/Downstream/Ego-Tasks"]
+ path = InternVideo1/Downstream/Ego-Tasks
+ url = https://github.com/OpenGVLab/ego4d-eccv2022-solutions.git
diff --git a/third_party/InternVideo/Data/InternVid/README.md b/third_party/InternVideo/Data/InternVid/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..49d308d3677eeef8907f0d670b31feb4aaaaed7d
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/README.md
@@ -0,0 +1,109 @@
+# InternVid: A Large-scale Video-Text Dataset for Multimodal Understanding and Generation \[[Paper](https://arxiv.org/pdf/2307.06942.pdf)\]
+
+[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20InternVid-Dataset-blue)](https://huggingface.co/datasets/OpenGVLab/InternVid) | [![Model Checkpoint](https://img.shields.io/badge/%F0%9F%A4%97%20ViCLIP-Model-purple)](https://huggingface.co/OpenGVLab/ViCLIP)
+
+\[[中文版本](README_CN.md)\]
+
+# :fire: News
+- The implementation of ViCLIP is given [here](https://github.com/OpenGVLab/InternVideo/tree/main/InternVideo1/Pretrain/ViCLIP).
+
+- InternVid has been accepted for spotlight presentation of ICLR 2024.
+
+- We release a subset [InternVid-Aesthetics-18M](https://huggingface.co/datasets/OpenGVLab/InternVid/viewer/InternVid-10M/AES). It consists of 18 million video clips that have been assigned high aesthetic scores. For more details on the aesthetic scoring, please refer to [laion aesthetic predictor](https://github.com/LAION-AI/aesthetic-predictor).
+
+- We enhance InternVid-10M-FLT dataset annotations by incorporating video language and type information sourced from YouTube's metainfo. You can find the updated annotations at [this link](https://huggingface.co/datasets/OpenGVLab/InternVid-10M-FLT-INFO).
+
+- We release ViCLIP models trained on different subsets of InternVid. Check their performance [here](#model-performance) and download them [here](#pretrained-data--model).
+
+- We are excited to announce the partial release of a large-scale video-text dataset aimed at facilitating multimodal understanding and generation. As part of this release, we are making available a subset [InternVid-10M-FLT](https://huggingface.co/datasets/OpenGVLab/InternVid) of the dataset, which comprises 10 million video clips. Additionally, we have provided a [ViCLIP](https://huggingface.co/OpenGVLab/ViCLIP) model trained on this subset, using the ViT-L architecture. It achieves SOTA zero-shot action recognition performance on Kinetics.
+
+- We give a step-by-step instructions and clarify the process of accessing and utilizing ViClip in [demo.ipynb](https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/demo.ipynb).
+
+- Some model weights and the corresponding data are released at [Pretrained Data & Model](#pretrained-data--model). Their performance is given at [Model Performance](#model-performance).
+
+Stay tuned for updates!
+
+# Introduction
+
+### Data
+
+We collected videos from 16 popular categories with varying percentages. We ensured diversity by selecting videos from countries with different languages instead of relying on a dominant language environment. The countries we sampled from include the UK, USA, Australia, Japan, Korea, China, Russia, and France, among others. In terms of duration, every video lasts 351.9s on average. Almost half (49%) of the videos are five minutes or less, while a quarter (26%) fall between five and ten minutes. Only 8% of the videos are over 20 minutes long. Among the curated videos, 85% were high-resolution (720P), while the remaining 15% had lower resolutions ranging from 360P to 720P. Although the lower-resolution videos may not perform as well as the high-resolution ones in content generation tasks, they can still be useful in video-language representation learning, provided that they have appropriate captions.
+
+![b469e00b43d46a6b3f89899483abcf6](https://github.com/OpenGVLab/InternVideo/assets/43169235/7d6aca7d-362a-425d-9ef2-ec0189491b52)
+
+InternVid exhibits diverse clip durations and caption lengths in the segmented clip level. The aesthetic scores and clip-caption similarities are distributed uniformly. The majority of clips are 0-10 seconds in length, accounting for 85% of all clips. Approximately half of the clips have captions with 10-20 words, while one-third of the clip captions have fewer than 10 words. About 11% of clips have long captions with more than 20 words.
+
+![429af4993adb77478c000c865ae5a1b](https://github.com/OpenGVLab/InternVideo/assets/43169235/f64588c3-81e8-43de-b771-46500474d2ff)
+
+### ViCLIP: a simple video CLIP for transferrable video-text representation
+
+Built upon CLIP , we make a simple video-text pretraining baseline ViCLIP. It consists of a video encoder (ViT) and a text encoder, as given below. Both modules are initialized from the corresponding CLIP components. We update the native attention in the video encoder to spatiotemporal attention while maintaining other design elements. For efficient learning, we apply masking to videos in pre-training.
+
+
+
+### Model Performance
+
+**Table 1: Zero-shot action recognition results on Kinetics 400/600/700. We report the top-1 accuracy of the compared methods on each dataset.**
+|Method | Training Data | K400 | | K600 | | K700 | |
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+| | |top-1 | AVG | top-1 | AVG | top-1 | AVG |
+CLIP | CLIP400M | 58.42 | 70.14 | 55.11| 67.16| 46.12| 58.38
+CLIP | DataComp-1B |56.14| 67.67| 54.15| 65.83| 45.36| 57.01
+EVA-CLIP-L | Merged-2B | - | 65.00| - |64.90| - |59.10
+EVA-CLIP-E | LAION-2B |- |69.80|-| 69.30| -| 63.40
+ViCLIP-B | +InternVid-10M-FLT | 58.52 | 71.11 | 55.37 | 68.27 | 47.09 | 59.98
+ViCLIP-B | +InternVid-200M | 56.58 | 69.20 | 53.57 | 66.20 | 45.82 | 58.28
+ViCLIP-L| +WebVid10M |59.88| 71.03| 58.66| 69.84| 50.23| 61.86
+ViCLIP-L| +InternVid-10M-DIV| 63.00| 74.15| 60.68| 72.07| 52.50| 64.59
+ViCLIP-L| +InternVid-10M-FLT| **64.80** | **75.70** | **62.20** | **73.53** | **54.30** | **66.38**
+ViCLIP-L | +InternVid-200M | 59.80 | 71.09 | 57.80 | 69.34 | 49.30 | 61.25
+
+**Table 2: Fine-tuned action recognition results on Kinetics 400 and SomethingSomethingV2.**
+|Method | Training Data | K400 | | SthSthV2 | |
+|:---:|:---:|:---:|:---:|:---:|:---:|
+| | |top-1 | top-5 | top-1 | top-5|
+CLIP | CLIP400M | 86.7 | 97.2 | 70.1 | 92.5
+CLIP | DataComp-1B |85.6| 96.8| 68.9| 91.8
+ViCLIP-L| +WebVid10M |85.0| 96.8| 68.7| 91.9
+ViCLIP-L| +InternVid-10M-FLT| 86.8 |97.5| 71.2| 93.2
+ViCLIP-L| +InternVid-10M-FLT+K710| 88.0| 97.8| 71.8| 93.6
+ViCLIP-L | +InternVid-200M | 87.9 |97.9| 73.6| 94.9
+ViCLIP-L | +InternVid-200M+K710 | **88.7** | **98.2** | **74.2** | **95.0**
+
+# Data & Model Zoo
+
+### Pretrained Data & Model
+
+
+| Model | Training Data | Descriptions |
+| :-----------------: | :----------------------: | :---------------------------------------------------------------------------------------------------: |
+| ViCLIP-L-14 \[[HuggingFace](https://huggingface.co/OpenGVLab/ViCLIP) \| [Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViClip-InternVid-10M-FLT.pth)\] | InternVid-10M-FLT \[[HuggingFace](https://huggingface.co/datasets/OpenGVLab/InternVid) \| [OpenDataLab](https://opendatalab.com/shepshep/InternVid)\] | - |
+| ViCLIP-L-14 \[[Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViCLIP-L_InternVid-DIV-10M.pth)\] | InternVid-10M-DIV | - |
+| ViCLIP-L-14 \[[Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViCLIP-L_WebVid-10M.pth)\] | WebVid-10M | - |
+| ViCLIP-L-14 \[[Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViCLIP-L_InternVid-10M.pth)\] | InternVid-10M | - |
+| ViCLIP-L-14 \[[Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViCLIP-L_InternVid-50M.pth)\] | InternVid-50M | - |
+| ViCLIP-L-14 \[[Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViCLIP-L_InternVid-200M.pth)\] | InternVid-200M | - |
+| ViCLIP-B-16 \[[OneDrive](https://pjlab-my.sharepoint.cn/:u:/g/personal/wangyi_pjlab_org_cn/EY6ac22ZVzJLm1-wm_9gPaMBm5MFg36GKTxlkwTemgmKzQ?e=mH6u6A)\] | InternVid-10M-FLT | - |
+| ViCLIP-B-16 \[[OneDrive](https://pjlab-my.sharepoint.cn/:u:/g/personal/wangyi_pjlab_org_cn/EVGBg6kq4M1MjbeSdqiXsaMBaBduhR7CQCT11JR4edmZ8Q?e=ILtTfM)\] | InternVid-200M | - |
+
+
+
+## Citation
+
+If you find this work useful for your research, please consider citing InternVid. Your acknowledgement would greatly help us in continuing to contribute resources to the research community.
+
+```
+@inproceedings{wang2023internvid,
+ title={InternVid: A Large-scale Video-Text Dataset for Multimodal Understanding and Generation},
+ author={Wang, Yi and He, Yinan and Li, Yizhuo and Li, Kunchang and Yu, Jiashuo and Ma, Xin and Li, Xinhao and Chen, Guo and Chen, Xinyuan and Wang, Yaohui and others},
+ booktitle={The Twelfth International Conference on Learning Representations},
+ year={2023}
+}
+
+@article{wang2022internvideo,
+ title={InternVideo: General Video Foundation Models via Generative and Discriminative Learning},
+ author={Wang, Yi and Li, Kunchang and Li, Yizhuo and He, Yinan and Huang, Bingkun and Zhao, Zhiyu and Zhang, Hongjie and Xu, Jilan and Liu, Yi and Wang, Zun and Xing, Sen and Chen, Guo and Pan, Junting and Yu, Jiashuo and Wang, Yali and Wang, Limin and Qiao, Yu},
+ journal={arXiv preprint arXiv:2212.03191},
+ year={2022}
+}
+```
diff --git a/third_party/InternVideo/Data/InternVid/README_CN.md b/third_party/InternVideo/Data/InternVid/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..10f24ba5377d2d21e723270d24633e899cdbd932
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/README_CN.md
@@ -0,0 +1,63 @@
+# InternVid \[[论文](https://arxiv.org/pdf/2307.06942.pdf)\]
+
+[![数据集](https://img.shields.io/badge/%F0%9F%A4%97%20InternVid-Dataset-blue)](https://huggingface.co/datasets/OpenGVLab/InternVid) | [![模型](https://img.shields.io/badge/%F0%9F%A4%97%20ViCLIP-Model-purple)](https://huggingface.co/OpenGVLab/ViCLIP)
+
+\[[English verision](README.md)\]
+
+# :fire: 新闻
+我们很高兴宣布部分发布一个大规模的视频文本数据集,旨在促进多模态理解和生成。作为此次发布的一部分,我们提供了该数据集的[子集](https://huggingface.co/datasets/OpenGVLab/InternVid)包含1000万个视频剪辑。此外,我们还提供了一个使用ViT-L架构在这个子集上训练的[ViCLIP](https://huggingface.co/OpenGVLab/ViCLIP)。该模型在Kinetics上实现了SOTA的零样本动作识别性能。
+
+我们提供了示例代码,阐明如何使用ViClip的过程,在[demo.ipynb](https://github.com/OpenGVLab/InternVideo/blob/main/Data/InternVid/demo.ipynb)中有详述。
+
+请关注我们的更新!
+
+# 简介
+
+**数据**
+
+我们从16个流行类别中收集了各种百分比的视频。为了确保多样性,我们选择了来自不同语言的国家的视频,而非依赖于一个主导语言环境。我们采样的国家包括英国、美国、澳大利亚、日本、韩国、中国、俄罗斯和法国等。在时长方面,每个视频平均持续351.9秒。几乎一半(49%)的视频时长不超过五分钟,而四分之一(26%)的视频时长在五到十分钟之间。只有8%的视频超过20分钟。在策划的视频中,85%是高分辨率(720P),其余15%的分辨率从360P至720P不等。虽然低分辨率的视频在内容生成任务中可能表现不如高分辨率的视频,但只要配有适当的字幕,它们仍可用于视频-语言表示学习。
+
+![b469e00b43d46a6b3f89899483abcf6](https://github.com/OpenGVLab/InternVideo/assets/43169235/7d6aca7d-362a-425d-9ef2-ec0189491b52)
+
+InternVid展示了在分割剪辑级别上具有不同剪辑时长和字幕长度的多样性。美学分数和剪辑-字幕相似度均匀分布。大部分剪辑的长度在0-10秒之间,占所有剪辑的85%。大约一半的剪辑字幕含有10-20个单词,而三分之一的剪辑字幕含有少于10个单词。大约11%的剪辑具有超过20个单词的长字幕。
+
+![429af4993adb77478c000c865ae5a1b](https://github.com/OpenGVLab/InternVideo/assets/43169235/f64588c3-81e8-43de-b771-46500474d2ff)
+
+**ViCLIP: 一个简单的用于转移视频-文本表示的视频CLIP**
+
+基于CLIP , 我们构建了一个简单的视频-文本预训练基线ViCLIP。它由视频编码器(ViT)和文本编码器组成,如下所示。这两个模块都是从相应的CLIP组件初始化的。我们将视频编码器中的原生注意力更新为时空注意力,同时保持其他设计元素不变。为了高效学习,我们在预训练中对视频进行了掩蔽处理。
+
+
+
+
+# 数据 & 模型库
+
+### 预训练数据 & 模型
+
+
+
+| 模型 | 训练数据 | 描述 |
+| :-----------------: | :----------------------: | :---------------------------------------------------------------------------------------------------: |
+| ViCLIP-L-14 \[[HuggingFace](https://huggingface.co/OpenGVLab/ViCLIP) \| [Aliyun](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/viclip/ViClip-InternVid-10M-FLT.pth)\] | InternVid-10M-FLT \[[HuggingFace](https://huggingface.co/datasets/OpenGVLab/InternVid) \| [OpenDataLab](https://opendatalab.com/shepshep/InternVid)\] | |
+
+
+
+## Citation
+
+如果您发现这项工作对您的研究有所帮助,请考虑引用InternVid。您的肯定将极大地帮助我们继续为研究社区贡献资源。
+
+```
+@article{wang2023internvid,
+ title={InternVid: A Large-scale Video-Text Dataset for Multimodal Understanding and Generation},
+ author={Wang, Yi and He, Yinan and Li, Yizhuo and Li, Kunchang and Yu, Jiashuo and Ma, Xin and Chen, Xinyuan and Wang, Yaohui and Luo, Ping and Liu, Ziwei and Wang, Yali and Wang, Limin and Qiao, Yu},
+ journal={arXiv preprint arXiv:2307.06942},
+ year={2023}
+}
+
+@article{wang2022internvideo,
+ title={InternVideo: General Video Foundation Models via Generative and Discriminative Learning},
+ author={Wang, Yi and Li, Kunchang and Li, Yizhuo and He, Yinan and Huang, Bingkun and Zhao, Zhiyu and Zhang, Hongjie and Xu, Jilan and Liu, Yi and Wang, Zun and Xing, Sen and Chen, Guo and Pan, Junting and Yu, Jiashuo and Wang, Yali and Wang, Limin and Qiao, Yu},
+ journal={arXiv preprint arXiv:2212.03191},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/third_party/InternVideo/Data/InternVid/demo.ipynb b/third_party/InternVideo/Data/InternVid/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..fe7557a22d83e84c016493ad913e81b1702bedd9
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/demo.ipynb
@@ -0,0 +1,176 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "f86bc499",
+ "metadata": {},
+ "source": [
+ "## download ViCILP weights and put its pth file in viclip folder. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e7a90379-d9ee-45d9-9073-7ed5132fa6b1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/mnt/petrelfs/wangyi/.conda/envs/pt13/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import os\n",
+ "import cv2\n",
+ "\n",
+ "from viclip import get_viclip, retrieve_text, _frame_from_video"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "a425a5da-ceaf-4b89-9845-c8ba576902d8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "video = cv2.VideoCapture('example1.mp4')\n",
+ "frames = [x for x in _frame_from_video(video)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e6c1cd7a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# modify xxx to the path of the pretrained model\n",
+ "model_cfgs = {\n",
+ " 'viclip-l-internvid-10m-flt': {\n",
+ " 'size': 'l',\n",
+ " 'pretrained': 'xxx/ViCLIP-L_InternVid-FLT-10M.pth',\n",
+ " },\n",
+ " 'viclip-l-internvid-200m': {\n",
+ " 'size': 'l',\n",
+ " 'pretrained': 'xxx/ViCLIP-L_InternVid-200M.pth',\n",
+ " },\n",
+ " 'viclip-b-internvid-10m-flt': {\n",
+ " 'size': 'b',\n",
+ " 'pretrained': 'xxx/ViCLIP-B_InternVid-FLT-10M.pth',\n",
+ " },\n",
+ " 'viclip-b-internvid-200m': {\n",
+ " 'size': 'b',\n",
+ " 'pretrained': 'xxx/ViCLIP-B_InternVid-200M.pth',\n",
+ " },\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "3fb7397a-02ef-41b5-9ffe-f2363b277778",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/mnt/petrelfs/wangyi/.conda/envs/pt13/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+ " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "text: A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run. ~ prob: 0.8333\n",
+ "text: A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon. ~ prob: 0.1266\n",
+ "text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.0368\n",
+ "text: A person dressed in a blue jacket shovels the snow-covered pavement outside their house. ~ prob: 0.0030\n",
+ "text: A playful dog slides down a snowy hill, wagging its tail with delight. ~ prob: 0.0003\n"
+ ]
+ }
+ ],
+ "source": [
+ "text_candidates = [\"A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon.\",\n",
+ " \"A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys.\",\n",
+ " \"A person dressed in a blue jacket shovels the snow-covered pavement outside their house.\",\n",
+ " \"A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner.\",\n",
+ " \"A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride.\",\n",
+ " \"A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees.\",\n",
+ " \"A playful dog slides down a snowy hill, wagging its tail with delight.\",\n",
+ " \"A person in a blue jacket walks their pet on a leash, enjoying a peaceful winter walk among the trees.\",\n",
+ " \"A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run.\",\n",
+ " \"A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery.\"]\n",
+ "\n",
+ "cfg = model_cfgs['viclip-l-internvid-10m-flt']\n",
+ "model_l = get_viclip(cfg['size'], cfg['pretrained'])\n",
+ "texts, probs = retrieve_text(frames, text_candidates, models=model_l, topk=5)\n",
+ "\n",
+ "for t, p in zip(texts, probs):\n",
+ " print(f'text: {t} ~ prob: {p:.4f}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "a2969ba6-19d0-4893-b071-b82fa046c312",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "text: A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon. ~ prob: 0.8192\n",
+ "text: A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run. ~ prob: 0.1084\n",
+ "text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.0676\n",
+ "text: A playful dog slides down a snowy hill, wagging its tail with delight. ~ prob: 0.0047\n",
+ "text: A person dressed in a blue jacket shovels the snow-covered pavement outside their house. ~ prob: 0.0002\n"
+ ]
+ }
+ ],
+ "source": [
+ "cfg = model_cfgs['viclip-b-internvid-10m-flt']\n",
+ "model_b = get_viclip(cfg['size'], cfg['pretrained'])\n",
+ "texts, probs = retrieve_text(frames, text_candidates, models=model_b, topk=5)\n",
+ "\n",
+ "for t, p in zip(texts, probs):\n",
+ " print(f'text: {t} ~ prob: {p:.4f}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ebdae1be-0dc4-4f3c-9856-5e0fd27aa368",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/third_party/InternVideo/Data/InternVid/div_sampling.py b/third_party/InternVideo/Data/InternVid/div_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..06cd929e85db4f4c4807ace33ef2a873762a96f7
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/div_sampling.py
@@ -0,0 +1,14 @@
+from collections import Counter
+import json
+import random
+import numpy as np
+data = json.load(open("/path/to/to_sample"))
+video_id = set([x["video"].split("/")[-1][:11] for x in data])
+video_id_counter = Counter([x["video"].split("/")[-1][:11] for x in data])
+sampling_weights = [1.0 / video_id_counter[x["video"].split("/")[-1][:11]] for x in data]
+np.random.seed(42)
+sampling_weights = np.array(sampling_weights)
+sampling_weights = sampling_weights / sampling_weights.sum()
+sampled_index = np.random.choice(len(data), 10647458, replace=False, p=sampling_weights)
+data = [data[i] for i in sampled_index]
+json.dump(data, open("/path/to/sampled", "w"))
\ No newline at end of file
diff --git a/third_party/InternVideo/Data/InternVid/example1.mp4 b/third_party/InternVideo/Data/InternVid/example1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..815f8412d86b0b9d2200efc7d28d5e454e121c0d
Binary files /dev/null and b/third_party/InternVideo/Data/InternVid/example1.mp4 differ
diff --git a/third_party/InternVideo/Data/InternVid/start_annotation_prototype.sh b/third_party/InternVideo/Data/InternVid/start_annotation_prototype.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e798773fe30f0a89d8936fb4f1aafae81f8337c3
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/start_annotation_prototype.sh
@@ -0,0 +1,16 @@
+unset http_proxy; unset https_proxy; unset HTTP_PROXY; unset HTTPS_PROXY
+JOB_NAME='data-annotate_check'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='Video-aigc-general'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ jupyter lab --ip=0.0.0.0
\ No newline at end of file
diff --git a/third_party/InternVideo/Data/InternVid/utils/basic_utils.py b/third_party/InternVideo/Data/InternVid/utils/basic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb453d35c852741bf1ad6dfe27e604d9fef6557e
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/utils/basic_utils.py
@@ -0,0 +1,286 @@
+import numpy as np
+import io
+import os
+import json
+import logging
+import random
+import time
+from collections import defaultdict, deque
+import datetime
+from pathlib import Path
+from typing import List, Union
+
+import torch
+import torch.distributed as dist
+from .distributed import is_dist_avail_and_initialized
+
+
+logger = logging.getLogger(__name__)
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total],
+ dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0: # skip empty meter
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0:
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {:.4f}".format(name, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def get_global_avg_dict(self, prefix=""):
+ """include a separator (e.g., `/`, or "_") at the end of `prefix`"""
+ d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
+ return d
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, log_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % log_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ res_mem=torch.cuda.max_memory_reserved() / MB,
+ ))
+ else:
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def compute_acc(logits, label, reduction='mean'):
+ ret = (torch.argmax(logits, dim=1) == label).float()
+ if reduction == 'none':
+ return ret.detach()
+ elif reduction == 'mean':
+ return ret.mean().item()
+
+
+def compute_n_params(model, return_str=True):
+ tot = 0
+ for p in model.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return '{:.1f}M'.format(tot / 1e6)
+ else:
+ return '{:.1f}K'.format(tot / 1e3)
+ else:
+ return tot
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def remove_files_if_exist(file_paths):
+ for fp in file_paths:
+ if os.path.isfile(fp):
+ os.remove(fp)
+
+
+def save_json(data, filename, save_pretty=False, sort_keys=False):
+ with open(filename, "w") as f:
+ if save_pretty:
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
+ else:
+ json.dump(data, f)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+def flat_list_of_lists(l):
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
+ return [item for sublist in l for item in sublist]
+
+
+def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
+ """
+ Args:
+ root: path to the directory to start search files
+ suffix: any str as suffix, or can match multiple such strings
+ when input is List[str].
+ Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
+ Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
+ """
+ if isinstance(suffix, str):
+ suffix = [suffix, ]
+ filepaths = flat_list_of_lists(
+ [list(Path(root).rglob(f"*{e}")) for e in suffix])
+ return filepaths
+
+
+def match_key_and_shape(state_dict1, state_dict2):
+ keys1 = set(state_dict1.keys())
+ keys2 = set(state_dict2.keys())
+ print(f"keys1 - keys2: {keys1 - keys2}")
+ print(f"keys2 - keys1: {keys2 - keys1}")
+
+ mismatch = 0
+ for k in list(keys1):
+ if state_dict1[k].shape != state_dict2[k].shape:
+ print(
+ f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
+ mismatch += 1
+ print(f"mismatch {mismatch}")
+
+
+def merge_dicts(list_dicts):
+ merged_dict = list_dicts[0].copy()
+ for i in range(1, len(list_dicts)):
+ merged_dict.update(list_dicts[i])
+ return merged_dict
diff --git a/third_party/InternVideo/Data/InternVid/utils/config.py b/third_party/InternVideo/Data/InternVid/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f9ef375b37daa6926f2259502913e38f22e6e2
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/utils/config.py
@@ -0,0 +1,281 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import json
+import os
+import os.path as osp
+import re
+import shutil
+import sys
+import tempfile
+from copy import deepcopy
+from importlib import import_module
+
+import yaml
+
+from .easydict import EasyDict
+
+__all__ = ["Config", "pretty_text"]
+
+
+BASE_KEY = "_base_"
+# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
+BASE_CONFIG = {}
+
+cfg = None
+
+
+class Config(object):
+ """config"""
+
+ @classmethod
+ def pretty_text(cls, cfg: dict, indent=2) -> str:
+ """format dict to a string
+
+ Args:
+ cfg (EasyDict): the params.
+
+ Returns: The string to display.
+
+ """
+ msg = "{\n"
+ for i, (k, v) in enumerate(cfg.items()):
+ if isinstance(v, dict):
+ v = cls.pretty_text(v, indent + 4)
+ spaces = " " * indent
+ msg += spaces + "{}: {}".format(k, v)
+ if i == len(cfg) - 1:
+ msg += " }"
+ else:
+ msg += "\n"
+ return msg
+
+ @classmethod
+ def dump(cls, cfg, savepath=None):
+ """dump cfg to `json` file.
+
+ Args:
+ cfg (dict): The dict to dump.
+ savepath (str): The filepath to save the dumped dict.
+
+ Returns: TODO
+
+ """
+ if savepath is None:
+ savepath = osp.join(cfg.WORKSPACE, "config.json")
+ json.dump(cfg, open(savepath, "w"), indent=2)
+
+ @classmethod
+ def get_config(cls, default_config: dict = None):
+ """get a `Config` instance.
+
+ Args:
+ default_config (dict): The default config. `default_config` will be overrided
+ by config file `--cfg`, `--cfg` will be overrided by commandline args.
+
+ Returns: an EasyDict.
+ """
+ global cfg
+ if cfg is not None:
+ return cfg
+
+ # define arg parser.
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
+ parser.add_argument(
+ "config_file", help="the configuration file to load. support: .yaml, .json, .py"
+ )
+ parser.add_argument(
+ "opts",
+ default=None,
+ nargs="*",
+ help="overrided configs. List. Format: 'key1 name1 key2 name2'",
+ )
+ args = parser.parse_args()
+
+ cfg = EasyDict(BASE_CONFIG)
+ if osp.isfile(args.config_file):
+ cfg_from_file = cls.from_file(args.config_file)
+ cfg = merge_a_into_b(cfg_from_file, cfg)
+ cfg = cls.merge_list(cfg, args.opts)
+ cfg = eval_dict_leaf(cfg)
+
+ # update some keys to make them show at the last
+ for k in BASE_CONFIG:
+ cfg[k] = cfg.pop(k)
+ return cfg
+
+ @classmethod
+ def from_file(cls, filepath: str) -> EasyDict:
+ """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
+
+ Args:
+ filepath (str): The config file path.
+
+ Returns: TODO
+
+ """
+ filepath = osp.abspath(osp.expanduser(filepath))
+ if not osp.isfile(filepath):
+ raise IOError(f"File does not exist: {filepath}")
+ if filepath.endswith(".py"):
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+
+ shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
+ sys.path.insert(0, temp_config_dir)
+ mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
+ # mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+ for k in list(sys.modules.keys()):
+ if "tmp_config" in k:
+ del sys.modules[k]
+ elif filepath.endswith((".yml", ".yaml")):
+ cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
+ elif filepath.endswith(".json"):
+ cfg_dict = json.load(open(filepath, "r"))
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ cfg_text = filepath + "\n"
+ with open(filepath, "r") as f:
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
+ cfg_dir = osp.dirname(filepath)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ for f in base_filename:
+ _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError("Duplicate key is not allowed among bases")
+ base_cfg_dict.update(c)
+
+ cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
+
+ return EasyDict(cfg_dict)
+
+ @classmethod
+ def merge_list(cls, cfg, opts: list):
+ """merge commandline opts.
+
+ Args:
+ cfg: (dict): The config to be merged.
+ opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
+ The keys can be nested. For example, ["a.b", v] will be considered
+ as `dict(a=dict(b=v))`.
+
+ Returns: dict.
+
+ """
+ assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
+ for i in range(0, len(opts), 2):
+ full_k, v = opts[i], opts[i + 1]
+ keys = full_k.split(".")
+ sub_d = cfg
+ for i, k in enumerate(keys):
+ if not hasattr(sub_d, k):
+ raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
+ if i != len(keys) - 1:
+ sub_d = sub_d[k]
+ else:
+ sub_d[k] = v
+ return cfg
+
+
+def merge_a_into_b(a, b, inplace=False):
+ """The values in a will override values in b.
+
+ Args:
+ a (dict): source dict.
+ b (dict): target dict.
+
+ Returns: dict. recursively merge dict a into dict b.
+
+ """
+ if not inplace:
+ b = deepcopy(b)
+ for key in a:
+ if key in b:
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
+ b[key] = merge_a_into_b(a[key], b[key], inplace=True)
+ else:
+ b[key] = a[key]
+ else:
+ b[key] = a[key]
+ return b
+
+
+def eval_dict_leaf(d, orig_dict=None):
+ """eval values of dict leaf.
+
+ Args:
+ d (dict): The dict to eval.
+
+ Returns: dict.
+
+ """
+ if orig_dict is None:
+ orig_dict = d
+ for k, v in d.items():
+ if not isinstance(v, dict):
+ d[k] = eval_string(v, orig_dict)
+ else:
+ eval_dict_leaf(v, orig_dict)
+ return d
+
+
+def eval_string(string, d):
+ """automatically evaluate string to corresponding types.
+
+ For example:
+ not a string -> return the original input
+ '0' -> 0
+ '0.2' -> 0.2
+ '[0, 1, 2]' -> [0,1,2]
+ 'eval(1+2)' -> 3
+ 'eval(range(5))' -> [0,1,2,3,4]
+ '${a}' -> d.a
+
+
+
+ Args:
+ string (str): The value to evaluate.
+ d (dict): The
+
+ Returns: the corresponding type
+
+ """
+ if not isinstance(string, str):
+ return string
+ # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
+ # return eval(string)
+ if string[0:5] == "eval(":
+ return eval(string[5:-1])
+
+ s0 = string
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ if s1 != s0:
+ while s1 != s0:
+ s0 = s1
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ return eval(s1)
+
+ try:
+ v = ast.literal_eval(string)
+ except:
+ v = string
+ return v
diff --git a/third_party/InternVideo/Data/InternVid/utils/config_utils.py b/third_party/InternVideo/Data/InternVid/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..72e31c7c922e811e62e2b92e708ab087651c40c2
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/utils/config_utils.py
@@ -0,0 +1,60 @@
+import logging
+import os
+import sys
+from os.path import dirname, join
+
+from utils.config import Config
+from utils.distributed import init_distributed_mode, is_main_process
+from utils.logger import setup_logger
+
+logger = logging.getLogger(__name__)
+
+
+def setup_config():
+ """Conbine yaml config and command line config with OmegaConf.
+ Also converts types, e.g., `'None'` (str) --> `None` (None)
+ """
+ config = Config.get_config()
+ if config.debug:
+ config.wandb.enable = False
+ return config
+
+
+def setup_evaluate_config(config):
+ """setup evaluation default settings, e.g., disable wandb"""
+ assert config.evaluate
+ config.wandb.enable = False
+ if config.output_dir is None:
+ config.output_dir = join(dirname(config.pretrained_path), "eval")
+ return config
+
+
+def setup_output_dir(output_dir, excludes=["code"]):
+ """ensure not overwritting an exisiting/non-empty output dir"""
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=False)
+ else:
+ existing_dirs_files = os.listdir(output_dir) # list
+ remaining = set(existing_dirs_files) - set(excludes)
+ remaining = [e for e in remaining if "slurm" not in e]
+ remaining = [e for e in remaining if ".out" not in e]
+ # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
+ logger.warn(f"remaining dirs or files: {remaining}")
+
+
+def setup_main():
+ """
+ Setup config, logger, output_dir, etc.
+ Shared for pretrain and all downstream tasks.
+ """
+ config = setup_config()
+ if hasattr(config, "evaluate") and config.evaluate:
+ config = setup_evaluate_config(config)
+ init_distributed_mode(config)
+
+ if is_main_process():
+ setup_output_dir(config.output_dir, excludes=["code"])
+ setup_logger(output=config.output_dir, color=True, name="vindlu")
+ logger.info(f"config: {Config.pretty_text(config)}")
+ Config.dump(config, os.path.join(config.output_dir, "config.json"))
+ return config
diff --git a/third_party/InternVideo/Data/InternVid/utils/distributed.py b/third_party/InternVideo/Data/InternVid/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/Data/InternVid/utils/easydict.py b/third_party/InternVideo/Data/InternVid/utils/easydict.py
new file mode 100644
index 0000000000000000000000000000000000000000..241aca41c9f1b0677be4bf6070c077fa24501816
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/utils/easydict.py
@@ -0,0 +1,149 @@
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ if hasattr(self, k):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+
diff --git a/third_party/InternVideo/Data/InternVid/utils/logger.py b/third_party/InternVideo/Data/InternVid/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/Data/InternVid/utils/optimizer.py b/third_party/InternVideo/Data/InternVid/utils/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..679483b72556c83d6ff19bc51fe4db41c656b56d
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/utils/optimizer.py
@@ -0,0 +1,133 @@
+""" Optimizer Factory w/ Custom Weight Decay
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import re
+import torch
+from torch import optim as optim
+from utils.distributed import is_main_process
+import logging
+logger = logging.getLogger(__name__)
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
+ named_param_tuples = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
+ named_param_tuples.append([name, param, 0])
+ elif name in no_decay_list:
+ named_param_tuples.append([name, param, 0])
+ else:
+ named_param_tuples.append([name, param, weight_decay])
+ return named_param_tuples
+
+
+def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
+ """use lr=diff_lr for modules named found in diff_lr_names,
+ otherwise use lr=default_lr
+
+ Args:
+ named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
+ diff_lr_names: List(str)
+ diff_lr: float
+ default_lr: float
+ Returns:
+ named_param_tuples_with_lr: List([name, param, weight_decay, lr])
+ """
+ named_param_tuples_with_lr = []
+ logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
+ for name, p, wd in named_param_tuples_or_model:
+ use_diff_lr = False
+ for diff_name in diff_lr_names:
+ # if diff_name in name:
+ if re.search(diff_name, name) is not None:
+ logger.info(f"param {name} use different_lr: {diff_lr}")
+ use_diff_lr = True
+ break
+
+ named_param_tuples_with_lr.append(
+ [name, p, wd, diff_lr if use_diff_lr else default_lr]
+ )
+
+ if is_main_process():
+ for name, _, wd, diff_lr in named_param_tuples_with_lr:
+ logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
+
+ return named_param_tuples_with_lr
+
+
+def create_optimizer_params_group(named_param_tuples_with_lr):
+ """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
+ group = {}
+ for name, p, wd, lr in named_param_tuples_with_lr:
+ if wd not in group:
+ group[wd] = {}
+ if lr not in group[wd]:
+ group[wd][lr] = []
+ group[wd][lr].append(p)
+
+ optimizer_params_group = []
+ for wd, lr_groups in group.items():
+ for lr, p in lr_groups.items():
+ optimizer_params_group.append(dict(
+ params=p,
+ weight_decay=wd,
+ lr=lr
+ ))
+ logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
+ return optimizer_params_group
+
+
+def create_optimizer(args, model, filter_bias_and_bn=True):
+ opt_lower = args.opt.lower()
+ weight_decay = args.weight_decay
+ # check for modules that requires different lr
+ if hasattr(args, "different_lr") and args.different_lr.enable:
+ diff_lr_module_names = args.different_lr.module_names
+ diff_lr = args.different_lr.lr
+ else:
+ diff_lr_module_names = []
+ diff_lr = None
+
+ no_decay = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay = model.no_weight_decay()
+ named_param_tuples = add_weight_decay(
+ model, weight_decay, no_decay, filter_bias_and_bn)
+ named_param_tuples = add_different_lr(
+ named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
+ parameters = create_optimizer_params_group(named_param_tuples)
+
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
+ opt_args['eps'] = args.opt_eps
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
+ opt_args['betas'] = args.opt_betas
+ if hasattr(args, 'opt_args') and args.opt_args is not None:
+ opt_args.update(args.opt_args)
+
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+ return optimizer
diff --git a/third_party/InternVideo/Data/InternVid/utils/scheduler.py b/third_party/InternVideo/Data/InternVid/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/Data/InternVid/viclip/__init__.py b/third_party/InternVideo/Data/InternVid/viclip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d72c0be64f7140f7898da87fbf581e8885418a9
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/viclip/__init__.py
@@ -0,0 +1,72 @@
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+from .viclip import ViCLIP
+import torch
+import numpy as np
+import cv2
+import os
+
+
+def get_viclip(size='l',
+ pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth")):
+
+ tokenizer = _Tokenizer()
+ vclip = ViCLIP(tokenizer=tokenizer, size=size, pretrain=pretrain)
+ m = {'viclip':vclip, 'tokenizer':tokenizer}
+
+ return m
+
+def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}):
+ for t in texts:
+ feat = clip.get_text_features(t, tokenizer, text_feat_d)
+ text_feat_d[t] = feat
+ return text_feat_d
+
+def get_vid_feat(frames, clip):
+ return clip.get_vid_features(frames)
+
+def _frame_from_video(video):
+ while video.isOpened():
+ success, frame = video.read()
+ if success:
+ yield frame
+ else:
+ break
+
+v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
+v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
+def normalize(data):
+ return (data/255.0-v_mean)/v_std
+
+def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
+ assert(len(vid_list) >= fnum)
+ step = len(vid_list) // fnum
+ vid_list = vid_list[::step][:fnum]
+ vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]
+ vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
+ vid_tube = np.concatenate(vid_tube, axis=1)
+ vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
+ vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
+ return vid_tube
+
+def retrieve_text(frames,
+ texts,
+ models={'viclip':None,
+ 'tokenizer':None},
+ topk=5,
+ device=torch.device('cuda')):
+ # clip, tokenizer = get_clip(name, model_cfg['size'], model_cfg['pretrained'], model_cfg['reload'])
+ assert(type(models)==dict and models['viclip'] is not None and models['tokenizer'] is not None)
+ clip, tokenizer = models['viclip'], models['tokenizer']
+ clip = clip.to(device)
+ frames_tensor = frames2tensor(frames, device=device)
+ vid_feat = get_vid_feat(frames_tensor, clip)
+
+ text_feat_d = {}
+ text_feat_d = get_text_feat_dict(texts, clip, tokenizer, text_feat_d)
+ text_feats = [text_feat_d[t] for t in texts]
+ text_feats_tensor = torch.cat(text_feats, 0)
+
+ probs, idxs = clip.get_predict_label(vid_feat, text_feats_tensor, top=topk)
+
+ ret_texts = [texts[i] for i in idxs.numpy()[0].tolist()]
+ return ret_texts, probs.numpy()[0]
\ No newline at end of file
diff --git a/third_party/InternVideo/Data/InternVid/viclip/simple_tokenizer.py b/third_party/InternVideo/Data/InternVid/viclip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2feb796910c3ca7c29ec088d39d6fc75e46323
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/viclip/simple_tokenizer.py
@@ -0,0 +1,135 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+# @lru_cache()
+# def default_bpe():
+# return "bpe_simple_vocab_16e6.txt.gz"
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/third_party/InternVideo/Data/InternVid/viclip/viclip.py b/third_party/InternVideo/Data/InternVid/viclip/viclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b919f4f3649bf694588518fc270fc098664ec0b
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/viclip/viclip.py
@@ -0,0 +1,262 @@
+import os
+import logging
+
+import torch
+from einops import rearrange
+from torch import nn
+import math
+
+# from .criterions import VTC_VTM_Loss
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+from .viclip_vision import clip_joint_l14, clip_joint_b16
+from .viclip_text import clip_text_l14, clip_text_b16
+
+logger = logging.getLogger(__name__)
+
+
+class ViCLIP(nn.Module):
+ """docstring for ViCLIP"""
+
+ def __init__(self,
+ tokenizer=None,
+ size='l',
+ pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
+ freeze_text=True):
+ super(ViCLIP, self).__init__()
+ if tokenizer:
+ self.tokenizer = tokenizer
+ else:
+ self.tokenizer = _Tokenizer()
+ self.max_txt_l = 32
+
+ if size.lower() == 'l':
+ self.vision_encoder_name = 'vit_l14'
+ elif size.lower() == 'b':
+ self.vision_encoder_name = 'vit_b16'
+ else:
+ raise NotImplementedError(f"Size {size} not implemented")
+
+ self.vision_encoder_pretrained = False
+ self.inputs_image_res = 224
+ self.vision_encoder_kernel_size = 1
+ self.vision_encoder_center = True
+ self.video_input_num_frames = 8
+ self.vision_encoder_drop_path_rate = 0.1
+ self.vision_encoder_checkpoint_num = 24
+ self.is_pretrain = pretrain
+ self.vision_width = 1024
+ self.text_width = 768
+ self.embed_dim = 768
+ self.masking_prob = 0.9
+
+ if size.lower() == 'l':
+ self.text_encoder_name = 'vit_l14'
+ elif size.lower() == 'b':
+ self.text_encoder_name = 'vit_b16'
+ else:
+ raise NotImplementedError(f"Size {size} not implemented")
+
+ self.text_encoder_pretrained = False#'bert-base-uncased'
+ self.text_encoder_d_model = 768
+
+ self.text_encoder_vocab_size = 49408
+
+ # create modules.
+ self.vision_encoder = self.build_vision_encoder()
+ self.text_encoder = self.build_text_encoder()
+
+ self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
+ self.temp_min = 1 / 100.0
+
+ if pretrain:
+ logger.info(f"Load pretrained weights from {pretrain}")
+ state_dict = torch.load(pretrain, map_location='cpu')['model']
+ self.load_state_dict(state_dict)
+
+ # Freeze weights
+ if freeze_text:
+ self.freeze_text()
+
+
+ def freeze_text(self):
+ """freeze text encoder"""
+ for p in self.text_encoder.parameters():
+ p.requires_grad = False
+
+ def no_weight_decay(self):
+ ret = {"temp"}
+ ret.update(
+ {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
+ )
+ ret.update(
+ {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
+ )
+
+ return ret
+
+ def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False):
+ """forward and calculate loss.
+
+ Args:
+ image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
+ text (dict): TODO
+ idx (torch.Tensor): TODO
+
+ Returns: TODO
+
+ """
+ self.clip_contrastive_temperature()
+
+ vision_embeds = self.encode_vision(image)
+ text_embeds = self.encode_text(raw_text)
+ if return_sims:
+ sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \
+ torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1)
+ return sims
+
+ # calculate loss
+
+ ## VTC loss
+ loss_vtc = self.clip_loss.vtc_loss(
+ vision_embeds, text_embeds, idx, self.temp, all_gather=True
+ )
+
+ return dict(
+ loss_vtc=loss_vtc,
+ )
+
+ def encode_vision(self, image, test=False):
+ """encode image / videos as features.
+
+ Args:
+ image (torch.Tensor): The input images.
+ test (bool): Whether testing.
+
+ Returns: tuple.
+ - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C].
+ - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C].
+
+ """
+ if image.ndim == 5:
+ image = image.permute(0, 2, 1, 3, 4).contiguous()
+ else:
+ image = image.unsqueeze(2)
+
+ if not test and self.masking_prob > 0.0:
+ return self.vision_encoder(
+ image, masking_prob=self.masking_prob
+ )
+
+ return self.vision_encoder(image)
+
+ def encode_text(self, text):
+ """encode text.
+ Args:
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
+ Returns: tuple.
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
+ - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
+
+ """
+ device = next(self.text_encoder.parameters()).device
+ text = self.text_encoder.tokenize(
+ text, context_length=self.max_txt_l
+ ).to(device)
+ text_embeds = self.text_encoder(text)
+ return text_embeds
+
+ @torch.no_grad()
+ def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
+ """Seems only used during pre-training"""
+ self.temp.clamp_(min=self.temp_min)
+
+ def build_vision_encoder(self):
+ """build vision encoder
+ Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
+
+ """
+ encoder_name = self.vision_encoder_name
+ if encoder_name == "vit_l14":
+ vision_encoder = clip_joint_l14(
+ pretrained=self.vision_encoder_pretrained,
+ input_resolution=self.inputs_image_res,
+ kernel_size=self.vision_encoder_kernel_size,
+ center=self.vision_encoder_center,
+ num_frames=self.video_input_num_frames,
+ drop_path=self.vision_encoder_drop_path_rate,
+ checkpoint_num=self.vision_encoder_checkpoint_num,
+ )
+ elif encoder_name == "vit_b16":
+ vision_encoder = clip_joint_b16(
+ pretrained=self.vision_encoder_pretrained,
+ input_resolution=self.inputs_image_res,
+ kernel_size=self.vision_encoder_kernel_size,
+ center=self.vision_encoder_center,
+ num_frames=self.video_input_num_frames,
+ drop_path=self.vision_encoder_drop_path_rate,
+ checkpoint_num=self.vision_encoder_checkpoint_num,
+ )
+ else:
+ raise NotImplementedError(f"Not implemented: {encoder_name}")
+
+ return vision_encoder
+
+ def build_text_encoder(self):
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
+ Returns: nn.Module. The text encoder
+
+ """
+ encoder_name = self.text_encoder_name
+
+ if encoder_name == "vit_l14":
+ text_encoder = clip_text_l14(
+ pretrained=self.text_encoder_pretrained,
+ context_length=self.max_txt_l,
+ vocab_size=self.text_encoder_vocab_size,
+ checkpoint_num=0,
+ )
+ elif encoder_name == "vit_b16":
+ text_encoder = clip_text_b16(
+ pretrained=self.text_encoder_pretrained,
+ context_length=self.max_txt_l,
+ vocab_size=self.text_encoder_vocab_size,
+ checkpoint_num=0,
+ )
+ else:
+ raise NotImplementedError(f"Not implemented: {encoder_name}")
+
+ return text_encoder
+
+ def get_text_encoder(self):
+ """get text encoder, used for text and cross-modal encoding"""
+ encoder = self.text_encoder
+ return encoder.bert if hasattr(encoder, "bert") else encoder
+
+ def get_text_features(self, input_text, tokenizer, text_feature_dict={}):
+ if input_text in text_feature_dict:
+ return text_feature_dict[input_text]
+ text_template= f"{input_text}"
+ with torch.no_grad():
+ # text_token = tokenizer.encode(text_template).cuda()
+ text_features = self.encode_text(text_template).float()
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ text_feature_dict[input_text] = text_features
+ return text_features
+
+ def get_vid_features(self, input_frames):
+ with torch.no_grad():
+ clip_feat = self.encode_vision(input_frames,test=True).float()
+ clip_feat /= clip_feat.norm(dim=-1, keepdim=True)
+ return clip_feat
+
+ def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
+ label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
+ top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
+ return top_probs, top_labels
+
+
+if __name__ =="__main__":
+ tokenizer = _Tokenizer()
diff --git a/third_party/InternVideo/Data/InternVid/viclip/viclip_text.py b/third_party/InternVideo/Data/InternVid/viclip/viclip_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..680b2430844ef63da12a90afe40cf2e612fb04a2
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/viclip/viclip_text.py
@@ -0,0 +1,297 @@
+import os
+import logging
+from collections import OrderedDict
+from pkg_resources import packaging
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+import torch.utils.checkpoint as checkpoint
+import functools
+
+logger = logging.getLogger(__name__)
+
+
+# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
+MODEL_PATH = 'https://huggingface.co/laion'
+_MODELS = {
+ "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
+ "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
+}
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
+ checkpoint_num: int = 0):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ self.checkpoint_num = checkpoint_num
+
+ def forward(self, x: torch.Tensor):
+ if self.checkpoint_num > 0:
+ segments = min(self.checkpoint_num, len(self.resblocks))
+ return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
+ else:
+ return self.resblocks(x)
+
+
+class CLIP_TEXT(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int,
+ checkpoint_num: int,
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+ self._tokenizer = _Tokenizer()
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask(),
+ checkpoint_num=checkpoint_num,
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+
+ def no_weight_decay(self):
+ return {'token_embedding', 'positional_embedding'}
+
+ @functools.lru_cache(maxsize=None)
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def tokenize(self, texts, context_length=77, truncate=True):
+ """
+ Returns the tokenized representation of given input string(s)
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+ truncate: bool
+ Whether to truncate the text in case its encoding is longer than the context length
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self._tokenizer.encoder["<|startoftext|>"]
+ eot_token = self._tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+ else:
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+ def forward(self, text):
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+
+def clip_text_b16(
+ embed_dim=512,
+ context_length=77,
+ vocab_size=49408,
+ transformer_width=512,
+ transformer_heads=8,
+ transformer_layers=12,
+ checkpoint_num=0,
+ pretrained=True,
+):
+ # raise NotImplementedError
+ model = CLIP_TEXT(
+ embed_dim,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ checkpoint_num,
+ )
+ # pretrained = _MODELS["ViT-B/16"]
+ # logger.info(f"Load pretrained weights from {pretrained}")
+ # state_dict = torch.load(pretrained, map_location='cpu')
+ # model.load_state_dict(state_dict, strict=False)
+ # return model.eval()
+ if pretrained:
+ if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
+ pretrained = _MODELS[pretrained]
+ else:
+ pretrained = _MODELS["ViT-B/16"]
+ logger.info(f"Load pretrained weights from {pretrained}")
+ state_dict = torch.load(pretrained, map_location='cpu')
+ if context_length != state_dict["positional_embedding"].size(0):
+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
+ print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
+ if context_length < state_dict["positional_embedding"].size(0):
+ state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
+ else:
+ state_dict["positional_embedding"] = F.pad(
+ state_dict["positional_embedding"],
+ (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
+ value=0,
+ )
+
+ message = model.load_state_dict(state_dict, strict=False)
+ print(f"Load pretrained weights from {pretrained}: {message}")
+ return model.eval()
+
+
+def clip_text_l14(
+ embed_dim=768,
+ context_length=77,
+ vocab_size=49408,
+ transformer_width=768,
+ transformer_heads=12,
+ transformer_layers=12,
+ checkpoint_num=0,
+ pretrained=True,
+):
+ model = CLIP_TEXT(
+ embed_dim,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ checkpoint_num,
+ )
+ if pretrained:
+ if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
+ pretrained = _MODELS[pretrained]
+ else:
+ pretrained = _MODELS["ViT-L/14"]
+ logger.info(f"Load pretrained weights from {pretrained}")
+ state_dict = torch.load(pretrained, map_location='cpu')
+ if context_length != state_dict["positional_embedding"].size(0):
+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
+ print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
+ if context_length < state_dict["positional_embedding"].size(0):
+ state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
+ else:
+ state_dict["positional_embedding"] = F.pad(
+ state_dict["positional_embedding"],
+ (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
+ value=0,
+ )
+
+ message = model.load_state_dict(state_dict, strict=False)
+ print(f"Load pretrained weights from {pretrained}: {message}")
+ return model.eval()
+
+
+def clip_text_l14_336(
+ embed_dim=768,
+ context_length=77,
+ vocab_size=49408,
+ transformer_width=768,
+ transformer_heads=12,
+ transformer_layers=12,
+):
+ raise NotImplementedError
+ model = CLIP_TEXT(
+ embed_dim,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers
+ )
+ pretrained = _MODELS["ViT-L/14_336"]
+ logger.info(f"Load pretrained weights from {pretrained}")
+ state_dict = torch.load(pretrained, map_location='cpu')
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
+
+
+def build_clip(config):
+ model_cls = config.text_encoder.clip_teacher
+ model = eval(model_cls)()
+ return model
diff --git a/third_party/InternVideo/Data/InternVid/viclip/viclip_vision.py b/third_party/InternVideo/Data/InternVid/viclip/viclip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..afa9a1903ccb3aa2b94ba3d1b1da68a8752df3b3
--- /dev/null
+++ b/third_party/InternVideo/Data/InternVid/viclip/viclip_vision.py
@@ -0,0 +1,362 @@
+#!/usr/bin/env python
+import os
+import logging
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from einops import rearrange
+from timm.models.layers import DropPath
+from timm.models.registry import register_model
+
+import torch.utils.checkpoint as checkpoint
+
+# from models.utils import load_temp_embed_with_mismatch
+
+logger = logging.getLogger(__name__)
+
+def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
+ """
+ Add/Remove extra temporal_embeddings as needed.
+ https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
+
+ temp_embed_old: (1, num_frames_old, 1, d)
+ temp_embed_new: (1, num_frames_new, 1, d)
+ add_zero: bool, if True, add zero, else, interpolate trained embeddings.
+ """
+ # TODO zero pad
+ num_frms_new = temp_embed_new.shape[1]
+ num_frms_old = temp_embed_old.shape[1]
+ logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
+ if num_frms_new > num_frms_old:
+ if add_zero:
+ temp_embed_new[
+ :, :num_frms_old
+ ] = temp_embed_old # untrained embeddings are zeros.
+ else:
+ temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
+ elif num_frms_new < num_frms_old:
+ temp_embed_new = temp_embed_old[:, :num_frms_new]
+ else: # =
+ temp_embed_new = temp_embed_old
+ return temp_embed_new
+
+
+# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
+MODEL_PATH = ''
+_MODELS = {
+ "ViT-L/14": os.path.join(MODEL_PATH, "ViCLIP-L_InternVid-FLT-10M.pth"),
+ "ViT-B/16": os.path.join(MODEL_PATH, "ViCLIP-B-InternVid-FLT-10M.pth"),
+}
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.):
+ super().__init__()
+
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ # logger.info(f'Droppath: {drop_path}')
+ self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
+ self.ln_1 = nn.LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("drop1", nn.Dropout(dropout)),
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
+ ("drop2", nn.Dropout(dropout)),
+ ]))
+ self.ln_2 = nn.LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.attention(self.ln_1(x)))
+ x = x + self.drop_path2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.):
+ super().__init__()
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]
+ self.resblocks = nn.ModuleList()
+ for idx in range(layers):
+ self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout))
+ self.checkpoint_num = checkpoint_num
+
+ def forward(self, x):
+ for idx, blk in enumerate(self.resblocks):
+ if idx < self.checkpoint_num:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ def __init__(
+ self, input_resolution, patch_size, width, layers, heads, output_dim=None,
+ kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0.,
+ temp_embed=True,
+ ):
+ super().__init__()
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv3d(
+ 3, width,
+ (kernel_size, patch_size, patch_size),
+ (kernel_size, patch_size, patch_size),
+ (0, 0, 0), bias=False
+ )
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = nn.LayerNorm(width)
+ if temp_embed:
+ self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
+
+ self.transformer = Transformer(
+ width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num,
+ dropout=dropout)
+
+ self.ln_post = nn.LayerNorm(width)
+ if output_dim is not None:
+ self.proj = nn.Parameter(torch.empty(width, output_dim))
+ else:
+ self.proj = None
+
+ self.dropout = nn.Dropout(dropout)
+
+ def get_num_layers(self):
+ return len(self.transformer.resblocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'}
+
+ def mask_tokens(self, inputs, masking_prob=0.0):
+ B, L, _ = inputs.shape
+
+ # This is different from text as we are masking a fix number of tokens
+ Lm = int(masking_prob * L)
+ masked_indices = torch.zeros(B, L)
+ indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm]
+ batch_indices = (
+ torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices)
+ )
+ masked_indices[batch_indices, indices] = 1
+
+ masked_indices = masked_indices.bool()
+
+ return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1])
+
+ def forward(self, x, masking_prob=0.0):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ B, C, T, H, W = x.shape
+ x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
+
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ # temporal pos
+ cls_tokens = x[:B, :1, :]
+ x = x[:, 1:]
+ x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
+ if hasattr(self, 'temporal_positional_embedding'):
+ if x.size(1) == 1:
+ # This is a workaround for unused parameter issue
+ x = x + self.temporal_positional_embedding.mean(1)
+ else:
+ x = x + self.temporal_positional_embedding
+ x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
+
+ if masking_prob > 0.0:
+ x = self.mask_tokens(x, masking_prob)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) #BND -> NBD
+ x = self.transformer(x)
+
+ x = self.ln_post(x)
+
+ if self.proj is not None:
+ x = self.dropout(x[0]) @ self.proj
+ else:
+ x = x.permute(1, 0, 2) #NBD -> BND
+
+ return x
+
+
+def inflate_weight(weight_2d, time_dim, center=True):
+ logger.info(f'Init center: {center}')
+ if center:
+ weight_3d = torch.zeros(*weight_2d.shape)
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ middle_idx = time_dim // 2
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
+ else:
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ weight_3d = weight_3d / time_dim
+ return weight_3d
+
+
+def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True):
+ state_dict_3d = model.state_dict()
+ for k in state_dict.keys():
+ if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape:
+ if len(state_dict_3d[k].shape) <= 2:
+ logger.info(f'Ignore: {k}')
+ continue
+ logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
+ time_dim = state_dict_3d[k].shape[2]
+ state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center)
+
+ pos_embed_checkpoint = state_dict['positional_embedding']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = (input_resolution // patch_size) ** 2
+ orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
+ new_size = int(num_patches ** 0.5)
+ if orig_size != new_size:
+ logger.info(f'Pos_emb from {orig_size} to {new_size}')
+ extra_tokens = pos_embed_checkpoint[:1]
+ pos_tokens = pos_embed_checkpoint[1:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
+ state_dict['positional_embedding'] = new_pos_embed
+
+ message = model.load_state_dict(state_dict, strict=False)
+ logger.info(f"Load pretrained weights: {message}")
+
+
+@register_model
+def clip_joint_b16(
+ pretrained=False, input_resolution=224, kernel_size=1,
+ center=True, num_frames=8, drop_path=0., checkpoint_num=0,
+ dropout=0.,
+):
+ model = VisionTransformer(
+ input_resolution=input_resolution, patch_size=16,
+ width=768, layers=12, heads=12, output_dim=512,
+ kernel_size=kernel_size, num_frames=num_frames,
+ drop_path=drop_path, checkpoint_num=checkpoint_num,
+ dropout=dropout,
+ )
+ # raise NotImplementedError
+ if pretrained:
+ if isinstance(pretrained, str):
+ model_name = pretrained
+ else:
+ model_name = "ViT-B/16"
+
+ logger.info('load pretrained weights')
+ state_dict = torch.load(_MODELS[model_name], map_location='cpu')
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center)
+ return model.eval()
+
+
+@register_model
+def clip_joint_l14(
+ pretrained=False, input_resolution=224, kernel_size=1,
+ center=True, num_frames=8, drop_path=0., checkpoint_num=0,
+ dropout=0.,
+):
+ model = VisionTransformer(
+ input_resolution=input_resolution, patch_size=14,
+ width=1024, layers=24, heads=16, output_dim=768,
+ kernel_size=kernel_size, num_frames=num_frames,
+ drop_path=drop_path, checkpoint_num=checkpoint_num,
+ dropout=dropout,
+ )
+
+ if pretrained:
+ if isinstance(pretrained, str):
+ model_name = pretrained
+ else:
+ model_name = "ViT-L/14"
+ logger.info('load pretrained weights')
+ state_dict = torch.load(_MODELS[model_name], map_location='cpu')
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
+ return model.eval()
+
+
+@register_model
+def clip_joint_l14_336(
+ pretrained=True, input_resolution=336, kernel_size=1,
+ center=True, num_frames=8, drop_path=0.
+):
+ raise NotImplementedError
+ model = VisionTransformer(
+ input_resolution=input_resolution, patch_size=14,
+ width=1024, layers=24, heads=16, output_dim=768,
+ kernel_size=kernel_size, num_frames=num_frames,
+ drop_path=drop_path,
+ )
+ if pretrained:
+ logger.info('load pretrained weights')
+ state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
+ return model.eval()
+
+
+def interpolate_pos_embed_vit(state_dict, new_model):
+ key = "vision_encoder.temporal_positional_embedding"
+ if key in state_dict:
+ vision_temp_embed_new = new_model.state_dict()[key]
+ vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d]
+ vision_temp_embed_old = state_dict[key]
+ vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2)
+
+ state_dict[key] = load_temp_embed_with_mismatch(
+ vision_temp_embed_old, vision_temp_embed_new, add_zero=False
+ ).squeeze(2)
+
+ key = "text_encoder.positional_embedding"
+ if key in state_dict:
+ text_temp_embed_new = new_model.state_dict()[key]
+ text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d]
+ text_temp_embed_old = state_dict[key]
+ text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2)
+
+ state_dict[key] = load_temp_embed_with_mismatch(
+ text_temp_embed_old, text_temp_embed_new, add_zero=False
+ ).squeeze(2).squeeze(0)
+ return state_dict
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+
+ # model = clip_joint_b16(pretrained=True, kernel_size=1, num_frames=8, num_classes=400, drop_path=0.1)
+ # logger.info(model)
+ model = clip_joint_l14(pretrained=False)
+
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
+ s = time.time()
+ logger.info(flop_count_table(flops, max_depth=1))
+ logger.info(time.time()-s)
+ # logger.info(model(torch.rand(1, 3, num_frames, 224, 224)).shape)
diff --git a/third_party/InternVideo/Data/instruction_data/README.md b/third_party/InternVideo/Data/instruction_data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9abf8f34b04f6c4b0945398f074b962c770a0bd8
--- /dev/null
+++ b/third_party/InternVideo/Data/instruction_data/README.md
@@ -0,0 +1,41 @@
+# Instruction data for [VideoChat](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat)
+
+# :fire: Updates
+- **2023/05/11**: Release the **V1**: [Google Drive](https://drive.google.com/file/d/1C-7xmf42QUEi4ApXTcxBHr5nLvTWXyUi/view?usp=sharing) | [Aliyun OSS](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/data/videochat/videochat_instruct_11k.json)
+
+# :speech_balloon: V1: 7K detailed descriptions + 4K multi-turn conversations
+
+ We build a video-centric multimodal instruction data based on WebVid-10M. The corresponding detailed descriptions and multi-turn conversations generations are produced by ChatGPT based on video text (aided by [**VideoChat-Text**](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_with_ChatGPT)) with several prompts concerning **spatiotemporal features**. Compared with detailed video descriptions, video conversations are introduced to further improve data diversity by introducing **temporal and casual features** in the video instruction data.
+
+
+Example of detailed video description
+
+
+
+
+
+
+Example of video conversation
+
+
+
+
+
+# :page_facing_up: Citation
+
+If you find this project useful in your research, please consider cite:
+```BibTeX
+@article{2023videochat,
+ title={VideoChat: Chat-Centric Video Understanding},
+ author={Li, Kunchang and He, Yinan and Wang, Yi and Li, Yizhuo and Wang, Wenhai and Luo, Ping and Wang, Yali and Wang, Limin and Qiao, Yu},
+ journal={arXiv preprint arXiv:2305.06355},
+ year={2023}
+}
+
+@article{wang2022internvideo,
+ title={InternVideo: General Video Foundation Models via Generative and Discriminative Learning},
+ author={Wang, Yi and Li, Kunchang and Li, Yizhuo and He, Yinan and Huang, Bingkun and Zhao, Zhiyu and Zhang, Hongjie and Xu, Jilan and Liu, Yi and Wang, Zun and Xing, Sen and Chen, Guo and Pan, Junting and Yu, Jiashuo and Wang, Yali and Wang, Limin and Qiao, Yu},
+ journal={arXiv preprint arXiv:2212.03191},
+ year={2022}
+}
+```
diff --git a/third_party/InternVideo/Data/instruction_data/assert/conversation.png b/third_party/InternVideo/Data/instruction_data/assert/conversation.png
new file mode 100644
index 0000000000000000000000000000000000000000..64220b0cb0a2338912e3977834d8650cfad9756f
Binary files /dev/null and b/third_party/InternVideo/Data/instruction_data/assert/conversation.png differ
diff --git a/third_party/InternVideo/Data/instruction_data/assert/detailed_description.png b/third_party/InternVideo/Data/instruction_data/assert/detailed_description.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a9307699cc2ccfe805c64cc4361d67aa31e623a
Binary files /dev/null and b/third_party/InternVideo/Data/instruction_data/assert/detailed_description.png differ
diff --git a/third_party/InternVideo/InternVideo2/README.md b/third_party/InternVideo/InternVideo2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ca67229006081b6779042cdce732799a05bc2cb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/README.md
@@ -0,0 +1,58 @@
+# InternVideo2 \[[Paper\]](https://arxiv.org/abs/2403.15377)
+
+
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-classification-on-kinetics-600)](https://paperswithcode.com/sota/action-classification-on-kinetics-600?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-classification-on-kinetics-700)](https://paperswithcode.com/sota/action-classification-on-kinetics-700?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-recognition-in-videos-on-something)](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-recognition-in-videos-on-activitynet)](https://paperswithcode.com/sota/action-recognition-in-videos-on-activitynet?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-classification-on-moments-in-time)](https://paperswithcode.com/sota/action-classification-on-moments-in-time?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/action-recognition-on-hacs)](https://paperswithcode.com/sota/action-recognition-on-hacs?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-msr-vtt)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-msr-vtt?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-msvd)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-msvd?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-lsmdc)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-lsmdc?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-didemo)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-didemo?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-vatex)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-vatex?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-retrieval-on-activitynet)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-activitynet?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-msr-vtt)](https://paperswithcode.com/sota/video-retrieval-on-msr-vtt?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-didemo)](https://paperswithcode.com/sota/video-retrieval-on-didemo?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-msvd)](https://paperswithcode.com/sota/video-retrieval-on-msvd?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-lsmdc)](https://paperswithcode.com/sota/video-retrieval-on-lsmdc?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-activitynet)](https://paperswithcode.com/sota/video-retrieval-on-activitynet?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-retrieval-on-vatex)](https://paperswithcode.com/sota/video-retrieval-on-vatex?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/text-to-audio-retrieval-on-audiocaps)](https://paperswithcode.com/sota/text-to-audio-retrieval-on-audiocaps?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/text-to-audio-retrieval-on-clotho)](https://paperswithcode.com/sota/text-to-audio-retrieval-on-clotho?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-text-to-audio-retrieval-on)](https://paperswithcode.com/sota/zero-shot-text-to-audio-retrieval-on?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-text-to-audio-retrieval-on-clotho)](https://paperswithcode.com/sota/zero-shot-text-to-audio-retrieval-on-clotho?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/audio-classification-on-esc-50)](https://paperswithcode.com/sota/audio-classification-on-esc-50?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-grounding-on-qvhighlights)](https://paperswithcode.com/sota/video-grounding-on-qvhighlights?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/temporal-action-localization-on-fineaction)](https://paperswithcode.com/sota/temporal-action-localization-on-fineaction?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/temporal-action-localization-on-hacs)](https://paperswithcode.com/sota/temporal-action-localization-on-hacs?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/temporal-action-localization-on-thumos14)](https://paperswithcode.com/sota/temporal-action-localization-on-thumos14?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/temporal-action-localization-on-activitynet)](https://paperswithcode.com/sota/temporal-action-localization-on-activitynet?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/zero-shot-video-question-answer-on-egoschema-1)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema-1?p=internvideo2-scaling-video-foundation-models)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/internvideo2-scaling-video-foundation-models/video-instance-segmentation-on-youtube-vis-1)](https://paperswithcode.com/sota/video-instance-segmentation-on-youtube-vis-1?p=internvideo2-scaling-video-foundation-models)
+
+This repo will give the code and models of '[InternVideo2: Scaling Video Foundation Models for Multimodal Video Understanding](https://arxiv.org/abs/2403.15377)' soon.
+
+- **Achieved `92.1%` Top1 accuracy in Kinetics 400.**
+- **Achieved `SOTA` performance on over `60` video/audio-related tasks (including action recognition, temporal localization, retrieval, etc) when released.**
+
+## Updates
+- `2024/04/15`: Update the code and scripts for InternVideo2 CLIP.
+- `2024/04/13`: Update the code and scripts for InternVideo2 Stage1 & 2.
+- `2024/03/22`: The technical report of InternVideo2 is released.
+
+## Citation
+
+If this work is helpful for your research, please consider citing InternVideo.
+
+```
+@article{wang2024internvideo2,
+ title={InternVideo2: Scaling Video Foundation Models for Multimodal Video Understanding},
+ author={Wang, Yi and Li, Kunchang and Li, Xinhao and Yu, Jiashuo and He, Yinan and Chen, Guo and Pei, Baoqi and Zheng, Rongkun and Xu, Jilan and Wang, Zun and Shi, Yansong and Jiang, Tianxiang and Li, Songze and Zhang, Hongjie and Huang, Yifei and Qiao, Yu and Wang, Yali and Wang, Limin},
+ journal={arXiv preprint arXiv:2403.15377},
+ year={2024}
+}
+```
diff --git a/third_party/InternVideo/InternVideo2/figs/wechatgrp.png b/third_party/InternVideo/InternVideo2/figs/wechatgrp.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a2ee907e5d75e1432a15023fc513217010c857a
Binary files /dev/null and b/third_party/InternVideo/InternVideo2/figs/wechatgrp.png differ
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/.gitignore b/third_party/InternVideo/InternVideo2/multi_modality/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5a942c785450c074ab6933b0ebda6fdba1fbefcb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/.gitignore
@@ -0,0 +1,58 @@
+# local #
+tmp*/
+cache/*
+*/cache*/
+tmp*.py
+tmp*
+*pickle
+data/
+mix_ckpt/
+
+# Model ckpts
+*.pth
+*.pt
+
+# Zip Files/Packages #
+*.7z
+*.dmg
+*.gz
+*.iso
+*.jar
+*.rar
+*.tar
+*.zip
+
+# Logs and databases #
+*.sql
+*.sqlite
+.ipynb_checkpoints/
+*.swp
+*.vscode/
+*.idea/
+*.pyc
+__pycache__
+slurm*out
+
+# OS files #
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+
+
+.vim-arsync
+scratch.norg
+sync_to_red.sh
+
+anno/
+wandb/
+logs/
+*.pth
+
+exp
+
+batchscript-*
+*.out
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/DATASET.md b/third_party/InternVideo/InternVideo2/multi_modality/DATASET.md
new file mode 100644
index 0000000000000000000000000000000000000000..57959edfd98f544d523d9863e2e341bc7a7bb614
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/DATASET.md
@@ -0,0 +1,42 @@
+# Dataset Preparation
+
+
+# Stage2——Video-language Alignment
+
+
+## Pretraining
+
+The public portion of the pre-trained dataset we use is as follows:
+- [CC3M images](https://github.com/google-research-datasets/conceptual-captions)
+- [CC12M images](https://github.com/google-research-datasets/conceptual-12m)
+- [SBU images](https://www.cs.rice.edu/~vo9/sbucaptions/)
+- [VG images](https://visualgenome.org/api/v0/api_home.html)
+- [COCO images](https://cocodataset.org/#download)
+- [WebVid videos](https://github.com/m-bain/webvid)
+- [InternVid videos](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid)
+
+## Evaluation
+
+For evaluation, we follow [VINDLU](https://github.com/klauscc/VindLU/) to prepare the datasets, but we **DO NOT** compress the videos and images. We use the original data and load the JSON files. And We use the same **JSON** files provided by [VINDLU](https://drive.google.com/drive/folders/12bC7WotvwyTG4pVvYeU4iZzmBLP1-6d9).
+
+
+### Video-Text Retrieval
+
+- [MSRVTT videos](https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip)
+- [MSVD videos](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/)
+- [ActivityNet videos](http://activity-net.org/download.html)
+- [DiDeMo videos](https://github.com/LisaAnne/LocalizingMoments)
+
+
+# Stage3——VideoChat
+
+## Pretraining
+
+- [VideoChat-IT](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT)
+
+
+## Evaluation
+### MVBench
+
+Please refer to [MVBench](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/INSTALL.md b/third_party/InternVideo/InternVideo2/multi_modality/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..a456016344a863d05709ce6cc3f5c8e67c8fa150
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/INSTALL.md
@@ -0,0 +1,20 @@
+# Installation
+
+## Requirements
+
+We mainly follow [UMT](https://github.com/OpenGVLab/Unmasked_Teacher) to prepare the enviroment.
+
+```shell
+pip install -r requirements.txt
+```
+In addition, in order to support the InternVideo2-6B pre-training, you also need to install [Flash Attention](https://github.com/Dao-AILab/flash-attention) and [DeepSpeed](https://github.com/microsoft/DeepSpeed).
+
+
+## Note
+
+To run InternVideo2 pretraining, you have to prepare the weights of the **[InternVL-6B visual encoder](https://huggingface.co/OpenGVLab/InternVL/blob/main/internvl_c_13b_224px.pth)**, and set the `your_model_path` in [internvl_clip_vision.py](./models/backbones/internvideo2/internvl_clip_vision.py).
+
+## Key Dependencies Installation for FlashAttention2
+
+Some modules (FusedMLP and DropoutLayerNorm) from FlashAttention2 used in our models rely on CUDA extensions.
+TBD
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/MODEL_ZOO.md b/third_party/InternVideo/InternVideo2/multi_modality/MODEL_ZOO.md
new file mode 100644
index 0000000000000000000000000000000000000000..084d2c0896bf8dcaddbc411d6a732ac5c2169b21
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/MODEL_ZOO.md
@@ -0,0 +1,76 @@
+# Model Zoo
+
+
+## Pretraining
+For $\text{InternVideo2}_{s2}$, we load those models of $\text{InternVideo2}_{s1}$ and further pretrain them on multi-modality datasets.
+
+For $\text{InternVideo2}_{clip}$, we load those models of $\text{InternVideo2}_{s2}$.
+
+
+| Model | Setting | Model | Pretraining Script |
+| -------- | ----------- | ------ | ------------- |
+| $\text{InternVideo2}_{s2}$-1B | IV-25.5M | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2/blob/main/InternVideo2-stage2_1b-224p-f4.pt) | [script](scripts/pretraining/stage2/1B/run.sh) |
+| $\text{InternVideo2}_{clip}$-1B | IV-25.5M | TBD | [script](scripts/pretraining/clip/1B/run.sh) |
+| $\text{InternVideo2}_{s2}$-6B | IV-400M | TBD | [script](scripts/pretraining/stage2/6B/run.sh) |
+| $\text{InternVideo2}_{clip}$-6B | IV-400M | TBD | [script](scripts/pretraining/clip/6B/run.sh) |
+
+
+### Zero-shot Evaluation
+
+## Zero-Shot Video-Text Retrieval
+
+| Model | Dataset | T2V | V2T | Evaluation Script |
+| -------- | ----------- | ------ | ------- | ------- |
+| $\text{InternVideo2}_{s2}$-1B | MSRVTT | 51.9 | 50.9 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt.sh) |
+| | LSMDC | 32.0 | 27.3 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_lsmdc.sh) |
+| | DiDeMo | 57.0 | 54.3 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_didemo.sh) |
+| | MSVD | 58.1 | 83.3 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_msvd.sh) |
+| | ANet | 60.4 | 54.8 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_anet.sh) |
+| | VATEX | 70.4 | 85.4 | [script](scripts/evaluation/stage2/zero_shot/1B/eval_vatex.sh) |
+| $\text{InternVideo2}_{s2}$-6B | MSRVTT | 55.9 | 53.7 | TBD |
+| | LSMDC | 33.8 | 30.1 | TBD |
+| | DiDeMo | 57.9 | 57.1 | TBD |
+| | MSVD | 59.3 | 83.1 | TBD |
+| | ANet | 63.2 | 56.5 | TBD |
+| | VATEX | 71.5 | 85.3 | TBD |
+
+
+| Model | Dataset | T2V | V2T | Evaluation Script |
+| -------- | ----------- | ------ | ------- | ------- |
+| $\text{InternVideo2}_{clip}$-1B | MSRVTT | 50.0 | 48.4 | [script](scripts/evaluation/clip/zero_shot/1B/eval_msrvtt.sh) |
+| | LSMDC | 26.4 | 23.1 | [script](scripts/evaluation/clip/zero_shot/1B/eval_lsmdc.sh) |
+| | DiDeMo | 47.8 | 46.4 | [script](scripts/evaluation/clip/zero_shot/1B/eval_didemo.sh) |
+| | ANet | 49.4 | 46.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_anet.sh) |
+| | VATEX_en | 63.5 | 81.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_vatex_en.sh) |
+| | VATEX_ch | 54.9 | 76.4 | [script](scripts/evaluation/clip/zero_shot/1B/eval_vatex_ch.sh) |
+| $\text{InternVideo2}_{clip}$-6B | MSRVTT | 50.9 | 50.6 | [script](scripts/evaluation/clip/zero_shot/6B/eval_msrvtt.sh) |
+| | LSMDC | 29.4 | 26.3 | [script](scripts/evaluation/clip/zero_shot/6B/eval_lsmdc.sh) |
+| | DiDeMo | 50.5 | 46.8| [script](scripts/evaluation/clip/zero_shot/6B/eval_didemo.sh) |
+| | ANet | 50.2 | 47.5 | [script](scripts/evaluation/clip/zero_shot/6B/eval_anet.sh) |
+| | VATEX_en | 64.1 | 82.6 | [script](scripts/evaluation/clip/zero_shot/6B/eval_vatex_en.sh) |
+| | VATEX_ch | 54.6 | 76.9 | [script](scripts/evaluation/clip/zero_shot/6B/eval_vatex_ch.sh) |
+
+
+## Zero-Shot Action Recognition
+
+| Model | Dataset | top-1 | AVG | Script |
+| -------- | ----------- | ------ | ------- | ------- |
+| $\text{InternVideo2}_{clip}$-1B | K400 | 73.1 | 82.4 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k400.sh) |
+| | K600 | 72.8 | 81.8 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k600.sh) |
+| | K700 | 64.9 | 75.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k700.sh) |
+| | UCF101 | 88.8 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_ucf101.sh) |
+| | HMDB51 | 53.9 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_hmdb51.sh) |
+| | MiT | 31.6 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_mit.sh) |
+| | SSv2-MC | 61.5 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_ssv2_mc.sh) |
+| $\text{InternVideo2}_{clip}$-6B | K400 | 72.7 | 82.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k400.sh) |
+| | K600 | 71.7 | 81.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k600.sh) |
+| | K700 | 64.2 | 75.2 | [script](scripts/evaluation/clip/zero_shot/1B/eval_k700.sh) |
+| | UCF101 | 89.5 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_ucf101.sh) |
+| | HMDB51 | 56.7 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_hmdb51.sh) |
+| | MiT | 32.9 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_mit.sh) |
+| | SSv2-MC | 63.5 | - | [script](scripts/evaluation/clip/zero_shot/1B/eval_ssv2_mc.sh) |
+
+| Model | Dataset | mAP | Script |
+| -------- | ----------- | ------ | ------- |
+| $\text{InternVideo2}_{clip}$-1B | Charades | 32.9 | [script](scripts/evaluation/clip/zero_shot/1B/eval_charades_mc.sh) |
+| $\text{InternVideo2}_{clip}$-6B | Charades | 34.6 | [script](scripts/evaluation/clip/zero_shot/6B/eval_charades_mc.sh) |
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/README.md b/third_party/InternVideo/InternVideo2/multi_modality/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e77a55056a5433f51620ff01320b1f930b98e961
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/README.md
@@ -0,0 +1,61 @@
+# Multi-modality of InternVideo2
+
+## Installation
+
+Please follow the installation instructions in [INSTALL](./INSTALL.md).
+
+>The codebase support using [wandb](https://wandb.ai/) to monitor training. If you want to use wandb, you will need to set up it following [this very short instruction](https://docs.wandb.ai/quickstart#1.-set-up-wandb), and also set `wandb.enable` in the config to be `True`. `wandb.entity` and `wandb.project` should also be set.
+
+## Datasets
+
+You can find the dataset instructions in [DATASET](DATASET.md).
+
+## Model ZOO
+
+You can find all the models and the scripts in [MODEL_ZOO](./MODEL_ZOO.md).
+
+## Demo of Using InternVideo2 in Your Work
+We give a short instructions of accessing and utilizing InternVideo2-stage2 in [demo.ipynb](./demo.ipynb).
+
+## Pre-Training
+
+We use [InternVL](https://github.com/OpenGVLab/InternVL/) pretrained model as the teacher by default
+
+For training, you can simply run the pretraining scripts in `scripts/pretraining` as follows:
+```shell
+bash scripts/pretraining/stage2/1B/run.sh
+```
+
+:warning: **Notes:**
+1. Set `data_dir` and `your_data_path` like `your_webvid_path` in [data.py](./configs/data.py) before running the scripts.
+2. Set `vision_encoder.pretrained` in `vision_encoder.pretrained` in the corresponding config files.
+3. Set `--rdzv_endpoint` to your `MASTER_NODE:MASTER_PORT`. You can also use the following commond to automatically set it:
+ ```shell
+ MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+ ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
+ MASTER_PORT=$((10000 + $RANDOM % 100))
+ torchrun --rdzv_endpoint=${MASTER_NODE}:10068 $@
+ ```
+4. `save_latest=True` will automatically save the latest checkpoint while training.
+5. `auto_resume=True` will automatically loaded the best or latest checkpoint while training.
+
+
+## Zero-shot Evaluation
+
+For zero-shot evaluation, you can simply run the pretraining scripts in `scripts/evaluation` as follows:
+```shell
+bash scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt.sh
+```
+When evaluating, you can choose to turn off deepspeed and the performance will fluctuate slightly from the reported result (around 0.2):
+```shell
+bash scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt_no_deepspeed.sh
+```
+
+:warning: **Notes:**
+1. Set `pretrained_path=your_model_path` in the running scripts before running the scripts.
+2. Set `zero_shot=True` and `evaluate=True` for zero-shot evaluation
+
+## Finetuning
+
+Coming soon.
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert.json b/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert.json
new file mode 100644
index 0000000000000000000000000000000000000000..d52d5ea8761fec6d19eaaff086299adf40370add
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert.json
@@ -0,0 +1,22 @@
+{
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30522,
+ "fusion_layer": 9,
+ "encoder_width": 768,
+ "cross_module": "ca"
+}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert_large.json b/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert_large.json
new file mode 100644
index 0000000000000000000000000000000000000000..7b4578d22d569a22044366533323fbc144e1bfd6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/config_bert_large.json
@@ -0,0 +1,25 @@
+{
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 1024,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522,
+ "fusion_layer": 19,
+ "encoder_width": 768,
+ "cross_module": "ca"
+}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/data.py b/third_party/InternVideo/InternVideo2/multi_modality/configs/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa2eff15888e38e38ba67597922f1f64f18172c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/data.py
@@ -0,0 +1,347 @@
+import os as __os # add "__" if not want to be exported
+from copy import deepcopy as __deepcopy
+
+
+# ============== pretraining datasets=================
+available_corpus = dict(
+ # pretraining image datasets
+ cc3m=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image"
+ ),
+ cc12m=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image"
+ ),
+ sbu=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image"
+ ),
+ vg=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image",
+ jump_filter=True
+ ),
+ coco=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image",
+ jump_filter=True
+ ),
+ laion_2b=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image",
+ jump_filter=True
+ ),
+ laion_coco=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image",
+ jump_filter=True
+ ),
+ laion_pop=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image",
+ jump_filter=True
+ ),
+ # pretraining video datasets
+ webvid_fuse_10m=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ jump_filter=True
+ ),
+ internvid_v1=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ jump_filter=True
+ ),
+ internvid_v2_avs_private=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio_video",
+ read_clip_from_video=False,
+ read_audio_from_video=True,
+ zero_audio_padding_for_video=True,
+ caption_augmentation=dict(caption_sample_type='avs_all'),
+ jump_filter=True
+ ),
+ webvid=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+ ),
+ webvid_10m=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ ),
+ # audio-text
+ wavcaps_400k=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio"
+ ),
+ # debug
+ cc3m_debug=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="image"
+ ),
+ webvid_debug=dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+ )
+)
+
+available_corpus["pretrain_example_data_1B"] = [
+ available_corpus['cc3m'],
+ available_corpus['webvid']
+]
+
+available_corpus["pretrain_example_data_6B"] = [
+ available_corpus['cc3m'],
+ available_corpus['webvid'],
+ available_corpus['internvid_v2_avs_private']
+]
+
+available_corpus["data_25m"] = [
+ available_corpus["webvid_10m"],
+ available_corpus["cc3m"],
+ available_corpus["coco"],
+ available_corpus["vg"],
+ available_corpus["sbu"],
+ available_corpus["cc12m"],
+]
+
+available_corpus["debug"] = [
+ available_corpus["cc3m_debug"],
+ available_corpus["webvid_debug"],
+]
+
+
+# ============== for validation =================
+available_corpus["msrvtt_1k_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+)
+
+available_corpus["didemo_ret_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_paragraph_retrieval=True,
+ trimmed30=True,
+ max_txt_l=64
+)
+
+available_corpus["anet_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_paragraph_retrieval=True,
+ max_txt_l = 150
+)
+
+available_corpus["lsmdc_ret_test_1000"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+)
+
+available_corpus["vatex_ch_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+)
+
+available_corpus["vatex_en_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video"
+)
+
+available_corpus["k400_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ is_act_rec=True,
+)
+
+available_corpus["k600_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_act_rec=True,
+)
+
+available_corpus["k700_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_act_rec=True,
+)
+
+available_corpus["mit_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_act_rec=True,
+)
+
+available_corpus["ucf101_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_act_rec=True,
+)
+
+available_corpus["hmdb51_act_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_act_rec=True,
+)
+
+available_corpus["ssv2_mc_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+)
+
+available_corpus["charades_mc_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+)
+
+
+available_corpus["anet_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_paragraph_retrieval=True,
+ max_txt_l = 150
+)
+
+available_corpus["didemo_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_paragraph_retrieval=True,
+ trimmed30=True,
+ max_txt_l=64
+)
+
+available_corpus["didemo_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ is_paragraph_retrieval=True,
+ trimmed30=True,
+ max_txt_l=64
+)
+
+available_corpus["lsmdc_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ max_txt_l=96
+)
+
+available_corpus["lsmdc_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ max_txt_l=96
+)
+
+available_corpus["msrvtt_ret_train9k"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+)
+
+available_corpus["msrvtt_ret_test1k"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+)
+
+available_corpus["msvd_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ max_txt_l=64,
+ has_multi_txt_gt=True
+)
+
+available_corpus["msvd_ret_val"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ max_txt_l=64
+)
+
+available_corpus["msvd_ret_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ max_txt_l=64
+)
+
+
+available_corpus["vatex_en_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="video",
+ has_multi_txt_gt=True
+)
+
+
+# audio-text
+
+available_corpus["audiocaps_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
+
+available_corpus["audiocaps_ret_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
+
+
+available_corpus["clothov1_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
+
+available_corpus["clothov1_ret_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
+
+available_corpus["clothov2_ret_train"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
+
+available_corpus["clothov2_ret_test"] = dict(
+ anno_path="your_path",
+ data_root="",
+ media_type="audio",
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config.json b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..33b4a892cbe5cb8664fa369d7dcee752ff1843cf
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config.json
@@ -0,0 +1,22 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "add_type_embeddings": false,
+ "vocab_size": 30522,
+ "encoder_width": 768,
+ "add_cross_attention": true,
+ "cross_freq": 0
+}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config_fusion.json b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config_fusion.json
new file mode 100644
index 0000000000000000000000000000000000000000..d8f464126c4b99be97fba06fce6783c714bd9964
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_config_fusion.json
@@ -0,0 +1,23 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "add_type_embeddings": false,
+ "vocab_size": 30522,
+ "encoder_width": 768,
+ "fusion_layer": 9,
+ "add_cross_attention": true,
+ "cross_freq": 0
+}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/med_large_config.json b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_large_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..137815767f509dbc4726af5940bb3cb9ac21005f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/med_large_config.json
@@ -0,0 +1,22 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "add_type_embeddings": false,
+ "vocab_size": 30522,
+ "encoder_width": 1024,
+ "add_cross_attention": true,
+ "cross_freq": 0
+}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/configs/model.py b/third_party/InternVideo/InternVideo2/multi_modality/configs/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f355549fc7d45761b7e4a354a55bba2e9b3f0af9
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/configs/model.py
@@ -0,0 +1,31 @@
+VisionEncoders = dict()
+
+
+TextEncoders = dict()
+TextEncoders["bert"] = dict(
+ name="bert_base",
+ pretrained="bert-base-uncased",
+ config="configs/config_bert.json",
+ d_model=768,
+ fusion_layer=9,
+)
+TextEncoders["bert_large"] = dict(
+ name="bert_large",
+ pretrained="bert-large-uncased",
+ config="configs/config_bert_large.json",
+ d_model=1024,
+ fusion_layer=19,
+)
+TextEncoders["med_bert"] = dict(
+ name="med_bert_base",
+ pretrained="bert-base-uncased",
+ config="configs/med_config.json",
+ d_model=768,
+)
+
+TextEncoders["med_bert_large"] = dict(
+ name="med_bert_large",
+ pretrained="bert-base-uncased", # not a bug, it just follows BLIP.
+ config="configs/med_large_config.json",
+ d_model=768
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b66a920dd0731a6d10a558cdba496ba847aa418e
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/__init__.py
@@ -0,0 +1,465 @@
+import logging
+import torch
+from torch.utils.data import ConcatDataset, DataLoader
+from dataset.resample_concat_dataset import ResampleConcatDataset
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+import copy
+
+from dataset.sampler import StatefulDistributedSampler
+from dataset.dataloader import MetaLoader, MetaLoader_rs # NOTE keep it
+from dataset.ret_dataset import (ImgTxtRetTrainDataset,
+ VidTxtRetTrainDataset,
+ ImgTxtRetEvalDataset,
+ VidTxtRetEvalDataset,
+ VidTxtRetMCEvalDataset,
+ VidTxtRetMCNewEvalDataset)
+# from dataset.ret_dataset import (ImgTxtRetTrainDataset,
+# VidTxtRetTrainDataset,
+# ImgTxtRetEvalDataset,
+# VidTxtRetEvalDataset,
+# AudioVidTxtRetTrainDataset,
+# AudioVidTxtRetEvalDataset,
+# VidTxtRetMCEvalDataset,
+# AudioTxtRetTrainDataset,
+# AudioTxtRetEvalDataset)
+
+from dataset.qa_dataset import ImageQADataset, VideoQADataset
+from dataset.pt_dataset import (ImgTxtPtTrainDataset,
+ VidTxtPtTrainDataset,)
+# from dataset.pt_dataset import (ImgTxtPtTrainDataset,
+# VidTxtPtTrainDataset,
+# AudioVidTxtPtTrainDataset,
+# AudioTxtPtTrainDataset)
+
+logger = logging.getLogger(__name__)
+
+def get_media_type(dataset_config):
+ return dataset_config['media_type']
+
+def get_dataset_cls(dataset_type, media_type, data_cfg):
+ if dataset_type == "pt_train":
+ if media_type == "image":
+ dataset_cls = ImgTxtPtTrainDataset
+ elif media_type == "video":
+ dataset_cls = VidTxtPtTrainDataset
+ # elif media_type == "audio_video":
+ # dataset_cls = AudioVidTxtPtTrainDataset
+ # elif media_type == "audio":
+ # dataset_cls = AudioTxtPtTrainDataset
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}, media_type={media_type}")
+ elif dataset_type == "ret_train":
+ if media_type == "image":
+ dataset_cls = ImgTxtRetTrainDataset
+ elif media_type == "video":
+ dataset_cls = VidTxtRetTrainDataset
+ # elif media_type == 'audio':
+ # dataset_cls = AudioTxtRetTrainDataset
+ # elif media_type == "audio_video":
+ # dataset_cls = AudioVidTxtRetTrainDataset
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}, media_type={media_type}")
+ elif dataset_type == "ret_eval":
+ if media_type == "image":
+ dataset_cls = ImgTxtRetEvalDataset
+ elif media_type == "video":
+ dataset_cls = VidTxtRetEvalDataset
+ # elif media_type == "audio":
+ # dataset_cls = AudioTxtRetEvalDataset
+ # elif media_type == "audio_video":
+ # dataset_cls = AudioVidTxtRetEvalDataset
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}, media_type={media_type}")
+ elif dataset_type in ["qa_train", 'qa_eval']:
+ if media_type == "image":
+ dataset_cls = ImageQADataset
+ elif media_type == "video":
+ dataset_cls = VideoQADataset
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}, media_type={media_type}")
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}, media_type={media_type}")
+
+ print(f"\033[31m dataset_type: {dataset_type} media_type: {media_type} dataset_cls: {dataset_cls}\033[0m")
+ logger.info(f"dataset_type: {dataset_type} media_type: {media_type} dataset_cls: {dataset_cls}")
+
+ return dataset_cls
+
+def get_train_transform(config, train_file):
+ vision_enc_name = config.model.vision_encoder.name
+ if "internvideo" in vision_enc_name or "vit" in vision_enc_name or "umt" in vision_enc_name:
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ elif "clip" in vision_enc_name:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ std = (0.26862954, 0.26130258, 0.27577711)
+ else:
+ raise NotImplementedError(vision_enc_name)
+
+ normalize = transforms.Normalize(mean, std)
+
+ # loaded images and videos are torch.Tensor of torch.uint8 format,
+ # ordered as (T, 1 or 3, H, W) where T=1 for image
+ type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
+
+ if config.inputs.video_input.random_aug:
+ aug_transform = transforms.RandAugment()
+ else:
+ aug_transform = transforms.Lambda(lambda x: x)
+
+ train_transform = transforms.Compose(
+ [
+ aug_transform,
+ transforms.RandomResizedCrop(
+ config.inputs.image_res,
+ scale=(0.5, 1.0),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.RandomHorizontalFlip(),
+ type_transform,
+ normalize,
+ ]
+ )
+
+ return train_transform
+
+def get_test_transform(config, test_file):
+ vision_enc_name = config.model.vision_encoder.name
+ if "internvideo" in vision_enc_name or "vit" in vision_enc_name or "umt" in vision_enc_name:
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ elif "clip" in vision_enc_name:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ std = (0.26862954, 0.26130258, 0.27577711)
+ else:
+ raise NotImplementedError(vision_enc_name)
+
+ normalize = transforms.Normalize(mean, std)
+
+ # loaded images and videos are torch.Tensor of torch.uint8 format,
+ # ordered as (T, 1 or 3, H, W) where T=1 for image
+ type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
+
+ test_transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (config.inputs.image_res, config.inputs.image_res),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ type_transform,
+ normalize,
+ ]
+ )
+ return test_transform
+
+
+def create_dataset(dataset_type, config):
+ ##########################################################
+ # Shared setting for all datasets
+ if config.inputs.get('video_input', None) is not None:
+ video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
+ video_only_dataset_kwargs_train = dict(
+ video_reader_type=video_reader_type,
+ sample_type=config.inputs.video_input.sample_type,
+ num_frames=config.inputs.video_input.num_frames,
+ num_tries=10, # false tolerance
+ )
+ video_only_dataset_kwargs_eval = dict(
+ video_reader_type=video_reader_type,
+ sample_type=config.inputs.video_input.sample_type_test,
+ num_frames=config.inputs.video_input.num_frames_test,
+ num_tries=1, # we want to have predictions for all videos
+ )
+ else:
+ logger.warn("Make sure that you don't need video input!!!")
+ if config.inputs.get('audio_input', None) is not None:
+ audio_reader_type = config.inputs.audio_input.get("audio_reader_type", "torchaudio")
+ audio_only_dataset_kwargs_train = dict(
+ audio_reader_type=audio_reader_type,
+ audio_sample_rate=config.inputs.audio_input.get('audio_sample_rate', 16000),
+ max_audio_length=config.inputs.audio_input.get('max_audio_length', 10),
+ num_tries=10,
+ )
+ audio_only_dataset_kwargs_eval = dict(
+ audio_reader_type=audio_reader_type,
+ audio_sample_rate=config.inputs.audio_input.get('audio_sample_rate_test', 16000),
+ max_audio_length=config.inputs.audio_input.get('max_audio_length', 10),
+ num_tries=1,
+ )
+ else:
+ logger.warn("Make sure that you don't need audio input!!!")
+
+
+ if dataset_type == "pt_train":
+ # convert to list of lists
+ train_files = (
+ [config.train_file] if isinstance(config.train_file, dict) else config.train_file
+ )
+ train_media_types = sorted(list({get_media_type(e) for e in train_files}))
+
+ train_datasets = []
+ for m in train_media_types:
+
+ # dataset of the same media_type will be mixed in a single Dataset object
+ _train_files = [e for e in train_files if get_media_type(e) == m]
+
+ datasets = []
+ sample_weights = []
+ for train_file in _train_files:
+ dataset_cls = get_dataset_cls(dataset_type=dataset_type, media_type=m, data_cfg=train_file)
+ if m == "audio":
+ train_transform = None
+ else:
+ train_transform = get_train_transform(config, train_file)
+ dataset_kwargs = dict(
+ ann_file=train_file,
+ transform=train_transform,
+ num_epochs=config.scheduler.epochs)
+
+ if m == "audio_video":
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
+ dataset_kwargs.update(audio_only_dataset_kwargs_train)
+ elif m == "video":
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
+ elif m == 'audio':
+ dataset_kwargs.update(audio_only_dataset_kwargs_train)
+ elif m != 'image':
+ raise NotImplementedError(m)
+ logger.info(f"dataset_type={dataset_type}, train_file={train_file}")
+ logger.info(dataset_kwargs)
+ logger.info('train_transform:')
+ logger.info(str(train_transform))
+
+ datasets.append(dataset_cls(**dataset_kwargs))
+ sample_weights.append(train_file.get("sample_weight", 1))
+ # assert train_file.get("sample_weight", 1) == 1, train_file
+
+ if sum(sample_weights) > len(sample_weights):
+ logger.info(f'Use ResampleConcatDataset for {m}, sample_weights={sample_weights}')
+ dataset = ResampleConcatDataset(datasets, sample_weights=sample_weights)
+ else:
+ logger.info(f'Use ConcatDataset for {m}')
+ dataset = ConcatDataset(datasets)
+
+ train_datasets.append(dataset)
+
+ return train_datasets
+
+ elif dataset_type == "ret_train":
+ assert isinstance(config.train_file, dict), config.train_file
+ train_transform = get_train_transform(config, config.train_file)
+ dataset_cls = get_dataset_cls(dataset_type=dataset_type,
+ media_type=config.train_file.media_type,
+ data_cfg=config.train_file)
+ if config.train_file.media_type == "video":
+ dataset_kwargs = dict(
+ ann_file=config.train_file,
+ transform=train_transform)
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
+ elif config.train_file.media_type == 'audio':
+ dataset_kwargs = dict(
+ ann_file=config.train_file,
+ transform=None)
+ dataset_kwargs.update(audio_only_dataset_kwargs_train)
+ elif config.train_file.media_type == 'audio_video':
+ dataset_kwargs = dict(
+ ann_file=config.train_file,
+ transform=train_transform)
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
+ dataset_kwargs.update(audio_only_dataset_kwargs_train)
+ else:
+ raise NotImplementedError(config.train_file.media_type)
+
+ logger.info(f"dataset_type={dataset_type}, train_file={config.train_file}")
+ logger.info(dataset_kwargs)
+ logger.info('train_transform:')
+ logger.info(str(train_transform))
+
+ return [dataset_cls(**dataset_kwargs)]
+
+ elif dataset_type == "qa_train":
+ assert type(config.train_file) is dict, f"assuming single train media type but get {config.train_file}"
+
+ media_type = get_media_type(config.train_file[0]) # assuming single train media type
+ if media_type == "audio":
+ train_transform = None
+ else:
+ train_transform = get_train_transform(config, config.train_file)
+
+ dataset_cls = get_dataset_cls(dataset_type=dataset_type,
+ media_type=media_type,
+ data_cfg=config.train_file)
+ dataset_kwargs = dict(
+ ann_file=config.train_file, transform=train_transform, eos=config.eos, mode="train"
+ )
+ if media_type == "video":
+ dataset_kwargs.update(video_only_dataset_kwargs_train)
+ train_dataset = dataset_cls(**dataset_kwargs)
+
+ logger.info(f"dataset_type={dataset_type}, train_file={config.train_file}")
+ logger.info(dataset_kwargs)
+ logger.info('train_transform:')
+ logger.info(str(train_transform))
+
+ return train_dataset
+
+ elif dataset_type in ["pt_eval", "ret_eval", "qa_eval"]:
+ test_datasets = []
+ test_dataset_names = []
+ # multiple test datasets, all separate
+ for name, data_cfg in config.test_file.items():
+ media_type = get_media_type(data_cfg)
+ test_dataset_names.append(name)
+ test_transform = get_test_transform(config, data_cfg)
+
+ if dataset_type == "qa_eval" or (dataset_type == "pt_eval" and "_qa_" in name):
+ test_dataset_cls = get_dataset_cls(dataset_type='qa_eval',
+ media_type=media_type,
+ data_cfg=data_cfg)
+ dataset_kwargs = dict(
+ ann_file=data_cfg,
+ transform=test_transform,
+ eos=config.eos,
+ mode="eval",
+ answer_list=config.answer_list,
+ )
+ if media_type == "video":
+ dataset_kwargs.update(video_only_dataset_kwargs_eval)
+ else:
+ raise NotImplementedError(f"media_type={media_type}")
+ else: # ret
+ test_dataset_cls = get_dataset_cls(dataset_type='ret_eval',
+ media_type=media_type,
+ data_cfg=data_cfg)
+ if media_type == "video":
+ dataset_kwargs = dict(
+ ann_file=data_cfg,
+ transform=test_transform
+ )
+ if "hmdb" in name or 'frame' in name: # read image for video
+ dataset_kwargs["video_reader_type"] = 'img'
+ dataset_kwargs.update(video_only_dataset_kwargs_eval)
+ elif media_type == "audio":
+ dataset_kwargs = dict(
+ ann_file=data_cfg,
+ transform=None)
+ dataset_kwargs.update(audio_only_dataset_kwargs_eval)
+ elif media_type == 'audio_video':
+ dataset_kwargs = dict(
+ ann_file=data_cfg,
+ transform=test_transform)
+ dataset_kwargs.update(video_only_dataset_kwargs_eval)
+ dataset_kwargs.update(audio_only_dataset_kwargs_eval)
+ elif media_type != 'image':
+ raise NotImplementedError(f"media_type={media_type}")
+
+ logger.info(f"dataset_type={dataset_type}, test_file={data_cfg}")
+ logger.info(dataset_kwargs)
+ logger.info('test_transform:')
+ logger.info(str(test_transform))
+
+ test_datasets.append(test_dataset_cls(**dataset_kwargs))
+ return test_datasets, test_dataset_names
+
+ elif dataset_type == "mc_test":
+ test_transform = get_test_transform(config, config.test_file.mc_test)
+ dataset_kwargs = dict(ann_file=[config.test_file.mc_test], transform=test_transform)
+ dataset_kwargs.update(video_only_dataset_kwargs_eval)
+
+ logger.info(f"dataset_type={dataset_type}, test_file={config.test_file}")
+ logger.info(dataset_kwargs)
+ logger.info('test_transform:')
+ logger.info(str(test_transform))
+
+ return VidTxtRetMCEvalDataset(**dataset_kwargs)
+
+ elif dataset_type == "mc_new_test":
+ test_transform = get_test_transform(config, config.test_file.mc_test)
+ dataset_kwargs = dict(ann_file=[config.test_file.mc_test], transform=test_transform)
+ dataset_kwargs.update(video_only_dataset_kwargs_eval)
+
+ logger.info(f"dataset_type={dataset_type}, test_file={config.test_file}")
+ logger.info(dataset_kwargs)
+ logger.info('test_transform:')
+ logger.info(str(test_transform))
+
+ return VidTxtRetMCNewEvalDataset(**dataset_kwargs)
+
+ else:
+ raise NotImplementedError(f"dataset_type={dataset_type}")
+
+def vqa_collate_fn(batch):
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
+ for image, question, answer, weights in batch:
+ image_list.append(image)
+ question_list.append(question)
+ weight_list += weights
+ answer_list += answer
+ n.append(len(answer))
+ return (
+ torch.stack(image_list, dim=0),
+ question_list,
+ answer_list,
+ torch.Tensor(weight_list),
+ n,
+ )
+
+
+def create_sampler(datasets, shuffles, num_tasks, global_rank):
+ samplers = []
+ for dataset, shuffle in zip(datasets, shuffles):
+ sampler = torch.utils.data.DistributedSampler(
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
+ )
+ samplers.append(sampler)
+ return samplers
+
+def create_stateful_sampler(datasets, batch_size):
+ samplers = []
+ for dataset, bs in zip(datasets, batch_size):
+ sampler = StatefulDistributedSampler(dataset, batch_size=bs)
+ samplers.append(sampler)
+ return samplers
+
+def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
+ loaders = []
+ for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
+ datasets, samplers, batch_size, num_workers, is_trains, collate_fns
+ ):
+ if is_train:
+ shuffle = sampler is None
+ drop_last = True
+ pin_memory = True
+ persistent_workers = True if n_worker > 0 else False
+ else:
+ shuffle = False
+ drop_last = False
+ pin_memory = False
+ persistent_workers = False
+ loader = DataLoader(
+ dataset,
+ batch_size=bs,
+ num_workers=n_worker,
+ pin_memory=pin_memory,
+ sampler=sampler,
+ shuffle=shuffle,
+ collate_fn=collate_fn,
+ drop_last=drop_last,
+ persistent_workers=persistent_workers,
+ )
+ loaders.append(loader)
+ return loaders
+
+
+def iterate_dataloaders(dataloaders):
+ """Alternatively generate data from multiple dataloaders,
+ since we use `zip` to concat multiple dataloaders,
+ the loop will end when the smaller dataloader runs out.
+
+ Args:
+ dataloaders List(DataLoader): can be a single or multiple dataloaders
+ """
+ for data_tuples in zip(*dataloaders):
+ for idx, data in enumerate(data_tuples):
+ yield dataloaders[idx].dataset.media_type, data
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/av_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/av_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65f03f53e3de4576df7ca347c242e4f9b05d4ee
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/av_utils.py
@@ -0,0 +1,184 @@
+import av
+import gc
+import torch
+import torchaudio
+import numpy as np
+import random
+import logging
+import io
+from torchvision.transforms.functional import pil_to_tensor
+
+logger = logging.getLogger(__name__)
+
+
+
+def get_index(num_frames, num_segments):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+
+def lazy_load_s3video(s3path_video, num_frames, video_start_frame, video_end_frame, client):
+ # load video from ceph
+ assert client is not None
+ video_bytes_stream = client.get(s3path_video, enable_stream_lazyloding=True)
+ container = av.open(video_bytes_stream)
+ stream = container.streams.video[0]
+ # duration = stream.duration
+ real_fps = container.streams.video[0].average_rate
+ time_base = container.streams.video[0].time_base
+ start, end = video_start_frame, video_end_frame
+ # Convert time to pts
+ duration_frams = end - start + 1
+ frames_index = get_index(duration_frams, num_frames)
+
+ pts_list = []
+
+ start_pts = int((start/real_fps) / time_base)
+ end_pts = int((end/real_fps) / time_base)
+ for frame_index in frames_index:
+ pts_list.append(int((frame_index / real_fps)) / time_base)
+
+ # Seek to nearest key frame from the start
+ container.seek(max(start_pts, 0), stream=stream)
+
+ frames = []
+ for frame in container.decode(**{"video":0}):
+ if frame.pts < start_pts:
+ continue
+ # if frame.pts <= end_pts:
+ if len(pts_list) >0:
+ if frame.pts >= pts_list[0]:
+ frames.append(frame)
+ pts_list.pop(0)
+ else:
+ break
+ frames = [pil_to_tensor(frames[idx].to_rgb().to_image()).unsqueeze(0) for idx in range(len(frames))]
+ container.close()
+ del video_bytes_stream # T C H W
+
+ return torch.cat(frames, dim=0) # , start, end, float(real_fps)
+
+
+def load_audio_av(video_path, video_start_frame, video_end_frame, sr, max_audio_length, client): # sr should be 16000
+ assert client is not None
+ video_bytes_stream = client.get(video_path, enable_stream_lazyloding=True)
+ try:
+ container = av.open(video_bytes_stream)
+ except:
+ logger.warn(f"Something wrong when av.open (video_path: {video_path})!")
+ return None
+ if len(container.streams.audio) == 0:
+ logger.warn(f"There is no audio! (video_path: {video_path})!")
+ return None
+ audio_stream = container.streams.audio[0]
+ real_fps = container.streams.video[0].average_rate
+ time_base = audio_stream.time_base
+ csr = audio_stream.sample_rate
+ start_frame, end_frame = video_start_frame, video_end_frame
+ start_pts = int((start_frame/real_fps) / time_base)
+ end_pts = int((end_frame/real_fps) / time_base)
+ frames = []
+ container.seek(max(start_pts, 0), stream=audio_stream)
+ try:
+ for frame in container.decode(**{"audio":0}):
+ if frame.pts < start_pts:
+ continue
+ frames.append(frame.to_ndarray())
+ if frame.pts > end_pts:
+ break
+ except:
+ gc.collect()
+ pass
+ # gc.collect()
+ container.close()
+ del video_bytes_stream
+
+ audio_raw = np.concatenate(frames, 1)
+ audio = torch.from_numpy(audio_raw)
+ if audio.size(0) == 2:
+ audio = torch.mean(audio, dim=0, keepdim=True)
+ if len(audio.shape) == 1:
+ audio = audio.unsqueeze(0)
+ assert max_audio_length == 10, max_audio_length
+ max_length = max_audio_length * sr
+ if csr != sr:
+ trans = torchaudio.transforms.Resample(csr, sr)
+ audio = trans(audio)
+ if audio.shape[1] >= max_length:
+ max_start = audio.shape[1] - max_length
+ start = random.randint(0, max_start)
+ audio = audio[:, start: start + max_length]
+ audio = audio * 2 ** 15
+ fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10)
+ fbank_mean = 15.41663
+ fbank_std = 6.55582
+ fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64
+
+ src_length = fbank.shape[0]
+ pad_len = 998 - src_length
+ fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
+ padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool()
+
+ return fbank#, padding_mask
+
+def load_full_audio_av(video_path, sr, max_audio_length, client):
+ assert client is not None
+ video_bytes_stream = client.get(video_path) #, enable_stream_lazyloding=False)
+ try:
+ container = av.open(io.BytesIO(video_bytes_stream))
+ except Exception as e:
+ logger.warn(f"Something wrong {e} when av.open (video_path: {video_path})!")
+ return None
+ if len(container.streams.audio) == 0:
+ logger.warn(f"There is no audio! (video_path: {video_path})!")
+ return None
+ audio_stream = container.streams.audio[0]
+ csr = audio_stream.sample_rate
+ frames = []
+ try:
+ for frame in container.decode(**{"audio":0}):
+ frames.append(frame.to_ndarray())
+ except:
+ gc.collect()
+ pass
+ # gc.collect()
+ container.close()
+ del video_bytes_stream
+
+ audio_raw = np.concatenate(frames, 1)
+ audio = torch.from_numpy(audio_raw)
+ if audio.size(0) == 2:
+ audio = torch.mean(audio, dim=0, keepdim=True)
+ if len(audio.shape)==1:
+ audio = audio.unsqueeze(0)
+ assert max_audio_length == 10, max_audio_length
+ max_length = max_audio_length * sr
+ if csr != sr:
+ trans = torchaudio.transforms.Resample(csr, sr)
+ audio = trans(audio)
+ if audio.shape[1] >= max_length:
+ max_start = audio.shape[1] - max_length
+ start = random.randint(0, max_start)
+ audio = audio[:, start: start + max_length]
+ audio = audio * 2 ** 15
+ fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10)
+ fbank_mean = 15.41663
+ fbank_std = 6.55582
+ fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64
+
+ src_length = fbank.shape[0]
+ pad_len = 998 - src_length
+ fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
+ padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool()
+
+ return fbank #, padding_mask
+
+
+ # frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
+ # # frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
+
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/base_dataset.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..31877874e8e0bd6f1fa3cfa8ba3ac8c8d956552e
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/base_dataset.py
@@ -0,0 +1,111 @@
+import logging
+import os
+import random
+try:
+ from petrel_client.client import Client
+except:
+ Client = None
+from torch.utils.data import Dataset
+from .utils import load_image_from_path
+from .av_utils import lazy_load_s3video
+
+logger = logging.getLogger(__name__)
+
+
+class BaseDataset(Dataset):
+ """Base class that implements the image and video loading methods"""
+
+ media_type = "video"
+
+ def __init__(self):
+ assert self.media_type in ["audio", "image", "video", "audio_video"]
+ self.data_root = None
+ self.data_root_prefix = ""
+ self.anno_list = (
+ None # list(dict), each dict contains {"image": str, # image or video path}
+ )
+ self.transform = None
+ self.audio_reader_type = None
+ self.audio_sample_rate = None
+ self.max_audio_length = None
+ self.video_reader = None
+ self.num_tries = None
+ self.client = Client('~/petreloss.conf') if Client is not None else None
+ self.trimmed30 = False
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def get_anno(self, index): # NOTE used for most ret_dataset
+ """obtain the annotation for one media (video or image)
+
+ Args:
+ index (int): The media index.
+
+ Returns: dict.
+ - "image": the filename, video also use "image".
+ - "caption": The caption for this file.
+
+ """
+ anno = self.anno_list[index]
+ if self.data_root is not None:
+ if self.media_type == "audio":
+ anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, anno["audio"])
+ else:
+ anno["image"] = self.data_root_prefix + os.path.join(self.data_root, anno["image"])
+ return anno
+
+ def load_and_transform_media_data(self, index, data_path):
+ try:
+ if self.media_type == "image":
+ return self.load_and_transform_media_data_image(index, data_path)
+ elif self.media_type == "audio":
+ return self.load_and_transform_media_data_audio(index, data_path)
+ elif self.media_type == "video":
+ return self.load_and_transform_media_data_video(index, data_path)
+ elif self.media_type == "audio_video":
+ return self.load_and_transform_media_data_audio_video(index, data_path)
+ else:
+ raise NotImplementedError(self.media_type)
+ except Exception as e:
+ logger.info(f"Something wrong when read {data_path}")
+ raise e
+
+ def load_and_transform_media_data_image(self, index, data_path):
+ if type(data_path) is dict:
+ image = load_image_from_path(data_path["image"], client=self.client)
+ if "crop_bbox" in data_path.keys():
+ bbox = data_path["crop_bbox"]
+ x0, y0, x1, y1 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+ image = image[:, :, y0:y1, x0:x1]
+
+ image = self.transform(image)
+ else:
+
+ image = load_image_from_path(data_path, client=self.client)
+ image = self.transform(image)
+ return image, index
+
+ def load_and_transform_media_data_video(self, index, data_path):
+ if type(data_path) is dict:
+ if data_path['read_clip_from_video']:
+ if self.trimmed30:
+ raise NotImplementedError("lazy_load_s3video does not support trimmed30")
+ frames = lazy_load_s3video(data_path['video'], self.num_frames, data_path['video_start_frame'], data_path['video_end_frame'], self.client)
+ else:
+ raise NotImplementedError(data_path)
+ else:
+ max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
+ frames, frame_indices, video_duration = self.video_reader(
+ data_path, self.num_frames, self.sample_type,
+ max_num_frames=max_num_frames, client=self.client,
+ trimmed30=self.trimmed30
+ )
+
+ # NOTE shared aug for video frames
+ frames = self.transform(frames)
+ return frames, index
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/dataloader.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f958d0fdfbf1a39ce082497f2700ef6789c8c45
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/dataloader.py
@@ -0,0 +1,121 @@
+import torch
+import torch.distributed as dist
+from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
+import random
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class MetaLoader(object):
+ """ wraps multiple data loader """
+ def __init__(self, name2loader):
+ """Iterates over multiple dataloaders, it ensures all processes
+ work on data from the same dataloader. This loader will end when
+ the shorter dataloader raises StopIteration exception.
+
+ loaders: Dict, {name: dataloader}
+ """
+ self.name2loader = name2loader
+ self.name2iter = {name: iter(l) for name, l in name2loader.items()}
+ name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
+ index2name = {v: k for k, v in name2index.items()}
+
+ iter_order = []
+ for n, l in name2loader.items():
+ iter_order.extend([name2index[n]]*len(l))
+
+ random.shuffle(iter_order)
+ iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
+
+ # sync
+ if is_dist_avail_and_initialized():
+ # make sure all processes have the same order so that
+ # each step they will have data from the same loader
+ dist.broadcast(iter_order, src=0)
+ self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
+
+ logger.info(str(self))
+
+ def __str__(self):
+ output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
+ for idx, (name, loader) in enumerate(self.name2loader.items()):
+ output.append(
+ f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
+ )
+ return "\n".join(output)
+
+ def __len__(self):
+ return len(self.iter_order)
+
+ def __iter__(self):
+ """ this iterator will run indefinitely """
+ for name in self.iter_order:
+ _iter = self.name2iter[name]
+ batch = next(_iter)
+ yield name, batch
+
+
+class MetaLoader_rs(object):
+ """ wraps multiple data loader """
+ def __init__(self, name2loader, skip_num=0):
+ """Iterates over multiple dataloaders, it ensures all processes
+ work on data from the same dataloader. This loader will end when
+ the shorter dataloader raises StopIteration exception.
+
+ loaders: Dict, {name: dataloader}
+ """
+ self.name2loader = name2loader
+ name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
+ index2name = {v: k for k, v in name2index.items()}
+
+ iter_order = []
+ for n, l in name2loader.items():
+ iter_order.extend([name2index[n]]*len(l))
+
+ random.shuffle(iter_order)
+ iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
+
+ # sync
+ if is_dist_avail_and_initialized():
+ # make sure all processes have the same order so that
+ # each step they will have data from the same loader
+ dist.broadcast(iter_order, src=0)
+
+ if skip_num > 0:
+ iter_order_skip = iter_order[:skip_num]
+ for k, v in index2name.items():
+ media_step = (iter_order_skip == k).sum().item()
+ name2loader[v].sampler.set_start_iter(media_step)
+ logger.info(f"{v} dataloder skip steps: {media_step}")
+ iter_order = iter_order[skip_num:]
+ self.name2loader = name2loader
+ else:
+ logger.info("Do not skip steps for any dataloader!")
+ for k, v in index2name.items():
+ name2loader[v].sampler.set_start_iter(0)
+
+ self.name2iter = {name: iter(l) for name, l in name2loader.items()}
+ self.iter_idx = iter_order
+ self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
+
+ logger.info(str(self))
+
+ def __str__(self):
+ output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
+ for idx, (name, loader) in enumerate(self.name2loader.items()):
+ length = (self.iter_idx == idx).sum()
+ output.append(
+ f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} "
+ )
+ return "\n".join(output)
+
+ def __len__(self):
+ return len(self.iter_order)
+
+ def __iter__(self):
+ """ this iterator will run indefinitely """
+ for name in self.iter_order:
+ _iter = self.name2iter[name]
+ batch = next(_iter)
+ yield name, batch
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/pt_dataset.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/pt_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd1f148fccdc50c0dbfc9ea093db9105a0bdf7c7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/pt_dataset.py
@@ -0,0 +1,489 @@
+import logging
+import os
+import json
+import random
+import io
+import torch
+import numpy as np
+
+from dataset.base_dataset import BaseDataset
+from dataset.text_prompt import kinetics_templates, imagenet_templates
+from dataset.utils import pre_text
+from dataset.video_utils import VIDEO_READER_FUNCS
+from dataset.serialize import get_local_rank, TorchShmSerializedList
+
+logger = logging.getLogger(__name__)
+
+
+class ImgTxtPtTrainDataset(BaseDataset):
+ media_type = "image"
+
+ def __init__(self, ann_file, transform, num_epochs=1):
+ super().__init__()
+
+ logger.info(f"ann_file: {ann_file}")
+
+ self.media_type = ann_file.media_type
+ self.label_file = ann_file.anno_path
+ self.data_root = ann_file.data_root
+ self.data_root_prefix = ann_file.get("data_root_prefix", "")
+ self.min_caption_length = ann_file.get("min_caption_length", 2)
+ self.caption_augmentation = ann_file.get("caption_augmentation", None)
+ self.transform = transform
+ # each caption has multiple image as ground_truth, e.g., ssv2
+ self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False)
+ assert not self.has_multi_vision_gt
+
+ self.crop_img = ann_file.get("crop_img", False)
+
+ self.use_prompt = ann_file.get("prompt", "") != ""
+ if self.use_prompt:
+ if ann_file.prompt == "imagenet":
+ self.prompt = imagenet_templates
+ logger.info(f"Use prompt for ImageNet")
+ elif ann_file.prompt == "kinetics":
+ self.prompt = kinetics_templates
+ logger.info(f"Use prompt for Kinetics")
+ else:
+ raise NotImplementedError(ann_file.prompt)
+ logger.info(self.prompt)
+
+
+ if self.use_prompt and self.caption_augmentation is not None:
+ raise NotImplementedError("You can't use prompt because of multiple captions!")
+
+
+ if '.json' in self.label_file:
+ logger.info(f"Loading json file {self.label_file}")
+
+ if get_local_rank() == 0: # Only one rank need to read the file
+ with io.BytesIO(self.client.get(self.label_file)) as f:
+ # with open(self.label_file, 'r') as f:
+ annos = json.load(f)
+
+ if ann_file.get("jump_filter", False):
+ logger.info("Jump filter!")
+ else:
+ if self.caption_augmentation is not None:
+ # filter out the caption with length less than min_caption_length
+ new_annos = []
+ if self.media_type == "audio_video" and self.caption_augmentation.caption_sample_type == 'avs_all':
+ for anno in annos:
+ ok = True
+ if not anno['video'].endswith('.mp4'):
+ ok = False
+ for k in anno.keys():
+ if "caption" in k and 'asr' not in k:
+ tmp_c = pre_text(anno[k])
+ if len(tmp_c.split()) < self.min_caption_length:
+ ok = False
+ break
+ if ok:
+ new_annos.append(anno)
+ elif self.caption_augmentation.caption_sample_type == 'uniform':
+ for anno in annos:
+ if "captions" in anno.keys():
+ caption_key = "captions"
+ else:
+ caption_key = "caption"
+
+ assert type(anno[caption_key]) is list, type(anno[caption_key])
+ caption_list = []
+ for c in anno[caption_key]:
+ tmp_c = pre_text(c)
+ if len(tmp_c.split()) >= self.min_caption_length:
+ caption_list.append(tmp_c)
+
+ if len(caption_list) > 0:
+ new_annos.append(anno)
+ else:
+ raise NotImplementedError(ann_file)
+
+ logger.info(f"Num samples: {len(annos)}")
+ logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}")
+ annos = new_annos
+ else:
+ # filter out the caption with length less than min_caption_length
+ captions = [pre_text(anno["caption"]) for anno in annos]
+ captions_len = [len(caption.split()) for caption in captions]
+ logger.info("Num samples: {}".format(len(captions)))
+ logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len])))
+ annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length]
+ if num_epochs < 1:
+ raise NotImplementedError
+ else:
+ annos = []
+
+ self.anno = TorchShmSerializedList(annos)
+ self.num_examples = len(self.anno)
+ logger.info(f"num_examples: {self.num_examples}")
+
+ else:
+ raise NotImplementedError("We need json file!!!")
+
+ def __len__(self):
+ return self.num_examples
+
+ def get_caption(self, index):
+ if '.json' in self.label_file:
+ if self.caption_augmentation is not None:
+ if self.caption_augmentation.caption_sample_type == 'avs_all':
+ caption_dict = {}
+ for k in self.anno[index].keys():
+ if 'caption' in k:
+ caption_dict[k] = self.anno[index][k]
+ else:
+ if "captions" in self.anno[index].keys():
+ captions = self.anno[index]["captions"]
+ else:
+ captions = self.anno[index]["caption"]
+ else:
+ caption = self.anno[index]["caption"]
+ else:
+ raise NotImplementedError
+
+ if self.caption_augmentation is not None:
+ if self.caption_augmentation.caption_sample_type == 'uniform':
+ caption = random.choice(captions)
+ elif self.caption_augmentation.caption_sample_type == 'avs_all':
+ caption = caption_dict
+ else:
+ raise NotImplementedError
+ return caption
+
+ def get_anno(self, index):
+ assert self.media_type == 'image', self.media_type
+ anno = {"caption": self.get_caption(index)}
+ anno["image"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["image"])
+ if self.use_prompt:
+ anno["caption"] = random.choice(self.prompt).format(anno["caption"])
+ if self.crop_img:
+ anno["crop_bbox"] = self.anno[index]["crop_bbox"]
+ return anno
+
+ def pre_caption(self, caption):
+ if type(caption) is str:
+ return pre_text(caption)
+ elif type(caption) is dict:
+ assert self.caption_augmentation.caption_sample_type == 'avs_all'
+ caption_dict = {}
+ for k in caption.keys():
+ caption_dict[k] = pre_text(caption[k])
+ return caption_dict
+ else:
+ raise NotImplementedError(caption)
+
+ def __getitem__(self, index):
+ try:
+ ann = self.get_anno(index)
+ caption = self.pre_caption(ann["caption"])
+ # key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"])
+ if self.crop_img:
+ data_path = {"image":ann["image"], "crop_bbox":ann["crop_bbox"]}
+ image, index = self.load_and_transform_media_data(index, data_path)
+ else:
+ image, index = self.load_and_transform_media_data(index, ann["image"])
+ return image, caption, index
+ except Exception as e:
+ logger.warning(f"Caught exception {e} when loading image {ann}")
+ # raise e
+ print(e)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class VidTxtPtTrainDataset(ImgTxtPtTrainDataset):
+ media_type = "video"
+
+ def __init__(
+ self,
+ ann_file,
+ transform,
+ num_frames=4,
+ video_reader_type="decord",
+ sample_type="rand",
+ num_tries=3,
+ num_epochs=1
+ ):
+ super().__init__(ann_file, transform, num_epochs)
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
+
+ self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False)
+ self.read_clip_from_video = ann_file.get("read_clip_from_video", False)
+
+ if self.is_paragraph_retrieval:
+ raise NotImplementedError
+
+ def get_anno(self, index):
+ assert self.media_type == "video", self.media_type
+ anno = {"caption": self.get_caption(index)}
+ anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"])
+ if self.read_clip_from_video:
+ anno["video_start_frame"] = self.anno[index]["video_start_frame"]
+ anno["video_end_frame"] = self.anno[index]["video_end_frame"]
+ if self.use_prompt:
+ anno["caption"] = random.choice(self.prompt).format(anno["caption"])
+
+ return anno
+
+ def __getitem__(self, index):
+ try:
+ ann = self.get_anno(index)
+ caption = self.pre_caption(ann["caption"])
+
+ if self.read_clip_from_video:
+ data_path = {
+ "video": ann["video"],
+ "video_start_frame": ann["video_start_frame"],
+ "video_end_frame": ann["video_end_frame"],
+ "read_clip_from_video": True
+ }
+ else:
+ data_path = ann["video"]
+
+ video, index = self.load_and_transform_media_data(index, data_path)
+
+ return video, caption, index
+
+ except Exception as e:
+ logger.warning(f"Caught exception {e} when loading video {ann}")
+ # raise e
+ print(e)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class AudioVidTxtPtTrainDataset(VidTxtPtTrainDataset):
+ media_type = "audio_video"
+
+ def __init__(
+ self,
+ ann_file,
+ transform,
+ audio_sample_rate=16000,
+ audio_reader_type='torchaudio',
+ max_audio_length=10,
+ num_frames=4,
+ video_reader_type="decord",
+ sample_type="rand",
+ num_tries=3,
+ num_epochs=1
+ ):
+ super().__init__(ann_file, transform, num_epochs=num_epochs, num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries)
+
+ assert self.media_type == 'audio_video', self.media_type
+ self.audio_sample_rate = audio_sample_rate
+ self.audio_reader_type = audio_reader_type
+ self.max_audio_length = max_audio_length
+
+ self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
+ self.read_audio_from_video = ann_file.get("read_audio_from_video", False)
+ self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False)
+
+ self.now_tries = 0
+
+ def get_anno(self, index):
+ anno = {"caption": self.get_caption(index)}
+ anno["video"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["video"])
+ if self.read_clip_from_video:
+ anno["video_start_frame"] = self.anno[index]["video_start_frame"]
+ anno["video_end_frame"] = self.anno[index]["video_end_frame"]
+
+ if "audio" in self.anno[index].keys():
+ anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"])
+
+ if self.use_prompt:
+ anno["caption"] = random.choice(self.prompt).format(anno["caption"])
+
+ return anno
+
+ def __getitem__(self, index):
+ try:
+ ann = self.get_anno(index)
+ caption = self.pre_caption(ann["caption"])
+ data_path = {'video': ann["video"]}
+
+ if self.read_clip_from_video:
+ data_path["video_start_frame"] = ann["video_start_frame"]
+ data_path["video_end_frame"] = ann["video_end_frame"]
+
+ if "audio" in ann.keys():
+ data_path["read_audio_from_video"] = False
+ data_path["audio"] = ann["audio"]
+ else:
+ data_path["read_audio_from_video"] = self.read_audio_from_video
+
+ data_path["read_clip_from_video"] = self.read_clip_from_video
+
+ media, index = self.load_and_transform_media_data(index, data_path)
+ self.now_tries = 0
+
+ audio = media[0]
+ if audio is None and self.zero_audio_padding_for_video:
+ logger.warning(f"No audio in {data_path}")
+ media[0] = torch.zeros((998, 64), dtype=torch.float32)
+
+ return media, caption, index
+
+ except Exception as e:
+ # print(e)
+ if self.num_tries < self.now_tries:
+ raise e
+ else:
+ self.now_tries += 1
+ logger.warning(f"Caught exception {e} when loading audio-video {ann}")
+ # logger.warning(f"Caught exception when loading audio-video {ann['video']}")
+
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class AudioTxtPtTrainDataset(BaseDataset):
+ media_type = "audio"
+
+ def __init__(self, ann_file, transform,
+ audio_sample_rate=16000,
+ audio_reader_type='torchaudio',
+ max_audio_length=10,
+ num_tries=3,
+ num_epochs=1):
+ super().__init__()
+
+ logger.info(f"ann_file: {ann_file}")
+
+ self.media_type = ann_file.media_type
+ self.label_file = ann_file.anno_path
+ self.data_root = ann_file.data_root
+ self.data_root_prefix = ann_file.get("data_root_prefix", "")
+ self.min_caption_length = ann_file.get("min_caption_length", 2)
+ self.caption_augmentation = ann_file.get("caption_augmentation", None)
+ self.transform = transform
+ self.audio_sample_rate = audio_sample_rate
+ self.max_audio_length = max_audio_length
+ self.audio_reader_type = audio_reader_type
+ self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
+ assert not self.has_multi_audio_gt
+
+ self.use_prompt = ann_file.get("prompt", "") != ""
+ if self.use_prompt:
+ if ann_file.prompt == "imagenet":
+ self.prompt = imagenet_templates
+ logger.info(f"Use prompt for ImageNet")
+ elif ann_file.prompt == "kinetics":
+ self.prompt = kinetics_templates
+ logger.info(f"Use prompt for Kinetics")
+ else:
+ raise NotImplementedError(ann_file.prompt)
+ logger.info(self.prompt)
+
+
+ if self.use_prompt and self.caption_augmentation is not None:
+ raise NotImplementedError("You can't use prompt because of multiple captions!")
+
+ if '.json' in self.label_file:
+ logger.info(f"Loading json file {self.label_file}")
+
+ if get_local_rank() == 0: # Only one rank need to read the file
+ with io.BytesIO(self.client.get(self.label_file)) as f:
+ # with open(self.label_file, 'r') as f:
+ annos = json.load(f)
+
+ if ann_file.get("jump_filter", False):
+ logger.info("Jump filter!")
+ else:
+ if self.caption_augmentation is not None:
+ # filter out the caption with length less than min_caption_length
+ new_annos = []
+ if self.caption_augmentation.caption_sample_type == 'uniform':
+ for anno in annos:
+ if "captions" in anno.keys():
+ caption_key = "captions"
+ else:
+ caption_key = "caption"
+
+ assert type(anno[caption_key]) is list, type(anno[caption_key])
+ caption_list = []
+ for c in anno[caption_key]:
+ tmp_c = pre_text(c)
+ if len(tmp_c.split()) >= self.min_caption_length:
+ caption_list.append(tmp_c)
+
+ if len(caption_list) > 0:
+ new_annos.append(anno)
+ else:
+ raise NotImplementedError(ann_file)
+
+ logger.info(f"Num samples: {len(annos)}")
+ logger.info(f"Num samples not too short: {len(new_annos)} min_caption_length={self.min_caption_length}")
+ annos = new_annos
+ else:
+ # filter out the caption with length less than min_caption_length
+ captions = [pre_text(anno["caption"]) for anno in annos]
+ captions_len = [len(caption.split()) for caption in captions]
+ logger.info("Num samples: {}".format(len(captions)))
+ logger.info("Num samples too short: {}".format(sum([l < self.min_caption_length for l in captions_len])))
+ annos = [anno for anno, l in zip(annos, captions_len) if l >= self.min_caption_length]
+ if num_epochs < 1:
+ raise NotImplementedError
+ else:
+ annos = []
+
+ self.anno = TorchShmSerializedList(annos)
+ self.num_examples = len(self.anno)
+ logger.info(f"num_examples: {self.num_examples}")
+
+ else:
+ raise NotImplementedError("We need json file!!!")
+
+ def __len__(self):
+ return self.num_examples
+
+ def get_caption(self, index):
+ if '.json' in self.label_file:
+ if self.caption_augmentation is not None:
+ if "captions" in self.anno[index].keys():
+ captions = self.anno[index]["captions"]
+ else:
+ captions = self.anno[index]["caption"]
+ else:
+ caption = self.anno[index]["caption"]
+ else:
+ raise NotImplementedError
+
+ if self.caption_augmentation is not None:
+ if self.caption_augmentation.caption_sample_type == 'uniform':
+ caption = random.choice(captions)
+ else:
+ raise NotImplementedError
+ return caption
+
+ def get_anno(self, index):
+ assert self.media_type == 'audio', self.media_type
+ anno = {"caption": self.get_caption(index)}
+ anno["audio"] = self.data_root_prefix + os.path.join(self.data_root, self.anno[index]["audio"])
+ if self.use_prompt:
+ anno["caption"] = random.choice(self.prompt).format(anno["caption"])
+
+ return anno
+
+ def pre_caption(self, caption):
+ if type(caption) is str:
+ return pre_text(caption)
+ else:
+ raise NotImplementedError(caption)
+
+ def __getitem__(self, index):
+ try:
+ ann = self.get_anno(index)
+ caption = self.pre_caption(ann["caption"])
+ audio, index = self.load_and_transform_media_data(index, ann["audio"])
+ return audio, caption, index
+ except Exception as e:
+ logger.warning(f"Caught exception {e} when loading audio {ann}")
+ print(e)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/qa_dataset.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/qa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b62fb5ceb10a05a50cd7a9d001df8fd2ec129b6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/qa_dataset.py
@@ -0,0 +1,70 @@
+import json
+from dataset.base_dataset import BaseDataset
+from dataset.utils import pre_text, load_anno
+from dataset.video_utils import VIDEO_READER_FUNCS
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ImageQADataset(BaseDataset):
+ media_type = "image"
+
+ def __init__(self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None):
+ super(ImageQADataset, self).__init__()
+ assert mode in ["train", "eval"]
+ self.mode = mode
+ self.transform = transform
+ self.eos = eos
+
+ self.anno_list = load_anno(ann_file)
+
+ if mode == "eval":
+ self.answer_list = json.load(open(answer_list, "r"))
+
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def get_answers_with_weights(self, raw_answers):
+ if isinstance(raw_answers, str):
+ raw_answers = [raw_answers]
+ answer_weight = {}
+ for answer in raw_answers:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1/len(raw_answers)
+ else:
+ answer_weight[answer] = 1/len(raw_answers)
+
+ answers = list(answer_weight.keys())
+ weights = [answer_weight[a] for a in answers]
+ answers = [answer + " " + self.eos for answer in answers]
+ return answers, weights
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ image, index = self.load_and_transform_media_data(index, ann["media"])
+
+ question = pre_text(ann["question"])
+ if self.mode == "train":
+ answers, weights = self.get_answers_with_weights(ann["answer"])
+ return image, question, answers, weights
+ else: # self.mode == "eval":
+ question_id = ann["question_id"]
+ return image, question, question_id
+
+
+class VideoQADataset(ImageQADataset):
+ media_type = "video"
+
+ def __init__(
+ self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None,
+ num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1
+ ):
+ super(VideoQADataset, self).__init__(
+ ann_file, transform, eos, mode, answer_list)
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/resample_concat_dataset.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/resample_concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..608bd32a317045167ecdf36c4a6c14fc22c4647d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/resample_concat_dataset.py
@@ -0,0 +1,74 @@
+import bisect
+import warnings
+import logging
+from typing import (
+ Iterable,
+ List,
+ TypeVar,
+)
+
+logger = logging.getLogger(__name__)
+
+T_co = TypeVar('T_co', covariant=True)
+T = TypeVar('T')
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ResampleConcatDataset(Dataset[T_co]):
+ r"""Dataset as a concatenation of multiple datasets.
+
+ This class is useful to assemble different existing datasets.
+
+ Args:
+ datasets (sequence): List of datasets to be concatenated
+ """
+ datasets: List[Dataset[T_co]]
+ cumulative_sizes: List[int]
+
+ @staticmethod
+ def cumsum_with_sample_weight(sequence, sample_weights):
+ r, s = [], 0
+ for i, e in enumerate(sequence):
+ l = int(len(e) * sample_weights[i]) # NOTE
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset], sample_weights: List[int]) -> None:
+ super(ResampleConcatDataset, self).__init__()
+
+ self.datasets = list(datasets)
+ self.sample_weights = sample_weights
+ assert len(self.datasets) == len(self.sample_weights), f"{len(self.datasets)} != {len(self.sample_weights)}"
+ logging.info(f"datasets: {self.datasets} sample weight: {self.sample_weights}")
+ for i in range(len(self.sample_weights)):
+ assert self.sample_weights[i] >= 1
+
+ assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
+ for d in self.datasets:
+ assert not isinstance(d, IterableDataset), "ResampleConcatDataset does not support IterableDataset"
+ self.cumulative_sizes = self.cumsum_with_sample_weight(self.datasets, self.sample_weights)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ sample_idx = sample_idx // self.sample_weights[dataset_idx] # NOTE
+ return self.datasets[dataset_idx][sample_idx]
+
+ @property
+ def cummulative_sizes(self):
+ warnings.warn("cummulative_sizes attribute is renamed to "
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
+ return self.cumulative_sizes
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/ret_dataset.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/ret_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..22bf4cffa497adbdf4d748b7b71f38cd6a9fd905
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/ret_dataset.py
@@ -0,0 +1,490 @@
+from os.path import basename
+import numpy as np
+import logging
+import torch
+
+from dataset.base_dataset import BaseDataset
+from dataset.utils import load_anno, pre_text
+from dataset.video_utils import VIDEO_READER_FUNCS
+from dataset.text_prompt import kinetics_templates_action_clip as kinetics_templates
+
+logger = logging.getLogger(__name__)
+
+
+class AudioTxtRetTrainDataset(BaseDataset):
+ media_type = "audio"
+
+ def __init__(
+ self, ann_file, transform, audio_sample_rate,
+ audio_reader_type='librosa', max_audio_length=0, num_tries=3):
+ super(AudioTxtRetTrainDataset, self).__init__()
+ self.anno_list = load_anno(ann_file)
+ self.transform = transform
+ self.audio_reader_type = audio_reader_type
+ self.num_tries = num_tries
+ self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
+ self.trimmed30 = ann_file.get("trimmed30", False)
+
+ self.max_audio_length = max_audio_length
+ self.audio_sample_rate = audio_sample_rate
+ self.match_ids = {}
+
+ n = 0
+ for ann in self.anno_list:
+ key = ann["caption"] if self.has_multi_audio_gt else basename(ann["image"])
+ if key not in self.match_ids:
+ self.match_ids[key] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ try:
+ ann = self.anno_list[index]
+ audio, index = self.load_and_transform_media_data(index, ann['image'])
+ caption = pre_text(ann["caption"])
+ key = ann["caption"] if self.has_multi_audio_gt else basename(ann["image"])
+ return audio, caption, self.match_ids[key]
+ except Exception as e:
+ logger.error(e)
+ print(e, flush=True)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class AudioTxtRetEvalDataset(BaseDataset):
+ media_type = "audio"
+
+ def __init__(
+ self, ann_file, transform, audio_sample_rate,
+ audio_reader_type='librosa', max_audio_length=0, num_tries=3):
+ super(AudioTxtRetEvalDataset, self).__init__()
+ self.anno_list = load_anno(ann_file)
+ self.transform = transform
+ self.audio_sample_rate = audio_sample_rate
+ self.max_audio_length = max_audio_length
+ self.audio_reader_type = audio_reader_type
+ self.num_tries = num_tries
+ self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
+ self.trimmed30 = ann_file.get("trimmed30", False)
+ self.max_txt_l = ann_file.get("max_txt_l", 32)
+
+ self.text = None
+ self.audio = None
+ self.txt2img = None
+ self.img2txt = None
+
+ self.build_data()
+
+ def build_data(self):
+ self.text = []
+ self.audio = []
+ self.txt2img = {}
+ self.img2txt = {}
+ if self.has_multi_audio_gt:
+ self.build_data_multi_audio_gt()
+ else:
+ self.build_data_multi_txt_gt()
+
+ def build_data_multi_audio_gt(self):
+ """each text may have multiple ground_truth audio, e.g., ssv2"""
+ audio_id = 0
+ for txt_id, ann in enumerate(self.anno_list):
+ self.text.append(pre_text(ann["caption"]))
+ self.txt2img[txt_id] = []
+ _audios = ann["image"] \
+ if isinstance(ann["image"], list) else [ann["image"], ]
+ for i, audio in enumerate(_audios):
+ self.audio.append(audio)
+ self.txt2img[txt_id].append(audio_id)
+ self.img2txt[audio_id] = txt_id
+ audio_id += 1
+
+ def build_data_multi_txt_gt(self):
+ """each audio may have multiple ground_truth text, e.g., COCO and Flickr30K"""
+ txt_id = 0
+ for audio_id, ann in enumerate(self.anno_list):
+ self.audio.append(ann["image"])
+ self.img2txt[audio_id] = []
+ _captions = ann["caption"] \
+ if isinstance(ann["caption"], list) else [ann["caption"], ]
+ for i, caption in enumerate(_captions):
+ self.text.append(pre_text(caption))
+ self.img2txt[audio_id].append(txt_id)
+ self.txt2img[txt_id] = audio_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ audio, index = self.load_and_transform_media_data(index, ann["image"])
+ return audio, index
+
+
+
+class ImgTxtRetTrainDataset(BaseDataset):
+ media_type = "image"
+
+ def __init__(self, ann_file, transform):
+ super(ImgTxtRetTrainDataset, self).__init__()
+ self.anno_list = load_anno(ann_file)
+ self.transform = transform
+ # each caption has multiple image as ground_truth, e.g., ssv2
+ self.has_multi_txt_gt = ann_file.get("has_multi_txt_gt", False)
+ self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False)
+
+ if self.has_multi_txt_gt:
+ logger.info("The dataset has multiple ground truth for a image/video!")
+ tmp_anno_list = []
+ for ann in self.anno_list:
+ img_path = ann["image"]
+ for caption in ann["caption"]:
+ tmp_anno_list.append({
+ "image": img_path,
+ "caption": caption
+ })
+ self.anno_list = tmp_anno_list
+
+ self.match_ids = {}
+ n = 0
+ for ann in self.anno_list:
+ key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"])
+ if key not in self.match_ids:
+ self.match_ids[key] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ try:
+ ann = self.anno_list[index]
+ image, index = self.load_and_transform_media_data(index, ann["image"])
+ caption = pre_text(ann["caption"])
+ key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"])
+ return image, caption, self.match_ids[key]
+ except Exception as e:
+ logger.error(e)
+ print(e, flush=True)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class VidTxtRetTrainDataset(ImgTxtRetTrainDataset):
+ media_type = "video"
+
+ def __init__(
+ self, ann_file, transform, num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=3):
+ super(VidTxtRetTrainDataset, self).__init__(ann_file, transform)
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
+ self.read_clip_from_video = ann_file.get("read_clip_from_video", False)
+ if self.read_clip_from_video:
+ raise NotImplementedError("key for match_ids is not implemented!")
+ self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False)
+ if self.is_paragraph_retrieval:
+ self.anno_list = preprocess_para_retrieval_data(self.anno_list)
+ self.trimmed30 = ann_file.get("trimmed30", False)
+ if self.trimmed30:
+ logger.info("Trimming the video, only use the first 30s!")
+
+
+class AudioVidTxtRetTrainDataset(VidTxtRetTrainDataset):
+ media_type = "audio_video"
+
+ def __init__(
+ self, ann_file, transform,
+ audio_sample_rate=16000,
+ audio_reader_type='torchaudio',
+ max_audio_length=10,
+ num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=3):
+ super(AudioVidTxtRetTrainDataset, self).__init__(ann_file, transform,
+ num_frames=num_frames, video_reader_type=video_reader_type, sample_type=sample_type, num_tries=num_tries)
+
+ assert self.media_type == 'audio_video', self.media_type
+ self.audio_sample_rate = audio_sample_rate
+ self.audio_reader_type = audio_reader_type
+ self.max_audio_length = max_audio_length
+
+ self.has_multi_audio_gt = ann_file.get("has_multi_audio_gt", False)
+ self.read_audio_from_video = ann_file.get("read_audio_from_video", False)
+ self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False)
+
+ def __getitem__(self, index):
+ try:
+ ann = self.anno_list[index]
+ caption = pre_text(ann["caption"])
+
+ data_path = {'video': ann["image"]}
+ data_path["read_clip_from_video"] = self.read_clip_from_video
+ if "audio" in ann.keys():
+ data_path["read_audio_from_video"] = False
+ data_path["audio"] = ann["audio"]
+ else:
+ data_path["read_audio_from_video"] = self.read_audio_from_video
+
+ media, index = self.load_and_transform_media_data(index, data_path)
+
+ audio = media[0]
+ if audio is None and self.zero_audio_padding_for_video:
+ logger.warning(f"No audio in {data_path}")
+ media[0] = torch.zeros((998, 64), dtype=torch.float32)
+
+ key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"])
+ return media, caption, self.match_ids[key]
+
+ except Exception as e:
+ logger.error(e)
+ print(e, flush=True)
+ index = np.random.randint(0, len(self))
+ return self.__getitem__(index)
+
+
+class ImgTxtRetEvalDataset(BaseDataset):
+ media_type = "image"
+
+ def __init__(self, ann_file, transform):
+ super(ImgTxtRetEvalDataset, self).__init__()
+ self.raw_anno_list = load_anno(ann_file)
+
+ self.transform = transform
+ self.has_multi_vision_gt = ann_file.get("has_multi_vision_gt", False) # each caption has multiple image as ground_truth
+
+ self.is_act_rec = ann_file.get("is_act_rec", False)
+ self.max_txt_l = ann_file.get("max_txt_l", 32) # NOTE
+
+ self.text = None
+ self.image = None
+ self.txt2img = None
+ self.img2txt = None
+ self.build_data()
+
+ def build_data(self):
+ self.text = []
+ self.image = []
+ self.txt2img = {}
+ self.img2txt = {}
+ if self.is_act_rec:
+ self.build_data_act_rec()
+ elif self.has_multi_vision_gt:
+ self.build_data_multi_img_gt()
+ else:
+ self.build_data_multi_txt_gt()
+ self.anno_list = [dict(image=e) for e in self.image]
+
+ def build_data_act_rec(self):
+ """action recognition task, e.g., kinetics400"""
+ text = list(set([e["caption"] for e in self.raw_anno_list]))
+ text2label = {e: i for i, e in enumerate(text)}
+ text = [[t.format(e) for t in kinetics_templates] for e in text]
+ text = [e for l in text for e in l]
+ self.text = [pre_text(e) for e in text]
+ self.num_prompts = len(kinetics_templates)
+ self.img2txt = {i: text2label[e["caption"]] for i, e in enumerate(self.raw_anno_list)}
+ self.txt2img = [[] for _ in range(len(text) // len(kinetics_templates))]
+ for i, e in enumerate(self.raw_anno_list):
+ self.image.append(e["image"])
+ self.txt2img[text2label[e["caption"]]].append(i)
+ logger.info(f"Action recognition, number of prompts: {self.num_prompts}")
+ logger.info(f"Action recognition, number of classes: {len(self.text)}")
+
+ def build_data_multi_img_gt(self):
+ """each text may have multiple ground_truth image, e.g., ssv2"""
+ img_id = 0
+ for txt_id, ann in enumerate(self.raw_anno_list):
+ self.text.append(pre_text(ann["caption"]))
+ self.txt2img[txt_id] = []
+ _images = ann["image"] \
+ if isinstance(ann["image"], list) else [ann["image"], ]
+ for i, image in enumerate(_images):
+ self.image.append(image)
+ self.txt2img[txt_id].append(img_id)
+ self.img2txt[img_id] = txt_id
+ img_id += 1
+
+ def build_data_multi_txt_gt(self):
+ """each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
+ txt_id = 0
+ for img_id, ann in enumerate(self.raw_anno_list):
+ self.image.append(ann["image"])
+ self.img2txt[img_id] = []
+ _captions = ann["caption"] \
+ if isinstance(ann["caption"], list) else [ann["caption"], ]
+ for i, caption in enumerate(_captions):
+ self.text.append(pre_text(caption))
+ self.img2txt[img_id].append(txt_id)
+ self.txt2img[txt_id] = img_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ image, index = self.load_and_transform_media_data(index, ann["image"])
+ return image, index
+
+
+class VidTxtRetEvalDataset(ImgTxtRetEvalDataset):
+ media_type = "video"
+
+ def __init__(
+ self, ann_file, transform, num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=1):
+ super(VidTxtRetEvalDataset, self).__init__(ann_file, transform)
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
+ self.is_paragraph_retrieval = ann_file.get("is_paragraph_retrieval", False)
+ if self.is_paragraph_retrieval:
+ logger.info("Preprocess paragraph retrieval data!!!")
+ self.anno_list = preprocess_para_retrieval_data(self.raw_anno_list)
+ self.trimmed30 = ann_file.get("trimmed30", False)
+ if self.trimmed30:
+ logger.info("Trimming the video, only use the first 30s!!!")
+ self.read_clip_from_video = ann_file.get("read_clip_from_video", False)
+ self.use_subtitle = ann_file.get("use_subtitle", False)
+ if self.use_subtitle:
+ if self.is_act_rec:
+ raise NotImplementedError
+ self.build_subtitle_data()
+
+ self.build_data()
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ if self.read_clip_from_video:
+ raise NotImplementedError("key for match_ids is not implemented!")
+ else:
+ data_path = ann["image"]
+ image, index = self.load_and_transform_media_data(index, data_path)
+ return image, index
+
+ def build_subtitle_data(self):
+ self.subtitle = []
+ for _, ann in enumerate(self.raw_anno_list):
+ if self.trimmed30:
+ if "asr_trimmed_30" in ann.keys():
+ self.subtitle.append(pre_text(ann["asr_trimmed_30"]))
+ else:
+ self.subtitle.append("")
+ else:
+ if "asr" in ann.keys():
+ self.subtitle.append(pre_text(ann["asr"]))
+ else:
+ self.subtitle.append("")
+
+
+def preprocess_para_retrieval_data(anno_list):
+ processed_anno_list = []
+ for d in anno_list:
+ d["caption"] = " ".join(d.pop("caption"))
+ processed_anno_list.append(d)
+ return processed_anno_list
+
+
+class VidTxtRetMCEvalDataset(BaseDataset):
+ """For MSRVTT-MC test task"""
+ media_type = "video"
+
+ def __init__(self, ann_file, transform, num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=1):
+ super(VidTxtRetMCEvalDataset, self).__init__()
+ self.anno_list = load_anno(ann_file)
+ self.transform = transform
+ # video args
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ image, index = self.load_and_transform_media_data(index, ann["image"])
+ caption = [pre_text(e) for e in ann["caption"]] # len=5
+ answer = ann["answer"]
+ return image, caption, answer, ann
+
+
+class VidTxtRetMCNewEvalDataset(BaseDataset):
+ """For SSV2-MC and Charades-MC test task"""
+ media_type = "video"
+
+ def __init__(self, ann_file, transform, num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=1):
+ super(VidTxtRetMCNewEvalDataset, self).__init__()
+ self.anno_list = load_anno(ann_file)
+ self.transform = transform
+ # video args
+ self.num_frames = num_frames
+ self.video_reader_type = video_reader_type
+ self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
+ self.sample_type = sample_type
+ self.num_tries = num_tries
+
+ def __len__(self):
+ return len(self.anno_list)
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ image, index = self.load_and_transform_media_data(index, ann["image"])
+ option = [pre_text(e) for e in ann["option"]] # len=174
+ answer = ann["answer"]
+ if isinstance(answer, list):
+ answer = torch.Tensor(answer)
+ return image, option, answer, ann
+
+
+class AudioVidTxtRetEvalDataset(VidTxtRetEvalDataset):
+ media_type = "audio_video"
+
+ def __init__(
+ self, ann_file, transform, num_frames=4,
+ video_reader_type="decord", sample_type="rand", num_tries=1,
+ audio_sample_rate=16000,
+ audio_reader_type='torchaudio',
+ max_audio_length=10):
+ super(AudioVidTxtRetEvalDataset, self).__init__(ann_file, transform,
+ num_frames=num_frames, video_reader_type=video_reader_type,
+ sample_type=sample_type, num_tries=num_tries)
+
+ self.audio_sample_rate = audio_sample_rate
+ self.audio_reader_type = audio_reader_type
+ self.max_audio_length = max_audio_length
+ self.read_clip_from_video = ann_file.get("read_clip_from_video", False)
+ self.read_audio_from_video = ann_file.get("read_audio_from_video", False)
+ self.zero_audio_padding_for_video = ann_file.get("zero_audio_padding_for_video", False)
+
+ def __getitem__(self, index):
+ ann = self.anno_list[index]
+ data_path = {'video': ann["image"]}
+
+ if self.read_clip_from_video:
+ raise NotImplementedError("Need to modify load_anno!")
+
+ if not self.read_audio_from_video:
+ raise NotImplementedError("Need to modify load_anno!")
+
+ data_path["read_clip_from_video"] = self.read_clip_from_video
+ data_path["read_audio_from_video"] = self.read_audio_from_video
+
+ media, index = self.load_and_transform_media_data(index, data_path)
+ audio = media[0]
+ if audio is None and self.zero_audio_padding_for_video:
+ media[0] = torch.zeros((998, 64), dtype=torch.float32)
+
+ return media, index
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/sampler.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fa8cc8e00bf3acbb4f031d5586df56782f71f00
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/sampler.py
@@ -0,0 +1,64 @@
+import numpy as np
+import torch
+from torch.utils.data.distributed import DistributedSampler
+
+
+# stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93
+class StatefulDistributedSampler(DistributedSampler):
+ """
+ More fine-grained state DataSampler that uses training iteration and epoch
+ both for shuffling data. PyTorch DistributedSampler only uses epoch
+ for the shuffling and starts sampling data from the start. In case of training
+ on very large data, we train for one epoch only and when we resume training,
+ we want to resume the data sampler from the training iteration.
+ """
+
+ def __init__(self, dataset, batch_size=None, seed: int = 0):
+ """
+ Initializes the instance of StatefulDistributedSampler. Random seed is set
+ for the epoch set and data is shuffled. For starting the sampling, use
+ the start_iter (set to 0 or set by checkpointing resuming) to
+ sample data from the remaining images.
+
+ Args:
+ dataset (Dataset): Pytorch dataset that sampler will shuffle
+ batch_size (int): batch size we want the sampler to sample
+ seed (int): Seed for the torch generator.
+ """
+ super().__init__(dataset, shuffle=False, seed=seed)
+
+ self.start_iter = 0
+ self.batch_size = batch_size
+ self.total_size = len(dataset) - (len(dataset) % self.num_replicas)
+ self.num_samples = self.total_size // self.num_replicas
+ print(f"rank: {self.rank}: Sampler created...")
+
+ def __iter__(self):
+ # partition data into num_replicas and optionally shuffle within a rank
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+ shuffling = torch.randperm(self.num_samples, generator=g).tolist()
+ indices = np.array(
+ list(
+ range(
+ (self.rank * self.num_samples), (self.rank + 1) * self.num_samples
+ )
+ )
+ )[shuffling].tolist()
+
+ # make sure we have correct number of samples per replica
+ assert len(indices) == self.num_samples
+ assert self.batch_size > 0, "batch_size not set for the sampler"
+
+ # resume the sampler
+ start_index = self.start_iter * self.batch_size
+ indices = indices[start_index:]
+ return iter(indices)
+
+ def set_start_iter(self, start_iter):
+ """
+ Set the iteration number from which the sampling should start. This is
+ used to find the marker in the data permutation order from where the
+ sampler should start sampling.
+ """
+ self.start_iter = start_iter
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/serialize.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/serialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7666434382e0a81c8ec636ea8215c1a9eacf52c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/serialize.py
@@ -0,0 +1,199 @@
+# Description: This file contains the code for serializing the dataset.
+# From https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/blob/795868a37446d61412b9a58dbb1b7c76e75d39c4/serialize.py
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+List serialization code adopted from
+https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py
+"""
+
+import multiprocessing as mp
+
+from typing import List, Any, Optional
+
+import pickle
+import numpy as np
+import torch
+import torch.distributed as dist
+
+import functools
+import os
+
+from datetime import timedelta
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+
+ # this is not guaranteed to be set
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ return int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ return int(os.environ['SLURM_LOCALID'])
+ else:
+ raise RuntimeError("Unable to get local rank")
+
+
+def get_local_size() -> int:
+ return torch.cuda.device_count()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo", timeout=timedelta(minutes=60))
+ else:
+ return dist.group.WORLD
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = (
+ _get_global_gloo_group()
+ ) # use CPU group by default, to reduce GPU RAM usage.
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return [data]
+
+ output = [None for _ in range(world_size)]
+ dist.all_gather_object(output, data, group=group)
+ return output
+
+
+class NumpySerializedList:
+ def __init__(self, lst: list):
+ def _serialize(data):
+ buffer = pickle.dumps(data, protocol=-1)
+ return np.frombuffer(buffer, dtype=np.uint8)
+
+ print(
+ "Serializing {} elements to byte tensors and concatenating them all ...".format(
+ len(lst)
+ )
+ )
+ self._lst = [_serialize(x) for x in lst]
+ self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
+ self._addr = np.cumsum(self._addr)
+ self._lst = np.concatenate(self._lst)
+ print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
+
+ def __len__(self):
+ return len(self._addr)
+
+ def __getitem__(self, idx):
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+ end_addr = self._addr[idx].item()
+ bytes = memoryview(self._lst[start_addr:end_addr])
+ return pickle.loads(bytes)
+
+
+class TorchSerializedList(NumpySerializedList):
+ def __init__(self, lst: list):
+ super().__init__(lst)
+ self._addr = torch.from_numpy(self._addr)
+ self._lst = torch.from_numpy(self._lst)
+
+ def __getitem__(self, idx):
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+ end_addr = self._addr[idx].item()
+ bytes = memoryview(self._lst[start_addr:end_addr].numpy())
+ return pickle.loads(bytes)
+
+
+def local_scatter(array: Optional[List[Any]]):
+ """
+ Scatter an array from local leader to all local workers.
+ The i-th local worker gets array[i].
+
+ Args:
+ array: Array with same size of #local workers.
+ """
+ if get_local_size() <= 1:
+ # Just one worker. Do nothing.
+ return array[0]
+ if get_local_rank() == 0:
+ assert len(array) == get_local_size()
+ all_gather(array)
+ else:
+ all_data = all_gather(None)
+ array = all_data[get_rank() - get_local_rank()]
+ return array[get_local_rank()]
+
+
+# NOTE: https://github.com/facebookresearch/mobile-vision/pull/120
+# has another implementation that does not use tensors.
+class TorchShmSerializedList(TorchSerializedList):
+ def __init__(self, lst: list):
+ if get_local_rank() == 0:
+ super().__init__(lst)
+
+ if get_local_rank() == 0:
+ # Move data to shared memory, obtain a handle to send to each local worker.
+ # This is cheap because a tensor will only be moved to shared memory once.
+ handles = [None] + [
+ bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst)))
+ for _ in range(get_local_size() - 1)
+ ]
+ else:
+ handles = None
+ # Each worker receives the handle from local leader.
+ handle = local_scatter(handles)
+
+ if get_local_rank() > 0:
+ # Materialize the tensor from shared memory.
+ self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle)
+ print(
+ f"Worker {get_rank()} obtains a dataset of length="
+ f"{len(self)} from its local leader."
+ )
+
+
+# From https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/issues/5#issuecomment-1510676170
+def local_broadcast_process_authkey():
+ if int(os.environ['LOCAL_WORLD_SIZE']) == 1:
+ return
+ local_rank = int(os.environ['LOCAL_RANK'])
+ authkey = bytes(mp.current_process().authkey)
+ all_keys = all_gather(authkey)
+ local_leader_key = all_keys[get_rank() - local_rank]
+ if authkey != local_leader_key:
+ print("Process authkey is different from the key of local leader. This might happen when "
+ "workers are launched independently.")
+ print("Overwriting local authkey ...")
+ mp.current_process().authkey = local_leader_key
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/text_prompt.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/text_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb6e835e0df0f43d52b35edd809d594552353eec
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/text_prompt.py
@@ -0,0 +1,132 @@
+kinetics_templates_action_clip = [
+ f"a photo of action {{}}",
+ f"a picture of action {{}}",
+ f"Human action of {{}}",
+ f"{{}}, an action",
+ f"{{}} this is an action",
+ f"{{}}, a video of action",
+ f"Playing action of {{}}",
+ f"{{}}",
+ f"Playing a kind of action, {{}}",
+ f"Doing a kind of action, {{}}",
+ f"Look, the human is {{}}",
+ f"Can you recognize the action of {{}}?",
+ f"Video classification of {{}}",
+ f"A video of {{}}",
+ f"The man is {{}}",
+ f"The woman is {{}}",
+]
+
+kinetics_templates = [
+ "A photo of action {}.",
+ "A video of action {}.",
+ "He or she is {}.",
+ "A person is doing {}.",
+ "Look, the human is {}.",
+ "Human action of {}.",
+ "Playing action of {}.",
+ "Video classification of {}.",
+ "Doing a kind of action, {}.",
+ "Playing a kind of action, {}.",
+ "Can you recognize the action of {}?",
+ "{}, an action.",
+ "{} this is an action.",
+ "{}, a video of action.",
+ "An action of {} is in the video.",
+ "There is a person doing {} in the video.",
+ "A photo of a person doing {}.",
+ "A photo of a person performing {}.",
+ "A photo of a person practicing {}.",
+ "A video of a person doing {}.",
+ "A video of a person performing {}.",
+ "A video of a person practicing {}.",
+ "A example of a person doing {}.",
+ "A example of a person performing {}.",
+ "A example of a person practicing {}.",
+ "A demonstration of a person doing {}.",
+ "A demonstration of a person performing {}.",
+ "A demonstration of a person practicing {}.",
+]
+
+imagenet_templates = [
+ "a bad photo of a {}.",
+ "a photo of many {}.",
+ "a sculpture of a {}.",
+ "a photo of the hard to see {}.",
+ "a low resolution photo of the {}.",
+ "a rendering of a {}.",
+ "graffiti of a {}.",
+ "a bad photo of the {}.",
+ "a cropped photo of the {}.",
+ "a tattoo of a {}.",
+ "the embroidered {}.",
+ "a photo of a hard to see {}.",
+ "a bright photo of a {}.",
+ "a photo of a clean {}.",
+ "a photo of a dirty {}.",
+ "a dark photo of the {}.",
+ "a drawing of a {}.",
+ "a photo of my {}.",
+ "the plastic {}.",
+ "a photo of the cool {}.",
+ "a close-up photo of a {}.",
+ "a black and white photo of the {}.",
+ "a painting of the {}.",
+ "a painting of a {}.",
+ "a pixelated photo of the {}.",
+ "a sculpture of the {}.",
+ "a bright photo of the {}.",
+ "a cropped photo of a {}.",
+ "a plastic {}.",
+ "a photo of the dirty {}.",
+ "a jpeg corrupted photo of a {}.",
+ "a blurry photo of the {}.",
+ "a photo of the {}.",
+ "a good photo of the {}.",
+ "a rendering of the {}.",
+ "a {} in a video game.",
+ "a photo of one {}.",
+ "a doodle of a {}.",
+ "a close-up photo of the {}.",
+ "a photo of a {}.",
+ "the origami {}.",
+ "the {} in a video game.",
+ "a sketch of a {}.",
+ "a doodle of the {}.",
+ "a origami {}.",
+ "a low resolution photo of a {}.",
+ "the toy {}.",
+ "a rendition of the {}.",
+ "a photo of the clean {}.",
+ "a photo of a large {}.",
+ "a rendition of a {}.",
+ "a photo of a nice {}.",
+ "a photo of a weird {}.",
+ "a blurry photo of a {}.",
+ "a cartoon {}.",
+ "art of a {}.",
+ "a sketch of the {}.",
+ "a embroidered {}.",
+ "a pixelated photo of a {}.",
+ "itap of the {}.",
+ "a jpeg corrupted photo of the {}.",
+ "a good photo of a {}.",
+ "a plushie {}.",
+ "a photo of the nice {}.",
+ "a photo of the small {}.",
+ "a photo of the weird {}.",
+ "the cartoon {}.",
+ "art of the {}.",
+ "a drawing of the {}.",
+ "a photo of the large {}.",
+ "a black and white photo of a {}.",
+ "the plushie {}.",
+ "a dark photo of a {}.",
+ "itap of a {}.",
+ "graffiti of the {}.",
+ "a toy {}.",
+ "itap of my {}.",
+ "a photo of a cool {}.",
+ "a photo of a small {}.",
+ "a tattoo of the {}.",
+]
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/utils.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..725fa1333773db6a6581308b67161c3b75ea167f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/utils.py
@@ -0,0 +1,279 @@
+from utils.distributed import is_main_process, get_rank, get_world_size
+import logging
+import torch.distributed as dist
+import torch
+import io
+import os
+import json
+import re
+import random
+import numpy as np
+from os.path import join
+from tqdm import trange
+from PIL import Image
+from PIL import ImageFile
+from torchvision.transforms import PILToTensor
+import librosa
+import torchaudio
+# import soundfile as sf
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+Image.MAX_IMAGE_PIXELS = None
+
+logger = logging.getLogger(__name__)
+
+
+def load_audio_from_path(audio_path, client, sr, audio_reader_type, max_length=0):
+ # print(f"audio_path: {audio_path}, client: {client}, sr: {sr}, audio_reader_type: {audio_reader_type}")
+ if "s3://" in audio_path and client is not None:
+ audio_bytes = client.get(audio_path)
+ buff = io.BytesIO(audio_bytes)
+ else:
+ buff = audio_path
+ if audio_reader_type == 'librosa':
+ audio, _ = librosa.load(buff, sr=sr)
+ audio = torch.from_numpy(audio)
+ # audio = normalize(audio) # normalize waveform to -1,1 due to specified sr in librosa.load
+ # elif audio_reader_type == 'soundfile':
+ # audio, _ = sf.read(buff, sr=sr)
+ # audio = torch.from_numpy(audio)
+ elif audio_reader_type == 'torchaudio':
+ torchaudio.set_audio_backend('soundfile') # for flac files
+ audio, csr = torchaudio.load(buff)
+ if csr != sr:
+ trans = torchaudio.transforms.Resample(csr, sr)
+ audio = trans(audio)
+ if audio.size(0) == 2:
+ audio = torch.mean(audio, dim=0, keepdim=False)
+ else:
+ raise NotImplementedError
+ if max_length != 0:
+ # if audio length is longer than max_length, we randomly crop it to uta length
+ if audio.shape[0] >= max_length:
+ max_start = audio.shape[0] - max_length
+ start = random.randint(0, max_start)
+ audio = audio[start: start + max_length]
+ # padding = torch.zeros(audio.shape).long()
+ else:
+ # padding = torch.cat((torch.zeros(audio.shape), torch.ones(max_length-audio.shape[0])), -1).long()
+ audio = torch.nn.functional.pad(audio, (0, max_length-audio.shape[-1]), 'constant')
+ # print(f"post audio max: {audio.max()}, audio min: {audio.min()}, audio shape: {audio.shape}")
+ if len(audio.shape) == 1:
+ audio = audio.unsqueeze(0)
+ fbank = audio * 2 ** 15
+ fbank = torchaudio.compliance.kaldi.fbank(fbank, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10)
+ fbank_mean = 15.41663
+ fbank_std = 6.55582
+ fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64
+ return fbank
+
+
+def load_image_from_path(image_path, client):
+ if "s3://" in image_path and client is not None:
+ value = client.Get(image_path)
+ if value is None:
+ logger.warning(f"Failed to load {image_path}")
+ img_bytes = np.frombuffer(value, dtype=np.uint8)
+ buff = io.BytesIO(img_bytes)
+ image = Image.open(buff).convert('RGB')
+ else:
+ image = Image.open(image_path).convert('RGB') # PIL Image
+ image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8
+ return image
+
+
+def load_anno(ann_file_list):
+ """[summary]
+
+ Args:
+ ann_file_list (List[List[str, str]] or List[str, str]):
+ the latter will be automatically converted to the former.
+ Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
+ which specifies the data type, video or image
+
+ Returns:
+ List(dict): each dict is {
+ image: str or List[str], # image_path,
+ caption: str or List[str] # caption text string
+ }
+ """
+ if isinstance(ann_file_list, dict):
+ ann_file_list = [ann_file_list]
+
+ ann = []
+ for d in ann_file_list:
+
+ data_root = d.data_root
+ data_root_prefix = d.get("data_root_prefix", "")
+ fp = d.anno_path
+
+ cur_ann = json.load(open(fp, "r"))
+ iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
+ if is_main_process() else range(len(cur_ann))
+ for idx in iterator:
+ if d.media_type == "image":
+ key = "image"
+ elif d.media_type in ["video", "audio_video"]:
+ key = "video"
+ elif d.media_type == "audio":
+ key = "audio"
+ else:
+ raise NotImplementedError(key)
+
+ # unified to have the same key for data path
+ if isinstance(cur_ann[idx][key], str):
+ cur_ann[idx]["image"] = data_root_prefix + join(data_root, cur_ann[idx][key])
+ else: # list
+ cur_ann[idx]["image"] = [data_root_prefix + join(data_root, e) for e in cur_ann[idx][key]]
+ ann += cur_ann
+ return ann
+
+
+def pre_text(text, max_l=None):
+ assert type(text) is str, text
+ text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
+ text = text.replace('-', ' ').replace('/', ' ').replace('', 'person')
+
+ text = re.sub(r"\s{2,}", ' ', text)
+ text = text.rstrip('\n').strip(' ')
+
+ if max_l: # truncate
+ words = text.split(' ')
+ if len(words) > max_l:
+ text = ' '.join(words[:max_l])
+ return text
+
+
+def collect_result(result, result_dir, filename, is_json=True, is_list=True):
+ if is_json:
+ result_file = os.path.join(
+ result_dir, '%s_rank%d.json' % (filename, get_rank()))
+ final_result_file = os.path.join(result_dir, '%s.json' % filename)
+ json.dump(result, open(result_file, 'w'))
+ else:
+ result_file = os.path.join(
+ result_dir, '%s_rank%d.pth' % (filename, get_rank()))
+ final_result_file = os.path.join(result_dir, '%s.pth' % filename)
+ torch.save(result, result_file)
+
+ dist.barrier()
+
+ result = None
+ if is_main_process():
+ # combine results from all processes
+ if is_list:
+ result = []
+ else:
+ result = {}
+ for rank in range(get_world_size()):
+ if is_json:
+ result_file = os.path.join(
+ result_dir, '%s_rank%d.json' % (filename, rank))
+ res = json.load(open(result_file, 'r'))
+ else:
+ result_file = os.path.join(
+ result_dir, '%s_rank%d.pth' % (filename, rank))
+ res = torch.load(result_file)
+ if is_list:
+ result += res
+ else:
+ result.update(res)
+
+ return result
+
+
+def sync_save_result(result, result_dir, filename, is_json=True, is_list=True):
+ """gather results from multiple GPUs"""
+ if is_json:
+ result_file = os.path.join(
+ result_dir, "dist_res", '%s_rank%d.json' % (filename, get_rank()))
+ final_result_file = os.path.join(result_dir, '%s.json' % filename)
+ os.makedirs(os.path.dirname(result_file), exist_ok=True)
+ json.dump(result, open(result_file, 'w'))
+ else:
+ result_file = os.path.join(
+ result_dir, "dist_res", '%s_rank%d.pth' % (filename, get_rank()))
+ os.makedirs(os.path.dirname(result_file), exist_ok=True)
+ final_result_file = os.path.join(result_dir, '%s.pth' % filename)
+ torch.save(result, result_file)
+
+ dist.barrier()
+
+ if is_main_process():
+ # combine results from all processes
+ if is_list:
+ result = []
+ else:
+ result = {}
+ for rank in range(get_world_size()):
+ if is_json:
+ result_file = os.path.join(
+ result_dir, "dist_res", '%s_rank%d.json' % (filename, rank))
+ res = json.load(open(result_file, 'r'))
+ else:
+ result_file = os.path.join(
+ result_dir, "dist_res", '%s_rank%d.pth' % (filename, rank))
+ res = torch.load(result_file)
+ if is_list:
+ result += res
+ else:
+ result.update(res)
+ if is_json:
+ json.dump(result, open(final_result_file, 'w'))
+ else:
+ torch.save(result, final_result_file)
+
+ logger.info('result file saved to %s' % final_result_file)
+ dist.barrier()
+ return final_result_file, result
+
+
+def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
+ """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
+ into a (n+1)-d array, only allow the first dim has variable lengths.
+ Args:
+ sequences: list(n-d tensor or list)
+ dtype: np.dtype or torch.dtype
+ device:
+ fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
+ return will be of shape [len(sequences), fixed_length, ...]
+ Returns:
+ padded_seqs: ((n+1)-d tensor) padded with zeros
+ mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
+ 1 indicate valid, 0 otherwise
+ Examples:
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
+ >>> pad_sequences_1d(test_data_list, dtype=torch.long)
+ >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
+ >>> pad_sequences_1d(test_data_3d, dtype=torch.float)
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
+ >>> pad_sequences_1d(test_data_list, dtype=np.float32)
+ >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
+ >>> pad_sequences_1d(test_data_3d, dtype=np.float32)
+ """
+ if isinstance(sequences[0], list):
+ if "torch" in str(dtype):
+ sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
+ else:
+ sequences = [np.asarray(s, dtype=dtype) for s in sequences]
+
+ extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements
+ lengths = [len(seq) for seq in sequences]
+ if fixed_length is not None:
+ max_length = fixed_length
+ else:
+ max_length = max(lengths)
+ if isinstance(sequences[0], torch.Tensor):
+ assert "torch" in str(dtype), "dtype and input type does not match"
+ padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
+ mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
+ else: # np
+ assert "numpy" in str(dtype), "dtype and input type does not match"
+ padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
+ mask = np.zeros((len(sequences), max_length), dtype=np.float32)
+
+ for idx, seq in enumerate(sequences):
+ end = lengths[idx]
+ padded_seqs[idx, :end] = seq
+ mask[idx, :end] = 1
+ return padded_seqs, mask # , lengths
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/dataset/video_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/dataset/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d19052b9ec51c8c99793c52df6690a7ae93fc5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/dataset/video_utils.py
@@ -0,0 +1,209 @@
+"""
+Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
+"""
+import os
+import random
+import io
+import av
+import cv2
+import decord
+import imageio
+from decord import VideoReader
+import torch
+import numpy as np
+import math
+decord.bridge.set_bridge("torch")
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
+ """
+ Converts a present time with the given time base and start_pts offset to seconds.
+
+ Returns:
+ time_in_seconds (float): The corresponding time in seconds.
+
+ https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
+ """
+ if pts == math.inf:
+ return math.inf
+
+ return int(pts - start_pts) * time_base
+
+
+def get_pyav_video_duration(video_reader):
+ video_stream = video_reader.streams.video[0]
+ video_duration = pts_to_secs(
+ video_stream.duration,
+ video_stream.time_base,
+ video_stream.start_time
+ )
+ return float(video_duration)
+
+
+def get_frame_indices_by_fps():
+ pass
+
+
+def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
+ if sample in ["rand", "middle"]: # uniform sampling
+ acc_samples = min(num_frames, vlen)
+ # split the video into `acc_samples` intervals, and sample from each interval.
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
+ ranges = []
+ for idx, interv in enumerate(intervals[:-1]):
+ ranges.append((interv, intervals[idx + 1] - 1))
+ if sample == 'rand':
+ try:
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
+ except:
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
+ frame_indices.sort()
+ frame_indices = list(frame_indices)
+ elif fix_start is not None:
+ frame_indices = [x[0] + fix_start for x in ranges]
+ elif sample == 'middle':
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
+ else:
+ raise NotImplementedError
+
+ if len(frame_indices) < num_frames: # padded with last frame
+ padded_frame_indices = [frame_indices[-1]] * num_frames
+ padded_frame_indices[:len(frame_indices)] = frame_indices
+ frame_indices = padded_frame_indices
+ elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
+ output_fps = float(sample[3:])
+ duration = float(vlen) / input_fps
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
+ frame_indices = [e for e in frame_indices if e < vlen]
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
+ frame_indices = frame_indices[:max_num_frames]
+ # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
+ else:
+ raise ValueError
+ return frame_indices
+
+
+def read_frames_av(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
+ reader = av.open(video_path)
+ frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
+ vlen = len(frames)
+ duration = get_pyav_video_duration(reader)
+ fps = vlen / float(duration)
+ frame_indices = get_frame_indices(
+ num_frames, vlen, sample=sample, fix_start=fix_start,
+ input_fps=fps, max_num_frames=max_num_frames
+ )
+ frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
+ frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
+ return frames, frame_indices, duration
+
+
+def read_frames_gif(
+ video_path, num_frames, sample='rand', fix_start=None,
+ max_num_frames=-1, client=None, trimmed30=False,
+ ):
+ if 's3://' in video_path:
+ video_bytes = client.get(video_path)
+ gif = imageio.get_reader(io.BytesIO(video_bytes))
+ else:
+ gif = imageio.get_reader(video_path)
+ vlen = len(gif)
+ frame_indices = get_frame_indices(
+ num_frames, vlen, sample=sample, fix_start=fix_start,
+ max_num_frames=max_num_frames
+ )
+ frames = []
+ for index, frame in enumerate(gif):
+ # for index in frame_idxs:
+ if index in frame_indices:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
+ frame = torch.from_numpy(frame).byte()
+ # # (H x W x C) to (C x H x W)
+ frame = frame.permute(2, 0, 1)
+ frames.append(frame)
+ frames = torch.stack(frames) # .float() / 255
+ return frames, frame_indices, None
+
+
+def read_frames_decord(
+ video_path, num_frames, sample='rand', fix_start=None,
+ max_num_frames=-1, client=None, trimmed30=False
+ ):
+ num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
+ if "s3://" in video_path:
+ video_bytes = client.get(video_path)
+ # print(f"\033[1;31;40m {video_path} ok: {video_bytes is None} \033[0m")
+ if video_bytes is None:
+ logger.warning(f"Failed to load {video_path}")
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=num_threads)
+ else:
+ video_reader = VideoReader(video_path, num_threads=num_threads)
+ vlen = len(video_reader)
+
+ fps = video_reader.get_avg_fps()
+ duration = vlen / float(fps)
+
+ # only use top 30 seconds
+ if trimmed30 and duration > 30:
+ duration = 30
+ vlen = int(30 * float(fps))
+
+ frame_indices = get_frame_indices(
+ num_frames, vlen, sample=sample, fix_start=fix_start,
+ input_fps=fps, max_num_frames=max_num_frames
+ )
+
+ frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
+ frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
+ return frames, frame_indices, duration
+
+
+def read_frames_img(
+ video_path, num_frames, sample='rand', fix_start=None,
+ max_num_frames=-1, client=None, trimmed30=False
+ ):
+ img_list=[]
+ if "s3://" in video_path:
+ for path in client.list(video_path):
+ if path.startswith('img'):
+ img_list.append(path)
+ else:
+ for path in os.listdir(video_path):
+ if path.startswith('img'):
+ img_list.append(path)
+
+ vlen = len(img_list)
+
+ frame_indices = get_frame_indices(
+ num_frames, vlen, sample=sample, fix_start=fix_start,
+ max_num_frames=max_num_frames
+ )
+
+ imgs = []
+ for idx in frame_indices:
+ frame_fname = os.path.join(video_path, img_list[idx])
+ if "s3://" in video_path:
+ img_bytes = client.get(frame_fname)
+ else:
+ with open(frame_fname, 'rb') as f:
+ img_bytes = f.read()
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ imgs.append(img)
+
+ frames = torch.tensor(np.array(imgs), dtype=torch.uint8).permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
+ return frames, frame_indices, None
+
+
+VIDEO_READER_FUNCS = {
+ 'av': read_frames_av,
+ 'decord': read_frames_decord,
+ 'gif': read_frames_gif,
+ 'img': read_frames_img,
+}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/demo.ipynb b/third_party/InternVideo/InternVideo2/multi_modality/demo/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..de6edf24a92a46beee1d7400e2633c6a1879a2d3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/demo/demo.ipynb
@@ -0,0 +1,584 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/toolkit/.conda/envs/urlb_test/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "DropoutAddRMSNorm of flash_attn is not installed!!!\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2024-04-16 22:03:29,983] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import os\n",
+ "import io\n",
+ "import cv2\n",
+ "\n",
+ "import torch\n",
+ "import sys\n",
+ "sys.path.insert(0, '/home/toolkit/eai_urlb/InternVideo/InternVideo2/multi_modality/demo/')\n",
+ "sys.path.insert(0, '/home/toolkit/eai_urlb/InternVideo/InternVideo2/multi_modality')\n",
+ "\n",
+ "from small_config import (Config, eval_dict_leaf)\n",
+ "from small_utils import (retrieve_text, _frame_from_video, setup_internvideo2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# video = cv2.VideoCapture('example1.mp4')\n",
+ "video = cv2.VideoCapture('../../../../video_samples/person_walking_video.mp4')\n",
+ "frames = [x for x in _frame_from_video(video)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "text_candidates = [\"A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon.\",\n",
+ " \"A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys.\",\n",
+ " \"A person dressed in a blue jacket shovels the snow-covered pavement outside their house.\",\n",
+ " \"A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner.\",\n",
+ " \"A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride.\",\n",
+ " \"A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees.\",\n",
+ " \"A playful dog slides down a snowy hill, wagging its tail with delight.\",\n",
+ " \"A person in a blue jacket walks their pet on a leash, enjoying a peaceful winter walk among the trees.\",\n",
+ " \"A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run.\",\n",
+ " \"A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery.\",\n",
+ " \"A person playing with a kid in the street\",\n",
+ " \"A group of friends playing bowling.\",\n",
+ " \"A japanese girl eating noodles\",\n",
+ " \"A painting by Monet\",\n",
+ " \"A person lying in bed\",\n",
+ " \"A person lying down on the grass\",\n",
+ " \"A person with a hat\",\n",
+ " \"Playing with hat\",\n",
+ " \"Somebody walking\",\n",
+ " \"Fidget spinner\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load_state_dict: _IncompatibleKeys(missing_keys=['text_encoder.embeddings.position_ids', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.token_type_embeddings.weight', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.intermediate.dense.weight', 'text_encoder.encoder.layer.0.intermediate.dense.bias', 'text_encoder.encoder.layer.0.output.dense.weight', 'text_encoder.encoder.layer.0.output.dense.bias', 'text_encoder.encoder.layer.0.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.output.LayerNorm.bias', 'text_encoder.encoder.layer.1.attention.self.query.weight', 'text_encoder.encoder.layer.1.attention.self.query.bias', 'text_encoder.encoder.layer.1.attention.self.key.weight', 'text_encoder.encoder.layer.1.attention.self.key.bias', 'text_encoder.encoder.layer.1.attention.self.value.weight', 'text_encoder.encoder.layer.1.attention.self.value.bias', 'text_encoder.encoder.layer.1.attention.output.dense.weight', 'text_encoder.encoder.layer.1.attention.output.dense.bias', 'text_encoder.encoder.layer.1.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.1.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.1.intermediate.dense.weight', 'text_encoder.encoder.layer.1.intermediate.dense.bias', 'text_encoder.encoder.layer.1.output.dense.weight', 'text_encoder.encoder.layer.1.output.dense.bias', 'text_encoder.encoder.layer.1.output.LayerNorm.weight', 'text_encoder.encoder.layer.1.output.LayerNorm.bias', 'text_encoder.encoder.layer.2.attention.self.query.weight', 'text_encoder.encoder.layer.2.attention.self.query.bias', 'text_encoder.encoder.layer.2.attention.self.key.weight', 'text_encoder.encoder.layer.2.attention.self.key.bias', 'text_encoder.encoder.layer.2.attention.self.value.weight', 'text_encoder.encoder.layer.2.attention.self.value.bias', 'text_encoder.encoder.layer.2.attention.output.dense.weight', 'text_encoder.encoder.layer.2.attention.output.dense.bias', 'text_encoder.encoder.layer.2.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.2.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.2.intermediate.dense.weight', 'text_encoder.encoder.layer.2.intermediate.dense.bias', 'text_encoder.encoder.layer.2.output.dense.weight', 'text_encoder.encoder.layer.2.output.dense.bias', 'text_encoder.encoder.layer.2.output.LayerNorm.weight', 'text_encoder.encoder.layer.2.output.LayerNorm.bias', 'text_encoder.encoder.layer.3.attention.self.query.weight', 'text_encoder.encoder.layer.3.attention.self.query.bias', 'text_encoder.encoder.layer.3.attention.self.key.weight', 'text_encoder.encoder.layer.3.attention.self.key.bias', 'text_encoder.encoder.layer.3.attention.self.value.weight', 'text_encoder.encoder.layer.3.attention.self.value.bias', 'text_encoder.encoder.layer.3.attention.output.dense.weight', 'text_encoder.encoder.layer.3.attention.output.dense.bias', 'text_encoder.encoder.layer.3.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.3.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.3.intermediate.dense.weight', 'text_encoder.encoder.layer.3.intermediate.dense.bias', 'text_encoder.encoder.layer.3.output.dense.weight', 'text_encoder.encoder.layer.3.output.dense.bias', 'text_encoder.encoder.layer.3.output.LayerNorm.weight', 'text_encoder.encoder.layer.3.output.LayerNorm.bias', 'text_encoder.encoder.layer.4.attention.self.query.weight', 'text_encoder.encoder.layer.4.attention.self.query.bias', 'text_encoder.encoder.layer.4.attention.self.key.weight', 'text_encoder.encoder.layer.4.attention.self.key.bias', 'text_encoder.encoder.layer.4.attention.self.value.weight', 'text_encoder.encoder.layer.4.attention.self.value.bias', 'text_encoder.encoder.layer.4.attention.output.dense.weight', 'text_encoder.encoder.layer.4.attention.output.dense.bias', 'text_encoder.encoder.layer.4.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.4.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.4.intermediate.dense.weight', 'text_encoder.encoder.layer.4.intermediate.dense.bias', 'text_encoder.encoder.layer.4.output.dense.weight', 'text_encoder.encoder.layer.4.output.dense.bias', 'text_encoder.encoder.layer.4.output.LayerNorm.weight', 'text_encoder.encoder.layer.4.output.LayerNorm.bias', 'text_encoder.encoder.layer.5.attention.self.query.weight', 'text_encoder.encoder.layer.5.attention.self.query.bias', 'text_encoder.encoder.layer.5.attention.self.key.weight', 'text_encoder.encoder.layer.5.attention.self.key.bias', 'text_encoder.encoder.layer.5.attention.self.value.weight', 'text_encoder.encoder.layer.5.attention.self.value.bias', 'text_encoder.encoder.layer.5.attention.output.dense.weight', 'text_encoder.encoder.layer.5.attention.output.dense.bias', 'text_encoder.encoder.layer.5.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.5.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.5.intermediate.dense.weight', 'text_encoder.encoder.layer.5.intermediate.dense.bias', 'text_encoder.encoder.layer.5.output.dense.weight', 'text_encoder.encoder.layer.5.output.dense.bias', 'text_encoder.encoder.layer.5.output.LayerNorm.weight', 'text_encoder.encoder.layer.5.output.LayerNorm.bias', 'text_encoder.encoder.layer.6.attention.self.query.weight', 'text_encoder.encoder.layer.6.attention.self.query.bias', 'text_encoder.encoder.layer.6.attention.self.key.weight', 'text_encoder.encoder.layer.6.attention.self.key.bias', 'text_encoder.encoder.layer.6.attention.self.value.weight', 'text_encoder.encoder.layer.6.attention.self.value.bias', 'text_encoder.encoder.layer.6.attention.output.dense.weight', 'text_encoder.encoder.layer.6.attention.output.dense.bias', 'text_encoder.encoder.layer.6.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.6.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.6.intermediate.dense.weight', 'text_encoder.encoder.layer.6.intermediate.dense.bias', 'text_encoder.encoder.layer.6.output.dense.weight', 'text_encoder.encoder.layer.6.output.dense.bias', 'text_encoder.encoder.layer.6.output.LayerNorm.weight', 'text_encoder.encoder.layer.6.output.LayerNorm.bias', 'text_encoder.encoder.layer.7.attention.self.query.weight', 'text_encoder.encoder.layer.7.attention.self.query.bias', 'text_encoder.encoder.layer.7.attention.self.key.weight', 'text_encoder.encoder.layer.7.attention.self.key.bias', 'text_encoder.encoder.layer.7.attention.self.value.weight', 'text_encoder.encoder.layer.7.attention.self.value.bias', 'text_encoder.encoder.layer.7.attention.output.dense.weight', 'text_encoder.encoder.layer.7.attention.output.dense.bias', 'text_encoder.encoder.layer.7.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.7.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.7.intermediate.dense.weight', 'text_encoder.encoder.layer.7.intermediate.dense.bias', 'text_encoder.encoder.layer.7.output.dense.weight', 'text_encoder.encoder.layer.7.output.dense.bias', 'text_encoder.encoder.layer.7.output.LayerNorm.weight', 'text_encoder.encoder.layer.7.output.LayerNorm.bias', 'text_encoder.encoder.layer.8.attention.self.query.weight', 'text_encoder.encoder.layer.8.attention.self.query.bias', 'text_encoder.encoder.layer.8.attention.self.key.weight', 'text_encoder.encoder.layer.8.attention.self.key.bias', 'text_encoder.encoder.layer.8.attention.self.value.weight', 'text_encoder.encoder.layer.8.attention.self.value.bias', 'text_encoder.encoder.layer.8.attention.output.dense.weight', 'text_encoder.encoder.layer.8.attention.output.dense.bias', 'text_encoder.encoder.layer.8.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.8.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.8.intermediate.dense.weight', 'text_encoder.encoder.layer.8.intermediate.dense.bias', 'text_encoder.encoder.layer.8.output.dense.weight', 'text_encoder.encoder.layer.8.output.dense.bias', 'text_encoder.encoder.layer.8.output.LayerNorm.weight', 'text_encoder.encoder.layer.8.output.LayerNorm.bias', 'text_encoder.encoder.layer.9.attention.self.query.weight', 'text_encoder.encoder.layer.9.attention.self.query.bias', 'text_encoder.encoder.layer.9.attention.self.key.weight', 'text_encoder.encoder.layer.9.attention.self.key.bias', 'text_encoder.encoder.layer.9.attention.self.value.weight', 'text_encoder.encoder.layer.9.attention.self.value.bias', 'text_encoder.encoder.layer.9.attention.output.dense.weight', 'text_encoder.encoder.layer.9.attention.output.dense.bias', 'text_encoder.encoder.layer.9.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.9.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.9.intermediate.dense.weight', 'text_encoder.encoder.layer.9.intermediate.dense.bias', 'text_encoder.encoder.layer.9.output.dense.weight', 'text_encoder.encoder.layer.9.output.dense.bias', 'text_encoder.encoder.layer.9.output.LayerNorm.weight', 'text_encoder.encoder.layer.9.output.LayerNorm.bias', 'text_encoder.encoder.layer.10.attention.self.query.weight', 'text_encoder.encoder.layer.10.attention.self.query.bias', 'text_encoder.encoder.layer.10.attention.self.key.weight', 'text_encoder.encoder.layer.10.attention.self.key.bias', 'text_encoder.encoder.layer.10.attention.self.value.weight', 'text_encoder.encoder.layer.10.attention.self.value.bias', 'text_encoder.encoder.layer.10.attention.output.dense.weight', 'text_encoder.encoder.layer.10.attention.output.dense.bias', 'text_encoder.encoder.layer.10.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.10.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.10.intermediate.dense.weight', 'text_encoder.encoder.layer.10.intermediate.dense.bias', 'text_encoder.encoder.layer.10.output.dense.weight', 'text_encoder.encoder.layer.10.output.dense.bias', 'text_encoder.encoder.layer.10.output.LayerNorm.weight', 'text_encoder.encoder.layer.10.output.LayerNorm.bias', 'text_encoder.encoder.layer.11.attention.self.query.weight', 'text_encoder.encoder.layer.11.attention.self.query.bias', 'text_encoder.encoder.layer.11.attention.self.key.weight', 'text_encoder.encoder.layer.11.attention.self.key.bias', 'text_encoder.encoder.layer.11.attention.self.value.weight', 'text_encoder.encoder.layer.11.attention.self.value.bias', 'text_encoder.encoder.layer.11.attention.output.dense.weight', 'text_encoder.encoder.layer.11.attention.output.dense.bias', 'text_encoder.encoder.layer.11.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.11.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.11.intermediate.dense.weight', 'text_encoder.encoder.layer.11.intermediate.dense.bias', 'text_encoder.encoder.layer.11.output.dense.weight', 'text_encoder.encoder.layer.11.output.dense.bias', 'text_encoder.encoder.layer.11.output.LayerNorm.weight', 'text_encoder.encoder.layer.11.output.LayerNorm.bias', 'text_encoder.encoder.layer.12.attention.self.query.weight', 'text_encoder.encoder.layer.12.attention.self.query.bias', 'text_encoder.encoder.layer.12.attention.self.key.weight', 'text_encoder.encoder.layer.12.attention.self.key.bias', 'text_encoder.encoder.layer.12.attention.self.value.weight', 'text_encoder.encoder.layer.12.attention.self.value.bias', 'text_encoder.encoder.layer.12.attention.output.dense.weight', 'text_encoder.encoder.layer.12.attention.output.dense.bias', 'text_encoder.encoder.layer.12.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.12.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.12.intermediate.dense.weight', 'text_encoder.encoder.layer.12.intermediate.dense.bias', 'text_encoder.encoder.layer.12.output.dense.weight', 'text_encoder.encoder.layer.12.output.dense.bias', 'text_encoder.encoder.layer.12.output.LayerNorm.weight', 'text_encoder.encoder.layer.12.output.LayerNorm.bias', 'text_encoder.encoder.layer.13.attention.self.query.weight', 'text_encoder.encoder.layer.13.attention.self.query.bias', 'text_encoder.encoder.layer.13.attention.self.key.weight', 'text_encoder.encoder.layer.13.attention.self.key.bias', 'text_encoder.encoder.layer.13.attention.self.value.weight', 'text_encoder.encoder.layer.13.attention.self.value.bias', 'text_encoder.encoder.layer.13.attention.output.dense.weight', 'text_encoder.encoder.layer.13.attention.output.dense.bias', 'text_encoder.encoder.layer.13.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.13.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.13.intermediate.dense.weight', 'text_encoder.encoder.layer.13.intermediate.dense.bias', 'text_encoder.encoder.layer.13.output.dense.weight', 'text_encoder.encoder.layer.13.output.dense.bias', 'text_encoder.encoder.layer.13.output.LayerNorm.weight', 'text_encoder.encoder.layer.13.output.LayerNorm.bias', 'text_encoder.encoder.layer.14.attention.self.query.weight', 'text_encoder.encoder.layer.14.attention.self.query.bias', 'text_encoder.encoder.layer.14.attention.self.key.weight', 'text_encoder.encoder.layer.14.attention.self.key.bias', 'text_encoder.encoder.layer.14.attention.self.value.weight', 'text_encoder.encoder.layer.14.attention.self.value.bias', 'text_encoder.encoder.layer.14.attention.output.dense.weight', 'text_encoder.encoder.layer.14.attention.output.dense.bias', 'text_encoder.encoder.layer.14.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.14.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.14.intermediate.dense.weight', 'text_encoder.encoder.layer.14.intermediate.dense.bias', 'text_encoder.encoder.layer.14.output.dense.weight', 'text_encoder.encoder.layer.14.output.dense.bias', 'text_encoder.encoder.layer.14.output.LayerNorm.weight', 'text_encoder.encoder.layer.14.output.LayerNorm.bias', 'text_encoder.encoder.layer.15.attention.self.query.weight', 'text_encoder.encoder.layer.15.attention.self.query.bias', 'text_encoder.encoder.layer.15.attention.self.key.weight', 'text_encoder.encoder.layer.15.attention.self.key.bias', 'text_encoder.encoder.layer.15.attention.self.value.weight', 'text_encoder.encoder.layer.15.attention.self.value.bias', 'text_encoder.encoder.layer.15.attention.output.dense.weight', 'text_encoder.encoder.layer.15.attention.output.dense.bias', 'text_encoder.encoder.layer.15.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.15.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.15.intermediate.dense.weight', 'text_encoder.encoder.layer.15.intermediate.dense.bias', 'text_encoder.encoder.layer.15.output.dense.weight', 'text_encoder.encoder.layer.15.output.dense.bias', 'text_encoder.encoder.layer.15.output.LayerNorm.weight', 'text_encoder.encoder.layer.15.output.LayerNorm.bias', 'text_encoder.encoder.layer.16.attention.self.query.weight', 'text_encoder.encoder.layer.16.attention.self.query.bias', 'text_encoder.encoder.layer.16.attention.self.key.weight', 'text_encoder.encoder.layer.16.attention.self.key.bias', 'text_encoder.encoder.layer.16.attention.self.value.weight', 'text_encoder.encoder.layer.16.attention.self.value.bias', 'text_encoder.encoder.layer.16.attention.output.dense.weight', 'text_encoder.encoder.layer.16.attention.output.dense.bias', 'text_encoder.encoder.layer.16.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.16.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.16.intermediate.dense.weight', 'text_encoder.encoder.layer.16.intermediate.dense.bias', 'text_encoder.encoder.layer.16.output.dense.weight', 'text_encoder.encoder.layer.16.output.dense.bias', 'text_encoder.encoder.layer.16.output.LayerNorm.weight', 'text_encoder.encoder.layer.16.output.LayerNorm.bias', 'text_encoder.encoder.layer.17.attention.self.query.weight', 'text_encoder.encoder.layer.17.attention.self.query.bias', 'text_encoder.encoder.layer.17.attention.self.key.weight', 'text_encoder.encoder.layer.17.attention.self.key.bias', 'text_encoder.encoder.layer.17.attention.self.value.weight', 'text_encoder.encoder.layer.17.attention.self.value.bias', 'text_encoder.encoder.layer.17.attention.output.dense.weight', 'text_encoder.encoder.layer.17.attention.output.dense.bias', 'text_encoder.encoder.layer.17.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.17.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.17.intermediate.dense.weight', 'text_encoder.encoder.layer.17.intermediate.dense.bias', 'text_encoder.encoder.layer.17.output.dense.weight', 'text_encoder.encoder.layer.17.output.dense.bias', 'text_encoder.encoder.layer.17.output.LayerNorm.weight', 'text_encoder.encoder.layer.17.output.LayerNorm.bias', 'text_encoder.encoder.layer.18.attention.self.query.weight', 'text_encoder.encoder.layer.18.attention.self.query.bias', 'text_encoder.encoder.layer.18.attention.self.key.weight', 'text_encoder.encoder.layer.18.attention.self.key.bias', 'text_encoder.encoder.layer.18.attention.self.value.weight', 'text_encoder.encoder.layer.18.attention.self.value.bias', 'text_encoder.encoder.layer.18.attention.output.dense.weight', 'text_encoder.encoder.layer.18.attention.output.dense.bias', 'text_encoder.encoder.layer.18.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.18.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.18.intermediate.dense.weight', 'text_encoder.encoder.layer.18.intermediate.dense.bias', 'text_encoder.encoder.layer.18.output.dense.weight', 'text_encoder.encoder.layer.18.output.dense.bias', 'text_encoder.encoder.layer.18.output.LayerNorm.weight', 'text_encoder.encoder.layer.18.output.LayerNorm.bias', 'text_encoder.encoder.layer.19.attention.self.query.weight', 'text_encoder.encoder.layer.19.attention.self.query.bias', 'text_encoder.encoder.layer.19.attention.self.key.weight', 'text_encoder.encoder.layer.19.attention.self.key.bias', 'text_encoder.encoder.layer.19.attention.self.value.weight', 'text_encoder.encoder.layer.19.attention.self.value.bias', 'text_encoder.encoder.layer.19.attention.output.dense.weight', 'text_encoder.encoder.layer.19.attention.output.dense.bias', 'text_encoder.encoder.layer.19.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.19.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.19.crossattention.self.query.weight', 'text_encoder.encoder.layer.19.crossattention.self.query.bias', 'text_encoder.encoder.layer.19.crossattention.self.key.weight', 'text_encoder.encoder.layer.19.crossattention.self.key.bias', 'text_encoder.encoder.layer.19.crossattention.self.value.weight', 'text_encoder.encoder.layer.19.crossattention.self.value.bias', 'text_encoder.encoder.layer.19.crossattention.output.dense.weight', 'text_encoder.encoder.layer.19.crossattention.output.dense.bias', 'text_encoder.encoder.layer.19.crossattention.output.LayerNorm.weight', 'text_encoder.encoder.layer.19.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.19.intermediate.dense.weight', 'text_encoder.encoder.layer.19.intermediate.dense.bias', 'text_encoder.encoder.layer.19.output.dense.weight', 'text_encoder.encoder.layer.19.output.dense.bias', 'text_encoder.encoder.layer.19.output.LayerNorm.weight', 'text_encoder.encoder.layer.19.output.LayerNorm.bias', 'text_encoder.encoder.layer.20.attention.self.query.weight', 'text_encoder.encoder.layer.20.attention.self.query.bias', 'text_encoder.encoder.layer.20.attention.self.key.weight', 'text_encoder.encoder.layer.20.attention.self.key.bias', 'text_encoder.encoder.layer.20.attention.self.value.weight', 'text_encoder.encoder.layer.20.attention.self.value.bias', 'text_encoder.encoder.layer.20.attention.output.dense.weight', 'text_encoder.encoder.layer.20.attention.output.dense.bias', 'text_encoder.encoder.layer.20.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.20.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.20.crossattention.self.query.weight', 'text_encoder.encoder.layer.20.crossattention.self.query.bias', 'text_encoder.encoder.layer.20.crossattention.self.key.weight', 'text_encoder.encoder.layer.20.crossattention.self.key.bias', 'text_encoder.encoder.layer.20.crossattention.self.value.weight', 'text_encoder.encoder.layer.20.crossattention.self.value.bias', 'text_encoder.encoder.layer.20.crossattention.output.dense.weight', 'text_encoder.encoder.layer.20.crossattention.output.dense.bias', 'text_encoder.encoder.layer.20.crossattention.output.LayerNorm.weight', 'text_encoder.encoder.layer.20.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.20.intermediate.dense.weight', 'text_encoder.encoder.layer.20.intermediate.dense.bias', 'text_encoder.encoder.layer.20.output.dense.weight', 'text_encoder.encoder.layer.20.output.dense.bias', 'text_encoder.encoder.layer.20.output.LayerNorm.weight', 'text_encoder.encoder.layer.20.output.LayerNorm.bias', 'text_encoder.encoder.layer.21.attention.self.query.weight', 'text_encoder.encoder.layer.21.attention.self.query.bias', 'text_encoder.encoder.layer.21.attention.self.key.weight', 'text_encoder.encoder.layer.21.attention.self.key.bias', 'text_encoder.encoder.layer.21.attention.self.value.weight', 'text_encoder.encoder.layer.21.attention.self.value.bias', 'text_encoder.encoder.layer.21.attention.output.dense.weight', 'text_encoder.encoder.layer.21.attention.output.dense.bias', 'text_encoder.encoder.layer.21.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.21.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.21.crossattention.self.query.weight', 'text_encoder.encoder.layer.21.crossattention.self.query.bias', 'text_encoder.encoder.layer.21.crossattention.self.key.weight', 'text_encoder.encoder.layer.21.crossattention.self.key.bias', 'text_encoder.encoder.layer.21.crossattention.self.value.weight', 'text_encoder.encoder.layer.21.crossattention.self.value.bias', 'text_encoder.encoder.layer.21.crossattention.output.dense.weight', 'text_encoder.encoder.layer.21.crossattention.output.dense.bias', 'text_encoder.encoder.layer.21.crossattention.output.LayerNorm.weight', 'text_encoder.encoder.layer.21.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.21.intermediate.dense.weight', 'text_encoder.encoder.layer.21.intermediate.dense.bias', 'text_encoder.encoder.layer.21.output.dense.weight', 'text_encoder.encoder.layer.21.output.dense.bias', 'text_encoder.encoder.layer.21.output.LayerNorm.weight', 'text_encoder.encoder.layer.21.output.LayerNorm.bias', 'text_encoder.encoder.layer.22.attention.self.query.weight', 'text_encoder.encoder.layer.22.attention.self.query.bias', 'text_encoder.encoder.layer.22.attention.self.key.weight', 'text_encoder.encoder.layer.22.attention.self.key.bias', 'text_encoder.encoder.layer.22.attention.self.value.weight', 'text_encoder.encoder.layer.22.attention.self.value.bias', 'text_encoder.encoder.layer.22.attention.output.dense.weight', 'text_encoder.encoder.layer.22.attention.output.dense.bias', 'text_encoder.encoder.layer.22.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.22.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.22.crossattention.self.query.weight', 'text_encoder.encoder.layer.22.crossattention.self.query.bias', 'text_encoder.encoder.layer.22.crossattention.self.key.weight', 'text_encoder.encoder.layer.22.crossattention.self.key.bias', 'text_encoder.encoder.layer.22.crossattention.self.value.weight', 'text_encoder.encoder.layer.22.crossattention.self.value.bias', 'text_encoder.encoder.layer.22.crossattention.output.dense.weight', 'text_encoder.encoder.layer.22.crossattention.output.dense.bias', 'text_encoder.encoder.layer.22.crossattention.output.LayerNorm.weight', 'text_encoder.encoder.layer.22.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.22.intermediate.dense.weight', 'text_encoder.encoder.layer.22.intermediate.dense.bias', 'text_encoder.encoder.layer.22.output.dense.weight', 'text_encoder.encoder.layer.22.output.dense.bias', 'text_encoder.encoder.layer.22.output.LayerNorm.weight', 'text_encoder.encoder.layer.22.output.LayerNorm.bias', 'text_encoder.encoder.layer.23.attention.self.query.weight', 'text_encoder.encoder.layer.23.attention.self.query.bias', 'text_encoder.encoder.layer.23.attention.self.key.weight', 'text_encoder.encoder.layer.23.attention.self.key.bias', 'text_encoder.encoder.layer.23.attention.self.value.weight', 'text_encoder.encoder.layer.23.attention.self.value.bias', 'text_encoder.encoder.layer.23.attention.output.dense.weight', 'text_encoder.encoder.layer.23.attention.output.dense.bias', 'text_encoder.encoder.layer.23.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.23.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.23.crossattention.self.query.weight', 'text_encoder.encoder.layer.23.crossattention.self.query.bias', 'text_encoder.encoder.layer.23.crossattention.self.key.weight', 'text_encoder.encoder.layer.23.crossattention.self.key.bias', 'text_encoder.encoder.layer.23.crossattention.self.value.weight', 'text_encoder.encoder.layer.23.crossattention.self.value.bias', 'text_encoder.encoder.layer.23.crossattention.output.dense.weight', 'text_encoder.encoder.layer.23.crossattention.output.dense.bias', 'text_encoder.encoder.layer.23.crossattention.output.LayerNorm.weight', 'text_encoder.encoder.layer.23.crossattention.output.LayerNorm.bias', 'text_encoder.encoder.layer.23.intermediate.dense.weight', 'text_encoder.encoder.layer.23.intermediate.dense.bias', 'text_encoder.encoder.layer.23.output.dense.weight', 'text_encoder.encoder.layer.23.output.dense.bias', 'text_encoder.encoder.layer.23.output.LayerNorm.weight', 'text_encoder.encoder.layer.23.output.LayerNorm.bias'], unexpected_keys=['temp', 'itm_head.weight', 'itm_head.bias', 'text_encoder.bert.embeddings.position_ids', 'text_encoder.bert.embeddings.word_embeddings.weight', 'text_encoder.bert.embeddings.position_embeddings.weight', 'text_encoder.bert.embeddings.token_type_embeddings.weight', 'text_encoder.bert.embeddings.LayerNorm.weight', 'text_encoder.bert.embeddings.LayerNorm.bias', 'text_encoder.bert.encoder.layer.0.attention.self.query.weight', 'text_encoder.bert.encoder.layer.0.attention.self.query.bias', 'text_encoder.bert.encoder.layer.0.attention.self.key.weight', 'text_encoder.bert.encoder.layer.0.attention.self.key.bias', 'text_encoder.bert.encoder.layer.0.attention.self.value.weight', 'text_encoder.bert.encoder.layer.0.attention.self.value.bias', 'text_encoder.bert.encoder.layer.0.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.0.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.0.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.0.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.0.output.dense.weight', 'text_encoder.bert.encoder.layer.0.output.dense.bias', 'text_encoder.bert.encoder.layer.0.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.0.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.1.attention.self.query.weight', 'text_encoder.bert.encoder.layer.1.attention.self.query.bias', 'text_encoder.bert.encoder.layer.1.attention.self.key.weight', 'text_encoder.bert.encoder.layer.1.attention.self.key.bias', 'text_encoder.bert.encoder.layer.1.attention.self.value.weight', 'text_encoder.bert.encoder.layer.1.attention.self.value.bias', 'text_encoder.bert.encoder.layer.1.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.1.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.1.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.1.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.1.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.1.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.1.output.dense.weight', 'text_encoder.bert.encoder.layer.1.output.dense.bias', 'text_encoder.bert.encoder.layer.1.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.1.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.2.attention.self.query.weight', 'text_encoder.bert.encoder.layer.2.attention.self.query.bias', 'text_encoder.bert.encoder.layer.2.attention.self.key.weight', 'text_encoder.bert.encoder.layer.2.attention.self.key.bias', 'text_encoder.bert.encoder.layer.2.attention.self.value.weight', 'text_encoder.bert.encoder.layer.2.attention.self.value.bias', 'text_encoder.bert.encoder.layer.2.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.2.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.2.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.2.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.2.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.2.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.2.output.dense.weight', 'text_encoder.bert.encoder.layer.2.output.dense.bias', 'text_encoder.bert.encoder.layer.2.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.2.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.3.attention.self.query.weight', 'text_encoder.bert.encoder.layer.3.attention.self.query.bias', 'text_encoder.bert.encoder.layer.3.attention.self.key.weight', 'text_encoder.bert.encoder.layer.3.attention.self.key.bias', 'text_encoder.bert.encoder.layer.3.attention.self.value.weight', 'text_encoder.bert.encoder.layer.3.attention.self.value.bias', 'text_encoder.bert.encoder.layer.3.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.3.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.3.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.3.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.3.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.3.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.3.output.dense.weight', 'text_encoder.bert.encoder.layer.3.output.dense.bias', 'text_encoder.bert.encoder.layer.3.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.3.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.4.attention.self.query.weight', 'text_encoder.bert.encoder.layer.4.attention.self.query.bias', 'text_encoder.bert.encoder.layer.4.attention.self.key.weight', 'text_encoder.bert.encoder.layer.4.attention.self.key.bias', 'text_encoder.bert.encoder.layer.4.attention.self.value.weight', 'text_encoder.bert.encoder.layer.4.attention.self.value.bias', 'text_encoder.bert.encoder.layer.4.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.4.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.4.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.4.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.4.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.4.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.4.output.dense.weight', 'text_encoder.bert.encoder.layer.4.output.dense.bias', 'text_encoder.bert.encoder.layer.4.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.4.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.5.attention.self.query.weight', 'text_encoder.bert.encoder.layer.5.attention.self.query.bias', 'text_encoder.bert.encoder.layer.5.attention.self.key.weight', 'text_encoder.bert.encoder.layer.5.attention.self.key.bias', 'text_encoder.bert.encoder.layer.5.attention.self.value.weight', 'text_encoder.bert.encoder.layer.5.attention.self.value.bias', 'text_encoder.bert.encoder.layer.5.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.5.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.5.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.5.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.5.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.5.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.5.output.dense.weight', 'text_encoder.bert.encoder.layer.5.output.dense.bias', 'text_encoder.bert.encoder.layer.5.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.5.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.6.attention.self.query.weight', 'text_encoder.bert.encoder.layer.6.attention.self.query.bias', 'text_encoder.bert.encoder.layer.6.attention.self.key.weight', 'text_encoder.bert.encoder.layer.6.attention.self.key.bias', 'text_encoder.bert.encoder.layer.6.attention.self.value.weight', 'text_encoder.bert.encoder.layer.6.attention.self.value.bias', 'text_encoder.bert.encoder.layer.6.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.6.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.6.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.6.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.6.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.6.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.6.output.dense.weight', 'text_encoder.bert.encoder.layer.6.output.dense.bias', 'text_encoder.bert.encoder.layer.6.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.6.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.7.attention.self.query.weight', 'text_encoder.bert.encoder.layer.7.attention.self.query.bias', 'text_encoder.bert.encoder.layer.7.attention.self.key.weight', 'text_encoder.bert.encoder.layer.7.attention.self.key.bias', 'text_encoder.bert.encoder.layer.7.attention.self.value.weight', 'text_encoder.bert.encoder.layer.7.attention.self.value.bias', 'text_encoder.bert.encoder.layer.7.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.7.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.7.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.7.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.7.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.7.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.7.output.dense.weight', 'text_encoder.bert.encoder.layer.7.output.dense.bias', 'text_encoder.bert.encoder.layer.7.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.7.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.8.attention.self.query.weight', 'text_encoder.bert.encoder.layer.8.attention.self.query.bias', 'text_encoder.bert.encoder.layer.8.attention.self.key.weight', 'text_encoder.bert.encoder.layer.8.attention.self.key.bias', 'text_encoder.bert.encoder.layer.8.attention.self.value.weight', 'text_encoder.bert.encoder.layer.8.attention.self.value.bias', 'text_encoder.bert.encoder.layer.8.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.8.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.8.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.8.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.8.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.8.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.8.output.dense.weight', 'text_encoder.bert.encoder.layer.8.output.dense.bias', 'text_encoder.bert.encoder.layer.8.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.8.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.9.attention.self.query.weight', 'text_encoder.bert.encoder.layer.9.attention.self.query.bias', 'text_encoder.bert.encoder.layer.9.attention.self.key.weight', 'text_encoder.bert.encoder.layer.9.attention.self.key.bias', 'text_encoder.bert.encoder.layer.9.attention.self.value.weight', 'text_encoder.bert.encoder.layer.9.attention.self.value.bias', 'text_encoder.bert.encoder.layer.9.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.9.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.9.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.9.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.9.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.9.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.9.output.dense.weight', 'text_encoder.bert.encoder.layer.9.output.dense.bias', 'text_encoder.bert.encoder.layer.9.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.9.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.10.attention.self.query.weight', 'text_encoder.bert.encoder.layer.10.attention.self.query.bias', 'text_encoder.bert.encoder.layer.10.attention.self.key.weight', 'text_encoder.bert.encoder.layer.10.attention.self.key.bias', 'text_encoder.bert.encoder.layer.10.attention.self.value.weight', 'text_encoder.bert.encoder.layer.10.attention.self.value.bias', 'text_encoder.bert.encoder.layer.10.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.10.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.10.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.10.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.10.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.10.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.10.output.dense.weight', 'text_encoder.bert.encoder.layer.10.output.dense.bias', 'text_encoder.bert.encoder.layer.10.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.10.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.11.attention.self.query.weight', 'text_encoder.bert.encoder.layer.11.attention.self.query.bias', 'text_encoder.bert.encoder.layer.11.attention.self.key.weight', 'text_encoder.bert.encoder.layer.11.attention.self.key.bias', 'text_encoder.bert.encoder.layer.11.attention.self.value.weight', 'text_encoder.bert.encoder.layer.11.attention.self.value.bias', 'text_encoder.bert.encoder.layer.11.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.11.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.11.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.11.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.11.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.11.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.11.output.dense.weight', 'text_encoder.bert.encoder.layer.11.output.dense.bias', 'text_encoder.bert.encoder.layer.11.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.11.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.12.attention.self.query.weight', 'text_encoder.bert.encoder.layer.12.attention.self.query.bias', 'text_encoder.bert.encoder.layer.12.attention.self.key.weight', 'text_encoder.bert.encoder.layer.12.attention.self.key.bias', 'text_encoder.bert.encoder.layer.12.attention.self.value.weight', 'text_encoder.bert.encoder.layer.12.attention.self.value.bias', 'text_encoder.bert.encoder.layer.12.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.12.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.12.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.12.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.12.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.12.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.12.output.dense.weight', 'text_encoder.bert.encoder.layer.12.output.dense.bias', 'text_encoder.bert.encoder.layer.12.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.12.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.13.attention.self.query.weight', 'text_encoder.bert.encoder.layer.13.attention.self.query.bias', 'text_encoder.bert.encoder.layer.13.attention.self.key.weight', 'text_encoder.bert.encoder.layer.13.attention.self.key.bias', 'text_encoder.bert.encoder.layer.13.attention.self.value.weight', 'text_encoder.bert.encoder.layer.13.attention.self.value.bias', 'text_encoder.bert.encoder.layer.13.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.13.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.13.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.13.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.13.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.13.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.13.output.dense.weight', 'text_encoder.bert.encoder.layer.13.output.dense.bias', 'text_encoder.bert.encoder.layer.13.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.13.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.14.attention.self.query.weight', 'text_encoder.bert.encoder.layer.14.attention.self.query.bias', 'text_encoder.bert.encoder.layer.14.attention.self.key.weight', 'text_encoder.bert.encoder.layer.14.attention.self.key.bias', 'text_encoder.bert.encoder.layer.14.attention.self.value.weight', 'text_encoder.bert.encoder.layer.14.attention.self.value.bias', 'text_encoder.bert.encoder.layer.14.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.14.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.14.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.14.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.14.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.14.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.14.output.dense.weight', 'text_encoder.bert.encoder.layer.14.output.dense.bias', 'text_encoder.bert.encoder.layer.14.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.14.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.15.attention.self.query.weight', 'text_encoder.bert.encoder.layer.15.attention.self.query.bias', 'text_encoder.bert.encoder.layer.15.attention.self.key.weight', 'text_encoder.bert.encoder.layer.15.attention.self.key.bias', 'text_encoder.bert.encoder.layer.15.attention.self.value.weight', 'text_encoder.bert.encoder.layer.15.attention.self.value.bias', 'text_encoder.bert.encoder.layer.15.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.15.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.15.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.15.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.15.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.15.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.15.output.dense.weight', 'text_encoder.bert.encoder.layer.15.output.dense.bias', 'text_encoder.bert.encoder.layer.15.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.15.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.16.attention.self.query.weight', 'text_encoder.bert.encoder.layer.16.attention.self.query.bias', 'text_encoder.bert.encoder.layer.16.attention.self.key.weight', 'text_encoder.bert.encoder.layer.16.attention.self.key.bias', 'text_encoder.bert.encoder.layer.16.attention.self.value.weight', 'text_encoder.bert.encoder.layer.16.attention.self.value.bias', 'text_encoder.bert.encoder.layer.16.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.16.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.16.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.16.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.16.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.16.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.16.output.dense.weight', 'text_encoder.bert.encoder.layer.16.output.dense.bias', 'text_encoder.bert.encoder.layer.16.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.16.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.17.attention.self.query.weight', 'text_encoder.bert.encoder.layer.17.attention.self.query.bias', 'text_encoder.bert.encoder.layer.17.attention.self.key.weight', 'text_encoder.bert.encoder.layer.17.attention.self.key.bias', 'text_encoder.bert.encoder.layer.17.attention.self.value.weight', 'text_encoder.bert.encoder.layer.17.attention.self.value.bias', 'text_encoder.bert.encoder.layer.17.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.17.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.17.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.17.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.17.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.17.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.17.output.dense.weight', 'text_encoder.bert.encoder.layer.17.output.dense.bias', 'text_encoder.bert.encoder.layer.17.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.17.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.18.attention.self.query.weight', 'text_encoder.bert.encoder.layer.18.attention.self.query.bias', 'text_encoder.bert.encoder.layer.18.attention.self.key.weight', 'text_encoder.bert.encoder.layer.18.attention.self.key.bias', 'text_encoder.bert.encoder.layer.18.attention.self.value.weight', 'text_encoder.bert.encoder.layer.18.attention.self.value.bias', 'text_encoder.bert.encoder.layer.18.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.18.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.18.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.18.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.18.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.18.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.18.output.dense.weight', 'text_encoder.bert.encoder.layer.18.output.dense.bias', 'text_encoder.bert.encoder.layer.18.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.18.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.19.attention.self.query.weight', 'text_encoder.bert.encoder.layer.19.attention.self.query.bias', 'text_encoder.bert.encoder.layer.19.attention.self.key.weight', 'text_encoder.bert.encoder.layer.19.attention.self.key.bias', 'text_encoder.bert.encoder.layer.19.attention.self.value.weight', 'text_encoder.bert.encoder.layer.19.attention.self.value.bias', 'text_encoder.bert.encoder.layer.19.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.19.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.19.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.19.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.19.crossattention.self.query.weight', 'text_encoder.bert.encoder.layer.19.crossattention.self.query.bias', 'text_encoder.bert.encoder.layer.19.crossattention.self.key.weight', 'text_encoder.bert.encoder.layer.19.crossattention.self.key.bias', 'text_encoder.bert.encoder.layer.19.crossattention.self.value.weight', 'text_encoder.bert.encoder.layer.19.crossattention.self.value.bias', 'text_encoder.bert.encoder.layer.19.crossattention.output.dense.weight', 'text_encoder.bert.encoder.layer.19.crossattention.output.dense.bias', 'text_encoder.bert.encoder.layer.19.crossattention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.19.crossattention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.19.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.19.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.19.output.dense.weight', 'text_encoder.bert.encoder.layer.19.output.dense.bias', 'text_encoder.bert.encoder.layer.19.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.19.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.20.attention.self.query.weight', 'text_encoder.bert.encoder.layer.20.attention.self.query.bias', 'text_encoder.bert.encoder.layer.20.attention.self.key.weight', 'text_encoder.bert.encoder.layer.20.attention.self.key.bias', 'text_encoder.bert.encoder.layer.20.attention.self.value.weight', 'text_encoder.bert.encoder.layer.20.attention.self.value.bias', 'text_encoder.bert.encoder.layer.20.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.20.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.20.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.20.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.20.crossattention.self.query.weight', 'text_encoder.bert.encoder.layer.20.crossattention.self.query.bias', 'text_encoder.bert.encoder.layer.20.crossattention.self.key.weight', 'text_encoder.bert.encoder.layer.20.crossattention.self.key.bias', 'text_encoder.bert.encoder.layer.20.crossattention.self.value.weight', 'text_encoder.bert.encoder.layer.20.crossattention.self.value.bias', 'text_encoder.bert.encoder.layer.20.crossattention.output.dense.weight', 'text_encoder.bert.encoder.layer.20.crossattention.output.dense.bias', 'text_encoder.bert.encoder.layer.20.crossattention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.20.crossattention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.20.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.20.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.20.output.dense.weight', 'text_encoder.bert.encoder.layer.20.output.dense.bias', 'text_encoder.bert.encoder.layer.20.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.20.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.21.attention.self.query.weight', 'text_encoder.bert.encoder.layer.21.attention.self.query.bias', 'text_encoder.bert.encoder.layer.21.attention.self.key.weight', 'text_encoder.bert.encoder.layer.21.attention.self.key.bias', 'text_encoder.bert.encoder.layer.21.attention.self.value.weight', 'text_encoder.bert.encoder.layer.21.attention.self.value.bias', 'text_encoder.bert.encoder.layer.21.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.21.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.21.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.21.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.21.crossattention.self.query.weight', 'text_encoder.bert.encoder.layer.21.crossattention.self.query.bias', 'text_encoder.bert.encoder.layer.21.crossattention.self.key.weight', 'text_encoder.bert.encoder.layer.21.crossattention.self.key.bias', 'text_encoder.bert.encoder.layer.21.crossattention.self.value.weight', 'text_encoder.bert.encoder.layer.21.crossattention.self.value.bias', 'text_encoder.bert.encoder.layer.21.crossattention.output.dense.weight', 'text_encoder.bert.encoder.layer.21.crossattention.output.dense.bias', 'text_encoder.bert.encoder.layer.21.crossattention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.21.crossattention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.21.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.21.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.21.output.dense.weight', 'text_encoder.bert.encoder.layer.21.output.dense.bias', 'text_encoder.bert.encoder.layer.21.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.21.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.22.attention.self.query.weight', 'text_encoder.bert.encoder.layer.22.attention.self.query.bias', 'text_encoder.bert.encoder.layer.22.attention.self.key.weight', 'text_encoder.bert.encoder.layer.22.attention.self.key.bias', 'text_encoder.bert.encoder.layer.22.attention.self.value.weight', 'text_encoder.bert.encoder.layer.22.attention.self.value.bias', 'text_encoder.bert.encoder.layer.22.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.22.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.22.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.22.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.22.crossattention.self.query.weight', 'text_encoder.bert.encoder.layer.22.crossattention.self.query.bias', 'text_encoder.bert.encoder.layer.22.crossattention.self.key.weight', 'text_encoder.bert.encoder.layer.22.crossattention.self.key.bias', 'text_encoder.bert.encoder.layer.22.crossattention.self.value.weight', 'text_encoder.bert.encoder.layer.22.crossattention.self.value.bias', 'text_encoder.bert.encoder.layer.22.crossattention.output.dense.weight', 'text_encoder.bert.encoder.layer.22.crossattention.output.dense.bias', 'text_encoder.bert.encoder.layer.22.crossattention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.22.crossattention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.22.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.22.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.22.output.dense.weight', 'text_encoder.bert.encoder.layer.22.output.dense.bias', 'text_encoder.bert.encoder.layer.22.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.22.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.23.attention.self.query.weight', 'text_encoder.bert.encoder.layer.23.attention.self.query.bias', 'text_encoder.bert.encoder.layer.23.attention.self.key.weight', 'text_encoder.bert.encoder.layer.23.attention.self.key.bias', 'text_encoder.bert.encoder.layer.23.attention.self.value.weight', 'text_encoder.bert.encoder.layer.23.attention.self.value.bias', 'text_encoder.bert.encoder.layer.23.attention.output.dense.weight', 'text_encoder.bert.encoder.layer.23.attention.output.dense.bias', 'text_encoder.bert.encoder.layer.23.attention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.23.attention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.23.crossattention.self.query.weight', 'text_encoder.bert.encoder.layer.23.crossattention.self.query.bias', 'text_encoder.bert.encoder.layer.23.crossattention.self.key.weight', 'text_encoder.bert.encoder.layer.23.crossattention.self.key.bias', 'text_encoder.bert.encoder.layer.23.crossattention.self.value.weight', 'text_encoder.bert.encoder.layer.23.crossattention.self.value.bias', 'text_encoder.bert.encoder.layer.23.crossattention.output.dense.weight', 'text_encoder.bert.encoder.layer.23.crossattention.output.dense.bias', 'text_encoder.bert.encoder.layer.23.crossattention.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.23.crossattention.output.LayerNorm.bias', 'text_encoder.bert.encoder.layer.23.intermediate.dense.weight', 'text_encoder.bert.encoder.layer.23.intermediate.dense.bias', 'text_encoder.bert.encoder.layer.23.output.dense.weight', 'text_encoder.bert.encoder.layer.23.output.dense.bias', 'text_encoder.bert.encoder.layer.23.output.LayerNorm.weight', 'text_encoder.bert.encoder.layer.23.output.LayerNorm.bias', 'text_encoder.cls.predictions.bias', 'text_encoder.cls.predictions.transform.dense.weight', 'text_encoder.cls.predictions.transform.dense.bias', 'text_encoder.cls.predictions.transform.LayerNorm.weight', 'text_encoder.cls.predictions.transform.LayerNorm.bias', 'text_encoder.cls.predictions.decoder.weight', 'text_encoder.cls.predictions.decoder.bias'])\n"
+ ]
+ }
+ ],
+ "source": [
+ "if 'intern_model' in locals():\n",
+ " del intern_model\n",
+ " del tokenizer\n",
+ "config = Config.from_file('/home/toolkit/eai_urlb/InternVideo/InternVideo2/multi_modality/demo/internvideo2_stage2_config.py')\n",
+ "config = eval_dict_leaf(config)\n",
+ "config.model.vision_encoder.num_frames = 8\n",
+ "config.num_frames = 8\n",
+ "config.num_frames_test = 8\n",
+ "config.model.text_encoder.pretrained = '/home/toolkit/.cache/huggingface/hub/models--bert-large-uncased/snapshots/6da4b6a26a1877e173fca3225479512db81a5e5b/'\n",
+ "config.model.text_encoder.config = '/home/toolkit/eai_urlb/InternVideo/InternVideo2/multi_modality/' + config.model.text_encoder.config\n",
+ "model_pth = '/home/toolkit/eai_urlb/InternVideo/InternVideo2/download_models/InternVideo2-stage2_1b-224p-f4.pt'\n",
+ "config.pretrained_path = model_pth\n",
+ "config['model']['vision_encoder']['pretrained'] = model_pth\n",
+ "intern_model, tokenizer = setup_internvideo2(config) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Video tensor([0.0030], device='cuda:0')\n",
+ "Text tensor([-0.0008, -0.0001, -0.0013, -0.0014, 0.0005, -0.0004, -0.0004, -0.0006,\n",
+ " 0.0001, -0.0003, 0.0003, 0.0012, -0.0004, 0.0007, -0.0014, -0.0017,\n",
+ " -0.0007, -0.0018, -0.0006, -0.0024], device='cuda:0')\n",
+ "text: Somebody walking ~ prob: 0.6945\n",
+ "text: Playing with hat ~ prob: 0.1198\n",
+ "text: A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride. ~ prob: 0.0297\n",
+ "text: A person with a hat ~ prob: 0.0245\n",
+ "text: A person dressed in a blue jacket shovels the snow-covered pavement outside their house. ~ prob: 0.0226\n",
+ "text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.0222\n",
+ "text: A group of friends playing bowling. ~ prob: 0.0212\n",
+ "text: A person lying in bed ~ prob: 0.0208\n",
+ "text: A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery. ~ prob: 0.0186\n",
+ "text: A person in a blue jacket walks their pet on a leash, enjoying a peaceful winter walk among the trees. ~ prob: 0.0102\n",
+ "text: A person playing with a kid in the street ~ prob: 0.0045\n",
+ "text: A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys. ~ prob: 0.0025\n",
+ "text: A playful dog slides down a snowy hill, wagging its tail with delight. ~ prob: 0.0024\n",
+ "text: A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees. ~ prob: 0.0015\n"
+ ]
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "intern_model.eval()\n",
+ "texts, probs = retrieve_text(frames, text_candidates, model=intern_model, topk=14, config=config)\n",
+ "\n",
+ "# Video tensor([0.0023], device='cuda:0')\n",
+ "# Text tensor([-0.0008, -0.0001, -0.0013, -0.0014, 0.0005, -0.0004, -0.0004, -0.0006,\n",
+ "# 0.0001, -0.0003, 0.0003, 0.0012, -0.0004, 0.0007, -0.0014, -0.0017,\n",
+ "# -0.0007, -0.0018, -0.0006], device='cuda:0')\n",
+ "# text: A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery. ~ prob: 0.4592\n",
+ "# text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.1335\n",
+ "# text: A japanese girl eating noodles ~ prob: 0.1089\n",
+ "\n",
+ "for t, p in zip(texts, probs):\n",
+ " print(f'text: {t} ~ prob: {p:.4f}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Parameter containing:\n",
+ "tensor([[[[[-4.7913e-03, -2.1515e-03, -2.0447e-03, ..., 3.2997e-04,\n",
+ " -3.0212e-03, -7.9727e-04],\n",
+ " [ 9.7656e-04, 2.4567e-03, 9.8419e-04, ..., -1.8845e-03,\n",
+ " 2.3193e-03, 3.6621e-03],\n",
+ " [-3.5095e-04, 2.5940e-03, -2.7618e-03, ..., -3.7956e-04,\n",
+ " -3.1948e-05, 1.7166e-03],\n",
+ " ...,\n",
+ " [-3.8605e-03, -1.1215e-03, -9.0790e-04, ..., 6.5994e-04,\n",
+ " 1.0071e-03, 1.2894e-03],\n",
+ " [-2.2278e-03, 1.2589e-03, -1.0204e-04, ..., 3.7079e-03,\n",
+ " 1.5354e-04, -8.3160e-04],\n",
+ " [ 6.3324e-04, 1.4114e-03, 9.5367e-04, ..., -3.4485e-03,\n",
+ " -1.8234e-03, -4.0283e-03]]],\n",
+ "\n",
+ "\n",
+ " [[[ 7.0190e-04, -1.3657e-03, -6.5994e-04, ..., 1.4725e-03,\n",
+ " -8.5831e-04, 1.6212e-04],\n",
+ " [ 1.7262e-04, 8.0872e-04, 4.1485e-05, ..., -6.4850e-04,\n",
+ " 5.5695e-04, 1.7242e-03],\n",
+ " [ 1.3504e-03, 3.2959e-03, -1.3275e-03, ..., 2.2736e-03,\n",
+ " 4.2725e-04, 1.9150e-03],\n",
+ " ...,\n",
+ " [-2.3041e-03, -6.4850e-04, -2.8839e-03, ..., 2.9755e-04,\n",
+ " -3.0518e-04, 1.2817e-03],\n",
+ " [ 9.3079e-04, -1.2512e-03, -1.5335e-03, ..., 1.9455e-03,\n",
+ " -3.4142e-04, -1.2054e-03],\n",
+ " [ 9.1553e-03, 3.6774e-03, 2.2125e-03, ..., -5.3883e-05,\n",
+ " 3.2234e-04, 2.3499e-03]]],\n",
+ "\n",
+ "\n",
+ " [[[ 2.0752e-03, 7.4768e-04, 2.6512e-04, ..., 2.3193e-03,\n",
+ " -3.3379e-04, -9.2983e-05],\n",
+ " [ 1.4725e-03, 1.0986e-03, -8.8692e-05, ..., -2.8229e-04,\n",
+ " 7.2098e-04, -2.2888e-03],\n",
+ " [ 1.3809e-03, 1.5945e-03, 6.5231e-04, ..., 3.3112e-03,\n",
+ " 2.1515e-03, -1.4114e-03],\n",
+ " ...,\n",
+ " [-1.2512e-03, 1.0605e-03, 5.6744e-05, ..., -4.7112e-04,\n",
+ " -3.4714e-04, -1.6861e-03],\n",
+ " [-5.4550e-04, 1.1978e-03, 1.9531e-03, ..., 7.6675e-04,\n",
+ " -1.9150e-03, -1.6937e-03],\n",
+ " [-4.5776e-03, -3.0212e-03, -1.4648e-03, ..., -1.0757e-03,\n",
+ " 1.0061e-04, 2.9449e-03]]]],\n",
+ "\n",
+ "\n",
+ "\n",
+ " [[[[-9.9487e-03, -5.9814e-03, 3.9673e-03, ..., 7.8125e-03,\n",
+ " 4.5776e-03, -4.7607e-03],\n",
+ " [-7.5989e-03, 2.5940e-04, -6.0730e-03, ..., -1.4725e-03,\n",
+ " -3.8300e-03, -2.4567e-03],\n",
+ " [ 6.9427e-04, 3.1090e-04, -2.1515e-03, ..., -1.2779e-04,\n",
+ " -6.0120e-03, -1.4191e-03],\n",
+ " ...,\n",
+ " [ 1.1597e-02, -8.3447e-05, -1.3428e-03, ..., -4.4556e-03,\n",
+ " -4.4823e-04, -1.6861e-03],\n",
+ " [ 4.3640e-03, -2.0447e-03, -1.3123e-03, ..., -4.4556e-03,\n",
+ " -4.0283e-03, -4.6387e-03],\n",
+ " [ 9.2163e-03, -5.1880e-03, 1.3351e-03, ..., 4.7112e-04,\n",
+ " 2.6550e-03, 4.9744e-03]]],\n",
+ "\n",
+ "\n",
+ " [[[ 4.3030e-03, -6.3171e-03, -1.2436e-03, ..., 2.1210e-03,\n",
+ " -9.4250e-07, -1.0559e-02],\n",
+ " [-1.2436e-03, -4.1504e-03, -9.7046e-03, ..., -2.8687e-03,\n",
+ " -6.9885e-03, -9.7046e-03],\n",
+ " [ 6.2561e-04, -5.7678e-03, -3.9978e-03, ..., -1.9989e-03,\n",
+ " -4.5166e-03, -5.5542e-03],\n",
+ " ...,\n",
+ " [ 7.6904e-03, -2.9144e-03, -2.0905e-03, ..., -3.9368e-03,\n",
+ " 2.1515e-03, -3.9062e-03],\n",
+ " [-6.9427e-04, -2.9907e-03, -1.2512e-03, ..., 1.6785e-03,\n",
+ " 5.8594e-03, -2.0294e-03],\n",
+ " [ 7.8678e-05, -6.0730e-03, 1.0834e-03, ..., 2.9564e-04,\n",
+ " 3.1738e-03, -8.4839e-03]]],\n",
+ "\n",
+ "\n",
+ " [[[ 4.2419e-03, -7.5073e-03, -2.8381e-03, ..., -7.7515e-03,\n",
+ " -6.6223e-03, 2.1667e-03],\n",
+ " [ 8.1787e-03, 5.8899e-03, 1.0376e-03, ..., -1.8463e-03,\n",
+ " -3.1281e-03, 5.8899e-03],\n",
+ " [ 1.7776e-03, -4.2915e-04, 8.6975e-04, ..., -4.2915e-05,\n",
+ " -3.2043e-03, 9.5825e-03],\n",
+ " ...,\n",
+ " [ 2.7847e-04, -1.9989e-03, -5.2490e-03, ..., -5.7068e-03,\n",
+ " -4.5776e-04, -3.5095e-03],\n",
+ " [-5.3406e-03, -6.9427e-04, -4.9133e-03, ..., -1.0910e-03,\n",
+ " -6.4468e-04, -5.1880e-03],\n",
+ " [-7.3853e-03, -2.1210e-03, 4.7302e-03, ..., 2.0752e-03,\n",
+ " -2.0447e-03, -1.2329e-02]]]],\n",
+ "\n",
+ "\n",
+ "\n",
+ " [[[[ 4.4861e-03, -2.8992e-03, -4.7302e-03, ..., -5.3406e-03,\n",
+ " -4.6692e-03, -4.6387e-03],\n",
+ " [ 1.8921e-03, -5.6458e-03, -3.7079e-03, ..., -2.5482e-03,\n",
+ " -4.8218e-03, 2.1515e-03],\n",
+ " [ 4.5471e-03, 2.9755e-04, -3.7842e-03, ..., 3.6774e-03,\n",
+ " -2.6550e-03, -1.8845e-03],\n",
+ " ...,\n",
+ " [ 7.2098e-04, 3.1281e-03, 2.0027e-04, ..., 2.7924e-03,\n",
+ " 1.0986e-03, 3.4943e-03],\n",
+ " [ 1.4496e-03, -2.8229e-04, 7.0801e-03, ..., 1.0071e-03,\n",
+ " -3.9978e-03, 3.7689e-03],\n",
+ " [ 9.9945e-04, 7.3624e-04, 9.7046e-03, ..., 3.9673e-03,\n",
+ " 6.7139e-03, 1.1414e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[-3.5400e-03, -6.9809e-04, 3.9673e-03, ..., 7.1716e-04,\n",
+ " 2.3651e-03, 1.6098e-03],\n",
+ " [-1.7319e-03, 8.0109e-04, 2.7466e-03, ..., -1.7262e-04,\n",
+ " -1.6937e-03, 6.1340e-03],\n",
+ " [-3.9978e-03, 2.0599e-03, -2.4414e-03, ..., 2.2888e-03,\n",
+ " 2.2736e-03, 4.1809e-03],\n",
+ " ...,\n",
+ " [-6.6223e-03, -1.0529e-03, -3.0823e-03, ..., 1.2894e-03,\n",
+ " 1.7624e-03, -6.0425e-03],\n",
+ " [ 7.5531e-04, -2.0599e-03, 2.0142e-03, ..., 3.3569e-03,\n",
+ " 1.8215e-04, -7.1411e-03],\n",
+ " [ 5.5237e-03, 2.3842e-04, 7.2937e-03, ..., -4.1809e-03,\n",
+ " -4.4861e-03, -1.7700e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[-7.1716e-04, -1.5488e-03, -2.5635e-03, ..., 1.8692e-03,\n",
+ " 5.4016e-03, 3.8300e-03],\n",
+ " [ 1.9531e-03, -9.1934e-04, 2.0981e-05, ..., -5.3024e-04,\n",
+ " -1.9989e-03, 1.1778e-04],\n",
+ " [ 2.5635e-03, 4.4556e-03, -3.9978e-03, ..., 1.7548e-03,\n",
+ " 1.4114e-04, -1.2817e-03],\n",
+ " ...,\n",
+ " [-1.9226e-03, -4.1389e-04, -4.2114e-03, ..., 8.2016e-04,\n",
+ " 5.0964e-03, 2.5330e-03],\n",
+ " [ 8.4229e-03, 1.8539e-03, 1.4038e-03, ..., 2.4109e-03,\n",
+ " 1.8616e-03, -1.0300e-03],\n",
+ " [-8.6784e-05, -7.3547e-03, -1.5182e-03, ..., 1.5335e-03,\n",
+ " 2.2736e-03, -8.4839e-03]]]],\n",
+ "\n",
+ "\n",
+ "\n",
+ " ...,\n",
+ "\n",
+ "\n",
+ "\n",
+ " [[[[-1.2207e-02, 4.0771e-02, -2.9419e-02, ..., 6.5918e-02,\n",
+ " -2.6978e-02, 3.0640e-02],\n",
+ " [ 7.8125e-02, 4.8828e-02, -5.3955e-02, ..., 4.3945e-02,\n",
+ " -3.3447e-02, 3.4424e-02],\n",
+ " [-1.3489e-02, 7.2021e-03, -5.0293e-02, ..., -1.9043e-02,\n",
+ " -4.4189e-02, -3.7354e-02],\n",
+ " ...,\n",
+ " [-3.1250e-02, -1.1047e-02, 5.7617e-02, ..., -1.9287e-02,\n",
+ " 6.2500e-02, -6.5308e-03],\n",
+ " [-3.1738e-02, 5.6152e-03, -1.0986e-02, ..., -5.6763e-03,\n",
+ " 2.3804e-02, 6.2500e-02],\n",
+ " [-3.5400e-02, -4.4861e-03, 3.7109e-02, ..., 3.3691e-02,\n",
+ " -7.1777e-02, 9.3750e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[-1.5332e-01, 3.3203e-02, -7.8125e-02, ..., 6.9824e-02,\n",
+ " -1.1902e-02, 6.1340e-03],\n",
+ " [ 5.1117e-04, 2.8809e-02, -7.9102e-02, ..., 5.4932e-02,\n",
+ " -8.6670e-03, 2.2827e-02],\n",
+ " [-7.3853e-03, 6.6528e-03, -4.9561e-02, ..., 5.6076e-04,\n",
+ " -3.0029e-02, -1.6724e-02],\n",
+ " ...,\n",
+ " [ 6.7871e-02, 5.0049e-02, 7.0801e-02, ..., -6.1646e-03,\n",
+ " 7.9102e-02, -4.5410e-02],\n",
+ " [ 3.5706e-03, 6.5308e-03, -9.0942e-03, ..., -2.1210e-03,\n",
+ " 4.9561e-02, 4.6143e-02],\n",
+ " [ 5.6152e-02, 1.8311e-02, 4.5898e-02, ..., 2.8564e-02,\n",
+ " -9.3262e-02, -4.9316e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[-5.1270e-02, 9.4727e-02, 2.1973e-02, ..., 4.9072e-02,\n",
+ " -6.0547e-02, 6.9580e-03],\n",
+ " [-1.7578e-02, 1.5869e-02, -5.0293e-02, ..., 9.4604e-03,\n",
+ " -6.9336e-02, 7.2098e-04],\n",
+ " [-6.4453e-02, -1.0620e-02, -7.0801e-02, ..., 3.5156e-02,\n",
+ " -4.1016e-02, -3.4912e-02],\n",
+ " ...,\n",
+ " [-5.2185e-03, 3.0640e-02, 7.5195e-02, ..., 4.1260e-02,\n",
+ " 8.6426e-02, -2.6367e-02],\n",
+ " [ 2.8076e-03, 5.4626e-03, 2.0874e-02, ..., 1.0452e-03,\n",
+ " -1.2207e-02, 2.1973e-03],\n",
+ " [-7.7148e-02, -5.0781e-02, 3.3936e-02, ..., 1.7334e-02,\n",
+ " -1.2988e-01, -5.0781e-02]]]],\n",
+ "\n",
+ "\n",
+ "\n",
+ " [[[[-5.6885e-02, -1.1035e-01, -2.1118e-02, ..., 5.2002e-02,\n",
+ " -7.9346e-03, -5.3711e-02],\n",
+ " [ 2.9053e-02, 1.7944e-02, -1.0315e-02, ..., 3.6621e-02,\n",
+ " 3.3936e-02, -1.4587e-02],\n",
+ " [ 1.5259e-04, -2.6245e-02, -9.5703e-02, ..., 9.5825e-03,\n",
+ " 6.3965e-02, 2.9907e-02],\n",
+ " ...,\n",
+ " [ 3.8818e-02, 2.9907e-02, 2.7710e-02, ..., -1.0938e-01,\n",
+ " -4.6387e-02, 5.1575e-03],\n",
+ " [-1.7212e-02, 1.6235e-02, -6.0547e-02, ..., -2.7710e-02,\n",
+ " -5.9204e-03, 2.5024e-02],\n",
+ " [-4.1504e-02, 1.3794e-02, -8.2520e-02, ..., 4.7852e-02,\n",
+ " 8.2520e-02, -9.4238e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[ 3.7842e-02, -9.7656e-02, -4.6143e-02, ..., 2.5635e-02,\n",
+ " -1.6479e-02, 1.9531e-03],\n",
+ " [ 7.7148e-02, 4.0771e-02, -3.4027e-03, ..., 1.4648e-02,\n",
+ " -2.0142e-02, -8.6670e-03],\n",
+ " [ 6.1035e-02, 6.5308e-03, -2.5635e-02, ..., 5.0049e-02,\n",
+ " 4.8828e-03, -2.4780e-02],\n",
+ " ...,\n",
+ " [ 9.9487e-03, -6.6528e-03, -1.1353e-02, ..., -1.1572e-01,\n",
+ " 1.9043e-02, 5.9082e-02],\n",
+ " [ 3.8086e-02, 4.2725e-02, -4.7363e-02, ..., 7.5989e-03,\n",
+ " 6.4087e-03, 7.3853e-03],\n",
+ " [-7.2632e-03, 9.0820e-02, -5.2979e-02, ..., 4.4678e-02,\n",
+ " 6.6895e-02, -1.0303e-01]]],\n",
+ "\n",
+ "\n",
+ " [[[ 1.4526e-02, -1.1475e-01, -3.8818e-02, ..., 8.1055e-02,\n",
+ " 3.8818e-02, 2.1118e-02],\n",
+ " [ 1.2354e-01, 5.1025e-02, 2.1973e-03, ..., 1.8677e-02,\n",
+ " 1.5991e-02, -3.3203e-02],\n",
+ " [ 7.7148e-02, 6.9336e-02, -3.7842e-02, ..., -3.0151e-02,\n",
+ " 3.9062e-02, -2.2217e-02],\n",
+ " ...,\n",
+ " [ 4.3701e-02, 8.2520e-02, 1.0156e-01, ..., -8.5449e-02,\n",
+ " -3.0060e-03, 1.0547e-01],\n",
+ " [ 3.0640e-02, 8.0078e-02, 2.1606e-02, ..., -2.0264e-02,\n",
+ " 1.9287e-02, 7.8613e-02],\n",
+ " [ 4.4678e-02, 9.7168e-02, -4.9561e-02, ..., 3.6377e-02,\n",
+ " 1.3477e-01, -2.7771e-03]]]],\n",
+ "\n",
+ "\n",
+ "\n",
+ " [[[[-2.0996e-01, -4.6387e-02, -5.6458e-03, ..., 1.7334e-02,\n",
+ " 4.6082e-03, -1.4038e-02],\n",
+ " [-2.8931e-02, 2.0020e-02, -8.7891e-03, ..., -8.2520e-02,\n",
+ " -5.2002e-02, -1.5869e-03],\n",
+ " [ 7.1289e-02, 4.3335e-03, 1.1047e-02, ..., -7.5684e-03,\n",
+ " -1.7456e-02, 1.5137e-02],\n",
+ " ...,\n",
+ " [ 5.0781e-02, 4.3213e-02, -8.1055e-02, ..., -3.9062e-02,\n",
+ " -1.0693e-01, -2.9175e-02],\n",
+ " [ 6.2500e-02, -3.9062e-03, -7.8735e-03, ..., -3.6377e-02,\n",
+ " -4.8340e-02, 4.8340e-02],\n",
+ " [-8.4473e-02, -3.3447e-02, -6.7383e-02, ..., 4.0527e-02,\n",
+ " -6.9885e-03, 1.0547e-01]]],\n",
+ "\n",
+ "\n",
+ " [[[-2.2266e-01, -4.6143e-02, -2.9541e-02, ..., 6.8054e-03,\n",
+ " 3.4180e-02, -2.3682e-02],\n",
+ " [-2.4170e-02, 5.7129e-02, 4.0771e-02, ..., -4.5898e-02,\n",
+ " -1.1536e-02, 9.0942e-03],\n",
+ " [ 7.8613e-02, 4.6631e-02, 9.9182e-04, ..., 4.2236e-02,\n",
+ " 2.5879e-02, 4.2236e-02],\n",
+ " ...,\n",
+ " [ 5.7129e-02, 3.1250e-02, -7.7148e-02, ..., -2.1362e-02,\n",
+ " -2.8809e-02, -6.3171e-03],\n",
+ " [ 1.0352e-01, 6.2988e-02, -1.6602e-02, ..., -3.4668e-02,\n",
+ " -5.9128e-04, 6.2988e-02],\n",
+ " [-1.0193e-02, 5.0537e-02, -1.6479e-02, ..., 3.1738e-02,\n",
+ " -4.3945e-02, -2.5146e-02]]],\n",
+ "\n",
+ "\n",
+ " [[[ 5.5176e-02, -2.3804e-02, -2.0020e-02, ..., -5.1270e-03,\n",
+ " -5.8899e-03, -2.4414e-02],\n",
+ " [ 5.4199e-02, 4.0894e-03, 8.1787e-03, ..., -1.5320e-02,\n",
+ " 1.3885e-03, 3.7842e-02],\n",
+ " [ 1.0938e-01, 5.5847e-03, -1.3184e-02, ..., 4.1260e-02,\n",
+ " 2.1484e-02, 5.2734e-02],\n",
+ " ...,\n",
+ " [-4.3297e-04, -1.1169e-02, -8.5449e-02, ..., 2.0264e-02,\n",
+ " -4.7363e-02, -3.6774e-03],\n",
+ " [ 5.5847e-03, -1.7212e-02, -4.9805e-02, ..., -7.6660e-02,\n",
+ " -2.7466e-02, 4.8340e-02],\n",
+ " [-9.3750e-02, -3.6377e-02, -6.6895e-02, ..., -3.0029e-02,\n",
+ " -5.8350e-02, -6.1768e-02]]]]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"\n",
+ "Parameter containing:\n",
+ "tensor([[[[[-4.7913e-03, -2.1515e-03, -2.0447e-03, ..., 3.2997e-04,\n",
+ " -3.0212e-03, -7.9727e-04],\n",
+ " [ 9.7656e-04, 2.4567e-03, 9.8419e-04, ..., -1.8845e-03,\n",
+ " 2.3193e-03, 3.6621e-03],\n",
+ " [-3.5095e-04, 2.5940e-03, -2.7618e-03, ..., -3.7956e-04,\n",
+ " -3.1948e-05, 1.7166e-03],\n",
+ " ...,\n",
+ " [-3.8605e-03, -1.1215e-03, -9.0790e-04, ..., 6.5994e-04,\n",
+ " 1.0071e-03, 1.2894e-03],\n",
+ " [-2.2278e-03, 1.2589e-03, -1.0204e-04, ..., 3.7079e-03,\n",
+ " 1.5354e-04, -8.3160e-04],\n",
+ " [ 6.3324e-04, 1.4114e-03, 9.5367e-04, ..., -3.4485e-03,\n",
+ " -1.8234e-03, -4.0283e-03]]],\n",
+ "\"\"\"\n",
+ "intern_model.vision_encoder.patch_embed.proj.weight\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Parameter containing:\n",
+ "tensor([[ 0.0366, 0.0135, 0.0492, ..., -0.0274, 0.0493, 0.0242],\n",
+ " [-0.0113, 0.0504, 0.0469, ..., -0.0269, -0.0224, -0.0305],\n",
+ " [ 0.0192, -0.0152, 0.0119, ..., 0.0115, -0.0182, -0.0063],\n",
+ " ...,\n",
+ " [-0.0370, -0.0460, 0.0203, ..., 0.0157, -0.0529, 0.0139],\n",
+ " [-0.0523, -0.0192, -0.0612, ..., -0.0515, 0.0169, 0.0098],\n",
+ " [ 0.0277, -0.0029, -0.0349, ..., 0.0014, -0.0453, 0.0052]],\n",
+ " device='cuda:0')"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"\n",
+ "Parameter containing:\n",
+ "tensor([[ 0.0366, 0.0135, 0.0492, ..., -0.0274, 0.0493, 0.0242],\n",
+ " [-0.0113, 0.0504, 0.0469, ..., -0.0269, -0.0224, -0.0305],\n",
+ " [ 0.0192, -0.0152, 0.0119, ..., 0.0115, -0.0182, -0.0063],\n",
+ " ...,\n",
+ " [-0.0370, -0.0460, 0.0203, ..., 0.0157, -0.0529, 0.0139],\n",
+ " [-0.0523, -0.0192, -0.0612, ..., -0.0515, 0.0169, 0.0098],\n",
+ " [ 0.0277, -0.0029, -0.0349, ..., 0.0014, -0.0453, 0.0052]],\n",
+ " device='cuda:0')\n",
+ "\"\"\"\n",
+ "intern_model.text_encoder.encoder.layer[0].output.dense.weight"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Video tensor([0.0005], device='cuda:0')\n",
+ "# Text tensor([-1.1985e-03, 5.7084e-04, -7.3242e-05, -2.1923e-04, 1.3280e-03,\n",
+ "# 6.7617e-05, -5.6482e-04, 1.3007e-03, \n",
+ "# 9.1326e-04, 5.7684e-04],\n",
+ "# device='cuda:0')\n",
+ "# text: A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees. ~ prob: 0.5572\n",
+ "# text: A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys. ~ prob: 0.1044\n",
+ "# text: A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon. ~ prob: 0.0958\n",
+ "# text: A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride. ~ prob: 0.0936\n",
+ "# text: A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run. ~ prob: 0.0404"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# /home/toolkit/.conda/envs/urlb_test/lib/python3.8/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+ "# warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
+ "# Video tensor([-0.0014], device='cuda:0')\n",
+ "# Text tensor([-1.8553e-03, -1.8098e-03, 5.9901e-04, -1.9457e-03, 4.7702e-05,\n",
+ "# -2.8283e-03, -2.2676e-03, 7.7966e-04, -2.1556e-04, -3.8074e-04],\n",
+ "# device='cuda:0')\n",
+ "# text: A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride. ~ prob: 0.3186\n",
+ "# text: A playful dog slides down a snowy hill, wagging its tail with delight. ~ prob: 0.1871\n",
+ "# text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.1405\n",
+ "# text: A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys. ~ prob: 0.1344\n",
+ "# text: A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run. ~ prob: 0.0955"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/easydict.py b/third_party/InternVideo/InternVideo2/multi_modality/demo/easydict.py
new file mode 100644
index 0000000000000000000000000000000000000000..241aca41c9f1b0677be4bf6070c077fa24501816
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/demo/easydict.py
@@ -0,0 +1,149 @@
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ if hasattr(self, k):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/example1.mp4 b/third_party/InternVideo/InternVideo2/multi_modality/demo/example1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..815f8412d86b0b9d2200efc7d28d5e454e121c0d
Binary files /dev/null and b/third_party/InternVideo/InternVideo2/multi_modality/demo/example1.mp4 differ
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/internvideo2_stage2_config.py b/third_party/InternVideo/InternVideo2/multi_modality/demo/internvideo2_stage2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..30060b3fd636aaf5fdbde5c3cce78c7fdb4dabae
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/demo/internvideo2_stage2_config.py
@@ -0,0 +1,113 @@
+from configs.data import *
+from configs.model import *
+# ========================= data ==========================
+# NOTE The train_file will not be used during the evaluation
+
+num_workers = 6
+
+# ========================= input ==========================
+num_frames = 4
+num_frames_test = 4
+batch_size = 8
+batch_size_test = 4
+size_t = 224
+max_txt_l = 40
+
+origin_num_frames = 4
+
+use_half_precision = False
+use_bf16 = False
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_stage2_pt.pth',
+ use_checkpoint=True,
+ checkpoint_num=40,
+ use_flash_attn=use_half_precision,
+ use_fused_rmsnorm=use_half_precision,
+ use_fused_mlp=use_half_precision,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ sep_image_video_pos_embed=True,
+ keep_temporal=False,
+ only_mask=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= optimizer ==========================
+dist_url = "env://"
+device = "cpu"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = ""
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/small_config.py b/third_party/InternVideo/InternVideo2/multi_modality/demo/small_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52b98a8ea5090699b48be9acbc0430258247319
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/demo/small_config.py
@@ -0,0 +1,274 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import json
+import os
+import os.path as osp
+import re
+import shutil
+import sys
+import tempfile
+from copy import deepcopy
+from importlib import import_module
+
+import yaml
+
+from easydict import EasyDict
+
+__all__ = ["Config", "pretty_text"]
+
+
+BASE_KEY = "_base_"
+# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
+BASE_CONFIG = {}
+
+cfg = None
+
+
+class Config(object):
+ """config"""
+
+ @classmethod
+ def pretty_text(cls, cfg: dict, indent=2) -> str:
+ """format dict to a string
+
+ Args:
+ cfg (EasyDict): the params.
+
+ Returns: The string to display.
+
+ """
+ msg = "{\n"
+ for i, (k, v) in enumerate(cfg.items()):
+ if isinstance(v, dict):
+ v = cls.pretty_text(v, indent + 4)
+ spaces = " " * indent
+ msg += spaces + "{}: {}".format(k, v)
+ if i == len(cfg) - 1:
+ msg += " }"
+ else:
+ msg += "\n"
+ return msg
+
+ @classmethod
+ def dump(cls, cfg, savepath=None):
+ """dump cfg to `json` file.
+
+ Args:
+ cfg (dict): The dict to dump.
+ savepath (str): The filepath to save the dumped dict.
+
+ Returns: TODO
+
+ """
+ if savepath is None:
+ savepath = osp.join(cfg.WORKSPACE, "config.json")
+ json.dump(cfg, open(savepath, "w"), indent=2)
+
+ @classmethod
+ def get_config(cls, default_config: dict = None):
+ """get a `Config` instance.
+
+ Args:
+ default_config (dict): The default config. `default_config` will be overrided
+ by config file `--cfg`, `--cfg` will be overrided by commandline args.
+
+ Returns: an EasyDict.
+ """
+ global cfg
+ if cfg is not None:
+ return cfg
+
+ # define arg parser.
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
+ parser.add_argument(
+ "config_file", help="the configuration file to load. support: .yaml, .json, .py"
+ )
+ parser.add_argument(
+ "opts",
+ default=None,
+ nargs="*",
+ help="overrided configs. List. Format: 'key1 name1 key2 name2'",
+ )
+ args = parser.parse_args()
+
+ cfg = EasyDict(BASE_CONFIG)
+ if osp.isfile(args.config_file):
+ cfg_from_file = cls.from_file(args.config_file)
+ cfg = merge_a_into_b(cfg_from_file, cfg)
+ cfg = cls.merge_list(cfg, args.opts)
+ cfg = eval_dict_leaf(cfg)
+
+ # update some keys to make them show at the last
+ for k in BASE_CONFIG:
+ cfg[k] = cfg.pop(k)
+ return cfg
+
+ @classmethod
+ def from_file(cls, filepath: str) -> EasyDict:
+ """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
+
+ Args:
+ filepath (str): The config file path.
+
+ Returns: TODO
+
+ """
+ filepath = osp.abspath(osp.expanduser(filepath))
+ if not osp.isfile(filepath):
+ raise IOError(f"File does not exist: {filepath}")
+ if filepath.endswith(".py"):
+ sys.path.insert(0, osp.dirname(filepath))
+ mod = import_module(osp.splitext(osp.basename(filepath))[0])
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+
+ elif filepath.endswith((".yml", ".yaml")):
+ cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
+ elif filepath.endswith(".json"):
+ cfg_dict = json.load(open(filepath, "r"))
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ cfg_text = filepath + "\n"
+ with open(filepath, "r") as f:
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
+ cfg_dir = osp.dirname(filepath)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ for f in base_filename:
+ _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError("Duplicate key is not allowed among bases")
+ base_cfg_dict.update(c)
+
+ cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
+
+ return EasyDict(cfg_dict)
+
+ @classmethod
+ def merge_list(cls, cfg, opts: list):
+ """merge commandline opts.
+
+ Args:
+ cfg: (dict): The config to be merged.
+ opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
+ The keys can be nested. For example, ["a.b", v] will be considered
+ as `dict(a=dict(b=v))`.
+
+ Returns: dict.
+
+ """
+ assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
+ for i in range(0, len(opts), 2):
+ full_k, v = opts[i], opts[i + 1]
+ keys = full_k.split(".")
+ sub_d = cfg
+ for i, k in enumerate(keys):
+ if not hasattr(sub_d, k):
+ raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
+ if i != len(keys) - 1:
+ sub_d = sub_d[k]
+ else:
+ sub_d[k] = v
+ return cfg
+
+
+def merge_a_into_b(a, b, inplace=False):
+ """The values in a will override values in b.
+
+ Args:
+ a (dict): source dict.
+ b (dict): target dict.
+
+ Returns: dict. recursively merge dict a into dict b.
+
+ """
+ if not inplace:
+ b = deepcopy(b)
+ for key in a:
+ if key in b:
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
+ b[key] = merge_a_into_b(a[key], b[key], inplace=True)
+ else:
+ b[key] = a[key]
+ else:
+ b[key] = a[key]
+ return b
+
+
+def eval_dict_leaf(d, orig_dict=None):
+ """eval values of dict leaf.
+
+ Args:
+ d (dict): The dict to eval.
+
+ Returns: dict.
+
+ """
+ if orig_dict is None:
+ orig_dict = d
+ for k, v in d.items():
+ if not isinstance(v, dict):
+ d[k] = eval_string(v, orig_dict)
+ else:
+ eval_dict_leaf(v, orig_dict)
+ return d
+
+
+def eval_string(string, d):
+ """automatically evaluate string to corresponding types.
+
+ For example:
+ not a string -> return the original input
+ '0' -> 0
+ '0.2' -> 0.2
+ '[0, 1, 2]' -> [0,1,2]
+ 'eval(1+2)' -> 3
+ 'eval(range(5))' -> [0,1,2,3,4]
+ '${a}' -> d.a
+
+
+
+ Args:
+ string (str): The value to evaluate.
+ d (dict): The
+
+ Returns: the corresponding type
+
+ """
+ if not isinstance(string, str):
+ return string
+ # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
+ # return eval(string)
+ if string[0:5] == "eval(":
+ return eval(string[5:-1])
+
+ s0 = string
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ if s1 != s0:
+ while s1 != s0:
+ s0 = s1
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ return eval(s1)
+
+ try:
+ v = ast.literal_eval(string)
+ except:
+ v = string
+ return v
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/demo/small_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/demo/small_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..628bb366a53f3b94b2b8b90206b40438a7c4bc13
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/demo/small_utils.py
@@ -0,0 +1,318 @@
+import numpy as np
+import cv2
+import os
+import io
+
+import torch
+from torch import nn
+
+import sys
+from models.backbones.internvideo2 import pretrain_internvideo2_1b_patch14_224
+from models.backbones.bert.builder import build_bert
+# from models.criterions import get_sim
+from models.backbones.internvideo2.pos_embed import interpolate_pos_embed_internvideo2_new
+from models.backbones.bert.tokenization_bert import BertTokenizer
+
+
+def _frame_from_video(video):
+ while video.isOpened():
+ success, frame = video.read()
+ if success:
+ yield frame
+ else:
+ break
+
+v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
+v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
+def normalize(data):
+ return (data/255.0-v_mean)/v_std
+
+
+def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
+ assert(len(vid_list) >= fnum)
+ step = len(vid_list) // fnum
+ vid_list = vid_list[::step][:fnum]
+ vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]
+ vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
+ vid_tube = np.concatenate(vid_tube, axis=1)
+ vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
+ vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
+ return vid_tube
+
+
+def get_text_feat_dict(texts, clip, text_feat_d={}):
+ for t in texts:
+ feat = clip.get_txt_feat(t)
+ text_feat_d[t] = feat
+ return text_feat_d
+
+
+def get_vid_feat(frames, vlm):
+ return vlm.get_vid_features(frames)
+
+
+def retrieve_text(frames,
+ texts,
+ model,
+ topk:int=5,
+ config: dict={},
+ device=torch.device('cuda')):
+
+ vlm = model
+ vlm = vlm.to(device)
+
+ fn = config.get('num_frames', 8)
+ size_t = config.get('size_t', 224)
+ frames_tensor = frames2tensor(frames, fnum=fn, target_size=(size_t, size_t), device=device)
+ vid_feat = vlm.get_vid_features(frames_tensor)
+ print('Video', vid_feat.mean(dim=-1))
+
+ text_feat_d = {}
+ text_feat_d = get_text_feat_dict(texts, vlm, text_feat_d)
+ text_feats = [text_feat_d[t] for t in texts]
+ text_feats_tensor = torch.cat(text_feats, 0)
+ print('Text', text_feats_tensor.mean(dim=-1))
+
+ probs, idxs = vlm.predict_label(vid_feat, text_feats_tensor, top=topk)
+
+ ret_texts = [texts[i] for i in idxs.long().numpy()[0].tolist()]
+ return ret_texts, probs.float().numpy()[0]
+
+
+def setup_internvideo2(config: dict):
+ if "bert" in config.model.text_encoder.name:
+ tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True)
+ model = InternVideo2_Stage2(config=config, tokenizer=tokenizer, is_pretrain=True)
+ else:
+ model = InternVideo2_Stage2(config=config, is_pretrain=True)
+ tokenizer = model.tokenizer
+
+ if config.get('compile_model', False):
+ torch.set_float32_matmul_precision('high')
+ model = torch.compile(model)
+
+ model = model.to(torch.device(config.device))
+ model_without_ddp = model
+
+ if (config.pretrained_path.strip() and (os.path.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path):
+ checkpoint = torch.load(config.pretrained_path, map_location="cpu")
+ try:
+ if "model" in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint["module"] # This is a deepspeed stage 1 model
+ except:
+ state_dict = checkpoint
+
+ # Note: this was a temporary fix due to the bug caused by is_pretrain=False
+ # from collections import OrderedDict
+ # state_dict = OrderedDict({ k.replace('text_encoder.bert', 'text_encoder') : state_dict[k] for k in state_dict})
+
+ if config.get('origin_num_frames', None) is not None:
+ a = len(state_dict)
+ interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames)
+ assert a == len(state_dict), state_dict.keys()
+
+ msg = model_without_ddp.load_state_dict(state_dict, strict=False)
+ print(f"load_state_dict: {msg}")
+
+ if config.get('use_bf16', False):
+ model_without_ddp = model_without_ddp.to(torch.bfloat16)
+ elif config.get('use_half_precision', False):
+ model_without_ddp = model_without_ddp.to(torch.float16)
+ else:
+ model_without_ddp = model_without_ddp.to(torch.float32)
+
+ return (model_without_ddp, tokenizer,)
+
+
+class InternVideo2_Stage2(nn.Module):
+ """docstring for InternVideo2_Stage2"""
+
+ def __init__(self,
+ config,
+ tokenizer,
+ is_pretrain: bool=True):
+ super(InternVideo2_Stage2, self).__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.is_pretrain = is_pretrain
+ self.vision_width = config.model.vision_encoder.clip_embed_dim
+ self.text_width = config.model.text_encoder.d_model
+ self.embed_dim = config.model.embed_dim
+
+ # create modules.
+ self.vision_encoder = self.build_vision_encoder()
+ self.freeze_vision()
+
+ self.text_encoder = self.build_text_encoder()
+ self.freeze_text()
+
+ self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
+ self.text_proj = nn.Linear(self.text_width, self.embed_dim)
+
+ def freeze_vision(self):
+ """freeze vision encoder"""
+ for p in self.vision_encoder.parameters():
+ p.requires_grad = False
+
+ def freeze_text(self):
+ """freeze text encoder"""
+ for p in self.text_encoder.parameters():
+ p.requires_grad = False
+
+ @property
+ def dtype(self):
+ return self.vision_encoder.patch_embed.proj.weight.dtype
+
+ def encode_vision(self,
+ image: torch.Tensor,
+ test: bool=False):
+ """encode image / videos as features.
+
+ Args:
+ image (torch.Tensor): The input images.
+ test (bool): Whether testing.
+
+ Returns: tuple.
+ - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
+ - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
+ - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C].
+ - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
+
+ """
+
+ T = image.shape[1]
+ use_image = True if T == 1 else False
+ image = image.permute(0, 2, 1, 3, 4).to(self.dtype) # [B,T,C,H,W] -> [B,C,T,H,W]
+ # whether save temporal dimension
+ # keep_temporal=self.config.model.vision_encoder.keep_temporal
+ if test:
+ vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder(
+ image, None, use_image)
+ return vision_embeds, pooled_vision_embeds
+ else:
+ mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image)
+ # if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'):
+ # keep_temporal = False
+ # print(f"\033[31mmask is {type(mask)}\033[0m")
+ vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder(
+ image, mask, use_image)
+ return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis
+
+ def encode_text(self,
+ text: dict):
+ """encode text.
+ Args:
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
+ Returns: tuple.
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
+ - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
+
+ """
+ text_output = self.get_text_encoder()(
+ text.input_ids,
+ attention_mask=text.attention_mask,
+ return_dict=True,
+ mode="text",
+ )
+ text_embeds = text_output.last_hidden_state
+ pooled_text_embeds = text_embeds[:, 0]
+ return text_embeds, pooled_text_embeds
+
+ def build_vision_encoder(self):
+ """build vision encoder
+ Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`.
+
+ """
+ encoder_name = self.config.model.vision_encoder.name
+
+ if encoder_name == 'pretrain_internvideo2_1b_patch14_224':
+ vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model)
+ else:
+ raise ValueError(f"Not implemented: {encoder_name}")
+
+ # parameters for mask
+ img_size = self.config.model.vision_encoder.img_size
+ num_frames = self.config.model.vision_encoder.num_frames
+ tublet_size = self.config.model.vision_encoder.tubelet_size
+ patch_size = self.config.model.vision_encoder.patch_size
+ self.clip_img_size = self.config.model.vision_encoder.clip_input_resolution
+ self.video_mask_type = self.config.model.vision_encoder.video_mask_type
+ self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size)
+ self.video_mask_ratio = self.config.model.vision_encoder.video_mask_ratio
+ self.image_mask_type = self.config.model.vision_encoder.image_mask_type
+ self.image_window_size = (1, img_size // patch_size, img_size // patch_size)
+ self.image_mask_ratio = self.config.model.vision_encoder.image_mask_ratio
+
+ return vision_encoder
+
+ def build_text_encoder(self):
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
+ Returns: nn.Module. The text encoder
+
+ """
+ encoder_name = self.config.model.text_encoder.name
+
+ if "bert" in encoder_name:
+ text_encoder = build_bert(
+ self.config.model,
+ self.is_pretrain,
+ self.config.gradient_checkpointing,
+ )
+ else:
+ raise ValueError(f"Not implemented: {encoder_name}")
+
+ return text_encoder
+
+ def get_text_encoder(self):
+ """get text encoder, used for text and cross-modal encoding"""
+ encoder = self.text_encoder
+ return encoder.bert if hasattr(encoder, "bert") else encoder
+
+ def get_vid_features(self,
+ frames: torch.Tensor):
+ """get the video features for the given frames.
+
+ Args:
+ frames (torch.Tensor): The input frames. Shape: [B,T,C,H,W].
+
+ Returns: tuple.
+ - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
+ - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
+
+ """
+ with torch.no_grad():
+ _, vfeat = self.encode_vision(frames, test=True)
+ vfeat = self.vision_proj(vfeat)
+ vfeat /= vfeat.norm(dim=-1, keepdim=True)
+ return vfeat
+
+ def get_txt_feat(self,
+ text: str):
+ """get the text features for the given text."""
+ device = next(self.parameters()).device
+ with torch.no_grad():
+ text = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.config.max_txt_l,
+ return_tensors="pt",).to(device)
+ _, tfeat = self.encode_text(text)
+ tfeat = self.text_proj(tfeat)
+ tfeat /= tfeat.norm(dim=-1, keepdim=True)
+ return tfeat
+
+ def predict_label(self,
+ vid_feat: torch.Tensor,
+ txt_feat: torch.Tensor,
+ top: int=5):
+ label_probs = (100.0 * vid_feat @ txt_feat.T).softmax(dim=-1)
+ top_probs, top_labels = label_probs.float().cpu().topk(top, dim=-1)
+ return top_probs, top_labels
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/miscs/test_flops.py b/third_party/InternVideo/InternVideo2/multi_modality/miscs/test_flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab4d812083d7bba5d00e1ecb0de418dbee7d2135
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/miscs/test_flops.py
@@ -0,0 +1,69 @@
+import torch
+from fvcore.nn import FlopCountAnalysis, flop_count_table
+from torch.nn import MultiheadAttention
+
+from models.beit.st_beit import BeitConfig, BeitModel
+from models.temporal_model import (STAdapter, TemporalAttention,
+ WindowTemporalAttention)
+
+
+def mem_stat():
+ mem = torch.cuda.max_memory_allocated() / 1024 / 1024
+ print(f"max memory allocated: {mem}MB")
+
+
+def build_backbone(tm_block="timesformer"):
+ """TODO: Docstring for build_backbone.
+ Returns: TODO
+
+ """
+ if tm_block == "timesformer":
+ other_cfg = dict(
+ num_frames=12, temporal_model_block="timesformer", temporal_model_config={}
+ )
+ elif tm_block == "st_adapter":
+ other_cfg = dict(
+ num_frames=12, temporal_model_block="st_adapter", temporal_model_config={}
+ )
+ elif tm_block == "xclip":
+ other_cfg = dict(
+ num_frames=12, temporal_model_block="xclip", temporal_model_config={}
+ )
+ elif tm_block == "none":
+ other_cfg = dict(num_frames=12, temporal_model_block="none", temporal_model_config={})
+ elif tm_block == "wa_2x2":
+ other_cfg = dict(
+ num_frames=12,
+ temporal_model_block="window_attention",
+ temporal_model_config=dict(window_size=(2, 2)),
+ )
+ elif tm_block == "wa_7x7":
+ other_cfg = dict(
+ num_frames=12,
+ temporal_model_block="window_attention",
+ temporal_model_config=dict(window_size=(7, 7)),
+ )
+ else:
+ raise ValueError("not exist")
+
+ model_card = "microsoft/beit-base-patch16-224-pt22k-ft22k"
+ model_config = BeitConfig.from_pretrained(model_card, image_size=224, **other_cfg)
+ model = BeitModel(model_config)
+ return model
+
+
+# model = TemporalAttention()
+model = build_backbone("st_adapter")
+model.gradient_checkpointing_enable()
+model.cuda()
+for i in range(3):
+ x = torch.rand(32, 12, 3, 224, 224, requires_grad=True)
+ x = x.cuda()
+ x = x.requires_grad_()
+ y = model(x)
+ loss = y[0].mean()
+ loss.backward()
+ mem_stat()
+
+# flops = FlopCountAnalysis(model, x)
+# print(flop_count_table(flops))
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c064a3682d41deb9810d194593fdfdbc87f23488
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/__init__.py
@@ -0,0 +1,9 @@
+from .internvideo2_clip import InternVideo2_CLIP
+from .internvideo2_stage2 import InternVideo2_Stage2
+# from .internvideo2_stage2_audio import InternVideo2_Stage2_audio
+
+__all__ = [
+ 'InternVideo2_CLIP',
+ 'InternVideo2_Stage2',
+ # 'InternVideo2_Stage2_audio'
+]
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/BEATs.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/BEATs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f441b5480eec1d0e574be83c1dfafb8a731c72b7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/BEATs.py
@@ -0,0 +1,210 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+import torchaudio.compliance.kaldi as ta_kaldi
+
+try:
+ from .backbone import (
+ TransformerEncoder,
+ )
+except:
+ from backbone import (
+ TransformerEncoder,
+ )
+
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+class BEATsConfig:
+ def __init__(self, cfg=None):
+ self.input_patch_size: int = -1 # path size of patch embedding
+ self.embed_dim: int = 512 # patch embedding dimension
+ self.conv_bias: bool = False # include bias in conv encoder
+
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
+ self.activation_fn: str = "gelu" # activation function to use
+
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
+
+ # dropouts
+ self.dropout: float = 0.1 # dropout probability for the transformer
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
+
+ # positional embeddings
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
+
+ # relative position embedding
+ self.relative_position_embedding: bool = False # apply relative position embedding
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
+
+ # label predictor
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
+ self.predictor_class: int = 527 # target class number for the predictor
+
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ self.__dict__.update(cfg)
+
+
+class BEATs(nn.Module):
+ def __init__(
+ self,
+ cfg: BEATsConfig,
+ ) -> None:
+ super().__init__()
+ logger.info(f"BEATs Config: {cfg.__dict__}")
+
+ self.cfg = cfg
+
+ self.embed = cfg.embed_dim
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.input_patch_size = cfg.input_patch_size
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
+ bias=cfg.conv_bias)
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+ assert not cfg.deep_norm or not cfg.layer_norm_first
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ if cfg.finetuned_model:
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
+ else:
+ self.predictor = None
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(
+ padding_mask.size(0), features.size(1), -1
+ )
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def preprocess(
+ self,
+ source: torch.Tensor,
+ fbank_mean: float = 15.41663,
+ fbank_std: float = 6.55582,
+ ) -> torch.Tensor:
+ fbanks = []
+ for waveform in source:
+ waveform = waveform.unsqueeze(0) * 2 ** 15
+ # print(waveform.max(), waveform.min(), waveform.shape, waveform.dtype)
+ # waveform = waveform.unsqueeze(0)
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
+ # print(fbank.max(), fbank.min(), fbank.shape, fbank.dtype)
+ fbanks.append(fbank)
+ fbank = torch.stack(fbanks, dim=0)
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
+ # print(fbank.max(), fbank.min(), fbank.shape, fbank.dtype)
+ return fbank
+
+ def forward(self, fbank):
+ fbank = fbank.unsqueeze(1)
+ ### fbank b,1,128,1024
+ features = self.patch_embedding(fbank)
+ ### b,512,8,64
+ features = features.reshape(features.shape[0], features.shape[1], -1).contiguous()
+ ### b, 512 , 512
+ features = features.transpose(1, 2).contiguous() ##b,512,512
+ features = self.layer_norm(features)
+ x = self.dropout_input(features)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ x = self.dropout_input(features)
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=None,
+ )
+
+ return x
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ fbank_mean: float = 15.41663,
+ fbank_std: float = 6.55582,
+ ):
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
+
+ fbank = fbank.unsqueeze(1)
+ features = self.patch_embedding(fbank)
+ features = features.reshape(features.shape[0], features.shape[1], -1)
+ features = features.transpose(1, 2).contiguous()
+ features = self.layer_norm(features)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ x = self.dropout_input(features)
+
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ )
+
+ if self.predictor is not None:
+ x = self.predictor_dropout(x)
+ logits = self.predictor(x)
+
+ if padding_mask is not None and padding_mask.any():
+ logits[padding_mask] = 0
+ logits = logits.sum(dim=1)
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
+ else:
+ logits = logits.mean(dim=1)
+
+ lprobs = torch.sigmoid(logits)
+
+ return lprobs, padding_mask
+ else:
+ return x, padding_mask
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/README.md b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fd54e797a6ae5152c92fd2b37a8531f13e103f00
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/README.md
@@ -0,0 +1,127 @@
+
+# BEATs
+
+[**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
+
+Official PyTorch implementation and pretrained models of BEATs
+
+## Pre-Trained and Fine-Tuned Tokenizers and Models
+Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
+|---|---|---|---|---
+Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
+
+
+### Load Tokenizers
+
+```python
+import torch
+from Tokenizers import TokenizersConfig, Tokenizers
+
+# load the pre-trained checkpoints
+checkpoint = torch.load('/path/to/tokenizer.pt')
+
+cfg = TokenizersConfig(checkpoint['cfg'])
+BEATs_tokenizer = Tokenizers(cfg)
+BEATs_tokenizer.load_state_dict(checkpoint['model'])
+BEATs_tokenizer.eval()
+
+# tokenize the audio and generate the labels
+audio_input_16khz = torch.randn(1, 10000)
+padding_mask = torch.zeros(1, 10000).bool()
+
+labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
+```
+
+
+### Load Pre-Trained Models
+
+```python
+import torch
+from BEATs import BEATs, BEATsConfig
+
+# load the pre-trained checkpoints
+checkpoint = torch.load('/path/to/model.pt')
+
+cfg = BEATsConfig(checkpoint['cfg'])
+BEATs_model = BEATs(cfg)
+BEATs_model.load_state_dict(checkpoint['model'])
+BEATs_model.eval()
+
+# extract the the audio representation
+audio_input_16khz = torch.randn(1, 10000)
+padding_mask = torch.zeros(1, 10000).bool()
+
+representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
+```
+
+
+### Load Fine-tuned Models
+
+```python
+import torch
+from BEATs import BEATs, BEATsConfig
+
+# load the fine-tuned checkpoints
+checkpoint = torch.load('/path/to/model.pt')
+
+cfg = BEATsConfig(checkpoint['cfg'])
+BEATs_model = BEATs(cfg)
+BEATs_model.load_state_dict(checkpoint['model'])
+BEATs_model.eval()
+
+# predict the classification probability of each class
+audio_input_16khz = torch.randn(3, 10000)
+padding_mask = torch.zeros(3, 10000).bool()
+
+probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
+
+for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
+ top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
+ print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
+```
+
+## Evaluation Results
+
+### Comparing with the SOTA Single Models
+![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
+
+
+### Comparing with the SOTA Ensemble Models
+![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
+
+
+### Comparing Different BEATS Tokenizers
+![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
+
+
+### Comparing Different Pre-Training Targets
+![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
+
+
+## License
+This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
+Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
+
+[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
+
+
+### Reference
+If you find our work is useful in your research, please cite the following paper:
+``` latex
+@article{Chen2022beats,
+ title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
+ author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
+ eprint={2212.09058},
+ archivePrefix={arXiv},
+ year={2022}
+}
+```
+### Contact Information
+
+For help or issues using BEATs models, please submit a GitHub issue.
+
+For other communications related to BEATs, please contact Yu Wu (`yuwu1@microsoft.com`).
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/Tokenizers.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/Tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..eafe212d8a2ce70157f3841374873a57c5bbed0b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/Tokenizers.py
@@ -0,0 +1,173 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+import torchaudio.compliance.kaldi as ta_kaldi
+
+from backbone import (
+ TransformerEncoder,
+)
+from quantizer import (
+ NormEMAVectorQuantizer,
+)
+
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+class TokenizersConfig:
+ def __init__(self, cfg=None):
+ self.input_patch_size: int = -1 # path size of patch embedding
+ self.embed_dim: int = 512 # patch embedding dimension
+ self.conv_bias: bool = False # include bias in conv encoder
+
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
+ self.activation_fn: str = "gelu" # activation function to use
+
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
+
+ # dropouts
+ self.dropout: float = 0.1 # dropout probability for the transformer
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
+
+ # positional embeddings
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
+
+ # relative position embedding
+ self.relative_position_embedding: bool = False # apply relative position embedding
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
+
+ # quantizer
+ self.quant_n: int = 1024 # codebook number in quantizer
+ self.quant_dim: int = 256 # codebook dimension in quantizer
+
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ self.__dict__.update(cfg)
+
+
+class Tokenizers(nn.Module):
+ def __init__(
+ self,
+ cfg: TokenizersConfig,
+ ) -> None:
+ super().__init__()
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
+
+ self.cfg = cfg
+
+ self.embed = cfg.embed_dim
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.input_patch_size = cfg.input_patch_size
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
+ bias=cfg.conv_bias)
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+ assert not cfg.deep_norm or not cfg.layer_norm_first
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ self.quantize = NormEMAVectorQuantizer(
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
+ )
+ self.quant_n = cfg.quant_n
+ self.quantize_layer = nn.Sequential(
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
+ nn.Tanh(),
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
+ )
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(
+ padding_mask.size(0), features.size(1), -1
+ )
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def preprocess(
+ self,
+ source: torch.Tensor,
+ fbank_mean: float = 15.41663,
+ fbank_std: float = 6.55582,
+ ) -> torch.Tensor:
+ fbanks = []
+ for waveform in source:
+ waveform = waveform.unsqueeze(0) * 2 ** 15
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
+ fbanks.append(fbank)
+ fbank = torch.stack(fbanks, dim=0)
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
+ return fbank
+
+ def extract_labels(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ fbank_mean: float = 15.41663,
+ fbank_std: float = 6.55582,
+ ):
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
+
+ fbank = fbank.unsqueeze(1)
+ features = self.patch_embedding(fbank)
+ features = features.reshape(features.shape[0], features.shape[1], -1)
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ x = self.dropout_input(features)
+
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ )
+
+ quantize_input = self.quantize_layer(x)
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
+
+ return embed_ind
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/backbone.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a14c3ea30053bc79855d1907ad371ccaa231d7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/backbone.py
@@ -0,0 +1,960 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+r"""
+Weight Normalization from https://arxiv.org/abs/1602.07868
+"""
+from torch.nn.parameter import Parameter, UninitializedParameter
+from typing import Any, TypeVar
+import torch
+
+def _weight_norm(v_in, g_in, dim):
+ assert v_in.device == g_in.device, "weight_norm: expected v_in and g_in to be on the same device, but v_in is on {} and g_in is on {}".format(v_in.device, g_in.device)
+
+ v = v_in.contiguous()
+ g = g_in.contiguous()
+
+
+ # has_half_dtype = v.dtype == torch.float16 or g.dtype == torch.float16
+ # can_use_fused = not has_half_dtype and (dim == 0 or dim == v.dim() - 1)
+
+ # if can_use_fused:
+
+ # return torch._weight_norm_interface(v, g, dim)[0]
+ # else:
+
+ return v * (g / torch.norm_except_dim(v, 2, dim))
+
+def norm_except_dim(v, pow, dim):
+
+ if dim == -1:
+ return v.norm(pow)
+ elif dim == 0:
+ output_size = [1] * v.dim()
+ output_size[0] = v.size(0)
+ return v.contiguous().view(v.size(0), -1).norm(pow, 1).view(output_size)
+ elif dim == v.dim() - 1:
+ output_size = [1] * v.dim()
+ output_size[v.dim() - 1] = v.size(v.dim() - 1)
+ return v.contiguous().view(-1, v.size(v.dim() - 1)).norm(pow, 0).view(output_size)
+ else:
+ return norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim)
+
+import torch.nn as nn
+from torch.nn import Module
+__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
+
+class WeightNorm(object):
+ name: str
+ dim: int
+
+ def __init__(self, name: str, dim: int) -> None:
+ if dim is None:
+ dim = -1
+ self.name = name
+ self.dim = dim
+
+ def compute_weight(self, module: Module) -> Any:
+ g = getattr(module, self.name + '_g')
+ v = getattr(module, self.name + '_v')
+ return _weight_norm(v, g, self.dim)
+
+ def apply(module, name: str, dim: int) -> 'WeightNorm':
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, WeightNorm) and hook.name == name:
+ raise RuntimeError("Cannot register two weight_norm hooks on "
+ "the same parameter {}".format(name))
+
+ if dim is None:
+ dim = -1
+
+ fn = WeightNorm(name, dim)
+
+ weight = getattr(module, name)
+ if isinstance(weight, UninitializedParameter):
+ raise ValueError(
+ 'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
+ 'Make sure to run the dummy forward before applying weight normalization')
+ # remove w from parameter list
+ del module._parameters[name]
+
+ # add g and v as new parameters and express w as g/||v|| * v
+ module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
+ module.register_parameter(name + '_v', Parameter(weight.data))
+ setattr(module, name, fn.compute_weight(module))
+
+ # recompute weight before every forward()
+ module.register_forward_pre_hook(fn)
+
+ return fn
+
+ def remove(self, module: Module) -> None:
+ weight = self.compute_weight(module)
+ delattr(module, self.name)
+ del module._parameters[self.name + '_g']
+ del module._parameters[self.name + '_v']
+ setattr(module, self.name, Parameter(weight.data))
+
+ def __call__(self, module: Module, inputs: Any) -> None:
+ setattr(module, self.name, self.compute_weight(module))
+
+
+T_module = TypeVar('T_module', bound=Module)
+
+def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
+ r"""Applies weight normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
+
+ Weight normalization is a reparameterization that decouples the magnitude
+ of a weight tensor from its direction. This replaces the parameter specified
+ by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
+ (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
+ Weight normalization is implemented via a hook that recomputes the weight
+ tensor from the magnitude and direction before every :meth:`~Module.forward`
+ call.
+
+ By default, with ``dim=0``, the norm is computed independently per output
+ channel/plane. To compute a norm over the entire weight tensor, use
+ ``dim=None``.
+
+ See https://arxiv.org/abs/1602.07868
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+ dim (int, optional): dimension over which to compute the norm
+
+ Returns:
+ The original module with the weight norm hook
+
+ Example::
+
+ >>> m = weight_norm(nn.Linear(20, 40), name='weight')
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_g.size()
+ torch.Size([40, 1])
+ >>> m.weight_v.size()
+ torch.Size([40, 20])
+
+ """
+ WeightNorm.apply(module, name, dim)
+ return module
+
+
+def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
+ r"""Removes the weight normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = weight_norm(nn.Linear(20, 40))
+ >>> remove_weight_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, WeightNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("weight_norm of '{}' not found in {}"
+ .format(name, module))
+
+import math
+import numpy as np
+from typing import Dict, Optional, Tuple
+import torch
+from torch import Tensor, nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm, Parameter
+#from weight_norm import *
+
+try:
+ from .modules import (
+ GradMultiply,
+ SamePad,
+ get_activation_fn,
+ GLU_Linear,
+ quant_noise,
+ )
+except:
+ from modules import (
+ GradMultiply,
+ SamePad,
+ get_activation_fn,
+ GLU_Linear,
+ quant_noise,
+ )
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=args.conv_pos,
+ padding=args.conv_pos // 2,
+ groups=args.conv_pos_groups,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ # self.pos_conv = weight_norm(self.pos_conv, name="weight", dim=2)
+ # self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+ if hasattr(args, "relative_position_embedding"):
+ self.relative_position_embedding = args.relative_position_embedding
+ self.num_buckets = args.num_buckets
+ self.max_distance = args.max_distance
+ else:
+ self.relative_position_embedding = False
+ self.num_buckets = 0
+ self.max_distance = 0
+
+ self.layers = nn.ModuleList(
+ [
+ TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ deep_norm=args.deep_norm,
+ has_relative_attention_bias=self.relative_position_embedding,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ gru_rel_pos=args.gru_rel_pos,
+ encoder_layers=args.encoder_layers,
+ )
+ for i in range(args.encoder_layers)
+ ]
+ )
+ if self.relative_position_embedding:
+ for i in range(1, args.encoder_layers):
+ del self.layers[i].self_attn.relative_attention_bias
+ self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
+
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+
+ self.apply(init_bert_params)
+
+ if args.deep_norm:
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
+ for i in range(args.encoder_layers):
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
+ nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
+ nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
+
+ self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
+
+ def forward(self, x, padding_mask=None, layer=None):
+
+ x, layer_results = self.extract_features(x, padding_mask, layer)
+
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
+
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ origin_type = x.dtype
+ #self.pos_conv.float() # NOTE force_fp32 to make nn.utils.weight_norm happy
+ x_conv = self.pos_conv(x.transpose(1, 2).contiguous()).to(dtype=origin_type)
+ #self.pos_conv.to(origin_type)
+ x_conv = x_conv.transpose(1, 2).contiguous()
+ x = x + x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1).contiguous()
+
+ layer_results = []
+ z = None
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ r = None
+ pos_bias = None
+ for i, layer in enumerate(self.layers):
+ if self.layer_wise_gradient_decay_ratio != 1.0:
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability > self.layerdrop):
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1).contiguous()
+
+ return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ deep_norm: bool = False,
+ has_relative_attention_bias: bool = False,
+ num_buckets: int = 0,
+ max_distance: int = 0,
+ rescale_init: bool = False,
+ gru_rel_pos: bool = False,
+ encoder_layers: int = 0,
+ ) -> None:
+
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ self.activation_name = activation_fn
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ has_relative_attention_bias=has_relative_attention_bias,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ rescale_init=rescale_init,
+ gru_rel_pos=gru_rel_pos,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+
+ if self.activation_name == "glu":
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
+ else:
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ self.deep_norm = deep_norm
+ if self.deep_norm:
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
+ else:
+ self.deep_norm_alpha = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ pos_bias=None
+ ):
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias
+ )
+
+ x = self.dropout1(x)
+ x = residual * self.deep_norm_alpha + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual * self.deep_norm_alpha + x
+ x = self.final_layer_norm(x)
+
+ return x, attn, pos_bias
+
+
+class MultiheadAttention(nn.Module):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ has_relative_attention_bias=False,
+ num_buckets=32,
+ max_distance=128,
+ gru_rel_pos=False,
+ rescale_init=False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout_module = nn.Dropout(dropout)
+
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+ self.head_dim = embed_dim // num_heads
+ self.q_head_dim = self.head_dim
+ self.k_head_dim = self.head_dim
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, (
+ "Self-attention requires query, key and " "value to be of the same size"
+ )
+
+ k_bias = True
+ if rescale_init:
+ k_bias = False
+
+ k_embed_dim = embed_dim
+ q_embed_dim = embed_dim
+
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
+ )
+
+ self.out_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.gru_rel_pos = gru_rel_pos
+ if self.gru_rel_pos:
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ # Empirically observed the convergence to be much better with
+ # the scaled initialization
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ else:
+ nn.init.xavier_uniform_(self.k_proj.weight)
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ nn.init.xavier_uniform_(self.q_proj.weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.out_proj.bias is not None:
+ nn.init.constant_(self.out_proj.bias, 0.0)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+ if self.has_relative_attention_bias:
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
+
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
+ num_buckets = self.num_buckets
+ max_distance = self.max_distance
+ relative_buckets = 0
+
+ if bidirectional:
+ num_buckets = num_buckets // 2
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
+ relative_positions = torch.abs(relative_positions)
+ else:
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
+
+ max_exact = num_buckets // 2
+ is_small = relative_positions < max_exact
+
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_positions.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length):
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_positions_bucket(
+ relative_position,
+ bidirectional=True
+ )
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.permute([2, 0, 1]).contiguous()
+ return values
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ position_bias: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert key_bsz == bsz
+ assert value is not None
+ assert src_len, bsz == value.shape[:2]
+
+ if self.has_relative_attention_bias and position_bias is None:
+ position_bias = self.compute_bias(tgt_len, src_len)
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+ alpha = 32
+ q *= 1 / alpha
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ dim=1,
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
+ .transpose(0, 1).contiguous()
+ )
+ if k is not None:
+ k = (
+ k.contiguous()
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
+ .transpose(0, 1).contiguous()
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1).contiguous()
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
+ key_padding_mask
+ ),
+ ],
+ dim=1,
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2).contiguous())
+ attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2).contiguous()
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2).contiguous()
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v, position_bias
+
+ if position_bias is not None:
+ attn_mask_rel_pos = position_bias
+ if self.gru_rel_pos == 1:
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
+ _B, _H, _L, __ = query_layer.size()
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
+
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
+
+ attn_weights = attn_weights + attn_mask_rel_pos
+
+ attn_weights_float = F.softmax(
+ attn_weights, dim=-1
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0).contiguous()
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights, position_bias
+
+ @staticmethod
+ def _append_prev_key_padding_mask(
+ key_padding_mask: Optional[Tensor],
+ prev_key_padding_mask: Optional[Tensor],
+ batch_size: int,
+ src_len: int,
+ static_kv: bool,
+ ) -> Optional[Tensor]:
+ # saved key padding masks have shape (bsz, seq_len)
+ if prev_key_padding_mask is not None and static_kv:
+ new_key_padding_mask = prev_key_padding_mask
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
+ new_key_padding_mask = torch.cat(
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
+ )
+ # During incremental decoding, as the padding token enters and
+ # leaves the frame, there will be a time when prev or current
+ # is None
+ elif prev_key_padding_mask is not None:
+ if src_len > prev_key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
+ device=prev_key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat(
+ [prev_key_padding_mask.float(), filler.float()], dim=1
+ )
+ else:
+ new_key_padding_mask = prev_key_padding_mask.float()
+ elif key_padding_mask is not None:
+ if src_len > key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - key_padding_mask.size(1)),
+ device=key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat(
+ [filler.float(), key_padding_mask.float()], dim=1
+ )
+ else:
+ new_key_padding_mask = key_padding_mask.float()
+ else:
+ new_key_padding_mask = prev_key_padding_mask
+ return new_key_padding_mask
+
+ def _get_input_buffer(
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
+ ) -> Dict[str, Optional[Tensor]]:
+ result = self.get_incremental_state(incremental_state, "attn_state")
+ if result is not None:
+ return result
+ else:
+ empty_result: Dict[str, Optional[Tensor]] = {}
+ return empty_result
+
+ def _set_input_buffer(
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+ buffer: Dict[str, Optional[Tensor]],
+ ):
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+ return attn_weights
+
+
+def init_bert_params(module):
+ """
+ Initialize the weights specific to the BERT Model.
+ This overrides the default initializations depending on the specified arguments.
+ 1. If normal_init_linear_weights is set then weights of linear
+ layer will be initialized using the normal distribution and
+ bais will be set to the specified value.
+ 2. If normal_init_embed_weights is set then weights of embedding
+ layer will be initialized using the normal distribution.
+ 3. If normal_init_proj_weights is set then weights of
+ in_project_weight for MultiHeadAttention initialized using
+ the normal distribution (to be validated).
+ """
+
+ def normal_(data):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
+ )
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, MultiheadAttention):
+ normal_(module.q_proj.weight.data)
+ normal_(module.k_proj.weight.data)
+ normal_(module.v_proj.weight.data)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/modules.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..7772b2d7448edca5ec2aa5fcd6278429b98e35a4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/modules.py
@@ -0,0 +1,219 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import warnings
+import torch
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
+
+
+class SamePad(nn.Module):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class Swish(nn.Module):
+ def __init__(self):
+ super(Swish, self).__init__()
+ self.act = torch.nn.Sigmoid()
+
+ def forward(self, x):
+ return x * self.act(x)
+
+
+class GLU_Linear(nn.Module):
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
+ super(GLU_Linear, self).__init__()
+
+ self.glu_type = glu_type
+ self.output_dim = output_dim
+
+ if glu_type == "sigmoid":
+ self.glu_act = torch.nn.Sigmoid()
+ elif glu_type == "swish":
+ self.glu_act = Swish()
+ elif glu_type == "relu":
+ self.glu_act = torch.nn.ReLU()
+ elif glu_type == "gelu":
+ self.glu_act = torch.nn.GELU()
+
+ if bias_in_glu:
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
+ else:
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
+
+ def forward(self, x):
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
+ x = self.linear(x)
+
+ if self.glu_type == "bilinear":
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
+ else:
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
+
+ return x
+
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+ )
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return gelu
+ elif activation == "gelu_fast":
+ warnings.warn(
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
+ )
+ return gelu_accurate
+ elif activation == "gelu_accurate":
+ return gelu_accurate
+ elif activation == "tanh":
+ return torch.tanh
+ elif activation == "linear":
+ return lambda x: x
+ elif activation == "glu":
+ return lambda x: x
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+def quant_noise(module, p, block_size):
+ """
+ Wraps modules and applies quantization noise to the weights for
+ subsequent quantization with Iterative Product Quantization as
+ described in "Training with Quantization Noise for Extreme Model Compression"
+
+ Args:
+ - module: nn.Module
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights,
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper
+ which consists in randomly dropping blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ assert (
+ module.weight.size(1) % block_size == 0
+ ), "Input features must be a multiple of block sizes"
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ assert (
+ module.in_channels % block_size == 0
+ ), "Input channels must be a multiple of block sizes"
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = torch.zeros(
+ in_features // block_size * out_features, device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(
+ weight.size(0), weight.size(1), device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = (
+ mask.unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ )
+
+ # scale weights and apply mask
+ mask = mask.to(
+ torch.bool
+ ) # x.bool() is not currently supported in TorchScript
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/quantizer.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5370d02e7f8f10723128b9bbc34afd3342cfcd86
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/beats/quantizer.py
@@ -0,0 +1,215 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on VQGAN code bases
+# https://github.com/CompVis/taming-transformers
+# --------------------------------------------------------'
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as distributed
+
+try:
+ from einops import rearrange, repeat
+except ImportError:
+ pass
+
+
+def l2norm(t):
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def sample_vectors(samples, num):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ if use_cosine_sim:
+ dists = samples @ means.t()
+ else:
+ diffs = rearrange(samples, 'n d -> n () d') \
+ - rearrange(means, 'c d -> () c d')
+ dists = -(diffs ** 2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ if use_cosine_sim:
+ new_means = l2norm(new_means)
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
+ super().__init__()
+ self.num_tokens = num_tokens
+ self.codebook_dim = codebook_dim
+ self.decay = decay
+ self.eps = eps
+ if codebook_init_path == '':
+ if not kmeans_init:
+ weight = torch.randn(num_tokens, codebook_dim)
+ weight = l2norm(weight)
+ else:
+ weight = torch.zeros(num_tokens, codebook_dim)
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
+ else:
+ print(f"load init codebook weight from {codebook_init_path}")
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
+ weight = codebook_ckpt_weight.clone()
+ self.register_buffer('initted', torch.Tensor([True]))
+
+ self.weight = nn.Parameter(weight, requires_grad=False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
+ self.update = True
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.initted:
+ return
+ print("Performing Kemans init for codebook")
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
+ self.weight.data.copy_(embed)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initted.data.copy_(torch.Tensor([True]))
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ # normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
+ self.weight.data.copy_(embed_normalized)
+
+
+def norm_ema_inplace(moving_avg, new, decay):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+ moving_avg.data.copy_(l2norm(moving_avg.data))
+
+
+class NormEMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
+ super().__init__()
+ self.codebook_dim = embedding_dim
+ self.num_tokens = n_embed
+ self.beta = beta
+ self.decay = decay
+
+ # learnable = True if orthogonal_reg_weight > 0 else False
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
+
+ self.statistic_code_usage = statistic_code_usage
+ if statistic_code_usage:
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
+ if distributed.is_available() and distributed.is_initialized():
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
+ self.all_reduce_fn = distributed.all_reduce
+ else:
+ self.all_reduce_fn = nn.Identity()
+
+ def reset_cluster_size(self, device):
+ if self.statistic_code_usage:
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
+ self.cluster_size = self.cluster_size.to(device)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z, 'b c h w -> b h w c'
+ # z = rearrange(z, 'b c h w -> b h w c')
+ # z = z.transpose(1, 2)
+ z = l2norm(z)
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ self.embedding.init_embed_(z_flattened)
+
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+
+ if not self.training:
+ with torch.no_grad():
+ cluster_size = encodings.sum(0)
+ self.all_reduce_fn(cluster_size)
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
+
+ if self.training and self.embedding.update:
+ # EMA cluster size
+
+ bins = encodings.sum(0)
+ self.all_reduce_fn(bins)
+
+ # self.embedding.cluster_size_ema_update(bins)
+ ema_inplace(self.cluster_size, bins, self.decay)
+
+ zero_mask = (bins == 0)
+ bins = bins.masked_fill(zero_mask, 1.)
+
+ embed_sum = z_flattened.t() @ encodings
+ self.all_reduce_fn(embed_sum)
+
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
+ embed_normalized = l2norm(embed_normalized)
+
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
+ embed_normalized)
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ # z_q, 'b h w c -> b c h w'
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
+ # z_q = z_q.transpose(1, 2)
+ return z_q, loss, encoding_indices
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/builder.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..05031ae973b3079f463e24292862de2b411c0319
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/builder.py
@@ -0,0 +1,119 @@
+from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
+
+import logging
+logger = logging.getLogger(__name__)
+
+def build_bert(model_config, pretrain, checkpoint, encoder_width=None):
+ """build text encoder.
+
+ Args:
+ model_config (dict): model config.
+ pretrain (bool): Whether to do pretrain or finetuning.
+ checkpoint (bool): whether to do gradient_checkpointing.
+
+ Returns: TODO
+
+ """
+ bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
+ if encoder_width is None:
+ bert_config.encoder_width = model_config.vision_encoder.d_model
+ else:
+ bert_config.encoder_width = encoder_width
+
+ bert_config.gradient_checkpointing = checkpoint
+ bert_config.fusion_layer = model_config.text_encoder.fusion_layer
+
+ if not model_config.multimodal.enable:
+ bert_config.fusion_layer = bert_config.num_hidden_layers
+
+ if pretrain:
+ try:
+ text_encoder, loading_info = BertForMaskedLM.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ output_loading_info=True,
+ local_files_only=True
+ )
+ except:
+ text_encoder, loading_info = BertForMaskedLM.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ output_loading_info=True,
+ local_files_only=False
+ )
+ else:
+ try:
+ text_encoder, loading_info = BertModel.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ add_pooling_layer=False,
+ output_loading_info=True,
+ local_files_only=True
+ )
+ except:
+ text_encoder, loading_info = BertModel.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ add_pooling_layer=False,
+ output_loading_info=True,
+ local_files_only=False
+ )
+
+ return text_encoder
+
+
+def build_bert_decoder(model_config, checkpoint, only_fusion_layer=True):
+ """build text decoder the same as the multimodal encoder.
+
+ Args:
+ model_config (dict): model config.
+ pretrain (bool): Whether to do pretrain or finetuning.
+ checkpoint (bool): whether to do gradient_checkpointing.
+
+ Returns: TODO
+
+ """
+ bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
+ bert_config.encoder_width = model_config.vision_encoder.d_model
+ bert_config.gradient_checkpointing = checkpoint
+
+ bert_config.fusion_layer = 0
+
+ if only_fusion_layer:
+ bert_config.num_hidden_layers = (
+ bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
+ )
+
+ text_decoder, loading_info = BertLMHeadModel.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ output_loading_info=True,
+ local_files_only=True
+ )
+
+ return text_decoder
+
+def build_lm_bert_decoder(model_config, checkpoint):
+ """build text decoder the same as the multimodal encoder.
+
+ Args:
+ model_config (dict): model config.
+ pretrain (bool): Whether to do pretrain or finetuning.
+ checkpoint (bool): whether to do gradient_checkpointing.
+
+ Returns: TODO
+
+ """
+ bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
+ bert_config.encoder_width = model_config.vision_encoder.d_model
+ bert_config.gradient_checkpointing = checkpoint
+ bert_config.fusion_layer = model_config.text_encoder.fusion_layer
+
+ text_decoder, loading_info = BertLMHeadModel.from_pretrained(
+ model_config.text_encoder.pretrained,
+ config=bert_config,
+ output_loading_info=True,
+ local_files_only=True
+ )
+
+ return text_decoder
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/med.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b6a5f475e02dd75dabd012a016d99e5d1374f5c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/med.py
@@ -0,0 +1,1270 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+ Based on huggingface code base
+ https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+from .xbert import BertAttention, BertIntermediate, BertOutput, BertPooler, BertOnlyMLMHead
+
+logging.set_verbosity_error()
+logger = logging.get_logger(__name__)
+
+
+class BaseEncoder(nn.Module):
+ """
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward_features(self, samples, **kwargs):
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ return list(self.parameters())[0].device
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ if config.add_type_embeddings:
+ self.token_type_embeddings = nn.Embedding(
+ config.type_vocab_size, config.hidden_size
+ )
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+ )
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if token_type_ids is not None:
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ else:
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+
+ # compatibility for ALBEF and BLIP
+ if hasattr(self.config, 'cross_freq') and self.config.cross_freq > 0:
+ self.fusion_layer = self.config.num_hidden_layers
+ add_cross_attention = self.config.add_cross_attention and (layer_num % self.config.cross_freq == 0)
+ elif hasattr(self.config, 'fusion_layer'):
+ # ALBEF & ALPRO
+ self.fusion_layer = self.config.fusion_layer
+ add_cross_attention = (
+ self.fusion_layer <= layer_num and self.config.add_cross_attention
+ )
+
+ self.fusion_layer = self.fusion_layer
+ else:
+ # BLIP
+ self.fusion_layer = self.config.num_hidden_layers
+ add_cross_attention = self.config.add_cross_attention
+
+ if hasattr(self.config, 'must_fusion_layer') and layer_num in config.must_fusion_layer:
+ add_cross_attention = True
+
+
+ logger.info(f'layer_num: {layer_num} fusion_layer: {self.fusion_layer}, add_cross_attention: {add_cross_attention}')
+
+ # if self.config.add_cross_attention:
+ if add_cross_attention:
+ self.crossattention = BertAttention(
+ config, is_cross_attention=self.config.add_cross_attention
+ )
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = (
+ past_key_value[:2] if past_key_value is not None else None
+ )
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ # TODO line 482 in albef/models/xbert.py
+ # compatibility for ALBEF and BLIP
+ if mode in ["multimodal", "fusion"] and hasattr(self, "crossattention"):
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+
+ if isinstance(encoder_hidden_states, list):
+ raise NotImplementedError
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[
+ (self.layer_num - self.fusion_layer)
+ % len(encoder_hidden_states)
+ ],
+ encoder_attention_mask[
+ (self.layer_num - self.fusion_layer)
+ % len(encoder_hidden_states)
+ ],
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1]
+
+ else:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = (
+ outputs + cross_attention_outputs[1:-1]
+ ) # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode="multimodal",
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ () if output_attentions and self.config.add_cross_attention else None
+ )
+
+ next_decoder_cache = () if use_cache else None
+
+ # try:
+ # # ALBEF
+ # fusion_layer = self.config.fusion_layer
+ # except AttributeError:
+ # BLIP
+ fusion_layer = self.config.num_hidden_layers
+
+ if mode == "text":
+ start_layer = 0
+ # output_layer = self.config.fusion_layer
+ output_layer = fusion_layer
+
+ elif mode == "fusion":
+ raise NotImplementedError
+ # start_layer = self.config.fusion_layer
+ start_layer = fusion_layer
+ output_layer = self.config.num_hidden_layers
+
+ elif mode == "multimodal":
+ start_layer = 0
+ output_layer = self.config.num_hidden_layers
+
+ # compatibility for ALBEF and BLIP
+ # for i in range(self.config.num_hidden_layers):
+ for i in range(start_layer, output_layer):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ # TODO pay attention to this.
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ # NOTE maybe wrong
+ # if hasattr(layer_module, "crossattention"):
+ # # all_cross_attentions = all_cross_attentions + (layer_outputs[3], )
+ # all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, seq_length, prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode="multimodal",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError(
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds"
+ )
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+ 0
+ ].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ # token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode="multimodal",
+ soft_labels=None,
+ alpha=0,
+ return_logits=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_embeds=encoder_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+
+ if soft_labels is not None:
+ loss_distill = -torch.sum(
+ F.log_softmax(prediction_scores, dim=-1) * soft_labels, dim=-1
+ )
+ loss_distill = loss_distill[labels != -100].mean()
+ masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return (
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, attention_mask=None, **model_kwargs
+ ):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ assert (
+ self.config.pad_token_id is not None
+ ), "The PAD token should be defined for generation"
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
+ dim=-1,
+ )
+ dummy_token = torch.full(
+ (effective_batch_size, 1),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ mode="multimodal",
+ soft_labels=None,
+ alpha=0,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if soft_labels is not None:
+ loss_distill = -torch.sum(
+ F.log_softmax(shifted_prediction_scores, dim=-1) * soft_labels, dim=-1
+ )
+ loss_distill = (loss_distill * (labels != -100)).sum(1)
+ lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past=None, attention_mask=None, **model_kwargs
+ ):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+class XBertLMHeadDecoder(BertLMHeadModel):
+ """
+ This class decouples the decoder forward logic from the VL model.
+ In this way, different VL models can share this decoder as long as
+ they feed encoder_embeds as required.
+ """
+
+ @classmethod
+ def from_config(cls, med_config_path, pretrained=None):
+
+ med_config = BertConfig.from_json_file(med_config_path)
+
+ if pretrained is not None:
+ return cls.from_pretrained(pretrained, config=med_config, local_files_only=True)
+ else:
+ return cls(config=med_config)
+
+ def generate_from_encoder(
+ self,
+ tokenized_prompt,
+ visual_embeds,
+ sep_token_id,
+ pad_token_id,
+ use_nucleus_sampling=False,
+ num_beams=3,
+ max_length=30,
+ min_length=10,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ **kwargs
+ ):
+
+ if not use_nucleus_sampling:
+ num_beams = num_beams
+ visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0)
+
+ image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to(
+ self.device
+ )
+
+ model_kwargs = {
+ "encoder_hidden_states": visual_embeds,
+ "encoder_attention_mask": image_atts,
+ }
+
+ if use_nucleus_sampling:
+ # nucleus sampling
+ outputs = self.generate(
+ input_ids=tokenized_prompt.input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ eos_token_id=sep_token_id,
+ pad_token_id=pad_token_id,
+ repetition_penalty=1.1,
+ **model_kwargs
+ )
+ else:
+ # beam search
+ outputs = self.generate(
+ input_ids=tokenized_prompt.input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ eos_token_id=sep_token_id,
+ pad_token_id=pad_token_id,
+ repetition_penalty=repetition_penalty,
+ **model_kwargs
+ )
+
+ return outputs
+
+
+class XBertEncoder(BertModel, BaseEncoder):
+ @classmethod
+ def from_config(cls, med_config_path, pretrained=None, fusion_layer=None):
+
+ med_config = BertConfig.from_json_file(med_config_path)
+ if fusion_layer is not None:
+ med_config.fusion_layer = fusion_layer
+
+ if pretrained is not None:
+ return cls.from_pretrained(
+ pretrained, config=med_config, add_pooling_layer=False, local_files_only=True)
+ else:
+ return cls(config=med_config, add_pooling_layer=False)
+
+ def forward_automask(self, tokenized_text, visual_embeds, **kwargs):
+ image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to(
+ self.device
+ )
+
+ text = tokenized_text
+ text_output = super().forward(
+ text.input_ids,
+ attention_mask=text.attention_mask,
+ encoder_hidden_states=visual_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ return text_output
+
+ def forward_text(self, tokenized_text, **kwargs):
+ text = tokenized_text
+ token_type_ids = kwargs.get("token_type_ids", None)
+
+ text_output = super().forward(
+ text.input_ids,
+ attention_mask=text.attention_mask,
+ token_type_ids=token_type_ids,
+ return_dict=True,
+ mode="text",
+ )
+
+ return text_output
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/tokenization_bert.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/tokenization_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..66e8d8e55b738766d045fe85dac73417b175ffd7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/tokenization_bert.py
@@ -0,0 +1,546 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Bert."""
+
+
+import collections
+import os
+import unicodedata
+from typing import List, Optional, Tuple
+
+from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
+ }
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "bert-base-uncased": 512,
+ "bert-large-uncased": 512,
+ "bert-base-cased": 512,
+ "bert-large-cased": 512,
+ "bert-base-multilingual-uncased": 512,
+ "bert-base-multilingual-cased": 512,
+ "bert-base-chinese": 512,
+ "bert-base-german-cased": 512,
+ "bert-large-uncased-whole-word-masking": 512,
+ "bert-large-cased-whole-word-masking": 512,
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
+ "bert-base-cased-finetuned-mrpc": 512,
+ "bert-base-german-dbmdz-cased": 512,
+ "bert-base-german-dbmdz-uncased": 512,
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
+ "wietsedv/bert-base-dutch-cased": 512,
+}
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "bert-base-uncased": {"do_lower_case": True},
+ "bert-large-uncased": {"do_lower_case": True},
+ "bert-base-cased": {"do_lower_case": False},
+ "bert-large-cased": {"do_lower_case": False},
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
+ "bert-base-multilingual-cased": {"do_lower_case": False},
+ "bert-base-chinese": {"do_lower_case": False},
+ "bert-base-german-cased": {"do_lower_case": False},
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
+}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BertTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a BERT tokenizer. Based on WordPiece.
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
+ Users should refer to this superclass for more information regarding those methods.
+ Args:
+ vocab_file (:obj:`str`):
+ File containing the vocabulary.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (:obj:`Iterable`, `optional`):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ :obj:`do_basic_tokenize=True`
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to tokenize Chinese characters.
+ This should likely be deactivated for Japanese (see this `issue
+ `__).
+ strip_accents: (:obj:`bool`, `optional`):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for :obj:`lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs
+ ):
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ vocab_file)
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict(
+ [(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(
+ vocab=self.vocab, unk_token=self.unk_token)
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
+
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """ Converts a sequence of tokens (string) in a single string. """
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+ - single sequence: ``[CLS] X ``
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer ``prepare_for_model`` method.
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+ Returns:
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+ ::
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
+ Args:
+ token_ids_0 (:obj:`List[int]`):
+ List of IDs.
+ token_ids_1 (:obj:`List[int]`, `optional`):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
+ sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
+ VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix +
+ "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!".format(
+ vocab_file)
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+class BasicTokenizer(object):
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+ Args:
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (:obj:`Iterable`, `optional`):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ :obj:`do_basic_tokenize=True`
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to tokenize Chinese characters.
+ This should likely be deactivated for Japanese (see this `issue
+ `__).
+ strip_accents: (:obj:`bool`, `optional`):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for :obj:`lowercase` (as in the original BERT).
+ """
+
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
+ WordPieceTokenizer.
+ Args:
+ **never_split**: (`optional`) list of str
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(
+ set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if never_split is not None and text in never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer`.
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/xbert.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/xbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..a350ff40e23639da1a7ccd043811203de07b7ad4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/bert/xbert.py
@@ -0,0 +1,2170 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BERT model. """
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from torch import Tensor, device, dtype, nn
+from torch.nn import CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+# from transformers.models.bert.configuration_bert import BertConfig
+from transformers.configuration_utils import PretrainedConfig
+from transformers.file_utils import (ModelOutput, add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions, MaskedLMOutput,
+ MultipleChoiceModelOutput, NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput, SequenceClassifierOutput,
+ TokenClassifierOutput)
+from transformers.modeling_utils import (PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer)
+from transformers.utils import logging
+
+transformers.logging.set_verbosity_error()
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "BertConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+
+BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bert-base-uncased",
+ "bert-large-uncased",
+ "bert-base-cased",
+ "bert-large-cased",
+ "bert-base-multilingual-uncased",
+ "bert-base-multilingual-cased",
+ "bert-base-chinese",
+ "bert-base-german-cased",
+ "bert-large-uncased-whole-word-masking",
+ "bert-large-cased-whole-word-masking",
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
+ "bert-large-cased-whole-word-masking-finetuned-squad",
+ "bert-base-cased-finetuned-mrpc",
+ "bert-base-german-dbmdz-cased",
+ "bert-base-german-dbmdz-uncased",
+ "cl-tohoku/bert-base-japanese",
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
+ "cl-tohoku/bert-base-japanese-char",
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
+ "TurkuNLP/bert-base-finnish-cased-v1",
+ "TurkuNLP/bert-base-finnish-uncased-v1",
+ "wietsedv/bert-base-dutch-cased",
+ # See all BERT models at https://huggingface.co/models?filter=bert
+]
+
+
+class BertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
+ instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the BERT
+ [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertModel, BertConfig
+
+ >>> # Initializing a BERT bert-base-uncased style configuration
+ >>> configuration = BertConfig()
+
+ >>> # Initializing a model from the bert-base-uncased style configuration
+ >>> model = BertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "bert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ cross_module="ca",
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+ self.cross_module = cross_module
+
+
+def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n
+ in [
+ "adam_v",
+ "adam_m",
+ "AdamWeightDecayOptimizer",
+ "AdamWeightDecayOptimizer_1",
+ "global_step",
+ ]
+ for n in name
+ ):
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info("Skipping {}".format("/".join(name)))
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert (
+ pointer.shape == array.shape
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info("Initialize PyTorch weight {}".format(name))
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=self.position_ids.device
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # added `attention_scores` to return tuple
+ outputs = (
+ (context_layer, attention_probs, attention_scores)
+ if output_attentions
+ else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+
+ self.self = BertSelfAttention(config, is_cross_attention)
+
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ # add attentions if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+ return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+
+ self.has_cross_attention = layer_num >= config.fusion_layer
+ if self.has_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=True)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if self.has_cross_attention:
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+
+ if type(encoder_hidden_states) == list:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[
+ (self.layer_num - self.config.fusion_layer)
+ % len(encoder_hidden_states)
+ ],
+ encoder_attention_mask[
+ (self.layer_num - self.config.fusion_layer)
+ % len(encoder_hidden_states)
+ ],
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1]
+
+ else:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
+ attention_output = cross_attention_outputs[0]
+ # add cross attentions if we output attention weights
+ outputs = outputs + cross_attention_outputs[1:-1]
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+ logger.info(f"build bert with cross_module: {config.cross_module}")
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode="multi_modal",
+ normalize_attention=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_cross_attentions = () if output_attentions else None
+
+ next_decoder_cache = () if use_cache else None
+
+ if (
+ mode == "text" or mode == "temporal"
+ ): # temporal is added and used for temporal att module.
+ start_layer = 0
+ output_layer = self.config.fusion_layer
+
+ elif mode == "fusion":
+ start_layer = self.config.fusion_layer
+ output_layer = self.config.num_hidden_layers
+
+ elif mode == "multi_modal":
+ start_layer = 0
+ output_layer = self.config.num_hidden_layers
+
+ for i in range(start_layer, output_layer):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
+ "`use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_reentrant=False,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ # whether to output normalized attention,
+ # note for unnormalized attention, there is a mask added
+ offset = int(normalize_attention)
+ # all_self_attentions = all_self_attentions + (layer_outputs[1], )
+ all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],)
+ if hasattr(layer_module, "crossattention"):
+ # all_cross_attentions = all_cross_attentions + (layer_outputs[3], )
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class BertPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+@dataclass
+class BertForPreTrainingOutput(ModelOutput):
+ """
+ Output type of :class:`~transformers.BertForPreTraining`.
+ Args:
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: torch.FloatTensor = None
+ seq_relationship_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+BERT_START_DOCSTRING = r"""
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
+ pruning heads etc.)
+ This model is also a PyTorch `torch.nn.Module `__
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+ general usage and behavior.
+ Parameters:
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+ weights.
+"""
+
+BERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
+ Indices of input sequence tokens in the vocabulary.
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
+ details.
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+ 1]``:
+ - 0 corresponds to a `sentence A` token,
+ - 1 corresponds to a `sentence B` token.
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+ config.max_position_embeddings - 1]``.
+ `What are position IDs? <../glossary.html#position-ids>`_
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
+ vectors than the model's internal embedding lookup matrix.
+ output_attentions (:obj:`bool`, `optional`):
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+ tensors for more detail.
+ output_hidden_states (:obj:`bool`, `optional`):
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+ more detail.
+ return_dict (:obj:`bool`, `optional`):
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ BERT_START_DOCSTRING,
+)
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, seq_length, prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode="multi_modal",
+ normalize_attention=True,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError(
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds"
+ )
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+ 0
+ ].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ normalize_attention=normalize_attention,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForPreTraining(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertPreTrainingHeads(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(
+ output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertForPreTraining
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+ next_sentence_loss = loss_fct(
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)
+ )
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return BertForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """,
+ BERT_START_DOCSTRING,
+)
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=True,
+ reduction="mean",
+ mode="multi_modal",
+ normalize_attention=True,
+ soft_labels=None,
+ alpha=0,
+ return_logits=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ # logger.info(f"before: {labels.min()}, {labels.max()}")
+ # logger.info(f"before input_ids: {input_ids.min()}, {input_ids.max()}")
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ normalize_attention=normalize_attention,
+ )
+
+
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+ # logger.info(f"new: {labels.min()}, {labels.max()} {prediction_scores.shape}")
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+
+ # logger.info(f"mid2 {labels.min()}, {labels.max()} {prediction_scores.shape}")
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+
+ # logger.info(f"after {self.config.vocab_size}, {labels.min()}, {labels.max()}")
+ loss_fct = CrossEntropyLoss(reduction=reduction)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if soft_labels is not None:
+ loss_distill = -torch.sum(
+ F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, dim=-1
+ )
+ loss_distill = (loss_distill * (labels != -100)).sum(1)
+ lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past=None, attention_mask=None, **model_kwargs
+ ):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@dataclass
+class MaskedLMOutputWithDistill(MaskedLMOutput):
+ loss_aux: Optional[torch.FloatTensor] = None
+ loss_distill: Optional[torch.FloatTensor] = None
+
+
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING
+)
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def tie_aux_decoder_weights(self, module, aux_modules):
+ """Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
+ for m in aux_modules:
+ m.predictions.decoder.weight = module.predictions.decoder.weight
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode="multi_modal",
+ normalize_attention=True,
+ soft_labels=None,
+ alpha=0,
+ return_logits=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_embeds=encoder_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ normalize_attention=normalize_attention,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ masked_lm_loss_aux = 0.0
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+
+ if soft_labels is not None:
+ loss_distill = -torch.sum(
+ F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1
+ )
+ loss_distill = loss_distill[labels != -100].mean()
+ masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ # changed from MaskedLMOutput to MaskedLMOutputWithDistill
+ return MaskedLMOutputWithDistill(
+ loss=masked_lm_loss,
+ loss_aux=masked_lm_loss_aux,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ assert (
+ self.config.pad_token_id is not None
+ ), "The PAD token should be defined for generation"
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
+ )
+ dummy_token = torch.full(
+ (effective_batch_size, 1),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
+ BERT_START_DOCSTRING,
+)
+class BertForNextSentencePrediction(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.cls = BertOnlyNSPHead(config)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(
+ output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ **kwargs,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ """
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return (
+ ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+ )
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # We are doing regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1), labels.view(-1))
+ else:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForMultipleChoice(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
+ :obj:`input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = (
+ attention_mask.view(-1, attention_mask.size(-1))
+ if attention_mask is not None
+ else None
+ )
+ token_type_ids = (
+ token_type_ids.view(-1, token_type_ids.size(-1))
+ if token_type_ids is not None
+ else None
+ )
+ position_ids = (
+ position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ )
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForTokenClassification(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+ 1]``.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss,
+ labels.view(-1),
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ BERT_START_DOCSTRING,
+)
+class BertForQuestionAnswering(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions.clamp_(0, ignored_index)
+ end_positions.clamp_(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/__init__.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d3bdba46e3dbc7aaf20bdbf492f6d120b90f50
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/__init__.py
@@ -0,0 +1,4 @@
+from .internvl_clip_vision import internvl_clip_6b
+from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224
+from .internvideo2_clip_vision import InternVideo2
+from .internvideo2_clip_text import LLaMA, Tokenizer
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/flash_attention_class.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/flash_attention_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..04edd18ee4efcd0fd9f50ea38087a4417792c3fa
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/flash_attention_class.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+
+from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
+ max_s=None, need_weights=False):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
+ if unpadded: (nnz, 3, h, d)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert not need_weights
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
+ assert qkv.is_cuda
+
+ if cu_seqlens is None:
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = seqlen
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, batch_size, seqlen),
+ 'b s (h d) -> b s h d', h=nheads)
+ else:
+ assert max_s is not None
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+
+ return output, None
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa09ad023290b286e37adb3379877362cde354e0
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2.py
@@ -0,0 +1,797 @@
+import math
+import logging
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed, interpolate_pos_embed_internvideo2
+from .flash_attention_class import FlashAttention
+
+logger = logging.getLogger(__name__)
+
+try:
+ from flash_attn.modules.mlp import FusedMLP
+except:
+ logger.warn(f'FusedMLP of flash_attn is not installed!!!')
+
+try:
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+except:
+ logger.warn(f'DropoutAddRMSNorm of flash_attn is not installed!!!')
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ # print(f"\033[31m use_checkpoint [0m")
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+ self.num_img_patches = self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class Linear_Decoder(nn.Module):
+ def __init__(self, in_channels=1408, out_channels=3200,
+ norm_layer=nn.LayerNorm, clip_norm_type='l2'):
+ super().__init__()
+ self.clip_norm_type = clip_norm_type
+ logger.info(f'Normalization Type: {clip_norm_type}')
+
+ self.head = nn.Linear(in_channels, out_channels)
+ self.norm = norm_layer(out_channels)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ x = self.norm(self.head(x))
+
+ if self.clip_norm_type == 'l2':
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.clip_norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ return x
+
+
+class PretrainInternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25,
+ embed_dim: int = 1408,
+ num_heads: int = 16,
+ mlp_ratio: float = 48/11,
+ init_values: float = 1e-5,
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False,
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ sep_image_video_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ # for unmasked teacher
+ clip_teacher_embed_dim: int = 3200,
+ clip_teacher_final_dim: int = 768, # if 0, not distill final features
+ clip_norm_type: str = 'l2',
+ clip_return_layer: int = 1,
+ clip_student_return_interval: int = 1,
+ ):
+ super().__init__()
+
+ self.num_frames = num_frames
+ self.tubelet_size = tubelet_size
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ self.depth = depth
+ self.clip_norm_type = clip_norm_type
+ self.return_index = []
+ for i in range(clip_return_layer):
+ self.return_index.append(depth - int(i * clip_student_return_interval) - 1)
+ logger.info(f'Normalization Type: {clip_norm_type}')
+ logger.info(f'Strudent Return Index: {self.return_index}')
+
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ num_img_patches = self.patch_embed.num_img_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ self.sep_image_video_pos_embed = sep_image_video_pos_embed
+ if sep_pos_embed:
+ raise NotImplementedError
+ else:
+ if sep_image_video_pos_embed:
+ logger.info("Use joint position embedding, for image and video we use different pos_embed.")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
+ # for CLIP decoder
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
+ else:
+ logger.info("Use joint position embedding, for image and video we use same pos_embed.")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ logger.info(f"Droppath rate: {dpr}")
+ logger.info(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
+
+ # CLIP decoder
+ self.clip_decoder = nn.ModuleList([
+ Linear_Decoder(
+ in_channels=embed_dim,
+ out_channels=clip_teacher_embed_dim,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
+ clip_norm_type=clip_norm_type
+ ) for _ in range(clip_return_layer)
+ ])
+ self.final_clip_decoder = nn.Identity()
+ if clip_teacher_final_dim > 0:
+ self.final_clip_decoder = Linear_Decoder(
+ in_channels=clip_embed_dim,
+ out_channels=clip_teacher_final_dim,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
+ clip_norm_type=clip_norm_type
+ )
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def init_pos_embed(self):
+ logger.info("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ raise NotImplementedError
+ else:
+ # trunc_normal_(self.pos_embed, std=.02)
+ # trunc_normal_(self.clip_pos_embed, std=.02)
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+ self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ if self.sep_image_video_pos_embed:
+ img_pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ 1,
+ cls_token=True
+ )
+ self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
+ self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'img_pos_embed',
+ 'cls_token',
+ 'clip_pos_embed',
+ 'clip_pos_embed_spatial',
+ 'clip_pos_embed_temporal',
+ 'clip_pos_embed_cls',
+ 'clip_img_pos_embed'
+ }
+
+ # @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False):
+ x = self.patch_embed(x.type(self.dtype))
+ # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ raise NotImplementedError
+ else:
+ if use_image:
+ if self.sep_image_video_pos_embed:
+ pos_embed = self.img_pos_embed
+ else:
+ # (1, num_img_patches + 1, embed_dim)
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
+ cls_pos_embed = self.pos_embed[:, 0:1, :]
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
+
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
+
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
+ # print('final img_pos_embed.shape:', pos_embed.shape)
+ else:
+ pos_embed = self.pos_embed
+ x = x + pos_embed
+
+ # mask tokens, ~mask means visible
+ if mask is not None:
+ x = x[~mask].reshape(B, -1, C)
+ else:
+ x = x.reshape(B, -1, C)
+
+ residual = None
+ x_clip = []
+ for idx, blk in enumerate(self.blocks):
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ # print(f"\033[31m这是{idx}, {x.shape}\033[0m")
+ x = blk(x, residual=residual)
+ # return intermediate features
+ if idx in self.return_index:
+ if isinstance(x, tuple) and len(x) == 2:
+ tmp_x, tmp_residual = x
+ if residual is not None:
+ x_clip.append(tmp_x + tmp_residual)
+ else:
+ x_clip.append(x)
+ if idx == (self.depth + x_vis_return_idx):
+ # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}')
+ break
+
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x_vis = x
+ if x_vis_only:
+ return x_vis
+
+ x_pool_vis = self.clip_projector(x_vis)
+ x_align = self.final_clip_decoder(x_pool_vis)
+
+ # align CLIP
+ x_clip = torch.stack(x_clip)
+ K, B, _, C_CLIP = x_clip.shape
+ # add pos_embed
+ if self.sep_pos_embed:
+ raise NotImplementedError
+ else:
+ if use_image:
+ if self.sep_image_video_pos_embed:
+ clip_pos_embed = self.clip_img_pos_embed
+ else:
+ # (1, num_img_patches + 1, embed_dim)
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
+ clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
+
+ clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
+
+ clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1)
+ # print('final img_pos_embed.shape:', pos_embed.shape)
+
+ else:
+ clip_pos_embed = self.clip_pos_embed
+
+ clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
+ if mask is not None:
+ x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
+ else:
+ x_clip = x_clip + clip_pos_embed.view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
+
+ # CLIP decoder
+ x_clip_align = []
+ for idx, clip_decoder in enumerate(self.clip_decoder):
+ x_clip_align.append(clip_decoder(x_clip[idx]))
+ x_clip_align = torch.stack(x_clip_align)
+
+
+ return x_vis, x_pool_vis, x_clip_align, x_align
+
+
+def pretrain_internvideo2_1b_patch14_224(config):
+ model = PretrainInternVideo2(
+ in_chans=3, img_size=224, patch_size=14,
+ embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
+ attn_pool_num_heads=16, qkv_bias=False,
+ drop_path_rate=0.25,
+ init_values=0.00001,
+ qk_normalization=True,
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
+ fused_mlp_heuristic=1,
+ layerscale_no_force_fp32=False,
+ num_frames=config.vision_encoder.num_frames,
+ tubelet_size=config.vision_encoder.tubelet_size,
+ sep_pos_embed=False,
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
+ use_checkpoint=config.vision_encoder.use_checkpoint,
+ checkpoint_num=config.vision_encoder.checkpoint_num,
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
+ clip_norm_type=config.vision_encoder.clip_norm_type,
+ clip_return_layer=config.vision_encoder.clip_return_layer,
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
+ )
+
+ if config.vision_encoder.pretrained is not None:
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
+ message = model.load_state_dict(state_dict, strict=False)
+ logger.info(message)
+ else:
+ logger.info("No pretrained weights!!!")
+ return model
+
+
+def pretrain_internvideo2_6b_patch14_224(config):
+ model = PretrainInternVideo2(
+ in_chans=3, img_size=224, patch_size=14,
+ embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
+ attn_pool_num_heads=16, qkv_bias=False,
+ drop_path_rate=0.3,
+ init_values=0.00001,
+ qk_normalization=True,
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
+ fused_mlp_heuristic=1,
+ layerscale_no_force_fp32=False,
+ num_frames=config.vision_encoder.num_frames,
+ tubelet_size=config.vision_encoder.tubelet_size,
+ sep_pos_embed=False,
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
+ use_checkpoint=config.vision_encoder.use_checkpoint,
+ checkpoint_num=config.vision_encoder.checkpoint_num,
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
+ clip_norm_type=config.vision_encoder.clip_norm_type,
+ clip_return_layer=config.vision_encoder.clip_return_layer,
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
+ )
+
+ if config.vision_encoder.pretrained is not None:
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info(msg)
+ else:
+ logger.info("No pretrained weights!!!")
+ return model
+
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+ img_size = 224
+
+ model = pretrain_internvideo2_1b_patch14_224(clip_return_layer=6).cuda().half()
+ # print(model)
+
+ # flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, img_size, img_size).cuda().half())
+ # s = time.time()
+ # print(flop_count_table(flops, max_depth=1))
+ # print(time.time()-s)
+
+ mask = torch.cat([
+ torch.ones(1, 8 * int(16 * 16 * 0.75)),
+ torch.zeros(1, 8 * int(16 * 16 * 0.25)),
+ torch.zeros(1, 1),
+ ], dim=-1).to(torch.bool).cuda()
+
+ output = model(torch.rand(4, 3, num_frames, img_size, img_size).cuda().half(), mask.repeat(4, 1))
+ print(output[0].shape)
+ print(output[1].shape)
+ print(output[2].shape)
+ print(output[3].shape)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_text.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..23ef60d924c413cda798b86eb73bd1f2ada32232
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_text.py
@@ -0,0 +1,69 @@
+import logging
+import numpy as np
+import torch
+import torch.nn.functional as F
+from peft import get_peft_model, LoraConfig, TaskType
+from torch import nn
+from transformers import LlamaForCausalLM, LlamaConfig
+
+from transformers import LlamaTokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class LLaMA(nn.Module):
+ def __init__(
+ self,
+ use_flash_attn: bool = True,
+ transformer_width: int = 4096,
+ llama_path: str = None,
+ use_lora: bool = True,
+ clip_embed_dim: int = 768,
+ ):
+ super().__init__()
+
+ self.use_flash_attn = use_flash_attn
+ self.transformer_width = transformer_width
+
+ """ text encoder of InternVL """
+ llama_config = LlamaConfig.from_pretrained(llama_path, local_files_only=True)
+ llama_config.causal = True
+ llama_config.use_flash_attention = use_flash_attn
+ model = LlamaForCausalLM.from_pretrained( # LLAMA model
+ llama_path, torch_dtype=torch.float16, config=llama_config, local_files_only=True)
+ if not use_lora:
+ self.transformer = model.model
+ else:
+ peft_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, lora_alpha=32, lora_dropout=0.1)
+ model = get_peft_model(model, peft_config)
+ self.transformer = model.base_model.model.model
+
+ self.transformer.gradient_checkpointing = True
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, clip_embed_dim))
+
+ def forward(self, text):
+ text_key_padding_mask = text > 0
+
+ x = self.transformer(input_ids=text, attention_mask=text_key_padding_mask).last_hidden_state
+ x = x[torch.arange(x.shape[0]), text_key_padding_mask.sum(1) - 1]
+ x = x @ self.text_projection
+
+ return x
+
+
+class Tokenizer(nn.Module):
+ def __init__(self, tokenizer_path="your_model_path/chinese_alpaca_lora_7b"):
+ super(Tokenizer, self).__init__()
+ self.tokenizer = LlamaTokenizer.from_pretrained(
+ tokenizer_path,
+ local_files_only=True,
+ legacy=False
+ )
+ self.tokenizer.pad_token = " " # allow padding
+ self.tokenizer.add_eos_token = True
+
+ def forward(self, text):
+ text = ["summarize:" + item for item in text]
+ text = self.tokenizer(text, return_tensors="pt", max_length=80, truncation=True, padding="max_length").input_ids
+ return text
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_vision.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..378da86ee3a214258c899a719b6e37f4b41ea98b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvideo2_clip_vision.py
@@ -0,0 +1,548 @@
+import logging
+import math
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
+from .flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+
+logger = logging.getLogger(__name__)
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ if drop_path > 0.:
+ logger.info(f"Use DropPath in projector: {drop_path}")
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class InternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25, # may need ablation
+ head_drop_path_rate: float = 0.,
+ embed_dim: int = 1408,
+ num_heads: int = 16,
+ mlp_ratio: float = 48/11,
+ init_values: float = 1e-5, # may need ablation
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False, # when True for training?
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, logger.info(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+ logger.info(mlp_ratio)
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+ self.T = num_frames // tubelet_size
+
+ if use_fused_rmsnorm:
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ if sep_pos_embed:
+ logger.info("Use seperable position embedding")
+ grid_size = self.patch_embed.grid_size
+ self.grid_size = grid_size
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ else:
+ logger.info("Use joint position embedding")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ logger.info(f"Droppath rate: {dpr}")
+ logger.info(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=head_drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
+ )
+
+ self.fc_norm = nn.Identity()
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def init_pos_embed(self):
+ logger.info("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ # trunc_normal_(self.pos_embed_spatial, std=.02)
+ # trunc_normal_(self.pos_embed_temporal, std=.02)
+ # trunc_normal_(self.pos_embed_cls, std=.02)
+ pos_embed_spatial = get_2d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ )
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ pos_embed_temporal = get_1d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[0], # t_size
+ )
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ else:
+ # trunc_normal_(self.pos_embed, std=.02)
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'cls_token'
+ }
+
+ def forward(self, x, use_image=False):
+ x = self.patch_embed(x.type(self.dtype))
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ if use_image:
+ pos_embed = self.pos_embed_spatial
+ else:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ pos_embed = torch.cat(
+ [
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
+ pos_embed,
+ ],
+ 1,
+ )
+ else:
+ if use_image:
+ cls_pos_embed = self.pos_embed[:, :1, :]
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.T, L, C).mean(dim=1)
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
+ else:
+ pos_embed = self.pos_embed
+
+ x = x + pos_embed
+
+ residual = None
+ for blk in self.blocks:
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x = self.clip_projector(x)
+
+ x = self.fc_norm(x)
+ return x
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvl_clip_vision.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvl_clip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f773bfadb56294f0899fa63cc4007f44ed45306
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/internvl_clip_vision.py
@@ -0,0 +1,558 @@
+import os
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+try:
+ from .flash_attention_class import FlashAttention
+except:
+ from flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+
+
+MODEL_PATH = 'your_model_path/internvl'
+_MODELS = {
+ # see InternVL
+ "internvl_c_13b_224px": os.path.join(MODEL_PATH, "internvl_c_13b_224px.pth"),
+}
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None, return_attn=False):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ if return_attn:
+ return x, attn.mean(1) # (B, n_head, n_q, C) => (B, n_q, C)
+ else:
+ return x, None
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None, return_attn=False):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x, attn = self.cross_attn(x_q, k=x_k, v=x_v, return_attn=return_attn)
+ return x, attn
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x, return_attn=False):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x, attn = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None, return_attn=return_attn)
+ x = x.squeeze(1)
+ return x, attn
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv3d(
+ in_chans, embed_dim,
+ kernel_size=(1, patch_size[0], patch_size[1]),
+ stride=(1, patch_size[0], patch_size[1]),
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(3).permute(0, 2, 3, 1) # (N, C, T, H, W) => (N, T, H * W, C)
+ x = self.norm(x)
+ return x
+
+
+class InternVL_CLIP(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.2,
+ embed_dim: int = 3200,
+ num_heads: int = 25,
+ mlp_ratio: int = 4,
+ init_values: float = 0.1,
+ qk_normalization: bool = True,
+ depth: int = 48,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ with_cp: bool = False,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = True,
+ # for unmasked teacher
+ clip_norm_type: str = 'l2',
+ return_attn: bool = True,
+ clip_return_layer: int = 1,
+ clip_return_interval: int = 1,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ self.clip_norm_type = clip_norm_type
+ self.return_attn = return_attn
+ self.return_index = []
+ for i in range(clip_return_layer):
+ self.return_index.append(depth - int(i * clip_return_interval) - 1)
+ print(f'Normalization Type: {clip_norm_type}')
+ print(f'Return Attention: {return_attn}')
+ print(f'Teacher Return Interval: {self.return_index}')
+
+ """ only use image encoder of InternVL """
+ if use_fused_rmsnorm:
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp,
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def forward(self, image):
+ x = self.patch_embed(image.type(self.dtype))
+ B, T, HW, C = x.size()
+ x = x.reshape(B * T, HW, C)
+
+ cls_tokens = self.cls_token.expand(B * T, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+
+ residual = None
+ z = []
+ for idx, blk in enumerate(self.blocks):
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ # return intermediate features
+ if idx in self.return_index:
+ if isinstance(x, tuple) and len(x) == 2:
+ tmp_x, tmp_residual = x
+ if residual is not None:
+ z.append(tmp_x + tmp_residual)
+ else:
+ z.append(x)
+
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x, attn = self.clip_projector(x, return_attn=self.return_attn)
+
+ if self.clip_norm_type == 'l2':
+ # normalization of intermediate features
+ z = torch.stack(z) # (K, BT, HW+1, C)
+ K = z.shape[0]
+ cls_tokens, z = z[:, :, :1, :], z[:, :, 1:, :]
+ cls_tokens = cls_tokens.view(K, B, T, 1, C).mean(2) # (K, BT, 1, C) => (K, B, 1, C)
+ z = z.reshape(K, B, T * HW, C)
+ z = torch.cat((cls_tokens, z), dim=2) # (K, B, HWT+1, C)
+ z = z / z.norm(dim=-1, keepdim=True)
+ # normalization of final features
+ x = x.view(B, T, -1).mean(1) # (BT, C) => (B, C)
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.clip_norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ if self.return_attn:
+ return z, x, attn[:, 0, 1:] # (B * T, HW)
+ else:
+ return z, x
+
+
+def inflate_weight(weight_2d, time_dim, center=True):
+ print(f'Init center: {center}')
+ if center:
+ weight_3d = torch.zeros(*weight_2d.shape)
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ middle_idx = time_dim // 2
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
+ else:
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ weight_3d = weight_3d / time_dim
+ return weight_3d
+
+
+def process_checkpoint(ckpt, model):
+ new_ckpt = {}
+ state_dict_3d = model.state_dict()
+ for k, v in ckpt['module'].items():
+ new_k = k
+ if 'patch_embed' in new_k and new_k in state_dict_3d.keys() and v.shape != state_dict_3d[new_k].shape:
+ print(new_k)
+ print(f'Inflate: {k}, {v.shape} => {state_dict_3d[new_k].shape}')
+ time_dim = state_dict_3d[new_k].shape[2]
+ v = inflate_weight(v, time_dim)
+ new_ckpt[new_k] = v
+
+ # interpolate position embedding
+ pos_embed_checkpoint = new_ckpt['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.num_patches
+ orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
+ new_size = int(num_patches ** 0.5)
+ if orig_size != new_size:
+ print(f'pos_embed from {orig_size} to {new_size}')
+ extra_tokens = pos_embed_checkpoint[:, :1]
+ pos_tokens = pos_embed_checkpoint[:, 1:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ new_ckpt['pos_embed'] = new_pos_embed
+
+ return new_ckpt
+
+
+def internvl_clip_6b(
+ img_size,
+ clip_norm_type='l2',
+ return_attn=True,
+ clip_return_layer=1,
+ clip_return_interval=1
+ ):
+ model = InternVL_CLIP(
+ img_size=img_size,
+ layerscale_no_force_fp32=False,
+ clip_norm_type=clip_norm_type,
+ return_attn=return_attn,
+ clip_return_layer=clip_return_layer,
+ clip_return_interval=clip_return_interval,
+ )
+
+ ckpt = torch.load(_MODELS["internvl_c_13b_224px"], map_location='cpu')
+ new_ckpt = process_checkpoint(ckpt, model)
+ message = model.load_state_dict(new_ckpt, strict=False)
+ print(message)
+ return model.eval()
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ num_frames = 8
+ img_size = 224
+ video = torch.rand(1, 3, num_frames, img_size, img_size).cuda().half()
+
+ model = internvl_clip_6b(img_size).cuda().half()
+ # flops = FlopCountAnalysis(model, video)
+ model(video)
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a98c5cd4fa4fa4c0c6a6ecdff44188b8e80856
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py
@@ -0,0 +1,299 @@
+import numpy as np
+import torch
+import logging
+
+logger = logging.getLogger(__name__)
+
+# --------------------------------------------------------
+# 3D sine-cosine position embedding
+# References:
+# MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
+# --------------------------------------------------------
+def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ t_size: int of the temporal size
+ return:
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ assert embed_dim % 4 == 0
+ embed_dim_spatial = embed_dim // 4 * 3
+ embed_dim_temporal = embed_dim // 4
+
+ # spatial
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
+ embed_dim_spatial, grid
+ )
+
+ # temporal
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
+ embed_dim_temporal, grid_t
+ )
+
+ # concate: [T, H, W] order
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(
+ pos_embed_temporal, grid_size**2, axis=1
+ ) # [T, H*W, D // 4]
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(
+ pos_embed_spatial, t_size, axis=0
+ ) # [T, H*W, D // 4 * 3]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
+
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
+ """
+ t_size: int of the temporal size
+ return:
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[0]
+ ) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[1]
+ ) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
+ if pos_name in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model[pos_name]
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
+
+ # we use 4 frames for pretraining
+ new_t_size = model.T
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // (new_t_size))** 0.5)
+
+ # class_token and dist_token are kept unchanged
+ if orig_t_size != new_t_size:
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
+ pos_embed_checkpoint = new_pos_embed
+
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
+
+
+def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
+ # interpolate position embedding
+ for pos_name in ['pos_embed', 'clip_pos_embed']:
+ if pos_name in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model[pos_name]
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
+
+ # we use 8 frames for pretraining
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
+ new_t_size = model.num_frames // model.tubelet_size
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // (new_t_size))** 0.5)
+
+ # class_token and dist_token are kept unchanged
+ if orig_t_size != new_t_size:
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
+ pos_embed_checkpoint = new_pos_embed
+
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
+
+ if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
+ raise NotImplementedError
+
+
+def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
+ pos_names = []
+ for k in checkpoint_model.keys():
+ if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
+ pos_names.append(k)
+
+ logger.info(f"pos names list for interpolating: {pos_names}")
+
+ assert len(pos_names) > 0, checkpoint_model.keys()
+
+ if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
+ raise NotImplementedError
+
+ # interpolate position embedding
+ for pos_name in pos_names:
+
+ pos_embed_checkpoint = checkpoint_model[pos_name]
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
+
+ # we use 8 frames for pretraining
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
+ new_t_size = model.num_frames // model.tubelet_size
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // (new_t_size))** 0.5)
+
+ # class_token and dist_token are kept unchanged
+ if orig_t_size != new_t_size:
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
+ pos_embed_checkpoint = new_pos_embed
+
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model[pos_name] = new_pos_embed
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/criterions.py b/third_party/InternVideo/InternVideo2/multi_modality/models/criterions.py
new file mode 100644
index 0000000000000000000000000000000000000000..4407583376dd2779916863a4be172f057111f0c7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/criterions.py
@@ -0,0 +1,486 @@
+import logging
+from functools import lru_cache
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from models.utils import allgather_wgrad
+
+logger = logging.getLogger(__name__)
+
+
+def get_sim(
+ vision_proj: torch.Tensor,
+ text_proj: torch.Tensor,
+ temp=1.0,
+ agg_method="mean",
+):
+ """calculate pair-wise video-text similarity.
+
+ Args:
+ vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
+ text_proj (torch.Tensor): The text representation. Shape: [B,C].
+ temp (torch.Tensor): The temperature. Shape: [].
+
+ Returns: The similarity between video and text. Shape: [B,B].
+
+ """
+ vision_proj = F.normalize(vision_proj, dim=-1)
+ text_proj = F.normalize(text_proj, dim=-1)
+ if vision_proj.ndim == 3:
+ sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B]
+ sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B]
+ if agg_method == "mean":
+ sim_v2t = sim_v2t.mean(1)
+ sim_t2v = sim_t2v.mean(1)
+ elif agg_method == "max":
+ sim_v2t = sim_v2t.max(1)[0]
+ sim_t2v = sim_t2v.max(1)[0]
+ elif text_proj.ndim == 3:
+ sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B]
+ sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B]
+ if agg_method == "mean":
+ sim_v2t = sim_v2t.mean(1)
+ sim_t2v = sim_t2v.mean(1)
+ elif agg_method == "max":
+ sim_v2t = sim_v2t.max(1)[0]
+ sim_t2v = sim_t2v.max(1)[0]
+ else:
+ sim_v2t = vision_proj @ text_proj.T / temp
+ sim_t2v = sim_v2t.T
+
+ return sim_v2t, sim_t2v
+
+
+class VTC_VTM_Loss(nn.Module):
+ """video-text contrastive and matching losses."""
+
+ def __init__(self, vtm_hard_neg):
+ super().__init__()
+ self.vtm_hard_neg = vtm_hard_neg
+
+ def vtc_loss(
+ self,
+ vision_proj: torch.Tensor,
+ text_proj: torch.Tensor,
+ idx: torch.Tensor,
+ temp=1.0,
+ all_gather=True,
+ agg_method="mean",
+ ):
+ """forward to calculate the loss
+
+ Args:
+ vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
+ text_proj (torch.Tensor): The text representation. Shape: [B,C].
+ idx (torch.Tensor): The index for each example. Shape: [B,].
+ temp (torch.Tensor): The temperature. Shape: [].
+ all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
+
+ Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
+
+ """
+ if all_gather:
+ gather_args = self.get_gather_args()
+ vision_proj = allgather_wgrad(vision_proj, gather_args)
+ text_proj = allgather_wgrad(text_proj, gather_args)
+ if idx is not None:
+ idx = allgather_wgrad(idx, gather_args)
+
+ sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp, agg_method=agg_method)
+
+ with torch.no_grad():
+ sim_v2t_targets = self.get_mask(sim_v2t, idx=idx, normalize=True)
+ sim_t2v_targets = sim_v2t_targets
+
+ loss_i2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_v2t_targets, dim=1).mean()
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_t2v_targets, dim=1).mean()
+
+ loss_vtc = (loss_i2t + loss_t2i) / 2
+ return loss_vtc
+
+ def vtm_loss(
+ self,
+ multimodal_encoder,
+ vtm_head: nn.Module,
+ temp,
+ vision_embeds: torch.Tensor,
+ text_embeds: torch.Tensor,
+ vision_proj: torch.Tensor,
+ text_proj: torch.Tensor,
+ text_atts: torch.Tensor,
+ idx: torch.Tensor,
+ ):
+ """video-text matching loss.
+
+ Args:
+ multinomial_encoder (nn.Module): The multimodal_encoder.
+ vtm_head (nn.Module): The head to produce the video-text matching score.
+ temp (torch.Tensor): temporature for similarity calculation.
+ vision_embeds (torch.Tensor): The features of all patches in the video. Shape: [B,T,L,C].
+ text_embeds (torch.Tensor): The features of all tokens in the text. Shape: [B,L,C].
+ vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
+ text_proj (torch.Tensor): The text representation. Shape: [B,C].
+ text_atts (torch.Tensor): The padded mask for text tokens. 0 is padded. Shape: [B,L].
+ idx (torch.Tensor): The index for each example. Shape: [B,].
+
+ Returns: TODO
+
+ """
+ with torch.no_grad():
+ sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp)
+ vision_atts = torch.ones(
+ vision_embeds.size()[:-1], dtype=torch.long, device=vision_embeds.device
+ )
+ weights_v2t = F.softmax(sim_v2t + 1e-4, dim=1) # (N, N)
+ weights_t2v = F.softmax(sim_t2v + 1e-4, dim=1)
+
+ mask = self.get_mask(sim_v2t, idx=idx).bool()
+ weights_v2t.masked_fill_(mask, 0)
+ weights_t2v.masked_fill_(mask, 0)
+ weights_v2t = torch.nan_to_num_(weights_v2t, nan=1e-2, posinf=1e-2, neginf=1e-2)
+ weights_t2v = torch.nan_to_num_(weights_t2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
+
+ # select a negative image for each text
+ if self.vtm_hard_neg:
+ vision_neg_indices = torch.multinomial(weights_t2v, 1).squeeze() # NOTE bs != 1
+ txt_neg_indices = torch.multinomial(weights_v2t, 1).squeeze()
+ else:
+ vision_neg_indices = self.get_rand_indices(mask, 1).squeeze()
+ txt_neg_indices = self.get_rand_indices(mask, 1).squeeze()
+
+ vision_embeds_neg = vision_embeds[vision_neg_indices] # [B, T*L, c]
+ text_embeds_neg = text_embeds[txt_neg_indices] # [B, L, d]
+ text_atts_neg = text_atts[txt_neg_indices]
+
+ # concat embeddings
+ vision_embeds_all = torch.cat([vision_embeds, vision_embeds_neg, vision_embeds], dim=0)
+ text_embeds_all = torch.cat([text_embeds, text_embeds, text_embeds_neg], dim=0)
+ vision_atts_all = torch.cat([vision_atts, vision_atts, vision_atts], dim=0)
+ text_atts_all = torch.cat([text_atts, text_atts, text_atts_neg], dim=0)
+
+ output = multimodal_encoder(
+ encoder_embeds=text_embeds_all,
+ attention_mask=text_atts_all,
+ encoder_hidden_states=vision_embeds_all,
+ encoder_attention_mask=vision_atts_all,
+ return_dict=True,
+ mode="fusion",
+ )
+
+ vtm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
+
+ vtm_logits = vtm_head(vtm_embeds) # [3*B, 2]
+
+ bs = vtm_logits.shape[0] // 3
+ vtm_labels = vtm_logits.new_ones(3 * bs, dtype=torch.long)
+ vtm_labels[bs:] = 0
+ loss_vtm = F.cross_entropy(vtm_logits, vtm_labels)
+ return loss_vtm
+
+ def get_rand_indices(self, mask, k):
+ """get rand indices according to mask.
+ Args:
+ mask (torch.Tensor): Shape: (N, L) 0 indicates the positions that we can sample, 1 otherwise
+ k (int): the number indices to sample at each row.
+ Returns:
+ The sampled indices. Shape: [N,k].
+ (N, k) indices
+ """
+ mask = mask.float()
+ mask = mask - 10000 * mask
+ mask += torch.randn_like(mask)
+ _, indices = torch.sort(mask, dim=1, descending=True)
+ indices = indices[:, :k].contiguous()
+ return indices
+
+ @torch.no_grad()
+ def get_mask(self, sim, idx=None, normalize=False):
+ """
+ Args:
+ sim (torch.Tensor): The similarity between videos and texts. shape: (B, B).
+ idx (torch.Tensor): The index for each video. Shape: [B].
+ normalize (bool): If true, make row sum equal to 1
+ """
+ if idx is not None:
+ idx = idx.view(-1, 1)
+ mask = torch.eq(idx, idx.T).to(sim.dtype)
+ if normalize:
+ mask = mask / mask.sum(1, keepdim=True)
+ else:
+ mask = torch.zeros_like(sim)
+ mask.fill_diagonal_(1)
+ return mask # `1` mark valid/matched location
+
+ @lru_cache(maxsize=16)
+ def get_gather_args(self):
+ """obtain the args for all_gather
+ Returns: dict.
+
+ """
+ from utils.distributed import get_rank, get_world_size
+ from utils.easydict import EasyDict
+ return EasyDict({"world_size": get_world_size(), "rank": get_rank()})
+
+
+class MLMLoss(nn.Module):
+ """masked language modeling loss."""
+
+ def __init__(self, masking_prob, tokenizer):
+ super(MLMLoss, self).__init__()
+ self.tokenizer = tokenizer
+ self.masking_prob = masking_prob
+
+ def mlm_loss(
+ self,
+ text_encoder,
+ text,
+ vision_embeds,
+ vision_atts,
+ ):
+ input_ids = text.input_ids.clone()
+ labels = input_ids.clone()
+ probability_matrix = torch.full(labels.shape, self.masking_prob)
+ input_ids, labels = self.mask(
+ input_ids,
+ text_encoder.config.vocab_size,
+ input_ids.device,
+ targets=labels,
+ probability_matrix=probability_matrix,
+ )
+
+ intermediate_mlm_output = text_encoder.bert(
+ input_ids,
+ attention_mask=text.attention_mask,
+ encoder_hidden_states=vision_embeds,
+ encoder_attention_mask=vision_atts,
+ return_dict=True,
+ mode="text",
+ )
+
+ text_embeds = intermediate_mlm_output.last_hidden_state
+
+ mlm_output = text_encoder(
+ encoder_embeds=text_embeds,
+ attention_mask=text.attention_mask,
+ encoder_hidden_states=vision_embeds,
+ encoder_attention_mask=vision_atts,
+ return_dict=True,
+ labels=labels,
+ soft_labels=None,
+ mode="fusion",
+ )
+ return mlm_output.loss
+
+ def simple_mlm_loss(
+ self,
+ text_encoder,
+ text,
+ text_embeds,
+ vision_embeds,
+ vision_atts,
+ labels
+ ):
+ mlm_output = text_encoder(
+ encoder_embeds=text_embeds,
+ attention_mask=text.attention_mask,
+ encoder_hidden_states=vision_embeds,
+ encoder_attention_mask=vision_atts,
+ return_dict=True,
+ labels=labels,
+ soft_labels=None,
+ mode="fusion",
+ )
+ return mlm_output.loss
+
+ def mask(
+ self,
+ input_ids,
+ vocab_size,
+ device,
+ targets=None,
+ masked_indices=None,
+ probability_matrix=None,
+ ):
+ if masked_indices is None:
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
+ """make deepspeed happy!"""
+ # _pad_mask = (input_ids == self.tokenizer.pad_token_id).to(masked_indices.device, non_blocking=True) # 0
+ # # print(_pad_mask.device)
+ # masked_indices[_pad_mask] = False
+ # _cls_mask = (input_ids == self.tokenizer.cls_token_id).to(masked_indices.device, non_blocking=True) # 101
+ # masked_indices[_cls_mask] = False
+
+
+ if targets is not None:
+ # We only compute loss on masked tokens
+ targets[~masked_indices] = -100
+
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
+ indices_replaced = (
+ torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
+ )
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = (
+ torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
+ & masked_indices
+ & ~indices_replaced
+ )
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
+ input_ids[indices_random] = random_words[indices_random]
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
+
+ if targets is not None:
+ return input_ids, targets
+ else:
+ return input_ids
+
+
+class UTA_Loss(nn.Module):
+ """mask align clip loss."""
+
+ def __init__(self, uta_norm_type='l2', uta_loss_type='l2'):
+ super().__init__()
+ self.norm_type = uta_norm_type
+ self.loss_type = uta_loss_type
+ logger.info(f'Norm type: {uta_norm_type}')
+ logger.info(f'Loss type: {uta_loss_type}')
+
+ if uta_loss_type == 'mse':
+ self.loss_func = nn.MSELoss()
+ elif uta_loss_type == 'smooth_l1':
+ self.loss_func = nn.SmoothL1Loss()
+
+ def uta_loss(self, student_output, clip_output):
+ """forward to calculate the loss
+
+ Args:
+ student_output (torch.Tensor): The student output. Shape: [K,B,N,C].
+ clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C].
+
+ Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
+ """
+
+ if self.norm_type == 'l2':
+ student_output = student_output / student_output.norm(dim=-1, keepdim=True)
+ clip_output = clip_output / clip_output.norm(dim=-1, keepdim=True)
+ elif self.norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ if self.loss_type == 'l2':
+ loss_uta = (2 - 2 * (student_output * clip_output).sum(dim=-1)).mean()
+ elif self.loss_type in ['mse', 'smooth_l1']:
+ loss_uta = self.loss_func(input=student_output, target=clip_output)
+ else:
+ raise NotImplementedError
+
+ return loss_uta
+
+ def uta_vision_loss(self, student_v_output, clip_v_output):
+ """forward to calculate the loss
+
+ Args:
+ student_v_output (torch.Tensor): The student output. Shape: [B,T,C].
+ clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C].
+
+ Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
+ """
+
+ if student_v_output.shape[1] != clip_v_output.shape[1]:
+ student_v_output = student_v_output.mean(1, keepdim=True)
+ clip_v_output = clip_v_output.mean(1, keepdim=True)
+ if self.norm_type == 'l2':
+ student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True)
+ clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True)
+ elif self.norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ if self.loss_type == 'l2':
+ loss_uta = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean()
+ elif self.loss_type in ['mse', 'smooth_l1']:
+ loss_uta = self.loss_func(input=student_v_output, target=clip_v_output)
+ else:
+ raise NotImplementedError
+
+ return loss_uta
+
+ def uta_all_loss(
+ self,
+ student_v_output, clip_v_output,
+ student_t_output, clip_t_output,
+ ):
+ """forward to calculate the loss
+
+ Args:
+ student_v_output (torch.Tensor): The student output. Shape: [B,T,C].
+ clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C].
+ student_t_output (torch.Tensor): The student output. Shape: [B,1,C].
+ clip_t_output (torch.Tensor): The teacher representation. Shape: [B,1,C].
+
+ Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
+ """
+
+ if student_v_output.shape[1] != clip_v_output.shape[1]:
+ student_v_output = student_v_output.mean(1, keepdim=True)
+ clip_v_output = clip_v_output.mean(1, keepdim=True)
+ if self.norm_type == 'l2':
+ student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True)
+ clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True)
+ student_t_output = student_t_output / student_t_output.norm(dim=-1, keepdim=True)
+ clip_t_output = clip_t_output / clip_t_output.norm(dim=-1, keepdim=True)
+ elif self.norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ if self.loss_type == 'l2':
+ loss_uta_v = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean()
+ loss_uta_t = (2 - 2 * (student_t_output * clip_t_output).sum(dim=-1)).mean()
+ elif self.loss_type in ['mse', 'smooth_l1']:
+ loss_uta_v = self.loss_func(input=student_v_output, target=clip_v_output)
+ loss_uta_t = self.loss_func(input=student_t_output, target=clip_t_output)
+ else:
+ raise NotImplementedError
+
+ return (loss_uta_v + loss_uta_t) / 2.
+
+
+class new_UTA_Loss(nn.Module):
+ """mask align clip loss."""
+
+ def __init__(self, distill_final_features=True, clip_loss_ratio=[1., 1.]):
+ super().__init__()
+ self.distill_final_features = distill_final_features
+ self.clip_loss_ratio = clip_loss_ratio
+
+ logger.info(f'distill_final_features: {distill_final_features}')
+ logger.info(f'clip_loss_ratio: {clip_loss_ratio}')
+
+
+ def uta_loss(self, student_output, student_output_final,
+ targets_clip_middle_vis, targets_clip_final_vis):
+ """forward to calculate the loss
+
+ Args:
+ student_output (torch.Tensor): The student output. Shape: [K,B,N,C].
+ clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C].
+
+ Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
+ """
+ loss_clip_middle = (2 - 2 * (student_output * targets_clip_middle_vis).sum(dim=-1)).mean()
+ if self.distill_final_features and self.clip_loss_ratio[1] > 0:
+ loss_clip_final = (2 - 2 * (student_output_final * targets_clip_final_vis).sum(dim=-1)).mean()
+ else:
+ loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device)
+ loss_uta = loss_clip_middle * self.clip_loss_ratio[0] + loss_clip_final * self.clip_loss_ratio[1]
+ return loss_uta
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/dist_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/models/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..625cff416e5a31b81f150d6ae674d9c2301a4a66
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/dist_utils.py
@@ -0,0 +1,191 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+distributed API using Horovod
+Modified from OpenNMT's native pytorch distributed utils
+(https://github.com/OpenNMT/OpenNMT-py)
+"""
+import math
+import pickle
+
+import torch
+import torch.distributed as dist
+from time import time
+from torch.autograd import Function
+from torch.utils.data.distributed import DistributedSampler
+
+
+class ddp_allgather_with_grads(Function):
+ @staticmethod
+ def forward(ctx, x):
+ tmp_input = x.cuda()
+ size = torch.tensor(tmp_input.shape[0]).cuda()
+ size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())]
+ dist.all_gather(size_list, size)
+ max_size = max(size_list).item()
+ padding_size = max_size - size
+ if padding_size > 0 :
+ padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input)
+ tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0)
+ tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())]
+ dist.all_gather(tmp_list, tmp_input)
+ ctx.size = size_list
+ output = []
+ for t, s in zip(tmp_list, size_list):
+ output.append(t[:s])
+ output = torch.cat(output,dim=0)
+ output.requires_grad = True
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_x = None
+
+ if grad_output is not None:
+ grad_output.detach()
+ #grad_x = grad_output.chunk(dist.get_world_size(),dim=0)[dist.get_rank()]
+ start = sum(ctx.size[:dist.get_rank()])
+ end = start + ctx.size[dist.get_rank()]
+ grad_x = grad_output[start:end]
+ return grad_x
+
+
+def ddp_allgather(input):
+ tmp_input = input.cuda()
+ size = torch.tensor(tmp_input.shape[0]).cuda()
+ size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())]
+ dist.all_gather(size_list, size)
+ max_size = max(size_list).item()
+ padding_size = max_size - size
+ if padding_size > 0 :
+ padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input)
+ tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0)
+ tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())]
+ dist.all_gather(tmp_list, tmp_input)
+ output = []
+ for t, s in zip(tmp_list, size_list):
+ output.append(t[:s])
+ output = torch.cat(output,dim=0)
+ return output
+
+
+def _encode(enc, max_size, use_max_size=False):
+ enc_size = len(enc)
+ enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
+ if use_max_size:
+ # this is used for broadcasting
+ buffer_ = torch.cuda.ByteTensor(max_size+enc_byte)
+ else:
+ buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
+ remainder = enc_size
+ for i in range(enc_byte):
+ base = 256 ** (enc_byte-i-1)
+ buffer_[i] = remainder // base
+ remainder %= base
+ buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
+ return buffer_, enc_byte
+
+
+def _decode(buffer_, enc_byte):
+ size = sum(256 ** (enc_byte-i-1) * buffer_[i].item()
+ for i in range(enc_byte))
+ bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
+ shift = size + enc_byte
+ return bytes_list, shift
+
+
+_BUFFER_SIZE = 4096
+
+
+def all_gather_list(data):
+ """Gathers arbitrary data from all nodes into a list."""
+ enc = pickle.dumps(data)
+
+ enc_size = len(enc)
+ max_size = ddp_allgather(torch.tensor([enc_size]).cuda()).max().item()
+ in_buffer, enc_byte = _encode(enc, max_size)
+
+ out_buffer = ddp_allgather(in_buffer[:enc_byte+enc_size])
+
+ results = []
+ for _ in range(dist.get_world_size()):
+ bytes_list, shift = _decode(out_buffer, enc_byte)
+ out_buffer = out_buffer[shift:]
+ result = pickle.loads(bytes_list)
+ results.append(result)
+ return results
+
+
+def any_broadcast(data, root_rank):
+ """broadcast arbitrary data from root_rank to all nodes."""
+ enc = pickle.dumps(data)
+
+ max_size = ddp_allgather(torch.tensor([len(enc)]).cuda()).max().item()
+ buffer_, enc_byte = _encode(enc, max_size, use_max_size=True)
+
+ dist.broadcast(buffer_, root_rank)
+
+ bytes_list, _ = _decode(buffer_, enc_byte)
+ result = pickle.loads(bytes_list)
+ return result
+
+
+class DistributedSampler_wopadding(DistributedSampler):
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch and seed
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
+ else:
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+
+ if self.drop_last:
+ indices = indices[:self.total_size]
+ #assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:len(indices):self.num_replicas]
+ # assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ torch.distributed.all_reduce(all_gradients)
+ return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = torch.distributed.get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+
+ # tensor_all = GatherLayer.apply(tensors)
+ tensor_all = GatherLayer.apply(tensors)
+
+ return torch.cat(tensor_all, dim=0)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_clip.py b/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a465c6099b76bb9edd1ee786a62923cf306feb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_clip.py
@@ -0,0 +1,252 @@
+import logging
+
+import torch
+from torch import nn
+import numpy as np
+from PIL import Image
+import torchvision.transforms as transforms
+from torchvision.transforms import InterpolationMode
+
+from .backbones.internvideo2 import InternVideo2, LLaMA, Tokenizer
+from .criterions import VTC_VTM_Loss
+
+logger = logging.getLogger(__name__)
+
+
+class InternVideo2_CLIP(nn.Module):
+ def __init__(self, config, tokenizer=None, is_pretrain=True):
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+ self.is_pretrain = is_pretrain
+
+ # create modules.
+ if tokenizer is None:
+ self.tokenizer = Tokenizer(config.model.tokenizer_path)
+ self.vision_encoder = self.build_vision_encoder()
+ self.text_encoder = self.build_text_encoder()
+ # adopt 1 / 100. as in ViCLIP
+ self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
+ self.temp_min = config.model.temp_min
+
+ # freeze model
+ if self.config.model.freeze_vision:
+ for name, p in self.vision_encoder.named_parameters():
+ if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'):
+ logger.info(f"Unfreeze {name}")
+ else:
+ logger.info(f"Freeze {name}")
+ p.requires_grad = False
+ if self.config.model.freeze_text:
+ for name, p in self.text_encoder.named_parameters():
+ if self.config.model.open_text_projection and name.startswith('text_projection'):
+ logger.info(f"Unfreeze {name}")
+ elif self.config.model.open_text_lora and 'lora' in name:
+ logger.info(f"Unfreeze {name}")
+ else:
+ logger.info(f"Freeze {name}")
+ p.requires_grad = False
+
+ img_size = self.config.model.vision_encoder.img_size
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (img_size, img_size),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.Lambda(lambda x: x.float().div(255.0)),
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ ]
+ )
+
+ # load pretrained models
+ self.load_checkpoint(
+ config.model.vision_ckpt_path, config.model.text_ckpt_path,
+ config.model.get("extra_ckpt_path", None)
+ )
+
+ # criterions
+ self.clip_loss = VTC_VTM_Loss(False)
+
+ def no_weight_decay(self):
+ ret = {"temp"}
+ ret.update(
+ {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
+ )
+ # no weight decay for LLM if training
+ ret.update(
+ {"text_encoder." + k for k, _ in self.text_encoder.named_parameters()}
+ )
+
+ return ret
+
+ @torch.no_grad()
+ def clip_contrastive_temperature(self):
+ """Seems only used during pre-training"""
+ self.temp.clamp_(min=self.temp_min)
+
+ def forward(self, image, text, idx):
+ """forward and calculate loss.
+
+ Args:
+ image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
+ text (dict): TODO
+ idx (torch.Tensor): TODO
+
+ Returns: TODO
+
+ """
+ self.clip_contrastive_temperature()
+ vision_embeds = self.encode_vision(image)
+ text_embeds = self.encode_text(text)
+
+ # VTC loss
+ loss_vtc = self.clip_loss.vtc_loss(
+ vision_embeds, text_embeds, idx, self.temp, all_gather=True
+ )
+
+ return dict(
+ loss_vtc=loss_vtc,
+ )
+
+ def encode_vision(self, image, test=False):
+ """encode image / videos as features.
+
+ Args:
+ image (torch.Tensor): The input images.
+ test (bool): Whether testing.
+
+ Returns: tuple.
+ - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C].
+
+ """
+ T = image.shape[1]
+ use_image = True if T == 1 else False
+ image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
+
+ vision_embeds = self.vision_encoder(image, use_image=use_image)
+ return vision_embeds
+
+ def encode_text(self, text):
+ """encode text.
+ Args:
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
+ Returns: tuple.
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C].
+
+ """
+ text_embeds = self.text_encoder(text)
+ return text_embeds
+
+ def build_vision_encoder(self):
+ """build vision encoder
+ Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
+
+ """
+ vision_encoder = InternVideo2(
+ in_chans=self.config.model.vision_encoder.in_chans,
+ patch_size=self.config.model.vision_encoder.patch_size,
+ img_size=self.config.model.vision_encoder.img_size,
+ qkv_bias=self.config.model.vision_encoder.qkv_bias,
+ drop_path_rate=self.config.model.vision_encoder.drop_path_rate,
+ head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate,
+ embed_dim=self.config.model.vision_encoder.embed_dim,
+ num_heads=self.config.model.vision_encoder.num_heads,
+ mlp_ratio=self.config.model.vision_encoder.mlp_ratio,
+ init_values=self.config.model.vision_encoder.init_values,
+ qk_normalization=self.config.model.vision_encoder.qk_normalization,
+ depth=self.config.model.vision_encoder.depth,
+ use_flash_attn=self.config.model.vision_encoder.use_flash_attn,
+ use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm,
+ use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp,
+ fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic,
+ attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads,
+ clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim,
+ layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32,
+ num_frames=self.config.model.vision_encoder.num_frames,
+ tubelet_size=self.config.model.vision_encoder.tubelet_size,
+ sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed,
+ use_checkpoint=self.config.model.vision_encoder.use_checkpoint,
+ checkpoint_num=self.config.model.vision_encoder.checkpoint_num,
+ )
+ return vision_encoder
+
+ def build_text_encoder(self):
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
+ Returns: nn.Module. The text encoder
+
+ """
+ text_encoder = LLaMA(
+ use_flash_attn=self.config.model.text_encoder.use_flash_attn,
+ transformer_width=self.config.model.text_encoder.transformer_width,
+ llama_path=self.config.model.text_encoder.llama_path,
+ use_lora=self.config.model.text_encoder.use_lora,
+ )
+
+ return text_encoder
+
+ def load_checkpoint(self, vision_ckpt_path=None, text_ckpt_path=None, extra_ckpt_path=None):
+ assert vision_ckpt_path is not None, "No vision_encoder checkpoint"
+ assert text_ckpt_path is not None, "No text_encoder checkpoint"
+
+ new_ckpt = {}
+
+ # load vision_encoder
+ logger.info(f"Load vision_encoder checkpoint from {vision_ckpt_path}")
+ vision_ckpt = torch.load(vision_ckpt_path, map_location='cpu')
+ if 'module' in vision_ckpt.keys():
+ vision_ckpt = vision_ckpt['module']
+ elif 'model' in vision_ckpt.keys():
+ vision_ckpt = vision_ckpt['model']
+ if self.config.model.get('load_vision_ckpt_from_internvideo2_stage2', False):
+ from .backbones.internvideo2.pos_embed import interpolate_pos_embed
+ orig_t_size = self.config.model.get('vision_ckpt_t_size', 4)
+ interpolate_pos_embed(vision_ckpt, self.vision_encoder, orig_t_size=orig_t_size) # 4 for InternVideo2 stage2
+ for k, v in vision_ckpt.items():
+ if k.startswith('vision_encoder.'):
+ if 'clip_decoder' in k or 'final_clip_decoder' in k:
+ continue
+ elif 'clip_pos_embed' in k or 'clip_img_pos_embed' in k or 'img_pos_embed' in k :
+ continue
+ else:
+ new_ckpt[k] = v
+ else:
+ continue
+ else:
+ for k, v in vision_ckpt.items():
+ if k.startswith('clip_decoder.') or k.startswith('mae_decoder.') or k.startswith('final_clip_decoder.'):
+ continue
+ elif k in ['clip_pos_embed', 'mae_pos_embed']:
+ continue
+ else:
+ new_k = 'vision_encoder.' + k
+ new_ckpt[new_k] = v
+
+ # load text_encoder
+ logger.info(f"Load text_encoder checkpoint from {text_ckpt_path}")
+ test_ckpt = torch.load(text_ckpt_path, map_location='cpu')
+ if 'module' in test_ckpt.keys():
+ test_ckpt = test_ckpt['module']
+ for k, v in test_ckpt.items():
+ if k.startswith('transformer.') or k == 'text_projection':
+ new_k = "text_encoder." + k
+ else:
+ continue
+ new_ckpt[new_k] = v
+
+ # load extra checkpoint
+ # often when post-pretrain after previous pretraining, thus the keys are same
+ if extra_ckpt_path is not None:
+ logger.info(f"Load extra checkpoint from {extra_ckpt_path}")
+ extra_ckpt = torch.load(extra_ckpt_path, map_location='cpu')
+ if 'module' in extra_ckpt.keys():
+ extra_ckpt = extra_ckpt['module']
+ for k, v in extra_ckpt.items():
+ new_ckpt[k] = v
+
+ msg = self.load_state_dict(new_ckpt, strict=False)
+ logger.info(msg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_stage2.py b/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_stage2.py
new file mode 100644
index 0000000000000000000000000000000000000000..458732aff7f024c0aa896bee0295dd662f20e720
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/internvideo2_stage2.py
@@ -0,0 +1,361 @@
+import logging
+
+import torch
+from torch import nn
+
+from .backbones.internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224, internvl_clip_6b
+from .backbones.bert.builder import build_bert
+from .criterions import MLMLoss, VTC_VTM_Loss, new_UTA_Loss
+from .mask import (
+ TubeMaskingGenerator,
+ RandomMaskingGenerator
+)
+
+logger = logging.getLogger(__name__)
+
+
+class InternVideo2_Stage2(nn.Module):
+ """docstring for InternVideo2_Stage2"""
+
+ def __init__(self, config, tokenizer, is_pretrain=True):
+ super(InternVideo2_Stage2, self).__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.is_pretrain = is_pretrain
+ self.vision_width = config.model.vision_encoder.clip_embed_dim
+ self.text_width = config.model.text_encoder.d_model
+ self.embed_dim = config.model.embed_dim
+
+ # create modules.
+ self.vision_encoder = self.build_vision_encoder()
+ if config.model.get("freeze_vision", False):
+ self.freeze_vision()
+
+ self.text_encoder = self.build_text_encoder()
+ if config.model.get("freeze_text", False):
+ self.freeze_text()
+
+ self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
+ self.text_proj = nn.Linear(self.text_width, self.embed_dim)
+
+ self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
+ self.itm_head = nn.Linear(self.text_width, 2)
+
+ # criterions
+ self.loss_weight = config.criterion.loss_weight
+ self.criterion_uta = new_UTA_Loss(
+ config.criterion.distill_final_features,
+ config.criterion.clip_loss_ratio,
+ )
+ self.criterion_vtc_vtm = VTC_VTM_Loss(config.criterion.vtm_hard_neg)
+ self.criterion_mlm = MLMLoss(config.criterion.mlm_masking_prob, tokenizer)
+ self.uta_image_only = config.criterion.get('uta_image_only', False)
+ logger.info(f"uta_image_only={self.uta_image_only}")
+
+ def freeze_vision(self):
+ """freeze vision encoder"""
+ for p in self.vision_encoder.parameters():
+ p.requires_grad = False
+
+ def freeze_text(self):
+ """freeze text encoder"""
+ for p in self.text_encoder.parameters():
+ p.requires_grad = False
+
+ def no_weight_decay(self):
+ ret = {"temp"}
+ ret.update(
+ {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
+ )
+ # ret.update(
+ # {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
+ # )
+
+ return ret
+
+ @property
+ def dtype(self):
+ return self.vision_encoder.patch_embed.proj.weight.dtype
+
+ def forward(self, image, text, idx, media_type='image'):
+ """forward and calculate loss.
+
+ Args:
+ image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
+ text (dict)
+ idx (torch.Tensor)
+ media_type: str
+ Returns:
+
+ """
+
+ self.clip_contrastive_temperature()
+ T = image.shape[1]
+ use_image = True if T == 1 else False
+
+ vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis = self.encode_vision(image)
+
+ text_embeds, pooled_text_embeds = self.encode_text(text)
+
+ # obtain vision and text representations.
+ vision_proj = self.vision_proj(pooled_vision_embeds)
+ text_proj = self.text_proj(pooled_text_embeds)
+
+ # calculate loss
+ ## UTA loss
+ if self.loss_weight.uta != 0:
+ if self.uta_image_only and not use_image:
+ loss_uta = torch.tensor(0)
+ else:
+ loss_uta = self.criterion_uta.uta_loss(student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis)
+ else:
+ loss_uta = torch.tensor(0)
+
+ ## VTC loss
+ if self.loss_weight.vtc != 0:
+ loss_vtc = self.criterion_vtc_vtm.vtc_loss(
+ vision_proj, text_proj, idx, self.temp, all_gather=True
+ )
+ else:
+ loss_vtc = torch.tensor(0)
+
+ ## VTM loss
+ if self.loss_weight.vtm != 0:
+ loss_vtm = self.criterion_vtc_vtm.vtm_loss(
+ self.get_text_encoder(),
+ self.itm_head,
+ self.temp,
+ vision_embeds,
+ text_embeds,
+ vision_proj,
+ text_proj,
+ text.attention_mask,
+ idx,
+ )
+ else:
+ loss_vtm = torch.tensor(0)
+
+ ## MLM loss
+ if self.is_pretrain and self.loss_weight.mlm != 0:
+ loss_mlm = self.criterion_mlm.mlm_loss(
+ self.text_encoder, text, vision_embeds, None
+ )
+ else:
+ loss_mlm = torch.tensor(0)
+
+ return dict(
+ loss_uta=loss_uta * self.loss_weight.uta,
+ loss_vtc=loss_vtc * self.loss_weight.vtc,
+ loss_vtm=loss_vtm * self.loss_weight.vtm,
+ loss_mlm=loss_mlm * self.loss_weight.mlm,
+ )
+
+ def encode_teacher(self, image):
+ """encode image / videos as features.
+
+ Args:
+ image (torch.Tensor): The input images.
+
+ Returns: tuple.
+ - mask (torch.Tensor): Mask. Shape: [B,N1].
+ - d_mask (torch.Tensor): Double Mask. Shape: [B,N2].
+ - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
+
+ """
+ B, C, T, H, W = image.shape
+ mask_type = self.image_mask_type if T == 1 else self.video_mask_type
+ window_size = self.image_window_size if T == 1 else self.video_window_size
+ mask_ratio = self.image_mask_ratio if T == 1 else self.video_mask_ratio
+
+ if (self.uta_image_only and T != 1) or self.config.model.vision_encoder.get('only_mask', False):
+ if mask_type == 'tube':
+ mask = TubeMaskingGenerator(window_size, mask_ratio, B)
+ elif mask_type == 'random':
+ mask = RandomMaskingGenerator(window_size, mask_ratio, B)
+ elif mask_type == 'none':
+ return None, None, None
+ else:
+ raise NotImplementedError
+
+ mask = mask.view(B, -1).to(torch.bool)
+ mask = torch.cat((torch.zeros(B, 1).to(mask.device), mask), dim=1)
+ mask = mask.to(torch.bool)
+
+ return mask, None, None
+
+ if self.clip_teacher is None or self.loss_weight.uta == 0:
+ return None, None, None
+
+ if H != self.clip_img_size:
+ image = torch.nn.functional.interpolate(
+ image.reshape(B, C*T, H, W),
+ size=(self.clip_img_size, self.clip_img_size),
+ mode='bicubic', align_corners=False
+ )
+ image = image.view(B, C, T, self.clip_img_size, self.clip_img_size)
+
+ with torch.no_grad():
+ if mask_type == 'tube':
+ mask = TubeMaskingGenerator(window_size, mask_ratio, B)
+ norm_clip_middle, norm_clip_final, attn = self.clip_teacher(image)
+ elif mask_type == 'random':
+ mask = RandomMaskingGenerator(window_size, mask_ratio, B)
+ norm_clip_middle, norm_clip_final, attn = self.clip_teacher(image)
+ elif mask_type in 'attention':
+ norm_clip_middle, norm_clip_final, attn = self.clip_teacher(image)
+ BT, N = attn.shape
+ N_vis = N - int(N * mask_ratio)
+ importance = torch.multinomial(attn, N)
+ mask = torch.ones((BT, N))
+ pos1 = torch.arange(BT).view(-1, 1).repeat(1, N_vis)
+ pos2 = importance[:, :N_vis]
+ mask[pos1, pos2] = 0
+ else:
+ raise NotImplementedError
+
+ mask = mask.view(B, -1).to(torch.bool)
+ mask = torch.cat((torch.zeros(B, 1), mask), dim=1)
+ mask = mask.to(torch.bool)
+
+ # mask clip output
+ C_CLIP = norm_clip_middle.shape[-1]
+ if len(norm_clip_middle.shape) == 4:
+ K = norm_clip_middle.shape[0]
+ clip_mask = mask.unsqueeze(0).repeat(K, 1, 1)
+ targets_clip_middle_vis = norm_clip_middle[~clip_mask].reshape(K, B, -1, C_CLIP)
+ else:
+ clip_mask = mask
+ targets_clip_middle_vis = norm_clip_middle[~clip_mask].reshape(B, -1, C_CLIP)
+
+ targets_clip_final_vis = norm_clip_final # only one tokens
+
+ return mask, targets_clip_middle_vis, targets_clip_final_vis
+
+ def encode_vision(self, image, test=False):
+ """encode image / videos as features.
+
+ Args:
+ image (torch.Tensor): The input images.
+ test (bool): Whether testing.
+
+ Returns: tuple.
+ - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
+ - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
+ - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C].
+ - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
+
+ """
+
+ T = image.shape[1]
+ use_image = True if T == 1 else False
+ image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
+ # whether save temporal dimension
+ # keep_temporal=self.config.model.vision_encoder.keep_temporal
+ if test:
+ vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder(
+ image, None, use_image)
+ return vision_embeds, pooled_vision_embeds
+ else:
+ mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image)
+ # if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'):
+ # keep_temporal = False
+ # print(f"\033[31mmask is {type(mask)}\033[0m")
+ vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder(
+ image, mask, use_image)
+ return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis
+
+ def encode_text(self, text):
+ """encode text.
+ Args:
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
+ Returns: tuple.
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
+ - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
+
+ """
+ text_output = self.get_text_encoder()(
+ text.input_ids,
+ attention_mask=text.attention_mask,
+ return_dict=True,
+ mode="text",
+ )
+ text_embeds = text_output.last_hidden_state
+ pooled_text_embeds = text_embeds[:, 0]
+ return text_embeds, pooled_text_embeds
+
+ @torch.no_grad()
+ def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
+ """Seems only used during pre-training"""
+ self.temp.clamp_(min_val, max_val)
+
+ def build_vision_encoder(self):
+ """build vision encoder
+ Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`.
+
+ """
+ encoder_name = self.config.model.vision_encoder.name
+ logger.info(f"Build vision_encoder: {encoder_name}")
+ if encoder_name == 'pretrain_internvideo2_1b_patch14_224':
+ vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model)
+ elif encoder_name == 'pretrain_internvideo2_6b_patch14_224':
+ vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model)
+ else:
+ raise ValueError(f"Not implemented: {encoder_name}")
+
+ teacher_name = self.config.model.vision_encoder.clip_teacher
+ self.clip_teacher = None
+ if teacher_name != None:
+ assert teacher_name == 'internvl_clip_6b'
+ self.clip_teacher = internvl_clip_6b(
+ img_size=self.config.model.vision_encoder.clip_input_resolution,
+ clip_norm_type=self.config.model.vision_encoder.clip_norm_type,
+ return_attn=True,
+ clip_return_layer=self.config.model.vision_encoder.clip_return_layer,
+ clip_return_interval=self.config.model.vision_encoder.clip_teacher_return_interval
+ )
+ for p in self.clip_teacher.parameters():
+ p.requires_grad = False
+
+ # parameters for mask
+ img_size = self.config.model.vision_encoder.img_size
+ num_frames = self.config.model.vision_encoder.num_frames
+ tublet_size = self.config.model.vision_encoder.tubelet_size
+ patch_size = self.config.model.vision_encoder.patch_size
+ self.clip_img_size = self.config.model.vision_encoder.clip_input_resolution
+ self.video_mask_type = self.config.model.vision_encoder.video_mask_type
+ self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size)
+ self.video_mask_ratio = self.config.model.vision_encoder.video_mask_ratio
+ self.image_mask_type = self.config.model.vision_encoder.image_mask_type
+ self.image_window_size = (1, img_size // patch_size, img_size // patch_size)
+ self.image_mask_ratio = self.config.model.vision_encoder.image_mask_ratio
+
+ return vision_encoder
+
+ def build_text_encoder(self):
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
+ Returns: nn.Module. The text encoder
+
+ """
+ encoder_name = self.config.model.text_encoder.name
+ logger.info(f"Build text_encoder {encoder_name}")
+
+ if "bert" in encoder_name:
+ text_encoder = build_bert(
+ self.config.model,
+ self.is_pretrain,
+ self.config.gradient_checkpointing,
+ )
+ else:
+ raise ValueError(f"Not implemented: {encoder_name}")
+
+ return text_encoder
+
+ def get_text_encoder(self):
+ """get text encoder, used for text and cross-modal encoding"""
+ encoder = self.text_encoder
+ return encoder.bert if hasattr(encoder, "bert") else encoder
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/mask.py b/third_party/InternVideo/InternVideo2/multi_modality/models/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc6095892570c202b66434640151a7af73da2d1
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/mask.py
@@ -0,0 +1,37 @@
+import torch
+import numpy as np
+
+
+def TubeMaskingGenerator(input_size, mask_ratio, batch, device='cuda'):
+ frames, height, width = input_size
+ num_patches_per_frame = height * width
+ num_masks_per_frame = int(mask_ratio * num_patches_per_frame)
+
+ mask_list = []
+ for _ in range(batch):
+ mask_per_frame = np.hstack([
+ np.zeros(num_patches_per_frame - num_masks_per_frame),
+ np.ones(num_masks_per_frame),
+ ])
+ np.random.shuffle(mask_per_frame)
+ mask_list.append(np.tile(mask_per_frame, (frames, 1)).flatten())
+ mask = torch.Tensor(mask_list).to(device, non_blocking=True).to(torch.bool)
+ return mask
+
+
+def RandomMaskingGenerator(input_size, mask_ratio, batch, device='cuda'):
+ frames, height, width = input_size
+
+ num_patches = frames * height * width # 8x14x14
+ num_mask = int(mask_ratio * num_patches)
+
+ mask_list = []
+ for _ in range(batch):
+ mask = np.hstack([
+ np.zeros(num_patches - num_mask),
+ np.ones(num_mask),
+ ])
+ np.random.shuffle(mask)
+ mask_list.append(mask)
+ mask = torch.Tensor(np.array(mask_list)).to(device, non_blocking=True).to(torch.bool)
+ return mask
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/models/utils.py b/third_party/InternVideo/InternVideo2/multi_modality/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7617fa228550b5bc5d8f0e5fc72a4ee7b8919e09
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/models/utils.py
@@ -0,0 +1,299 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy import interpolate
+from typing import List
+
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
+ """
+ Add/Remove extra temporal_embeddings as needed.
+ https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
+
+ temp_embed_old: (1, num_frames_old, 1, d)
+ temp_embed_new: (1, num_frames_new, 1, d)
+ add_zero: bool, if True, add zero, else, interpolate trained embeddings.
+ """
+ # TODO zero pad
+ num_frms_new = temp_embed_new.shape[1]
+ num_frms_old = temp_embed_old.shape[1]
+ logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
+ if num_frms_new > num_frms_old:
+ if add_zero:
+ temp_embed_new[
+ :, :num_frms_old
+ ] = temp_embed_old # untrained embeddings are zeros.
+ else:
+ temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
+ elif num_frms_new < num_frms_old:
+ temp_embed_new = temp_embed_old[:, :num_frms_new]
+ else: # =
+ temp_embed_new = temp_embed_old
+ return temp_embed_new
+
+
+def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new):
+ """
+ temp_embed_old: (1, num_frames_old, 1, d)
+ Returns:
+ temp_embed_new: (1, num_frames_new, 1, d)
+ """
+ temp_embed_old = temp_embed_old.squeeze(2).permute(
+ 0, 2, 1
+ ) # (1, d, num_frames_old)
+ temp_embed_new = F.interpolate(
+ temp_embed_old, num_frames_new, mode="linear"
+ ) # (1, d, num_frames_new)
+ temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze(
+ 2
+ ) # (1, num_frames_new, 1, d)
+ return temp_embed_new
+
+
+def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
+ """
+ Args:
+ pos_embed_old: (1, L_old, d), pre-trained
+ pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
+ num_patches_new:
+ """
+ # interpolate position embedding
+ embedding_size = pos_embed_old.shape[-1]
+ num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches_new ** 0.5)
+
+ if orig_size != new_size:
+ # class_token and dist_token are kept unchanged
+ # the extra tokens seems always at the beginning of the position embedding
+ extra_tokens = pos_embed_old[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_old[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ logger.info(f"reshape position embedding from {orig_size}**2 to {new_size}**2")
+ return interpolated_pos_embed
+ else:
+ return pos_embed_old
+
+
+def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
+ """
+ Args:
+ state_dict_old: loaded state dict
+ state_dict_new: state dict for model with new image size
+ patch_shape_new: new model patch_shape
+ ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
+ """
+ all_keys = list(state_dict_old.keys())
+ for key in all_keys:
+ if "relative_position_index" in key:
+ state_dict_old.pop(key)
+
+ if "relative_position_bias_table" in key:
+ rel_pos_bias = state_dict_old[key]
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+ dst_num_pos, _ = state_dict_new[key].size()
+ dst_patch_shape = patch_shape_new
+ if dst_patch_shape[0] != dst_patch_shape[1]:
+ raise NotImplementedError()
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
+ dst_patch_shape[1] * 2 - 1
+ )
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
+ if src_size != dst_size:
+ # logger.info("Position interpolate for %s from %dx%d to %dx%d" % (
+ # key, src_size, src_size, dst_size, dst_size))
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r ** n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ # logger.info("Original positions = %s" % str(x))
+ # logger.info("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(num_attn_heads):
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
+ f = interpolate.interp2d(x, y, z, kind="cubic")
+ all_rel_pos_bias.append(
+ torch.Tensor(f(dx, dy))
+ .contiguous()
+ .view(-1, 1)
+ .to(rel_pos_bias.device)
+ )
+
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
+ state_dict_old[key] = new_rel_pos_bias
+ return state_dict_old
+
+
+def tile(x, dim, n_tile):
+ init_dim = x.size(dim)
+ repeat_idx = [1] * x.dim()
+ repeat_idx[dim] = n_tile
+ x = x.repeat(*repeat_idx)
+ order_index = torch.LongTensor(
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
+ )
+ return torch.index_select(x, dim, order_index.to(x.device))
+
+
+def mask_logits(target, mask):
+ return target * mask + (1 - mask) * (-1e10)
+
+
+class AllGather(torch.autograd.Function):
+ """An autograd function that performs allgather on a tensor."""
+
+ @staticmethod
+ def forward(ctx, tensor, args):
+ output = [torch.empty_like(tensor) for _ in range(args.world_size)]
+ torch.distributed.all_gather(output, tensor)
+ ctx.rank = args.rank
+ ctx.batch_size = tensor.shape[0]
+ return torch.cat(output, dim=0)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return (
+ grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
+ None,
+ )
+
+
+allgather_wgrad = AllGather.apply
+
+
+def tie_encoder_decoder_weights(
+ encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str
+):
+ uninitialized_encoder_weights: List[str] = []
+ if decoder.__class__ != encoder.__class__:
+ if issubclass(decoder.__class__, encoder.__class__):
+ logger.info(
+ f"decoder ({decoder.__class__}) and encoder ({encoder.__class__}) are not equal, encoder is decoder's father. In this case make sure that all encoder weights are correctly initialized."
+ )
+ elif issubclass(encoder.__class__, decoder.__class__):
+ logger.info(
+ f"decoder ({decoder.__class__}) and encoder ({encoder.__class__}) are not equal, decoder is encoder's father. In this case make sure that all encoder weights are correctly initialized."
+ )
+ else:
+ raise ValueError(f"decoder ({decoder.__class__}) and encoder ({encoder.__class__}) are not equal!!!")
+
+ def tie_encoder_to_decoder_recursively(
+ decoder_pointer: nn.Module,
+ encoder_pointer: nn.Module,
+ module_name: str,
+ uninitialized_encoder_weights: List[str],
+ skip_key: str,
+ depth=0,
+ ):
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
+ encoder_pointer, nn.Module
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
+ assert hasattr(encoder_pointer, "weight")
+ encoder_pointer.weight = decoder_pointer.weight
+ if hasattr(decoder_pointer, "bias"):
+ assert hasattr(encoder_pointer, "bias")
+ encoder_pointer.bias = decoder_pointer.bias
+ logger.info(module_name + " is tied")
+ return
+
+ encoder_modules = encoder_pointer._modules
+ decoder_modules = decoder_pointer._modules
+ if len(decoder_modules) > 0:
+ assert (
+ len(encoder_modules) > 0
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+ all_encoder_weights = set(
+ [module_name + "/" + sub_name for sub_name in encoder_modules.keys()]
+ )
+ encoder_layer_pos = 0
+ for name, module in decoder_modules.items():
+ if name.isdigit():
+ encoder_name = str(int(name) + encoder_layer_pos)
+ decoder_name = name
+ if not isinstance(
+ decoder_modules[decoder_name],
+ type(encoder_modules[encoder_name]),
+ ) and len(encoder_modules) != len(decoder_modules):
+ # this can happen if the name corresponds to the position in a list module list of layers
+ # in this case the decoder has added a cross-attention that the encoder does not have
+ # thus skip this step and subtract one layer pos from encoder
+ encoder_layer_pos -= 1
+ continue
+ elif name not in encoder_modules:
+ continue
+ elif depth > 500:
+ raise ValueError(
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
+ )
+ else:
+ decoder_name = encoder_name = name
+ tie_encoder_to_decoder_recursively(
+ decoder_modules[decoder_name],
+ encoder_modules[encoder_name],
+ module_name + "/" + name,
+ uninitialized_encoder_weights,
+ skip_key,
+ depth=depth + 1,
+ )
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+ uninitialized_encoder_weights += list(all_encoder_weights)
+
+ # tie weights recursively
+ tie_encoder_to_decoder_recursively(
+ decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key
+ )
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/preprocess/compress.py b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/compress.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcbc1e46b980d73c633ae8045ea74e6aadf46d0
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/compress.py
@@ -0,0 +1,143 @@
+"""
+Used to compress videos (FPS and dimensions) in the Singularity project.
+Modified from https://github.com/ArrowLuo/CLIP4Clip
+"""
+import os
+from os.path import exists, join
+import argparse
+import subprocess
+from multiprocessing import Pool
+import shutil
+try:
+ from psutil import cpu_count
+except:
+ from multiprocessing import cpu_count
+from functools import partial
+from tqdm import tqdm
+from PIL import Image
+
+
+def resize_image(input_path, output_path, size=224):
+ with Image.open(input_path) as img:
+ w, h = img.width, img.height
+ r = 1. * w / h
+ if w > h:
+ h = size
+ w = r * size
+ else:
+ h = size / r
+ w = size
+
+ img_resized = img.resize((int(w), int(h)))
+ img_resized.save(output_path)
+
+
+def _compress_images(input_output_pair, size=224):
+ """
+ Scale and downsample an input image to a given fps and size (shorter side size).
+ This also removes the audio from the image.
+ """
+ input_image_path, output_image_path = input_output_pair
+ try:
+ resize_image(input_image_path, output_image_path, size)
+ except Exception as e:
+ print(f"Caught Exception {e}")
+
+
+def _compress_videos(input_output_pair, size=224, fps=3):
+ """
+ Scale and downsample an input video to a given fps and size (shorter side size).
+ This also removes the audio from the video.
+ """
+ input_file_path, output_file_path = input_output_pair
+ try:
+ command = ['ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-i', input_file_path,
+ '-filter:v', # no audio
+ f"scale='if(gt(a,1),trunc(oh*a/2)*2,{size})':'if(gt(a,1),{size},trunc(ow*a/2)*2)'",
+ '-map', '0:v', # no audio
+ '-r', str(fps), # frames per second
+ # '-g', str(16),
+ output_file_path,
+ ]
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ except Exception as e:
+ raise e
+
+
+def _compress(input_output_pair, fps=3, size=224, file_type="image"):
+ if file_type == "image":
+ _compress_images(input_output_pair, size)
+ elif file_type == "video":
+ _compress_videos(input_output_pair, size, fps)
+
+
+def prepare_input_output_pairs(input_root, output_root, input_file_list_path=None):
+ # filename list in `input_file_list_path` can be created very fast using `ls -U . >> ../video_filenames.txt`
+ with open(input_file_list_path, "r") as f:
+ filenames = [s.strip() for s in f.readlines()]
+ print(f"There are {len(filenames)} video/images files loaded from list.")
+ input_file_path_list = []
+ output_file_path_list = []
+ for e in tqdm(filenames, desc="find un-processed videos/images"):
+ input_file_path = join(input_root, e)
+ output_file_path = join(output_root, e)
+ if not exists(output_file_path):
+ input_file_path_list.append(input_file_path)
+ output_file_path_list.append(output_file_path)
+ return input_file_path_list, output_file_path_list
+
+
+def run_compress():
+ parser = argparse.ArgumentParser(description="Compress videos/images for speed-up")
+ parser.add_argument("--input_root", type=str, help="input root", required=True)
+ parser.add_argument("--input_file_list_path", type=str, required=True, default=None,
+ help="list of video filenames under args.input_root, it can be "
+ "created efficiently with `ls -U /path/to/video >> /path/to/video_filenames.txt`")
+ parser.add_argument("--output_root", type=str, help="output root", required=True)
+ parser.add_argument("--size", type=int, default=224, help="shorter side size, aspect ratio is kept")
+ parser.add_argument("--num_workers", type=int, default=24, help="#workers")
+ parser.add_argument("--fps", type=int, default=3, help="fps for output video, ignored if file_type == image")
+ parser.add_argument("--file_type", type=str, choices=["image", "video"], help="input file type")
+ args = parser.parse_args()
+
+ # set paths
+ input_root = args.input_root
+ output_root = args.output_root
+ assert input_root != output_root
+ if not exists(output_root):
+ os.makedirs(output_root, exist_ok=True)
+
+ # prepare and find un-processed
+ input_file_path_list, output_file_path_list = prepare_input_output_pairs(
+ input_root, output_root, input_file_list_path=args.input_file_list_path,
+ )
+ print(f"input_file_path_list[:3] {input_file_path_list[:3]}")
+ print(f"output_file_path_list[:3] {output_file_path_list[:3]}")
+ print("Total videos/images need to process: {}".format(len(input_file_path_list)))
+
+ # start parallel jobs
+ num_cores = cpu_count()
+ num_workers = args.num_workers
+ print(f"Begin with {num_cores}-core logical processor, {num_workers} workers")
+ compress = partial(_compress, fps=args.fps, size=args.size, file_type=args.file_type)
+ input_pairs = list(zip(input_file_path_list, output_file_path_list))
+ with Pool(num_workers) as pool, tqdm(total=len(input_file_path_list), desc="re-encoding videos/images") as pbar:
+ for idx, _ in enumerate(pool.imap_unordered(compress, input_pairs, chunksize=32)):
+ pbar.update(1)
+
+ # copy-paste failed files
+ print("Compress finished, copy-paste failed files...")
+ copy_count = 0
+ for input_file_path, output_file_path in zip(input_file_path_list, output_file_path_list):
+ if exists(input_file_path):
+ if exists(output_file_path) is False or os.path.getsize(output_file_path) < 1.:
+ copy_count += 1
+ shutil.copyfile(input_file_path, output_file_path)
+ print("Copy and replace file: {}".format(output_file_path))
+ print(f"copy_count {copy_count}")
+
+
+if __name__ == "__main__":
+ run_compress()
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/preprocess/create_sqlite_db.py b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/create_sqlite_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..57655f516cf31ef2e528666f6ccd7455625ff6aa
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/create_sqlite_db.py
@@ -0,0 +1,96 @@
+import json
+import os
+import sqlite3
+import time
+
+import numpy as np
+
+
+def convert_to_sqlite_db(src_path: str, dst_path: str, media_type: str):
+ """TODO: Docstring for convert_to_sqlite_db.
+
+ Args:
+ src_path (str): The src json annotation file path.
+ dst_path (str): The saved sqlite db path.
+ media_type (str): The media type. Either "image" or "video".
+
+ """
+
+ # con = sqlite3.connect("file:"+dst_path+"?mode=ro",uri=True)
+ con = sqlite3.connect(dst_path)
+ cur = con.cursor()
+ print(f"creating table")
+ cur.execute("DROP TABLE IF EXISTS annos")
+ table_sql = f""" CREATE TABLE IF NOT EXISTS `annos` (
+ `id` integer PRIMARY KEY,
+ `{media_type}` text,
+ `caption` text
+ )"""
+ print(table_sql)
+ cur.execute(table_sql)
+
+ with open(src_path, "r") as f:
+ anno_list = json.load(f)
+ filenames = [anno[media_type] for anno in anno_list]
+ captions = [anno["caption"] for anno in anno_list]
+ ids = list(range(len(filenames)))
+ records = list(zip(ids, filenames, captions))
+
+ cur.executemany(f"INSERT INTO annos (id, {media_type}, caption) VALUES (?,?,?)", records)
+ con.commit()
+ con.close()
+
+
+def read_sqlite(db_path):
+ """TODO: Docstring for read_sqlite.
+
+ Args:
+ db_path (TODO): TODO
+
+ Returns: TODO
+
+ """
+ con = sqlite3.connect("file:" + db_path + "?mode=ro", uri=True)
+ cur = con.cursor()
+ ids = cur.execute("SELECT id FROM annos").fetchall()
+ ids = [id[0] for id in ids]
+ print("number medias:", len(ids))
+ np.random.shuffle(ids)
+ for id in ids[:100]:
+ t1 = time.time()
+ query = f"SELECT * FROM annos WHERE id = {id};"
+ res = cur.execute(query)
+ newid, filename, caption = res.fetchone()
+ t2 = time.time()
+ print(f"time: {t2-t1}s", id, newid, filename, caption)
+ con.close()
+
+
+def convert(json_filename, media_type):
+ """convert json annotations to sqlite.
+ Returns: TODO
+
+ """
+ print(f"\n--------converting {filename}----------------")
+ # src_path = os.path.join(os.environ["SL_DATA_DIR"], "anno_pretrain", json_filename)
+ path = 'your_data_path/anno/anno_pretrain'
+ src_path = os.path.join(path, json_filename)
+ dst_path = src_path.replace(".json", ".sqlite.db")
+ convert_to_sqlite_db(src_path, dst_path, media_type)
+ read_sqlite(dst_path)
+
+
+if __name__ == "__main__":
+ filenames = [
+ # ["cc12m.json", "image"],
+ ["cc3m_train.json", "image"],
+ # ["cc3m_val.json", "image"],
+ # ["coco.json", "image"],
+ # ["sbu.json", "image"],
+ # ["vg.json", "image"],
+ # ["webvid_10m_train.json", "video"],
+ # ["webvid_10m_val.json", "video"],
+ ["webvid_train.json", "video"],
+ ]
+ for filename, media_type in filenames:
+ convert(filename, media_type)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/preprocess/gen_webvid10m_label.py b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/gen_webvid10m_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0992caa660390fe008cf0b7d0cce1a2b54e4451
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/gen_webvid10m_label.py
@@ -0,0 +1,69 @@
+import json
+import os
+from multiprocessing import Pool
+
+import pandas
+import tqdm
+
+from utils import get_video_duration
+
+data_dir = os.path.join(os.environ["SL_DATA_DIR"], "videos_images/webvid_10m_2fps_224")
+downloaded_vidlist = data_dir.replace("webvid_10m_2fps_224", "webvid_10m_vidlist.txt")
+
+def gen_valid_vidlist():
+ """generate the valid video list.
+ Returns: set. The valid
+
+ """
+ with open(downloaded_vidlist, 'r') as f:
+ videos = f.read().splitlines()
+ return set(videos)
+
+
+def gen_labels(src_file, dst_file):
+ """TODO: Docstring for gen_labels.
+
+ Args:
+ src_file (str): The original csv file
+ dst_file (str): the output json file
+ data_dir (str): The data to store the videos.
+
+ """
+ df = pandas.read_csv(src_file)
+ vids = df["videoid"].values.tolist()
+ captions = df["name"].values.tolist()
+
+ valid_videos = gen_valid_vidlist()
+
+ labels = []
+ num_invalid = 0
+ for vid, caption in tqdm.tqdm(zip(vids, captions), total=len(vids)):
+ vid_name = f"{vid}.mp4"
+ if vid_name in valid_videos:
+ example = {"video": vid_name, "caption": caption, "duration": 0}
+ labels.append(example)
+ else:
+ num_invalid += 1
+
+ # pool = Pool(128)
+ # labels = []
+ # for example in tqdm.tqdm(pool.imap_unordered(gen_one_example, zip(vids,captions)), total=len(vids)):
+ # labels.append(example)
+ print(f"number of valid videos: {len(labels)}. invalid: {num_invalid}")
+ with open(dst_file, "w") as f:
+ json.dump(labels, f)
+
+
+def webvid10m(subset):
+ print(f"generate labels for subset: {subset}")
+ assert subset in ["train", "val"]
+ src_file = f"/data/shared/datasets/webvid-10M/raw_data/results_10M_{subset}.csv"
+ dst_file = os.path.join(
+ os.environ["SL_DATA_DIR"], "anno_pretrain", f"webvid_10m_{subset}.json"
+ )
+ gen_labels(src_file, dst_file)
+
+
+if __name__ == "__main__":
+ webvid10m("val")
+ webvid10m("train")
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/preprocess/utils.py b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d9e72cf00f9a42c646ce767dfb26f5fe44cda74
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/preprocess/utils.py
@@ -0,0 +1,18 @@
+import json
+import subprocess
+
+
+def get_video_duration(filename):
+
+ result = subprocess.check_output(
+ f'ffprobe -v quiet -show_streams -select_streams v:0 -of json "{filename}"', shell=True
+ ).decode()
+ fields = json.loads(result)["streams"][0]
+
+ duration = float(fields["duration"])
+ return duration
+
+if __name__ == "__main__":
+ import os
+ fp = os.path.join(os.environ["SL_DATA_DIR"], "videos_images/webvid_10m_2fps_224/22920757.mp4")
+ print(get_video_duration(fp))
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/requirements.txt b/third_party/InternVideo/InternVideo2/multi_modality/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca05a0c7bc1549d9bd8c52c93c591bbd84fd7851
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/requirements.txt
@@ -0,0 +1,27 @@
+apex==0.9.10dev
+av==11.0.0
+decord==0.6.0
+deepspeed==0.10.1
+einops==0.7.0
+flash_attn==2.0.8
+fvcore==0.1.5.post20221221
+imageio==2.31.1
+librosa==0.10.1
+numpy==1.24.4
+opencv_python==4.8.0.76
+pandas==2.0.3
+petrel_oss_sdk==v2.2.1_2_g1505ef3_master
+Pillow==10.0.0
+psutil==5.9.5
+PyYAML==6.0.1
+scipy==1.13.0
+soundfile==0.12.1
+tensorflow==2.16.1
+termcolor==2.4.0
+timm==0.5.4
+torch==1.13.1+cu117
+torchaudio==0.13.1+cu117
+torchvision==0.14.1+cu117
+tqdm==4.66.1
+transformers==4.28.1
+wandb==0.16.1
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_anet.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_anet.py
new file mode 100644
index 0000000000000000000000000000000000000000..da54cbcd4e179d8edbc9716be548d810e6fe30af
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_anet.py
@@ -0,0 +1,141 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["anet_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+is_paragraph_retrieval = True
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_charades_mc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_charades_mc.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ba5607a86757cefeefc55afec5ddef65848f32
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_charades_mc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(mc_test=available_corpus["charades_mc_test"])
+test_types = ["mc_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 4
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_didemo.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_didemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..633e52dc6bfaef74ecc63508c68bc68027ea5869
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_didemo.py
@@ -0,0 +1,142 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["didemo_ret_test"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+is_paragraph_retrieval = True
+trimmed30 = True
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_hmdb51.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_hmdb51.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a8072fc9dc403a308e76742f98289aaf57d4716
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_hmdb51.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["hmdb51_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k400.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k400.py
new file mode 100644
index 0000000000000000000000000000000000000000..a59f122f98bb2342f55b07b9f1d26dc53297912d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k400.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k400_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k600.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k600.py
new file mode 100644
index 0000000000000000000000000000000000000000..63eece3d817dc616c963f6a9bd06986fcecf8dcc
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k600.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k600_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k700.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k700.py
new file mode 100644
index 0000000000000000000000000000000000000000..093a6c86f95b2be041eee8091008c1912ed01b2a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_k700.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k700_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_lsmdc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_lsmdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c936aa53688688d4050195867b5d63057f92c93f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_lsmdc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["lsmdc_ret_test_1000"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_mit.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_mit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b947af84829b71862b3436a153560e7df06499a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_mit.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["mit_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_msrvtt.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_msrvtt.py
new file mode 100644
index 0000000000000000000000000000000000000000..849692b20a74f2ef804186455098ab91037cd021
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_msrvtt.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["msrvtt_1k_test"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ssv2_mc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ssv2_mc.py
new file mode 100644
index 0000000000000000000000000000000000000000..484651d872941a9ccab3609ba3b295d9688e3cd7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ssv2_mc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(mc_test=available_corpus["ssv2_mc_val"])
+test_types = ["mc_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 4
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ucf101.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ucf101.py
new file mode 100644
index 0000000000000000000000000000000000000000..50efd7d003f7208f741f709c20d1d96cc1ede41c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_ucf101.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["ucf101_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_ch.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_ch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3221648056b673f7e52aee39a3c8b9d3b5e1747
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_ch.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["vatex_ch_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_en.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_en.py
new file mode 100644
index 0000000000000000000000000000000000000000..95a574afe5c76474069ca3d935ae6edec1195856
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/config_vatex_en.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["vatex_en_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_anet.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_anet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3f22b7b186925593b07ad7740a24a996752f3c61
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_anet.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_anet'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_anet.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_charades_mc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_charades_mc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..43e036be57f1ffb80aa37dadfb05c719b1f92452
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_charades_mc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_charades_mc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval_mc2.py \
+ $(dirname $0)/config_charades_mc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_hmdb51.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_hmdb51.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ea11bb13f72f8a448b0b55c4d49268451d896f9d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_hmdb51.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_hmdb51'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_hmdb51.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k400.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k400.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab35c257017be6efbc78882432a12f300e6a45bf
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k400.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k400'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k400.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k600.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k600.sh
new file mode 100644
index 0000000000000000000000000000000000000000..69951939dd64c14e2a73b5a5bd55abd6d1987510
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k600.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k600'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k600.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k700.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k700.sh
new file mode 100644
index 0000000000000000000000000000000000000000..abf8eca769e5731dd59f49d042d347772291d10f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_k700.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k700'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k700.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_lsmdc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_lsmdc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..03e10d94729ddb3f9f5808798e274401a10dbb6d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_lsmdc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_lsmdc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_lsmdc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_mit.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_mit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1416e8f193b77aa095c6d21a18327ab6e031287b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_mit.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_mit'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_mit.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_msrvtt.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_msrvtt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8d24932b94c3d6b5393ccaaee574165445e30c04
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_msrvtt.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_msrvtt'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_msrvtt.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ssv2_mc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ssv2_mc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ce0d8981849b27c5b1c7a7edffdb6bebe076ab87
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ssv2_mc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_ssv2_mc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval_mc.py \
+ $(dirname $0)/config_ssv2_mc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ucf101.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ucf101.sh
new file mode 100644
index 0000000000000000000000000000000000000000..727cbd97e164e22ff3a75e2e3ff61b0cdeb4bb2e
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_ucf101.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_ucf101.'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_ucf101.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_ch.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_ch.sh
new file mode 100644
index 0000000000000000000000000000000000000000..93ad5d72350b0019da7808d1661372fbac706264
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_ch.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_vatex_ch'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_vatex_ch.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_en.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_en.sh
new file mode 100644
index 0000000000000000000000000000000000000000..74a38b41c57698cfcf4f43aecc485e124ee00350
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/1B/eval_vatex_en.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_vatex_en'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_vatex_en.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_1B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_anet.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_anet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee11091e2b72d0d0c498f578680f43bce4e5dab7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_anet.py
@@ -0,0 +1,141 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["anet_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+is_paragraph_retrieval = True
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_charades_mc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_charades_mc.py
new file mode 100644
index 0000000000000000000000000000000000000000..17bfeebfbdb69f7466ca9ba2660dc2e6f0fe788a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_charades_mc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(mc_test=available_corpus["charades_mc_test"])
+test_types = ["mc_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 4
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_didemo.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_didemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..156d657a6748f366f6c2a3d85ed08bf006c7eaa5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_didemo.py
@@ -0,0 +1,142 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["didemo_ret_test"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+is_paragraph_retrieval = True
+trimmed30 = True
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_hmdb51.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_hmdb51.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0d6bc84b2e6c3891a08910ed8a67b9f58ac6939
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_hmdb51.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["hmdb51_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k400.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k400.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c272936d0e61cc3aa55e23e6c25d7a90a596002
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k400.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k400_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k600.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k600.py
new file mode 100644
index 0000000000000000000000000000000000000000..978eb9596c911ec103af557f26b1cddc1667463f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k600.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k600_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k700.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k700.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fdfcb8fabc2337f190cbf485632e8b0c26b568c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_k700.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k700_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_lsmdc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_lsmdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a158da8f32d10129dda8a78b240b3400958fbe
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_lsmdc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["lsmdc_ret_test_1000"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_mit.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_mit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac9500ec09c60417e63311093e9e609e6f79551
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_mit.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["mit_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_msrvtt.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_msrvtt.py
new file mode 100644
index 0000000000000000000000000000000000000000..232e422375edea6db2c58c2debc85a547c0aec28
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_msrvtt.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["msrvtt_1k_test"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ssv2_mc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ssv2_mc.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3f81f22d0d5119fec8b2a68fae85346fa89a0
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ssv2_mc.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(mc_test=available_corpus["ssv2_mc_val"])
+test_types = ["mc_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 4
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ucf101.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ucf101.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec23758b9bf50375b92e733ea8827a4590b39f1
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_ucf101.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["ucf101_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_ch.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_ch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7f5ba6eaa50ae1590d07a5cea56c7445944d3eb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_ch.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["vatex_ch_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_en.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_en.py
new file mode 100644
index 0000000000000000000000000000000000000000..db49d17956e712ac8c49c619b483c2475955c48a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/config_vatex_en.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "webvid_debug"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(ret_test=available_corpus["vatex_en_ret_val"])
+test_types = ["ret_test"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=False,
+ stage=0,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_anet.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_anet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d1e5f76d929a26d6f8bde484692dc808663a16ac
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_anet.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_anet'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_anet.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_charades_mc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_charades_mc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e79a1b078819c9521ea89c339a16700166c7d0e4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_charades_mc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_charades_mc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval_mc2.py \
+ $(dirname $0)/config_charades_mc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_hmdb51.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_hmdb51.sh
new file mode 100644
index 0000000000000000000000000000000000000000..95f44a3c4d1a97433d7b26b07732bf65ac6589f3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_hmdb51.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_hmdb51'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_hmdb51.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k400.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k400.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0640cf0fae16758762297c62936b2a0578871b6a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k400.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k400'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k400.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k600.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k600.sh
new file mode 100644
index 0000000000000000000000000000000000000000..139640b4b5a7d73a8c9d2fd058196d3541e5b2e4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k600.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k600'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k600.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k700.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k700.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b52874271efee0338d7cbfe10e8e35705326576a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_k700.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_k700'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_k700.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_lsmdc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_lsmdc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7c938a12dc8179a23ac38b4ec6d9b4415b253eaf
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_lsmdc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_lsmdc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_lsmdc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_mit.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_mit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..431e191b094415accf7ed5377fe7da297ed8c41b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_mit.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_mit'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_mit.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_msrvtt.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_msrvtt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3a7099716d528dc26ea393f674d5c74e481ac56e
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_msrvtt.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_msrvtt'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_msrvtt.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ssv2_mc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ssv2_mc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..39af7594fc046af52f2b00809382cd326d080742
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ssv2_mc.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_ssv2_mc'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval_mc.py \
+ $(dirname $0)/config_ssv2_mc.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ucf101.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ucf101.sh
new file mode 100644
index 0000000000000000000000000000000000000000..78babb33898ffeb34720b6cc30ee4f93597a977a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_ucf101.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_ucf101.'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_ucf101.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_ch.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_ch.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1b7fcc021a2c7b9144bca37c48d01fa35138a5a6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_ch.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_vatex_ch'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_vatex_ch.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_en.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_en.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d6b3913af5f2737a0895ae9ce6b47fd4b77a8362
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/clip/zero_shot/6B/eval_vatex_en.sh
@@ -0,0 +1,30 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='zs_vatex_en'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/retrieval.py \
+ $(dirname $0)/config_vatex_en.py \
+ pretrained_path your_model_path/InternVideo2_CLIP_6B.pth \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_anet.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_anet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b632b8cafcb751b34ba79c5cf8d11715b8978df
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_anet.py
@@ -0,0 +1,149 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+
+train_file = available_corpus["anet_ret_val"]
+test_file = dict(anet_ret_val=available_corpus["anet_ret_val"])
+
+test_types = ["anet_ret_val"]
+
+num_workers = 6
+
+best_key = ["anet_ret_val_match", "t2v_r1"]
+
+# ========================= input ==========================
+origin_num_frames = 4
+num_frames = 8
+num_frames_test = 8
+batch_size = 8 # 8 * 32
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=False,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ only_mask=True,
+ sep_image_video_pos_embed=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-5),
+)
+
+scheduler = dict(sched="cosine", epochs=20, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-ft", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "ret"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+zero_shot = True
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_didemo.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_didemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90ee7c9c298449fabda412238c556a26aa98c6c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_didemo.py
@@ -0,0 +1,155 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+
+train_file = available_corpus["didemo_ret_test"]
+test_file = dict(didemo_ret_test=available_corpus["didemo_ret_test"])
+
+test_types = ["didemo_ret_test"]
+
+num_workers = 6
+
+best_key = ["didemo_ret_test_match", "t2v_r1"]
+
+# ========================= input ==========================
+origin_num_frames = 4
+num_frames = 8
+num_frames_test = 8
+batch_size = 8 # 8 * 32
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ audio_input=dict(
+ audio_sample_rate=16000,
+ has_multi_audio_gt=False,
+ audio_reader_type='torchaudio',
+ max_audio_length=10
+ ),
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", audio="${max_txt_l}", video="${max_txt_l}", audio_video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", audio="${batch_size}", video="${batch_size}", audio_video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", audio="${batch_size_test}", video="${batch_size_test}", audio_video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=False,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ only_mask=True,
+ sep_image_video_pos_embed=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-5),
+)
+
+scheduler = dict(sched="cosine", epochs=20, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-ft", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "ret"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+zero_shot = True
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_lsmdc.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_lsmdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e5e4a8dbc3d646bbbc92d4b301ed37e854c22c3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_lsmdc.py
@@ -0,0 +1,149 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+
+train_file = available_corpus["lsmdc_ret_test_1000"]
+test_file = dict(lsmdc_ret_test_1000=available_corpus["lsmdc_ret_test_1000"])
+
+test_types = ["lsmdc_ret_test_1000"]
+
+num_workers = 6
+
+best_key = ["lsmdc_ret_test_1000_match", "t2v_r1"]
+
+# ========================= input ==========================
+origin_num_frames = 4
+num_frames = 8
+num_frames_test = 8
+batch_size = 8 # 8 * 32
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=False,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ only_mask=True,
+ sep_image_video_pos_embed=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-5),
+)
+
+scheduler = dict(sched="cosine", epochs=20, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-ft", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "ret"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+zero_shot = True
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msrvtt.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msrvtt.py
new file mode 100644
index 0000000000000000000000000000000000000000..27002f017223e95aa714fafb2bd39d479a485c09
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msrvtt.py
@@ -0,0 +1,148 @@
+from configs.data import *
+from configs.model import *
+# ========================= data ==========================
+# NOTE The train_file will not be used during the evaluation
+train_file = available_corpus["msrvtt_1k_test"]
+test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"])
+
+test_types = ["msrvtt_1k_test"]
+
+
+num_workers = 6
+
+best_key = ["msrvtt_1k_test_match", "t2v_r1"]
+
+# ========================= input ==========================
+num_frames = 4
+num_frames_test = 4
+batch_size = 8
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=True,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ sep_image_video_pos_embed=True,
+ keep_temporal=False,
+ only_mask=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-Stage2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = ""
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msvd.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msvd.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab9aff0557793d0c75f8565f2ab972684cac7e3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_msvd.py
@@ -0,0 +1,148 @@
+from configs.data import *
+from configs.model import *
+# ========================= data ==========================
+# NOTE The train_file will not be used during the evaluation
+train_file = available_corpus["msrvtt_1k_test"]
+test_file = dict(msvd_ret_test=available_corpus["msvd_ret_test"])
+
+test_types = ["msvd_ret_test"]
+
+
+num_workers = 6
+
+best_key = ["msvd_ret_test_match", "t2v_r1"]
+
+# ========================= input ==========================
+num_frames = 4
+num_frames_test = 4
+batch_size = 8
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=True,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ sep_image_video_pos_embed=True,
+ keep_temporal=False,
+ only_mask=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-Stage2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = ""
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_vatex.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_vatex.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1de9a75604012cdaefce5f6368a3bb19112d097
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/config_vatex.py
@@ -0,0 +1,155 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+# NOTE The train_file will not be used during the evaluation
+train_file = available_corpus["vatex_en_ret_val"]
+test_file = dict(vatex_en_ret_val=available_corpus["vatex_en_ret_val"])
+
+test_types = ["vatex_en_ret_val"]
+
+num_workers = 6
+
+best_key = ["vatex_en_ret_val", "t2v_r1"]
+
+# ========================= input ==========================
+origin_num_frames = 4
+num_frames = 4
+num_frames_test = 4
+batch_size = 8 # 8 * 32
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ audio_input=dict(
+ audio_sample_rate=16000,
+ has_multi_audio_gt=False,
+ audio_reader_type='torchaudio',
+ max_audio_length=10
+ ),
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", audio="${max_txt_l}", video="${max_txt_l}", audio_video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", audio="${batch_size}", video="${batch_size}", audio_video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", audio="${batch_size_test}", video="${batch_size_test}", audio_video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=False,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ only_mask=True,
+ sep_image_video_pos_embed=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=1e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-5),
+)
+
+scheduler = dict(sched="cosine", epochs=20, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = True
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-ft", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "ret"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+zero_shot = True
+
+save_latest = False
+auto_resume = True
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_anet.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_anet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..53435e6b0ad2a98c326ef36b81403f4f548a0d01
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_anet.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_anet.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_didemo.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_didemo.sh
new file mode 100644
index 0000000000000000000000000000000000000000..14a4095c692463b8d6fbdc82a3ca8a0e502721db
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_didemo.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_didemo.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_lsmdc.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_lsmdc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0b1c59ccca7e8e4fe2a6654f90a79578ebe83208
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_lsmdc.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_lsmdc.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..086b253232512f27d9bc63f301bb3c941ba1c19f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msrvtt.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_msrvtt.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msvd.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msvd.sh
new file mode 100644
index 0000000000000000000000000000000000000000..450260b2f13978aba8ca898bd8e373d771a00032
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_msvd.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_msvd.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_vatex.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_vatex.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8506369d51ffa467852c7f5446c3332f212a322b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/evaluation/stage2/zero_shot/1B/eval_vatex.sh
@@ -0,0 +1,35 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video5'
+NNODE=1
+NUM_GPUS=1
+NUM_CPU=16
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config_vatex.py \
+ output_dir ${OUTPUT_DIR} \
+ evaluate True \
+ pretrained_path 'your_model_path/1B_stage2_pt.pth'
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/config.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb9c7b358fa4800d9cb3d0e8b31d5913793c7158
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/config.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "data_25m"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k400_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 256
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_1B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.3,
+ head_drop_path_rate=0.,
+ embed_dim=1408,
+ num_heads=16,
+ mlp_ratio=48/11,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_1B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.01, warmup_epochs=0.6)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 500
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/run.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..daef72d686b1e8c7b8fb1866cd940c05b3bf5c3f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/1B/run.sh
@@ -0,0 +1,29 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='1B'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=2
+NUM_GPUS=8
+NUM_CPU=128
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_clip/pretrain.py \
+ $(dirname $0)/config.py \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/config.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ccb72b6886e28313a70a1a03d4a74cbf67e4658
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/config.py
@@ -0,0 +1,140 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_corpus = "data_25m"
+train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
+test_file = dict(act_val=available_corpus["k400_act_val"])
+test_types = ["act_val"]
+num_workers = 12
+
+stop_key = None
+
+# ========================= input ==========================
+num_frames = 8
+num_frames_test = 8
+batch_size = 128
+batch_size_test = 64
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+model = dict(
+ model_cls="InternVideo2_CLIP",
+ vision_encoder=dict(
+ name="internvideo2_6B",
+ in_chans=3,
+ patch_size=14,
+ img_size=224,
+ qkv_bias=False,
+ drop_path_rate=0.35,
+ head_drop_path_rate=0.,
+ embed_dim=3200,
+ num_heads=25,
+ mlp_ratio=4,
+ init_values=0.1,
+ qk_normalization=True,
+ depth=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ fused_mlp_heuristic=1,
+ drop_cls_token=False,
+ attn_pool_num_heads=16,
+ clip_embed_dim=768,
+ layerscale_no_force_fp32=True,
+ num_frames=8,
+ tubelet_size=1,
+ sep_pos_embed=False,
+ use_checkpoint=False,
+ checkpoint_num=0,
+ ),
+ text_encoder=dict(
+ use_flash_attn=True,
+ transformer_width=4096,
+ llama_path="your_model_path/chinese_alpaca_lora_7b",
+ use_lora=True,
+ ),
+ temp=1 / 100.0,
+ temp_min=1 / 100.0,
+ freeze_vision=True,
+ open_vision_clip_projector=True,
+ freeze_text=True,
+ open_text_projection=False,
+ open_text_lora=False,
+ tokenizer_path="your_model_path/chinese_alpaca_lora_7b",
+ vision_ckpt_path="your_model_path/InternVideo2_Stage2_6B.pth",
+ load_vision_ckpt_from_internvideo2_stage2=True,
+ text_ckpt_path="your_model_path/internvl/internvl_c_13b_224px.pth",
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ ), # 0: disabled.
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=4e-4,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.2,
+ max_grad_norm=-1, # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=1, min_lr_multi=0.01, warmup_epochs=0.2)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+gradient_checkpointing = True
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="likunchang", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2_CLIP", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 1
+seed = 42
+
+save_latest = False
+save_iter = 1000
+auto_resume = True
+pretrained_path = "" # path to pretrained model weights, for resume only?
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/run.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a6e0ab101a5c8bd4c978d55e5206523bed11198f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/clip/6B/run.sh
@@ -0,0 +1,29 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME='6B'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=2
+NUM_GPUS=8
+NUM_CPU=32
+
+srun -p ${PARTITION} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --cpus-per-task=${NUM_CPU} \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks_ds/pretrain.py \
+ $(dirname $0)/config.py \
+ output_dir ${OUTPUT_DIR}
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/config.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bccfcd05ed54bb87f90eadda838aa61e6163b5c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/config.py
@@ -0,0 +1,151 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_file = available_corpus["pretrain_example_data_1B"]
+
+test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"],
+ didemo_ret_test=available_corpus["didemo_ret_test"])
+
+test_types = ["msrvtt_1k_test", "didemo_ret_test"]
+num_workers = 6
+
+best_key = ["msrvtt_1k_test_match", "t2v_r1"]
+
+# ========================= input ==========================
+num_frames = 4
+num_frames_test = 4
+batch_size = 64 # 64 * 64
+batch_size_test = 4
+max_txt_l = 32
+
+inputs = dict(
+ image_res=224,
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2",
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_1b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=1408,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/1B_pt.pth',
+ use_checkpoint=False,
+ checkpoint_num=40,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ sep_image_video_pos_embed=True,
+ keep_temporal=False,
+ only_mask=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ embed_dim=512,
+ temp=0.07,
+ find_unused_parameters=False
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ mvm=0.0,
+ uta=0.0,
+ ), # 0: disabled.
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.]
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=5e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-Stage2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+
+save_latest = True
+auto_resume = False
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+save_ckpt_iter = None
+delete_ds_optim_states = True
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/run.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0eeed8372dee64d0f04f57644e0e897f69850d6c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/1B/run.sh
@@ -0,0 +1,34 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=8
+NUM_GPUS=8
+NUM_CPU=128
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config.py \
+ output_dir ${OUTPUT_DIR}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/config.py b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4d9b6852018875e06a028fc0def7aa14418f89
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/config.py
@@ -0,0 +1,189 @@
+from configs.data import *
+from configs.model import *
+
+# ========================= data ==========================
+train_file = available_corpus["pretrain_example_data_1B"]
+
+
+test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"],
+ didemo_ret_test=available_corpus["didemo_ret_test"])
+
+test_types = ["msrvtt_1k_test", "didemo_ret_test"]
+num_workers = 6
+
+best_key = ["msrvtt_1k_test_match", "t2v_r1"]
+
+# ========================= input ==========================
+num_frames = 4
+num_frames_test = 4
+batch_size = 64 # 64 * 64
+batch_size_test = 4
+max_txt_l = 40
+
+inputs = dict(
+ image_res=224,
+ audio_input=dict(
+ audio_sample_rate=16000,
+ has_multi_audio_gt=False,
+ audio_reader_type='torchaudio',
+ max_audio_length=10
+ ),
+ video_input=dict(
+ num_frames="${num_frames}",
+ sample_type="rand",
+ num_frames_test="${num_frames_test}",
+ sample_type_test="middle",
+ random_aug=False,
+ ),
+ max_txt_l=dict(image="${max_txt_l}", audio="${max_txt_l}", video="${max_txt_l}", audio_video="${max_txt_l}"),
+ batch_size=dict(image="${batch_size}", audio="${batch_size}", video="${batch_size}", audio_video="${batch_size}"),
+ batch_size_test=dict(image="${batch_size_test}", audio="${batch_size_test}", video="${batch_size_test}", audio_video="${batch_size_test}"),
+)
+
+# ========================= model ==========================
+text_enc = "bert_large"
+model = dict(
+ model_cls="InternVideo2_Stage2_audio",
+ audio_encoder=dict(
+ name='beats',
+ d_model=768,
+ audio_model_path="your_model_path/beats.pth",
+ ),
+ vision_encoder=dict(
+ # backbone
+ name="pretrain_internvideo2_6b_patch14_224",
+ img_size=224,
+ num_frames="${num_frames}",
+ tubelet_size=1,
+ patch_size=14,
+ d_model=3200,
+ clip_embed_dim=768,
+ clip_teacher_embed_dim=3200,
+ clip_teacher_final_dim=768,
+ clip_norm_type='l2',
+ clip_return_layer=6,
+ clip_student_return_interval=1,
+ pretrained='your_model_path/6B_pt.pth',
+ use_checkpoint=True,
+ checkpoint_num=48,
+ use_flash_attn=True,
+ use_fused_rmsnorm=True,
+ use_fused_mlp=True,
+ # clip teacher
+ clip_teacher=None,
+ clip_input_resolution=224,
+ clip_teacher_return_interval=1,
+ # mask
+ video_mask_type="random",
+ video_mask_ratio=0.8,
+ image_mask_type="random",
+ image_mask_ratio=0.5,
+ sep_image_video_pos_embed=False,
+ keep_temporal=False,
+ only_mask=True
+ ),
+ text_encoder="${TextEncoders[${text_enc}]}",
+ multimodal=dict(enable=True),
+ contra_dim=768,
+ av_concat_dim=768,
+ temp=0.07,
+ find_unused_parameters=False,
+ freeze_vision=False,
+ freeze_audio=True
+)
+
+criterion = dict(
+ loss_weight=dict(
+ vtc=1.0,
+ mlm=1.0,
+ vtm=1.0,
+ uta=0.0,
+ # audio-related
+ atc=0.0,
+ avc=0.0,
+ avtc=1.0,
+ atm=0.0,
+ avtm=1.0,
+ amlm=0.0,
+ avmlm=1.0
+ ), # 0: disabled.
+ # ['video_name', 'selected_audio_caption', 'selected_video_caption', 'asr_captions', 'av_captions', 'video_fps', 'video_start_frame', 'video_end_frame', 'video']
+ loss_caption=dict(
+ # vision-related
+ vtc='avs_captions',
+ vtm='avs_captions',
+ mlm='avs_captions',
+ # audio-related
+ # atc='selected_audio_caption',
+ # atm='selected_audio_caption',
+ # amlm='selected_audio_caption',
+ # audio-vision-related
+ avtc='avs_captions',
+ avtm='avs_captions',
+ avmlm='avs_captions',
+ ),
+ vtm_hard_neg=True,
+ mlm_masking_prob=0.5,
+ distill_final_features=True,
+ clip_loss_ratio=[1., 1.],
+ uta_image_only=True
+)
+
+optimizer = dict(
+ opt="adamW",
+ lr=5e-5,
+ opt_betas=[0.9, 0.98], # default
+ weight_decay=0.05,
+ max_grad_norm=3., # requires a positive float, use -1 to disable
+ # use a different lr for some modules, e.g., larger lr for new modules
+ different_lr=dict(enable=False, module_names=[], lr=1e-3),
+)
+
+scheduler = dict(sched="cosine", epochs=5, min_lr_multi=0.01, warmup_epochs=1)
+
+evaluate = False
+deep_fusion = False
+evaluation = dict(
+ eval_frame_ensemble="concat", # [concat, max, mean, lse]
+ eval_x_only=False,
+ k_test=128,
+ eval_offload=True, # offload gpu tensors to cpu to save memory.
+)
+
+use_half_precision = True
+use_bf16 = True
+
+gradient_checkpointing = True # for text encoder
+use_flash_sdp = False
+use_mem_efficient_sdp = False and not use_flash_sdp
+compile_model = False
+
+# ========================= wandb ==========================
+wandb = dict(
+ enable=False,
+ entity="opengvlab", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
+ project="InternVideo2-Stage2", # setup in your command line
+)
+dist_url = "env://"
+device = "cuda"
+mode = "pt"
+
+# ========================= others ==========================
+output_dir = None # output dir
+resume = False # if True, load optimizer and scheduler states as well
+debug = False
+log_freq = 100
+seed = 42
+
+save_latest = True
+auto_resume = False
+jump_evaluate = False
+pretrained_path = "" # path to pretrained model weights, for resume only?
+save_ckpt_iter = None
+delete_ds_optim_states = True
+
+deepspeed = dict(
+ enable=True,
+ stage=1,
+)
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/run.sh b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..614fb01ef86e433a2ac21d6248eb1495dc9bfc19
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/scripts/pretraining/stage2/6B/run.sh
@@ -0,0 +1,34 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+JOB_NAME=$(basename $0)_$(date +"%Y%m%d_%H%M%S")
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="$(dirname $0)/logs/${JOB_NAME}"
+PARTITION='video'
+NNODE=32
+NUM_GPUS=8
+NUM_CPU=128
+
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ -n${NNODE} \
+ --gres=gpu:${NUM_GPUS} \
+ --ntasks-per-node=1 \
+ --quotatype=auto \
+ --cpus-per-task=${NUM_CPU} \
+ --quotatype=auto --kill-on-bad-exit=1 \
+ bash torchrun.sh \
+ --nnodes=${NNODE} \
+ --nproc_per_node=${NUM_GPUS} \
+ --rdzv_backend=c10d \
+ tasks/pretrain.py \
+ $(dirname $0)/config.py \
+ output_dir ${OUTPUT_DIR}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks/pretrain.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcbce81c722c86bc2e080124b6c87329fd54e0a3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks/pretrain.py
@@ -0,0 +1,567 @@
+import datetime
+import logging
+import time
+from os.path import join
+
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import wandb
+
+from dataset.serialize import local_broadcast_process_authkey
+from dataset import MetaLoader, MetaLoader_rs2, create_dataset, create_loader, create_sampler, create_stateful_sampler
+from models import *
+from tasks.retrieval_utils import evaluation_wrapper
+from tasks.shared_utils import get_media_types, setup_model
+from utils.basic_utils import (MetricLogger, SmoothedValue,
+ remove_files_if_exist, setup_seed)
+from utils.config_utils import setup_main
+from utils.distributed import get_rank, get_world_size, is_main_process
+from utils.logger import log_dict_to_wandb, setup_wandb
+try:
+ from petrel_client.client import Client
+except:
+ Client = None
+import io
+import os
+import shutil
+
+logger = logging.getLogger(__name__)
+
+ceph_ckpt_bucket = "shdd:s3://avp_ckpt"
+
+
+def train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ skip_num=0
+):
+
+ try:
+ ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}"
+ client_ckpt = Client(conf_path='~/petreloss.conf')
+ except Exception as e:
+ print(e)
+ logger.info("Ceph is not working!!!")
+
+
+ if config.use_half_precision:
+ if config.get('use_bf16', False):
+ cast_dtype = torch.bfloat16
+ else:
+ cast_dtype = torch.float16
+ else:
+ cast_dtype = None
+
+ model.train()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}"))
+ metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}"))
+ loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+
+ if config.get("use_raw_text", False): # for cosa
+ loss_names = loss_names + ["c_loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+ uta_all = config.criterion.get('uta_all', False)
+
+ media_types = get_media_types(train_loaders)
+
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(
+ f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}")
+ )
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ if config.distributed:
+ for d in train_loaders:
+ d.sampler.set_epoch(epoch)
+
+ if config.get('use_iter_train', False):
+ train_loader = MetaLoader_rs2(name2loader=dict(list(zip(media_types, train_loaders))), skip_num=skip_num)
+ else:
+ train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
+
+ model_without_ddp = model.module if config.distributed else model
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+
+ begin_step = global_step % len(train_loader)
+ logger.info(f"Epoch={epoch}, begin_step={begin_step} save_ckpt_iter={config.get('save_ckpt_iter', None)}")
+
+ for local_step, (media_type, (media, text, idx)) in enumerate(iterator):
+ if local_step < begin_step:
+ logger.warn(f"Jump local_step: {local_step} (begin_step={begin_step})!!!")
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ continue
+
+ if config.get("save_ckpt_iter", None) is not None: # and not is_iter_resume:
+ if local_step != 0 and local_step % config.get("save_ckpt_iter") == 0:
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ tag = f"ckpt_e{epoch:02d}_local{local_step}_global{global_step}"
+ client_state = {'epoch': epoch, 'global_step': global_step, 'local_step': local_step}
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state)
+ logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!")
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ for k in config.get("no_save_params_prefix", []):
+ kk = [x for x in state_dict.keys() if x.startswith(k)]
+ logger.info(f"Not saving {len(kk)} params with prefix {k}")
+ for kkk in kk:
+ state_dict.pop(kkk)
+
+ save_obj = {
+ "model": state_dict,
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "local_step": local_step,
+ "global_step": global_step,
+ }
+ try:
+ with io.BytesIO() as buffer:
+ torch.save(save_obj, buffer)
+ client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth", buffer.getvalue())
+ logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth)!!!")
+ except Exception as e:
+ print(e)
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth"))
+ logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth')})!!!")
+
+ if media_type == 'audio_video':
+ if type(media[0]) is list:
+ assert len(media[0]) == 2
+ audio = [media[0][0].to(device, dtype=cast_dtype, non_blocking=True), media[0][1].to(device, non_blocking=True)]
+ else:
+ audio = media[0].to(device, dtype=cast_dtype, non_blocking=True)
+ video = media[1].to(device, dtype=cast_dtype, non_blocking=True)
+ media = [audio, video]
+ else:
+ media = media.to(device, dtype=cast_dtype, non_blocking=True)
+ idx = idx.to(device, non_blocking=True)
+ if config.get("use_raw_text", False) or config.get("use_cosa", False):
+ max_length = config.inputs.max_txt_l[media_type]
+ else:
+ if type(text) is dict:
+ text_input = {}
+ for k in text.keys():
+ text_input[k] = tokenizer(
+ text[k],
+ padding="max_length",
+ truncation=True,
+ max_length=config.inputs.max_txt_l[media_type],
+ return_tensors="pt",
+ ).to(
+ device) # change from "longest" to "max_length"
+ else:
+ text_input = tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=config.inputs.max_txt_l[media_type],
+ return_tensors="pt",
+ ).to(
+ device) # change from "longest" to "max_length"
+
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ loss_dict = model(media, text_input, idx=idx, media_type=media_type)
+ loss = sum(loss_dict.values())
+
+ model.backward(loss)
+ model.step()
+
+ else: # NOTE We shouldn't use scaler if we only involve bf16, check this!
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=cast_dtype):
+ loss_dict = model(media, text_input, idx=idx, media_type=media_type)
+ loss = sum(loss_dict.values())
+
+ if not config.use_half_precision or config.get('use_bf16', False):
+ optimizer.zero_grad()
+ loss.backward()
+ if config.optimizer.max_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+
+ optimizer.step()
+ scheduler.step()
+ else:
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ if config.optimizer.max_grad_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+
+ # logging
+ for name in loss_names:
+ if name in loss_dict.keys():
+ value = loss_dict[name]
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(temperature=model_without_ddp.temp.item())
+
+ if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
+ try:
+ logs = metric_logger.get_global_avg_dict()
+ log_dict_to_wandb(logs, step=global_step, prefix="train/")
+ except Exception as e:
+ logger.warn("Wandb is not working!!!")
+ print(e)
+
+ global_step += 1
+
+ if config.debug and global_step % 20 == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ if config.debug and global_step % (2 * log_freq + 3) == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged stats: {metric_logger.global_avg()}")
+ return global_step
+
+
+def setup_dataloaders(config, mode="pt"):
+ # train datasets, create a list of data loaders
+ logger.info(f"Creating dataset for {mode} use_iter_train={config.get('use_iter_train', False)}")
+ train_datasets = create_dataset(f"{mode}_train", config)
+ media_types = get_media_types(train_datasets)
+
+ if config.get('use_iter_train', False):
+ if config.distributed:
+ batch_size = [config.inputs.batch_size[k] for k in media_types] # batch_size for each GPU
+ samplers = create_stateful_sampler(train_datasets, batch_size)
+ else:
+ raise NotImplementedError
+ else:
+ if config.distributed:
+ num_tasks = get_world_size()
+ global_rank = get_rank()
+ samplers = create_sampler(
+ train_datasets, [True] * len(media_types), num_tasks, global_rank
+ )
+ else:
+ samplers = [None] * len(media_types)
+
+ train_loaders = create_loader(
+ train_datasets,
+ samplers,
+ batch_size=[config.inputs.batch_size[k] for k in media_types],
+ num_workers=[config.num_workers] * len(media_types),
+ is_trains=[True] * len(media_types),
+ collate_fns=[None] * len(media_types),
+ ) # [0]
+
+ # test datasets, a mapping from dataset name to data loader
+ test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config)
+ test_loaders = create_loader(
+ test_datasets,
+ [None] * len(test_datasets),
+ batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets],
+ num_workers=[config.num_workers] * len(test_datasets),
+ is_trains=[False] * len(test_datasets),
+ collate_fns=[None] * len(test_datasets),
+ )
+ test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)}
+ return train_loaders, test_name2loaders, media_types
+
+
+def main(config):
+ if config.get('use_flash_sdp', False):
+ torch.backends.cuda.enable_flash_sdp(enabled=True)
+ elif config.get('use_mem_efficient_sdp', False):
+ torch.backends.cuda.enable_mem_efficient_sdp(enabled=True)
+
+ try:
+ ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}"
+ client_ckpt = Client(conf_path='~/petreloss.conf')
+ except Exception as e:
+ print(e)
+ logger.info("Ceph is not working!!!")
+
+ if is_main_process() and config.wandb.enable:
+ try:
+ run = setup_wandb(config)
+ logger.info("Wandb is working!!!")
+ except Exception as e:
+ logger.warn("Wandb is not working!!!")
+ print(e)
+
+ is_pretrain = config.mode == "pt"
+
+ logger.info(f"train_file: {config.train_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+
+ train_loaders, test_name2loaders, train_media_types = setup_dataloaders(
+ config, mode=config.mode
+ )
+ num_steps_per_epoch = sum(len(d) for d in train_loaders)
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
+ # set cudnn.benchmark=True only when input size is fixed
+ # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
+ cudnn.benchmark = len(train_media_types) == 1
+
+ print(f"\033[31m CURRENT NODE NAME: {os.environ['SLURMD_NODENAME']} dataloader is OK {datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S')}!!! \033[0m")
+
+ find_unused_parameters = config.model.get('find_unused_parameters', False)
+ logger.info(f"find_unused_parameters={find_unused_parameters}")
+
+ model_cls = eval(config.model.get('model_cls'))
+ (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ ) = setup_model(
+ config,
+ model_cls=model_cls,
+ add_decoder=False,
+ pretrain=is_pretrain,
+ find_unused_parameters=find_unused_parameters,
+ )
+
+ if is_main_process() and config.wandb.enable:
+ try:
+ wandb.watch(model)
+ except Exception as e:
+ logger.warn("Wandb is not working!!!")
+ print(e)
+
+ best = 0
+ best_epoch = 0
+ if type(config.best_key) is str:
+ best_key = [config.best_key, "t2v_r1"]
+ elif type(config.best_key) is list and len(config.best_key) == 2:
+ best_key = config.best_key
+ else:
+ raise NotImplementedError(config.best_key)
+
+ best_ckpt_id = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ logger.info(f"Start training, start_epoch={start_epoch}")
+ start_time = time.time()
+ start_step = start_epoch * num_steps_per_epoch
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ global_step = train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ skip_num = global_step - start_step
+ )
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if config.get("save_latest", False):
+ tag = "ckpt_latest"
+ else:
+ tag = f"ckpt_{epoch:02d}"
+
+ client_state = {'epoch': epoch, 'global_step': global_step}
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state)
+
+ logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!")
+ if is_main_process() and config.get("delete_ds_optim_states", False):
+ if config.get("save_latest", False):
+ if epoch == (config.scheduler.epochs - 1): # last epoch
+ last_tag = "ckpt_latest"
+ last_ckpt_path = f"{config.output_dir}/{last_tag}"
+ if os.path.exists(last_ckpt_path):
+ logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!")
+ for file in os.listdir(last_ckpt_path):
+ if file.endswith('optim_states.pt'):
+ os.remove(os.path.join(last_ckpt_path, file))
+ else:
+ last_tag = f"ckpt_{epoch-1:02d}"
+ last_ckpt_path = f"{config.output_dir}/{last_tag}"
+ if os.path.exists(last_ckpt_path):
+ logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!")
+ for file in os.listdir(last_ckpt_path):
+ if file.endswith('optim_states.pt'):
+ os.remove(os.path.join(last_ckpt_path, file))
+
+ if epoch == (config.scheduler.epochs - 1): # last epoch
+ last_tag = f"ckpt_{epoch:02d}"
+ last_ckpt_path = f"{config.output_dir}/{last_tag}"
+ if os.path.exists(last_ckpt_path):
+ logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!")
+ for file in os.listdir(last_ckpt_path):
+ if file.endswith('optim_states.pt'):
+ os.remove(os.path.join(last_ckpt_path, file))
+
+ if is_main_process():
+ if not (hasattr(config, "deepspeed") and config.deepspeed.enable):
+ state_dict = model_without_ddp.state_dict()
+ for k in config.get("no_save_params_prefix", []):
+ kk = [x for x in state_dict.keys() if x.startswith(k)]
+ logger.info(f"Not saving {len(kk)} params with prefix {k}")
+ for kkk in kk:
+ state_dict.pop(kkk)
+
+ save_obj = {
+ "model": state_dict,
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ try:
+ with io.BytesIO() as buffer:
+ torch.save(save_obj, buffer)
+ if config.get("save_latest", False):
+ client_ckpt.put(f"{ceph_ckpt_path}/ckpt_latest.pth", buffer.getvalue())
+ logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_latest.pth)!!!")
+ else:
+ client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}.pth", buffer.getvalue())
+ logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}.pth)!!!")
+ except Exception as e:
+ print(e)
+ if config.get("save_latest", False):
+ torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
+ logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, 'ckpt_latest.pth')})!!!")
+ else:
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
+ logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}.pth')})!!!")
+
+
+
+ if config.get("jump_evaluate", False) and not config.evaluate:
+ logger.warn(f"Jump the evaluation'))!!!")
+ else:
+ try:
+ eval_res = {}
+ for test_name, test_loader in test_name2loaders.items():
+ if test_name not in config.test_types:
+ logger.info(
+ f"Skip eval {test_name} split. All test_types {config.test_types}"
+ )
+ continue
+ res = evaluation_wrapper(
+ model_without_ddp, test_loader, tokenizer, device, config, prefix=test_name
+ )
+ eval_res.update(res)
+
+ if is_main_process():
+ # log to wandb
+ if config.wandb.enable:
+ try:
+ for p, v in eval_res.items():
+ log_dict_to_wandb(v, step=global_step, prefix=p)
+ except Exception as e:
+ logger.warn("Wandb is not working!!!")
+ print(e)
+
+ try:
+ cur_recall = eval_res[best_key[0]][best_key[1]]
+ except Exception as e:
+ logger.warn(e)
+ print(e)
+ # print(eval_res)
+ cur_recall = best - 1
+
+ eval_res = pd.DataFrame(eval_res)
+ logger.info(f"Epoch {epoch}")
+ logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}")
+
+ eval_res.to_json(join(config.output_dir, "eval_res_latest.json"))
+
+ state_dict = model_without_ddp.state_dict()
+
+ for k in config.get("no_save_params_prefix", []):
+ kk = [x for x in state_dict.keys() if x.startswith(k)]
+ logger.info(f"Not saving {len(kk)} params with prefix {k}")
+ for kkk in kk:
+ state_dict.pop(kkk)
+
+ if not config.evaluate and cur_recall > best:
+ if not (hasattr(config, "deepspeed") and config.deepspeed.enable):
+ try:
+ with io.BytesIO() as buffer:
+ torch.save(save_obj, buffer)
+ client_ckpt.put(f"{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth", buffer.getvalue())
+ logger.info(f"Save to ceph ({f'{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth'})!!!")
+ except Exception as e:
+ print(e)
+ torch.save(save_obj, join(config.output_dir, f"ckpt_best_{best_ckpt_id}.pth"))
+ logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_best_{best_ckpt_id}.pth')})!!!")
+ else:
+
+ now_ckpt_path = f"{config.output_dir}/{tag}/mp_rank_00_model_states.pt"
+ best_ckpt_path = f"{config.output_dir}/best_mp_rank_00_model_states.pt"
+
+ if os.path.exists(now_ckpt_path):
+ shutil.copy(now_ckpt_path, best_ckpt_path)
+ logger.info(f"Copy {now_ckpt_path} to {best_ckpt_path}!!!")
+ else:
+ logger.warn(f"Can't find {now_ckpt_path}, there's some wrong!!!")
+
+ eval_file = "eval_res_best.json"
+ eval_res.to_json(join(config.output_dir, eval_file))
+ best = cur_recall
+ best_epoch = epoch
+ except Exception as e:
+ logger.warn("Something wrong when eval or save!!!")
+ print(e)
+ if config.evaluate:
+ raise e
+
+
+ if config.evaluate:
+ break
+
+ start_step = global_step
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"best epoch {best_epoch} [best_key {best_key}]")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+ if is_main_process() and config.wandb.enable:
+ try:
+ run.finish()
+ except Exception as e:
+ logger.warn("Wandb is not working!!!")
+ print(e)
+
+
+if __name__ == "__main__":
+ print(f"\033[31m NODE LIST: {os.environ['SLURM_NODELIST']} \033[0m")
+ logger.info(f"NODE LIST: {os.environ['SLURM_NODELIST']}")
+ cfg = setup_main()
+ local_broadcast_process_authkey()
+ main(cfg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks/retrieval_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks/retrieval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f677a3f1cc33cd1f4065a83bd5bfff7405130e6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks/retrieval_utils.py
@@ -0,0 +1,1305 @@
+import datetime
+import logging
+import time
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from einops import rearrange
+
+from models.criterions import get_sim
+from utils.basic_utils import MetricLogger
+from utils.distributed import get_rank, get_world_size
+
+logger = logging.getLogger(__name__)
+
+
+def extract_text_feats(texts, max_txt_l, tokenizer, model, device, return_ids=False):
+ num_text = len(texts)
+ text_bs = 256
+ text_feats = []
+ text_atts = []
+
+ if return_ids:
+ text_ids = []
+
+ for i in range(0, num_text, text_bs):
+ text = texts[i : min(num_text, i + text_bs)]
+ text_input = tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=max_txt_l,
+ return_tensors="pt",
+ ).to(device) # NOTE not need to cast
+
+ text_feat = model.encode_text(text_input)[0]
+ text_feats.append(text_feat)
+ text_atts.append(text_input.attention_mask)
+ if return_ids:
+ text_ids.append(text_input.input_ids)
+
+ text_feats = torch.cat(text_feats, dim=0)
+ text_atts = torch.cat(text_atts, dim=0)
+ if return_ids:
+ text_ids = torch.cat(text_ids, dim=0)
+ return text_feats, text_atts, text_ids
+ else:
+ return text_feats, text_atts
+
+def extract_vision_feats(data_loader, model, device, config):
+ if config.use_half_precision:
+ if config.get('use_bf16', False):
+ cast_dtype = torch.bfloat16
+ else:
+ cast_dtype = torch.float16
+ else:
+ cast_dtype = None
+
+ image_feats_all = []
+ pooled_image_feats_all = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "extracting image feats"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for image, img_id in iterator:
+ image = image.to(device, dtype=cast_dtype, non_blocking=True)
+ image_feat, pooled_image_feat = model.encode_vision(image, test=True)
+ if len(pooled_image_feat.shape) == 2:
+ pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy
+ if config.evaluation.eval_frame_ensemble == "concat":
+ if len(image_feat.shape) == 4:
+ image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
+ image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
+ else:
+ assert config.video_input.num_frames == 1, "only support single-frame"
+ assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
+ if config.evaluation.eval_offload:
+ image_feats_all.append(image_feat.cpu())
+ pooled_image_feats_all.append(pooled_image_feat.cpu())
+ else:
+ image_feats_all.append(image_feat)
+ pooled_image_feats_all.append(pooled_image_feat)
+
+ image_feats_all = torch.cat(image_feats_all, dim=0)
+ pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
+
+ return image_feats_all, pooled_image_feats_all
+
+def extract_audio_feats(data_loader, model, device, config):
+ if config.use_half_precision:
+ if config.get('use_bf16', False):
+ cast_dtype = torch.bfloat16
+ else:
+ cast_dtype = torch.float16
+ else:
+ cast_dtype = None
+
+ audio_feats_all = []
+ pooled_audio_feats_all = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "extracting audio feats"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for audio, _ in iterator:
+ audio = audio.to(device, dtype=cast_dtype, non_blocking=True)
+ audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True)
+ audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy
+ pooled_audio_feat = pooled_audio_feat.unsqueeze(1)
+ if config.evaluation.eval_offload:
+ audio_feats_all.append(audio_feat.cpu())
+ pooled_audio_feats_all.append(pooled_audio_feat.cpu())
+ else:
+ audio_feats_all.append(audio_feat)
+ pooled_audio_feats_all.append(pooled_audio_feat)
+
+ audio_feats_all = torch.cat(audio_feats_all, dim=0)
+
+ pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0)
+ return audio_feats_all, pooled_audio_feats_all
+
+def extract_audio_vision_feats(data_loader, model, device, config):
+ if config.use_half_precision:
+ if config.get('use_bf16', False):
+ cast_dtype = torch.bfloat16
+ else:
+ cast_dtype = torch.float16
+ else:
+ cast_dtype = None
+
+ audio_feats_all = []
+ pooled_audio_feats_all = []
+ image_feats_all = []
+ pooled_image_feats_all = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "extracting audio and vision feats"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for media, _ in iterator:
+ audio = media[0]
+ image = media[1]
+ audio = audio.to(device, dtype=cast_dtype, non_blocking=True)
+ image = image.to(device, dtype=cast_dtype, non_blocking=True)
+ audio_feat, pooled_audio_feat = model.encode_audio(audio, test=True)
+ audio_feat = audio_feat.unsqueeze(1) # make deep_fusion happy
+ pooled_audio_feat = pooled_audio_feat.unsqueeze(1)
+ image_feat, pooled_image_feat = model.encode_vision(image, test=True)
+ if len(pooled_image_feat.shape) == 2:
+ pooled_image_feat = pooled_image_feat.unsqueeze(1) # make av_fusion happy
+ if config.evaluation.eval_frame_ensemble == "concat":
+ if len(image_feat.shape) == 4:
+ image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
+ image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
+ else:
+ assert config.video_input.num_frames == 1, "only support single-frame"
+ assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
+ if config.evaluation.eval_offload:
+ audio_feats_all.append(audio_feat.cpu())
+ pooled_audio_feats_all.append(pooled_audio_feat.cpu())
+ image_feats_all.append(image_feat.cpu())
+ pooled_image_feats_all.append(pooled_image_feat.cpu())
+ else:
+ audio_feats_all.append(audio_feat)
+ pooled_audio_feats_all.append(pooled_audio_feat)
+ image_feats_all.append(image_feat)
+ pooled_image_feats_all.append(pooled_image_feat)
+
+ audio_feats_all = torch.cat(audio_feats_all, dim=0)
+ pooled_audio_feats_all = torch.cat(pooled_audio_feats_all, dim=0)
+ image_feats_all = torch.cat(image_feats_all, dim=0)
+ pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
+
+ return audio_feats_all, pooled_audio_feats_all, image_feats_all, pooled_image_feats_all
+
+
+@torch.no_grad()
+def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""):
+ amp_eval_enabled = config.use_half_precision and not (hasattr(config, "deepspeed") and config.deepspeed.enable)
+ logger.info(f"Begin to eval, model_without_ddp.dtype={model.dtype if hasattr(model, 'dtype') else None}, amp_eval_enabled={amp_eval_enabled}, dtype={torch.bfloat16 if config.get('use_bf16', False) else torch.float16}")
+ with torch.cuda.amp.autocast(enabled=amp_eval_enabled, dtype=torch.bfloat16 if config.get('use_bf16', False) else torch.float16):
+ i2t_match, t2i_match = None, None
+ if "qformer" in config.model.model_cls.lower():
+ i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_qformer(
+ model, data_loader, tokenizer, device, config
+ )
+ elif "blip" in config.model.model_cls.lower():
+ raise NotImplementedError
+ elif "clip" in config.model.model_cls.lower() or 'coca' in config.model.model_cls.lower():
+ # raise NotImplementedError
+ i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation_clip(
+ model, data_loader, tokenizer, device, config
+ )
+ else:
+ i2t_match, t2i_match, i2t_sim, t2i_sim, i2t_dsl, t2i_dsl = evaluation(
+ model, data_loader, tokenizer, device, config
+ )
+
+ if hasattr(data_loader.dataset, "num_prompts"):
+ np = data_loader.dataset.num_prompts
+ logger.info(f"Using {np} prompts, we need reshape and mean!!!")
+ nt = len(data_loader.dataset.text) // np
+ if i2t_match is not None:
+ i2t_match = i2t_match.reshape((i2t_match.shape[0], nt, np)).mean(axis=-1)
+ t2i_match = t2i_match.reshape((nt, np, t2i_match.shape[1])).mean(axis=1)
+ i2t_sim = i2t_sim.reshape((i2t_sim.shape[0], nt, np)).mean(axis=-1)
+ t2i_sim = t2i_sim.reshape((nt, np, t2i_sim.shape[1])).mean(axis=1)
+ i2t_dsl = i2t_dsl.reshape((i2t_dsl.shape[0], nt, np)).mean(axis=-1)
+ t2i_dsl = t2i_dsl.reshape((nt, np, t2i_dsl.shape[1])).mean(axis=1)
+
+ score_pairs = [
+ (prefix + "_sim", i2t_sim, t2i_sim),
+ (prefix + "_dsl", i2t_dsl, t2i_dsl),
+ ]
+ if i2t_match is not None:
+ if config.evaluation.get('use_dsl_for_match', False):
+ score_pairs.append((prefix + "_match (use_dsl)", i2t_match, t2i_match))
+ else:
+ score_pairs.append((prefix + "_match", i2t_match, t2i_match))
+
+ res = dict()
+ for name, i2t, t2i in score_pairs:
+ if i2t is not None:
+ txt2img_ids = data_loader.dataset.txt2img
+ img2txt_ids = data_loader.dataset.img2txt
+ res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
+ return res
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ use_dsl_for_match = config.evaluation.get('use_dsl_for_match', False)
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ dtype = torch.half if config.use_half_precision else torch.float
+ media_type = data_loader.dataset.media_type
+ use_subtitle = hasattr(data_loader.dataset, "use_subtitle") and data_loader.dataset.use_subtitle
+ if use_subtitle:
+ assert media_type in ["video", "audio_video"], f"Not support media_type: {media_type}."
+ assert hasattr(data_loader.dataset, "subtitle") and data_loader.dataset.subtitle is not None, "You don't have subtitle to use."
+
+ logger.info(f"Start evaluation for media_type={media_type}")
+ assert media_type in ['audio', 'video', 'audio_video'], f"Not implement evaluation of {media_type}"
+
+ logger.info("Computing dual encoder features...")
+ start_time = time.time()
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ # max_txt_l of eval depends on data_cofig
+ max_txt_l = data_loader.dataset.max_txt_l
+
+ text_feats, text_atts = extract_text_feats(
+ texts, max_txt_l, tokenizer, model, device
+ ) # (bsz, Lt, d), (bsz, Lt)
+
+ if use_subtitle:
+ subtitle_feats, _ = extract_text_feats(
+ data_loader.dataset.subtitle, max_txt_l, tokenizer, model, device
+ ) # (bsz, Lt, d), (bsz, Lt)
+ subtitle_proj = model.text_proj(subtitle_feats[:, 0]).unsqueeze(1)
+ subtitle_feats = subtitle_feats.unsqueeze(1)
+
+ if media_type == 'video':
+ image_feats, pooled_image_feats = extract_vision_feats(
+ data_loader, model, device, config
+ ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
+ logger.info("Finished vision feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ if config.evaluation.eval_offload:
+ # image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
+ pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
+
+ if use_subtitle:
+ # print(subtitle_proj.shape, pooled_image_feats.shape)
+ i2t_scores, t2i_scores = get_sim(
+ model.vs_fusion(torch.concat([subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
+ )
+ else:
+ i2t_scores, t2i_scores = get_sim(
+ model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
+ )
+
+ if use_dsl_for_match:
+ logger.info("use_dsl_for_match!!!")
+ old_i2t_scores, old_t2i_scores = i2t_scores, t2i_scores
+ i2t_scores = old_i2t_scores * old_i2t_scores.softmax(dim=0)
+ t2i_scores = old_i2t_scores.T * old_i2t_scores.T.softmax(dim=0)
+
+ num_medias = len(data_loader.dataset.image)
+
+ # pooled_media_feats = pooled_image_feats
+ if use_subtitle:
+ media_feats = torch.concat([subtitle_feats, image_feats], dim=-2)
+ if hasattr(model, "vstm_head"):
+ match_head = model.vstm_head
+ else:
+ match_head = None
+ else:
+ media_feats = image_feats
+ if hasattr(model, "itm_head"):
+ match_head = model.itm_head
+ else:
+ match_head = None
+
+ elif media_type == 'audio':
+ audio_feats, pooled_audio_feats = extract_audio_feats(
+ data_loader, model, device, config
+ )
+ logger.info("Finished audio feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ if config.evaluation.eval_offload:
+ pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True)
+
+ i2t_scores, t2i_scores = get_sim(
+ model.audio_proj(pooled_audio_feats), model.text_proj(text_feats[:, 0])
+ )
+
+ num_medias = len(data_loader.dataset.audio)
+ media_feats = audio_feats
+ # pooled_media_feats = pooled_audio_feats
+ if hasattr(model, "atm_head"):
+ match_head = model.atm_head
+ else:
+ match_head = None
+
+ elif media_type == 'audio_video':
+ audio_feats, pooled_audio_feats, image_feats, pooled_image_feats = extract_audio_vision_feats(
+ data_loader, model, device, config
+ )
+ logger.info("Finished audio and vision feature extraction")
+
+ logger.info("Computing ITC scores [dot-product]")
+ if config.evaluation.eval_offload:
+ pooled_audio_feats = pooled_audio_feats.to(device, non_blocking=True)
+ pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
+
+ if use_subtitle:
+ i2t_scores, t2i_scores = get_sim(
+ model.avs_fusion(torch.concat([model.audio_proj(pooled_audio_feats), subtitle_proj, model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
+ )
+ else:
+ i2t_scores, t2i_scores = get_sim(
+ model.av_fusion(torch.concat([model.audio_proj(pooled_audio_feats), model.vision_proj(pooled_image_feats)], dim=-1)), model.text_proj(text_feats[:, 0])
+ )
+
+ num_medias = len(data_loader.dataset.image)
+ if use_subtitle:
+ media_feats = torch.concat([audio_feats, subtitle_feats, image_feats], dim=-2)
+ # pooled_media_feats = pooled_audio_feats
+ if hasattr(model, "avstm_head"):
+ match_head = model.avstm_head
+ else:
+ match_head = None
+ else:
+ media_feats = torch.concat([audio_feats, image_feats], dim=-2)
+ # pooled_media_feats = pooled_audio_feats
+ if hasattr(model, "avtm_head"):
+ match_head = model.avtm_head
+ else:
+ match_head = None
+ else:
+ raise NotImplementedError(media_type)
+
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ if match_head is not None:
+ i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ # computes only part of the scores at each GPU, gather at the end
+ logger.info("Rerank dual-encoder results with cross-encoder...")
+ num_tasks = get_world_size()
+ rank = get_rank()
+ # only uses the part associated with the raw eval set
+ # compute media2text #
+ step = num_medias // num_tasks + 1
+ start = rank * step
+ end = min(num_medias, start + step)
+
+ text_encoder = model.get_text_encoder()
+ iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
+ logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
+
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
+ )
+
+ assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
+
+ logger.info(
+ f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
+ )
+
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ if config.deep_fusion:
+ encoder_output = [
+ feat[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[start + i, clip_idx]
+ for feat in media_feats
+ ]
+
+ else:
+ encoder_output = (
+ media_feats[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else media_feats[start + i, clip_idx]
+ ) # (#frm*Li, d)
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+
+ if config.deep_fusion:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
+ left_encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in left_encoder_output
+ ]
+ encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
+ left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ for j in range(0, len(topk_idx), bs):
+ if j + bs > len(topk_idx):
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j:]],
+ attention_mask=text_atts[topk_idx[j:]],
+ encoder_hidden_states=left_encoder_output,
+ encoder_attention_mask=left_encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ else:
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j : j + bs]],
+ attention_mask=text_atts[topk_idx[j : j + bs]],
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
+
+ # compute text2media #
+ num_text = len(data_loader.dataset.text)
+ t2i_scores_x = torch.full((num_text, num_medias), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ step = num_text // num_tasks + 1
+ start = rank * step
+ end = min(num_text, start + step)
+
+ iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
+ logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
+ )
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+ for j in range(0, len(topk_idx), bs):
+
+ if config.deep_fusion:
+ encoder_output = [
+ feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[topk_idx[j : j + bs], clip_idx]
+ for feat in media_feats
+ ]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ encoder_output = (
+ media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to(
+ device, non_blocking=True
+ )
+ if config.evaluation.eval_offload
+ else media_feats[topk_idx[j : j + bs], clip_idx]
+ )
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ repeat_n = (
+ encoder_output.shape[0]
+ if not config.deep_fusion
+ else encoder_output[0].shape[0]
+ )
+ output = text_encoder(
+ encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
+ attention_mask=text_atts[start + i].repeat(repeat_n, 1),
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
+
+ logger.info("Compute over!!!")
+ if config.distributed:
+ logger.info("Gather across GPUs!!!")
+ # gather across GPUs
+ dist.barrier()
+ logger.info("dist.barrier()!!!")
+ dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
+ dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
+
+ if use_dsl_for_match:
+ i2t_scores_dsl = i2t_scores
+ i2t_scores_dsl_T = t2i_scores
+ i2t_scores = old_i2t_scores
+ t2i_scores = old_t2i_scores
+ else:
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+ else:
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Evaluation time {total_time_str}")
+
+
+ if match_head is not None:
+ return (
+ i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
+ t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+ else:
+ return (
+ None,
+ None,
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+
+
+@torch.no_grad()
+def evaluation_simple(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ media_type = data_loader.dataset.media_type
+
+ logger.info(f"Start evaluation for media_type={media_type}")
+ assert media_type in ['video'], f"Not implement evaluation of {media_type}"
+
+ logger.info("Computing dual encoder features...")
+ start_time = time.time()
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ # max_txt_l of eval depends on data_cofig
+ max_txt_l = data_loader.dataset.max_txt_l
+
+ text_feats, text_atts = extract_text_feats(
+ texts, max_txt_l, tokenizer, model, device
+ ) # (bsz, Lt, d), (bsz, Lt)
+
+
+ if media_type == 'video':
+ image_feats, pooled_image_feats = extract_vision_feats(
+ data_loader, model, device, config
+ ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
+ logger.info("Finished vision feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ if config.evaluation.eval_offload:
+ # image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
+ pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
+
+ i2t_scores, t2i_scores = get_sim(
+ model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
+ )
+
+ num_medias = len(data_loader.dataset.image)
+
+ media_feats = image_feats
+ if hasattr(model, "itm_head"):
+ match_head = model.itm_head
+ else:
+ match_head = None
+
+ else:
+ raise NotImplementedError(media_type)
+
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ if match_head is not None:
+ i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ # computes only part of the scores at each GPU, gather at the end
+ logger.info("Rerank dual-encoder results with cross-encoder...")
+ num_tasks = get_world_size()
+ rank = get_rank()
+ # only uses the part associated with the raw eval set
+ # compute media2text #
+ step = num_medias // num_tasks + 1
+ start = rank * step
+ end = min(num_medias, start + step)
+
+ text_encoder = model.get_text_encoder()
+ iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
+ logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
+
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
+ )
+
+ assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
+
+ logger.info(
+ f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
+ )
+
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ if config.deep_fusion:
+ encoder_output = [
+ feat[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[start + i, clip_idx]
+ for feat in media_feats
+ ]
+
+ else:
+ encoder_output = (
+ media_feats[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else media_feats[start + i, clip_idx]
+ ) # (#frm*Li, d)
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+
+ if config.deep_fusion:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
+ left_encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in left_encoder_output
+ ]
+ encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
+ left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ for j in range(0, len(topk_idx), bs):
+ if j + bs > len(topk_idx):
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j:]],
+ attention_mask=text_atts[topk_idx[j:]],
+ encoder_hidden_states=left_encoder_output,
+ encoder_attention_mask=left_encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ else:
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j : j + bs]],
+ attention_mask=text_atts[topk_idx[j : j + bs]],
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
+
+ # compute text2media #
+ num_text = len(data_loader.dataset.text)
+ t2i_scores_x = torch.full((num_text, num_medias), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ step = num_text // num_tasks + 1
+ start = rank * step
+ end = min(num_text, start + step)
+
+ iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
+ logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ media_feats.shape[1] if not config.deep_fusion else media_feats[0].shape[1]
+ )
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+ for j in range(0, len(topk_idx), bs):
+
+ if config.deep_fusion:
+ encoder_output = [
+ feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[topk_idx[j : j + bs], clip_idx]
+ for feat in media_feats
+ ]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ encoder_output = (
+ media_feats[topk_idx[j : j + bs].cpu(), clip_idx].to(
+ device, non_blocking=True
+ )
+ if config.evaluation.eval_offload
+ else media_feats[topk_idx[j : j + bs], clip_idx]
+ )
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ repeat_n = (
+ encoder_output.shape[0]
+ if not config.deep_fusion
+ else encoder_output[0].shape[0]
+ )
+ output = text_encoder(
+ encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
+ attention_mask=text_atts[start + i].repeat(repeat_n, 1),
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
+
+ logger.info("Compute over!!!")
+ if config.distributed:
+ logger.info("Gather across GPUs!!!")
+ # gather across GPUs
+ dist.barrier()
+ logger.info("dist.barrier()!!!")
+ dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
+ dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
+
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+ else:
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Evaluation time {total_time_str}")
+
+
+ if match_head is not None:
+ return (
+ i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
+ t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+ else:
+ return (
+ None,
+ None,
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+
+
+@torch.no_grad()
+def evaluation_qformer(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ dtype = torch.half if config.use_half_precision else torch.float
+ media_type = data_loader.dataset.media_type
+ logger.info(f"Start evaluation_qformer for media_type={media_type}")
+ assert media_type == 'video', f"Not implement evaluation of {media_type}"
+ logger.info("Computing dual encoder features...")
+ start_time = time.time()
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ # max_txt_l of eval depends on data_cofig
+ max_txt_l = data_loader.dataset.max_txt_l
+
+ text_feats, text_atts, text_ids = extract_text_feats(
+ texts, max_txt_l, tokenizer, model, device, return_ids=True
+ ) # (bsz, Lt, d), (bsz, Lt)
+
+ if media_type == 'video':
+ image_feats, pooled_image_feats = extract_vision_feats(
+ data_loader, model, device, config
+ ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
+ logger.info("Finished vision feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ if config.evaluation.eval_offload:
+ # image_feats = image_feats.to(device, non_blocking=True) image_feats will cause OOM!!!
+ pooled_image_feats = pooled_image_feats.to(device, non_blocking=True)
+
+ if hasattr(model, "q_vision_proj"):
+ i2t_scores, t2i_scores = get_sim(
+ model.q_vision_proj(pooled_image_feats), model.q_text_proj(text_feats[:, 0])
+ )
+ else:
+ i2t_scores, t2i_scores = get_sim(
+ model.vision_proj(pooled_image_feats), model.text_proj(text_feats[:, 0])
+ )
+
+ num_medias = len(data_loader.dataset.image)
+
+ media_feats = image_feats
+ if hasattr(model, "itm_head"):
+ match_head = model.itm_head
+ elif hasattr(model, "q_itm_head"):
+ match_head = model.q_itm_head
+ else:
+ raise NotImplementedError("you must have a match head in qformer!!!")
+
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ if match_head is not None:
+ i2t_scores_x = torch.full((num_medias, len(texts)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ # computes only part of the scores at each GPU, gather at the end
+ logger.info("Rerank dual-encoder results with cross-encoder...")
+ num_tasks = get_world_size()
+ rank = get_rank()
+ # only uses the part associated with the raw eval set
+ # compute image2text #
+ step = num_medias // num_tasks + 1
+ start = rank * step
+ end = min(num_medias, start + step)
+
+ iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
+ logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
+
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
+ )
+
+ assert not config.deep_fusion and n_clip_per_video == 1, f"Not implemented for config.deep_fusion={config.deep_fusion} n_clip_per_video={n_clip_per_video}"
+
+ logger.info(
+ f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
+ )
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ if config.deep_fusion:
+ encoder_output = [
+ feat[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[start + i, clip_idx]
+ for feat in media_feats
+ ]
+
+ else:
+ encoder_output = (
+ image_feats[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else image_feats[start + i, clip_idx]
+ ) # (#frm*Li, d)
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+
+ if not config.deep_fusion: # Create fake list
+ encoder_output = [encoder_output]
+ encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(device, non_blocking=True)
+ for feat in encoder_output
+ ]
+
+ for j in range(0, len(topk_idx), bs):
+ cur_bs = min(bs, len(topk_idx) - j)
+ encoder_output = [feat[:cur_bs] for feat in encoder_output]
+ encoder_att = [att[:cur_bs] for att in encoder_att]
+
+ batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
+ batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
+
+ output = model.vtm_embed(
+ text_ids=text_ids[topk_idx[j:j+bs]],
+ text_atts=text_atts[topk_idx[j:j+bs]],
+ vision_embeds=batch_encoder_output,
+ vision_atts=batch_encoder_att,
+ )
+
+
+ itm_embeds.append(output)
+
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
+
+ # compute text2image #
+ num_text = len(data_loader.dataset.text)
+ t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ step = num_text // num_tasks + 1
+ start = rank * step
+ end = min(num_text, start + step)
+
+ iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
+ logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
+ )
+ k = config.evaluation.k_test
+ logger.info(f"Top-{k} matching")
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+ for j in range(0, len(topk_idx), bs):
+
+ fake_image_feats = [image_feats] if not config.deep_fusion else image_feats
+
+ encoder_output = [
+ feat[topk_idx[j : j + bs].cpu(), clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else feat[topk_idx[j : j + bs], clip_idx]
+ for feat in fake_image_feats
+ ]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ cur_bs = min(bs, len(topk_idx) - j)
+
+ batch_encoder_output = encoder_output if config.deep_fusion else encoder_output[0]
+ batch_encoder_att = encoder_att if config.deep_fusion else encoder_att[0]
+
+
+ output = model.vtm_embed(
+ text_ids=text_ids[start + i].repeat(cur_bs, 1),
+ text_atts=text_atts[start + i].repeat(cur_bs, 1),
+ vision_embeds=batch_encoder_output,
+ vision_atts=batch_encoder_att,
+ )
+
+ itm_embeds.append(output)
+
+
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = match_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ raise NotImplementedError(f"len(clip_scores) == {len(clip_scores)}")
+
+ t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
+
+ logger.info("Compute over!!!")
+ if config.distributed:
+ logger.info("Gather across GPUs!!!")
+ # gather across GPUs
+ dist.barrier()
+ logger.info("dist.barrier()!!!")
+ dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM) over!!!")
+ dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
+ logger.info("dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM) over!!!")
+
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+
+ else:
+ i2t_scores_dsl = i2t_scores.float() * i2t_scores.float().softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.float().T * i2t_scores.float().T.softmax(dim=0)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Evaluation time {total_time_str}")
+
+ i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
+
+
+ if match_head is not None:
+ return (
+ i2t_scores_x.softmax(dim=1).cpu().float().numpy(),
+ t2i_scores_x.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+ else:
+ return (
+ None,
+ None,
+ i2t_scores.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores.T.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl.softmax(dim=1).cpu().float().numpy(),
+ i2t_scores_dsl_T.softmax(dim=1).cpu().float().numpy()
+ )
+
+
+@torch.no_grad()
+def evaluation_clip(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ dtype = torch.half if config.use_half_precision else torch.float
+ media_type = data_loader.dataset.media_type
+ logger.info(f"Start evaluation_clip for media_type={media_type}")
+
+ logger.info("Computing dual encoder features...")
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_feats = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i : min(num_text, i + text_bs)]
+ if "internvideo2" in config.model.model_cls.lower():
+ text_feat = model.encode_text(tokenizer(text).to(device))
+ else:
+ raise NotImplementedError
+ text_feat = model.encode_text(text)
+ text_feats.append(text_feat.cpu())
+ text_feats = torch.cat(text_feats, dim=0)
+ logger.info("Finished computing text features")
+
+ media_feats = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = f"extracting {media_type} feats!!!"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for media, _ in iterator:
+ if media_type in ['image', 'video']:
+ media = media.to(device, non_blocking=True)
+ media_feat = model.encode_vision(media, test=True)
+ elif media_type == 'audio':
+ media = media.to(device, non_blocking=True)
+ media_feat = model.encode_audio(media, test=True)
+ elif media_type == 'audio_video':
+ raise NotImplementedError(f"Not implement media_type:{media_type}")
+ else:
+ raise NotImplementedError(f"Not implement media_type:{media_type}")
+
+ media_feats.append(media_feat.cpu())
+
+ media_feats = torch.cat(media_feats, dim=0)
+ logger.info("Finished feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ # print(media_feats.dtype, text_feats.dtype)
+ # print(media_feats.device, text_feats.device)
+ i2t_scores, t2i_scores = get_sim(media_feats.float(), text_feats.float())
+ del media_feats, text_feats
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
+
+ return (
+ i2t_scores.cpu().float().numpy(),
+ i2t_scores.T.cpu().float().numpy(),
+ i2t_scores_dsl.cpu().float().numpy(),
+ i2t_scores_dsl_T.cpu().float().numpy(),
+ )
+
+
+@torch.no_grad()
+def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
+ # Images->Text
+ ranks = np.zeros(scores_i2t.shape[0])
+ for index, score in enumerate(scores_i2t):
+ inds = np.argsort(score)[::-1]
+ # Score
+ gt_txt_ids = img2txt[index]
+ if isinstance(gt_txt_ids, int):
+ ranks[index] = np.where(inds == gt_txt_ids)[0][0]
+ else:
+ rank = 1e20
+ for i in gt_txt_ids:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ # Text->Images
+ ranks = np.zeros(scores_t2i.shape[0])
+
+ for index, score in enumerate(scores_t2i):
+ inds = np.argsort(score)[::-1]
+ gt_img_ids = txt2img[index]
+ if isinstance(gt_img_ids, int):
+ ranks[index] = np.where(inds == gt_img_ids)[0][0]
+ else: # list, used in the case each caption has multiple GT images
+ # Score
+ rank = 1e20
+ for i in gt_img_ids:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ ir_mean = (ir1 + ir5 + ir10) / 3
+ r_mean = (tr_mean + ir_mean) / 2
+
+ eval_result = {
+ "v2t_r1": tr1,
+ "v2t_r5": tr5,
+ "v2t_r10": tr10,
+ "v2t_r_mean": tr_mean,
+ "t2v_r1": ir1,
+ "t2v_r5": ir5,
+ "t2v_r10": ir10,
+ "t2v_r_mean": ir_mean,
+ "r_mean": r_mean,
+ }
+ eval_result = {k: round(v, 2) for k, v in eval_result.items()}
+ return eval_result
+
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks/shared_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks/shared_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e86658bb9595d6ec56abc5d68023a35d190d266b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks/shared_utils.py
@@ -0,0 +1,235 @@
+import copy
+import logging
+import os
+import os.path as osp
+import io
+
+try:
+ import deepspeed
+except Exception as e:
+ print(e)
+ print("deepspeed is not installed!!!")
+
+from os.path import join
+
+
+try:
+ from petrel_client.client import Client
+except:
+ Client = None
+
+import torch
+from torch.utils.data import ConcatDataset, DataLoader
+
+from dataset.resample_concat_dataset import ResampleConcatDataset
+from models.backbones.internvideo2.pos_embed import interpolate_pos_embed_internvideo2_new
+from models.backbones.bert.tokenization_bert import BertTokenizer
+from utils.optimizer import create_optimizer
+from utils.scheduler import create_scheduler
+from utils.distributed import get_rank
+
+logger = logging.getLogger(__name__)
+
+
+def get_media_types(datasources):
+ """get the media types for for all the dataloaders.
+
+ Args:
+ datasources (List): List of dataloaders or datasets.
+
+ Returns: List. The media_types.
+
+ """
+ if isinstance(datasources[0], DataLoader):
+ datasets = [dataloader.dataset for dataloader in datasources]
+ else:
+ datasets = datasources
+ media_types = [
+ dataset.datasets[0].media_type
+ if isinstance(dataset, ConcatDataset) or isinstance(dataset, ResampleConcatDataset)
+ else dataset.media_type
+ for dataset in datasets
+ ]
+
+ return media_types
+
+
+def setup_model(
+ config, model_cls, add_decoder=False, pretrain=False, find_unused_parameters=False
+):
+ logger.info("Creating model")
+ config = copy.deepcopy(config)
+
+ if "bert" in config.model.text_encoder.name:
+ logger.info(f"Using BertTokenizer: {config.model.text_encoder.pretrained}!")
+ tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True)
+ model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain)
+ else:
+ model = model_cls(config=config, is_pretrain=pretrain)
+ tokenizer = model.tokenizer
+ logger.info(f"Using model.tokenizer: {tokenizer}!")
+
+
+ if config.get('compile_model', False):
+ torch.set_float32_matmul_precision('high')
+ model = torch.compile(model)
+
+ model = model.to(torch.device(config.device))
+ model_without_ddp = model
+
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ # We move this to the back
+ optimizer_params = create_optimizer(config.optimizer, model, return_group=True)
+ scheduler = None
+ scaler = None
+ else:
+ if config.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[config.gpu],
+ find_unused_parameters=find_unused_parameters, # `False` for image-only task
+ )
+
+ optimizer = create_optimizer(config.optimizer, model)
+ scaler = torch.cuda.amp.GradScaler(enabled=config.use_half_precision) # This is never used actually if we fixed bf16
+ scheduler = create_scheduler(config.scheduler, optimizer)
+
+
+ start_epoch = 0
+ global_step = 0
+
+ # auto resume the latest checkpoint
+ if config.get("auto_resume", False):
+ logger.info("Auto resuming")
+ model_latest = join(config.output_dir, "ckpt_latest.pth")
+ model_best = join(config.output_dir, "ckpt_best.pth")
+ large_num = -1
+ for p in os.listdir(config.output_dir):
+ if 'ckpt' in p:
+ num = p.split('_')[1].split('.')[0]
+ if str.isnumeric(num):
+ if int(num) > large_num:
+ large_num = int(num)
+ if large_num != -1:
+ model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
+ if osp.isfile(model_latest):
+ config.pretrained_path = model_latest
+ config.resume = True
+ elif osp.isfile(model_best):
+ config.pretrained_path = model_best
+ config.resume = True
+ else:
+ logger.info(f"Not found checkpoint in {config.output_dir}")
+
+
+ if (config.pretrained_path.strip() and (osp.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path):
+ if Client is not None:
+ client = Client()
+ with io.BytesIO(client.get(config.pretrained_path)) as buffer:
+ checkpoint = torch.load(buffer, map_location="cpu")
+ else:
+ checkpoint = torch.load(config.pretrained_path, map_location="cpu")
+ logger.info(f"Loading checkpoint from {config.pretrained_path}")
+ try:
+ if "model" in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint["module"] # This is a deepspeed stage 1 model
+ except:
+ state_dict = checkpoint
+
+ if config.get('origin_num_frames', None) is not None:
+ logger.info(f"interpolate_pos_embed_internvideo2 (origin_num_frames={config.origin_num_frames})!!!")
+ a = len(state_dict)
+ interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames)
+ assert a == len(state_dict), state_dict.keys()
+
+ if config.resume:
+ assert not (hasattr(config, "deepspeed") and config.deepspeed.enable), "Deepspeed should run here!!!"
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ scheduler.load_state_dict(checkpoint["scheduler"])
+ scaler.load_state_dict(checkpoint["scaler"])
+ if 'local_step' in checkpoint.keys():
+ start_epoch = checkpoint['epoch']
+ else:
+ start_epoch = checkpoint["epoch"] + 1
+ global_step = checkpoint["global_step"]
+
+ elif not pretrain: # downstream init from pretrained ckpt
+
+
+ if not config.evaluate or config.get("zero_shot", False): # finetuning from a pretrained weights.
+ if add_decoder:
+ logger.info("Init new decoder with encoder!!!")
+ for key in list(state_dict.keys()):
+ if "text_encoder.bert" in key:
+ encoder_key = key.replace("bert.", "")
+ state_dict[encoder_key] = state_dict[key]
+ if not add_decoder:
+ del state_dict[key]
+
+ # init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
+ # only for generation tasks like VQA
+ if add_decoder and "text_encoder.bert" in key:
+ if "layer" in key:
+ encoder_keys = key.split(".")
+ layer_num = int(encoder_keys[4])
+ if layer_num < config.model.text_encoder.fusion_layer:
+ del state_dict[key]
+ continue
+ else:
+ decoder_layer_num = layer_num - config.model.text_encoder.fusion_layer
+ encoder_keys[4] = str(decoder_layer_num)
+ encoder_key = ".".join(encoder_keys)
+ else:
+ encoder_key = key
+ decoder_key = encoder_key.replace("text_encoder", "text_decoder")
+ state_dict[decoder_key] = state_dict[key]
+ del state_dict[key]
+
+
+ msg = model_without_ddp.load_state_dict(state_dict, strict=False)
+ logger.info(msg)
+ logger.info(f"Loaded checkpoint from {config.pretrained_path}")
+ else:
+ if not config.resume:
+ assert not config.evaluate, "No available pretrained checkpoint provided!!!"
+ assert config.pretrained_path == "", config.pretrained_path
+ logger.warning("No available pretrained checkpoint provided, training from scratch.")
+
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ logger.info(f'Use deepspeed to initialize model (resume={config.resume}) !!!')
+ model = model_without_ddp
+
+ model, optimizer, _, _ = deepspeed.initialize(
+ args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed,
+ lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt)
+ )
+
+
+ if config.resume:
+ logger.info(f'Resume deepspeed ckpt from {config.output_dir}, tag={config.pretrained_path}, load_module_strict={config.get("load_module_strict", True)}, load_lr_scheduler_states={config.get("load_lr_scheduler_states", True)}!!!')
+ _, client_states = model.load_checkpoint(config.output_dir, tag=config.pretrained_path, load_module_strict=config.get("load_module_strict", True), load_lr_scheduler_states=config.get("load_lr_scheduler_states", True))
+ logger.info(client_states)
+ if 'local_step' in client_states.keys():
+ start_epoch = client_states['epoch']
+ else:
+ start_epoch = client_states['epoch'] + 1
+ global_step = client_states['global_step']
+
+
+ logger.info(f"Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M start_epoch={start_epoch}, global_step={global_step}")
+ print(f"\033[31m Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M start_epoch={start_epoch}, global_step={global_step}\033[0m")
+
+ return (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ )
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/pretrain.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d42de2ab2deffeec4bcb2a292821d013f54809
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/pretrain.py
@@ -0,0 +1,359 @@
+import datetime
+import logging
+import time
+from os.path import join
+
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import wandb
+from torch.utils.data import ConcatDataset
+
+from dataset.serialize import local_broadcast_process_authkey
+from dataset import MetaLoader_rs, create_dataset, create_loader, create_sampler, create_stateful_sampler
+from models import *
+from tasks_clip.retrieval_utils import evaluation_wrapper
+from tasks_clip.shared_utils import get_media_types, setup_model
+from utils.basic_utils import MetricLogger, SmoothedValue, setup_seed
+from utils.config_utils import setup_main
+from utils.distributed import get_rank, is_main_process
+from utils.logger import log_dict_to_wandb, setup_wandb
+
+logger = logging.getLogger(__name__)
+
+
+def train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type,
+ skip_num=0
+):
+ model.train()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}"))
+ metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}"))
+ loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+
+ media_types = get_media_types(train_loaders)
+
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(
+ f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}")
+ )
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ if config.distributed:
+ for d in train_loaders:
+ d.sampler.set_epoch(epoch)
+ train_loader = MetaLoader_rs(name2loader=dict(list(zip(media_types, train_loaders))), skip_num=skip_num)
+
+ model_without_ddp = model.module if config.distributed else model
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+ for i, (media_type, (image, text, idx)) in enumerate(iterator):
+ image = image.to(device, non_blocking=True)
+ idx = idx.to(device, non_blocking=True)
+ text_input = tokenizer(text).to(device)
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ loss_dict = model(image, text_input, idx=idx)
+ loss = sum(loss_dict.values())
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ model.backward(loss)
+ model.step()
+ else:
+ if not config.use_half_precision or config.get('use_bf16', True):
+ optimizer.zero_grad()
+ loss.backward()
+ if config.optimizer.max_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ optimizer.step()
+ scheduler.step()
+ else:
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ if config.optimizer.max_grad_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+
+ # logging
+ for name in loss_names:
+ value = loss_dict[name]
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(temperature=model_without_ddp.temp.item())
+
+ if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
+ logs = metric_logger.get_global_avg_dict()
+ log_dict_to_wandb(logs, step=global_step, prefix="train/")
+
+ global_step += 1
+
+ if config.debug and global_step % 20 == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ if config.debug and global_step % (2 * log_freq + 3) == 0:
+ logger.info("debug mode, break training loop")
+ break
+
+ if config.get('save_iter', 0) and global_step % config.save_iter == 0:
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ tag = f"ckpt_iter{global_step:02d}.pth"
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True)
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ param_grad_dict = {
+ k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
+ }
+ for k in list(state_dict.keys()):
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
+ # delete parameters that do not require gradient
+ logger.info(f"Not saving {k}")
+ del state_dict[k]
+ save_obj = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ torch.save(save_obj, join(config.output_dir, f"ckpt_iter{global_step:02d}.pth"))
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged stats: {metric_logger.global_avg()}")
+ return global_step
+
+
+def setup_dataloaders(config, mode="pt"):
+ # train datasets, create a list of data loaders
+ logger.info(f"Creating dataset for {mode}")
+ train_datasets = create_dataset(f"{mode}_train", config)
+ media_types = get_media_types(train_datasets)
+
+ if config.distributed:
+ batch_size = [config.inputs.batch_size[k] for k in media_types] # batch_size for each GPU
+ samplers = create_stateful_sampler(train_datasets, batch_size)
+ else:
+ raise NotImplementedError
+
+ train_loaders = create_loader(
+ train_datasets,
+ samplers,
+ batch_size=[config.inputs.batch_size[k] for k in media_types],
+ num_workers=[config.num_workers] * len(media_types),
+ is_trains=[True] * len(media_types),
+ collate_fns=[None] * len(media_types),
+ )
+
+ # test datasets, a mapping from dataset name to data loader
+ test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config)
+ test_loaders = create_loader(
+ test_datasets,
+ [None] * len(test_datasets),
+ batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets],
+ num_workers=[config.num_workers] * len(test_datasets),
+ is_trains=[False] * len(test_datasets),
+ collate_fns=[None] * len(test_datasets),
+ )
+ test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)}
+ return train_loaders, test_name2loaders, media_types
+
+
+def main(config):
+ if is_main_process() and config.wandb.enable:
+ run = setup_wandb(config)
+
+ is_pretrain = config.mode == "pt"
+
+ logger.info(f"train_file: {config.train_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+
+ train_loaders, test_name2loaders, train_media_types = setup_dataloaders(
+ config, mode=config.mode
+ )
+ num_steps_per_epoch = sum(len(d) for d in train_loaders)
+
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
+ # set cudnn.benchmark=True only when input size is fixed
+ # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
+ cudnn.benchmark = len(train_media_types) == 1
+
+ model_cls = eval(config.model.get('model_cls', 'InternVideo2_CLIP'))
+ (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ ) = setup_model(
+ config,
+ model_cls=model_cls,
+ pretrain=is_pretrain,
+ find_unused_parameters=True,
+ num_steps_per_epoch=num_steps_per_epoch,
+ )
+ if is_main_process() and config.wandb.enable:
+ wandb.watch(model)
+
+ best = 0
+ best_epoch = 0
+
+ if config.get('use_bf16', True):
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float16
+
+ logger.info("Start training")
+ logger.info(f"Epoch: {start_epoch}")
+ start_time = time.time()
+ start_step = start_epoch * num_steps_per_epoch
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ global_step = train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type,
+ skip_num = global_step - start_step
+ )
+
+ # save checkpoint befor evaluation
+ # only save those with gradient
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if config.get("save_latest", False):
+ tag = "ckpt_latest.pth"
+ else:
+ tag = f"ckpt_{epoch:02d}.pth"
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True)
+
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ param_grad_dict = {
+ k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
+ }
+ for k in list(state_dict.keys()):
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
+ # delete parameters that do not require gradient
+ logger.info(f"Not saving {k}")
+ del state_dict[k]
+
+ save_obj = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ if config.get("save_latest", False):
+ torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
+ else:
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
+
+ # evaluation
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ eval_res = {}
+ for test_name, test_loader in test_name2loaders.items():
+ if test_name not in config.test_types:
+ logger.info(
+ f"Skip eval {test_name} split. All test_types {config.test_types}"
+ )
+ continue
+ res = evaluation_wrapper(
+ model_without_ddp, test_loader, tokenizer, device, config, data_type=data_type, prefix=test_name
+ )
+ eval_res.update(res)
+
+ # save the best checkpoint
+ if is_main_process():
+ # log to wandb
+ if config.wandb.enable:
+ for p, v in eval_res.items():
+ log_dict_to_wandb(v, step=global_step, prefix=p)
+
+ if config.stop_key is not None and config.stop_key in eval_res:
+ cur_r_mean = eval_res[config.stop_key]["r_mean"]
+ else: # None
+ cur_r_mean = best + 1 # save the last as the best
+
+ eval_res = pd.DataFrame(eval_res)
+ logger.info(f"Epoch {epoch}")
+ logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}")
+
+ eval_res.to_json(join(config.output_dir, "eval_res_latest.json"))
+
+ if not config.evaluate and cur_r_mean > best:
+ if not hasattr(config, "deepspeed") or not config.deepspeed.enable:
+ torch.save(save_obj, join(config.output_dir, "ckpt_best.pth"))
+ eval_file = "eval_res_best.json"
+ eval_res.to_json(join(config.output_dir, eval_file))
+ best = cur_r_mean
+ best_epoch = epoch
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ r_mean_best = torch.tensor([0.0, 0.0]).to(device)
+ if is_main_process():
+ r_mean_best[0] = cur_r_mean
+ r_mean_best[1] = best
+ dist.broadcast(r_mean_best, 0)
+ cur_r_mean, best = r_mean_best[0].item(), r_mean_best[1].item()
+
+ if not config.evaluate and cur_r_mean > best:
+ model.save_checkpoint(config.output_dir, tag="ckpt_best.pth", save_latest=False, exclude_frozen_parameters=True)
+
+ if config.evaluate:
+ break
+
+ start_step = global_step
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"best epoch {best_epoch} [config.stop_key {config.stop_key}]")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+ if is_main_process() and config.wandb.enable:
+ run.finish()
+
+
+if __name__ == "__main__":
+ cfg = setup_main()
+ local_broadcast_process_authkey()
+ main(cfg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0a7ad6f0ca2e2ef7c755888b565dae796474ff
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval.py
@@ -0,0 +1,309 @@
+import copy
+import datetime
+import logging
+import os
+import time
+from os.path import join
+
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import wandb
+
+from dataset import MetaLoader
+from models import *
+from tasks_clip.pretrain import setup_dataloaders
+from tasks_clip.retrieval_utils import evaluation_wrapper
+from tasks_clip.shared_utils import setup_model
+from utils.basic_utils import MetricLogger, SmoothedValue, setup_seed
+from utils.config import Config
+from utils.config_utils import setup_main
+from utils.distributed import get_rank, is_main_process
+from utils.logger import log_dict_to_wandb, setup_wandb
+
+logger = logging.getLogger(__name__)
+
+
+def train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type
+):
+ model.train()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("temperature", SmoothedValue(window=1, fmt="{value:.4f}"))
+ loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+
+ media_types = [loader.dataset.media_type for loader in train_loaders]
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(f"{m}-{name}", SmoothedValue(window=1, fmt="{value:.4f}"))
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ if config.distributed:
+ for d in train_loaders:
+ d.sampler.set_epoch(epoch)
+ train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
+
+ model_without_ddp = model.module if config.distributed else model
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+ for i, (media_type, (image, text, idx)) in enumerate(iterator):
+ image = image.to(device, non_blocking=True)
+ idx = idx.to(device, non_blocking=True)
+ text_input = tokenizer(text).to(device)
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ loss_dict = model(image, text_input, idx=idx)
+ loss = sum(loss_dict.values())
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ model.backward(loss)
+ model.step()
+ else:
+ if not config.use_half_precision or config.get('use_bf16', True):
+ optimizer.zero_grad()
+ loss.backward()
+ if config.optimizer.max_grad_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ optimizer.step()
+ scheduler.step()
+ else:
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ if config.optimizer.max_grad_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+
+ # logging
+ for name in loss_names:
+ value = loss_dict[name]
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(temperature=model_without_ddp.temp.item())
+
+ if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
+ logs = metric_logger.get_global_avg_dict()
+ log_dict_to_wandb(logs, step=global_step, prefix="train/")
+
+ global_step += 1
+
+ if config.debug and (i + 1) % 5 == 0:
+ break
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged train stats: {metric_logger.global_avg()}")
+ return global_step
+
+
+def main(config):
+ if is_main_process() and config.wandb.enable:
+ run = setup_wandb(config)
+
+ logger.info(f"config: \n{config}")
+ logger.info(f"train_file: {config.train_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+ cudnn.benchmark = True
+
+ train_loaders, test_name2loaders, train_media_types = setup_dataloaders(config, mode="ret")
+ num_steps_per_epoch = sum(len(d) for d in train_loaders)
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
+
+ model_cls = eval(config.model.get('model_cls', 'InternVideo2_CLIP'))
+ (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ ) = setup_model(
+ config,
+ model_cls=model_cls,
+ pretrain=False,
+ # find_unused_parameters=True,
+ find_unused_parameters=False,
+ )
+ if is_main_process() and config.wandb.enable:
+ wandb.watch(model)
+
+ best = 0
+ best_epoch = 0
+
+ if config.get('use_bf16', True):
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float16
+
+ logger.info("Start " + "evaluation" if config.evaluate else "training")
+ start_time = time.time()
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ global_step = train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ )
+
+ # save checkpoint befor evaluation
+ # only save those with gradient
+ if not config.evaluate:
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if config.get("save_latest", False):
+ tag = "ckpt_latest.pth"
+ else:
+ tag = f"ckpt_{epoch:02d}.pth"
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True)
+
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ param_grad_dict = {
+ k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
+ }
+ for k in list(state_dict.keys()):
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
+ # delete parameters that do not require gradient
+ logger.info(f"Not saving {k}")
+ del state_dict[k]
+
+ save_obj = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ if config.get("save_latest", False):
+ torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
+ else:
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ eval_res = {}
+ for test_name, test_loader in test_name2loaders.items():
+ if test_name not in config.test_types:
+ logger.info(
+ f"Skip eval {test_name} split. All test_types {config.test_types}"
+ )
+ continue
+ res = evaluation_wrapper(
+ model_without_ddp, test_loader, tokenizer, device, config, data_type=data_type, prefix=test_name
+ )
+ eval_res.update(res)
+
+ # save the best checkpoint
+ if is_main_process():
+ # log to wandb
+ if config.wandb.enable:
+ for p, v in eval_res.items():
+ log_dict_to_wandb(v, step=global_step, prefix=p)
+
+ if config.stop_key is not None and config.stop_key in eval_res:
+ cur_r_mean = eval_res[config.stop_key]["r_mean"]
+ else: # None
+ cur_r_mean = best + 1 # save the last as the best
+
+ eval_res = pd.DataFrame(eval_res)
+ logger.info(f"Epoch {epoch}")
+ logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}")
+
+ eval_res.to_json(join(config.output_dir, "eval_res_latest.json"))
+
+ if not config.evaluate and cur_r_mean > best:
+ if not hasattr(config, "deepspeed") or not config.deepspeed.enable:
+ torch.save(save_obj, join(config.output_dir, "ckpt_best.pth"))
+ eval_file = "eval_res_best.json"
+ eval_res.to_json(join(config.output_dir, eval_file))
+ best = cur_r_mean
+ best_epoch = epoch
+ if config.evaluate:
+ eval_file = "eval_res.json"
+ eval_res.to_json(join(config.output_dir, eval_file))
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ r_mean_best = torch.tensor([0.0, 0.0]).to(device)
+ if is_main_process():
+ r_mean_best[0] = cur_r_mean
+ r_mean_best[1] = best
+ dist.broadcast(r_mean_best, 0)
+ cur_r_mean, best = r_mean_best[0].item(), r_mean_best[1].item()
+
+ if not config.evaluate and cur_r_mean > best:
+ model.save_checkpoint(config.output_dir, tag="ckpt_best.pth", save_latest=False, exclude_frozen_parameters=True)
+
+ if config.evaluate or config.debug:
+ break
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"best epoch {best_epoch} [config.stop_key {config.stop_key}]")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+ if is_main_process() and config.wandb.enable:
+ run.finish()
+
+
+def eval_after_training(train_config):
+ # general config for all
+ train_config.wandb.enable = False
+ train_config.evaluate = True
+ train_config.pretrained_path = join(train_config.output_dir, "ckpt_best.pth")
+ train_config.num_frames_test = train_config.num_frames
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames
+
+ if train_config.get('num_frames_test_final', False):
+ train_config.num_frames_test = train_config.num_frames_test_final
+ train_config.batch_size = train_config.batch_size_final
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames_test_final
+ train_config.model.vision_encoder.num_frames = train_config.num_frames_test_final
+
+ eval_config = copy.deepcopy(train_config)
+ eval_config.test_types = list(eval_config.test_file.keys())
+ eval_config.output_dir = join(eval_config.output_dir, f"eval_after_training")
+ eval_config.result_dir = eval_config.output_dir
+ if is_main_process():
+ os.makedirs(eval_config.output_dir, exist_ok=True)
+ Config.dump(eval_config, os.path.join(eval_config.output_dir, "config.json"))
+ logger.info(f"===========> START eval_after_training [{eval_config.test_types}]")
+ main(eval_config)
+
+
+if __name__ == "__main__":
+ cfg = setup_main()
+ main(cfg)
+ if not cfg.evaluate:
+ eval_after_training(cfg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc.py
new file mode 100644
index 0000000000000000000000000000000000000000..34166e057d0c4c294290f7dce1fb27722e8c0a9b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc.py
@@ -0,0 +1,401 @@
+import copy
+import datetime
+import logging
+import os
+import time
+import json
+from os.path import join
+
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import wandb
+import torch.nn.functional as F
+from einops import rearrange
+
+from dataset import MetaLoader, create_dataset, create_loader, create_sampler
+from models.utils import tile
+from models import *
+from tasks_clip.shared_utils import get_media_types, setup_model
+from utils.basic_utils import MetricLogger, SmoothedValue, setup_seed, flat_list_of_lists, save_json
+from utils.config import Config
+from utils.config_utils import setup_main
+from utils.distributed import get_rank, get_world_size, is_main_process
+from utils.logger import log_dict_to_wandb, setup_wandb
+
+logger = logging.getLogger(__name__)
+
+
+def get_sim_for_each_question(model, pooled_image_feat, pooled_text_feat, model_cls):
+ """TODO: Docstring for get_sim_for_each_question.
+
+ Args:
+ model (TODO): TODO
+ pooled_image_feat (torch.Tensor): Shape: [b, c]
+ pooled_text_feat (torch.Tensor): Shape: [b, n, c]. n is the number of answer candidates.
+
+ Returns: TODO
+
+ """
+ image_feat = F.normalize(pooled_image_feat, dim=-1).to(torch.float32)
+ text_feat = F.normalize(pooled_text_feat, dim=-1).to(torch.float32)
+ sim = torch.matmul(image_feat.unsqueeze(1), rearrange(text_feat, "b n c -> b c n")) # [b, 1, n]
+ if "InternVL" in model_cls:
+ sim = sim[:, 0] * model.logit_scale # [b, n]
+ else: # for UMT
+ sim = sim[:, 0] / model.temp # [b, n]
+ sim = F.softmax(sim, dim=1) # [b, n]
+ return sim
+
+
+def main_with_ensemble(config, test_loader, model_without_ddp, tokenizer, data_type):
+ logger.info(f"test_file: {config.test_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+ cudnn.benchmark = True
+
+ config.scheduler.num_training_steps = 10
+ config.scheduler.num_warmup_steps = 10
+ model = model_without_ddp
+ model.eval()
+
+ logger.info("Start " + "evaluation" if config.evaluate else "training")
+ metric_logger = MetricLogger(delimiter=" ")
+ iterator = metric_logger.log_every(test_loader, 5, "Evaluation: ")
+ num_options_per_q = 174
+ all_gt_answers = []
+ all_pred_answers = []
+ predictions = []
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type), torch.no_grad():
+ for image, text, ans, ann in iterator:
+ image = image.to(device, non_blocking=True) # bsz
+ all_gt_answers.append(ans)
+ text = flat_list_of_lists(list(zip(*text))) # List(str), len=bsz*174
+ text_input = tokenizer(text).to(device) # bsz*174
+
+ # encode text
+ pooled_text_feat = model.encode_text(text_input) # [b*174, c]
+ # encode image
+ pooled_image_feat = model.encode_vision(image, test=True) # [b, c]
+
+ # contrastive score
+ pooled_text_feat = rearrange(pooled_text_feat, "(b n) c -> b n c", n=num_options_per_q)
+ score = get_sim_for_each_question(model, pooled_image_feat, pooled_text_feat, model_cls=config.model.model_cls).cpu() # [b, n]
+
+ pred_ans = score.max(1)[1].cpu()
+ all_pred_answers.append(pred_ans)
+
+ # assemble predictions
+ for q_idx in range(len(score)): # bsz
+ _pred = dict(
+ video=ann["video"][q_idx],
+ answer=ann["answer"][q_idx].item(),
+ pred_ans=pred_ans[q_idx].item(),
+ pred_scores=score[q_idx].numpy(), # (174, )
+ )
+ predictions.append(_pred)
+
+ all_gt_answers = torch.cat(all_gt_answers, 0)
+ all_pred_answers = torch.cat(all_pred_answers, 0)
+ acc = all_gt_answers == all_pred_answers
+ acc = float(torch.sum(acc) / len(acc))
+ eval_res = {"acc": round(100 * acc, 2)}
+ logger.info(f"\n{eval_res}")
+ save_json(eval_res, join(config.output_dir, "eval_res.json"))
+ torch.save(predictions, join(config.output_dir, "prediction_scores.pth"))
+ return eval_res
+
+
+def train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type
+):
+ model.train()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("temperature", SmoothedValue(window=1, fmt="{value:.4f}"))
+ loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+
+ media_types = [loader.dataset.media_type for loader in train_loaders]
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(f"{m}-{name}", SmoothedValue(window=1, fmt="{value:.4f}"))
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ if config.distributed:
+ for d in train_loaders:
+ d.sampler.set_epoch(epoch)
+ train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
+
+ model_without_ddp = model.module if config.distributed else model
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+ for i, (media_type, (image, text, idx)) in enumerate(iterator):
+ image = image.to(device, non_blocking=True)
+ idx = idx.to(device, non_blocking=True)
+ text_input = tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=config.max_txt_l,
+ return_tensors="pt",
+ ).to(device)
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ loss_dict = model(image, text_input, idx=idx)
+ loss = sum(loss_dict.values())
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ if config.optimizer.max_grad_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+
+ # logging
+ for name in loss_names:
+ value = loss_dict[name]
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(temperature=model_without_ddp.temp.item())
+
+ if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
+ logs = metric_logger.get_global_avg_dict()
+ log_dict_to_wandb(logs, step=global_step, prefix="train/")
+
+ global_step += 1
+
+ if config.debug and (i + 1) % 5 == 0:
+ break
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged train stats: {metric_logger.global_avg()}")
+ return global_step
+
+
+def main(config):
+ if is_main_process() and config.wandb.enable:
+ run = setup_wandb(config)
+
+ logger.info(f"config: \n{config}")
+ logger.info(f"train_file: {config.train_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+ cudnn.benchmark = True
+
+ # create train loader
+ train_datasets = create_dataset("ret_train", config)
+ media_types = get_media_types(train_datasets)
+ if config.distributed:
+ num_tasks = get_world_size()
+ global_rank = get_rank()
+ samplers = create_sampler(
+ train_datasets, [True] * len(media_types), num_tasks, global_rank
+ )
+ else:
+ samplers = [None] * len(media_types)
+ train_loaders = create_loader(
+ train_datasets,
+ samplers,
+ batch_size=[config.inputs.batch_size[k] for k in media_types],
+ num_workers=[config.num_workers] * len(media_types),
+ is_trains=[True] * len(media_types),
+ collate_fns=[None] * len(media_types),
+ )
+
+ num_steps_per_epoch = sum(len(d) for d in train_loaders)
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
+
+ model_cls = eval(config.model.get('model_cls', 'InternVideo2_CLIP'))
+ (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ ) = setup_model(
+ config,
+ model_cls=model_cls,
+ pretrain=False,
+ # find_unused_parameters=True,
+ find_unused_parameters=False,
+ )
+ if is_main_process() and config.wandb.enable:
+ wandb.watch(model)
+
+ # create test dataloader
+ test_dataset = create_dataset("mc_new_test", config)
+ test_loader = create_loader(
+ [test_dataset],
+ [None],
+ batch_size=[config.inputs.batch_size_test.video],
+ num_workers=[config.num_workers],
+ is_trains=[False],
+ collate_fns=[None],
+ )[0]
+
+ best = 0
+ best_epoch = 0
+
+ if config.get('use_bf16', True):
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float16
+
+ logger.info("Start " + "evaluation" if config.evaluate else "training")
+ start_time = time.time()
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ global_step = train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type=data_type
+ )
+
+ # save checkpoint befor evaluation
+ # only save those with gradient
+ if not config.evaluate:
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if config.get("save_latest", False):
+ tag = "ckpt_latest.pth"
+ else:
+ tag = f"ckpt_{epoch:02d}.pth"
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True)
+
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ param_grad_dict = {
+ k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
+ }
+ for k in list(state_dict.keys()):
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
+ # delete parameters that do not require gradient
+ logger.info(f"Not saving {k}")
+ del state_dict[k]
+
+ save_obj = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ if config.get("save_latest", False):
+ torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
+ else:
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ res = main_with_ensemble(config, test_loader, model_without_ddp, tokenizer, data_type=data_type)
+ eval_res = res
+
+ if is_main_process():
+ if config.wandb.enable:
+ log_dict_to_wandb(eval_res, step=global_step, prefix=config.test_types)
+
+ acc = eval_res["acc"]
+ logger.info(f"Epoch {epoch}")
+ logger.info(f"\n{eval_res}")
+
+ save_json(eval_res, join(config.output_dir, "eval_res_latest.json"))
+
+ if not config.evaluate and acc > best:
+ if not hasattr(config, "deepspeed") or not config.deepspeed.enable:
+ torch.save(save_obj, join(config.output_dir, "ckpt_best.pth"))
+ eval_file = "eval_res_best.json"
+ save_json(eval_res, join(config.output_dir, eval_file))
+ best = acc
+ best_epoch = epoch
+ if config.evaluate:
+ eval_file = "eval_res.json"
+ save_json(eval_res, join(config.output_dir, eval_file))
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ acc_best = torch.tensor([0.0, 0.0]).to(device)
+ if is_main_process():
+ acc_best[0] = acc
+ acc_best[1] = best
+ dist.broadcast(acc_best, 0)
+ acc, best = acc_best[0].item(), acc_best[1].item()
+
+ if not config.evaluate and acc > best:
+ model.save_checkpoint(config.output_dir, tag="ckpt_best.pth", save_latest=False, exclude_frozen_parameters=True)
+
+ if config.evaluate or config.debug:
+ break
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"best epoch {best_epoch}")
+ logger.info(f"best {best}")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+ if is_main_process() and config.wandb.enable:
+ run.finish()
+
+
+def eval_after_training(train_config):
+ # general config for all
+ train_config.wandb.enable = False
+ train_config.evaluate = True
+ train_config.pretrained_path = join(train_config.output_dir, "ckpt_best.pth")
+ train_config.num_frames_test = train_config.num_frames
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames
+
+ if train_config.get('num_frames_test_final', False):
+ train_config.num_frames_test = train_config.num_frames_test_final
+ train_config.batch_size = train_config.batch_size_final
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames_test_final
+ train_config.model.vision_encoder.num_frames = train_config.num_frames_test_final
+
+ eval_config = copy.deepcopy(train_config)
+ eval_config.test_types = list(eval_config.test_file.keys())
+ eval_config.output_dir = join(eval_config.output_dir, f"eval_after_training")
+ eval_config.result_dir = eval_config.output_dir
+ if is_main_process():
+ os.makedirs(eval_config.output_dir, exist_ok=True)
+ Config.dump(eval_config, os.path.join(eval_config.output_dir, "config.json"))
+ logger.info(f"===========> START eval_after_training [{eval_config.test_types}]")
+ main(eval_config)
+
+
+if __name__ == "__main__":
+ cfg = setup_main()
+ main(cfg)
+ if not cfg.evaluate:
+ eval_after_training(cfg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc2.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4af66ba2a5044ab24e9042603f6deb629c5bd8b4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_mc2.py
@@ -0,0 +1,405 @@
+import copy
+import datetime
+import logging
+import os
+import time
+import json
+from os.path import join
+from torchnet import meter
+
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import wandb
+import torch.nn.functional as F
+from einops import rearrange
+
+from dataset import MetaLoader, create_dataset, create_loader, create_sampler
+from models.utils import tile
+from models import *
+from tasks_clip.shared_utils import get_media_types, setup_model
+from utils.basic_utils import MetricLogger, SmoothedValue, setup_seed, flat_list_of_lists, save_json
+from utils.config import Config
+from utils.config_utils import setup_main
+from utils.distributed import get_rank, get_world_size, is_main_process
+from utils.logger import log_dict_to_wandb, setup_wandb
+
+logger = logging.getLogger(__name__)
+
+
+def get_sim_for_each_question(model, pooled_image_feat, pooled_text_feat, model_cls):
+ """TODO: Docstring for get_sim_for_each_question.
+
+ Args:
+ model (TODO): TODO
+ pooled_image_feat (torch.Tensor): Shape: [b, c]
+ pooled_text_feat (torch.Tensor): Shape: [b, n, c]. n is the number of answer candidates.
+
+ Returns: TODO
+
+ """
+ image_feat = F.normalize(pooled_image_feat, dim=-1).to(torch.float32)
+ text_feat = F.normalize(pooled_text_feat, dim=-1).to(torch.float32)
+ sim = torch.matmul(image_feat.unsqueeze(1), rearrange(text_feat, "b n c -> b c n")) # [b, 1, n]
+ sim = sim[:, 0] / model.temp # [b, n]
+ sim = F.softmax(sim, dim=1) # [b, n]
+ return sim
+
+
+def main_with_ensemble(config, test_loader, model_without_ddp, tokenizer, data_type):
+ logger.info(f"test_file: {config.test_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+ cudnn.benchmark = True
+
+ config.scheduler.num_training_steps = 10
+ config.scheduler.num_warmup_steps = 10
+ model = model_without_ddp
+ model.eval()
+
+ logger.info("Start " + "evaluation" if config.evaluate else "training")
+ metric_logger = MetricLogger(delimiter=" ")
+ iterator = metric_logger.log_every(test_loader, 5, "Evaluation: ")
+ num_options_per_q = 157
+ all_gt_answers = None
+ all_preds = None
+ predictions = []
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type), torch.no_grad():
+ for image, text, ans, ann in iterator:
+ image = image.to(device, non_blocking=True) # bsz
+ if all_gt_answers is None:
+ all_gt_answers = ans
+ else:
+ all_gt_answers = torch.cat([all_gt_answers, ans], dim=0) # bsz*?
+ text = flat_list_of_lists(list(zip(*text))) # List(str), len=bsz*157
+ text_input = tokenizer(text).to(device) # bsz*157
+
+ # encode text
+ pooled_text_feat = model.encode_text(text_input) # [b*157, c]
+ # encode image
+ pooled_image_feat = model.encode_vision(image, test=True) # [bsz, c]
+
+ # contrastive score
+ pooled_text_feat = rearrange(pooled_text_feat, "(b n) c -> b n c", n=num_options_per_q)
+ score = get_sim_for_each_question(model, pooled_image_feat, pooled_text_feat, model_cls=config.model.model_cls) # [b, n]
+
+ # assemble predictions
+ if all_preds is None:
+ all_preds = score
+ else:
+ all_preds = torch.cat([all_preds, score], dim=0)
+ for q_idx in range(len(score)): # bsz
+ _pred = dict(
+ video=ann["video"][q_idx],
+ answer=ann["answer"][q_idx],
+ pred_scores=score[q_idx].cpu().numpy(), # (bsz, 157)
+ )
+ predictions.append(_pred)
+
+ # generate ft mactrix
+ gt = all_gt_answers.long()
+ map_meter = meter.mAPMeter()
+ map_meter.add(all_preds, gt)
+ ap = map_meter.value()
+ ap = float(ap) * 100
+
+ eval_res = {"map": round(ap, 2)}
+ logger.info(f"\n{eval_res}")
+ save_json(eval_res, join(config.output_dir, "eval_res.json"))
+ torch.save(predictions, join(config.output_dir, "prediction_scores.pth"))
+ return eval_res
+
+
+def train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type
+):
+ model.train()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("temperature", SmoothedValue(window=1, fmt="{value:.4f}"))
+ loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0]
+
+ media_types = [loader.dataset.media_type for loader in train_loaders]
+ for name in loss_names:
+ for m in media_types:
+ metric_logger.add_meter(f"{m}-{name}", SmoothedValue(window=1, fmt="{value:.4f}"))
+
+ header = f"Train Epoch: [{epoch}]"
+ log_freq = config.log_freq
+
+ if config.distributed:
+ for d in train_loaders:
+ d.sampler.set_epoch(epoch)
+ train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders))))
+
+ model_without_ddp = model.module if config.distributed else model
+ iterator = metric_logger.log_every(train_loader, log_freq, header)
+ for i, (media_type, (image, text, idx)) in enumerate(iterator):
+ image = image.to(device, non_blocking=True)
+ idx = idx.to(device, non_blocking=True)
+ text_input = tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=config.max_txt_l,
+ return_tensors="pt",
+ ).to(device)
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ loss_dict = model(image, text_input, idx=idx)
+ loss = sum(loss_dict.values())
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ if config.optimizer.max_grad_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+
+ # logging
+ for name in loss_names:
+ value = loss_dict[name]
+ value = value if isinstance(value, float) else value.item()
+ metric_logger.update(**{f"{media_type}-{name}": value})
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(temperature=model_without_ddp.temp.item())
+
+ if is_main_process() and config.wandb.enable and global_step % log_freq == 0:
+ logs = metric_logger.get_global_avg_dict()
+ log_dict_to_wandb(logs, step=global_step, prefix="train/")
+
+ global_step += 1
+
+ if config.debug and (i + 1) % 5 == 0:
+ break
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logger.info(f"Averaged train stats: {metric_logger.global_avg()}")
+ return global_step
+
+
+def main(config):
+ if is_main_process() and config.wandb.enable:
+ run = setup_wandb(config)
+
+ logger.info(f"config: \n{config}")
+ logger.info(f"train_file: {config.train_file}")
+
+ setup_seed(config.seed + get_rank())
+ device = torch.device(config.device)
+ cudnn.benchmark = True
+
+ # create train loader
+ train_datasets = create_dataset("ret_train", config)
+ media_types = get_media_types(train_datasets)
+ if config.distributed:
+ num_tasks = get_world_size()
+ global_rank = get_rank()
+ samplers = create_sampler(
+ train_datasets, [True] * len(media_types), num_tasks, global_rank
+ )
+ else:
+ samplers = [None] * len(media_types)
+ train_loaders = create_loader(
+ train_datasets,
+ samplers,
+ batch_size=[config.inputs.batch_size[k] for k in media_types],
+ num_workers=[config.num_workers] * len(media_types),
+ is_trains=[True] * len(media_types),
+ collate_fns=[None] * len(media_types),
+ )
+
+ num_steps_per_epoch = sum(len(d) for d in train_loaders)
+ config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs
+ config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs
+
+ model_cls = eval(config.model.get('model_cls', 'InternVideo2_CLIP'))
+ (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ ) = setup_model(
+ config,
+ model_cls=model_cls,
+ pretrain=False,
+ # find_unused_parameters=True,
+ find_unused_parameters=False,
+ )
+ if is_main_process() and config.wandb.enable:
+ wandb.watch(model)
+
+ # create test dataloader
+ test_dataset = create_dataset("mc_new_test", config)
+ test_loader = create_loader(
+ [test_dataset],
+ [None],
+ batch_size=[config.inputs.batch_size_test.video],
+ num_workers=[config.num_workers],
+ is_trains=[False],
+ collate_fns=[None],
+ )[0]
+
+ best = 0
+ best_epoch = 0
+
+ if config.get('use_bf16', True):
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float16
+
+ logger.info("Start " + "evaluation" if config.evaluate else "training")
+ start_time = time.time()
+ for epoch in range(start_epoch, config.scheduler.epochs):
+ if not config.evaluate:
+ global_step = train(
+ model,
+ train_loaders,
+ optimizer,
+ tokenizer,
+ epoch,
+ global_step,
+ device,
+ scheduler,
+ scaler,
+ config,
+ data_type=data_type
+ )
+
+ # save checkpoint befor evaluation
+ # only save those with gradient
+ if not config.evaluate:
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if config.get("save_latest", False):
+ tag = "ckpt_latest.pth"
+ else:
+ tag = f"ckpt_{epoch:02d}.pth"
+ model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, exclude_frozen_parameters=True)
+
+ elif is_main_process():
+ state_dict = model_without_ddp.state_dict()
+ param_grad_dict = {
+ k: v.requires_grad for (k, v) in model_without_ddp.named_parameters()
+ }
+ for k in list(state_dict.keys()):
+ if k in param_grad_dict.keys() and not param_grad_dict[k]:
+ # delete parameters that do not require gradient
+ logger.info(f"Not saving {k}")
+ del state_dict[k]
+
+ save_obj = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "scaler": scaler.state_dict(),
+ "config": config,
+ "epoch": epoch,
+ "global_step": global_step,
+ }
+ if config.get("save_latest", False):
+ torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth"))
+ else:
+ torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth"))
+
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ res = main_with_ensemble(config, test_loader, model_without_ddp, tokenizer, data_type=data_type)
+ eval_res = res
+
+ if is_main_process():
+ if config.wandb.enable:
+ log_dict_to_wandb(eval_res, step=global_step, prefix=config.test_types)
+
+ map = eval_res["map"]
+ logger.info(f"Epoch {epoch}")
+ logger.info(f"\n{eval_res}")
+
+ save_json(eval_res, join(config.output_dir, "eval_res_latest.json"))
+
+ if not config.evaluate and map > best:
+ if not hasattr(config, "deepspeed") or not config.deepspeed.enable:
+ torch.save(save_obj, join(config.output_dir, "ckpt_best.pth"))
+ eval_file = "eval_res_best.json"
+ save_json(eval_res, join(config.output_dir, eval_file))
+ best = map
+ best_epoch = epoch
+ if config.evaluate:
+ eval_file = "eval_res.json"
+ save_json(eval_res, join(config.output_dir, eval_file))
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ map_best = torch.tensor([0.0, 0.0]).to(device)
+ if is_main_process():
+ map_best[0] = map
+ map_best[1] = best
+ dist.broadcast(map_best, 0)
+ map, best = map_best[0].item(), map_best[1].item()
+
+ if not config.evaluate and map > best:
+ model.save_checkpoint(config.output_dir, tag="ckpt_best.pth", save_latest=False, exclude_frozen_parameters=True)
+
+ if config.evaluate or config.debug:
+ break
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Training time {total_time_str}")
+ logger.info(f"best epoch {best_epoch}")
+ logger.info(f"best {best}")
+ logger.info(f"Checkpoints and Logs saved at {config.output_dir}")
+
+ if is_main_process() and config.wandb.enable:
+ run.finish()
+
+
+def eval_after_training(train_config):
+ # general config for all
+ train_config.wandb.enable = False
+ train_config.evaluate = True
+ train_config.pretrained_path = join(train_config.output_dir, "ckpt_best.pth")
+ train_config.num_frames_test = train_config.num_frames
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames
+
+ if train_config.get('num_frames_test_final', False):
+ train_config.num_frames_test = train_config.num_frames_test_final
+ train_config.batch_size = train_config.batch_size_final
+ train_config.inputs.video_input.num_frames_test = train_config.num_frames_test_final
+ train_config.model.vision_encoder.num_frames = train_config.num_frames_test_final
+
+ eval_config = copy.deepcopy(train_config)
+ eval_config.test_types = list(eval_config.test_file.keys())
+ eval_config.output_dir = join(eval_config.output_dir, f"eval_after_training")
+ eval_config.result_dir = eval_config.output_dir
+ if is_main_process():
+ os.makedirs(eval_config.output_dir, exist_ok=True)
+ Config.dump(eval_config, os.path.join(eval_config.output_dir, "config.json"))
+ logger.info(f"===========> START eval_after_training [{eval_config.test_types}]")
+ main(eval_config)
+
+
+if __name__ == "__main__":
+ cfg = setup_main()
+ main(cfg)
+ if not cfg.evaluate:
+ eval_after_training(cfg)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca05a5736882cca19378e190aa237dbd3b6b5e0f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/retrieval_utils.py
@@ -0,0 +1,522 @@
+import datetime
+import logging
+import time
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from einops import rearrange
+
+from models.criterions import get_sim
+from utils.basic_utils import MetricLogger
+from utils.distributed import get_rank, get_world_size
+
+logger = logging.getLogger(__name__)
+
+
+def extract_text_feats(texts, max_txt_l, tokenizer, model, device):
+ num_text = len(texts)
+ text_bs = 256
+ text_feats = []
+ text_atts = []
+
+ for i in range(0, num_text, text_bs):
+ text = texts[i : min(num_text, i + text_bs)]
+ text_input = tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=max_txt_l,
+ return_tensors="pt",
+ ).to(device)
+
+ text_feat = model.encode_text(text_input)[0]
+ text_feats.append(text_feat)
+ text_atts.append(text_input.attention_mask)
+
+ text_feats = torch.cat(text_feats, dim=0)
+ text_atts = torch.cat(text_atts, dim=0)
+ return text_feats, text_atts
+
+
+def extract_vision_feats(data_loader, model, device, config):
+ image_feats_all = []
+ pooled_image_feats_all = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "extracting image feats"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for image, img_id in iterator:
+ image = image.to(device, non_blocking=True)
+ image_feat, pooled_image_feat = model.encode_vision(image, test=True)
+ if config.evaluation.eval_frame_ensemble == "concat": # default
+ if len(image_feat.shape) == 4:
+ image_feat = rearrange(image_feat, "b t l c -> b (t l) c").contiguous()
+ image_feat = image_feat.unsqueeze(1) # (bsz, 1, #frm*L, d)
+ else:
+ assert config.video_input.num_frames == 1, "only support single-frame"
+ assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
+ if config.evaluation.eval_offload:
+ image_feats_all.append(image_feat.cpu())
+ pooled_image_feats_all.append(pooled_image_feat.cpu())
+ else:
+ image_feats_all.append(image_feat)
+ pooled_image_feats_all.append(pooled_image_feat)
+
+ image_feats_all = torch.cat(image_feats_all, dim=0)
+
+ pooled_image_feats_all = torch.cat(pooled_image_feats_all, dim=0)
+ return image_feats_all, pooled_image_feats_all
+
+
+@torch.no_grad()
+def evaluation_wrapper(model, data_loader, tokenizer, device, config, data_type, prefix=""):
+ with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=data_type):
+ if "InternVideo2_CLIP" in config.model.model_cls:
+ i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation_video_clip(
+ model, data_loader, tokenizer, device, config)
+ else:
+ i2t_x, t2i_x, i2t_emb, t2i_emb = evaluation(
+ model, data_loader, tokenizer, device, config)
+ score_pairs = [
+ (prefix + "/", i2t_x, t2i_x),
+ (prefix + "_emb/", i2t_emb, t2i_emb),
+ ]
+ res = dict()
+ for name, i2t, t2i in score_pairs:
+ if i2t is not None:
+ txt2img_ids = data_loader.dataset.txt2img
+ img2txt_ids = data_loader.dataset.img2txt
+ res[name] = itm_eval(i2t, t2i, txt2img_ids, img2txt_ids)
+ return res
+
+
+@torch.no_grad()
+def evaluation_video_clip(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ dtype = torch.half if config.use_half_precision else torch.float
+ media_type = data_loader.dataset.media_type
+ logger.info(f"Start evaluation for media_type={media_type}")
+
+ logger.info("Computing dual encoder features...")
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_feats = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i:min(num_text, i + text_bs)]
+ if "InternVideo2_CLIP" in config.model.model_cls:
+ text_feat = model.encode_text(tokenizer(text).to(device))
+ else:
+ text_feat = model.encode_text(text)
+ text_feats.append(text_feat.cpu())
+ text_feats = torch.cat(text_feats, dim=0)
+ logger.info("Finished computing text features")
+
+ if hasattr(data_loader.dataset, "num_prompts"):
+ np = data_loader.dataset.num_prompts
+ logger.info("Using {} prompts".format(np))
+ nt = len(data_loader.dataset.text) // np
+ text_feats = text_feats.view(nt, np, -1)
+
+ image_feats = []
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "extracting image feats"
+ iterator = metric_logger.log_every(data_loader, 100, header)
+ for image, _ in iterator:
+ image = image.to(device, non_blocking=True)
+ image_feat = model.encode_vision(image, test=True)
+ image_feats.append(image_feat.cpu())
+ image_feats = torch.cat(image_feats, dim=0)
+ logger.info("Finished feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ i2t_scores, t2i_scores = get_sim(image_feats, text_feats)
+ del image_feats, text_feats
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ i2t_scores_dsl = i2t_scores * i2t_scores.softmax(dim=0)
+ i2t_scores_dsl_T = i2t_scores.T * i2t_scores.T.softmax(dim=0)
+
+ return (
+ i2t_scores_dsl.cpu().float().numpy(),
+ i2t_scores_dsl_T.cpu().float().numpy(),
+ i2t_scores.cpu().float().numpy(),
+ i2t_scores.T.cpu().float().numpy(),
+ )
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, tokenizer, device, config):
+ model.eval()
+
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation:"
+ dtype = torch.half if config.use_half_precision else torch.float
+ media_type = data_loader.dataset.media_type
+ logger.info(f"Start evaluation for media_type={media_type}")
+
+ logger.info("Computing dual encoder features...")
+ start_time = time.time()
+
+ # this computes all features in each GPU
+ texts = data_loader.dataset.text
+ max_txt_l = config.inputs.max_txt_l
+ if not isinstance(max_txt_l, int):
+ max_txt_l = max_txt_l[media_type]
+ text_feats, text_atts = extract_text_feats(
+ texts, max_txt_l, tokenizer, model, device
+ ) # (bsz, Lt, d), (bsz, Lt)
+
+ image_feats, pooled_image_feats = extract_vision_feats(
+ data_loader, model, device, config
+ ) # (bsz, 1, #frm*Li, d) or (bsz, #frm, Li, d), (bsz, #frm, d)
+ logger.info("Finished feature extraction")
+ logger.info("Computing ITC scores [dot-product]")
+ _pooled_image_feats = (
+ pooled_image_feats.to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else pooled_image_feats
+ )
+ i2t_scores, t2i_scores = get_sim(
+ model.vision_proj(_pooled_image_feats), model.text_proj(text_feats[:, 0])
+ )
+ logger.info("Computing ITC scores [dot-product], done!")
+
+ num_images = len(data_loader.dataset.image)
+ i2t_scores_x = torch.full((num_images, len(texts)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ # computes only part of the scores at each GPU, gather at the end
+ logger.info("Rerank dual-encoder results with cross-encoder...")
+ num_tasks = get_world_size()
+ rank = get_rank()
+ # only uses the part associated with the raw eval set
+ # compute image2text #
+ step = num_images // num_tasks + 1
+ start = rank * step
+ end = min(num_images, start + step)
+
+ text_encoder = model.get_text_encoder()
+ iterator = metric_logger.log_every(i2t_scores[start:end], 100, header)
+ logger.info(f"i2t_scores.shape {i2t_scores[start:end].shape}")
+
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
+ )
+
+ logger.info(
+ f"n_clip_per_video={n_clip_per_video}, with eval_frame_ensemble={config.evaluation.eval_frame_ensemble}"
+ )
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+ if config.deep_fusion:
+ encoder_output = [
+ feat[start + i, clip_idx].to(device, non_blocking=True)
+ for feat in image_feats
+ ]
+
+ else:
+ encoder_output = (
+ image_feats[start + i, clip_idx].to(device, non_blocking=True)
+ if config.evaluation.eval_offload
+ else image_feats[start + i, clip_idx]
+ ) # (#frm*Li, d)
+
+ """ original
+ encoder_output = encoder_output.repeat(k, 1, 1) # (k=128, #frm*Li, d)
+ encoder_att = torch.ones(
+ encoder_output.size()[:-1], dtype=torch.long
+ ).to(device, non_blocking=True)
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx],
+ attention_mask=text_atts[topk_idx],
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion"
+ )
+
+ itm_embeds = output.last_hidden_state[:, 0]
+ """
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+
+ if config.deep_fusion:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = [feat.repeat(left, 1, 1) for feat in encoder_output]
+ left_encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in left_encoder_output
+ ]
+ encoder_output = [feat.repeat(bs, 1, 1) for feat in encoder_output]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ if len(topk_idx) % bs != 0:
+ left = len(topk_idx) % bs
+ left_encoder_output = encoder_output.repeat(left, 1, 1) # (k=128, #frm*Li, d)
+ left_encoder_att = torch.ones(left_encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ encoder_output = encoder_output.repeat(bs, 1, 1) # (k=128, #frm*Li, d)
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ for j in range(0, len(topk_idx), bs):
+ if j + bs > len(topk_idx):
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j:]],
+ attention_mask=text_atts[topk_idx[j:]],
+ encoder_hidden_states=left_encoder_output,
+ encoder_attention_mask=left_encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ else:
+ output = text_encoder(
+ encoder_embeds=text_feats[topk_idx[j : j + bs]],
+ attention_mask=text_atts[topk_idx[j : j + bs]],
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = model.itm_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
+ clip_scores = torch.stack(clip_scores) # (#clips, k)
+ if config.evaluation.eval_frame_ensemble == "mean":
+ score = clip_scores.mean(0)
+ elif config.evaluation.eval_frame_ensemble == "max":
+ score = clip_scores.max(0)[0]
+ elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
+ score = torch.logsumexp(clip_scores, dim=0)
+ else:
+ raise ValueError(
+ "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
+ )
+
+ i2t_scores_x[start + i, topk_idx] = score.to(i2t_scores_x.dtype)
+
+ # compute text2image #
+ num_text = len(data_loader.dataset.text)
+ t2i_scores_x = torch.full((num_text, len(data_loader.dataset.image)), -100.0).to(
+ device, torch.float, non_blocking=True
+ )
+
+ step = num_text // num_tasks + 1
+ start = rank * step
+ end = min(num_text, start + step)
+
+ iterator = metric_logger.log_every(t2i_scores[start:end], 100, header)
+ logger.info(f"t2i_scores.shape {t2i_scores[start:end].shape}")
+ # generate score for each clip, and aggregate all clip scores for a video
+ n_clip_per_video = (
+ image_feats.shape[1] if not config.deep_fusion else image_feats[0].shape[1]
+ )
+ for i, sims in enumerate(iterator):
+ k = min(len(sims), config.evaluation.k_test)
+ topk_sim, topk_idx = sims.topk(k=k, dim=0)
+
+ clip_scores = []
+ for clip_idx in range(n_clip_per_video):
+
+ """old
+ encoder_output = image_feats[topk_idx, clip_idx].to(device, non_blocking=True) \
+ if config.evaluation.eval_offload else image_feats[topk_idx, clip_idx]
+ encoder_att = torch.ones(
+ encoder_output.size()[:-1], dtype=torch.long
+ ).to(device, non_blocking=True)
+ output = text_encoder(
+ encoder_embeds=text_feats[start+i].repeat(k, 1, 1),
+ attention_mask=text_atts[start+i].repeat(k, 1),
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion"
+ )
+
+ itm_embeds = output.last_hidden_state[:, 0]
+ """
+
+ # new
+ bs = 32
+ # bs = config.batch_size_test.video
+ itm_embeds = []
+ for j in range(0, len(topk_idx), bs):
+
+ if config.deep_fusion:
+ encoder_output = [
+ feat[topk_idx[j : j + bs], clip_idx].to(device, non_blocking=True)
+ for feat in image_feats
+ ]
+ encoder_att = [
+ torch.ones(feat.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+ for feat in encoder_output
+ ]
+ else:
+ encoder_output = (
+ image_feats[topk_idx[j : j + bs], clip_idx].to(
+ device, non_blocking=True
+ )
+ if config.evaluation.eval_offload
+ else image_feats[topk_idx[j : j + bs], clip_idx]
+ )
+ encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
+ device, non_blocking=True
+ )
+
+ repeat_n = (
+ encoder_output.shape[0]
+ if not config.deep_fusion
+ else encoder_output[0].shape[0]
+ )
+ output = text_encoder(
+ encoder_embeds=text_feats[start + i].repeat(repeat_n, 1, 1),
+ attention_mask=text_atts[start + i].repeat(repeat_n, 1),
+ encoder_hidden_states=encoder_output,
+ encoder_attention_mask=encoder_att,
+ return_dict=True,
+ mode="fusion",
+ )
+
+ batch_itm_embeds = output.last_hidden_state[:, 0]
+ itm_embeds.append(batch_itm_embeds)
+
+ itm_embeds = torch.cat(itm_embeds, dim=0)
+ # end new
+
+ score = model.itm_head(itm_embeds)[:, 1]
+ clip_scores.append(score)
+
+ if len(clip_scores) == 1:
+ score = clip_scores[0]
+ else:
+ assert config.evaluation.eval_frame_ensemble in ["mean", "max", "lse"]
+ clip_scores = torch.stack(clip_scores) # (#clips, k)
+ if config.evaluation.eval_frame_ensemble == "mean":
+ score = clip_scores.mean(0)
+ elif config.evaluation.eval_frame_ensemble == "max":
+ score = clip_scores.max(0)[0]
+ elif config.evaluation.eval_frame_ensemble == "lse": # LogSumExp
+ score = torch.logsumexp(clip_scores, dim=0)
+ else:
+ raise ValueError(
+ "config.evaluation.eval_frame_ensemble must in [mean, max, lse] when #clip > 1."
+ )
+
+ t2i_scores_x[start + i, topk_idx] = score.to(t2i_scores_x.dtype)
+
+ if config.distributed:
+ # gether across GPUs
+ dist.barrier()
+ dist.all_reduce(i2t_scores_x, op=dist.ReduceOp.SUM)
+ dist.all_reduce(t2i_scores_x, op=dist.ReduceOp.SUM)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info(f"Evaluation time {total_time_str}")
+
+ return (
+ i2t_scores_x.cpu().numpy(),
+ t2i_scores_x.cpu().numpy(),
+ i2t_scores.cpu().numpy(),
+ i2t_scores.T.cpu().numpy(),
+ )
+
+
+@torch.no_grad()
+def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
+ # Images->Text
+ ranks = np.zeros(scores_i2t.shape[0])
+ for index, score in enumerate(scores_i2t):
+ inds = np.argsort(score)[::-1]
+ # Score
+ gt_txt_ids = img2txt[index]
+ if isinstance(gt_txt_ids, int):
+ ranks[index] = np.where(inds == gt_txt_ids)[0][0]
+ else:
+ rank = 1e20
+ for i in gt_txt_ids:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ # Text->Images
+ ranks = np.zeros(scores_t2i.shape[0])
+
+ for index, score in enumerate(scores_t2i):
+ inds = np.argsort(score)[::-1]
+ gt_img_ids = txt2img[index]
+ if isinstance(gt_img_ids, int):
+ ranks[index] = np.where(inds == gt_img_ids)[0][0]
+ else: # list, used in the case each caption has multiple GT images
+ # Score
+ rank = 1e20
+ for i in gt_img_ids:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ ir_mean = (ir1 + ir5 + ir10) / 3
+ r_mean = (tr_mean + ir_mean) / 2
+
+ eval_result = {
+ "txt_r1": tr1,
+ "txt_r5": tr5,
+ "txt_r10": tr10,
+ "txt_r_mean": tr_mean,
+ "img_r1": ir1,
+ "img_r5": ir5,
+ "img_r10": ir10,
+ "img_r_mean": ir_mean,
+ "r_mean": r_mean,
+ }
+ eval_result = {k: round(v, 2) for k, v in eval_result.items()}
+ return eval_result
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/shared_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/shared_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..357439feb0bef466233033711e71c94326559d11
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tasks_clip/shared_utils.py
@@ -0,0 +1,180 @@
+import copy
+import logging
+import os
+import os.path as osp
+from os.path import join
+
+import torch
+import deepspeed
+from torch.utils.data import ConcatDataset, DataLoader
+
+from utils.optimizer import create_optimizer
+from utils.scheduler import create_scheduler
+
+logger = logging.getLogger(__name__)
+
+
+def get_media_types(datasources):
+ """get the media types for for all the dataloaders.
+
+ Args:
+ datasources (List): List of dataloaders or datasets.
+
+ Returns: List. The media_types.
+
+ """
+ if isinstance(datasources[0], DataLoader):
+ datasets = [dataloader.dataset for dataloader in datasources]
+ else:
+ datasets = datasources
+ media_types = [
+ dataset.datasets[0].media_type
+ if isinstance(dataset, ConcatDataset)
+ else dataset.media_type
+ for dataset in datasets
+ ]
+
+ return media_types
+
+
+def setup_model(
+ config, model_cls, pretrain=False, find_unused_parameters=False, num_steps_per_epoch=-1,
+):
+ logger.info("Creating model")
+ config = copy.deepcopy(config)
+
+ model = model_cls(config=config, is_pretrain=pretrain)
+
+ model = model.to(torch.device(config.device))
+ if config.use_half_precision:
+ if config.get('bf16', True):
+ logger.info("Change to bfloat16 for model")
+ model = model.to(torch.bfloat16)
+ else:
+ logger.info("Change to float16 for model")
+ model = model.half()
+ tokenizer = model.tokenizer
+ model_without_ddp = model
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ optimizer_params = create_optimizer(config.optimizer, model, return_group=True)
+ scheduler = None
+ scaler = None
+ else:
+ if config.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[config.gpu],
+ find_unused_parameters=find_unused_parameters, # `False` for image-only task
+ )
+
+ optimizer = create_optimizer(config.optimizer, model)
+ scheduler = create_scheduler(config.scheduler, optimizer)
+ scaler = torch.cuda.amp.GradScaler(enabled=config.use_half_precision) # This is never used actually if we fixed bf16
+
+ start_epoch = 0
+ global_step = 0
+
+ # auto resume the latest checkpoint
+ if config.get("auto_resume", False):
+ logger.info("Auto resuming")
+ model_latest = join(config.output_dir, "ckpt_latest.pth")
+ model_best = join(config.output_dir, "ckpt_best.pth")
+
+ large_step_num = -1
+ large_num = -1
+ for p in os.listdir(config.output_dir):
+ if 'ckpt_iter' in p:
+ num = p.split('_iter')[1].split('.')[0]
+ if str.isnumeric(num):
+ if int(num) > large_step_num:
+ large_step_num = int(num)
+ elif 'ckpt_' in p:
+ num = p.split('_')[1].split('.')[0]
+ if str.isnumeric(num):
+ if int(num) > large_num:
+ large_num = int(num)
+ if large_step_num != -1:
+ logger.info(f"Load the latest step: {large_step_num}")
+ model_latest = join(config.output_dir, f"ckpt_iter{large_step_num:02d}.pth")
+ if large_num != -1 and (large_num + 1) * num_steps_per_epoch > large_step_num:
+ logger.info(f"Load the latest epoch: {large_num}")
+ model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ if osp.isdir(model_latest):
+ config.pretrained_path = model_latest
+ config.resume = True
+ elif osp.isdir(model_best):
+ config.pretrained_path = model_best
+ config.resume = True
+ else:
+ logger.info(f"Not found checkpoint in {config.output_dir}")
+ else:
+ if osp.isfile(model_latest):
+ config.pretrained_path = model_latest
+ config.resume = True
+ elif osp.isfile(model_best):
+ config.pretrained_path = model_best
+ config.resume = True
+ else:
+ logger.info(f"Not found checkpoint in {config.output_dir}")
+
+ # load pretrained model
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ logger.info('Use deepspeed to initialize model!!!')
+ model = model_without_ddp
+ model, optimizer, _, _ = deepspeed.initialize(
+ args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed,
+ lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt)
+ )
+ if osp.isdir(config.pretrained_path):
+ logger.info(f"Load pretrained model from {config.pretrained_path}")
+ output_dir, tag = os.path.split(config.pretrained_path)
+ if config.resume:
+ _, client_state = model.load_checkpoint(output_dir, tag=tag, load_module_strict=False)
+ global_step = model.global_steps
+ assert num_steps_per_epoch > 0, "Please provide num_steps_per_epoch"
+ start_epoch = global_step // num_steps_per_epoch
+ else:
+ _, client_state = model.load_checkpoint(
+ output_dir, tag=tag, load_module_strict=False,
+ load_optimizer_states=False, load_lr_scheduler_states=False,
+ load_module_only=True
+ )
+ else:
+ if osp.isfile(config.pretrained_path):
+ checkpoint = torch.load(config.pretrained_path, map_location="cpu")
+ logger.info(f"Load pretrained model from {config.pretrained_path}")
+ if 'model' in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ elif 'module' in checkpoint.keys():
+ state_dict = checkpoint["module"]
+ else:
+ state_dict = checkpoint
+ # resume optimizer
+ if config.resume:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ scheduler.load_state_dict(checkpoint["scheduler"])
+ scaler.load_state_dict(checkpoint["scaler"])
+ start_epoch = checkpoint["epoch"] + 1
+ global_step = checkpoint["global_step"]
+
+ msg = model_without_ddp.load_state_dict(state_dict, strict=False)
+ logger.info(msg)
+ logger.info(f"Loaded checkpoint from {config.pretrained_path}")
+ else:
+ logger.warning("No pretrained checkpoint provided, training from scratch")
+
+ logger.info(f"Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M")
+
+ return (
+ model,
+ model_without_ddp,
+ optimizer,
+ scheduler,
+ scaler,
+ tokenizer,
+ start_epoch,
+ global_step,
+ )
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tests/test_cfg.py b/third_party/InternVideo/InternVideo2/multi_modality/tests/test_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5988cf61d428da271cc1fdd332d4019c31a9976
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tests/test_cfg.py
@@ -0,0 +1,6 @@
+from utils.config import Config
+
+cfg = Config.get_config()
+
+cfg_text = Config.pretty_text(cfg)
+print(cfg_text)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tools/run.py b/third_party/InternVideo/InternVideo2/multi_modality/tools/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..834af73da0b12e14dddda1f66d045c646ffa10ac
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tools/run.py
@@ -0,0 +1,147 @@
+import argparse
+import os
+import socket
+
+from utils import has_slurm, random_port, runcmd
+
+EXP_DIR_ENV_NAME = "VL_EXP_DIR"
+
+# if key in hostname; apply the args in value to slurm.
+DEFAULT_SLURM_ARGS = dict(login="-p gpu --mem=240GB -c 64 -t 2-00:00:00")
+
+
+def get_default_slurm_args():
+ """get the slurm args for different cluster.
+ Returns: TODO
+
+ """
+ hostname = socket.gethostname()
+ for k, v in DEFAULT_SLURM_ARGS.items():
+ if k in hostname:
+ return v
+ return ""
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # slurm
+ parser.add_argument("--slurm_args", type=str, default="", help="args for slurm.")
+ parser.add_argument(
+ "--no_slurm",
+ action="store_true",
+ help="If specified, will launch job without slurm",
+ )
+ parser.add_argument("--jobname", type=str, required=True, help="experiment name")
+ parser.add_argument(
+ "--dep_jobname", type=str, default="impossible_jobname", help="the dependent job name"
+ )
+ parser.add_argument("--nnodes", "-n", type=int, default=1, help="the number of nodes")
+ parser.add_argument(
+ "--ngpus", "-g", type=int, default=1, help="the number of gpus per nodes"
+ )
+
+ #
+ parser.add_argument(
+ "--task",
+ type=str,
+ required=True,
+ help="one of: pretrain, retrieval, retrieval_mc, vqa.",
+ )
+ parser.add_argument("--config", type=str, required=True, help="config file name.")
+ parser.add_argument("--model_args", type=str, default="", help="args for model")
+
+ args = parser.parse_args()
+ return args
+
+
+def get_output_dir(args):
+ """get the output_dir"""
+ return os.path.join(os.environ[EXP_DIR_ENV_NAME], args.jobname)
+
+
+def prepare(args: argparse.Namespace):
+ """prepare for job submission
+
+ Args:
+ args (dict): The arguments.
+
+ Returns: The path to the copied source code.
+
+ """
+
+ output_dir = get_output_dir(args)
+ code_dir = os.path.join(output_dir, "code")
+ project_dirname = os.path.basename(os.getcwd())
+
+ # check output_dir exist
+ if os.path.isdir(output_dir):
+ # if using slurm
+ if has_slurm() and not args.no_slurm:
+ raise ValueError(f"output_dir {output_dir} already exist. Exit.")
+ else:
+ os.mkdir(output_dir)
+ # copy code
+ cmd = f"cd ..; rsync -ar {project_dirname} {code_dir} --exclude='*.out'"
+ print(cmd)
+ runcmd(cmd)
+ return os.path.join(code_dir, project_dirname)
+
+
+def submit_job(args: argparse.Namespace):
+ """TODO: Docstring for build_job_script.
+
+ Args:
+ args (argparse.Namespace): The commandline args.
+
+ Returns: str. The script to run.
+
+ """
+ output_dir = get_output_dir(args)
+ # copy code
+ code_dir = prepare(args)
+
+ # enter in the backup code
+ master_port = os.environ.get("MASTER_PORT", random_port())
+ init_cmd = f" cd {code_dir}; export MASTER_PORT={master_port}; "
+
+ if has_slurm() and not args.no_slurm:
+ # prepare slurm args.
+ mode = "slurm"
+ default_slurm_args = get_default_slurm_args()
+ bin = (
+ f" sbatch --output {output_dir}/%j.out --error {output_dir}/%j.out"
+ f" {default_slurm_args}"
+ f" {args.slurm_args} --job-name={args.jobname} --nodes {args.nnodes} "
+ f" --ntasks {args.nnodes} "
+ f" --gpus-per-node={args.ngpus} "
+ f" --dependency=$(squeue --noheader --format %i --name {args.dep_jobname}) "
+ )
+ else:
+ mode = "local"
+ bin = "bash "
+
+ # build job cmd
+ job_cmd = (
+ f" tasks/{args.task}.py"
+ f" {args.config}"
+ f" output_dir {output_dir}"
+ f" {args.model_args}"
+ )
+
+ cmd = (
+ f" {init_cmd} {bin} "
+ f" tools/submit.sh "
+ f" {mode} {args.nnodes} {args.ngpus} {job_cmd} "
+ )
+
+ with open(os.path.join(output_dir, "cmd.txt"), "w") as f:
+ f.write(cmd)
+
+ print(cmd)
+ runcmd(cmd)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ submit_job(args)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tools/submit.sh b/third_party/InternVideo/InternVideo2/multi_modality/tools/submit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a8e53efb6e7df0a2567ac3497614cda22dec04e5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tools/submit.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env sh
+
+mode=$1 # slurm or local
+nnodes=$2
+ngpus=$3
+cmd=${@:4} # the command to run. i.e. tasks/pretrain.py ...
+
+if [[ "$mode" == "slurm" ]]; then # slurm
+ master_node=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+ all_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
+ echo "All nodes used: ${all_nodes}"
+ echo "Master node ${master_node}"
+
+ head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$master_node" hostname --ip-address | awk '{print $1}')
+ # head_node_ip=$master_node
+ rdzv_endpoint="${head_node_ip}:${MASTER_PORT:-40000}"
+ bin="srun"
+
+else # local
+ rdzv_endpoint="${MASTER_ADDR:-localhost}:${MASTER_PORT:-40000}"
+ bin=""
+fi
+
+echo "PYTHONPATH: ${PYTHONPATH}"
+which_python=$(which python)
+echo "which python: ${which_python}"
+export PYTHONPATH=${PYTHONPATH}:${which_python}
+export PYTHONPATH=${PYTHONPATH}:.
+echo "PYTHONPATH: ${PYTHONPATH}"
+
+#run command
+$bin torchrun --nnodes=$nnodes \
+ --nproc_per_node=$ngpus \
+ --rdzv_backend=c10d \
+ --rdzv_endpoint=${rdzv_endpoint} \
+ $cmd
+
+echo "Finish at dir: ${PWD}"
+############### ======> Your training scripts [END]
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/tools/utils.py b/third_party/InternVideo/InternVideo2/multi_modality/tools/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c00a7816145ffbfab79f82b82096cefe97dcbc4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/tools/utils.py
@@ -0,0 +1,29 @@
+import os
+import shutil
+import socket
+
+
+def has_slurm():
+ """determine the system has slurm or not
+ Returns: True if has else False.
+
+ """
+ return shutil.which("sbatch") is not None
+
+def random_port():
+ """random a unused port
+ Returns: str
+
+ """
+ with socket.socket() as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+def runcmd(cmd):
+ """run command
+
+ Args:
+ cmd (str): The command to run
+
+ """
+ os.system(cmd)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/torchrun.sh b/third_party/InternVideo/InternVideo2/multi_modality/torchrun.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ffc8fa5071d108cb63cb65b0ecd6c67862df8abb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/torchrun.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+MASTER_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+ALL_NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
+MASTER_PORT=$((10660 + $RANDOM % 10))
+
+echo "All nodes used:"
+echo ${ALL_NODES}
+echo "Master node:"
+echo ${MASTER_NODE}
+echo "Args:"
+echo $@
+
+torchrun --rdzv_endpoint=${MASTER_NODE}:10069 $@
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/basic_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/basic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb453d35c852741bf1ad6dfe27e604d9fef6557e
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/basic_utils.py
@@ -0,0 +1,286 @@
+import numpy as np
+import io
+import os
+import json
+import logging
+import random
+import time
+from collections import defaultdict, deque
+import datetime
+from pathlib import Path
+from typing import List, Union
+
+import torch
+import torch.distributed as dist
+from .distributed import is_dist_avail_and_initialized
+
+
+logger = logging.getLogger(__name__)
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total],
+ dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0: # skip empty meter
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ if meter.count == 0:
+ loss_str.append(
+ "{}: {}".format(name, "No data")
+ )
+ else:
+ loss_str.append(
+ "{}: {:.4f}".format(name, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def get_global_avg_dict(self, prefix=""):
+ """include a separator (e.g., `/`, or "_") at the end of `prefix`"""
+ d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
+ return d
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, log_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % log_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ res_mem=torch.cuda.max_memory_reserved() / MB,
+ ))
+ else:
+ logger.info(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logger.info('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def compute_acc(logits, label, reduction='mean'):
+ ret = (torch.argmax(logits, dim=1) == label).float()
+ if reduction == 'none':
+ return ret.detach()
+ elif reduction == 'mean':
+ return ret.mean().item()
+
+
+def compute_n_params(model, return_str=True):
+ tot = 0
+ for p in model.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return '{:.1f}M'.format(tot / 1e6)
+ else:
+ return '{:.1f}K'.format(tot / 1e3)
+ else:
+ return tot
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def remove_files_if_exist(file_paths):
+ for fp in file_paths:
+ if os.path.isfile(fp):
+ os.remove(fp)
+
+
+def save_json(data, filename, save_pretty=False, sort_keys=False):
+ with open(filename, "w") as f:
+ if save_pretty:
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
+ else:
+ json.dump(data, f)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+def flat_list_of_lists(l):
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
+ return [item for sublist in l for item in sublist]
+
+
+def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
+ """
+ Args:
+ root: path to the directory to start search files
+ suffix: any str as suffix, or can match multiple such strings
+ when input is List[str].
+ Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
+ Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
+ """
+ if isinstance(suffix, str):
+ suffix = [suffix, ]
+ filepaths = flat_list_of_lists(
+ [list(Path(root).rglob(f"*{e}")) for e in suffix])
+ return filepaths
+
+
+def match_key_and_shape(state_dict1, state_dict2):
+ keys1 = set(state_dict1.keys())
+ keys2 = set(state_dict2.keys())
+ print(f"keys1 - keys2: {keys1 - keys2}")
+ print(f"keys2 - keys1: {keys2 - keys1}")
+
+ mismatch = 0
+ for k in list(keys1):
+ if state_dict1[k].shape != state_dict2[k].shape:
+ print(
+ f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
+ mismatch += 1
+ print(f"mismatch {mismatch}")
+
+
+def merge_dicts(list_dicts):
+ merged_dict = list_dicts[0].copy()
+ for i in range(1, len(list_dicts)):
+ merged_dict.update(list_dicts[i])
+ return merged_dict
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/config.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b911b64de6647eb2f448828959652c601112d9c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/config.py
@@ -0,0 +1,274 @@
+from __future__ import annotations
+
+import argparse
+import ast
+import json
+import os
+import os.path as osp
+import re
+import shutil
+import sys
+import tempfile
+from copy import deepcopy
+from importlib import import_module
+
+import yaml
+
+from .easydict import EasyDict
+
+__all__ = ["Config", "pretty_text"]
+
+
+BASE_KEY = "_base_"
+# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
+BASE_CONFIG = {}
+
+cfg = None
+
+
+class Config(object):
+ """config"""
+
+ @classmethod
+ def pretty_text(cls, cfg: dict, indent=2) -> str:
+ """format dict to a string
+
+ Args:
+ cfg (EasyDict): the params.
+
+ Returns: The string to display.
+
+ """
+ msg = "{\n"
+ for i, (k, v) in enumerate(cfg.items()):
+ if isinstance(v, dict):
+ v = cls.pretty_text(v, indent + 4)
+ spaces = " " * indent
+ msg += spaces + "{}: {}".format(k, v)
+ if i == len(cfg) - 1:
+ msg += " }"
+ else:
+ msg += "\n"
+ return msg
+
+ @classmethod
+ def dump(cls, cfg, savepath=None):
+ """dump cfg to `json` file.
+
+ Args:
+ cfg (dict): The dict to dump.
+ savepath (str): The filepath to save the dumped dict.
+
+ Returns: TODO
+
+ """
+ if savepath is None:
+ savepath = osp.join(cfg.WORKSPACE, "config.json")
+ json.dump(cfg, open(savepath, "w"), indent=2)
+
+ @classmethod
+ def get_config(cls, default_config: dict = None):
+ """get a `Config` instance.
+
+ Args:
+ default_config (dict): The default config. `default_config` will be overrided
+ by config file `--cfg`, `--cfg` will be overrided by commandline args.
+
+ Returns: an EasyDict.
+ """
+ global cfg
+ if cfg is not None:
+ return cfg
+
+ # define arg parser.
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
+ parser.add_argument(
+ "config_file", help="the configuration file to load. support: .yaml, .json, .py"
+ )
+ parser.add_argument(
+ "opts",
+ default=None,
+ nargs="*",
+ help="overrided configs. List. Format: 'key1 name1 key2 name2'",
+ )
+ args = parser.parse_args()
+
+ cfg = EasyDict(BASE_CONFIG)
+ if osp.isfile(args.config_file):
+ cfg_from_file = cls.from_file(args.config_file)
+ cfg = merge_a_into_b(cfg_from_file, cfg)
+ cfg = cls.merge_list(cfg, args.opts)
+ cfg = eval_dict_leaf(cfg)
+
+ # update some keys to make them show at the last
+ for k in BASE_CONFIG:
+ cfg[k] = cfg.pop(k)
+ return cfg
+
+ @classmethod
+ def from_file(cls, filepath: str) -> EasyDict:
+ """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
+
+ Args:
+ filepath (str): The config file path.
+
+ Returns: TODO
+
+ """
+ filepath = osp.abspath(osp.expanduser(filepath))
+ if not osp.isfile(filepath):
+ raise IOError(f"File does not exist: {filepath}")
+ if filepath.endswith(".py"):
+ sys.path.insert(0, osp.dirname(filepath))
+ mod = import_module(osp.splitext(osp.basename(filepath))[0])
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+
+ elif filepath.endswith((".yml", ".yaml")):
+ cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
+ elif filepath.endswith(".json"):
+ cfg_dict = json.load(open(filepath, "r"))
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ cfg_text = filepath + "\n"
+ with open(filepath, "r") as f:
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
+ cfg_dir = osp.dirname(filepath)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = (
+ base_filename if isinstance(base_filename, list) else [base_filename]
+ )
+
+ cfg_dict_list = list()
+ for f in base_filename:
+ _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError("Duplicate key is not allowed among bases")
+ base_cfg_dict.update(c)
+
+ cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
+
+ return EasyDict(cfg_dict)
+
+ @classmethod
+ def merge_list(cls, cfg, opts: list):
+ """merge commandline opts.
+
+ Args:
+ cfg: (dict): The config to be merged.
+ opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
+ The keys can be nested. For example, ["a.b", v] will be considered
+ as `dict(a=dict(b=v))`.
+
+ Returns: dict.
+
+ """
+ assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
+ for i in range(0, len(opts), 2):
+ full_k, v = opts[i], opts[i + 1]
+ keys = full_k.split(".")
+ sub_d = cfg
+ for i, k in enumerate(keys):
+ if not hasattr(sub_d, k):
+ raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
+ if i != len(keys) - 1:
+ sub_d = sub_d[k]
+ else:
+ sub_d[k] = v
+ return cfg
+
+
+def merge_a_into_b(a, b, inplace=False):
+ """The values in a will override values in b.
+
+ Args:
+ a (dict): source dict.
+ b (dict): target dict.
+
+ Returns: dict. recursively merge dict a into dict b.
+
+ """
+ if not inplace:
+ b = deepcopy(b)
+ for key in a:
+ if key in b:
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
+ b[key] = merge_a_into_b(a[key], b[key], inplace=True)
+ else:
+ b[key] = a[key]
+ else:
+ b[key] = a[key]
+ return b
+
+
+def eval_dict_leaf(d, orig_dict=None):
+ """eval values of dict leaf.
+
+ Args:
+ d (dict): The dict to eval.
+
+ Returns: dict.
+
+ """
+ if orig_dict is None:
+ orig_dict = d
+ for k, v in d.items():
+ if not isinstance(v, dict):
+ d[k] = eval_string(v, orig_dict)
+ else:
+ eval_dict_leaf(v, orig_dict)
+ return d
+
+
+def eval_string(string, d):
+ """automatically evaluate string to corresponding types.
+
+ For example:
+ not a string -> return the original input
+ '0' -> 0
+ '0.2' -> 0.2
+ '[0, 1, 2]' -> [0,1,2]
+ 'eval(1+2)' -> 3
+ 'eval(range(5))' -> [0,1,2,3,4]
+ '${a}' -> d.a
+
+
+
+ Args:
+ string (str): The value to evaluate.
+ d (dict): The
+
+ Returns: the corresponding type
+
+ """
+ if not isinstance(string, str):
+ return string
+ # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
+ # return eval(string)
+ if string[0:5] == "eval(":
+ return eval(string[5:-1])
+
+ s0 = string
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ if s1 != s0:
+ while s1 != s0:
+ s0 = s1
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
+ return eval(s1)
+
+ try:
+ v = ast.literal_eval(string)
+ except:
+ v = string
+ return v
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/config_utils.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..71ea4de10d0dbada0d106995043dc0ec67abda25
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/config_utils.py
@@ -0,0 +1,170 @@
+import logging
+import os
+import sys
+import json
+import torch.distributed as dist
+from os.path import dirname, join
+
+from utils.config import Config
+from utils.distributed import init_distributed_mode, is_main_process
+from utils.logger import setup_logger
+
+logger = logging.getLogger(__name__)
+
+
+def setup_config():
+ """Conbine yaml config and command line config with OmegaConf.
+ Also converts types, e.g., `'None'` (str) --> `None` (None)
+ """
+ config = Config.get_config()
+ if config.debug:
+ config.wandb.enable = False
+ return config
+
+
+def setup_evaluate_config(config):
+ """setup evaluation default settings, e.g., disable wandb"""
+ assert config.evaluate
+ config.wandb.enable = False
+ if config.output_dir is None:
+ config.output_dir = join(dirname(config.pretrained_path), "eval")
+ return config
+
+
+def setup_output_dir(output_dir, excludes=["code"]):
+ """ensure not overwritting an exisiting/non-empty output dir"""
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=False)
+ else:
+ existing_dirs_files = os.listdir(output_dir) # list
+ remaining = set(existing_dirs_files) - set(excludes)
+ remaining = [e for e in remaining if "slurm" not in e]
+ remaining = [e for e in remaining if ".out" not in e]
+ # assert len(remaining) == 0, f"remaining dirs or files: {remaining}"
+ logger.warn(f"remaining dirs or files: {remaining}")
+
+
+def setup_deepspeed_zero_config(stage):
+ # We currently set ZeRO based on stage:
+ if stage == 1:
+ return {"stage": 1, "reduce_bucket_size": 5e8}
+ if stage == 2:
+ return {
+ "stage": 2,
+ "contiguous_gradients": False,
+ "overlap_comm": False,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ "allgather_bucket_size": 5e8,
+ "offload_optimizer": {
+ "device": "cpu"
+ },
+ }
+ if stage == 3:
+ return {
+ "stage": 3,
+ "contiguous_gradients": True,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_prefetch_bucket_size": 1e7,
+ "stage3_param_persistence_threshold": 1e5,
+ "reduce_bucket_size": 1e7,
+ "sub_group_size": 1e9,
+ "offload_optimizer": {
+ "device": "cpu"
+ },
+ "offload_param": {
+ "device": "cpu"
+ }
+ }
+
+ raise ValueError("Wrong stage for deepspeed {}".format(stage.stage))
+
+def setup_deepspeed_config(config):
+ config.deepspeed_config = os.path.join(config.output_dir, "deepspeed_config.json")
+ opts = config.optimizer
+ logger.info(f'Write deepspeed config to {config.deepspeed_config}')
+ if not is_main_process():
+ return config
+
+ os.makedirs(config.output_dir, exist_ok=True)
+
+ with open(config.deepspeed_config, mode="w") as writer:
+ ds_config = {
+ "train_batch_size": config.batch_size * dist.get_world_size(),
+ "train_micro_batch_size_per_gpu": config.batch_size,
+ "steps_per_print": 100,
+ "optimizer": {
+ "type": "Adam",
+ "adam_w_mode": True,
+ "params": {
+ "lr": opts.lr,
+ "weight_decay": opts.weight_decay,
+ "bias_correction": True,
+ "betas": [
+ opts.opt_betas[0],
+ opts.opt_betas[1],
+ ],
+ "eps": 1e-8
+ }
+ }
+ }
+ if config.deepspeed.stage != 0:
+ ds_config["zero_optimization"] = setup_deepspeed_zero_config(config.deepspeed.stage)
+
+ if config.use_half_precision:
+ if config.get('use_bf16', False):
+ ds_config["bf16"] = {
+ "enabled": True
+ }
+ else:
+ ds_config["fp16"] = {
+ "enabled": True,
+ "auto_cast": False,
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "consecutive_hysteresis": False,
+ "min_loss_scale": 1
+ }
+ else:
+ assert config.deepspeed.stage == 0, "You must use fp16 or bf16 when using ZERO!!!"
+
+ # if config.get("max_grad_norm", -1) > 0:
+ # ds_config.update({"gradient_clipping", config.max_grad_norm})
+ if opts.get("max_grad_norm", -1) > 0:
+ ds_config["gradient_clipping"] = opts.max_grad_norm
+
+ writer.write(json.dumps(ds_config, indent=2))
+
+ return config
+
+
+def setup_main():
+ """
+ Setup config, logger, output_dir, etc.
+ Shared for pretrain and all downstream tasks.
+ """
+ # try:
+ config = setup_config()
+ if hasattr(config, "evaluate") and config.evaluate:
+ config = setup_evaluate_config(config)
+ init_distributed_mode(config)
+
+ if hasattr(config, "deepspeed") and config.deepspeed.enable:
+ config = setup_deepspeed_config(config)
+ # except Exception as e:
+ # print(f"\033[31m NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK \033[0m")
+ # logger.info(f"NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK")
+ # raise ValueError
+
+ if is_main_process():
+ setup_output_dir(config.output_dir, excludes=["code"])
+ setup_logger(output=config.output_dir, color=True, name="vindlu")
+ logger.info(f"config: {Config.pretty_text(config)}")
+ Config.dump(config, os.path.join(config.output_dir, "config.json"))
+
+ dist.barrier()
+
+ return config
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/distributed.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f29ecb5dc102f78ea3d55dac5a6f3c457c2deb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/distributed.py
@@ -0,0 +1,175 @@
+import os
+import torch
+import torch.distributed as dist
+import logging
+try:
+ import deepspeed
+except Exception as e:
+ print(e)
+ print("deepspeed is not installed!!!")
+import datetime
+from datetime import timedelta
+
+logger = logging.getLogger(__name__)
+
+
+def setup_for_distributed(is_master):
+ import warnings
+
+ builtin_warn = warnings.warn
+
+ def warn(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_warn(*args, **kwargs)
+
+ # Log warnings only once
+ warnings.warn = warn
+ warnings.simplefilter("once", UserWarning)
+
+ if not is_master:
+ logging.disable()
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def is_port_in_use(port):
+ import socket
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return s.connect_ex(('localhost', port)) == 0
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ # job started by torch.distributed.launch
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ # local rank on the current node / global rank
+ local_rank = int(os.environ['SLURM_LOCALID'])
+ global_rank = int(os.environ['SLURM_PROCID'])
+ # number of processes / GPUs per node
+ world_size = int(os.environ["SLURM_NNODES"]) * \
+ int(os.environ["SLURM_TASKS_PER_NODE"][0])
+
+ print(world_size)
+
+ args.rank = global_rank
+ args.gpu = local_rank
+ args.world_size = world_size
+ else:
+ logger.info('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+
+ if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
+ dist_port = int(args.dist_url.split(":")[-1])
+ while is_port_in_use(dist_port):
+ dist_port += 10
+ args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
+
+ logger.info('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url))
+ if "SLURM_JOB_ID" in os.environ:
+ logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
+
+ if hasattr(args, "deepspeed") and args.deepspeed.enable:
+ deepspeed.init_distributed(
+ dist_backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(seconds=7200)
+ )
+ else:
+ torch.distributed.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank, timeout=timedelta(minutes=60))
+
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
+ dist.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ dist.all_reduce(all_gradients)
+ return all_gradients[dist.get_rank()]
+
+
+# copied from megavlt
+def gather_tensor_along_batch_with_backward(tensor, dim=0):
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return tensor
+
+ tensor_list = GatherLayer.apply(tensor)
+ tensor_list = torch.cat(tensor_list, dim=dim)
+ return tensor_list
+
+
+@torch.no_grad()
+def gather_tensor_along_batch(tensor, dim=0):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return tensor
+
+ with torch.no_grad():
+ tensor_list = []
+
+ for _ in range(world_size):
+ tensor_list.append(torch.zeros_like(tensor))
+
+ dist.all_gather(tensor_list, tensor)
+ tensor_list = torch.cat(tensor_list, dim=dim)
+ return tensor_list
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/easydict.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/easydict.py
new file mode 100644
index 0000000000000000000000000000000000000000..241aca41c9f1b0677be4bf6070c077fa24501816
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/easydict.py
@@ -0,0 +1,149 @@
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ if hasattr(self, k):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
+
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/logger.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3164ae7251e1f0006173c4f409c0901742048d6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/logger.py
@@ -0,0 +1,263 @@
+# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import logging
+import os
+import sys
+import time
+import wandb
+from typing import Any, Dict, Union
+
+import torch
+from .distributed import get_rank, is_main_process
+from termcolor import colored
+
+
+def log_dict_to_wandb(log_dict, step, prefix=""):
+ """include a separator `/` at the end of `prefix`"""
+ if not is_main_process():
+ return
+
+ log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
+ wandb.log(log_dict, step)
+
+
+def setup_wandb(config):
+ if not (config.wandb.enable and is_main_process()):
+ return
+
+ run = wandb.init(
+ config=config,
+ project=config.wandb.project,
+ entity=config.wandb.entity,
+ name=os.path.basename(config.output_dir),
+ reinit=True
+ )
+ return run
+
+
+def setup_output_folder(save_dir: str, folder_only: bool = False):
+ """Sets up and returns the output file where the logs will be placed
+ based on the configuration passed. Usually "save_dir/logs/log_.txt".
+ If env.log_dir is passed, logs will be directly saved in this folder.
+ Args:
+ folder_only (bool, optional): If folder should be returned and not the file.
+ Defaults to False.
+ Returns:
+ str: folder or file path depending on folder_only flag
+ """
+ log_filename = "train_"
+ log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
+ log_filename += ".log"
+
+ log_folder = os.path.join(save_dir, "logs")
+
+ if not os.path.exists(log_folder):
+ os.path.mkdirs(log_folder)
+
+ if folder_only:
+ return log_folder
+
+ log_filename = os.path.join(log_folder, log_filename)
+
+ return log_filename
+
+
+def setup_logger(
+ output: str = None,
+ color: bool = True,
+ name: str = "mmf",
+ disable: bool = False,
+ clear_handlers=True,
+ *args,
+ **kwargs,
+):
+ """
+ Initialize the MMF logger and set its verbosity level to "INFO".
+ Outside libraries shouldn't call this in case they have set there
+ own logging handlers and setup. If they do, and don't want to
+ clear handlers, pass clear_handlers options.
+ The initial version of this function was taken from D2 and adapted
+ for MMF.
+ Args:
+ output (str): a file name or a directory to save log.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Default: Saved to file
+ color (bool): If false, won't log colored logs. Default: true
+ name (str): the root module name of this logger. Defaults to "mmf".
+ disable: do not use
+ clear_handlers (bool): If false, won't clear existing handlers.
+ Returns:
+ logging.Logger: a logger
+ """
+ if disable:
+ return None
+ logger = logging.getLogger(name)
+ logger.propagate = False
+
+ logging.captureWarnings(True)
+ warnings_logger = logging.getLogger("py.warnings")
+
+ plain_formatter = logging.Formatter(
+ "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+
+ distributed_rank = get_rank()
+ handlers = []
+
+ logging_level = logging.INFO
+ # logging_level = logging.DEBUG
+
+ if distributed_rank == 0:
+ logger.setLevel(logging_level)
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging_level)
+ if color:
+ formatter = ColorfulFormatter(
+ colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+ warnings_logger.addHandler(ch)
+ handlers.append(ch)
+
+ # file logging: all workers
+ if output is None:
+ output = setup_output_folder()
+
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "train.log")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging_level)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+ warnings_logger.addHandler(fh)
+ handlers.append(fh)
+
+ # Slurm/FB output, only log the main process
+ # save_dir = get_mmf_env(key="save_dir")
+ if "train.log" not in filename and distributed_rank == 0:
+ filename = os.path.join(output, "train.log")
+ sh = logging.StreamHandler(_cached_log_stream(filename))
+ sh.setLevel(logging_level)
+ sh.setFormatter(plain_formatter)
+ logger.addHandler(sh)
+ warnings_logger.addHandler(sh)
+ handlers.append(sh)
+
+ logger.info(f"Logging to: {filename}")
+
+ # Remove existing handlers to add MMF specific handlers
+ if clear_handlers:
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+ # Now, add our handlers.
+ logging.basicConfig(level=logging_level, handlers=handlers)
+
+ return logger
+
+
+def setup_very_basic_config(color=True):
+ plain_formatter = logging.Formatter(
+ "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.INFO)
+ if color:
+ formatter = ColorfulFormatter(
+ colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
+ datefmt="%Y-%m-%dT%H:%M:%S",
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ # Setup a minimal configuration for logging in case something tries to
+ # log a message even before logging is setup by MMF.
+ logging.basicConfig(level=logging.INFO, handlers=[ch])
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
+
+
+# ColorfulFormatter is adopted from Detectron2 and adapted for MMF
+class ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ log = super().formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+class TensorboardLogger:
+ def __init__(self, log_folder="./logs", iteration=0):
+ # This would handle warning of missing tensorboard
+ from torch.utils.tensorboard import SummaryWriter
+
+ self.summary_writer = None
+ self._is_master = is_main_process()
+ # self.timer = Timer()
+ self.log_folder = log_folder
+
+ if self._is_master:
+ # current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
+ current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
+ # self.timer.get_time_hhmmss(None, format=self.time_format)
+ tensorboard_folder = os.path.join(
+ self.log_folder, f"tensorboard_{current_time}"
+ )
+ self.summary_writer = SummaryWriter(tensorboard_folder)
+
+ def __del__(self):
+ if getattr(self, "summary_writer", None) is not None:
+ self.summary_writer.close()
+
+ def _should_log_tensorboard(self):
+ if self.summary_writer is None or not self._is_master:
+ return False
+ else:
+ return True
+
+ def add_scalar(self, key, value, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ self.summary_writer.add_scalar(key, value, iteration)
+
+ def add_scalars(self, scalar_dict, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ for key, val in scalar_dict.items():
+ self.summary_writer.add_scalar(key, val, iteration)
+
+ def add_histogram_for_model(self, model, iteration):
+ if not self._should_log_tensorboard():
+ return
+
+ for name, param in model.named_parameters():
+ np_param = param.clone().cpu().data.numpy()
+ self.summary_writer.add_histogram(name, np_param, iteration)
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/optimizer.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..264d31a1d032e424251a66b78c8220330de01d13
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/optimizer.py
@@ -0,0 +1,142 @@
+""" Optimizer Factory w/ Custom Weight Decay
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import re
+import torch
+from torch import optim as optim
+from utils.distributed import is_main_process
+import logging
+logger = logging.getLogger(__name__)
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
+ named_param_tuples = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
+ named_param_tuples.append([name, param, 0])
+ elif name in no_decay_list:
+ named_param_tuples.append([name, param, 0])
+ else:
+ named_param_tuples.append([name, param, weight_decay])
+ return named_param_tuples
+
+
+def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
+ """use lr=diff_lr for modules named found in diff_lr_names,
+ otherwise use lr=default_lr
+
+ Args:
+ named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
+ diff_lr_names: List(str)
+ diff_lr: float
+ default_lr: float
+ Returns:
+ named_param_tuples_with_lr: List([name, param, weight_decay, lr])
+ """
+ named_param_tuples_with_lr = []
+ logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
+ for name, p, wd in named_param_tuples_or_model:
+ use_diff_lr = False
+ for diff_name in diff_lr_names:
+ # if diff_name in name:
+ if re.search(diff_name, name) is not None:
+ logger.info(f"param {name} use different_lr: {diff_lr}")
+ use_diff_lr = True
+ break
+
+ named_param_tuples_with_lr.append(
+ [name, p, wd, diff_lr if use_diff_lr else default_lr]
+ )
+
+ if is_main_process():
+ for name, _, wd, diff_lr in named_param_tuples_with_lr:
+ logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
+
+ return named_param_tuples_with_lr
+
+
+def create_optimizer_params_group(named_param_tuples_with_lr):
+ """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
+ group = {}
+ for name, p, wd, lr in named_param_tuples_with_lr:
+ if wd not in group:
+ group[wd] = {}
+ if lr not in group[wd]:
+ group[wd][lr] = []
+ group[wd][lr].append(p)
+
+ optimizer_params_group = []
+ for wd, lr_groups in group.items():
+ for lr, p in lr_groups.items():
+ optimizer_params_group.append(dict(
+ params=p,
+ weight_decay=wd,
+ lr=lr
+ ))
+ logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
+ return optimizer_params_group
+
+
+def create_optimizer(args, model, filter_bias_and_bn=True, return_group=False):
+ opt_lower = args.opt.lower()
+ weight_decay = args.weight_decay
+ # check for modules that requires different lr
+ if hasattr(args, "different_lr") and args.different_lr.enable:
+ diff_lr_module_names = args.different_lr.module_names
+ diff_lr = args.different_lr.lr
+ else:
+ diff_lr_module_names = []
+ diff_lr = None
+
+ no_decay = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay = model.no_weight_decay()
+
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ if hasattr(model.module, 'no_weight_decay'):
+ no_decay = model.module.no_weight_decay()
+ no_decay = {"module." + k for k in no_decay}
+
+ named_param_tuples = add_weight_decay(
+ model, weight_decay, no_decay, filter_bias_and_bn)
+ named_param_tuples = add_different_lr(
+ named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
+ parameters = create_optimizer_params_group(named_param_tuples)
+
+ if return_group:
+ return parameters
+
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
+ opt_args['eps'] = args.opt_eps
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
+ opt_args['betas'] = args.opt_betas
+ if hasattr(args, 'opt_args') and args.opt_args is not None:
+ opt_args.update(args.opt_args)
+
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+ return optimizer
diff --git a/third_party/InternVideo/InternVideo2/multi_modality/utils/scheduler.py b/third_party/InternVideo/InternVideo2/multi_modality/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a29f845c34f560a2525275ec91f79db8f71025f3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/multi_modality/utils/scheduler.py
@@ -0,0 +1,60 @@
+""" Scheduler Factory
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch.optim import Optimizer
+import math
+from torch.optim.lr_scheduler import LambdaLR
+
+
+def create_scheduler(args, optimizer):
+ lr_scheduler = None
+ if args.sched == 'cosine':
+ lr_scheduler = get_cosine_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=args.num_warmup_steps,
+ num_training_steps=args.num_training_steps,
+ num_cycles=0.5,
+ min_lr_multi=args.min_lr_multi,
+ last_epoch=args.get('last_epoch', -1)
+ )
+ else:
+ raise NotImplementedError(args.sched)
+
+ return lr_scheduler
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
+ num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
+):
+ """
+ Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
+
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ min_lr_multi (`float`, *optional*, defaults to 0):
+ The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/DATASET.md b/third_party/InternVideo/InternVideo2/single_modality/DATASET.md
new file mode 100644
index 0000000000000000000000000000000000000000..e12e01eba838b5f672a2b31782f0f7158309b8bf
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/DATASET.md
@@ -0,0 +1,11 @@
+# Dataset Preparation
+
+We follow [UniFormerV2](https://github.com/OpenGVLab/UniFormerV2/) to prepare the datasets. All the data files can be found [here](https://drive.google.com/drive/folders/17VB-XdF3Kfr9ORmnGyXCxTMs86n0L4QL?usp=sharing), including:
+- [Kinetics-400/600/700](https://www.deepmind.com/open-source/kinetics)
+- [Kinetics-710](https://github.com/OpenGVLab/UniFormerV2/blob/main/DATASET.md)
+- [Moments in Time V1](http://moments.csail.mit.edu/)
+- [Something-Something V1&V2](https://developer.qualcomm.com/software/ai-datasets/something-something)
+- [ANet](http://activity-net.org/)
+- [HACS](http://hacs.csail.mit.edu/)
+
+For K-Mash, we merge the above videos and other self-collected videos.
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/INSTALL.md b/third_party/InternVideo/InternVideo2/single_modality/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..2055529af121c1980952a7bf2da5b09a8ad63664
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/INSTALL.md
@@ -0,0 +1,17 @@
+# Installation
+
+## Requirements
+
+We mainly follow [UMT](https://github.com/OpenGVLab/Unmasked_Teacher) to prepare the enviroment.
+
+```shell
+pip install -r requirements.txt
+```
+
+We follow UMT to set `--epochs 201` to avoid the potential interrupt in the last epoch.
+
+> We observed accidental interrupt in the last epoch when conducted the pre-training experiments on V100 GPUs (PyTorch 1.6.0). This interrupt is caused by the scheduler of learning rate. We naively set --epochs 801 to walk away from issue.
+
+## Note
+
+To run InternVideo2 pretraining, you have to prepare the weights of the **[InternVL-6B visual encoder](https://huggingface.co/OpenGVLab/InternVL/blob/main/internvl_c_13b_224px.pth)** and **[VideoMAEv2-g](https://github.com/OpenGVLab/VideoMAEv2/blob/master/docs/MODEL_ZOO.md)**, and set the `your_model_path` in [internvl_clip_vision.py](./models/internvl_clip_vision.py) and [videomae.py](./models/videomae.py).
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/MODEL_ZOO.md b/third_party/InternVideo/InternVideo2/single_modality/MODEL_ZOO.md
new file mode 100644
index 0000000000000000000000000000000000000000..f64f022d4a4754ee2926010ed5afc54eff17c1bb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/MODEL_ZOO.md
@@ -0,0 +1,97 @@
+# Model Zoo
+
+## Note
+
+- For all the pretraining and finetuning, we adopt spaese/uniform sampling.
+- `#Frame` $=$ `#input_frame` $\times$ `#crop` $\times$ `#clip`
+- `#input_frame` means how many frames are input for model per inference
+- `#crop` means spatial crops (e.g., 3 for left/right/center)
+- `#clip` means temporal clips (e.g., 4 means repeted sampling four clips with different start indices)
+
+## Pretraining
+
+| Model | Setting | Model | Shell |
+| -------- | ----------- | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash-1.1M 300e | TBD | [run.sh](./scripts/pretraining/1B_pt.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash-2M 300e | TBD | [run.sh](./scripts/pretraining/6B_pt.sh) |
+
+
+## Finetuning
+
+### K710
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT | 8x3x4 | 87.6 | TBD | [run.sh](./scripts/finetuning/full_tuning/k710/1B_ft_k710_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT | 8x3x4 | 88.1 | TBD | [run.sh](./scripts/finetuning/full_tuning/k710/6B_ft_k710_f8.sh) |
+
+
+### K400
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ------------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 8x3x4 | 91.3 | TBD | [run.sh](./scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f8.sh) |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 16x3x4 | 91.6 | TBD | [run.sh](./scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f16.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 8x3x4 | 91.9 | TBD | [run.sh](./scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 16x3x4 | 92.1 | TBD | [run.sh](./scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f16.sh) |
+
+
+### K600
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ------------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 8x3x4 | 91.4 | TBD | [run.sh](./scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f8.sh) |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 16x3x4 | 91.6 | TBD | [run.sh](./scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f16.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 8x3x4 | 91.7 | TBD | [run.sh](./scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 16x3x4 | 91.9 | TBD | [run.sh](./scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f16.sh) |
+
+
+
+### K700
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ------------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 8x3x4 | 85.0 | TBD | [run.sh](./scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f8.sh) |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 16x3x4 | 85.4 | TBD | [run.sh](./scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f16.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 8x3x4 | 85.7 | TBD | [run.sh](./scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 16x3x4 | 85.9 | TBD | [run.sh](./scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f16.sh) |
+
+
+### MiT V1
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| ------------- | -------------------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT + K400 FT | 8x3x4 | 50.8 | TBD | [run.sh](./scripts/finetuning/full_tuning/mit/1B_ft_k710_ft_k400_ft_mit_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT + K400 FT | 8x3x4 | 51.0 | TBD | [run.sh](./scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B 336↑ | K-Mash PT + K710 FT + K400 FT | 8x3x4 | 51.2 | TBD | [run.sh](./scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8_res224to336.sh) |
+
+
+### SthSth V1
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ----------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT | 8x3x4 | 68.5 | TBD | [run.sh](./scripts/finetuning/full_tuning/ssv1/1B_ft_ssv1_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT | 8x3x4 | 69.7 | TBD | [run.sh](./scripts/finetuning/full_tuning/ssv1/6B_ft_ssv1_f8.sh) |
+
+
+### SthSth V2
+
+| Model | Setting | #Frame | Top-1 | Model | Shell |
+| -------- | ----------- | -------- | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-1B | K-Mash PT | 8x3x4 | 77.1 | TBD | [run.sh](./scripts/finetuning/full_tuning/ssv1/1B_ft_ssv1_f8.sh) |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT | 8x3x4 | 77.5 | TBD | [run.sh](./scripts/finetuning/full_tuning/ssv1/6B_ft_ssv1_f8.sh) |
+
+
+
+### ANet
+
+| Model | Setting | #Frame | Top-1 | mAP | Model | Shell |
+| ------------- | -------------------- | -------- | ------ | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT + K400 FT | 8x3x4 | 95.9 | 98.2 | TBD | [run.sh](./scripts/finetuning/full_tuning/anet/6B_ft_k710_ft_k400_ap_anet_f8.sh) |
+
+
+### HACS
+
+| Model | Setting | #Frame | Top-1 | mAP | Model | Shell |
+| ------------- | -------------------- | -------- | ------ | ------ | ------ | ------ |
+| $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT + K400 FT | 8x3x4 | 97.0 | 98.8 | TBD | [run.sh](./scripts/finetuning/full_tuning/hacs/6B_ft_k710_ft_k400_ap_hacs_f8.sh) |
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/README.md b/third_party/InternVideo/InternVideo2/single_modality/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a3527b13aadd30decdb1dc3838a8a02e2b2d0d34
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/README.md
@@ -0,0 +1,44 @@
+# Single-modality
+
+## Installation
+
+Please follow the installation instructions in [INSTALL](./INSTALL.md).
+
+## Datasets
+
+You can find the dataset instructions in [DATASET](./DATASET.md).
+
+## Model ZOO
+
+You can find all the models and the scripts in [MODEL_ZOO](./MODEL_ZOO.md).
+
+## Pre-Training
+
+We use [InternVL](https://github.com/OpenGVLab/InternVL/) and [VideoMAEv2](https://github.com/OpenGVLab/VideoMAEv2) pretrained models as teachers by default
+
+For training, you can simply run the pretraining scripts in `scripts/pretraining` as follows:
+```shell
+bash ./scripts/pretraining/1B_pt.sh
+```
+
+:warning: **Notes:**
+1. Chage `DATA_PATH` to your data path before running the scripts.
+2. `--sampling_rate` is set to 1 for **sprase sampling**.
+3. The latest checkpoint will be automatically saved while training, thus we use a large `--save_ckpt_freq`.
+4. For InternVideo2-1B and 6B, we use InternVL-C-13B and VideoMAEv2-g.
+
+
+## Finetuning
+
+For finetuning, you can simply run the pretraining scripts in `scripts/finetuning` as follows:
+```shell
+bash ./scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f8.sh
+```
+
+:warning: **Notes:**
+1. Chage `DATA_PATH` And `PREFIX` to your data path before running the scripts.
+2. Chage `MODEL_PATH` to your model path.
+3. Set `--use_checkpoint` and `--checkpoint_num` to save GPU memory.
+4. The best checkpoint will be automatically evaluated with `--test_best`.
+5. Set `--test_num_segment` and `--test_num_crop` for different evaluation strategies.
+6. To only run evaluation, just set `--eval`.
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/__init__.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d27e0d69f7b0a31fd847237db15580bea60245b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/__init__.py
@@ -0,0 +1 @@
+from .build import build_dataset, build_pretraining_dataset, build_multi_pretraining_dataset
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/anet.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/anet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b21083b377a4bdb8cb3d111c5866a8d1fe052b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/anet.py
@@ -0,0 +1,414 @@
+import os
+import os
+import io
+import random
+import numpy as np
+from numpy.lib.function_base import disp
+import torch
+from torchvision import transforms
+import warnings
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from .random_erasing import RandomErasing
+from .video_transforms import (
+ Compose, Resize, CenterCrop, Normalize,
+ create_random_augment, random_short_side_scale_jitter,
+ random_crop, random_resized_crop_with_shift, random_resized_crop,
+ horizontal_flip, random_short_side_scale_jitter, uniform_crop,
+)
+from .volume_transforms import ClipToTensor
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+class ANetDataset(Dataset):
+ """Load your own video classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ frame_sample_rate=2, crop_size=224, short_side_size=256,
+ new_height=256, new_width=340, keep_aspect_ratio=True,
+ num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,
+ args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.frame_sample_rate = frame_sample_rate
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+ assert num_segment == 1
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError("Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0])
+ self.total_time = list(cleaned.values[:, 1])
+ self.start_time = list(cleaned.values[:, 2])
+ self.end_time = list(cleaned.values[:, 3])
+ self.label_array = list(cleaned.values[:, 4])
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size, interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size, self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size), interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_total_time = []
+ self.test_start_time = []
+ self.test_end_time = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ self.test_total_time.append(self.total_time[idx])
+ self.test_start_time.append(self.start_time[idx])
+ self.test_end_time.append(self.end_time[idx])
+ sample_label = self.label_array[idx]
+ self.test_label_array.append(sample_label)
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_seg.append((ck, cp))
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ sample = self.dataset_samples[index]
+ total_time, start_time, end_time = self.total_time[index], self.start_time[index], self.end_time[index]
+ buffer = self.loadvideo_decord(sample, total_time, start_time, end_time, chunk_nb=-1) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during training".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ total_time, start_time, end_time = self.total_time[index], self.start_time[index], self.end_time[index]
+ buffer = self.loadvideo_decord(sample, total_time, start_time, end_time, chunk_nb=-1)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ total_time, start_time, end_time = self.total_time[index], self.start_time[index], self.end_time[index]
+ buffer = self.loadvideo_decord(sample, total_time, start_time, end_time, chunk_nb=0)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=0)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ total_time, start_time, end_time = self.test_total_time[index], self.test_start_time[index], self.test_end_time[index]
+ buffer = self.loadvideo_decord(sample, total_time, start_time, end_time, chunk_nb=chunk_nb)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+ if self.test_num_crop == 1:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2
+ spatial_start = int(spatial_step)
+ else:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [
+ transforms.ToPILImage()(frame) for frame in buffer
+ ]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ )
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False
+ )
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+ def _get_seq_frames(self, video_size, start_index, num_frames, clip_idx=-1):
+ seg_size = max(0., float(video_size - 1) / num_frames)
+ max_frame = int(video_size) - 1
+ seq = []
+ # index from 1, must add 1
+ if clip_idx == -1:
+ for i in range(num_frames):
+ start = int(np.round(seg_size * i))
+ end = int(np.round(seg_size * (i + 1)))
+ idx = min(random.randint(start, end), max_frame)
+ seq.append(idx)
+ else:
+ num_segment = 1
+ if self.mode == 'test':
+ num_segment = self.test_num_segment
+ duration = seg_size / (num_segment + 1)
+ for i in range(num_frames):
+ start = int(np.round(seg_size * i))
+ frame_index = start + int(duration * (clip_idx + 1))
+ idx = min(frame_index, max_frame)
+ seq.append(idx)
+ seq = np.array(seq)
+ return seq + start_index
+
+ def loadvideo_decord(self, sample, total_time, start_time, end_time, chunk_nb=0):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ try:
+ if self.keep_aspect_ratio:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
+ else:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ width=self.new_width,
+ height=self.new_height,
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
+ num_threads=1, ctx=cpu(0))
+
+ duration = len(vr)
+ start_index = 0
+
+ if total_time!= -1 and start_time != -1 and end_time != -1:
+ fps = duration / total_time
+ duration = int(fps * (end_time - start_time))
+ start_index = int(fps * start_time)
+
+ all_index = self._get_seq_frames(duration, start_index, self.clip_len, clip_idx=chunk_nb)
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+ except:
+ print("video cannot be loaded by decord: ", fname)
+ return []
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+def spatial_sampling(
+ frames,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=224,
+ random_horizontal_flip=True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=None,
+ scale=None,
+ motion_shift=False,
+):
+ """
+ Perform spatial sampling on the given video frames. If spatial_idx is
+ -1, perform random scale, random crop, and random flip on the given
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
+ with the given spatial_idx.
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `num frames` x `height` x `width` x `channel`.
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
+ or 2, perform left, center, right crop if width is larger than
+ height, and perform top, center, buttom crop if height is larger
+ than width.
+ min_scale (int): the minimal size of scaling.
+ max_scale (int): the maximal size of scaling.
+ crop_size (int): the size of height and width used to crop the
+ frames.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale,
+ max_scale].
+ aspect_ratio (list): Aspect ratio range for resizing.
+ scale (list): Scale range for resizing.
+ motion_shift (bool): Whether to apply motion shift for resizing.
+ Returns:
+ frames (tensor): spatially sampled frames.
+ """
+ assert spatial_idx in [-1, 0, 1, 2]
+ if spatial_idx == -1:
+ if aspect_ratio is None and scale is None:
+ frames, _ = random_short_side_scale_jitter(
+ images=frames,
+ min_size=min_scale,
+ max_size=max_scale,
+ inverse_uniform_sampling=inverse_uniform_sampling,
+ )
+ frames, _ = random_crop(frames, crop_size)
+ else:
+ transform_func = (
+ random_resized_crop_with_shift
+ if motion_shift
+ else random_resized_crop
+ )
+ frames = transform_func(
+ images=frames,
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=scale,
+ ratio=aspect_ratio,
+ )
+ if random_horizontal_flip:
+ frames, _ = horizontal_flip(0.5, frames)
+ else:
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ assert len({min_scale, max_scale, crop_size}) == 1
+ frames, _ = random_short_side_scale_jitter(
+ frames, min_scale, max_scale
+ )
+ frames, _ = uniform_crop(frames, crop_size, spatial_idx)
+ return frames
+
+
+def tensor_normalize(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize.
+ mean (tensor or list): mean value to subtract.
+ std (tensor or list): std to divide.
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+ tensor = tensor / 255.0
+ if type(mean) == list:
+ mean = torch.tensor(mean)
+ if type(std) == list:
+ std = torch.tensor(std)
+ tensor = tensor - mean
+ tensor = tensor / std
+ return tensor
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/build.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e277bb42c54b540c555f6fe3eb0efef0f7c242a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/build.py
@@ -0,0 +1,313 @@
+import os
+from torchvision import transforms
+from .transforms import *
+from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator
+from .mae import VideoMAE
+from .mae_multi import VideoMAE_multi
+from .kinetics import VideoClsDataset
+from .kinetics_sparse import VideoClsDataset_sparse
+from .anet import ANetDataset
+from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset
+from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset
+
+
+class DataAugmentationForVideoMAE(object):
+ def __init__(self, args):
+ self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
+ self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
+ normalize = GroupNormalize(self.input_mean, self.input_std)
+ self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
+ if args.color_jitter > 0:
+ self.transform = transforms.Compose([
+ self.train_augmentation,
+ GroupColorJitter(args.color_jitter),
+ GroupRandomHorizontalFlip(flip=args.flip),
+ Stack(roll=False),
+ ToTorchFormatTensor(div=True),
+ normalize,
+ ])
+ else:
+ self.transform = transforms.Compose([
+ self.train_augmentation,
+ GroupRandomHorizontalFlip(flip=args.flip),
+ Stack(roll=False),
+ ToTorchFormatTensor(div=True),
+ normalize,
+ ])
+ if args.mask_type == 'tube':
+ self.masked_position_generator = TubeMaskingGenerator(
+ args.window_size, args.mask_ratio
+ )
+ elif args.mask_type == 'random':
+ self.masked_position_generator = RandomMaskingGenerator(
+ args.window_size, args.mask_ratio
+ )
+ elif args.mask_type in 'attention':
+ self.masked_position_generator = None
+
+ def __call__(self, images):
+ process_data, _ = self.transform(images)
+ if self.masked_position_generator is None:
+ return process_data, -1
+ else:
+ return process_data, self.masked_position_generator()
+
+ def __repr__(self):
+ repr = "(DataAugmentationForVideoMAE,\n"
+ repr += " transform = %s,\n" % str(self.transform)
+ repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
+ repr += ")"
+ return repr
+
+
+def build_pretraining_dataset(args):
+ transform = DataAugmentationForVideoMAE(args)
+ dataset = VideoMAE(
+ root=None,
+ setting=args.data_path,
+ prefix=args.prefix,
+ split=args.split,
+ video_ext='mp4',
+ is_color=True,
+ modality='rgb',
+ num_segments=args.num_segments,
+ new_length=args.num_frames,
+ new_step=args.sampling_rate,
+ transform=transform,
+ temporal_jitter=False,
+ video_loader=True,
+ use_decord=args.use_decord,
+ lazy_init=False,
+ num_sample=args.num_sample)
+ print("Data Aug = %s" % str(transform))
+ return dataset
+
+
+def build_multi_pretraining_dataset(args):
+ origianl_flip = args.flip
+ transform = DataAugmentationForVideoMAE(args)
+ args.flip = False
+ transform_ssv2 = DataAugmentationForVideoMAE(args)
+ args.flip = origianl_flip
+
+ dataset = VideoMAE_multi(
+ root=None,
+ setting=args.data_path,
+ prefix=args.prefix,
+ split=args.split,
+ is_color=True,
+ modality='rgb',
+ num_segments=args.num_segments,
+ new_length=args.num_frames,
+ new_step=args.sampling_rate,
+ transform=transform,
+ transform_ssv2=transform_ssv2,
+ temporal_jitter=False,
+ video_loader=True,
+ use_decord=args.use_decord,
+ lazy_init=False,
+ num_sample=args.num_sample)
+ print("Data Aug = %s" % str(transform))
+ print("Data Aug for SSV2 = %s" % str(transform_ssv2))
+ return dataset
+
+
+def build_dataset(is_train, test_mode, args):
+ print(f'Use Dataset: {args.data_set}')
+ if args.data_set in [
+ 'Kinetics',
+ 'Kinetics_sparse',
+ 'mitv1_sparse'
+ ]:
+ mode = None
+ anno_path = None
+ if is_train is True:
+ mode = 'train'
+ anno_path = os.path.join(args.data_path, 'train.csv')
+ elif test_mode is True:
+ mode = 'test'
+ anno_path = os.path.join(args.data_path, 'test.csv')
+ else:
+ mode = 'validation'
+ anno_path = os.path.join(args.data_path, 'val.csv')
+
+ if 'sparse' in args.data_set:
+ func = VideoClsDataset_sparse
+ else:
+ func = VideoClsDataset
+
+ dataset = func(
+ anno_path=anno_path,
+ prefix=args.prefix,
+ split=args.split,
+ mode=mode,
+ clip_len=args.num_frames,
+ frame_sample_rate=args.sampling_rate,
+ num_segment=1,
+ test_num_segment=args.test_num_segment,
+ test_num_crop=args.test_num_crop,
+ num_crop=1 if not test_mode else 3,
+ keep_aspect_ratio=True,
+ crop_size=args.input_size,
+ short_side_size=args.short_side_size,
+ new_height=256,
+ new_width=320,
+ args=args)
+
+ nb_classes = args.nb_classes
+
+ elif args.data_set == 'SSV2':
+ mode = None
+ anno_path = None
+ if is_train is True:
+ mode = 'train'
+ anno_path = os.path.join(args.data_path, 'train.csv')
+ elif test_mode is True:
+ mode = 'test'
+ anno_path = os.path.join(args.data_path, 'test.csv')
+ else:
+ mode = 'validation'
+ anno_path = os.path.join(args.data_path, 'val.csv')
+
+ if args.use_decord:
+ func = SSVideoClsDataset
+ else:
+ func = SSRawFrameClsDataset
+
+ dataset = func(
+ anno_path=anno_path,
+ prefix=args.prefix,
+ split=args.split,
+ mode=mode,
+ clip_len=1,
+ num_segment=args.num_frames,
+ test_num_segment=args.test_num_segment,
+ test_num_crop=args.test_num_crop,
+ num_crop=1 if not test_mode else 3,
+ keep_aspect_ratio=True,
+ crop_size=args.input_size,
+ short_side_size=args.short_side_size,
+ new_height=256,
+ new_width=320,
+ filename_tmpl=args.filename_tmpl,
+ args=args)
+ nb_classes = 174
+
+ elif args.data_set == 'UCF101':
+ mode = None
+ anno_path = None
+ if is_train is True:
+ mode = 'train'
+ anno_path = os.path.join(args.data_path, 'train.csv')
+ elif test_mode is True:
+ mode = 'test'
+ anno_path = os.path.join(args.data_path, 'test.csv')
+ else:
+ mode = 'validation'
+ anno_path = os.path.join(args.data_path, 'val.csv')
+
+ dataset = VideoClsDataset(
+ anno_path=anno_path,
+ prefix=args.prefix,
+ split=args.split,
+ mode=mode,
+ clip_len=args.num_frames,
+ frame_sample_rate=args.sampling_rate,
+ num_segment=1,
+ test_num_segment=args.test_num_segment,
+ test_num_crop=args.test_num_crop,
+ num_crop=1 if not test_mode else 3,
+ keep_aspect_ratio=True,
+ crop_size=args.input_size,
+ short_side_size=args.short_side_size,
+ new_height=256,
+ new_width=320,
+ args=args)
+ nb_classes = 101
+
+ elif args.data_set == 'HMDB51':
+ mode = None
+ anno_path = None
+ if is_train is True:
+ mode = 'train'
+ anno_path = os.path.join(args.data_path, 'train.csv')
+ elif test_mode is True:
+ mode = 'test'
+ anno_path = os.path.join(args.data_path, 'test.csv')
+ else:
+ mode = 'validation'
+ anno_path = os.path.join(args.data_path, 'val.csv')
+
+ if args.use_decord:
+ func = HMDBVideoClsDataset
+ else:
+ func = HMDBRawFrameClsDataset
+
+ dataset = func(
+ anno_path=anno_path,
+ prefix=args.prefix,
+ split=args.split,
+ mode=mode,
+ clip_len=1,
+ num_segment=args.num_frames,
+ test_num_segment=args.test_num_segment,
+ test_num_crop=args.test_num_crop,
+ num_crop=1 if not test_mode else 3,
+ keep_aspect_ratio=True,
+ crop_size=args.input_size,
+ short_side_size=args.short_side_size,
+ new_height=256,
+ new_width=320,
+ filename_tmpl=args.filename_tmpl,
+ args=args)
+ nb_classes = 51
+
+ elif args.data_set in [
+ 'ANet',
+ 'HACS',
+ 'ANet_interval',
+ 'HACS_interval'
+ ]:
+ mode = None
+ anno_path = None
+ if is_train is True:
+ mode = 'train'
+ anno_path = os.path.join(args.data_path, 'train.csv')
+ elif test_mode is True:
+ mode = 'test'
+ anno_path = os.path.join(args.data_path, 'test.csv')
+ else:
+ mode = 'validation'
+ anno_path = os.path.join(args.data_path, 'val.csv')
+
+ if 'interval' in args.data_set:
+ func = ANetDataset
+ else:
+ func = VideoClsDataset_sparse
+
+ dataset = func(
+ anno_path=anno_path,
+ prefix=args.prefix,
+ split=args.split,
+ mode=mode,
+ clip_len=args.num_frames,
+ frame_sample_rate=args.sampling_rate,
+ num_segment=1,
+ test_num_segment=args.test_num_segment,
+ test_num_crop=args.test_num_crop,
+ num_crop=1 if not test_mode else 3,
+ keep_aspect_ratio=True,
+ crop_size=args.input_size,
+ short_side_size=args.short_side_size,
+ new_height=256,
+ new_width=320,
+ args=args)
+ nb_classes = args.nb_classes
+
+ else:
+ print(f'Wrong: {args.data_set}')
+ raise NotImplementedError()
+ assert nb_classes == args.nb_classes
+ print("Number of the class = %d" % args.nb_classes)
+
+ return dataset, nb_classes
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/hmdb.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/hmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97317a4ffa888fee2fb5b71d96cbf17c8127215
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/hmdb.py
@@ -0,0 +1,704 @@
+import os
+import io
+import cv2
+import numpy as np
+import torch
+from torchvision import transforms
+import warnings
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from .random_erasing import RandomErasing
+from .video_transforms import (
+ Compose, Resize, CenterCrop, Normalize,
+ create_random_augment, random_short_side_scale_jitter,
+ random_crop, random_resized_crop_with_shift, random_resized_crop,
+ horizontal_flip, random_short_side_scale_jitter, uniform_crop,
+)
+from .volume_transforms import ClipToTensor
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+
+class HMDBRawFrameClsDataset(Dataset):
+ """Load your own raw frame classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ crop_size=224, short_side_size=256, new_height=256, new_width=340,
+ keep_aspect_ratio=True, num_segment=1, num_crop=1, test_num_segment=10,
+ test_num_crop=3, filename_tmpl='img_{:05}.jpg', args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.filename_tmpl = filename_tmpl
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError(
+ "Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0].astype('str'))
+ self.total_frames = list(cleaned.values[:, 1] - 1) # max - 1
+ self.label_array = list(cleaned.values[:, -1])
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size,
+ interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size,
+ self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size),
+ interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_total_frames = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ self.test_seg.append((ck, cp))
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_total_frames.append(self.total_frames[idx])
+ self.test_label_array.append(self.label_array[idx])
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ scale_t = 1
+
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample,
+ total_frame,
+ sample_rate_scale=scale_t) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn(
+ "video {} not correctly loaded during training".format(
+ sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample,
+ total_frame,
+ sample_rate_scale=scale_t)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample, total_frame)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn(
+ "video {} not correctly loaded during validation".
+ format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.load_frame(sample, total_frame)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split(
+ "/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ total_frame = self.test_total_frames[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.load_frame(sample, total_frame)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ total_frame = self.test_total_frames[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.load_frame(sample, total_frame)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+
+ if self.test_num_crop == 1:
+ spatial_start = int(1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2)
+ else:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ spatial_start = int(split_nb * spatial_step)
+ temporal_start = chunk_nb
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[temporal_start::self.test_num_segment, \
+ spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[temporal_start::self.test_num_segment, \
+ :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225])
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False)
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+ def load_frame(self, sample, num_frames, sample_rate_scale=1):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ if self.mode == 'test':
+ tick = num_frames / float(self.num_segment)
+ all_index = []
+ for t_seg in range(self.test_num_segment):
+ tmp_index = [
+ int(t_seg * tick / self.test_num_segment + tick * x)
+ for x in range(self.num_segment)
+ ]
+ all_index.extend(tmp_index)
+ all_index = list(np.sort(np.array(all_index)))
+ imgs = []
+ for idx in all_index:
+ frame_fname = os.path.join(fname, self.filename_tmpl.format(idx + 1))
+ if "s3://" in fname:
+ img_bytes = self.client.get(frame_fname)
+ else:
+ with open(frame_fname, 'rb') as f:
+ img_bytes = f.read()
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ imgs.append(img)
+ buffer = np.array(imgs)
+ return buffer
+
+ # handle temporal segments
+ average_duration = num_frames // self.num_segment
+ all_index = []
+ if average_duration > 0:
+ if self.mode == 'validation':
+ all_index = list(
+ np.multiply(list(range(self.num_segment)),
+ average_duration) +
+ np.ones(self.num_segment, dtype=int) *
+ (average_duration // 2))
+ else:
+ all_index = list(
+ np.multiply(list(range(self.num_segment)),
+ average_duration) +
+ np.random.randint(average_duration, size=self.num_segment))
+ elif num_frames > self.num_segment:
+ if self.mode == 'validation':
+ all_index = list(range(self.num_segment))
+ else:
+ all_index = list(
+ np.sort(
+ np.random.randint(num_frames, size=self.num_segment)))
+ else:
+ all_index = [0] * (self.num_segment - num_frames) + list(
+ range(num_frames))
+ all_index = list(np.array(all_index))
+ imgs = []
+ for idx in all_index:
+ frame_fname = os.path.join(fname, self.filename_tmpl.format(idx + 1))
+ if "s3://" in fname:
+ img_bytes = self.client.get(frame_fname)
+ else:
+ with open(frame_fname, 'rb') as f:
+ img_bytes = f.read()
+ try:
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ except Exception:
+ print(f"Error when reading {frame_fname}", flush=True)
+ return []
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ imgs.append(img)
+ buffer = np.array(imgs)
+ return buffer
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+class HMDBVideoClsDataset(Dataset):
+ """Load your own video classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ crop_size=224, short_side_size=256, new_height=256,
+ new_width=340, keep_aspect_ratio=True, num_segment=1,
+ num_crop=1, test_num_segment=10, test_num_crop=3, filename_tmpl=None, args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError("Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0])
+ self.label_array = list(cleaned.values[:, 1])
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size, interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size, self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size), interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ sample_label = self.label_array[idx]
+ self.test_label_array.append(sample_label)
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_seg.append((ck, cp))
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ scale_t = 1
+
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during training".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ temporal_start = chunk_nb # 0/1
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[temporal_start::2, \
+ spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[temporal_start::2, \
+ :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [
+ transforms.ToPILImage()(frame) for frame in buffer
+ ]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ )
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False
+ )
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+
+ def loadvideo_decord(self, sample, sample_rate_scale=1):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ try:
+ if self.keep_aspect_ratio:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
+ else:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ width=self.new_width,
+ height=self.new_height,
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
+ num_threads=1, ctx=cpu(0))
+ except:
+ print("video cannot be loaded by decord: ", fname)
+ return []
+
+ if self.mode == 'test':
+ tick = len(vr) / float(self.num_segment)
+ all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
+ [int(tick * x) for x in range(self.num_segment)]))
+ while len(all_index) < (self.num_segment * self.test_num_segment):
+ all_index.append(all_index[-1])
+ all_index = np.sort(np.array(all_index))
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+ elif self.mode == 'validation':
+ tick = len(vr) / float(self.num_segment)
+ all_index = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)])
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+
+ # handle temporal segments
+ average_duration = len(vr) // self.num_segment
+ if average_duration > 0:
+ all_index = list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
+ size=self.num_segment))
+ elif len(vr) > self.num_segment:
+ all_index = list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
+ else:
+ all_index = list(np.zeros((self.num_segment,)))
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+def spatial_sampling(
+ frames,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=224,
+ random_horizontal_flip=True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=None,
+ scale=None,
+ motion_shift=False,
+):
+ """
+ Perform spatial sampling on the given video frames. If spatial_idx is
+ -1, perform random scale, random crop, and random flip on the given
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
+ with the given spatial_idx.
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `num frames` x `height` x `width` x `channel`.
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
+ or 2, perform left, center, right crop if width is larger than
+ height, and perform top, center, buttom crop if height is larger
+ than width.
+ min_scale (int): the minimal size of scaling.
+ max_scale (int): the maximal size of scaling.
+ crop_size (int): the size of height and width used to crop the
+ frames.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale,
+ max_scale].
+ aspect_ratio (list): Aspect ratio range for resizing.
+ scale (list): Scale range for resizing.
+ motion_shift (bool): Whether to apply motion shift for resizing.
+ Returns:
+ frames (tensor): spatially sampled frames.
+ """
+ assert spatial_idx in [-1, 0, 1, 2]
+ if spatial_idx == -1:
+ if aspect_ratio is None and scale is None:
+ frames, _ = random_short_side_scale_jitter(
+ images=frames,
+ min_size=min_scale,
+ max_size=max_scale,
+ inverse_uniform_sampling=inverse_uniform_sampling,
+ )
+ frames, _ = random_crop(frames, crop_size)
+ else:
+ transform_func = (
+ random_resized_crop_with_shift
+ if motion_shift
+ else random_resized_crop
+ )
+ frames = transform_func(
+ images=frames,
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=scale,
+ ratio=aspect_ratio,
+ )
+ if random_horizontal_flip:
+ frames, _ = horizontal_flip(0.5, frames)
+ else:
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ assert len({min_scale, max_scale, crop_size}) == 1
+ frames, _ = random_short_side_scale_jitter(
+ frames, min_scale, max_scale
+ )
+ frames, _ = uniform_crop(frames, crop_size, spatial_idx)
+ return frames
+
+
+def tensor_normalize(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize.
+ mean (tensor or list): mean value to subtract.
+ std (tensor or list): std to divide.
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+ tensor = tensor / 255.0
+ if type(mean) == list:
+ mean = torch.tensor(mean)
+ if type(std) == list:
+ std = torch.tensor(std)
+ tensor = tensor - mean
+ tensor = tensor / std
+ return tensor
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics.py
new file mode 100644
index 0000000000000000000000000000000000000000..927a10bf7b0724f415e5394342373b09774f37cf
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics.py
@@ -0,0 +1,405 @@
+import os
+import os
+import io
+import numpy as np
+from numpy.lib.function_base import disp
+import torch
+from torchvision import transforms
+import warnings
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from .random_erasing import RandomErasing
+from .video_transforms import (
+ Compose, Resize, CenterCrop, Normalize,
+ create_random_augment, random_short_side_scale_jitter,
+ random_crop, random_resized_crop_with_shift, random_resized_crop,
+ horizontal_flip, random_short_side_scale_jitter, uniform_crop,
+)
+from .volume_transforms import ClipToTensor
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+class VideoClsDataset(Dataset):
+ """Load your own video classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ frame_sample_rate=2, crop_size=224, short_side_size=256,
+ new_height=256, new_width=340, keep_aspect_ratio=True,
+ num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,
+ args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.frame_sample_rate = frame_sample_rate
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+ assert num_segment == 1
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError("Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0])
+ self.label_array = list(cleaned.values[:, 1])
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size, interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size, self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size), interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ sample_label = self.label_array[idx]
+ self.test_label_array.append(sample_label)
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_seg.append((ck, cp))
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ scale_t = 1
+
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during training".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+
+ if self.test_num_crop == 1:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2
+ spatial_start = int(spatial_step)
+ else:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [
+ transforms.ToPILImage()(frame) for frame in buffer
+ ]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ )
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False
+ )
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+
+ def loadvideo_decord(self, sample, sample_rate_scale=1, chunk_nb=0):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ try:
+ if self.keep_aspect_ratio:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
+ else:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ width=self.new_width,
+ height=self.new_height,
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
+ num_threads=1, ctx=cpu(0))
+
+ # handle temporal segments
+ converted_len = int(self.clip_len * self.frame_sample_rate)
+ seg_len = len(vr) // self.num_segment
+
+ if self.mode == 'test':
+ temporal_step = max(1.0 * (len(vr) - converted_len) / (self.test_num_segment - 1), 0)
+ temporal_start = int(chunk_nb * temporal_step)
+
+ bound = min(temporal_start + converted_len, len(vr))
+ all_index = [x for x in range(temporal_start, bound, self.frame_sample_rate)]
+ while len(all_index) < self.clip_len:
+ all_index.append(all_index[-1])
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+
+ all_index = []
+ for i in range(self.num_segment):
+ if seg_len <= converted_len:
+ index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate)
+ index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len))
+ index = np.clip(index, 0, seg_len - 1).astype(np.int64)
+ else:
+ if self.mode == 'validation':
+ end_idx = (seg_len - converted_len) // 2
+ else:
+ end_idx = np.random.randint(converted_len, seg_len)
+ str_idx = end_idx - converted_len
+ index = np.linspace(str_idx, end_idx, num=self.clip_len)
+ index = np.clip(index, str_idx, end_idx - 1).astype(np.int64)
+ index = index + i*seg_len
+ all_index.extend(list(index))
+
+ all_index = all_index[::int(sample_rate_scale)]
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+ except:
+ print("video cannot be loaded by decord: ", fname)
+ return []
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+def spatial_sampling(
+ frames,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=224,
+ random_horizontal_flip=True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=None,
+ scale=None,
+ motion_shift=False,
+):
+ """
+ Perform spatial sampling on the given video frames. If spatial_idx is
+ -1, perform random scale, random crop, and random flip on the given
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
+ with the given spatial_idx.
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `num frames` x `height` x `width` x `channel`.
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
+ or 2, perform left, center, right crop if width is larger than
+ height, and perform top, center, buttom crop if height is larger
+ than width.
+ min_scale (int): the minimal size of scaling.
+ max_scale (int): the maximal size of scaling.
+ crop_size (int): the size of height and width used to crop the
+ frames.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale,
+ max_scale].
+ aspect_ratio (list): Aspect ratio range for resizing.
+ scale (list): Scale range for resizing.
+ motion_shift (bool): Whether to apply motion shift for resizing.
+ Returns:
+ frames (tensor): spatially sampled frames.
+ """
+ assert spatial_idx in [-1, 0, 1, 2]
+ if spatial_idx == -1:
+ if aspect_ratio is None and scale is None:
+ frames, _ = random_short_side_scale_jitter(
+ images=frames,
+ min_size=min_scale,
+ max_size=max_scale,
+ inverse_uniform_sampling=inverse_uniform_sampling,
+ )
+ frames, _ = random_crop(frames, crop_size)
+ else:
+ transform_func = (
+ random_resized_crop_with_shift
+ if motion_shift
+ else random_resized_crop
+ )
+ frames = transform_func(
+ images=frames,
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=scale,
+ ratio=aspect_ratio,
+ )
+ if random_horizontal_flip:
+ frames, _ = horizontal_flip(0.5, frames)
+ else:
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ assert len({min_scale, max_scale, crop_size}) == 1
+ frames, _ = random_short_side_scale_jitter(
+ frames, min_scale, max_scale
+ )
+ frames, _ = uniform_crop(frames, crop_size, spatial_idx)
+ return frames
+
+
+def tensor_normalize(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize.
+ mean (tensor or list): mean value to subtract.
+ std (tensor or list): std to divide.
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+ tensor = tensor / 255.0
+ if type(mean) == list:
+ mean = torch.tensor(mean)
+ if type(std) == list:
+ std = torch.tensor(std)
+ tensor = tensor - mean
+ tensor = tensor / std
+ return tensor
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics_sparse.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics_sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..75468a433883418573332c603ed01d390175de1b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/kinetics_sparse.py
@@ -0,0 +1,393 @@
+import os
+import os
+import io
+import random
+import numpy as np
+from numpy.lib.function_base import disp
+import torch
+from torchvision import transforms
+import warnings
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from .random_erasing import RandomErasing
+from .video_transforms import (
+ Compose, Resize, CenterCrop, Normalize,
+ create_random_augment, random_short_side_scale_jitter,
+ random_crop, random_resized_crop_with_shift, random_resized_crop,
+ horizontal_flip, random_short_side_scale_jitter, uniform_crop,
+)
+from .volume_transforms import ClipToTensor
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+class VideoClsDataset_sparse(Dataset):
+ """Load your own video classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ frame_sample_rate=2, crop_size=224, short_side_size=256,
+ new_height=256, new_width=340, keep_aspect_ratio=True,
+ num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,
+ args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.frame_sample_rate = frame_sample_rate
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+ assert num_segment == 1
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError("Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0])
+ self.label_array = list(cleaned.values[:, 1])
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size, interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size, self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size), interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ sample_label = self.label_array[idx]
+ self.test_label_array.append(sample_label)
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_seg.append((ck, cp))
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=-1) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during training".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=-1)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=0)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=0)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+ if self.test_num_crop == 1:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2
+ spatial_start = int(spatial_step)
+ else:
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [
+ transforms.ToPILImage()(frame) for frame in buffer
+ ]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ )
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False
+ )
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+ def _get_seq_frames(self, video_size, num_frames, clip_idx=-1):
+ seg_size = max(0., float(video_size - 1) / num_frames)
+ max_frame = int(video_size) - 1
+ seq = []
+ # index from 1, must add 1
+ if clip_idx == -1:
+ for i in range(num_frames):
+ start = int(np.round(seg_size * i))
+ end = int(np.round(seg_size * (i + 1)))
+ idx = min(random.randint(start, end), max_frame)
+ seq.append(idx)
+ else:
+ num_segment = 1
+ if self.mode == 'test':
+ num_segment = self.test_num_segment
+ duration = seg_size / (num_segment + 1)
+ for i in range(num_frames):
+ start = int(np.round(seg_size * i))
+ frame_index = start + int(duration * (clip_idx + 1))
+ idx = min(frame_index, max_frame)
+ seq.append(idx)
+ return seq
+
+ def loadvideo_decord(self, sample, chunk_nb=0):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ try:
+ if self.keep_aspect_ratio:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
+ else:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ width=self.new_width,
+ height=self.new_height,
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
+ num_threads=1, ctx=cpu(0))
+
+ all_index = self._get_seq_frames(len(vr), self.clip_len, clip_idx=chunk_nb)
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+ except:
+ print("video cannot be loaded by decord: ", fname)
+ return []
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+def spatial_sampling(
+ frames,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=224,
+ random_horizontal_flip=True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=None,
+ scale=None,
+ motion_shift=False,
+):
+ """
+ Perform spatial sampling on the given video frames. If spatial_idx is
+ -1, perform random scale, random crop, and random flip on the given
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
+ with the given spatial_idx.
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `num frames` x `height` x `width` x `channel`.
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
+ or 2, perform left, center, right crop if width is larger than
+ height, and perform top, center, buttom crop if height is larger
+ than width.
+ min_scale (int): the minimal size of scaling.
+ max_scale (int): the maximal size of scaling.
+ crop_size (int): the size of height and width used to crop the
+ frames.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale,
+ max_scale].
+ aspect_ratio (list): Aspect ratio range for resizing.
+ scale (list): Scale range for resizing.
+ motion_shift (bool): Whether to apply motion shift for resizing.
+ Returns:
+ frames (tensor): spatially sampled frames.
+ """
+ assert spatial_idx in [-1, 0, 1, 2]
+ if spatial_idx == -1:
+ if aspect_ratio is None and scale is None:
+ frames, _ = random_short_side_scale_jitter(
+ images=frames,
+ min_size=min_scale,
+ max_size=max_scale,
+ inverse_uniform_sampling=inverse_uniform_sampling,
+ )
+ frames, _ = random_crop(frames, crop_size)
+ else:
+ transform_func = (
+ random_resized_crop_with_shift
+ if motion_shift
+ else random_resized_crop
+ )
+ frames = transform_func(
+ images=frames,
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=scale,
+ ratio=aspect_ratio,
+ )
+ if random_horizontal_flip:
+ frames, _ = horizontal_flip(0.5, frames)
+ else:
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ assert len({min_scale, max_scale, crop_size}) == 1
+ frames, _ = random_short_side_scale_jitter(
+ frames, min_scale, max_scale
+ )
+ frames, _ = uniform_crop(frames, crop_size, spatial_idx)
+ return frames
+
+
+def tensor_normalize(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize.
+ mean (tensor or list): mean value to subtract.
+ std (tensor or list): std to divide.
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+ tensor = tensor / 255.0
+ if type(mean) == list:
+ mean = torch.tensor(mean)
+ if type(std) == list:
+ std = torch.tensor(std)
+ tensor = tensor - mean
+ tensor = tensor / std
+ return tensor
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/mae.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/mae.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2dc1072aab399ee0caa5575aa81d69a988fac0
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/mae.py
@@ -0,0 +1,280 @@
+import os
+import cv2
+import io
+import numpy as np
+import torch
+import decord
+from PIL import Image
+from decord import VideoReader, cpu
+import random
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+
+class VideoMAE(torch.utils.data.Dataset):
+ """Load your own video classification dataset.
+ Parameters
+ ----------
+ root : str, required.
+ Path to the root folder storing the dataset.
+ setting : str, required.
+ A text file describing the dataset, each line per video sample.
+ There are three items in each line: (1) video path; (2) video length and (3) video label.
+ prefix : str, required.
+ The prefix for loading data.
+ split : str, required.
+ The split character for metadata.
+ train : bool, default True.
+ Whether to load the training or validation set.
+ test_mode : bool, default False.
+ Whether to perform evaluation on the test set.
+ Usually there is three-crop or ten-crop evaluation strategy involved.
+ name_pattern : str, default None.
+ The naming pattern of the decoded video frames.
+ For example, img_00012.jpg.
+ video_ext : str, default 'mp4'.
+ If video_loader is set to True, please specify the video format accordinly.
+ is_color : bool, default True.
+ Whether the loaded image is color or grayscale.
+ modality : str, default 'rgb'.
+ Input modalities, we support only rgb video frames for now.
+ Will add support for rgb difference image and optical flow image later.
+ num_segments : int, default 1.
+ Number of segments to evenly divide the video into clips.
+ A useful technique to obtain global video-level information.
+ Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
+ num_crop : int, default 1.
+ Number of crops for each image. default is 1.
+ Common choices are three crops and ten crops during evaluation.
+ new_length : int, default 1.
+ The length of input video clip. Default is a single image, but it can be multiple video frames.
+ For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
+ new_step : int, default 1.
+ Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
+ new_step=2 means we will extract a video clip of every other frame.
+ temporal_jitter : bool, default False.
+ Whether to temporally jitter if new_step > 1.
+ video_loader : bool, default False.
+ Whether to use video loader to load data.
+ use_decord : bool, default True.
+ Whether to use Decord video loader to load data. Otherwise load image.
+ transform : function, default None.
+ A function that takes data and label and transforms them.
+ data_aug : str, default 'v1'.
+ Different types of data augmentation auto. Supports v1, v2, v3 and v4.
+ lazy_init : bool, default False.
+ If set to True, build a dataset instance without loading any dataset.
+ """
+ def __init__(self,
+ root,
+ setting,
+ prefix='',
+ split=' ',
+ train=True,
+ test_mode=False,
+ name_pattern='img_%05d.jpg',
+ video_ext='mp4',
+ is_color=True,
+ modality='rgb',
+ num_segments=1,
+ num_crop=1,
+ new_length=1,
+ new_step=1,
+ transform=None,
+ temporal_jitter=False,
+ video_loader=False,
+ use_decord=True,
+ lazy_init=False,
+ num_sample=1,
+ ):
+
+ super(VideoMAE, self).__init__()
+ self.root = root
+ self.setting = setting
+ self.prefix = prefix
+ self.split = split
+ self.train = train
+ self.test_mode = test_mode
+ self.is_color = is_color
+ self.modality = modality
+ self.num_segments = num_segments
+ self.num_crop = num_crop
+ self.new_length = new_length
+ self.new_step = new_step
+ self.skip_length = self.new_length * self.new_step
+ self.temporal_jitter = temporal_jitter
+ self.name_pattern = name_pattern
+ self.video_loader = video_loader
+ self.video_ext = video_ext
+ self.use_decord = use_decord
+ self.transform = transform
+ self.lazy_init = lazy_init
+ self.num_sample = num_sample
+
+ # sparse sampling, num_segments != 1
+ if self.num_segments != 1:
+ print('Use sparse sampling, change frame and stride')
+ self.new_length = self.num_segments
+ self.skip_length = 1
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if not self.lazy_init:
+ self.clips = self._make_dataset(root, setting)
+ if len(self.clips) == 0:
+ raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
+ "Check your data directory (opt.data-dir)."))
+
+ def __getitem__(self, index):
+ while True:
+ try:
+ images = None
+ if self.use_decord:
+ directory, target = self.clips[index]
+ if self.video_loader:
+ if '.' in directory.split('/')[-1]:
+ # data in the "setting" file already have extension, e.g., demo.mp4
+ video_name = directory
+ else:
+ # data in the "setting" file do not have extension, e.g., demo
+ # So we need to provide extension (i.e., .mp4) to complete the file name.
+ video_name = '{}.{}'.format(directory, self.video_ext)
+
+ video_name = os.path.join(self.prefix, video_name)
+ if video_name.startswith('s3') or video_name.startswith('p2:s3'):
+ video_bytes = self.client.get(video_name)
+ decord_vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
+ duration = len(decord_vr)
+
+ segment_indices, skip_offsets = self._sample_train_indices(duration)
+ images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
+
+ else:
+ video_name, total_frame, target = self.clips[index]
+ video_name = os.path.join(self.prefix, video_name)
+
+ segment_indices, skip_offsets = self._sample_train_indices(total_frame)
+ frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets)
+ images = []
+ for idx in frame_id_list:
+ frame_fname = os.path.join(video_name, self.name_pattern.format(idx))
+ img_bytes = self.client.get(frame_fname)
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ images.append(Image.fromarray(img))
+ if images is not None:
+ break
+ except Exception as e:
+ print("Failed to load video from {} with error {}".format(
+ video_name, e))
+ index = random.randint(0, len(self.clips) - 1)
+
+ if self.num_sample > 1:
+ process_data_list = []
+ mask_list = []
+ for _ in range(self.num_sample):
+ process_data, mask = self.transform((images, None))
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)
+ process_data_list.append(process_data)
+ mask_list.append(mask)
+ return process_data_list, mask_list
+ else:
+ process_data, mask = self.transform((images, None)) # T*C,H,W
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
+ return (process_data, mask)
+
+ def __len__(self):
+ return len(self.clips)
+
+ def _make_dataset(self, directory, setting):
+ if not os.path.exists(setting):
+ raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
+ clips = []
+
+ print(f'Load dataset using decord: {self.use_decord}')
+ with open(setting) as split_f:
+ data = split_f.readlines()
+ for line in data:
+ line_info = line.split(self.split)
+ if len(line_info) < 2:
+ raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
+ if self.use_decord:
+ # line format: video_path, video_label
+ clip_path = os.path.join(line_info[0])
+ target = int(line_info[1])
+ item = (clip_path, target)
+ else:
+ # line format: video_path, video_duration, video_label
+ clip_path = os.path.join(line_info[0])
+ total_frame = int(line_info[1])
+ target = int(line_info[2])
+ item = (clip_path, total_frame, target)
+ clips.append(item)
+ return clips
+
+ def _sample_train_indices(self, num_frames):
+ average_duration = (num_frames - self.skip_length + 1) // self.num_segments
+ if average_duration > 0:
+ offsets = np.multiply(list(range(self.num_segments)),
+ average_duration)
+ offsets = offsets + np.random.randint(average_duration,
+ size=self.num_segments)
+ elif num_frames > max(self.num_segments, self.skip_length):
+ offsets = np.sort(np.random.randint(
+ num_frames - self.skip_length + 1,
+ size=self.num_segments))
+ else:
+ offsets = np.zeros((self.num_segments,))
+
+ if self.temporal_jitter:
+ skip_offsets = np.random.randint(
+ self.new_step, size=self.skip_length // self.new_step)
+ else:
+ skip_offsets = np.zeros(
+ self.skip_length // self.new_step, dtype=int)
+ return offsets + 1, skip_offsets
+
+ def _get_frame_id_list(self, duration, indices, skip_offsets):
+ frame_id_list = []
+ for seg_ind in indices:
+ offset = int(seg_ind)
+ for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
+ if offset + skip_offsets[i] <= duration:
+ frame_id = offset + skip_offsets[i] - 1
+ else:
+ frame_id = offset - 1
+ frame_id_list.append(frame_id)
+ if offset + self.new_step < duration:
+ offset += self.new_step
+ return frame_id_list
+
+ def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
+ sampled_list = []
+ frame_id_list = []
+ for seg_ind in indices:
+ offset = int(seg_ind)
+ for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
+ if offset + skip_offsets[i] <= duration:
+ frame_id = offset + skip_offsets[i] - 1
+ else:
+ frame_id = offset - 1
+ frame_id_list.append(frame_id)
+ if offset + self.new_step < duration:
+ offset += self.new_step
+ try:
+ video_data = video_reader.get_batch(frame_id_list).asnumpy()
+ sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
+ except:
+ raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
+ return sampled_list
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/mae_multi.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/mae_multi.py
new file mode 100644
index 0000000000000000000000000000000000000000..66f5b1ba784a531c4149f9cc9f8dacab6055795c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/mae_multi.py
@@ -0,0 +1,274 @@
+import os
+import cv2
+import io
+import numpy as np
+import torch
+import decord
+from PIL import Image
+from decord import VideoReader, cpu
+import random
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+
+class VideoMAE_multi(torch.utils.data.Dataset):
+ """Load your own video classification dataset.
+ Parameters
+ ----------
+ root : str, required.
+ Path to the root folder storing the dataset.
+ setting : str, required.
+ A text file describing the dataset, each line per video sample.
+ There are three items in each line: (1) video path; (2) video length and (3) video label.
+ prefix : str, required.
+ The prefix for loading data.
+ split : str, required.
+ The split character for metadata.
+ train : bool, default True.
+ Whether to load the training or validation set.
+ test_mode : bool, default False.
+ Whether to perform evaluation on the test set.
+ Usually there is three-crop or ten-crop evaluation strategy involved.
+ name_pattern : str, default None.
+ The naming pattern of the decoded video frames.
+ For example, img_00012.jpg.
+ is_color : bool, default True.
+ Whether the loaded image is color or grayscale.
+ modality : str, default 'rgb'.
+ Input modalities, we support only rgb video frames for now.
+ Will add support for rgb difference image and optical flow image later.
+ num_segments : int, default 1.
+ Number of segments to evenly divide the video into clips.
+ A useful technique to obtain global video-level information.
+ Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
+ num_crop : int, default 1.
+ Number of crops for each image. default is 1.
+ Common choices are three crops and ten crops during evaluation.
+ new_length : int, default 1.
+ The length of input video clip. Default is a single image, but it can be multiple video frames.
+ For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
+ new_step : int, default 1.
+ Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
+ new_step=2 means we will extract a video clip of every other frame.
+ temporal_jitter : bool, default False.
+ Whether to temporally jitter if new_step > 1.
+ video_loader : bool, default False.
+ Whether to use video loader to load data.
+ use_decord : bool, default True.
+ Whether to use Decord video loader to load data. Otherwise load image.
+ transform : function, default None.
+ A function that takes data and label and transforms them.
+ transform_ssv2 : function, default None.
+ A function that takes data and label and transforms them.
+ data_aug : str, default 'v1'.
+ Different types of data augmentation auto. Supports v1, v2, v3 and v4.
+ lazy_init : bool, default False.
+ If set to True, build a dataset instance without loading any dataset.
+ """
+ def __init__(self,
+ root,
+ setting,
+ prefix='',
+ split=' ',
+ train=True,
+ test_mode=False,
+ name_pattern='img_%05d.jpg',
+ is_color=True,
+ modality='rgb',
+ num_segments=1,
+ num_crop=1,
+ new_length=1,
+ new_step=1,
+ transform=None,
+ transform_ssv2=None,
+ temporal_jitter=False,
+ video_loader=False,
+ use_decord=True,
+ lazy_init=False,
+ num_sample=1,
+ ):
+
+ super(VideoMAE_multi, self).__init__()
+ self.root = root
+ self.setting = setting
+ self.prefix = prefix
+ self.split = split
+ self.train = train
+ self.test_mode = test_mode
+ self.is_color = is_color
+ self.modality = modality
+ self.num_segments = num_segments
+ self.num_crop = num_crop
+ self.new_length = new_length
+ self.new_step = new_step
+ self.skip_length = self.new_length * self.new_step
+ self.temporal_jitter = temporal_jitter
+ self.name_pattern = name_pattern
+ self.video_loader = video_loader
+ self.use_decord = use_decord
+ self.transform = transform
+ self.transform_ssv2 = transform_ssv2
+ self.lazy_init = lazy_init
+ self.num_sample = num_sample
+
+ assert use_decord == True, "Only support to read video now!"
+
+ # sparse sampling, num_segments != 1
+ if self.num_segments != 1:
+ print('Use sparse sampling, change frame and stride')
+ self.new_length = self.num_segments
+ self.skip_length = 1
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if not self.lazy_init:
+ self.clips = self._make_dataset(root, setting)
+ if len(self.clips) == 0:
+ raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
+ "Check your data directory (opt.data-dir)."))
+
+ def __getitem__(self, index):
+ while True:
+ try:
+ images = None
+ if self.use_decord:
+ source, path, total_time, start_time, end_time, target = self.clips[index]
+ if self.video_loader:
+ video_name = os.path.join(self.prefix, path)
+ if "s3://" in fname:
+ video_bytes = self.client.get(video_name)
+ decord_vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
+ duration = len(decord_vr)
+ start_index = 0
+
+ if total_time!= -1 and start_time != -1 and end_time != -1:
+ fps = duration / total_time
+ duration = int(fps * (end_time - start_time))
+ start_index = int(fps * start_time)
+ segment_indices, skip_offsets = self._sample_train_indices(duration, start_index)
+ images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
+ else:
+ raise NotImplementedError
+
+ if images is not None:
+ break
+ except Exception as e:
+ print("Failed to load video from {} with error {}".format(
+ video_name, e))
+ index = random.randint(0, len(self.clips) - 1)
+
+ if self.num_sample > 1:
+ process_data_list = []
+ mask_list = []
+ for _ in range(self.num_sample):
+ if source == "ssv2":
+ process_data, mask = self.transform_ssv2((images, None))
+ else:
+ process_data, mask = self.transform((images, None))
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)
+ process_data_list.append(process_data)
+ mask_list.append(mask)
+ return process_data_list, mask_list
+ else:
+ if source == "ssv2":
+ process_data, mask = self.transform_ssv2((images, None)) # T*C,H,W
+ else:
+ process_data, mask = self.transform((images, None)) # T*C,H,W
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
+ return (process_data, mask)
+
+ def __len__(self):
+ return len(self.clips)
+
+ def _make_dataset(self, directory, setting):
+ if not os.path.exists(setting):
+ raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
+ clips = []
+
+ print(f'Load dataset using decord: {self.use_decord}')
+ with open(setting) as split_f:
+ data = split_f.readlines()
+ for line in data:
+ line_info = line.split(self.split)
+ if len(line_info) < 2:
+ raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
+ if self.use_decord:
+ # line format: source, path, total_time, start_time, end_time, target
+ source = line_info[0]
+ path = line_info[1]
+ total_time = float(line_info[2])
+ start_time = float(line_info[3])
+ end_time = float(line_info[4])
+ target = int(line_info[5])
+ item = (source, path, total_time, start_time, end_time, target)
+ else:
+ raise NotImplementedError
+
+ clips.append(item)
+ return clips
+
+ def _sample_train_indices(self, num_frames, start_index=0):
+ average_duration = (num_frames - self.skip_length + 1) // self.num_segments
+ if average_duration > 0:
+ offsets = np.multiply(list(range(self.num_segments)),
+ average_duration)
+ offsets = offsets + np.random.randint(average_duration,
+ size=self.num_segments)
+ elif num_frames > max(self.num_segments, self.skip_length):
+ offsets = np.sort(np.random.randint(
+ num_frames - self.skip_length + 1,
+ size=self.num_segments))
+ else:
+ offsets = np.zeros((self.num_segments,))
+
+ if self.temporal_jitter:
+ skip_offsets = np.random.randint(
+ self.new_step, size=self.skip_length // self.new_step)
+ else:
+ skip_offsets = np.zeros(
+ self.skip_length // self.new_step, dtype=int)
+ return offsets + start_index, skip_offsets
+
+ def _get_frame_id_list(self, duration, indices, skip_offsets):
+ frame_id_list = []
+ for seg_ind in indices:
+ offset = int(seg_ind)
+ for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
+ if offset + skip_offsets[i] <= duration:
+ frame_id = offset + skip_offsets[i] - 1
+ else:
+ frame_id = offset - 1
+ frame_id_list.append(frame_id)
+ if offset + self.new_step < duration:
+ offset += self.new_step
+ return frame_id_list
+
+ def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets):
+ sampled_list = []
+ frame_id_list = []
+ for seg_ind in indices:
+ offset = int(seg_ind)
+ for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
+ if offset + skip_offsets[i] <= duration:
+ frame_id = offset + skip_offsets[i] - 1
+ else:
+ frame_id = offset - 1
+ frame_id_list.append(frame_id)
+ if offset + self.new_step < duration:
+ offset += self.new_step
+ try:
+ video_data = video_reader.get_batch(frame_id_list).asnumpy()
+ sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
+ except:
+ raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration))
+ return sampled_list
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/masking_generator.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/masking_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac942d3f27eb5c04fb38191946ca49900719380
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/masking_generator.py
@@ -0,0 +1,49 @@
+import numpy as np
+
+
+class TubeMaskingGenerator:
+ def __init__(self, input_size, mask_ratio):
+ self.frames, self.height, self.width = input_size
+ self.num_patches_per_frame = self.height * self.width
+ self.total_patches = self.frames * self.num_patches_per_frame
+ self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
+ self.total_masks = self.frames * self.num_masks_per_frame
+
+ def __repr__(self):
+ repr_str = "Maks: total patches {}, mask patches {}".format(
+ self.total_patches, self.total_masks
+ )
+ return repr_str
+
+ def __call__(self):
+ mask_per_frame = np.hstack([
+ np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
+ np.ones(self.num_masks_per_frame),
+ ])
+ np.random.shuffle(mask_per_frame)
+ mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
+ return mask
+
+
+class RandomMaskingGenerator:
+ def __init__(self, input_size, mask_ratio):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size, ) * 3
+
+ self.frames, self.height, self.width = input_size
+
+ self.num_patches = self.frames * self.height * self.width # 8x14x14
+ self.num_mask = int(mask_ratio * self.num_patches)
+
+ def __repr__(self):
+ repr_str = "Maks: total patches {}, mask patches {}".format(
+ self.num_patches, self.num_mask)
+ return repr_str
+
+ def __call__(self):
+ mask = np.hstack([
+ np.zeros(self.num_patches - self.num_mask),
+ np.ones(self.num_mask),
+ ])
+ np.random.shuffle(mask)
+ return mask # [196*8]
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/mixup.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fea7dae0644ad8c7ee6d3c50df5d59b10fd34b0
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/mixup.py
@@ -0,0 +1,316 @@
+""" Mixup and Cutmix
+
+Papers:
+mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+
+CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
+
+Code Reference:
+CutMix: https://github.com/clovaai/CutMix-PyTorch
+
+Hacked together by / Copyright 2019, Ross Wightman
+"""
+import numpy as np
+import torch
+
+
+def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
+ x = x.long().view(-1, 1)
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
+
+
+def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
+ off_value = smoothing / num_classes
+ on_value = 1. - smoothing + off_value
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
+ return y1 * lam + y2 * (1. - lam)
+
+
+def rand_bbox(img_shape, lam, margin=0., count=None):
+ """ Standard CutMix bounding-box
+ Generates a random square bbox based on lambda value. This impl includes
+ support for enforcing a border margin as percent of bbox dimensions.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ lam (float): Cutmix lambda value
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
+ count (int): Number of bbox to generate
+ """
+ ratio = np.sqrt(1 - lam)
+ img_h, img_w = img_shape[-2:]
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
+ return yl, yh, xl, xh
+
+
+def rand_bbox_minmax(img_shape, minmax, count=None):
+ """ Min-Max CutMix bounding-box
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
+ based on min/max percent values applied to each dimension of the input image.
+
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
+ count (int): Number of bbox to generate
+ """
+ assert len(minmax) == 2
+ img_h, img_w = img_shape[-2:]
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
+ yl = np.random.randint(0, img_h - cut_h, size=count)
+ xl = np.random.randint(0, img_w - cut_w, size=count)
+ yu = yl + cut_h
+ xu = xl + cut_w
+ return yl, yu, xl, xu
+
+
+def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
+ """ Generate bbox and apply lambda correction.
+ """
+ if ratio_minmax is not None:
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
+ else:
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
+ if correct_lam or ratio_minmax is not None:
+ bbox_area = (yu - yl) * (xu - xl)
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
+ return (yl, yu, xl, xu), lam
+
+
+class Mixup:
+ """ Mixup/Cutmix that applies different params to each element or whole batch
+
+ Args:
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
+ prob (float): probability of applying mixup or cutmix per batch or element
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
+ label_smoothing (float): apply label smoothing to the mixed target tensor
+ num_classes (int): number of classes for target
+ """
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
+ self.mixup_alpha = mixup_alpha
+ self.cutmix_alpha = cutmix_alpha
+ self.cutmix_minmax = cutmix_minmax
+ if self.cutmix_minmax is not None:
+ assert len(self.cutmix_minmax) == 2
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
+ self.cutmix_alpha = 1.0
+ self.mix_prob = prob
+ self.switch_prob = switch_prob
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+ self.mode = mode
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
+
+ def _params_per_elem(self, batch_size):
+ lam = np.ones(batch_size, dtype=np.float32)
+ use_cutmix = np.zeros(batch_size, dtype=np.bool)
+ if self.mixup_enabled:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
+ lam_mix = np.where(
+ use_cutmix,
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = np.ones(batch_size, dtype=np.bool)
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
+ return lam, use_cutmix
+
+ def _params_per_batch(self):
+ lam = 1.
+ use_cutmix = False
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand() < self.switch_prob
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = True
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = float(lam_mix)
+ return lam, use_cutmix
+
+ def _mix_elem(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_pair(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size // 2):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_batch(self, x):
+ lam, use_cutmix = self._params_per_batch()
+ if lam == 1.:
+ return 1.
+ if use_cutmix:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
+ else:
+ x_flipped = x.flip(0).mul_(1. - lam)
+ x.mul_(lam).add_(x_flipped)
+ return lam
+
+ def __call__(self, x, target):
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
+ if self.mode == 'elem':
+ lam = self._mix_elem(x)
+ elif self.mode == 'pair':
+ lam = self._mix_pair(x)
+ else:
+ lam = self._mix_batch(x)
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
+ return x, target
+
+
+class FastCollateMixup(Mixup):
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
+
+ A Mixup impl that's performed while collating the batches.
+ """
+
+ def _mix_elem_collate(self, output, batch, half=False):
+ batch_size = len(batch)
+ num_elem = batch_size // 2 if half else batch_size
+ assert len(output) == num_elem
+ lam_batch, use_cutmix = self._params_per_elem(num_elem)
+ for i in range(num_elem):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ mixed = batch[i][0]
+ if lam != 1.:
+ if use_cutmix[i]:
+ if not half:
+ mixed = mixed.copy()
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
+ np.rint(mixed, out=mixed)
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
+ if half:
+ lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
+ return torch.tensor(lam_batch).unsqueeze(1)
+
+ def _mix_pair_collate(self, output, batch):
+ batch_size = len(batch)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+ for i in range(batch_size // 2):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ mixed_i = batch[i][0]
+ mixed_j = batch[j][0]
+ assert 0 <= lam <= 1.0
+ if lam < 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ patch_i = mixed_i[:, yl:yh, xl:xh].copy()
+ mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
+ mixed_j[:, yl:yh, xl:xh] = patch_i
+ lam_batch[i] = lam
+ else:
+ mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
+ mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
+ mixed_i = mixed_temp
+ np.rint(mixed_j, out=mixed_j)
+ np.rint(mixed_i, out=mixed_i)
+ output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
+ output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+ return torch.tensor(lam_batch).unsqueeze(1)
+
+ def _mix_batch_collate(self, output, batch):
+ batch_size = len(batch)
+ lam, use_cutmix = self._params_per_batch()
+ if use_cutmix:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ for i in range(batch_size):
+ j = batch_size - i - 1
+ mixed = batch[i][0]
+ if lam != 1.:
+ if use_cutmix:
+ mixed = mixed.copy() # don't want to modify the original while iterating
+ mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
+ else:
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
+ np.rint(mixed, out=mixed)
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
+ return lam
+
+ def __call__(self, batch, _=None):
+ batch_size = len(batch)
+ assert batch_size % 2 == 0, 'Batch size should be even when using this'
+ half = 'half' in self.mode
+ if half:
+ batch_size //= 2
+ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ if self.mode == 'elem' or self.mode == 'half':
+ lam = self._mix_elem_collate(output, batch, half=half)
+ elif self.mode == 'pair':
+ lam = self._mix_pair_collate(output, batch)
+ else:
+ lam = self._mix_batch_collate(output, batch)
+ target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
+ target = target[:batch_size]
+ return output, target
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/rand_augment.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/rand_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c57d10e3c1abcba046995b96b9d23378b77b41
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/rand_augment.py
@@ -0,0 +1,531 @@
+"""
+This implementation is based on
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
+pulished under an Apache License 2.0.
+
+COMMENT FROM ORIGINAL:
+AutoAugment, RandAugment, and AugMix for PyTorch
+This code implements the searched ImageNet policies with various tweaks and
+improvements and does not include any of the search code. AA and RA
+Implementation adapted from:
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+AugMix adapted from:
+ https://github.com/google-research/augmix
+Papers:
+ AutoAugment: Learning Augmentation Policies from Data
+ https://arxiv.org/abs/1805.09501
+ Learning Data Augmentation Strategies for Object Detection
+ https://arxiv.org/abs/1906.11172
+ RandAugment: Practical automated data augmentation...
+ https://arxiv.org/abs/1909.13719
+ AugMix: A Simple Data Processing Method to Improve Robustness and
+ Uncertainty https://arxiv.org/abs/1912.02781
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import math
+import numpy as np
+import random
+import re
+import PIL
+from PIL import Image, ImageEnhance, ImageOps
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.0
+
+_HPARAMS_DEFAULT = {
+ "translate_const": 250,
+ "img_mean": _FILL,
+}
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _interpolation(kwargs):
+ interpolation = kwargs.pop("resample", Image.BILINEAR)
+ if isinstance(interpolation, (list, tuple)):
+ return random.choice(interpolation)
+ else:
+ return interpolation
+
+
+def _check_args_tf(kwargs):
+ if "fillcolor" in kwargs and _PIL_VER < (5, 0):
+ kwargs.pop("fillcolor")
+ kwargs["resample"] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
+ )
+
+
+def shear_y(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
+ )
+
+
+def translate_x_rel(img, pct, **kwargs):
+ pixels = pct * img.size[0]
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
+ )
+
+
+def translate_y_rel(img, pct, **kwargs):
+ pixels = pct * img.size[1]
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
+ )
+
+
+def translate_x_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
+ )
+
+
+def translate_y_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(
+ img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
+ )
+
+
+def rotate(img, degrees, **kwargs):
+ _check_args_tf(kwargs)
+ if _PIL_VER >= (5, 2):
+ return img.rotate(degrees, **kwargs)
+ elif _PIL_VER >= (5, 0):
+ w, h = img.size
+ post_trans = (0, 0)
+ rotn_center = (w / 2.0, h / 2.0)
+ angle = -math.radians(degrees)
+ matrix = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def transform(x, y, matrix):
+ (a, b, c, d, e, f) = matrix
+ return a * x + b * y + c, d * x + e * y + f
+
+ matrix[2], matrix[5] = transform(
+ -rotn_center[0] - post_trans[0],
+ -rotn_center[1] - post_trans[1],
+ matrix,
+ )
+ matrix[2] += rotn_center[0]
+ matrix[5] += rotn_center[1]
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+ else:
+ return img.rotate(degrees, resample=kwargs["resample"])
+
+
+def auto_contrast(img, **__):
+ return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+ return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+ return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+ return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+ lut = []
+ for i in range(256):
+ if i < thresh:
+ lut.append(min(255, i + add))
+ else:
+ lut.append(i)
+ if img.mode in ("L", "RGB"):
+ if img.mode == "RGB" and len(lut) == 256:
+ lut = lut + lut + lut
+ return img.point(lut)
+ else:
+ return img
+
+
+def posterize(img, bits_to_keep, **__):
+ if bits_to_keep >= 8:
+ return img
+ return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+ return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+ return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+ return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+ """With 50% prob, negate the value"""
+ return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+ # range [-30, 30]
+ level = (level / _MAX_LEVEL) * 30.0
+ level = _randomly_negate(level)
+ return (level,)
+
+
+def _enhance_level_to_arg(level, _hparams):
+ # range [0.1, 1.9]
+ return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+ # range [0.1, 1.9]
+ level = (level / _MAX_LEVEL) * 0.9
+ level = 1.0 + _randomly_negate(level)
+ return (level,)
+
+
+def _shear_level_to_arg(level, _hparams):
+ # range [-0.3, 0.3]
+ level = (level / _MAX_LEVEL) * 0.3
+ level = _randomly_negate(level)
+ return (level,)
+
+
+def _translate_abs_level_to_arg(level, hparams):
+ translate_const = hparams["translate_const"]
+ level = (level / _MAX_LEVEL) * float(translate_const)
+ level = _randomly_negate(level)
+ return (level,)
+
+
+def _translate_rel_level_to_arg(level, hparams):
+ # default range [-0.45, 0.45]
+ translate_pct = hparams.get("translate_pct", 0.45)
+ level = (level / _MAX_LEVEL) * translate_pct
+ level = _randomly_negate(level)
+ return (level,)
+
+
+def _posterize_level_to_arg(level, _hparams):
+ # As per Tensorflow TPU EfficientNet impl
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
+ # intensity/severity of augmentation decreases with level
+ return (int((level / _MAX_LEVEL) * 4),)
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+ # As per Tensorflow models research and UDA impl
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
+ # intensity/severity of augmentation increases with level
+ return (4 - _posterize_level_to_arg(level, hparams)[0],)
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+ # As per original AutoAugment paper description
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
+ # intensity/severity of augmentation decreases with level
+ return (int((level / _MAX_LEVEL) * 4) + 4,)
+
+
+def _solarize_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation decreases with level
+ return (int((level / _MAX_LEVEL) * 256),)
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation increases with level
+ return (256 - _solarize_level_to_arg(level, _hparams)[0],)
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+ # range [0, 110]
+ return (int((level / _MAX_LEVEL) * 110),)
+
+
+LEVEL_TO_ARG = {
+ "AutoContrast": None,
+ "Equalize": None,
+ "Invert": None,
+ "Rotate": _rotate_level_to_arg,
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+ "Posterize": _posterize_level_to_arg,
+ "PosterizeIncreasing": _posterize_increasing_level_to_arg,
+ "PosterizeOriginal": _posterize_original_level_to_arg,
+ "Solarize": _solarize_level_to_arg,
+ "SolarizeIncreasing": _solarize_increasing_level_to_arg,
+ "SolarizeAdd": _solarize_add_level_to_arg,
+ "Color": _enhance_level_to_arg,
+ "ColorIncreasing": _enhance_increasing_level_to_arg,
+ "Contrast": _enhance_level_to_arg,
+ "ContrastIncreasing": _enhance_increasing_level_to_arg,
+ "Brightness": _enhance_level_to_arg,
+ "BrightnessIncreasing": _enhance_increasing_level_to_arg,
+ "Sharpness": _enhance_level_to_arg,
+ "SharpnessIncreasing": _enhance_increasing_level_to_arg,
+ "ShearX": _shear_level_to_arg,
+ "ShearY": _shear_level_to_arg,
+ "TranslateX": _translate_abs_level_to_arg,
+ "TranslateY": _translate_abs_level_to_arg,
+ "TranslateXRel": _translate_rel_level_to_arg,
+ "TranslateYRel": _translate_rel_level_to_arg,
+}
+
+
+NAME_TO_OP = {
+ "AutoContrast": auto_contrast,
+ "Equalize": equalize,
+ "Invert": invert,
+ "Rotate": rotate,
+ "Posterize": posterize,
+ "PosterizeIncreasing": posterize,
+ "PosterizeOriginal": posterize,
+ "Solarize": solarize,
+ "SolarizeIncreasing": solarize,
+ "SolarizeAdd": solarize_add,
+ "Color": color,
+ "ColorIncreasing": color,
+ "Contrast": contrast,
+ "ContrastIncreasing": contrast,
+ "Brightness": brightness,
+ "BrightnessIncreasing": brightness,
+ "Sharpness": sharpness,
+ "SharpnessIncreasing": sharpness,
+ "ShearX": shear_x,
+ "ShearY": shear_y,
+ "TranslateX": translate_x_abs,
+ "TranslateY": translate_y_abs,
+ "TranslateXRel": translate_x_rel,
+ "TranslateYRel": translate_y_rel,
+}
+
+
+class AugmentOp:
+ """
+ Apply for video.
+ """
+
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ self.aug_fn = NAME_TO_OP[name]
+ self.level_fn = LEVEL_TO_ARG[name]
+ self.prob = prob
+ self.magnitude = magnitude
+ self.hparams = hparams.copy()
+ self.kwargs = {
+ "fillcolor": hparams["img_mean"]
+ if "img_mean" in hparams
+ else _FILL,
+ "resample": hparams["interpolation"]
+ if "interpolation" in hparams
+ else _RANDOM_INTERPOLATION,
+ }
+
+ # If magnitude_std is > 0, we introduce some randomness
+ # in the usually fixed policy and sample magnitude from a normal distribution
+ # with mean `magnitude` and std-dev of `magnitude_std`.
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
+ self.magnitude_std = self.hparams.get("magnitude_std", 0)
+
+ def __call__(self, img_list):
+ if self.prob < 1.0 and random.random() > self.prob:
+ return img_list
+ magnitude = self.magnitude
+ if self.magnitude_std and self.magnitude_std > 0:
+ magnitude = random.gauss(magnitude, self.magnitude_std)
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
+ level_args = (
+ self.level_fn(magnitude, self.hparams)
+ if self.level_fn is not None
+ else ()
+ )
+
+ if isinstance(img_list, list):
+ return [
+ self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
+ ]
+ else:
+ return self.aug_fn(img_list, *level_args, **self.kwargs)
+
+
+_RAND_TRANSFORMS = [
+ "AutoContrast",
+ "Equalize",
+ "Invert",
+ "Rotate",
+ "Posterize",
+ "Solarize",
+ "SolarizeAdd",
+ "Color",
+ "Contrast",
+ "Brightness",
+ "Sharpness",
+ "ShearX",
+ "ShearY",
+ "TranslateXRel",
+ "TranslateYRel",
+]
+
+
+_RAND_INCREASING_TRANSFORMS = [
+ "AutoContrast",
+ "Equalize",
+ "Invert",
+ "Rotate",
+ "PosterizeIncreasing",
+ "SolarizeIncreasing",
+ "SolarizeAdd",
+ "ColorIncreasing",
+ "ContrastIncreasing",
+ "BrightnessIncreasing",
+ "SharpnessIncreasing",
+ "ShearX",
+ "ShearY",
+ "TranslateXRel",
+ "TranslateYRel",
+]
+
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+ "Rotate": 0.3,
+ "ShearX": 0.2,
+ "ShearY": 0.2,
+ "TranslateXRel": 0.1,
+ "TranslateYRel": 0.1,
+ "Color": 0.025,
+ "Sharpness": 0.025,
+ "AutoContrast": 0.025,
+ "Solarize": 0.005,
+ "SolarizeAdd": 0.005,
+ "Contrast": 0.005,
+ "Brightness": 0.005,
+ "Equalize": 0.005,
+ "Posterize": 0,
+ "Invert": 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+ transforms = transforms or _RAND_TRANSFORMS
+ assert weight_idx == 0 # only one set of weights currently
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
+ probs = [rand_weights[k] for k in transforms]
+ probs /= np.sum(probs)
+ return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS
+ return [
+ AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
+ for name in transforms
+ ]
+
+
+class RandAugment:
+ def __init__(self, ops, num_layers=2, choice_weights=None):
+ self.ops = ops
+ self.num_layers = num_layers
+ self.choice_weights = choice_weights
+
+ def __call__(self, img):
+ # no replacement when using weighted choice
+ ops = np.random.choice(
+ self.ops,
+ self.num_layers,
+ replace=self.choice_weights is None,
+ p=self.choice_weights,
+ )
+ for op in ops:
+ img = op(img)
+ return img
+
+
+def rand_augment_transform(config_str, hparams):
+ """
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
+
+ Create a RandAugment transform
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude of rand augment
+ 'n' - integer num layers (number of transform ops selected per image)
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+ 'mstd' - float std deviation of magnitude noise applied
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
+ num_layers = 2 # default to 2 ops per image
+ weight_idx = None # default to no probability weights for op choice
+ transforms = _RAND_TRANSFORMS
+ config = config_str.split("-")
+ assert config[0] == "rand"
+ config = config[1:]
+ for c in config:
+ cs = re.split(r"(\d.*)", c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == "mstd":
+ # noise param injected via hparams for now
+ hparams.setdefault("magnitude_std", float(val))
+ elif key == "inc":
+ if bool(val):
+ transforms = _RAND_INCREASING_TRANSFORMS
+ elif key == "m":
+ magnitude = int(val)
+ elif key == "n":
+ num_layers = int(val)
+ elif key == "w":
+ weight_idx = int(val)
+ else:
+ assert NotImplementedError
+ ra_ops = rand_augment_ops(
+ magnitude=magnitude, hparams=hparams, transforms=transforms
+ )
+ choice_weights = (
+ None if weight_idx is None else _select_rand_weights(weight_idx)
+ )
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/random_erasing.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/random_erasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b46547b78b75f01b1c3968ecddaaba3739529a27
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/random_erasing.py
@@ -0,0 +1,173 @@
+"""
+This implementation is based on
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
+pulished under an Apache License 2.0.
+"""
+import math
+import random
+import torch
+
+
+def _get_pixels(
+ per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
+):
+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
+ # paths, flip the order so normal is run on CPU if this becomes a problem
+ # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
+ if per_pixel:
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
+ elif rand_color:
+ return torch.empty(
+ (patch_size[0], 1, 1), dtype=dtype, device=device
+ ).normal_()
+ else:
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
+
+
+class RandomErasing:
+ """Randomly selects a rectangle region in an image and erases its pixels.
+ 'Random Erasing Data Augmentation' by Zhong et al.
+ See https://arxiv.org/pdf/1708.04896.pdf
+ This variant of RandomErasing is intended to be applied to either a batch
+ or single image tensor after it has been normalized by dataset mean and std.
+ Args:
+ probability: Probability that the Random Erasing operation will be performed.
+ min_area: Minimum percentage of erased area wrt input image area.
+ max_area: Maximum percentage of erased area wrt input image area.
+ min_aspect: Minimum aspect ratio of erased area.
+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
+ 'const' - erase block is constant color of 0 for all channels
+ 'rand' - erase block is same per-channel random (normal) color
+ 'pixel' - erase block is per-pixel random (normal) color
+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
+ per-image count is randomly chosen between 1 and this value.
+ """
+
+ def __init__(
+ self,
+ probability=0.5,
+ min_area=0.02,
+ max_area=1 / 3,
+ min_aspect=0.3,
+ max_aspect=None,
+ mode="const",
+ min_count=1,
+ max_count=None,
+ num_splits=0,
+ device="cuda",
+ cube=True,
+ ):
+ self.probability = probability
+ self.min_area = min_area
+ self.max_area = max_area
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+ self.min_count = min_count
+ self.max_count = max_count or min_count
+ self.num_splits = num_splits
+ mode = mode.lower()
+ self.rand_color = False
+ self.per_pixel = False
+ self.cube = cube
+ if mode == "rand":
+ self.rand_color = True # per block random normal
+ elif mode == "pixel":
+ self.per_pixel = True # per pixel random normal
+ else:
+ assert not mode or mode == "const"
+ self.device = device
+
+ def _erase(self, img, chan, img_h, img_w, dtype):
+ if random.random() > self.probability:
+ return
+ area = img_h * img_w
+ count = (
+ self.min_count
+ if self.min_count == self.max_count
+ else random.randint(self.min_count, self.max_count)
+ )
+ for _ in range(count):
+ for _ in range(10):
+ target_area = (
+ random.uniform(self.min_area, self.max_area) * area / count
+ )
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < img_w and h < img_h:
+ top = random.randint(0, img_h - h)
+ left = random.randint(0, img_w - w)
+ img[:, top : top + h, left : left + w] = _get_pixels(
+ self.per_pixel,
+ self.rand_color,
+ (chan, h, w),
+ dtype=dtype,
+ device=self.device,
+ )
+ break
+
+ def _erase_cube(
+ self,
+ img,
+ batch_start,
+ batch_size,
+ chan,
+ img_h,
+ img_w,
+ dtype,
+ ):
+ if random.random() > self.probability:
+ return
+ area = img_h * img_w
+ count = (
+ self.min_count
+ if self.min_count == self.max_count
+ else random.randint(self.min_count, self.max_count)
+ )
+ for _ in range(count):
+ for _ in range(100):
+ target_area = (
+ random.uniform(self.min_area, self.max_area) * area / count
+ )
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < img_w and h < img_h:
+ top = random.randint(0, img_h - h)
+ left = random.randint(0, img_w - w)
+ for i in range(batch_start, batch_size):
+ img_instance = img[i]
+ img_instance[
+ :, top : top + h, left : left + w
+ ] = _get_pixels(
+ self.per_pixel,
+ self.rand_color,
+ (chan, h, w),
+ dtype=dtype,
+ device=self.device,
+ )
+ break
+
+ def __call__(self, input):
+ if len(input.size()) == 3:
+ self._erase(input, *input.size(), input.dtype)
+ else:
+ batch_size, chan, img_h, img_w = input.size()
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
+ batch_start = (
+ batch_size // self.num_splits if self.num_splits > 1 else 0
+ )
+ if self.cube:
+ self._erase_cube(
+ input,
+ batch_start,
+ batch_size,
+ chan,
+ img_h,
+ img_w,
+ input.dtype,
+ )
+ else:
+ for i in range(batch_start, batch_size):
+ self._erase(input[i], chan, img_h, img_w, input.dtype)
+ return input
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/ssv2.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/ssv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..215c29c01bcad3a8ae7b057f6d2567d821b3dc8b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/ssv2.py
@@ -0,0 +1,697 @@
+import os
+import io
+import cv2
+import numpy as np
+import torch
+from torchvision import transforms
+import warnings
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from .random_erasing import RandomErasing
+from .video_transforms import (
+ Compose, Resize, CenterCrop, Normalize,
+ create_random_augment, random_short_side_scale_jitter,
+ random_crop, random_resized_crop_with_shift, random_resized_crop,
+ horizontal_flip, random_short_side_scale_jitter, uniform_crop,
+)
+from .volume_transforms import ClipToTensor
+
+try:
+ from petrel_client.client import Client
+ has_client = True
+except ImportError:
+ has_client = False
+
+
+class SSRawFrameClsDataset(Dataset):
+ """Load your own raw frame classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ crop_size=224, short_side_size=256, new_height=256, new_width=340,
+ keep_aspect_ratio=True, num_segment=1, num_crop=1, test_num_segment=10,
+ test_num_crop=3, filename_tmpl='img_{:05}.jpg', args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.filename_tmpl = filename_tmpl
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError(
+ "Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0].astype('str'))
+ self.total_frames = list(cleaned.values[:, 1])
+ self.label_array = list(cleaned.values[:, -1])
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size,
+ interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size,
+ self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size),
+ interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_total_frames = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ self.test_seg.append((ck, cp))
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_total_frames.append(self.total_frames[idx])
+ self.test_label_array.append(self.label_array[idx])
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ scale_t = 1
+
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample,
+ total_frame,
+ sample_rate_scale=scale_t) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn(
+ "video {} not correctly loaded during training".format(
+ sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample,
+ total_frame,
+ sample_rate_scale=scale_t)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ total_frame = self.total_frames[index]
+ buffer = self.load_frame(sample, total_frame)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn(
+ "video {} not correctly loaded during validation".
+ format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.load_frame(sample, total_frame)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split(
+ "/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ total_frame = self.test_total_frames[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.load_frame(sample, total_frame)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ total_frame = self.test_total_frames[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.load_frame(sample, total_frame)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ temporal_start = chunk_nb
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[temporal_start::self.test_num_segment, \
+ spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[temporal_start::self.test_num_segment, \
+ :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406],
+ [0.229, 0.224, 0.225])
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False)
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+ def load_frame(self, sample, num_frames, sample_rate_scale=1):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ if self.mode == 'test':
+ tick = num_frames / float(self.num_segment)
+ all_index = []
+ for t_seg in range(self.test_num_segment):
+ tmp_index = [
+ int(t_seg * tick / self.test_num_segment + tick * x)
+ for x in range(self.num_segment)
+ ]
+ all_index.extend(tmp_index)
+ all_index = list(np.sort(np.array(all_index)))
+ imgs = []
+ for idx in all_index:
+ frame_fname = os.path.join(fname, self.filename_tmpl.format(idx + 1))
+ if "s3://" in fname:
+ img_bytes = self.client.get(frame_fname)
+ else:
+ with open(frame_fname, 'rb') as f:
+ img_bytes = f.read()
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ imgs.append(img)
+ buffer = np.array(imgs)
+ return buffer
+
+ # handle temporal segments
+ average_duration = num_frames // self.num_segment
+ all_index = []
+ if average_duration > 0:
+ if self.mode == 'validation':
+ all_index = list(
+ np.multiply(list(range(self.num_segment)),
+ average_duration) +
+ np.ones(self.num_segment, dtype=int) *
+ (average_duration // 2))
+ else:
+ all_index = list(
+ np.multiply(list(range(self.num_segment)),
+ average_duration) +
+ np.random.randint(average_duration, size=self.num_segment))
+ elif num_frames > self.num_segment:
+ if self.mode == 'validation':
+ all_index = list(range(self.num_segment))
+ else:
+ all_index = list(
+ np.sort(
+ np.random.randint(num_frames, size=self.num_segment)))
+ else:
+ all_index = [0] * (self.num_segment - num_frames) + list(
+ range(num_frames))
+ all_index = list(np.array(all_index))
+ imgs = []
+ for idx in all_index:
+ frame_fname = os.path.join(fname, self.filename_tmpl.format(idx + 1))
+ if "s3://" in fname:
+ img_bytes = self.client.get(frame_fname)
+ else:
+ with open(frame_fname, 'rb') as f:
+ img_bytes = f.read()
+ img_np = np.frombuffer(img_bytes, np.uint8)
+ img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ imgs.append(img)
+ buffer = np.array(imgs)
+ return buffer
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+class SSVideoClsDataset(Dataset):
+ """Load your own video classification dataset."""
+
+ def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
+ crop_size=224, short_side_size=256, new_height=256,
+ new_width=340, keep_aspect_ratio=True, num_segment=1,
+ num_crop=1, test_num_segment=10, test_num_crop=3, filename_tmpl=None, args=None):
+ self.anno_path = anno_path
+ self.prefix = prefix
+ self.split = split
+ self.mode = mode
+ self.clip_len = clip_len
+ self.crop_size = crop_size
+ self.short_side_size = short_side_size
+ self.new_height = new_height
+ self.new_width = new_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.num_segment = num_segment
+ self.test_num_segment = test_num_segment
+ self.num_crop = num_crop
+ self.test_num_crop = test_num_crop
+ self.args = args
+ self.aug = False
+ self.rand_erase = False
+
+ self.client = None
+ if has_client:
+ self.client = Client('~/petreloss.conf')
+
+ if self.mode in ['train']:
+ self.aug = True
+ if self.args.reprob > 0:
+ self.rand_erase = True
+ if VideoReader is None:
+ raise ImportError("Unable to import `decord` which is required to read videos.")
+
+ import pandas as pd
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
+ self.dataset_samples = list(cleaned.values[:, 0])
+ self.label_array = list(cleaned.values[:, 1])
+
+ if (mode == 'train'):
+ pass
+
+ elif (mode == 'validation'):
+ self.data_transform = Compose([
+ Resize(self.short_side_size, interpolation='bilinear'),
+ CenterCrop(size=(self.crop_size, self.crop_size)),
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ elif mode == 'test':
+ self.data_resize = Compose([
+ Resize(size=(short_side_size), interpolation='bilinear')
+ ])
+ self.data_transform = Compose([
+ ClipToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ ])
+ self.test_seg = []
+ self.test_dataset = []
+ self.test_label_array = []
+ for ck in range(self.test_num_segment):
+ for cp in range(self.test_num_crop):
+ for idx in range(len(self.label_array)):
+ sample_label = self.label_array[idx]
+ self.test_label_array.append(sample_label)
+ self.test_dataset.append(self.dataset_samples[idx])
+ self.test_seg.append((ck, cp))
+
+ def __getitem__(self, index):
+ if self.mode == 'train':
+ args = self.args
+ scale_t = 1
+
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during training".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
+
+ if args.num_sample > 1:
+ frame_list = []
+ label_list = []
+ index_list = []
+ for _ in range(args.num_sample):
+ new_frames = self._aug_frame(buffer, args)
+ label = self.label_array[index]
+ frame_list.append(new_frames)
+ label_list.append(label)
+ index_list.append(index)
+ return frame_list, label_list, index_list, {}
+ else:
+ buffer = self._aug_frame(buffer, args)
+
+ return buffer, self.label_array[index], index, {}
+
+ elif self.mode == 'validation':
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ if len(buffer) == 0:
+ while len(buffer) == 0:
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
+ index = np.random.randint(self.__len__())
+ sample = self.dataset_samples[index]
+ buffer = self.loadvideo_decord(sample)
+ buffer = self.data_transform(buffer)
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
+
+ elif self.mode == 'test':
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample)
+
+ while len(buffer) == 0:
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
+ str(self.test_dataset[index]), chunk_nb, split_nb))
+ index = np.random.randint(self.__len__())
+ sample = self.test_dataset[index]
+ chunk_nb, split_nb = self.test_seg[index]
+ buffer = self.loadvideo_decord(sample)
+
+ buffer = self.data_resize(buffer)
+ if isinstance(buffer, list):
+ buffer = np.stack(buffer, 0)
+
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
+ / (self.test_num_crop - 1)
+ temporal_start = chunk_nb # 0/1
+ spatial_start = int(split_nb * spatial_step)
+ if buffer.shape[1] >= buffer.shape[2]:
+ buffer = buffer[temporal_start::2, \
+ spatial_start:spatial_start + self.short_side_size, :, :]
+ else:
+ buffer = buffer[temporal_start::2, \
+ :, spatial_start:spatial_start + self.short_side_size, :]
+
+ buffer = self.data_transform(buffer)
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
+ chunk_nb, split_nb
+ else:
+ raise NameError('mode {} unkown'.format(self.mode))
+
+ def _aug_frame(
+ self,
+ buffer,
+ args,
+ ):
+
+ aug_transform = create_random_augment(
+ input_size=(self.crop_size, self.crop_size),
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ )
+
+ buffer = [
+ transforms.ToPILImage()(frame) for frame in buffer
+ ]
+
+ buffer = aug_transform(buffer)
+
+ buffer = [transforms.ToTensor()(img) for img in buffer]
+ buffer = torch.stack(buffer) # T C H W
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
+
+ # T H W C
+ buffer = tensor_normalize(
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ )
+ # T H W C -> C T H W.
+ buffer = buffer.permute(3, 0, 1, 2)
+ # Perform data augmentation.
+ scl, asp = (
+ [0.08, 1.0],
+ [0.75, 1.3333],
+ )
+
+ buffer = spatial_sampling(
+ buffer,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=self.crop_size,
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=asp,
+ scale=scl,
+ motion_shift=False
+ )
+
+ if self.rand_erase:
+ erase_transform = RandomErasing(
+ args.reprob,
+ mode=args.remode,
+ max_count=args.recount,
+ num_splits=args.recount,
+ device="cpu",
+ )
+ buffer = buffer.permute(1, 0, 2, 3)
+ buffer = erase_transform(buffer)
+ buffer = buffer.permute(1, 0, 2, 3)
+
+ return buffer
+
+
+ def loadvideo_decord(self, sample, sample_rate_scale=1):
+ """Load video content using Decord"""
+ fname = sample
+ fname = os.path.join(self.prefix, fname)
+
+ try:
+ if self.keep_aspect_ratio:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
+ else:
+ if "s3://" in fname:
+ video_bytes = self.client.get(fname)
+ vr = VideoReader(io.BytesIO(video_bytes),
+ width=self.new_width,
+ height=self.new_height,
+ num_threads=1,
+ ctx=cpu(0))
+ else:
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
+ num_threads=1, ctx=cpu(0))
+ except:
+ print("video cannot be loaded by decord: ", fname)
+ return []
+
+ if self.mode == 'test':
+ tick = len(vr) / float(self.num_segment)
+ all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
+ [int(tick * x) for x in range(self.num_segment)]))
+ while len(all_index) < (self.num_segment * self.test_num_segment):
+ all_index.append(all_index[-1])
+ all_index = np.sort(np.array(all_index))
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+ elif self.mode == 'validation':
+ tick = len(vr) / float(self.num_segment)
+ all_index = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)])
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+
+ # handle temporal segments
+ average_duration = len(vr) // self.num_segment
+ if average_duration > 0:
+ all_index = list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
+ size=self.num_segment))
+ elif len(vr) > self.num_segment:
+ all_index = list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
+ else:
+ all_index = list(np.zeros((self.num_segment,)))
+ vr.seek(0)
+ buffer = vr.get_batch(all_index).asnumpy()
+ return buffer
+
+ def __len__(self):
+ if self.mode != 'test':
+ return len(self.dataset_samples)
+ else:
+ return len(self.test_dataset)
+
+
+def spatial_sampling(
+ frames,
+ spatial_idx=-1,
+ min_scale=256,
+ max_scale=320,
+ crop_size=224,
+ random_horizontal_flip=True,
+ inverse_uniform_sampling=False,
+ aspect_ratio=None,
+ scale=None,
+ motion_shift=False,
+):
+ """
+ Perform spatial sampling on the given video frames. If spatial_idx is
+ -1, perform random scale, random crop, and random flip on the given
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
+ with the given spatial_idx.
+ Args:
+ frames (tensor): frames of images sampled from the video. The
+ dimension is `num frames` x `height` x `width` x `channel`.
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
+ or 2, perform left, center, right crop if width is larger than
+ height, and perform top, center, buttom crop if height is larger
+ than width.
+ min_scale (int): the minimal size of scaling.
+ max_scale (int): the maximal size of scaling.
+ crop_size (int): the size of height and width used to crop the
+ frames.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale,
+ max_scale].
+ aspect_ratio (list): Aspect ratio range for resizing.
+ scale (list): Scale range for resizing.
+ motion_shift (bool): Whether to apply motion shift for resizing.
+ Returns:
+ frames (tensor): spatially sampled frames.
+ """
+ assert spatial_idx in [-1, 0, 1, 2]
+ if spatial_idx == -1:
+ if aspect_ratio is None and scale is None:
+ frames, _ = random_short_side_scale_jitter(
+ images=frames,
+ min_size=min_scale,
+ max_size=max_scale,
+ inverse_uniform_sampling=inverse_uniform_sampling,
+ )
+ frames, _ = random_crop(frames, crop_size)
+ else:
+ transform_func = (
+ random_resized_crop_with_shift
+ if motion_shift
+ else random_resized_crop
+ )
+ frames = transform_func(
+ images=frames,
+ target_height=crop_size,
+ target_width=crop_size,
+ scale=scale,
+ ratio=aspect_ratio,
+ )
+ if random_horizontal_flip:
+ frames, _ = horizontal_flip(0.5, frames)
+ else:
+ # The testing is deterministic and no jitter should be performed.
+ # min_scale, max_scale, and crop_size are expect to be the same.
+ assert len({min_scale, max_scale, crop_size}) == 1
+ frames, _ = random_short_side_scale_jitter(
+ frames, min_scale, max_scale
+ )
+ frames, _ = uniform_crop(frames, crop_size, spatial_idx)
+ return frames
+
+
+def tensor_normalize(tensor, mean, std):
+ """
+ Normalize a given tensor by subtracting the mean and dividing the std.
+ Args:
+ tensor (tensor): tensor to normalize.
+ mean (tensor or list): mean value to subtract.
+ std (tensor or list): std to divide.
+ """
+ if tensor.dtype == torch.uint8:
+ tensor = tensor.float()
+ tensor = tensor / 255.0
+ if type(mean) == list:
+ mean = torch.tensor(mean)
+ if type(std) == list:
+ std = torch.tensor(std)
+ tensor = tensor - mean
+ tensor = tensor / std
+ return tensor
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/transforms.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d7fe0e280871793d69bd9dc6d1ea84c387cf0d9
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/transforms.py
@@ -0,0 +1,231 @@
+import torch
+import torchvision.transforms.functional as F
+import warnings
+import random
+import numpy as np
+import torchvision
+from PIL import Image, ImageOps
+import numbers
+
+
+class GroupRandomCrop(object):
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ w, h = img_group[0].size
+ th, tw = self.size
+
+ out_images = list()
+
+ x1 = random.randint(0, w - tw)
+ y1 = random.randint(0, h - th)
+
+ for img in img_group:
+ assert(img.size[0] == w and img.size[1] == h)
+ if w == tw and h == th:
+ out_images.append(img)
+ else:
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return (out_images, label)
+
+
+class GroupCenterCrop(object):
+ def __init__(self, size):
+ self.worker = torchvision.transforms.CenterCrop(size)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupRandomHorizontalFlip(object):
+ def __init__(self, flip=False):
+ self.flip = flip
+
+ def __call__(self, img_tuple):
+ v = random.random()
+ if self.flip and v < 0.5:
+ img_group, label = img_tuple
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
+ return (ret, label)
+ else:
+ return img_tuple
+
+
+class GroupNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, tensor_tuple):
+ tensor, label = tensor_tuple
+ rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
+ rep_std = self.std * (tensor.size()[0]//len(self.std))
+
+ # TODO: make efficient
+ for t, m, s in zip(tensor, rep_mean, rep_std):
+ t.sub_(m).div_(s)
+
+ return (tensor,label)
+
+
+class GroupGrayScale(object):
+ def __init__(self, size):
+ self.worker = torchvision.transforms.Grayscale(size)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupColorJitter(object):
+ def __init__(self, size):
+ self.worker = torchvision.transforms.ColorJitter(
+ brightness=size, contrast=size, saturation=size
+ )
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupScale(object):
+ """ Rescales the input PIL.Image to the given 'size'.
+ 'size' will be the size of the smaller edge.
+ For example, if height > width, then image will be
+ rescaled to (size * height / width, size)
+ size: size of the smaller edge
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ self.worker = torchvision.transforms.Resize(size, interpolation)
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+ return ([self.worker(img) for img in img_group], label)
+
+
+class GroupMultiScaleCrop(object):
+
+ def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
+ self.scales = scales if scales is not None else [1, 875, .75, .66]
+ self.max_distort = max_distort
+ self.fix_crop = fix_crop
+ self.more_fix_crop = more_fix_crop
+ self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
+ self.interpolation = Image.BILINEAR
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ im_size = img_group[0].size
+
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
+ crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
+ return (ret_img_group, label)
+
+ def _sample_crop_size(self, im_size):
+ image_w, image_h = im_size[0], im_size[1]
+
+ # find a crop size
+ base_size = min(image_w, image_h)
+ crop_sizes = [int(base_size * x) for x in self.scales]
+ crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
+ crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
+
+ pairs = []
+ for i, h in enumerate(crop_h):
+ for j, w in enumerate(crop_w):
+ if abs(i - j) <= self.max_distort:
+ pairs.append((w, h))
+
+ crop_pair = random.choice(pairs)
+ if not self.fix_crop:
+ w_offset = random.randint(0, image_w - crop_pair[0])
+ h_offset = random.randint(0, image_h - crop_pair[1])
+ else:
+ w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
+
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
+
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
+ offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
+ return random.choice(offsets)
+
+ @staticmethod
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
+ w_step = (image_w - crop_w) // 4
+ h_step = (image_h - crop_h) // 4
+
+ ret = list()
+ ret.append((0, 0)) # upper left
+ ret.append((4 * w_step, 0)) # upper right
+ ret.append((0, 4 * h_step)) # lower left
+ ret.append((4 * w_step, 4 * h_step)) # lower right
+ ret.append((2 * w_step, 2 * h_step)) # center
+
+ if more_fix_crop:
+ ret.append((0, 2 * h_step)) # center left
+ ret.append((4 * w_step, 2 * h_step)) # center right
+ ret.append((2 * w_step, 4 * h_step)) # lower center
+ ret.append((2 * w_step, 0 * h_step)) # upper center
+
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
+ return ret
+
+
+class Stack(object):
+
+ def __init__(self, roll=False):
+ self.roll = roll
+
+ def __call__(self, img_tuple):
+ img_group, label = img_tuple
+
+ if img_group[0].mode == 'L':
+ return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
+ elif img_group[0].mode == 'RGB':
+ if self.roll:
+ return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
+ else:
+ return (np.concatenate(img_group, axis=2), label)
+
+
+class ToTorchFormatTensor(object):
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
+ def __init__(self, div=True):
+ self.div = div
+
+ def __call__(self, pic_tuple):
+ pic, label = pic_tuple
+
+ if isinstance(pic, np.ndarray):
+ # handle numpy array
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
+ else:
+ # handle PIL Image
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ return (img.float().div(255.) if self.div else img.float(), label)
+
+
+class IdentityTransform(object):
+
+ def __call__(self, data):
+ return data
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/video_transforms.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/video_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..855b2395130f78782fd6fb4ea78106bd345548fb
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/video_transforms.py
@@ -0,0 +1,1285 @@
+#!/usr/bin/env python3
+import math
+import numpy as np
+import random
+import torch
+import torchvision.transforms.functional as F
+from PIL import Image
+from torchvision import transforms
+
+from .rand_augment import rand_augment_transform
+from .auto_augment import auto_augment_transform
+from .random_erasing import RandomErasing
+
+import numbers
+import PIL
+import torchvision
+
+import functional as FF
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: "PIL.Image.NEAREST",
+ Image.BILINEAR: "PIL.Image.BILINEAR",
+ Image.BICUBIC: "PIL.Image.BICUBIC",
+ Image.LANCZOS: "PIL.Image.LANCZOS",
+ Image.HAMMING: "PIL.Image.HAMMING",
+ Image.BOX: "PIL.Image.BOX",
+}
+
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _pil_interp(method):
+ if method == "bicubic":
+ return Image.BICUBIC
+ elif method == "lanczos":
+ return Image.LANCZOS
+ elif method == "hamming":
+ return Image.HAMMING
+ else:
+ return Image.BILINEAR
+
+
+def random_short_side_scale_jitter(
+ images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
+):
+ """
+ Perform a spatial short scale jittering on the given images and
+ corresponding boxes.
+ Args:
+ images (tensor): images to perform scale jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ min_size (int): the minimal size to scale the frames.
+ max_size (int): the maximal size to scale the frames.
+ boxes (ndarray): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale, max_scale].
+ Returns:
+ (tensor): the scaled images with dimension of
+ `num frames` x `channel` x `new height` x `new width`.
+ (ndarray or None): the scaled boxes with dimension of
+ `num boxes` x 4.
+ """
+ if inverse_uniform_sampling:
+ size = int(
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
+ )
+ else:
+ size = int(round(np.random.uniform(min_size, max_size)))
+
+ height = images.shape[2]
+ width = images.shape[3]
+ if (width <= height and width == size) or (
+ height <= width and height == size
+ ):
+ return images, boxes
+ new_width = size
+ new_height = size
+ if width < height:
+ new_height = int(math.floor((float(height) / width) * size))
+ if boxes is not None:
+ boxes = boxes * float(new_height) / height
+ else:
+ new_width = int(math.floor((float(width) / height) * size))
+ if boxes is not None:
+ boxes = boxes * float(new_width) / width
+
+ return (
+ torch.nn.functional.interpolate(
+ images,
+ size=(new_height, new_width),
+ mode="bilinear",
+ align_corners=False,
+ ),
+ boxes,
+ )
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+ """
+ Peform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = boxes.copy()
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def random_crop(images, size, boxes=None):
+ """
+ Perform random spatial crop on the given images and corresponding boxes.
+ Args:
+ images (tensor): images to perform random crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): the size of height and width to crop on the image.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): cropped images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ if images.shape[2] == size and images.shape[3] == size:
+ return images
+ height = images.shape[2]
+ width = images.shape[3]
+ y_offset = 0
+ if height > size:
+ y_offset = int(np.random.randint(0, height - size))
+ x_offset = 0
+ if width > size:
+ x_offset = int(np.random.randint(0, width - size))
+ cropped = images[
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
+ ]
+
+ cropped_boxes = (
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ )
+
+ return cropped, cropped_boxes
+
+
+def horizontal_flip(prob, images, boxes=None):
+ """
+ Perform horizontal flip on the given images and corresponding boxes.
+ Args:
+ prob (float): probility to flip the images.
+ images (tensor): images to perform horizontal flip, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ images (tensor): images with dimension of
+ `num frames` x `channel` x `height` x `width`.
+ flipped_boxes (ndarray or None): the flipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ if boxes is None:
+ flipped_boxes = None
+ else:
+ flipped_boxes = boxes.copy()
+
+ if np.random.uniform() < prob:
+ images = images.flip((-1))
+
+ if len(images.shape) == 3:
+ width = images.shape[2]
+ elif len(images.shape) == 4:
+ width = images.shape[3]
+ else:
+ raise NotImplementedError("Dimension does not supported")
+ if boxes is not None:
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
+
+ return images, flipped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ scale_size (int): optinal. If not None, resize the images to scale_size before
+ performing any crop.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ ndim = len(images.shape)
+ if ndim == 3:
+ images = images.unsqueeze(0)
+ height = images.shape[2]
+ width = images.shape[3]
+
+ if scale_size is not None:
+ if width <= height:
+ width, height = scale_size, int(height / width * scale_size)
+ else:
+ width, height = int(width / height * scale_size), scale_size
+ images = torch.nn.functional.interpolate(
+ images,
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
+ ]
+ cropped_boxes = (
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ )
+ if ndim == 3:
+ cropped = cropped.squeeze(0)
+ return cropped, cropped_boxes
+
+
+def clip_boxes_to_image(boxes, height, width):
+ """
+ Clip an array of boxes to an image with the given height and width.
+ Args:
+ boxes (ndarray): bounding boxes to perform clipping.
+ Dimension is `num boxes` x 4.
+ height (int): given image height.
+ width (int): given image width.
+ Returns:
+ clipped_boxes (ndarray): the clipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ clipped_boxes = boxes.copy()
+ clipped_boxes[:, [0, 2]] = np.minimum(
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
+ )
+ clipped_boxes[:, [1, 3]] = np.minimum(
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
+ )
+ return clipped_boxes
+
+
+def blend(images1, images2, alpha):
+ """
+ Blend two images with a given weight alpha.
+ Args:
+ images1 (tensor): the first images to be blended, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ images2 (tensor): the second images to be blended, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ alpha (float): the blending weight.
+ Returns:
+ (tensor): blended images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ return images1 * alpha + images2 * (1 - alpha)
+
+
+def grayscale(images):
+ """
+ Get the grayscale for the input images. The channels of images should be
+ in order BGR.
+ Args:
+ images (tensor): the input images for getting grayscale. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ img_gray (tensor): blended images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ # R -> 0.299, G -> 0.587, B -> 0.114.
+ img_gray = torch.tensor(images)
+ gray_channel = (
+ 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
+ )
+ img_gray[:, 0] = gray_channel
+ img_gray[:, 1] = gray_channel
+ img_gray[:, 2] = gray_channel
+ return img_gray
+
+
+def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
+ """
+ Perfrom a color jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ img_brightness (float): jitter ratio for brightness.
+ img_contrast (float): jitter ratio for contrast.
+ img_saturation (float): jitter ratio for saturation.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+
+ jitter = []
+ if img_brightness != 0:
+ jitter.append("brightness")
+ if img_contrast != 0:
+ jitter.append("contrast")
+ if img_saturation != 0:
+ jitter.append("saturation")
+
+ if len(jitter) > 0:
+ order = np.random.permutation(np.arange(len(jitter)))
+ for idx in range(0, len(jitter)):
+ if jitter[order[idx]] == "brightness":
+ images = brightness_jitter(img_brightness, images)
+ elif jitter[order[idx]] == "contrast":
+ images = contrast_jitter(img_contrast, images)
+ elif jitter[order[idx]] == "saturation":
+ images = saturation_jitter(img_saturation, images)
+ return images
+
+
+def brightness_jitter(var, images):
+ """
+ Perfrom brightness jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for brightness.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+
+ img_bright = torch.zeros(images.shape)
+ images = blend(images, img_bright, alpha)
+ return images
+
+
+def contrast_jitter(var, images):
+ """
+ Perfrom contrast jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for contrast.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+
+ img_gray = grayscale(images)
+ img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
+ images = blend(images, img_gray, alpha)
+ return images
+
+
+def saturation_jitter(var, images):
+ """
+ Perfrom saturation jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for saturation.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+ img_gray = grayscale(images)
+ images = blend(images, img_gray, alpha)
+
+ return images
+
+
+def lighting_jitter(images, alphastd, eigval, eigvec):
+ """
+ Perform AlexNet-style PCA jitter on the given images.
+ Args:
+ images (tensor): images to perform lighting jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ alphastd (float): jitter ratio for PCA jitter.
+ eigval (list): eigenvalues for PCA jitter.
+ eigvec (list[list]): eigenvectors for PCA jitter.
+ Returns:
+ out_images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ if alphastd == 0:
+ return images
+ # generate alpha1, alpha2, alpha3.
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
+ eig_vec = np.array(eigvec)
+ eig_val = np.reshape(eigval, (1, 3))
+ rgb = np.sum(
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
+ axis=1,
+ )
+ out_images = torch.zeros_like(images)
+ if len(images.shape) == 3:
+ # C H W
+ channel_dim = 0
+ elif len(images.shape) == 4:
+ # T C H W
+ channel_dim = 1
+ else:
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
+
+ for idx in range(images.shape[channel_dim]):
+ # C H W
+ if len(images.shape) == 3:
+ out_images[idx] = images[idx] + rgb[2 - idx]
+ # T C H W
+ elif len(images.shape) == 4:
+ out_images[:, idx] = images[:, idx] + rgb[2 - idx]
+ else:
+ raise NotImplementedError(
+ f"Unsupported dimension {len(images.shape)}"
+ )
+
+ return out_images
+
+
+def color_normalization(images, mean, stddev):
+ """
+ Perform color nomration on the given images.
+ Args:
+ images (tensor): images to perform color normalization. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ mean (list): mean values for normalization.
+ stddev (list): standard deviations for normalization.
+
+ Returns:
+ out_images (tensor): the noramlized images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ if len(images.shape) == 3:
+ assert (
+ len(mean) == images.shape[0]
+ ), "channel mean not computed properly"
+ assert (
+ len(stddev) == images.shape[0]
+ ), "channel stddev not computed properly"
+ elif len(images.shape) == 4:
+ assert (
+ len(mean) == images.shape[1]
+ ), "channel mean not computed properly"
+ assert (
+ len(stddev) == images.shape[1]
+ ), "channel stddev not computed properly"
+ else:
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
+
+ out_images = torch.zeros_like(images)
+ for idx in range(len(mean)):
+ # C H W
+ if len(images.shape) == 3:
+ out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
+ elif len(images.shape) == 4:
+ out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
+ else:
+ raise NotImplementedError(
+ f"Unsupported dimension {len(images.shape)}"
+ )
+ return out_images
+
+
+def _get_param_spatial_crop(
+ scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
+):
+ """
+ Given scale, ratio, height and width, return sampled coordinates of the videos.
+ """
+ for _ in range(num_repeat):
+ area = height * width
+ target_area = random.uniform(*scale) * area
+ if log_scale:
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+ else:
+ aspect_ratio = random.uniform(*ratio)
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if np.random.uniform() < 0.5 and switch_hw:
+ w, h = h, w
+
+ if 0 < w <= width and 0 < h <= height:
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = float(width) / float(height)
+ if in_ratio < min(ratio):
+ w = width
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = height
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ return i, j, h, w
+
+
+def random_resized_crop(
+ images,
+ target_height,
+ target_width,
+ scale=(0.8, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+):
+ """
+ Crop the given images to random size and aspect ratio. A crop of random
+ size (default: of 0.08 to 1.0) of the original size and a random aspect
+ ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
+ crop is finally resized to given size. This is popularly used to train the
+ Inception networks.
+
+ Args:
+ images: Images to perform resizing and cropping.
+ target_height: Desired height after cropping.
+ target_width: Desired width after cropping.
+ scale: Scale range of Inception-style area based random resizing.
+ ratio: Aspect ratio range of Inception-style area based random resizing.
+ """
+
+ height = images.shape[2]
+ width = images.shape[3]
+
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
+ cropped = images[:, :, i : i + h, j : j + w]
+ return torch.nn.functional.interpolate(
+ cropped,
+ size=(target_height, target_width),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+
+def random_resized_crop_with_shift(
+ images,
+ target_height,
+ target_width,
+ scale=(0.8, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+):
+ """
+ This is similar to random_resized_crop. However, it samples two different
+ boxes (for cropping) for the first and last frame. It then linearly
+ interpolates the two boxes for other frames.
+
+ Args:
+ images: Images to perform resizing and cropping.
+ target_height: Desired height after cropping.
+ target_width: Desired width after cropping.
+ scale: Scale range of Inception-style area based random resizing.
+ ratio: Aspect ratio range of Inception-style area based random resizing.
+ """
+ t = images.shape[1]
+ height = images.shape[2]
+ width = images.shape[3]
+
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
+ i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
+ i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
+ j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
+ h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
+ w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
+ out = torch.zeros((3, t, target_height, target_width))
+ for ind in range(t):
+ out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
+ images[
+ :,
+ ind : ind + 1,
+ i_s[ind] : i_s[ind] + h_s[ind],
+ j_s[ind] : j_s[ind] + w_s[ind],
+ ],
+ size=(target_height, target_width),
+ mode="bilinear",
+ align_corners=False,
+ )
+ return out
+
+
+def create_random_augment(
+ input_size,
+ auto_augment=None,
+ interpolation="bilinear",
+):
+ """
+ Get video randaug transform.
+
+ Args:
+ input_size: The size of the input video in tuple.
+ auto_augment: Parameters for randaug. An example:
+ "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
+ of operations to apply).
+ interpolation: Interpolation method.
+ """
+ if isinstance(input_size, tuple):
+ img_size = input_size[-2:]
+ else:
+ img_size = input_size
+
+ if auto_augment:
+ assert isinstance(auto_augment, str)
+ if isinstance(img_size, tuple):
+ img_size_min = min(img_size)
+ else:
+ img_size_min = img_size
+ aa_params = {"translate_const": int(img_size_min * 0.45)}
+ if interpolation and interpolation != "random":
+ aa_params["interpolation"] = _pil_interp(interpolation)
+ if auto_augment.startswith("rand"):
+ return transforms.Compose(
+ [rand_augment_transform(auto_augment, aa_params)]
+ )
+ # else:
+ # return transforms.Compose(
+ # [auto_augment_transform(auto_augment, aa_params)]
+ # )
+ raise NotImplementedError
+
+
+def random_sized_crop_img(
+ im,
+ size,
+ jitter_scale=(0.08, 1.0),
+ jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
+ max_iter=10,
+):
+ """
+ Performs Inception-style cropping (used for training).
+ """
+ assert (
+ len(im.shape) == 3
+ ), "Currently only support image for random_sized_crop"
+ h, w = im.shape[1:3]
+ i, j, h, w = _get_param_spatial_crop(
+ scale=jitter_scale,
+ ratio=jitter_aspect,
+ height=h,
+ width=w,
+ num_repeat=max_iter,
+ log_scale=False,
+ switch_hw=True,
+ )
+ cropped = im[:, i : i + h, j : j + w]
+ return torch.nn.functional.interpolate(
+ cropped.unsqueeze(0),
+ size=(size, size),
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(0)
+
+
+# The following code are modified based on timm lib, we will replace the following
+# contents with dependency from PyTorchVideo.
+# https://github.com/facebookresearch/pytorchvideo
+class RandomResizedCropAndInterpolation:
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(
+ self,
+ size,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation="bilinear",
+ ):
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ print("range should be of kind (min, max)")
+
+ if interpolation == "random":
+ self.interpolation = _RANDOM_INTERPOLATION
+ else:
+ self.interpolation = _pil_interp(interpolation)
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ area = img.size[0] * img.size[1]
+
+ for _ in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if in_ratio < min(ratio):
+ w = img.size[0]
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = img.size[1]
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolation = random.choice(self.interpolation)
+ else:
+ interpolation = self.interpolation
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
+
+ def __repr__(self):
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolate_str = " ".join(
+ [_pil_interpolation_to_str[x] for x in self.interpolation]
+ )
+ else:
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
+ format_string += ", scale={0}".format(
+ tuple(round(s, 4) for s in self.scale)
+ )
+ format_string += ", ratio={0}".format(
+ tuple(round(r, 4) for r in self.ratio)
+ )
+ format_string += ", interpolation={0})".format(interpolate_str)
+ return format_string
+
+
+def transforms_imagenet_train(
+ img_size=224,
+ scale=None,
+ ratio=None,
+ hflip=0.5,
+ vflip=0.0,
+ color_jitter=0.4,
+ auto_augment=None,
+ interpolation="random",
+ use_prefetcher=False,
+ mean=(0.485, 0.456, 0.406),
+ std=(0.229, 0.224, 0.225),
+ re_prob=0.0,
+ re_mode="const",
+ re_count=1,
+ re_num_splits=0,
+ separate=False,
+):
+ """
+ If separate==True, the transforms are returned as a tuple of 3 separate transforms
+ for use in a mixing dataset that passes
+ * all data through the first (primary) transform, called the 'clean' data
+ * a portion of the data through the secondary transform
+ * normalizes and converts the branches above with the third, final transform
+ """
+ if isinstance(img_size, tuple):
+ img_size = img_size[-2:]
+ else:
+ img_size = img_size
+
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
+ ratio = tuple(
+ ratio or (3.0 / 4.0, 4.0 / 3.0)
+ ) # default imagenet ratio range
+ primary_tfl = [
+ RandomResizedCropAndInterpolation(
+ img_size, scale=scale, ratio=ratio, interpolation=interpolation
+ )
+ ]
+ if hflip > 0.0:
+ primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
+ if vflip > 0.0:
+ primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
+
+ secondary_tfl = []
+ if auto_augment:
+ assert isinstance(auto_augment, str)
+ if isinstance(img_size, tuple):
+ img_size_min = min(img_size)
+ else:
+ img_size_min = img_size
+ aa_params = dict(
+ translate_const=int(img_size_min * 0.45),
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+ )
+ if interpolation and interpolation != "random":
+ aa_params["interpolation"] = _pil_interp(interpolation)
+ if auto_augment.startswith("rand"):
+ secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
+ elif auto_augment.startswith("augmix"):
+ raise NotImplementedError("Augmix not implemented")
+ else:
+ raise NotImplementedError("Auto aug not implemented")
+ elif color_jitter is not None:
+ # color jitter is enabled when not using AA
+ if isinstance(color_jitter, (list, tuple)):
+ # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
+ # or 4 if also augmenting hue
+ assert len(color_jitter) in (3, 4)
+ else:
+ # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
+ color_jitter = (float(color_jitter),) * 3
+ secondary_tfl += [transforms.ColorJitter(*color_jitter)]
+
+ final_tfl = []
+ final_tfl += [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
+ ]
+ if re_prob > 0.0:
+ final_tfl.append(
+ RandomErasing(
+ re_prob,
+ mode=re_mode,
+ max_count=re_count,
+ num_splits=re_num_splits,
+ device="cpu",
+ cube=False,
+ )
+ )
+
+ if separate:
+ return (
+ transforms.Compose(primary_tfl),
+ transforms.Compose(secondary_tfl),
+ transforms.Compose(final_tfl),
+ )
+ else:
+ return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
+
+############################################################################################################
+############################################################################################################
+
+class Compose(object):
+ """Composes several transforms
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms
+ to compose
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, clip):
+ for t in self.transforms:
+ clip = t(clip)
+ return clip
+
+
+class RandomHorizontalFlip(object):
+ """Horizontally flip the list of given images randomly
+ with a probability 0.5
+ """
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Randomly flipped clip
+ """
+ if random.random() < 0.5:
+ if isinstance(clip[0], np.ndarray):
+ return [np.fliplr(img) for img in clip]
+ elif isinstance(clip[0], PIL.Image.Image):
+ return [
+ img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip
+ ]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ ' but got list of {0}'.format(type(clip[0])))
+ return clip
+
+
+class RandomResize(object):
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
+ The larger the original image is, the more times it takes to
+ interpolate
+ Args:
+ interpolation (str): Can be one of 'nearest', 'bilinear'
+ defaults to nearest
+ size (tuple): (widht, height)
+ """
+
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
+ self.ratio = ratio
+ self.interpolation = interpolation
+
+ def __call__(self, clip):
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
+
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+
+ new_w = int(im_w * scaling_factor)
+ new_h = int(im_h * scaling_factor)
+ new_size = (new_w, new_h)
+ resized = FF.resize_clip(
+ clip, new_size, interpolation=self.interpolation)
+ return resized
+
+
+class Resize(object):
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
+ The larger the original image is, the more times it takes to
+ interpolate
+ Args:
+ interpolation (str): Can be one of 'nearest', 'bilinear'
+ defaults to nearest
+ size (tuple): (widht, height)
+ """
+
+ def __init__(self, size, interpolation='nearest'):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, clip):
+ resized = FF.resize_clip(
+ clip, self.size, interpolation=self.interpolation)
+ return resized
+
+
+class RandomCrop(object):
+ """Extract random crop at the same location for a list of images
+ Args:
+ size (sequence or int): Desired output size for the
+ crop in format (h, w)
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ size = (size, size)
+
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of images
+ """
+ h, w = self.size
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ if w > im_w or h > im_h:
+ error_msg = (
+ 'Initial image size should be larger then '
+ 'cropped size but got cropped sizes : ({w}, {h}) while '
+ 'initial image is ({im_w}, {im_h})'.format(
+ im_w=im_w, im_h=im_h, w=w, h=h))
+ raise ValueError(error_msg)
+
+ x1 = random.randint(0, im_w - w)
+ y1 = random.randint(0, im_h - h)
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
+
+ return cropped
+
+
+class ThreeCrop(object):
+ """Extract random crop at the same location for a list of images
+ Args:
+ size (sequence or int): Desired output size for the
+ crop in format (h, w)
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ size = (size, size)
+
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of images
+ """
+ h, w = self.size
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ if w != im_w and h != im_h:
+ clip = FF.resize_clip(clip, self.size, interpolation="bilinear")
+ im_h, im_w, im_c = clip[0].shape
+
+ step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0)
+ cropped = []
+ for i in range(3):
+ if (im_h > self.size[0]):
+ x1 = 0
+ y1 = i * step
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
+ else:
+ x1 = i * step
+ y1 = 0
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
+ return cropped
+
+
+class RandomRotation(object):
+ """Rotate entire clip randomly by a random angle within
+ given bounds
+ Args:
+ degrees (sequence or int): Range of degrees to select from
+ If degrees is a number instead of sequence like (min, max),
+ the range of degrees, will be (-degrees, +degrees).
+ """
+
+ def __init__(self, degrees):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError('If degrees is a single number,'
+ 'must be positive')
+ degrees = (-degrees, degrees)
+ else:
+ if len(degrees) != 2:
+ raise ValueError('If degrees is a sequence,'
+ 'it must be of len 2.')
+
+ self.degrees = degrees
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of images
+ """
+ import skimage
+ angle = random.uniform(self.degrees[0], self.degrees[1])
+ if isinstance(clip[0], np.ndarray):
+ rotated = [skimage.transform.rotate(img, angle) for img in clip]
+ elif isinstance(clip[0], PIL.Image.Image):
+ rotated = [img.rotate(angle) for img in clip]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+
+ return rotated
+
+
+class CenterCrop(object):
+ """Extract center crop at the same location for a list of images
+ Args:
+ size (sequence or int): Desired output size for the
+ crop in format (h, w)
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ size = (size, size)
+
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
+ in format (h, w, c) in numpy.ndarray
+ Returns:
+ PIL.Image or numpy.ndarray: Cropped list of images
+ """
+ h, w = self.size
+ if isinstance(clip[0], np.ndarray):
+ im_h, im_w, im_c = clip[0].shape
+ elif isinstance(clip[0], PIL.Image.Image):
+ im_w, im_h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ if w > im_w or h > im_h:
+ error_msg = (
+ 'Initial image size should be larger then '
+ 'cropped size but got cropped sizes : ({w}, {h}) while '
+ 'initial image is ({im_w}, {im_h})'.format(
+ im_w=im_w, im_h=im_h, w=w, h=h))
+ raise ValueError(error_msg)
+
+ x1 = int(round((im_w - w) / 2.))
+ y1 = int(round((im_h - h) / 2.))
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
+
+ return cropped
+
+
+class ColorJitter(object):
+ """Randomly change the brightness, contrast and saturation and hue of the clip
+ Args:
+ brightness (float): How much to jitter brightness. brightness_factor
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
+ contrast (float): How much to jitter contrast. contrast_factor
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
+ saturation (float): How much to jitter saturation. saturation_factor
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
+ [-hue, hue]. Should be >=0 and <= 0.5.
+ """
+
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ self.brightness = brightness
+ self.contrast = contrast
+ self.saturation = saturation
+ self.hue = hue
+
+ def get_params(self, brightness, contrast, saturation, hue):
+ if brightness > 0:
+ brightness_factor = random.uniform(
+ max(0, 1 - brightness), 1 + brightness)
+ else:
+ brightness_factor = None
+
+ if contrast > 0:
+ contrast_factor = random.uniform(
+ max(0, 1 - contrast), 1 + contrast)
+ else:
+ contrast_factor = None
+
+ if saturation > 0:
+ saturation_factor = random.uniform(
+ max(0, 1 - saturation), 1 + saturation)
+ else:
+ saturation_factor = None
+
+ if hue > 0:
+ hue_factor = random.uniform(-hue, hue)
+ else:
+ hue_factor = None
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (list): list of PIL.Image
+ Returns:
+ list PIL.Image : list of transformed PIL.Image
+ """
+ if isinstance(clip[0], np.ndarray):
+ raise TypeError(
+ 'Color jitter not yet implemented for numpy arrays')
+ elif isinstance(clip[0], PIL.Image.Image):
+ brightness, contrast, saturation, hue = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue)
+
+ # Create img transform function sequence
+ img_transforms = []
+ if brightness is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
+ if saturation is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
+ if hue is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
+ if contrast is not None:
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
+ random.shuffle(img_transforms)
+
+ # Apply to all images
+ jittered_clip = []
+ for img in clip:
+ for func in img_transforms:
+ jittered_img = func(img)
+ jittered_clip.append(jittered_img)
+
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return jittered_clip
+
+
+class Normalize(object):
+ """Normalize a clip with mean and standard deviation.
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+ .. note::
+ This transform acts out of place, i.e., it does not mutates the input tensor.
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
+ Returns:
+ Tensor: Normalized Tensor clip.
+ """
+ return FF.normalize(clip, self.mean, self.std)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/datasets/volume_transforms.py b/third_party/InternVideo/InternVideo2/single_modality/datasets/volume_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d33dadc9464fee731ae46cd14f20a04bc99a79b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/datasets/volume_transforms.py
@@ -0,0 +1,131 @@
+import numpy as np
+from PIL import Image
+import torch
+
+
+def convert_img(img):
+ """Converts (H, W, C) numpy.ndarray to (C, W, H) format
+ """
+ if len(img.shape) == 3:
+ img = img.transpose(2, 0, 1)
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, 0)
+ return img
+
+
+class ClipToTensor(object):
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
+ """
+
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
+ self.channel_nb = channel_nb
+ self.div_255 = div_255
+ self.numpy = numpy
+
+ def __call__(self, clip):
+ """
+ Args: clip (list of numpy.ndarray): clip (list of images)
+ to be converted to tensor.
+ """
+ # Retrieve shape
+ if isinstance(clip[0], np.ndarray):
+ h, w, ch = clip[0].shape
+ assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
+ ch)
+ elif isinstance(clip[0], Image.Image):
+ w, h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
+ but got list of {0}'.format(type(clip[0])))
+
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
+
+ # Convert
+ for img_idx, img in enumerate(clip):
+ if isinstance(img, np.ndarray):
+ pass
+ elif isinstance(img, Image.Image):
+ img = np.array(img, copy=False)
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
+ but got list of {0}'.format(type(clip[0])))
+ img = convert_img(img)
+ np_clip[:, img_idx, :, :] = img
+ if self.numpy:
+ if self.div_255:
+ np_clip = np_clip / 255.0
+ return np_clip
+
+ else:
+ tensor_clip = torch.from_numpy(np_clip)
+
+ if not isinstance(tensor_clip, torch.FloatTensor):
+ tensor_clip = tensor_clip.float()
+ if self.div_255:
+ tensor_clip = torch.div(tensor_clip, 255)
+ return tensor_clip
+
+
+# Note this norms data to -1/1
+class ClipToTensor_K(object):
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
+ """
+
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
+ self.channel_nb = channel_nb
+ self.div_255 = div_255
+ self.numpy = numpy
+
+ def __call__(self, clip):
+ """
+ Args: clip (list of numpy.ndarray): clip (list of images)
+ to be converted to tensor.
+ """
+ # Retrieve shape
+ if isinstance(clip[0], np.ndarray):
+ h, w, ch = clip[0].shape
+ assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
+ ch)
+ elif isinstance(clip[0], Image.Image):
+ w, h = clip[0].size
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
+ but got list of {0}'.format(type(clip[0])))
+
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
+
+ # Convert
+ for img_idx, img in enumerate(clip):
+ if isinstance(img, np.ndarray):
+ pass
+ elif isinstance(img, Image.Image):
+ img = np.array(img, copy=False)
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
+ but got list of {0}'.format(type(clip[0])))
+ img = convert_img(img)
+ np_clip[:, img_idx, :, :] = img
+ if self.numpy:
+ if self.div_255:
+ np_clip = (np_clip - 127.5) / 127.5
+ return np_clip
+
+ else:
+ tensor_clip = torch.from_numpy(np_clip)
+
+ if not isinstance(tensor_clip, torch.FloatTensor):
+ tensor_clip = tensor_clip.float()
+ if self.div_255:
+ tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
+ return tensor_clip
+
+
+class ToTensor(object):
+ """Converts numpy array to tensor
+ """
+
+ def __call__(self, array):
+ tensor = torch.from_numpy(array)
+ return tensor
diff --git a/third_party/InternVideo/InternVideo2/single_modality/engines/__init__.py b/third_party/InternVideo/InternVideo2/single_modality/engines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_finetuning.py b/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_finetuning.py
new file mode 100755
index 0000000000000000000000000000000000000000..27cf243dc9192c26fa96fb41c56a888a1edece60
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_finetuning.py
@@ -0,0 +1,295 @@
+import os
+import time
+import numpy as np
+import math
+import sys
+from typing import Iterable, Optional
+import torch
+from datasets.mixup import Mixup
+from timm.utils import accuracy, ModelEma
+import utils
+from scipy.special import softmax
+
+
+def train_class_batch(model, samples, target, criterion):
+ outputs = model(samples)
+ loss = criterion(outputs, target)
+ return loss, outputs
+
+
+def get_loss_scale_for_deepspeed(model):
+ optimizer = model.optimizer
+ return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
+
+
+def train_one_epoch(
+ model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
+ start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
+ num_training_steps_per_epoch=None, update_freq=None,
+ bf16=False,
+ ):
+ model.train(True)
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 1
+
+ if loss_scaler is None:
+ model.zero_grad()
+ model.micro_steps = 0
+ else:
+ optimizer.zero_grad()
+
+ for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ step = data_iter_step // update_freq
+ if step >= num_training_steps_per_epoch:
+ continue
+ it = start_steps + step # global training iteration
+ # Update LR & WD for the first acc
+ if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
+ for i, param_group in enumerate(optimizer.param_groups):
+ if lr_schedule_values is not None:
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
+ param_group["weight_decay"] = wd_schedule_values[it]
+
+ samples = samples.to(device, non_blocking=True)
+ targets = targets.to(device, non_blocking=True)
+
+ if mixup_fn is not None:
+ samples, targets = mixup_fn(samples, targets)
+
+ if loss_scaler is None:
+ samples = samples.bfloat16() if bf16 else samples.half()
+ loss, output = train_class_batch(
+ model, samples, targets, criterion)
+ else:
+ with torch.cuda.amp.autocast():
+ loss, output = train_class_batch(
+ model, samples, targets, criterion)
+
+ loss_value = loss.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ if loss_scaler is None:
+ loss /= update_freq
+ model.backward(loss)
+ model.step()
+
+ if (data_iter_step + 1) % update_freq == 0:
+ # model.zero_grad()
+ # Deepspeed will call step() & model.zero_grad() automatic
+ if model_ema is not None:
+ model_ema.update(model)
+ grad_norm = None
+ loss_scale_value = get_loss_scale_for_deepspeed(model)
+ else:
+ # this attribute is added by timm on one optimizer (adahessian)
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+ loss /= update_freq
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
+ parameters=model.parameters(), create_graph=is_second_order,
+ update_grad=(data_iter_step + 1) % update_freq == 0)
+ if (data_iter_step + 1) % update_freq == 0:
+ optimizer.zero_grad()
+ if model_ema is not None:
+ model_ema.update(model)
+ loss_scale_value = loss_scaler.state_dict()["scale"]
+
+ torch.cuda.synchronize()
+
+ if mixup_fn is None:
+ class_acc = (output.max(-1)[-1] == targets).float().mean()
+ else:
+ class_acc = None
+ metric_logger.update(loss=loss_value)
+ metric_logger.update(class_acc=class_acc)
+ metric_logger.update(loss_scale=loss_scale_value)
+ min_lr = 10.
+ max_lr = 0.
+ for group in optimizer.param_groups:
+ min_lr = min(min_lr, group["lr"])
+ max_lr = max(max_lr, group["lr"])
+
+ metric_logger.update(lr=max_lr)
+ metric_logger.update(min_lr=min_lr)
+ weight_decay_value = None
+ for group in optimizer.param_groups:
+ if group["weight_decay"] > 0:
+ weight_decay_value = group["weight_decay"]
+ metric_logger.update(weight_decay=weight_decay_value)
+ metric_logger.update(grad_norm=grad_norm)
+
+ if log_writer is not None:
+ log_writer.update(loss=loss_value, head="loss")
+ log_writer.update(class_acc=class_acc, head="loss")
+ log_writer.update(loss_scale=loss_scale_value, head="opt")
+ log_writer.update(lr=max_lr, head="opt")
+ log_writer.update(min_lr=min_lr, head="opt")
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
+ log_writer.update(grad_norm=grad_norm, head="opt")
+
+ log_writer.set_step()
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def validation_one_epoch(data_loader, model, device, ds=False, bf16=False):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Val:'
+
+ # switch to evaluation mode
+ model.eval()
+
+ for batch in metric_logger.log_every(data_loader, 10, header):
+ videos = batch[0]
+ target = batch[1]
+ videos = videos.to(device, non_blocking=True)
+ target = target.to(device, non_blocking=True)
+
+ # compute output
+ if ds:
+ videos = videos.bfloat16() if bf16 else videos.half()
+ output = model(videos)
+ loss = criterion(output, target)
+ else:
+ with torch.cuda.amp.autocast():
+ output = model(videos)
+ loss = criterion(output, target)
+
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+ batch_size = videos.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def final_test(data_loader, model, device, file, ds=False, bf16=False):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ # switch to evaluation mode
+ model.eval()
+ final_result = []
+
+ for batch in metric_logger.log_every(data_loader, 10, header):
+ videos = batch[0]
+ target = batch[1]
+ ids = batch[2]
+ chunk_nb = batch[3]
+ split_nb = batch[4]
+ videos = videos.to(device, non_blocking=True)
+ target = target.to(device, non_blocking=True)
+
+ # compute output
+ if ds:
+ videos = videos.bfloat16() if bf16 else videos.half()
+ output = model(videos)
+ loss = criterion(output, target)
+ else:
+ with torch.cuda.amp.autocast():
+ output = model(videos)
+ loss = criterion(output, target)
+
+ for i in range(output.size(0)):
+ string = "{} {} {} {} {}\n".format(ids[i], \
+ str(output.data[i].float().cpu().numpy().tolist()), \
+ str(int(target[i].cpu().numpy())), \
+ str(int(chunk_nb[i].cpu().numpy())), \
+ str(int(split_nb[i].cpu().numpy())))
+ final_result.append(string)
+
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+ batch_size = videos.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+
+ if not os.path.exists(file):
+ os.mknod(file)
+ with open(file, 'w') as f:
+ f.write("{}, {}\n".format(acc1, acc5))
+ for line in final_result:
+ f.write(line)
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def merge(eval_path, num_tasks):
+ dict_feats = {}
+ dict_label = {}
+ dict_pos = {}
+ print("Reading individual output files")
+
+ for x in range(num_tasks):
+ file = os.path.join(eval_path, str(x) + '.txt')
+ lines = open(file, 'r').readlines()[1:]
+ for line in lines:
+ line = line.strip()
+ name = line.split(' ')[0]
+ label = line.split(']')[-1].split(' ')[1]
+ chunk_nb = line.split(']')[-1].split(' ')[2]
+ split_nb = line.split(']')[-1].split(' ')[3]
+ data = np.fromstring(' '.join(line.split(' ')[1:]).split('[')[1].split(']')[0], dtype=np.float32, sep=',')
+ data = softmax(data)
+ if not name in dict_feats:
+ dict_feats[name] = []
+ dict_label[name] = 0
+ dict_pos[name] = []
+ if chunk_nb + split_nb in dict_pos[name]:
+ continue
+ dict_feats[name].append(data)
+ dict_pos[name].append(chunk_nb + split_nb)
+ dict_label[name] = label
+ print("Computing final results")
+
+ input_lst = []
+ print(len(dict_feats))
+ for i, item in enumerate(dict_feats):
+ input_lst.append([i, item, dict_feats[item], dict_label[item]])
+ from multiprocessing import Pool
+ p = Pool(64)
+ ans = p.map(compute_video, input_lst)
+ top1 = [x[1] for x in ans]
+ top5 = [x[2] for x in ans]
+ pred = [x[0] for x in ans]
+ label = [x[3] for x in ans]
+ final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
+ return final_top1*100 ,final_top5*100
+
+def compute_video(lst):
+ i, video_id, data, label = lst
+ feat = [x for x in data]
+ feat = np.mean(feat, axis=0)
+ pred = np.argmax(feat)
+ top1 = (int(pred) == int(label)) * 1.0
+ top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
+ return [pred, top1, top5, int(label)]
diff --git a/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_pretraining.py b/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f68c35c7fc2cf84b44c1fb6c497ab48f8db8de
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/engines/engine_for_pretraining.py
@@ -0,0 +1,216 @@
+import math
+import time
+import sys
+from typing import Iterable
+import torch
+import torch.distributed as dist
+import utils
+
+
+def get_loss_scale_for_deepspeed(model):
+ optimizer = model.optimizer
+ loss_scale = None
+ if hasattr(optimizer, 'loss_scale'):
+ loss_scale = optimizer.loss_scale
+ elif hasattr(optimizer, 'cur_scale'):
+ loss_scale = optimizer.cur_scale
+ return loss_scale, optimizer._global_grad_norm
+
+
+def train_one_epoch(
+ model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
+ log_writer=None, lr_scheduler=None, start_steps=None,
+ lr_schedule_values=None, wd_schedule_values=None,
+ clip_teacher_model=None, clip_input_resolution=224,
+ distill_final_features=True,
+ clip_loss_ratio=[1, 1],
+ mae_teacher_model=None, mae_input_resolution=224,
+ td_ratio=2,
+ mae_loss_ratio=1,
+ mask_type='tube', mask_ratio=0.,
+ bf16=False,
+ ):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 1
+ print(f"Temporal downsampling ratio: {td_ratio}")
+
+ if loss_scaler is None:
+ model.zero_grad()
+ model.micro_steps = 0
+ else:
+ optimizer.zero_grad()
+
+ if bf16:
+ datatype = torch.bfloat16
+ else:
+ datatype = torch.float16
+
+ for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ # assign learning rate & weight decay for each step
+ it = start_steps + step # global training iteration
+ if lr_schedule_values is not None or wd_schedule_values is not None:
+ for i, param_group in enumerate(optimizer.param_groups):
+ if lr_schedule_values is not None:
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
+ param_group["weight_decay"] = wd_schedule_values[it]
+
+ videos, bool_masked_pos = batch
+ videos = videos.to(device, non_blocking=True)
+
+ if mask_type in ['attention']:
+ bool_masked_pos = None
+ else:
+ bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
+
+ with torch.no_grad():
+ B, C, T, H, W = videos.shape
+ if H != mae_input_resolution:
+ mae_videos = torch.nn.functional.interpolate(
+ videos.view(B, C*T, H, W),
+ size=(clip_input_resolution, clip_input_resolution),
+ mode='bicubic', align_corners=False
+ )
+ mae_videos = mae_videos.view(B, C, T, clip_input_resolution, clip_input_resolution)
+ else:
+ mae_videos = videos
+
+ # VideoMAE use tublet_size=2, while CLIP & UMT use tublet_size=1
+ videos = videos[:, :, ::td_ratio]
+ T = T // td_ratio
+
+ if H != clip_input_resolution:
+ clip_videos = torch.nn.functional.interpolate(
+ videos.view(B, C*T, H, W),
+ size=(clip_input_resolution, clip_input_resolution),
+ mode='bicubic', align_corners=False
+ )
+ clip_videos = clip_videos.view(B, C, T, clip_input_resolution, clip_input_resolution)
+ else:
+ clip_videos = videos
+
+ with torch.cuda.amp.autocast(dtype=datatype):
+ if bool_masked_pos is None:
+ norm_clip_middle, norm_clip_final, attn = clip_teacher_model(clip_videos)
+ else:
+ norm_clip_middle, norm_clip_final = clip_teacher_model(clip_videos)
+
+ norm_mae = mae_teacher_model(mae_videos)
+
+ # generate attention mask
+ BT, N = attn.shape
+ N_vis = N - int(N * mask_ratio)
+ if mask_type == 'attention':
+ importance = torch.multinomial(attn, N)
+ bool_masked_pos = torch.ones((BT, N))
+ pos1 = torch.arange(BT).view(-1, 1).repeat(1, N_vis)
+ pos2 = importance[:, :N_vis]
+ bool_masked_pos[pos1, pos2] = 0
+ bool_masked_pos = bool_masked_pos.view(B, -1)
+ bool_masked_pos = torch.cat((torch.zeros(B, 1), bool_masked_pos), dim=1)
+ bool_masked_pos = bool_masked_pos.to(torch.bool)
+
+ K, _, _, C_CLIP = norm_clip_middle.shape
+ clip_bool_masked_pos = bool_masked_pos.unsqueeze(0).repeat(K, 1, 1)
+ targets_clip_middle_vis = norm_clip_middle[~clip_bool_masked_pos].reshape(K, B, -1, C_CLIP)
+ targets_clip_final_vis = norm_clip_final
+
+ K, _, _, C_MAE = norm_mae.shape
+ mae_bool_masked_pos = bool_masked_pos[:, 1:].unsqueeze(0).repeat(K, 1, 1)
+ targets_mae_vis = norm_mae[~mae_bool_masked_pos].reshape(K, B, -1, C_MAE)
+
+ if loss_scaler is None:
+ videos = videos.bfloat16() if bf16 else videos.half()
+ outputs_clip_middle, outputs_clip_final, output_mae = model(videos, bool_masked_pos)
+ # align CLIP and MAE followed MILAN, note that the features are processing by l2_norm
+ loss_clip_middle = (2 - 2 * (outputs_clip_middle * targets_clip_middle_vis).sum(dim=-1)).mean()
+ if distill_final_features and clip_loss_ratio[1] > 0:
+ loss_clip_final = (2 - 2 * (outputs_clip_final * targets_clip_final_vis).sum(dim=-1)).mean()
+ else:
+ loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device)
+ loss_mae = (2 - 2 * (output_mae * targets_mae_vis).sum(dim=-1)).mean()
+ else:
+ with torch.cuda.amp.autocast(dtype=datatype):
+ outputs_clip_middle, outputs_clip_final, output_mae = model(videos, bool_masked_pos)
+ # align CLIP followed MILAN, note that the features are processing by l2_norm
+ loss_clip_middle = (2 - 2 * (outputs_clip_middle * targets_clip_middle_vis).sum(dim=-1)).mean()
+ if distill_final_features and clip_loss_ratio[1] > 0:
+ loss_clip_final = (2 - 2 * (outputs_clip_final * targets_clip_final_vis).sum(dim=-1)).mean()
+ else:
+ loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device)
+ loss_mae = (2 - 2 * (output_mae * targets_mae_vis).sum(dim=-1)).mean()
+
+ loss = loss_clip_middle * clip_loss_ratio[0] + loss_clip_final * clip_loss_ratio[1] + loss_mae * mae_loss_ratio
+ loss_value = loss.item()
+
+ loss_list = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
+ dist.all_gather(loss_list, loss)
+ loss_list = torch.tensor(loss_list)
+ all_loss_mean_value = loss_list.mean().item()
+ metric_logger.update(all_loss_mean=all_loss_mean_value)
+
+ loss_list_isnan = torch.isnan(loss_list).any()
+ loss_list_isinf = torch.isinf(loss_list).any()
+ if loss_list_isnan or loss_list_isinf:
+ print(" ========== loss_isnan = {}, loss_isinf = {} ========== ".format(loss_list_isnan, loss_list_isinf))
+ sys.exit(1)
+
+ if loss_scaler is None:
+ model.backward(loss)
+ model.step()
+ loss_scale_value, grad_norm = get_loss_scale_for_deepspeed(model)
+ else:
+ optimizer.zero_grad()
+ # this attribute is added by timm on one optimizer (adahessian)
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
+ parameters=model.parameters(), create_graph=is_second_order)
+ loss_scale_value = loss_scaler.state_dict()["scale"]
+
+ torch.cuda.synchronize()
+
+ metric_logger.update(loss=loss_value)
+ metric_logger.update(loss_clip_middle=loss_clip_middle.item())
+ metric_logger.update(loss_clip_final=loss_clip_final.item())
+ metric_logger.update(loss_mae=loss_mae.item())
+ metric_logger.update(loss_scale=loss_scale_value)
+ min_lr = 10.
+ max_lr = 0.
+ for group in optimizer.param_groups:
+ min_lr = min(min_lr, group["lr"])
+ max_lr = max(max_lr, group["lr"])
+
+ metric_logger.update(lr=max_lr)
+ metric_logger.update(min_lr=min_lr)
+ weight_decay_value = None
+ for group in optimizer.param_groups:
+ if group["weight_decay"] > 0:
+ weight_decay_value = group["weight_decay"]
+ metric_logger.update(weight_decay=weight_decay_value)
+ metric_logger.update(grad_norm=grad_norm)
+
+ if log_writer is not None:
+ log_writer.update(all_rank_loss_mean=all_loss_mean_value, head="loss")
+ log_writer.update(loss=loss_value, head="loss")
+ log_writer.update(loss_clip_middle=loss_clip_middle, head="loss")
+ log_writer.update(loss_clip_final=loss_clip_final, head="loss")
+ log_writer.update(loss_mae=loss_mae, head="loss")
+ log_writer.update(loss_scale=loss_scale_value, head="opt")
+ log_writer.update(lr=max_lr, head="opt")
+ log_writer.update(min_lr=min_lr, head="opt")
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
+ log_writer.update(grad_norm=grad_norm, head="opt")
+ log_writer.set_step()
+
+ if lr_scheduler is not None:
+ lr_scheduler.step_update(start_steps + step)
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ timestep = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+ print(f"[{timestep}] Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
diff --git a/third_party/InternVideo/InternVideo2/single_modality/functional.py b/third_party/InternVideo/InternVideo2/single_modality/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e12e288299a54eefe8553ab666d2a45fea29194
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/functional.py
@@ -0,0 +1,89 @@
+import numbers
+import cv2
+import numpy as np
+import PIL
+import torch
+
+
+def _is_tensor_clip(clip):
+ return torch.is_tensor(clip) and clip.ndimension() == 4
+
+
+def crop_clip(clip, min_h, min_w, h, w):
+ if isinstance(clip[0], np.ndarray):
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
+
+ elif isinstance(clip[0], PIL.Image.Image):
+ cropped = [
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
+ ]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return cropped
+
+
+def resize_clip(clip, size, interpolation='bilinear'):
+ if isinstance(clip[0], np.ndarray):
+ if isinstance(size, numbers.Number):
+ im_h, im_w, im_c = clip[0].shape
+ # Min spatial dim already matches minimal size
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
+ and im_h == size):
+ return clip
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
+ size = (new_w, new_h)
+ else:
+ size = size[0], size[1]
+ if interpolation == 'bilinear':
+ np_inter = cv2.INTER_LINEAR
+ else:
+ np_inter = cv2.INTER_NEAREST
+ scaled = [
+ cv2.resize(img, size, interpolation=np_inter) for img in clip
+ ]
+ elif isinstance(clip[0], PIL.Image.Image):
+ if isinstance(size, numbers.Number):
+ im_w, im_h = clip[0].size
+ # Min spatial dim already matches minimal size
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
+ and im_h == size):
+ return clip
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
+ size = (new_w, new_h)
+ else:
+ size = size[1], size[0]
+ if interpolation == 'bilinear':
+ pil_inter = PIL.Image.BILINEAR
+ else:
+ pil_inter = PIL.Image.NEAREST
+ scaled = [img.resize(size, pil_inter) for img in clip]
+ else:
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
+ 'but got list of {0}'.format(type(clip[0])))
+ return scaled
+
+
+def get_resize_sizes(im_h, im_w, size):
+ if im_w < im_h:
+ ow = size
+ oh = int(size * im_h / im_w)
+ else:
+ oh = size
+ ow = int(size * im_w / im_h)
+ return oh, ow
+
+
+def normalize(clip, mean, std, inplace=False):
+ if not _is_tensor_clip(clip):
+ raise TypeError('tensor is not a torch clip.')
+
+ if not inplace:
+ clip = clip.clone()
+
+ dtype = clip.dtype
+ mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
+ std = torch.as_tensor(std, dtype=dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+
+ return clip
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/__init__.py b/third_party/InternVideo/InternVideo2/single_modality/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4d7b5c4635d47c558dbf08606a85395a9f307c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/__init__.py
@@ -0,0 +1,6 @@
+from .internvl_clip_vision import internvl_clip_6b
+from .videomae import mae_g14_hybrid
+from .internvideo2 import internvideo2_1B_patch14_224, internvideo2_6B_patch14_224
+from .internvideo2_cat import internvideo2_cat_1B_patch14_224, internvideo2_cat_6B_patch14_224
+from .internvideo2_ap import internvideo2_ap_1B_patch14_224, internvideo2_ap_6B_patch14_224
+from .internvideo2_pretrain import pretrain_internvideo2_1B_patch14_224, pretrain_internvideo2_6B_patch14_224
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/flash_attention_class.py b/third_party/InternVideo/InternVideo2/single_modality/models/flash_attention_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..04edd18ee4efcd0fd9f50ea38087a4417792c3fa
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/flash_attention_class.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+
+from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
+ max_s=None, need_weights=False):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
+ if unpadded: (nnz, 3, h, d)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert not need_weights
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
+ assert qkv.is_cuda
+
+ if cu_seqlens is None:
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = seqlen
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, batch_size, seqlen),
+ 'b s (h d) -> b s h d', h=nheads)
+ else:
+ assert max_s is not None
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+
+ return output, None
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2.py b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c69e6f7ec410262ad7fcdf519eee6029013e6413
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2.py
@@ -0,0 +1,589 @@
+import math
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
+from .flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ if drop_path > 0.:
+ print(f"Use DropPath in projector: {drop_path}")
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class InternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25,
+ embed_dim: int = 1408,
+ head_drop_path_rate: float = 0.,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.3637,
+ init_values: float = 1e-5,
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False,
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ fc_drop_rate: float = 0.,
+ num_classes: int = 1000,
+ init_scale: float = 0.001,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+ print(mlp_ratio)
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ if sep_pos_embed:
+ print("Use seperable position embedding")
+ grid_size = self.patch_embed.grid_size
+ self.grid_size = grid_size
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ else:
+ print("Use joint position embedding")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ print(f"Droppath rate: {dpr}")
+ print(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=head_drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
+ )
+
+ self.fc_norm = nn.LayerNorm(clip_embed_dim)
+ self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
+ self.head = nn.Linear(clip_embed_dim, num_classes)
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ def init_pos_embed(self):
+ print("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ # trunc_normal_(self.pos_embed_spatial, std=.02)
+ # trunc_normal_(self.pos_embed_temporal, std=.02)
+ # trunc_normal_(self.pos_embed_cls, std=.02)
+ pos_embed_spatial = get_2d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ )
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ pos_embed_temporal = get_1d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[0], # t_size
+ )
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ else:
+ # trunc_normal_(self.pos_embed, std=.02)
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'cls_token'
+ }
+
+ def forward(self, x):
+ x = self.patch_embed(x.type(self.dtype))
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ pos_embed = torch.cat(
+ [
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
+ pos_embed,
+ ],
+ 1,
+ )
+ else:
+ pos_embed = self.pos_embed
+ x = x + pos_embed
+
+ residual = None
+ for blk in self.blocks:
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x = self.clip_projector(x)
+
+ x = self.fc_norm(x)
+ x = self.head(self.fc_dropout(x))
+ return x
+
+
+@register_model
+def internvideo2_1B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=1408,
+ depth=40, num_heads=16, mlp_ratio=48/11,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+@register_model
+def internvideo2_6B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=3200,
+ depth=48, num_heads=25, mlp_ratio=4,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+ img_size = 224
+
+ # model = internvideo2_1B_patch14_224(num_classes=400).cuda().half()
+ model = internvideo2_6B_patch14_224(num_classes=400).cuda().half()
+ print(model)
+
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, img_size, img_size).cuda().half())
+ s = time.time()
+ print(flop_count_table(flops, max_depth=1))
+ print(time.time()-s)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_ap.py b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_ap.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa5c00e4177be93d6b705eb4b449a3e8c0b47b05
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_ap.py
@@ -0,0 +1,603 @@
+import math
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
+from .flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ if drop_path > 0.:
+ print(f"Use DropPath in projector: {drop_path}")
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class InternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25,
+ embed_dim: int = 1408,
+ head_drop_path_rate: float = 0.,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.3637,
+ init_values: float = 1e-5,
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False,
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ fc_drop_rate: float = 0.,
+ num_classes: int = 1000,
+ init_scale: float = 0.001,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+ print(mlp_ratio)
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ if sep_pos_embed:
+ print("Use seperable position embedding")
+ grid_size = self.patch_embed.grid_size
+ self.grid_size = grid_size
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ else:
+ print("Use joint position embedding")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ print(f"Droppath rate: {dpr}")
+ print(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=head_drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
+ )
+
+ self.extra_mlp = nn.Sequential(
+ nn.LayerNorm(clip_embed_dim),
+ Mlp(
+ in_features=clip_embed_dim,
+ hidden_features=int(clip_embed_dim*4.0),
+ act_layer=nn.GELU, drop=0.
+ ),
+ )
+
+ self.fc_norm = nn.LayerNorm(clip_embed_dim)
+ self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
+ self.head = nn.Linear(clip_embed_dim, num_classes)
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ print("Zero initialization for extra MLP")
+ nn.init.constant_(self.extra_mlp[1].fc2.weight, 0.)
+ nn.init.constant_(self.extra_mlp[1].fc2.bias, 0.)
+
+ def init_pos_embed(self):
+ print("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ # trunc_normal_(self.pos_embed_spatial, std=.02)
+ # trunc_normal_(self.pos_embed_temporal, std=.02)
+ # trunc_normal_(self.pos_embed_cls, std=.02)
+ pos_embed_spatial = get_2d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ )
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ pos_embed_temporal = get_1d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[0], # t_size
+ )
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ else:
+ # trunc_normal_(self.pos_embed, std=.02)
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'cls_token'
+ }
+
+ def forward(self, x):
+ x = self.patch_embed(x.type(self.dtype))
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ pos_embed = torch.cat(
+ [
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
+ pos_embed,
+ ],
+ 1,
+ )
+ else:
+ pos_embed = self.pos_embed
+ x = x + pos_embed
+
+ residual = None
+ for blk in self.blocks:
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x = self.clip_projector(x)
+ # add extra mlp as in VideoGLUE
+ x = x + self.extra_mlp(x)
+
+ x = self.fc_norm(x)
+ x = self.head(self.fc_dropout(x))
+ return x
+
+
+@register_model
+def internvideo2_ap_1B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=1408,
+ depth=40, num_heads=16, mlp_ratio=48/11,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+@register_model
+def internvideo2_ap_6B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=3200,
+ depth=48, num_heads=25, mlp_ratio=4,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+ img_size = 224
+
+ model = internvideo2_ap_6B_patch14_224(num_classes=400).cuda().half()
+ print(model)
+
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, img_size, img_size).cuda().half())
+ s = time.time()
+ print(flop_count_table(flops, max_depth=1))
+ print(time.time()-s)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_cat.py b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_cat.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aef00287ecb005184e87d8a372b4b95c4ba4b43
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_cat.py
@@ -0,0 +1,655 @@
+import math
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
+from .flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ if drop_path > 0.:
+ print(f"Use DropPath in projector: {drop_path}")
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class InternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25,
+ embed_dim: int = 1408,
+ head_drop_path_rate: float = 0.,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.3637,
+ init_values: float = 1e-5,
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False, # when True for training?
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ fc_drop_rate: float = 0.,
+ num_classes: int = 1000,
+ init_scale: float = 0.001,
+ merge_method: str = "proj", # proj, cls_avg1, cls_avgN, cls_avg1_proj, cls_avgN_proj
+ merge_norm: str = 'kaiming_BN',
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+ print(mlp_ratio)
+
+ self.merge_method = merge_method
+ self.merge_norm = merge_norm
+ print(f"Merge method: {merge_method}")
+ print(f"Merge Norm: {merge_norm}")
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ if sep_pos_embed:
+ print("Use seperable position embedding")
+ grid_size = self.patch_embed.grid_size
+ self.grid_size = grid_size
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ else:
+ print("Use joint position embedding")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ print(f"Droppath rate: {dpr}")
+ print(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=head_drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
+ )
+
+ self.fc_norm = nn.LayerNorm(clip_embed_dim)
+ self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
+
+ if self.merge_method == 'proj':
+ self.head = nn.Linear(clip_embed_dim, num_classes)
+ else:
+ norm_dim = embed_dim if 'avg1' in merge_method else clip_embed_dim
+ if merge_norm == 'kaiming_BN':
+ self.down_norm = nn.BatchNorm1d(norm_dim, affine=False, eps=1e-6)
+ elif merge_norm == 'LN':
+ self.down_norm = nn.LayerNorm(norm_dim)
+ elif merge_norm == 'BN':
+ self.down_norm = nn.BatchNorm1d(norm_dim)
+ else:
+ print(f"Wrong Norm: {merge_norm}")
+ raise Exception
+ # add downsample for avgN
+ if self.merge_method == 'cls_avg1':
+ self.down = nn.Identity()
+ self.head = nn.Linear(embed_dim * 2, num_classes)
+ elif self.merge_method == 'cls_avgN':
+ self.down = nn.Sequential(
+ nn.Linear(embed_dim, clip_embed_dim),
+ nn.GELU()
+ )
+ self.head = nn.Linear(clip_embed_dim * (num_frames // tubelet_size + 1), num_classes)
+ elif self.merge_method == 'cls_avg1_proj':
+ self.down = nn.Identity()
+ self.head = nn.Linear(embed_dim * 2 + clip_embed_dim, num_classes)
+ elif self.merge_method == 'cls_avgN_proj':
+ self.down = nn.Sequential(
+ nn.Linear(embed_dim, clip_embed_dim),
+ nn.GELU(),
+ )
+ self.head = nn.Linear(clip_embed_dim * (num_frames // tubelet_size + 2), num_classes)
+ else:
+ print(f"Wrong method: {self.merge_method}")
+ raise Exception
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ def init_pos_embed(self):
+ print("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ # trunc_normal_(self.pos_embed_spatial, std=.02)
+ # trunc_normal_(self.pos_embed_temporal, std=.02)
+ # trunc_normal_(self.pos_embed_cls, std=.02)
+ pos_embed_spatial = get_2d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ )
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ pos_embed_temporal = get_1d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[0], # t_size
+ )
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ else:
+ # trunc_normal_(self.pos_embed, std=.02)
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'cls_token'
+ }
+
+ def forward(self, x):
+ x = self.patch_embed(x.type(self.dtype))
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ pos_embed = torch.cat(
+ [
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
+ pos_embed,
+ ],
+ 1,
+ )
+ else:
+ pos_embed = self.pos_embed
+ x = x + pos_embed
+
+ residual = None
+ for blk in self.blocks:
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ if self.merge_method != 'proj':
+ # extra cls and avg
+ cls, avg = x[:, :1, :], x[:, 1:, :]
+ if 'avg1' in self.merge_method:
+ avg = avg.mean(1, keepdim=True) # (B, 1, C)
+ elif 'avgN' in self.merge_method:
+ avg = avg.view(B, T, L, C).mean(2) # (B, T, C)
+ final = self.down(torch.cat([cls, avg], dim=1)) # B, 1+T, C
+ if 'BN' in self.merge_norm:
+ final = self.down_norm(final.permute(0, 2, 1)).reshape(B, -1)
+ else:
+ final = self.down_norm(final).reshape(B, -1)
+
+ x = self.clip_projector(x)
+ x = self.fc_norm(x)
+
+ if self.merge_method == 'proj':
+ x = self.head(self.fc_dropout(x))
+ elif self.merge_method in ['cls_avg1', 'cls_avgN']:
+ x = self.head(self.fc_dropout(final))
+ elif self.merge_method in ['cls_avg1_proj', 'cls_avgN_proj']:
+ x = self.head(self.fc_dropout(torch.cat([final, x], dim=1)))
+ return x
+
+
+@register_model
+def internvideo2_cat_1B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=1408,
+ depth=40, num_heads=16, mlp_ratio=48/11,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+@register_model
+def internvideo2_cat_6B_patch14_224(pretrained=False, **kwargs):
+ model = InternVideo2(
+ img_size=224, patch_size=14, embed_dim=3200,
+ depth=48, num_heads=25, mlp_ratio=4,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+ img_size = 224
+
+ # model = internvideo2_cat_1B_patch14_224(num_classes=400).cuda().half()
+ model = internvideo2_cat_6B_patch14_224(
+ num_classes=400,
+ # merge_method='cls_avgN_proj',
+ merge_method='cls_avg1',
+ merge_norm='LN',
+ # merge_norm='kaiming_BN',
+ ).cuda().half()
+ print(model)
+
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, img_size, img_size).cuda().half())
+ s = time.time()
+ print(flop_count_table(flops, max_depth=1))
+ print(time.time()-s)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_pretrain.py b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de1e5bbd8cb506025b9435580fcc7dee79d8b53
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/internvideo2_pretrain.py
@@ -0,0 +1,800 @@
+import math
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
+from .flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
+
+ return x
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
+ x = x.squeeze(1)
+ return x
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
+ num_frames=8, tubelet_size=1, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (
+ num_frames // tubelet_size,
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1]
+ ) # (T, H, W)
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
+
+ self.proj = nn.Conv3d(
+ in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
+ stride=(tubelet_size, patch_size[0], patch_size[1])
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
+ x = self.norm(x)
+ return x
+
+
+class Linear_Decoder(nn.Module):
+ def __init__(self, in_channels=1408, out_channels=3200,
+ norm_layer=nn.LayerNorm, norm_type='l2'):
+ super().__init__()
+ self.norm_type = norm_type
+ print(f'Normalization Type: {norm_type}')
+
+ self.head = nn.Linear(in_channels, out_channels)
+ self.norm = norm_layer(out_channels)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ x = self.norm(self.head(x))
+
+ if self.norm_type == 'l2':
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ return x
+
+
+class MLP_Decoder(nn.Module):
+ def __init__(self, in_channels=768, out_channels=768,
+ norm_layer=nn.LayerNorm, norm_type='l2'):
+ super().__init__()
+ self.norm_type = norm_type
+ print(f'Normalization Type: {norm_type}')
+
+ self.head = nn.Sequential(
+ nn.Linear(in_channels, in_channels),
+ nn.GELU(),
+ nn.Linear(in_channels, out_channels)
+ )
+ self.norm = norm_layer(out_channels)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ x = self.norm(self.head(x))
+
+ if self.norm_type == 'l2':
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ return x
+
+
+class PretrainInternVideo2(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.25,
+ embed_dim: int = 1408,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.3637,
+ init_values: float = 1e-5,
+ qk_normalization: bool = True,
+ depth: int = 40,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = False,
+ num_frames: int = 8,
+ tubelet_size: int = 1,
+ sep_pos_embed: bool = False,
+ use_checkpoint: bool = False,
+ checkpoint_num: int = 0,
+ # for clip
+ clip_teacher_embed_dim: int = 3200,
+ clip_teacher_final_dim: int = 768, # if 0, not distill final features
+ clip_norm_type: str = 'l2',
+ clip_return_layer: int = 1,
+ clip_student_return_interval: int = 1,
+ # for mae
+ mae_teacher_embed_dim: int = 1408,
+ mae_norm_type: str = 'l2',
+ mae_return_layer: int = 1,
+ mae_student_return_interval: int = 1,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ self.clip_norm_type = clip_norm_type
+ self.clip_return_index = []
+ for i in range(clip_return_layer):
+ self.clip_return_index.append(depth - int(i * clip_student_return_interval) - 1)
+ print(f'CLIP Normalization Type: {clip_norm_type}')
+ print(f'CLIP Strudent Return Index: {self.clip_return_index}')
+
+ self.mae_norm_type = mae_norm_type
+ self.mae_return_index = []
+ for i in range(mae_return_layer):
+ self.mae_return_index.append(depth - int(i * mae_student_return_interval) - 1)
+ print(f'MAE Normalization Type: {mae_norm_type}')
+ print(f'MAE Strudent Return Index: {self.mae_return_index}')
+
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(
+ img_size, patch_size, in_chans, embed_dim,
+ num_frames=num_frames, tubelet_size=tubelet_size,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
+ self.sep_pos_embed = sep_pos_embed
+ if sep_pos_embed:
+ print("Use seperable position embedding")
+ grid_size = self.patch_embed.grid_size
+ self.grid_size = grid_size
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # for CLIP decoder
+ self.clip_pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.clip_pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ self.clip_pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # for MAE decoder
+ self.mae_pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
+ self.mae_pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
+ else:
+ print("Use joint position embedding")
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ # for CLIP decoder
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ # for MAE decoder
+ self.mae_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ # choose which layer to use checkpoint
+ with_cp_list = [False] * depth
+ if use_checkpoint:
+ for idx in range(depth):
+ if idx < checkpoint_num:
+ with_cp_list[idx] = True
+ print(f"Droppath rate: {dpr}")
+ print(f"Checkpoint list: {with_cp_list}")
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp_list[i],
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
+
+ # CLIP decoder
+ self.clip_decoder = nn.ModuleList([
+ Linear_Decoder(
+ in_channels=embed_dim,
+ out_channels=clip_teacher_embed_dim,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
+ norm_type=clip_norm_type
+ ) for _ in range(clip_return_layer)
+ ])
+ self.final_clip_decoder = nn.Identity()
+ if clip_teacher_final_dim > 0:
+ self.final_clip_decoder = Linear_Decoder(
+ in_channels=clip_embed_dim,
+ out_channels=clip_teacher_final_dim,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
+ norm_type=clip_norm_type
+ )
+
+ # MAE decoder
+ self.mae_decoder = nn.ModuleList([
+ MLP_Decoder(
+ in_channels=embed_dim,
+ out_channels=mae_teacher_embed_dim,
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
+ norm_type=mae_norm_type
+ ) for _ in range(mae_return_layer)
+ ])
+
+ self.init_pos_embed()
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def init_pos_embed(self):
+ print("Init pos_embed from sincos pos_embed")
+ if self.sep_pos_embed:
+ pos_embed_spatial = get_2d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ )
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ self.clip_pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ self.mae_pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
+ pos_embed_temporal = get_1d_sincos_pos_embed(
+ self.pos_embed_spatial.shape[-1],
+ self.patch_embed.grid_size[0], # t_size
+ )
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ self.clip_pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ self.mae_pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
+ else:
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[1], # height & weight
+ self.patch_embed.grid_size[0], # t_size
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+ self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+ self.mae_pos_embed.data.copy_(torch.from_numpy(pos_embed[1:]).float().unsqueeze(0))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'pos_embed',
+ 'pos_embed_spatial',
+ 'pos_embed_temporal',
+ 'pos_embed_cls',
+ 'cls_token',
+ 'clip_pos_embed',
+ 'clip_pos_embed_spatial',
+ 'clip_pos_embed_temporal',
+ 'clip_pos_embed_cls',
+ 'mae_pos_embed',
+ 'mae_pos_embed_spatial',
+ 'mae_pos_embed_temporal',
+ }
+
+ def forward(self, x, mask):
+ x = self.patch_embed(x.type(self.dtype))
+ B, T, L, C = x.shape # T: temporal; L: spatial
+ x = x.view([B, T * L, C])
+
+ # append cls token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add pos_embed
+ if self.sep_pos_embed:
+ pos_embed = self.pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ pos_embed = torch.cat(
+ [
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
+ pos_embed,
+ ],
+ 1,
+ )
+ else:
+ pos_embed = self.pos_embed
+ x = x + pos_embed
+
+ # mask tokens, ~mask means visible
+ x = x[~mask].reshape(B, -1, C)
+
+ residual = None
+ x_clip = []
+ x_mae = []
+ for idx, blk in enumerate(self.blocks):
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ # return intermediate features for CLIP
+ if idx in self.clip_return_index:
+ if isinstance(x, tuple) and len(x) == 2:
+ tmp_x, tmp_residual = x
+ if residual is not None:
+ x_clip.append(tmp_x + tmp_residual)
+ else:
+ x_clip.append(x)
+ # return intermediate features for MAE
+ if idx in self.mae_return_index:
+ if isinstance(x, tuple) and len(x) == 2:
+ tmp_x, tmp_residual = x
+ if residual is not None:
+ x_mae.append((tmp_x + tmp_residual)[:, 1:])
+ else:
+ x_mae.append(x[:, 1:])
+
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x = self.clip_projector(x)
+
+ # align CLIP
+ x_clip = torch.stack(x_clip)
+ K, B, _, C_CLIP = x_clip.shape
+ # add pos_embed
+ if self.sep_pos_embed:
+ clip_pos_embed = self.clip_pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.clip_pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ clip_pos_embed = torch.cat(
+ [
+ self.clip_pos_embed_cls.expand(clip_pos_embed.shape[0], -1, -1),
+ clip_pos_embed,
+ ],
+ 1,
+ )
+ else:
+ clip_pos_embed = self.clip_pos_embed
+ clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
+ x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
+ # CLIP decoder
+ x_clip_align = []
+ for idx, clip_decoder in enumerate(self.clip_decoder):
+ x_clip_align.append(clip_decoder(x_clip[idx]))
+ x_clip_align = torch.stack(x_clip_align)
+ x_align = self.final_clip_decoder(x)
+
+ # align MAE
+ x_mae = torch.stack(x_mae)
+ K, B, _, C_MAE = x_mae.shape
+ # add pos_embed
+ if self.sep_pos_embed:
+ mae_pos_embed = self.mae_pos_embed_spatial.repeat(
+ 1, self.grid_size[0], 1
+ ) + torch.repeat_interleave(
+ self.mae_pos_embed_temporal,
+ self.grid_size[1] * self.grid_size[2],
+ dim=1,
+ )
+ else:
+ mae_pos_embed = self.mae_pos_embed
+ mae_pos_embed = mae_pos_embed.repeat(B, 1, 1)
+ x_mae = x_mae + mae_pos_embed[~mask[:, 1:]].view(B, -1, C_MAE).unsqueeze(0).repeat(K, 1, 1, 1)
+ # MAE decoder
+ x_mae_align = []
+ for idx, mae_decoder in enumerate(self.mae_decoder):
+ x_mae_align.append(mae_decoder(x_mae[idx]))
+ x_mae_align = torch.stack(x_mae_align)
+
+ return x_clip_align, x_align, x_mae_align
+
+
+@register_model
+def pretrain_internvideo2_1B_patch14_224(pretrained=False, **kwargs):
+ model = PretrainInternVideo2(
+ img_size=224, patch_size=14, embed_dim=1408,
+ depth=40, num_heads=16, mlp_ratio=48/11,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+@register_model
+def pretrain_internvideo2_6B_patch14_224(pretrained=False, **kwargs):
+ model = PretrainInternVideo2(
+ img_size=224, patch_size=14, embed_dim=3200,
+ depth=48, num_heads=25, mlp_ratio=4,
+ attn_pool_num_heads=16, clip_embed_dim=768,
+ **kwargs
+ )
+ return model
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 8
+ img_size = 224
+
+ # model = pretrain_internvideo2_1B_patch14_224(clip_return_layer=6).cuda().half()
+ model = pretrain_internvideo2_6B_patch14_224(clip_return_layer=1).cuda().half()
+ # print(model)
+
+ # flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, img_size, img_size).cuda().half())
+ # s = time.time()
+ # print(flop_count_table(flops, max_depth=1))
+ # print(time.time()-s)
+
+ mask = torch.cat([
+ torch.zeros(1, 1),
+ torch.ones(1, 8 * int(16 * 16 * 0.75)),
+ torch.zeros(1, 8 * int(16 * 16 * 0.25)),
+ ], dim=-1).to(torch.bool).cuda()
+
+ output = model(torch.rand(4, 3, num_frames, img_size, img_size).cuda().half(), mask.repeat(4, 1))
+ print(output[0].shape)
+ print(output[1].shape)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/internvl_clip_vision.py b/third_party/InternVideo/InternVideo2/single_modality/models/internvl_clip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..d71bcc1508c30e09a5a0ffb9321375522804252f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/internvl_clip_vision.py
@@ -0,0 +1,555 @@
+import os
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath, to_2tuple
+from torch import nn
+
+import torch.utils.checkpoint as checkpoint
+from functools import partial
+from einops import rearrange
+
+from flash_attention_class import FlashAttention
+from flash_attn.modules.mlp import FusedMLP
+from flash_attn.ops.rms_norm import DropoutAddRMSNorm
+
+
+MODEL_PATH = 'your_model_path/internvl'
+_MODELS = {
+ # see InternVL
+ "internvl_c_13b_224px": os.path.join(MODEL_PATH, "internvl_c_13b_224px.pth"),
+}
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None, out_dim=None):
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ assert all_head_dim == dim
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, out_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, k=None, v=None, return_attn=False):
+ B, N, C = x.shape
+ N_k = k.shape[1]
+ N_v = v.shape[1]
+
+ q_bias, k_bias, v_bias = None, None, None
+ if self.q_bias is not None:
+ q_bias = self.q_bias
+ k_bias = self.k_bias
+ v_bias = self.v_bias
+
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
+
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ if return_attn:
+ return x, attn.mean(1) # (B, n_head, n_q, C) => (B, n_q, C)
+ else:
+ return x, None
+
+
+class AttentiveBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
+ super().__init__()
+
+ self.norm1_q = norm_layer(dim)
+ self.norm1_k = norm_layer(dim)
+ self.norm1_v = norm_layer(dim)
+ self.cross_attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None, return_attn=False):
+ x_q = self.norm1_q(x_q + pos_q)
+ x_k = self.norm1_k(x_kv + pos_k)
+ x_v = self.norm1_v(x_kv)
+ x, attn = self.cross_attn(x_q, k=x_k, v=x_v, return_attn=return_attn)
+ return x, attn
+
+
+class AttentionPoolingBlock(AttentiveBlock):
+
+ def forward(self, x, return_attn=False):
+ x_q = x.mean(1, keepdim=True)
+ x_kv, pos_q, pos_k = x, 0, 0
+ x, attn = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None, return_attn=return_attn)
+ x = x.squeeze(1)
+ return x, attn
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+ self.force_fp32 = force_fp32
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def forward(self, x):
+ if self.force_fp32:
+ output_type = x.dtype
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
+ return out.to(dtype=output_type)
+ else:
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_flash_attn = use_flash_attn
+ if use_flash_attn:
+ self.causal = causal
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
+
+ self.qk_normalization = qk_normalization
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ if self.use_fused_rmsnorm:
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
+ else:
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
+ )
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, x):
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
+ use_fused_rmsnorm=False):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
+ qk_normalization=qk_normalization,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ self.ls1 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ if use_fused_mlp:
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values,
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.with_cp = with_cp
+ self.use_fused_rmsnorm = use_fused_rmsnorm
+
+ def forward(self, x, residual=None):
+
+ def _inner_forward(x, residual=None):
+ if self.use_fused_rmsnorm:
+ x, residual = self.norm1(x, residual)
+ x = self.drop_path1(self.ls1(self.attn(x)))
+ x, residual = self.norm2(x, residual)
+ x = self.drop_path2(self.ls2(self.mlp(x)))
+ return x, residual
+ else:
+ assert residual is None
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+ if self.with_cp:
+ return checkpoint.checkpoint(_inner_forward, x, residual)
+ else:
+ return _inner_forward(x, residual=residual)
+
+
+class PatchEmbed(nn.Module):
+ """ 3D Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv3d(
+ in_chans, embed_dim,
+ kernel_size=(1, patch_size[0], patch_size[1]),
+ stride=(1, patch_size[0], patch_size[1]),
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(3).permute(0, 2, 3, 1) # (N, C, T, H, W) => (N, T, H * W, C)
+ x = self.norm(x)
+ return x
+
+
+class InternVL_CLIP(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ patch_size: int = 14,
+ img_size: int = 224,
+ qkv_bias: bool = False,
+ drop_path_rate: float = 0.2,
+ embed_dim: int = 3200,
+ num_heads: int = 25,
+ mlp_ratio: int = 4,
+ init_values: float = 0.1,
+ qk_normalization: bool = True,
+ depth: int = 48,
+ use_flash_attn: bool = True,
+ use_fused_rmsnorm: bool = True,
+ use_fused_mlp: bool = True,
+ fused_mlp_heuristic: int = 1,
+ with_cp: bool = False,
+ attn_pool_num_heads: int = 16,
+ clip_embed_dim: int = 768,
+ layerscale_no_force_fp32: bool = True,
+ # for unmasked teacher
+ clip_norm_type: str = 'l2',
+ return_attn: bool = True,
+ clip_return_layer: int = 1,
+ clip_return_interval: int = 1,
+ ):
+ super().__init__()
+
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, print(
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
+
+ self.use_flash_attn = use_flash_attn
+ self.embed_dim = embed_dim
+
+ self.clip_norm_type = clip_norm_type
+ self.return_attn = return_attn
+ self.return_index = []
+ for i in range(clip_return_layer):
+ self.return_index.append(depth - int(i * clip_return_interval) - 1)
+ print(f'Normalization Type: {clip_norm_type}')
+ print(f'Return Attention: {return_attn}')
+ print(f'Teacher Return Interval: {self.return_index}')
+
+ """ only use image encoder of InternVL """
+ if use_fused_rmsnorm:
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
+ else:
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
+ self.norm_layer_for_blocks = norm_layer_for_blocks
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ self.blocks = nn.ModuleList([
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
+ norm_layer=norm_layer_for_blocks,
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
+ fused_mlp_heuristic=fused_mlp_heuristic,
+ with_cp=with_cp,
+ qk_normalization=qk_normalization,
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
+ use_fused_rmsnorm=use_fused_rmsnorm)
+ for i in range(depth)])
+ self.clip_projector = AttentionPoolingBlock(
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
+
+ @property
+ def dtype(self):
+ return self.patch_embed.proj.weight.dtype
+
+ def forward(self, image):
+ x = self.patch_embed(image.type(self.dtype))
+ B, T, HW, C = x.size()
+ x = x.reshape(B * T, HW, C)
+
+ cls_tokens = self.cls_token.expand(B * T, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+
+ residual = None
+ z = []
+ for idx, blk in enumerate(self.blocks):
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ x = blk(x, residual=residual)
+ # return intermediate features
+ if idx in self.return_index:
+ if isinstance(x, tuple) and len(x) == 2:
+ tmp_x, tmp_residual = x
+ if residual is not None:
+ z.append(tmp_x + tmp_residual)
+ else:
+ z.append(x)
+
+ if isinstance(x, tuple) and len(x) == 2:
+ x, residual = x
+ if residual is not None:
+ x = x + residual
+
+ x, attn = self.clip_projector(x, return_attn=self.return_attn)
+
+ if self.clip_norm_type == 'l2':
+ # normalization of intermediate features
+ z = torch.stack(z) # (K, BT, HW+1, C)
+ K = z.shape[0]
+ cls_tokens, z = z[:, :, :1, :], z[:, :, 1:, :]
+ cls_tokens = cls_tokens.view(K, B, T, 1, C).mean(2) # (K, BT, 1, C) => (K, B, 1, C)
+ z = z.reshape(K, B, T * HW, C)
+ z = torch.cat((cls_tokens, z), dim=2) # (K, B, HWT+1, C)
+ z = z / z.norm(dim=-1, keepdim=True)
+ # normalization of final features
+ x = x.view(B, T, -1).mean(1) # (BT, C) => (B, C)
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.clip_norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ if self.return_attn:
+ return z, x, attn[:, 0, 1:] # (B * T, HW)
+ else:
+ return z, x
+
+
+def inflate_weight(weight_2d, time_dim, center=True):
+ print(f'Init center: {center}')
+ if center:
+ weight_3d = torch.zeros(*weight_2d.shape)
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ middle_idx = time_dim // 2
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
+ else:
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+ weight_3d = weight_3d / time_dim
+ return weight_3d
+
+
+def process_checkpoint(ckpt, model):
+ new_ckpt = {}
+ state_dict_3d = model.state_dict()
+ for k, v in ckpt['module'].items():
+ new_k = k
+ if 'patch_embed' in new_k and new_k in state_dict_3d.keys() and v.shape != state_dict_3d[new_k].shape:
+ print(new_k)
+ print(f'Inflate: {k}, {v.shape} => {state_dict_3d[new_k].shape}')
+ time_dim = state_dict_3d[new_k].shape[2]
+ v = inflate_weight(v, time_dim)
+ new_ckpt[new_k] = v
+
+ # interpolate position embedding
+ pos_embed_checkpoint = new_ckpt['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.num_patches
+ orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
+ new_size = int(num_patches ** 0.5)
+ if orig_size != new_size:
+ print(f'pos_embed from {orig_size} to {new_size}')
+ extra_tokens = pos_embed_checkpoint[:, :1]
+ pos_tokens = pos_embed_checkpoint[:, 1:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ new_ckpt['pos_embed'] = new_pos_embed
+
+ return new_ckpt
+
+
+def internvl_clip_6b(
+ img_size,
+ clip_norm_type='l2',
+ return_attn=True,
+ clip_return_layer=1,
+ clip_return_interval=1
+ ):
+ model = InternVL_CLIP(
+ img_size=img_size,
+ layerscale_no_force_fp32=False,
+ clip_norm_type=clip_norm_type,
+ return_attn=return_attn,
+ clip_return_layer=clip_return_layer,
+ clip_return_interval=clip_return_interval,
+ )
+
+ ckpt = torch.load(_MODELS["internvl_c_13b_224px"], map_location='cpu')
+ new_ckpt = process_checkpoint(ckpt, model)
+ message = model.load_state_dict(new_ckpt, strict=False)
+ print(message)
+ return model.eval()
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ num_frames = 8
+ img_size = 224
+ video = torch.rand(1, 3, num_frames, img_size, img_size).cuda().half()
+
+ model = internvl_clip_6b(img_size).cuda().half()
+ # flops = FlopCountAnalysis(model, video)
+ model(video)
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/pos_embed.py b/third_party/InternVideo/InternVideo2/single_modality/models/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..032c5ccffd22e01a3aedbc9279afb359896c4396
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/pos_embed.py
@@ -0,0 +1,131 @@
+import numpy as np
+
+
+# --------------------------------------------------------
+# 3D sine-cosine position embedding
+# References:
+# MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
+# --------------------------------------------------------
+def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ t_size: int of the temporal size
+ return:
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ assert embed_dim % 4 == 0
+ embed_dim_spatial = embed_dim // 4 * 3
+ embed_dim_temporal = embed_dim // 4
+
+ # spatial
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
+ embed_dim_spatial, grid
+ )
+
+ # temporal
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
+ embed_dim_temporal, grid_t
+ )
+
+ # concate: [T, H, W] order
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(
+ pos_embed_temporal, grid_size**2, axis=1
+ ) # [T, H*W, D // 4]
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(
+ pos_embed_spatial, t_size, axis=0
+ ) # [T, H*W, D // 4 * 3]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
+
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
+ """
+ t_size: int of the temporal size
+ return:
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
+ if cls_token:
+ pos_embed = np.concatenate(
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[0]
+ ) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(
+ embed_dim // 2, grid[1]
+ ) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/models/videomae.py b/third_party/InternVideo/InternVideo2/single_modality/models/videomae.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de10c015da058208a5fed46c3c0912698f527f3
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/models/videomae.py
@@ -0,0 +1,361 @@
+import os
+from functools import partial
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from flash_attn import flash_attn_func
+
+
+MODEL_PATH = 'your_model_path/videomae'
+_MODELS = {
+ # see videomaev2
+ "vit_g14_hybrid": os.path.join(MODEL_PATH, "vit_g_hybrid_1200e_pre.pth"),
+}
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ **kwargs
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ x = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=self.scale, causal=False).reshape(B, N, -1)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.tubelet_size = int(tubelet_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
+ stride=(self.tubelet_size, patch_size[0], patch_size[1]))
+
+ def forward(self, x, **kwargs):
+ B, C, T, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+# sin-cos position encoding
+# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
+def get_sinusoid_encoding_table(n_position, d_hid, cur_frame=-1, pre_n_position=1568):
+ ''' Sinusoid position encoding table '''
+ # TODO: make it with torch instead of numpy
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ # generate checkpoint position embedding
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
+ print(f"n_position: {n_position}")
+ print(f"pre_n_position: {pre_n_position}")
+ if n_position // cur_frame * 8 != pre_n_position and cur_frame != -1:
+ T = 8 # checkpoint frame
+ P = 14 # checkpoint size
+ C = d_hid
+ new_P = int((n_position // cur_frame) ** 0.5) # testing size
+ print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
+ print(f'Interpolate the position embedding')
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
+ sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
+ sinusoid_table = torch.nn.functional.interpolate(
+ sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
+ if cur_frame != -1 and cur_frame != 8:
+ print(f'Pretraining uses 8 frames, but current frame is {cur_frame}')
+ print(f'Interpolate the position embedding')
+ T = 8 # checkpoint frame
+ new_T = cur_frame # testing frame
+ # interpolate
+ P = int((n_position // cur_frame) ** 0.5) # testing size
+ C = d_hid
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
+ sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
+ sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
+ if n_position == pre_n_position:
+ return sinusoid_table
+ else:
+ print("Use learnable position embedding")
+ return nn.Parameter(sinusoid_table, requires_grad=True)
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ init_values=0.,
+ all_frames=16,
+ tubelet_size=2,
+ mae_norm_type='l2',
+ mae_return_layer=1,
+ mae_return_interval=1,
+ ):
+ super().__init__()
+ self.mae_norm_type = mae_norm_type
+ self.return_index = []
+ for i in range(mae_return_layer):
+ self.return_index.append(depth - int(i * mae_return_interval) - 1)
+ print(f'Normalization Type: {mae_norm_type}')
+ print(f'MAE Teacher return index: : {self.return_index}')
+
+ self.tubelet_size = tubelet_size
+ self.depth = depth
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
+ num_patches = self.patch_embed.num_patches
+
+ # sine-cosine positional embeddings is on the way
+ if patch_size == 14:
+ pre_n_position = 2048
+ else:
+ pre_n_position = 1568
+ self.pos_embed = get_sinusoid_encoding_table(
+ num_patches, embed_dim, all_frames // tubelet_size,
+ pre_n_position=pre_n_position
+ )
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, mask=None):
+ x = self.patch_embed(x)
+ B, _, C = x.size()
+
+ if self.pos_embed is not None:
+ x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
+ x = self.pos_drop(x)
+
+ if mask is not None:
+ x = x[~mask].reshape(B, -1, C) # ~mask means visible
+
+ z = []
+ for idx, blk in enumerate(self.blocks):
+ x = blk(x)
+ if idx == self.depth - 1:
+ x = self.norm(x)
+ if idx in self.return_index:
+ z.append(x)
+ x = torch.stack(z)
+
+ if self.mae_norm_type == 'l2':
+ x = x / x.norm(dim=-1, keepdim=True)
+ elif self.mae_norm_type == 'none':
+ pass
+ else:
+ raise NotImplementedError
+
+ return x
+
+
+def load_state_dict(model, state_dict):
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('encoder.'):
+ new_k = k[8:]
+ if new_k == "patch_embed.proj.weight" and model.tubelet_size == 1:
+ print("Kernel pooling")
+ v = v.mean(dim=2, keepdim=True)
+ new_state_dict[new_k] = v
+ msg = model.load_state_dict(new_state_dict)
+ print(msg)
+
+
+def mae_g14_hybrid(pretrained=True, **kwargs):
+ model = VisionTransformer(
+ patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ if pretrained:
+ print('load MAE pretrained weights')
+ state_dict = torch.load(_MODELS["vit_g14_hybrid"], map_location='cpu')
+ load_state_dict(model, state_dict['model'])
+ return model
+
+
+if __name__ == '__main__':
+ import time
+ from fvcore.nn import FlopCountAnalysis
+ from fvcore.nn import flop_count_table
+ import numpy as np
+
+ seed = 4217
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ num_frames = 16
+
+ model = mae_g14_hybrid(all_frames=num_frames, tubelet_size=2).cuda().half()
+ # print(model)
+
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224).cuda().half())
+ s = time.time()
+ print(flop_count_table(flops, max_depth=1))
+ print(time.time()-s)
+ # print(model(torch.rand(1, 3, num_frames, 224, 224)).shape)
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/optim_factory.py b/third_party/InternVideo/InternVideo2/single_modality/optim_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb4f4a08544abcfbcf45d07ec5563faae4a817b7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/optim_factory.py
@@ -0,0 +1,190 @@
+import torch
+from torch import optim as optim
+
+from timm.optim.adafactor import Adafactor
+from timm.optim.adahessian import Adahessian
+from timm.optim.adamp import AdamP
+from timm.optim.lookahead import Lookahead
+from timm.optim.nadam import Nadam
+# from timm.optim.novograd import NovoGrad
+from timm.optim.nvnovograd import NvNovoGrad
+from timm.optim.radam import RAdam
+from timm.optim.rmsprop_tf import RMSpropTF
+from timm.optim.sgdp import SGDP
+
+import json
+
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def get_num_layer_for_vit(var_name, num_max_layer):
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
+ return 0
+ elif var_name.startswith("patch_embed"):
+ return 0
+ elif var_name.startswith("rel_pos_bias"):
+ return num_max_layer - 1
+ elif var_name.startswith("blocks"):
+ layer_id = int(var_name.split('.')[1])
+ return layer_id + 1
+ elif var_name.startswith("transformer.resblocks"):
+ layer_id = int(var_name.split('.')[2])
+ return layer_id + 1
+ elif var_name in ("class_embedding", "positional_embedding", "temporal_positional_embedding"):
+ return 0
+ elif var_name.startswith("conv1"):
+ return 0
+ else:
+ return num_max_layer - 1
+
+
+class LayerDecayValueAssigner(object):
+ def __init__(self, values):
+ self.values = values
+
+ def get_scale(self, layer_id):
+ return self.values[layer_id]
+
+ def get_layer_id(self, var_name):
+ return get_num_layer_for_vit(var_name, len(self.values))
+
+
+def get_parameter_groups(
+ model, weight_decay=1e-5, skip_list=(), get_num_layer=None,
+ get_layer_scale=None,
+ ):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ group_name = "no_decay"
+ this_weight_decay = 0.
+ else:
+ group_name = "decay"
+ this_weight_decay = weight_decay
+ if get_num_layer is not None:
+ layer_id = get_num_layer(name)
+ group_name = "layer_%d_%s" % (layer_id, group_name)
+ else:
+ layer_id = None
+
+ if group_name not in parameter_group_names:
+ if get_layer_scale is not None:
+ scale = get_layer_scale(layer_id)
+ else:
+ scale = 1.
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale
+ }
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ return list(parameter_group_vars.values())
+
+
+def create_optimizer(
+ args, model, get_num_layer=None, get_layer_scale=None,
+ filter_bias_and_bn=True, skip_list=None
+ ):
+ opt_lower = args.opt.lower()
+ weight_decay = args.weight_decay
+ if weight_decay and filter_bias_and_bn:
+ skip = {}
+ if skip_list is not None:
+ skip = skip_list
+ elif hasattr(model, 'no_weight_decay'):
+ skip = model.no_weight_decay()
+ parameters = get_parameter_groups(
+ model, weight_decay, skip, get_num_layer, get_layer_scale,
+ )
+ weight_decay = 0.
+ else:
+ parameters = model.parameters()
+
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
+ opt_args['eps'] = args.opt_eps
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
+ opt_args['betas'] = args.opt_betas
+
+ print("optimizer settings:", opt_args)
+
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ elif opt_lower == 'nadam':
+ optimizer = Nadam(parameters, **opt_args)
+ elif opt_lower == 'radam':
+ optimizer = RAdam(parameters, **opt_args)
+ elif opt_lower == 'adamp':
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
+ elif opt_lower == 'sgdp':
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'adadelta':
+ optimizer = optim.Adadelta(parameters, **opt_args)
+ elif opt_lower == 'adafactor':
+ if not args.lr:
+ opt_args['lr'] = None
+ optimizer = Adafactor(parameters, **opt_args)
+ elif opt_lower == 'adahessian':
+ optimizer = Adahessian(parameters, **opt_args)
+ elif opt_lower == 'rmsprop':
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
+ elif opt_lower == 'rmsproptf':
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
+ # elif opt_lower == 'novograd':
+ # optimizer = NovoGrad(parameters, **opt_args)
+ elif opt_lower == 'nvnovograd':
+ optimizer = NvNovoGrad(parameters, **opt_args)
+ elif opt_lower == 'fusedsgd':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'fusedmomentum':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'fusedadam':
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
+ elif opt_lower == 'fusedadamw':
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
+ elif opt_lower == 'fusedlamb':
+ optimizer = FusedLAMB(parameters, **opt_args)
+ elif opt_lower == 'fusednovograd':
+ opt_args.setdefault('betas', (0.95, 0.98))
+ optimizer = FusedNovoGrad(parameters, **opt_args)
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+
+ if len(opt_split) > 1:
+ if opt_split[0] == 'lookahead':
+ optimizer = Lookahead(optimizer)
+
+ return optimizer
diff --git a/third_party/InternVideo/InternVideo2/single_modality/requirements.txt b/third_party/InternVideo/InternVideo2/single_modality/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..49a3bbb7e4b460620db556bf3692366960d11242
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/requirements.txt
@@ -0,0 +1,17 @@
+apex==0.9.10dev
+auto_augment==1.0.0
+decord==0.6.0
+deepspeed==0.10.1
+einops==0.7.0
+flash_attn==2.0.8
+fvcore==0.1.5.post20221221
+numpy==1.24.4
+opencv_python==4.8.0.76
+pandas==2.0.3
+Pillow==10.0.0
+scipy==1.13.0
+skimage==0.0
+tensorboardX==2.6.2
+timm==0.5.4
+torch==1.13.1+cu117
+torchvision==0.14.1+cu117
diff --git a/third_party/InternVideo/InternVideo2/single_modality/run_finetuning.py b/third_party/InternVideo/InternVideo2/single_modality/run_finetuning.py
new file mode 100755
index 0000000000000000000000000000000000000000..3687206e3fbc2a28cc44077593ecae84334d70d2
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/run_finetuning.py
@@ -0,0 +1,718 @@
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.backends.cudnn as cudnn
+import json
+import os
+from functools import partial
+from pathlib import Path
+from collections import OrderedDict
+
+from datasets.mixup import Mixup
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import ModelEma
+from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
+
+from datasets import build_dataset
+from single_modality.engines.engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge
+from utils import NativeScalerWithGradNormCount as NativeScaler
+from utils import multiple_samples_collate
+import utils
+from models import *
+from models.internvl_clip_vision import inflate_weight
+
+
+def get_args():
+ parser = argparse.ArgumentParser('VideoMAE fine-tuning and evaluation script for video classification', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int)
+ parser.add_argument('--epochs', default=30, type=int)
+ parser.add_argument('--update_freq', default=1, type=int)
+ parser.add_argument('--save_ckpt_freq', default=100, type=int)
+ parser.add_argument('--steps_per_print', default=1, type=int)
+ parser.add_argument('--use_ceph_checkpoint', action='store_true',
+ help="whether use ceph to save and load checkpoint, may be some bug now")
+ parser.set_defaults(use_ceph_checkpoint=False)
+ parser.add_argument('--ceph_checkpoint_prefix', default='', type=str,
+ help='prefix for checkpoint in ceph')
+ parser.add_argument('--ckpt_path_split', default='/exp/', type=str,
+ help='string for splitting the ckpt_path')
+
+ # Model parameters
+ parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--tubelet_size', type=int, default=2)
+ parser.add_argument('--input_size', default=224, type=int,
+ help='videos input size')
+ parser.add_argument('--layer_scale_init_value', default=1e-5, type=float,
+ help="0.1 for base, 1e-5 for large. set 0 to disable LayerScale")
+ parser.add_argument('--layerscale_no_force_fp32', action='store_true',
+ help="Not force fp32 for LayerScale")
+ parser.set_defaults(layerscale_no_force_fp32=False)
+ parser.add_argument('--sep_pos_embed', action='store_true',
+ help="whether use seperable position embedding")
+ parser.add_argument('--center_init', action='store_true',
+ help="center initlization for patch embedding")
+
+ parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
+ help='Attention dropout rate (default: 0.)')
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
+ help='Drop path rate (default: 0.1)')
+ parser.add_argument('--head_drop_path', type=float, default=0.0, metavar='PCT',
+ help='Head Drop path rate (default: 0.0)')
+
+ parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
+ parser.add_argument('--model_ema', action='store_true', default=False)
+ parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
+ parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
+
+ # Optimizer parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
+ weight decay. We use a cosine schedule for WD and using a larger decay by
+ the end of training improves performance for ViTs.""")
+
+ parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
+ help='learning rate (default: 1e-3)')
+ parser.add_argument('--layer_decay', type=float, default=0.75)
+
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
+ help='warmup learning rate (default: 1e-6)')
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
+
+ parser.add_argument('--warmup_epochs', type=float, default=5, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
+
+ # Augmentation parameters
+ parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
+ help='Color jitter factor (default: 0.4)')
+ parser.add_argument('--num_sample', type=int, default=2,
+ help='Repeated_aug (default: 2)')
+ parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
+ parser.add_argument('--smoothing', type=float, default=0.1,
+ help='Label smoothing (default: 0.1)')
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+
+ # Evaluation parameters
+ parser.add_argument('--crop_pct', type=float, default=None)
+ parser.add_argument('--short_side_size', type=int, default=224)
+ parser.add_argument('--test_num_segment', type=int, default=5)
+ parser.add_argument('--test_num_crop', type=int, default=3)
+
+ # Random Erase params
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+ help='Random erase prob (default: 0.25)')
+ parser.add_argument('--remode', type=str, default='pixel',
+ help='Random erase mode (default: "pixel")')
+ parser.add_argument('--recount', type=int, default=1,
+ help='Random erase count (default: 1)')
+ parser.add_argument('--resplit', action='store_true', default=False,
+ help='Do not random erase first (clean) augmentation split')
+
+ # Mixup params
+ parser.add_argument('--mixup', type=float, default=0.8,
+ help='mixup alpha, mixup enabled if > 0.')
+ parser.add_argument('--cutmix', type=float, default=1.0,
+ help='cutmix alpha, cutmix enabled if > 0.')
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
+ help='Probability of performing mixup or cutmix when either/both is enabled')
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
+ parser.add_argument('--mixup_mode', type=str, default='batch',
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+ # Finetuning params
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
+ parser.add_argument('--delete_head', action='store_true', help='whether delete head')
+ parser.add_argument('--model_key', default='model|module', type=str)
+ parser.add_argument('--model_prefix', default='', type=str)
+ parser.add_argument('--init_scale', default=0.001, type=float)
+ parser.add_argument('--use_checkpoint', action='store_true')
+ parser.set_defaults(use_checkpoint=False)
+ parser.add_argument('--checkpoint_num', default=0, type=int,
+ help='number of layers for using checkpoint')
+ parser.add_argument('--use_mean_pooling', action='store_true')
+ parser.set_defaults(use_mean_pooling=True)
+ parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
+
+ # Dataset parameters
+ parser.add_argument('--prefix', default='', type=str, help='prefix for data')
+ parser.add_argument('--split', default=' ', type=str, help='split for metadata')
+ parser.add_argument('--filename_tmpl', default='img_{:05}.jpg', type=str, help='file template')
+ parser.add_argument('--data_path', default='you_data_path', type=str,
+ help='dataset path')
+ parser.add_argument('--eval_data_path', default=None, type=str,
+ help='dataset path for evaluation')
+ parser.add_argument('--nb_classes', default=400, type=int,
+ help='number of the classification types')
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
+ parser.add_argument('--use_decord', action='store_true',
+ help='whether use decord to load video, otherwise load image')
+ parser.add_argument('--no_use_decord', action='store_false', dest='use_decord')
+ parser.set_defaults(use_decord=True)
+ parser.add_argument('--num_segments', type=int, default=1)
+ parser.add_argument('--num_frames', type=int, default=16)
+ parser.add_argument('--sampling_rate', type=int, default=4)
+ parser.add_argument('--data_set', default='Kinetics', choices=[
+ 'Kinetics', 'Kinetics_sparse',
+ 'SSV2', 'UCF101', 'HMDB51', 'image_folder',
+ 'mitv1_sparse',
+ 'ANet', 'HACS', 'ANet_interval', 'HACS_interval',
+ ], type=str, help='dataset')
+ parser.add_argument('--output_dir', default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default=None,
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--resume', default='',
+ help='resume from checkpoint')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
+ parser.set_defaults(auto_resume=True)
+
+ parser.add_argument('--save_ckpt', action='store_true')
+ parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
+ parser.set_defaults(save_ckpt=True)
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--test_best', action='store_true',
+ help='Whether test the best model')
+ parser.add_argument('--eval', action='store_true',
+ help='Perform evaluation only')
+ parser.add_argument('--dist_eval', action='store_true', default=False,
+ help='Enabling distributed evaluation')
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ parser.add_argument('--enable_deepspeed', action='store_true', default=False)
+ parser.add_argument('--bf16', default=False, action='store_true')
+ parser.add_argument('--zero_stage', default=0, type=int,
+ help='ZeRO optimizer stage (default: 0)')
+
+ known_args, _ = parser.parse_known_args()
+
+ if known_args.enable_deepspeed:
+ try:
+ import deepspeed
+ from deepspeed import DeepSpeedConfig
+ parser = deepspeed.add_config_arguments(parser)
+ ds_init = deepspeed.initialize
+ except:
+ print("Please 'pip install deepspeed'")
+ exit(0)
+ else:
+ ds_init = None
+
+ return parser.parse_args(), ds_init
+
+
+def main(args, ds_init):
+ utils.init_distributed_mode(args)
+
+ if ds_init is not None:
+ utils.create_internvideo2_ds_config(args)
+
+ print(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ # random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args)
+ if args.disable_eval_during_finetuning:
+ dataset_val = None
+ else:
+ dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args)
+ dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args)
+
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ sampler_test = torch.utils.data.DistributedSampler(
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if global_rank == 0 and args.log_dir is not None:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if args.num_sample > 1:
+ collate_func = partial(multiple_samples_collate, fold=False)
+ else:
+ collate_func = None
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ collate_fn=collate_func,
+ persistent_workers=True
+ )
+
+ if dataset_val is not None:
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False,
+ persistent_workers=True
+ )
+ else:
+ data_loader_val = None
+
+ if dataset_test is not None:
+ data_loader_test = torch.utils.data.DataLoader(
+ dataset_test, sampler=sampler_test,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False,
+ persistent_workers=True
+ )
+ else:
+ data_loader_test = None
+
+ mixup_fn = None
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ if mixup_active:
+ print("Mixup is activated!")
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+ model = create_model(
+ args.model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ num_frames=args.num_frames * args.num_segments,
+ tubelet_size=args.tubelet_size,
+ sep_pos_embed=args.sep_pos_embed,
+ fc_drop_rate=args.fc_drop_rate,
+ drop_path_rate=args.drop_path,
+ head_drop_path_rate=args.head_drop_path,
+ use_checkpoint=args.use_checkpoint,
+ checkpoint_num=args.checkpoint_num,
+ init_scale=args.init_scale,
+ init_values=args.layer_scale_init_value,
+ layerscale_no_force_fp32=args.layerscale_no_force_fp32,
+ )
+
+ patch_size = model.patch_embed.patch_size
+ print("Patch size = %s" % str(patch_size))
+ args.window_size = (args.num_frames // args.tubelet_size, args.input_size // patch_size[0], args.input_size // patch_size[1])
+ args.patch_size = patch_size
+
+ if args.finetune:
+ if args.finetune.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.finetune, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.finetune, map_location='cpu')
+
+ print("Load ckpt from %s" % args.finetune)
+ checkpoint_model = None
+ for model_key in args.model_key.split('|'):
+ if model_key in checkpoint:
+ checkpoint_model = checkpoint[model_key]
+ print("Load state_dict by model_key = %s" % model_key)
+ break
+ if checkpoint_model is None:
+ checkpoint_model = checkpoint
+
+ if 'head.weight' in checkpoint_model.keys():
+ if args.delete_head:
+ print("Removing head from pretrained checkpoint")
+ del checkpoint_model['head.weight']
+ del checkpoint_model['head.bias']
+ elif checkpoint_model['head.weight'].shape[0] == 710:
+ if args.nb_classes == 400:
+ checkpoint_model['head.weight'] = checkpoint_model['head.weight'][:args.nb_classes]
+ checkpoint_model['head.bias'] = checkpoint_model['head.bias'][:args.nb_classes]
+ elif args.nb_classes in [600, 700]:
+ # download from https://drive.google.com/drive/folders/17cJd2qopv-pEG8NSghPFjZo1UUZ6NLVm
+ map_path = f'./k710/label_mixto{args.nb_classes}.json'
+ print(f'Load label map from {map_path}')
+ with open(map_path) as f:
+ label_map = json.load(f)
+ checkpoint_model['head.weight'] = checkpoint_model['head.weight'][label_map]
+ checkpoint_model['head.bias'] = checkpoint_model['head.bias'][label_map]
+
+ all_keys = list(checkpoint_model.keys())
+ new_dict = OrderedDict()
+ for key in all_keys:
+ if key.startswith('backbone.'):
+ new_dict[key[9:]] = checkpoint_model[key]
+ elif key.startswith('encoder.'):
+ new_dict[key[8:]] = checkpoint_model[key]
+ else:
+ new_dict[key] = checkpoint_model[key]
+ checkpoint_model = new_dict
+
+ if checkpoint_model['patch_embed.proj.weight'].shape[2] == 1 and model.patch_embed.tubelet_size > 1:
+ print("Inflate patch embedding")
+ print(f"Use center initilization: {args.center_init}")
+ checkpoint_model['patch_embed.proj.weight'] = inflate_weight(
+ checkpoint_model['patch_embed.proj.weight'][:, :, 0],
+ model.patch_embed.tubelet_size,
+ center=args.center_init
+ )
+
+ # interpolate position embedding
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
+
+ # we use 8 frames for pretraining
+ orig_t_size = 8
+ new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // (new_t_size))** 0.5)
+
+ # class_token and dist_token are kept unchanged
+ if orig_t_size != new_t_size:
+ print(f"Temporal interpolate from {orig_t_size} to {new_t_size}")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+ pos_embed_checkpoint = new_pos_embed
+
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+ elif 'pos_embed_spatial' in checkpoint_model and 'pos_embed_temporal' in checkpoint_model:
+ pos_embed_spatial_checkpoint = checkpoint_model['pos_embed_spatial']
+ pos_embed_temporal_checkpoint = checkpoint_model['pos_embed_temporal']
+
+ embedding_size = pos_embed_spatial_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+
+ orig_t_size = pos_embed_temporal_checkpoint.shape[-2]
+ new_t_size = args.num_frames // model.patch_embed.tubelet_size
+
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(pos_embed_spatial_checkpoint.shape[-2] ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // new_t_size) ** 0.5)
+
+ if orig_t_size != new_t_size:
+ print(f"Temporal interpolate from {orig_t_size} to {new_t_size}")
+ tmp_pos_embed = pos_embed_temporal_checkpoint.view(1, orig_t_size, -1, embedding_size)
+ tmp_pos_embed = tmp_pos_embed.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ tmp_pos_embed = torch.nn.functional.interpolate(tmp_pos_embed, size=new_t_size, mode='linear')
+ tmp_pos_embed = tmp_pos_embed.view(1, -1, embedding_size, new_t_size)
+ tmp_pos_embed = tmp_pos_embed.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ checkpoint_model['pos_embed_temporal'] = tmp_pos_embed
+
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ pos_tokens = pos_embed_spatial_checkpoint
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ checkpoint_model['pos_embed_spatial'] = pos_tokens
+
+ utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
+
+ model.to(device)
+
+ model_ema = None
+ if args.model_ema:
+ model_ema = ModelEma(
+ model,
+ decay=args.model_ema_decay,
+ device='cpu' if args.model_ema_force_cpu else '',
+ resume='')
+ print("Using EMA with decay = %.8f" % args.model_ema_decay)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params:', n_parameters)
+
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
+ args.lr = args.lr * total_batch_size * args.num_sample / 256
+ args.min_lr = args.min_lr * total_batch_size * args.num_sample / 256
+ args.warmup_lr = args.warmup_lr * total_batch_size * args.num_sample / 256
+ print("LR = %.8f" % args.lr)
+ print("Batch size = %d" % total_batch_size)
+ print("Repeated sample = %d" % args.num_sample)
+ print("Update frequent = %d" % args.update_freq)
+ print("Number of training examples = %d" % len(dataset_train))
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
+
+ num_layers = model_without_ddp.get_num_layers()
+ if args.layer_decay < 1.0:
+ assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
+ else:
+ assigner = None
+
+ if assigner is not None:
+ print("Assigned values = %s" % str(assigner.values))
+
+ skip_weight_decay_list = model.no_weight_decay()
+ print("Skip weight decay list: ", skip_weight_decay_list)
+
+ if args.enable_deepspeed:
+ loss_scaler = None
+ optimizer_params = get_parameter_groups(
+ model, args.weight_decay, skip_weight_decay_list,
+ assigner.get_layer_id if assigner is not None else None,
+ assigner.get_scale if assigner is not None else None)
+ model, optimizer, _, _ = ds_init(
+ args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
+ )
+
+ print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
+ assert model.gradient_accumulation_steps() == args.update_freq
+ else:
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
+ model_without_ddp = model.module
+
+ optimizer = create_optimizer(
+ args, model_without_ddp, skip_list=skip_weight_decay_list,
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
+ loss_scaler = NativeScaler()
+
+ print("Use step level LR scheduler!")
+ lr_schedule_values = utils.cosine_scheduler(
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
+ warmup_epochs=args.warmup_epochs, start_warmup_value=args.warmup_lr, warmup_steps=args.warmup_steps,
+ )
+ if args.weight_decay_end is None:
+ args.weight_decay_end = args.weight_decay
+ wd_schedule_values = utils.cosine_scheduler(
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
+
+ if mixup_fn is not None:
+ # smoothing is handled with mixup label transform
+ criterion = SoftTargetCrossEntropy()
+ elif args.smoothing > 0.:
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+ else:
+ criterion = torch.nn.CrossEntropyLoss()
+
+ print("criterion = %s" % str(criterion))
+ ceph_args = {
+ 'use_ceph_checkpoint': args.use_ceph_checkpoint,
+ 'ceph_checkpoint_prefix': args.ceph_checkpoint_prefix,
+ 'ckpt_path_split': args.ckpt_path_split,
+ 'local_rank': args.gpu,
+ }
+ if ceph_args['use_ceph_checkpoint']:
+ print("Will automatically upload model on ceph")
+ assert ceph_args['ceph_checkpoint_prefix'] != '', "Should set prefix for ceph checkpoint!"
+
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+
+ print(f"Use bf16 {args.bf16}")
+
+ if args.eval:
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
+ test_stats = final_test(data_loader_test, model, device, preds_file, ds=args.enable_deepspeed, bf16=args.bf16)
+ torch.distributed.barrier()
+ if global_rank == 0:
+ print("Start merging results...")
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
+ log_stats = {'Final top-1': final_top1,
+ 'Final Top-5': final_top5}
+ if args.output_dir and utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ exit(0)
+
+
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ max_accuracy = 0.0
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ if log_writer is not None:
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train, optimizer,
+ device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
+ log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
+ lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
+ num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
+ bf16=args.bf16
+ )
+ if args.output_dir and args.save_ckpt:
+ # if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
+ # utils.save_model(
+ # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ # loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema,
+ # ceph_args=ceph_args,
+ # )
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, model_name='latest', model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+ if data_loader_val is not None:
+ test_stats = validation_one_epoch(data_loader_val, model, device, ds=args.enable_deepspeed, bf16=args.bf16)
+ timestep = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+ print(f"[{timestep}] Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%")
+ if max_accuracy < test_stats["acc1"]:
+ max_accuracy = test_stats["acc1"]
+ if args.output_dir and args.save_ckpt:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, model_name='best', model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+
+ print(f'Max accuracy: {max_accuracy:.2f}%')
+ if log_writer is not None:
+ log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch)
+ log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch)
+ log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch)
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ if args.output_dir and utils.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
+ if args.test_best:
+ print("Auto testing the best model")
+ args.eval = True
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+ test_stats = final_test(data_loader_test, model, device, preds_file, ds=args.enable_deepspeed, bf16=args.bf16)
+ torch.distributed.barrier()
+ if global_rank == 0:
+ print("Start merging results...")
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
+ log_stats = {'Final top-1': final_top1,
+ 'Final Top-5': final_top5}
+ if args.output_dir and utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ opts, ds_init = get_args()
+ if opts.output_dir:
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
+ main(opts, ds_init)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/run_linear_probing.py b/third_party/InternVideo/InternVideo2/single_modality/run_linear_probing.py
new file mode 100755
index 0000000000000000000000000000000000000000..b55b4f48cbd7f90a69d50c8669cc4d20a89713a6
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/run_linear_probing.py
@@ -0,0 +1,783 @@
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.backends.cudnn as cudnn
+import json
+import os
+from functools import partial
+from pathlib import Path
+from collections import OrderedDict
+
+from datasets.mixup import Mixup
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import ModelEma
+from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
+
+from datasets import build_dataset
+from single_modality.engines.engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge
+from utils import NativeScalerWithGradNormCount as NativeScaler
+from utils import multiple_samples_collate
+import utils
+from models import *
+
+
+def get_args():
+ parser = argparse.ArgumentParser('VideoMAE fine-tuning and evaluation script for video classification', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int)
+ parser.add_argument('--test_batch_size', default=64, type=int)
+ parser.add_argument('--epochs', default=30, type=int)
+ parser.add_argument('--update_freq', default=1, type=int)
+ parser.add_argument('--save_ckpt_freq', default=100, type=int)
+ parser.add_argument('--steps_per_print', default=1, type=int)
+ parser.add_argument('--use_ceph_checkpoint', action='store_true',
+ help="whether use ceph to save and load checkpoint, may be some bug now")
+ parser.set_defaults(use_ceph_checkpoint=False)
+ parser.add_argument('--ceph_checkpoint_prefix', default='', type=str,
+ help='prefix for checkpoint in ceph')
+ parser.add_argument('--ckpt_path_split', default='/exp/', type=str,
+ help='string for splitting the ckpt_path')
+
+ # Model parameters
+ parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--tubelet_size', type=int, default=2)
+ parser.add_argument('--input_size', default=224, type=int,
+ help='videos input size')
+ parser.add_argument('--layer_scale_init_value', default=1e-5, type=float,
+ help="0.1 for base, 1e-5 for large. set 0 to disable LayerScale")
+ parser.add_argument('--layerscale_no_force_fp32', action='store_true',
+ help="Not force fp32 for LayerScale")
+ parser.set_defaults(layerscale_no_force_fp32=False)
+ parser.add_argument('--sep_pos_embed', action='store_true',
+ help="whether use seperable position embedding")
+ parser.add_argument('--center_init', action='store_true',
+ help="center initlization for patch embedding")
+ parser.add_argument('--orig_t_size', type=int, default=8)
+
+ parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
+ help='Attention dropout rate (default: 0.)')
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
+ help='Drop path rate (default: 0.1)')
+ parser.add_argument('--head_drop_path', type=float, default=0.0, metavar='PCT',
+ help='Head Drop path rate (default: 0.0)')
+
+ parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
+ parser.add_argument('--model_ema', action='store_true', default=False)
+ parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
+ parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
+ parser.add_argument('--merge_method', type=str, default='proj', help='merge mthod for features')
+ parser.add_argument('--merge_norm', type=str, default='kaiming_BN', help='merge Norm for features')
+
+ # Optimizer parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
+ weight decay. We use a cosine schedule for WD and using a larger decay by
+ the end of training improves performance for ViTs.""")
+
+ parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
+ help='learning rate (default: 1e-3)')
+ parser.add_argument('--layer_decay', type=float, default=0.75)
+
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
+ help='warmup learning rate (default: 1e-6)')
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
+
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
+ parser.add_argument('--open_clip_projector', action='store_true',
+ help="whether open clip projector for training")
+ parser.set_defaults(open_clip_projector=False)
+ parser.add_argument('--open_block_num', type=int, default=0,
+ help="whether open the last few blocks")
+
+ # Augmentation parameters
+ parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
+ help='Color jitter factor (default: 0.4)')
+ parser.add_argument('--num_sample', type=int, default=2,
+ help='Repeated_aug (default: 2)')
+ parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
+ parser.add_argument('--smoothing', type=float, default=0.1,
+ help='Label smoothing (default: 0.1)')
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+
+ # Evaluation parameters
+ parser.add_argument('--crop_pct', type=float, default=None)
+ parser.add_argument('--short_side_size', type=int, default=224)
+ parser.add_argument('--test_num_segment', type=int, default=5)
+ parser.add_argument('--test_num_crop', type=int, default=3)
+
+ # Random Erase params
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+ help='Random erase prob (default: 0.25)')
+ parser.add_argument('--remode', type=str, default='pixel',
+ help='Random erase mode (default: "pixel")')
+ parser.add_argument('--recount', type=int, default=1,
+ help='Random erase count (default: 1)')
+ parser.add_argument('--resplit', action='store_true', default=False,
+ help='Do not random erase first (clean) augmentation split')
+
+ # Mixup params
+ parser.add_argument('--mixup', type=float, default=0.8,
+ help='mixup alpha, mixup enabled if > 0.')
+ parser.add_argument('--cutmix', type=float, default=1.0,
+ help='cutmix alpha, cutmix enabled if > 0.')
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
+ help='Probability of performing mixup or cutmix when either/both is enabled')
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
+ parser.add_argument('--mixup_mode', type=str, default='batch',
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+ # Finetuning params
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
+ parser.add_argument('--finetune_extra', default='', help='finetune from extra checkpoint')
+ parser.add_argument('--delete_head', action='store_true', help='whether delete head')
+ parser.add_argument('--model_key', default='model|module', type=str)
+ parser.add_argument('--model_prefix', default='', type=str)
+ parser.add_argument('--init_scale', default=0.001, type=float)
+ parser.add_argument('--use_checkpoint', action='store_true')
+ parser.set_defaults(use_checkpoint=False)
+ parser.add_argument('--checkpoint_num', default=0, type=int,
+ help='number of layers for using checkpoint')
+ parser.add_argument('--use_mean_pooling', action='store_true')
+ parser.set_defaults(use_mean_pooling=True)
+ parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
+
+ # Dataset parameters
+ parser.add_argument('--prefix', default='', type=str, help='prefix for data')
+ parser.add_argument('--split', default=' ', type=str, help='split for metadata')
+ parser.add_argument('--filename_tmpl', default='img_{:05}.jpg', type=str, help='file template')
+ parser.add_argument('--data_path', default='you_data_path', type=str,
+ help='dataset path')
+ parser.add_argument('--eval_data_path', default=None, type=str,
+ help='dataset path for evaluation')
+ parser.add_argument('--nb_classes', default=400, type=int,
+ help='number of the classification types')
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
+ parser.add_argument('--use_decord', action='store_true',
+ help='whether use decord to load video, otherwise load image')
+ parser.add_argument('--no_use_decord', action='store_false', dest='use_decord')
+ parser.set_defaults(use_decord=True)
+ parser.add_argument('--num_segments', type=int, default=1)
+ parser.add_argument('--num_frames', type=int, default=16)
+ parser.add_argument('--sampling_rate', type=int, default=4)
+ parser.add_argument('--data_set', default='Kinetics', choices=[
+ 'Kinetics', 'Kinetics_sparse',
+ 'SSV2', 'UCF101', 'HMDB51', 'image_folder',
+ 'mitv1_sparse',
+ 'ANet', 'HACS', 'ANet_interval', 'HACS_interval',
+ ], type=str, help='dataset')
+ parser.add_argument('--output_dir', default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default=None,
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--resume', default='',
+ help='resume from checkpoint')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
+ parser.set_defaults(auto_resume=True)
+
+ parser.add_argument('--save_ckpt', action='store_true')
+ parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
+ parser.set_defaults(save_ckpt=True)
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--test_best', action='store_true',
+ help='Whether test the best model')
+ parser.add_argument('--eval', action='store_true',
+ help='Perform evaluation only')
+ parser.add_argument('--dist_eval', action='store_true', default=False,
+ help='Enabling distributed evaluation')
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ parser.add_argument('--enable_deepspeed', action='store_true', default=False)
+ parser.add_argument('--bf16', default=False, action='store_true')
+ parser.add_argument('--zero_stage', default=0, type=int,
+ help='ZeRO optimizer stage (default: 0)')
+
+ known_args, _ = parser.parse_known_args()
+
+ if known_args.enable_deepspeed:
+ try:
+ import deepspeed
+ from deepspeed import DeepSpeedConfig
+ parser = deepspeed.add_config_arguments(parser)
+ ds_init = deepspeed.initialize
+ except:
+ print("Please 'pip install deepspeed'")
+ exit(0)
+ else:
+ ds_init = None
+
+ return parser.parse_args(), ds_init
+
+
+def main(args, ds_init):
+ utils.init_distributed_mode(args)
+
+ if ds_init is not None:
+ utils.create_internvideo2_ds_config(args)
+
+ print(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ # random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args)
+ if args.disable_eval_during_finetuning:
+ dataset_val = None
+ else:
+ dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args)
+ dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args)
+
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ sampler_test = torch.utils.data.DistributedSampler(
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if global_rank == 0 and args.log_dir is not None:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if args.num_sample > 1:
+ collate_func = partial(multiple_samples_collate, fold=False)
+ else:
+ collate_func = None
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ collate_fn=collate_func,
+ persistent_workers=True
+ )
+
+ if dataset_val is not None:
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=args.test_batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False,
+ persistent_workers=True
+ )
+ else:
+ data_loader_val = None
+
+ if dataset_test is not None:
+ data_loader_test = torch.utils.data.DataLoader(
+ dataset_test, sampler=sampler_test,
+ batch_size=args.test_batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False,
+ persistent_workers=True
+ )
+ else:
+ data_loader_test = None
+
+ mixup_fn = None
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ if mixup_active:
+ print("Mixup is activated!")
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+ if 'cat' in args.model:
+ model = create_model(
+ args.model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ num_frames=args.num_frames * args.num_segments,
+ tubelet_size=args.tubelet_size,
+ sep_pos_embed=args.sep_pos_embed,
+ fc_drop_rate=args.fc_drop_rate,
+ drop_path_rate=args.drop_path,
+ head_drop_path_rate=args.head_drop_path,
+ use_checkpoint=args.use_checkpoint,
+ checkpoint_num=args.checkpoint_num,
+ init_scale=args.init_scale,
+ init_values=args.layer_scale_init_value,
+ layerscale_no_force_fp32=args.layerscale_no_force_fp32,
+ merge_method=args.merge_method,
+ merge_norm=args.merge_norm,
+ )
+ else:
+ model = create_model(
+ args.model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ num_frames=args.num_frames * args.num_segments,
+ tubelet_size=args.tubelet_size,
+ sep_pos_embed=args.sep_pos_embed,
+ fc_drop_rate=args.fc_drop_rate,
+ drop_path_rate=args.drop_path,
+ head_drop_path_rate=args.head_drop_path,
+ use_checkpoint=args.use_checkpoint,
+ checkpoint_num=args.checkpoint_num,
+ init_scale=args.init_scale,
+ init_values=args.layer_scale_init_value,
+ layerscale_no_force_fp32=args.layerscale_no_force_fp32,
+ )
+
+ patch_size = model.patch_embed.patch_size
+ print("Patch size = %s" % str(patch_size))
+ args.window_size = (args.num_frames // args.tubelet_size, args.input_size // patch_size[0], args.input_size // patch_size[1])
+ args.patch_size = patch_size
+
+ if args.finetune:
+ if args.finetune.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.finetune, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.finetune, map_location='cpu')
+
+ print("Load ckpt from %s" % args.finetune)
+ checkpoint_model = None
+ for model_key in args.model_key.split('|'):
+ if model_key in checkpoint:
+ checkpoint_model = checkpoint[model_key]
+ print("Load state_dict by model_key = %s" % model_key)
+ break
+ if checkpoint_model is None:
+ checkpoint_model = checkpoint
+
+ if 'head.weight' in checkpoint_model.keys():
+ if args.delete_head:
+ print("Removing head from pretrained checkpoint")
+ del checkpoint_model['head.weight']
+ del checkpoint_model['head.bias']
+ elif checkpoint_model['head.weight'].shape[0] == 710:
+ if args.nb_classes == 400:
+ checkpoint_model['head.weight'] = checkpoint_model['head.weight'][:args.nb_classes]
+ checkpoint_model['head.bias'] = checkpoint_model['head.bias'][:args.nb_classes]
+ elif args.nb_classes in [600, 700]:
+ # download from https://drive.google.com/drive/folders/17cJd2qopv-pEG8NSghPFjZo1UUZ6NLVm
+ map_path = f'./k710/label_mixto{args.nb_classes}.json'
+ print(f'Load label map from {map_path}')
+ with open(map_path) as f:
+ label_map = json.load(f)
+ checkpoint_model['head.weight'] = checkpoint_model['head.weight'][label_map]
+ checkpoint_model['head.bias'] = checkpoint_model['head.bias'][label_map]
+
+ all_keys = list(checkpoint_model.keys())
+ new_dict = OrderedDict()
+ for key in all_keys:
+ if key.startswith('backbone.'):
+ new_dict[key[9:]] = checkpoint_model[key]
+ elif key.startswith('encoder.'):
+ new_dict[key[8:]] = checkpoint_model[key]
+ else:
+ new_dict[key] = checkpoint_model[key]
+ checkpoint_model = new_dict
+
+ if args.finetune_extra:
+ extra_checkpoint = torch.load(args.finetune_extra, map_location='cpu')
+ print("Load extra ckpt from %s" % args.finetune_extra)
+ extra_checkpoint_model = None
+ for model_key in args.model_key.split('|'):
+ if model_key in extra_checkpoint:
+ extra_checkpoint_model = extra_checkpoint[model_key]
+ print("Load state_dict by model_key = %s" % model_key)
+ break
+ for k, v in extra_checkpoint_model.items():
+ new_k = k
+ if k.startswith('vision_encoder.'):
+ new_k = k.replace('vision_encoder.', '')
+ else:
+ print(f"Ignore keys: {k}")
+ continue
+ checkpoint_model[new_k] = v
+
+ # interpolate position embedding
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
+
+ # we use 8 frames for pretraining
+ orig_t_size = args.orig_t_size
+ new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // (new_t_size))** 0.5)
+
+ # class_token and dist_token are kept unchanged
+ if orig_t_size != new_t_size:
+ print(f"Temporal interpolate from {orig_t_size} to {new_t_size}")
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+ pos_embed_checkpoint = new_pos_embed
+
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+ elif 'pos_embed_spatial' in checkpoint_model and 'pos_embed_temporal' in checkpoint_model:
+ pos_embed_spatial_checkpoint = checkpoint_model['pos_embed_spatial']
+ pos_embed_temporal_checkpoint = checkpoint_model['pos_embed_temporal']
+
+ embedding_size = pos_embed_spatial_checkpoint.shape[-1] # channel dim
+ num_patches = model.patch_embed.num_patches #
+
+ orig_t_size = pos_embed_temporal_checkpoint.shape[-2]
+ new_t_size = args.num_frames // model.patch_embed.tubelet_size
+
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(pos_embed_spatial_checkpoint.shape[-2] ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int((num_patches // new_t_size) ** 0.5)
+
+ if orig_t_size != new_t_size:
+ print(f"Temporal interpolate from {orig_t_size} to {new_t_size}")
+ tmp_pos_embed = pos_embed_temporal_checkpoint.view(1, orig_t_size, -1, embedding_size)
+ tmp_pos_embed = tmp_pos_embed.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
+ tmp_pos_embed = torch.nn.functional.interpolate(tmp_pos_embed, size=new_t_size, mode='linear')
+ tmp_pos_embed = tmp_pos_embed.view(1, -1, embedding_size, new_t_size)
+ tmp_pos_embed = tmp_pos_embed.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
+ checkpoint_model['pos_embed_temporal'] = tmp_pos_embed
+
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ pos_tokens = pos_embed_spatial_checkpoint
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
+ checkpoint_model['pos_embed_spatial'] = pos_tokens
+
+ utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
+
+ model.to(device)
+
+ print("Freeze backbone for linear probing")
+ if '6B' in args.model:
+ depth = 48
+ else:
+ depth = 40 # ViT-g
+ block_num_list = [(depth - i - 1) for i in range(args.open_block_num)]
+ for name, p in model.named_parameters():
+ if name.startswith('patch_embed') or name.startswith('pos_embed') or name.startswith('cls_token'):
+ print(f"Freeze {name}")
+ p.requires_grad = False
+ elif name.startswith('blocks'):
+ flag = True
+ for num in block_num_list:
+ if name.startswith(f'blocks.{num}'):
+ flag = False
+ break
+ if flag:
+ print(f"Freeze {name}")
+ p.requires_grad = False
+ else:
+ print(f"Unfreeze {name}")
+ elif name.startswith('clip_projector') and not args.open_clip_projector:
+ print(f"Freeze {name}")
+ p.requires_grad = False
+ else:
+ print(f"Unfreeze {name}")
+
+ model_ema = None
+ if args.model_ema:
+ model_ema = ModelEma(
+ model,
+ decay=args.model_ema_decay,
+ device='cpu' if args.model_ema_force_cpu else '',
+ resume='')
+ print("Using EMA with decay = %.8f" % args.model_ema_decay)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params:', n_parameters)
+
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
+ args.lr = args.lr * total_batch_size * args.num_sample / 256
+ args.min_lr = args.min_lr * total_batch_size * args.num_sample / 256
+ args.warmup_lr = args.warmup_lr * total_batch_size * args.num_sample / 256
+ print("LR = %.8f" % args.lr)
+ print("Batch size = %d" % total_batch_size)
+ print("Repeated sample = %d" % args.num_sample)
+ print("Update frequent = %d" % args.update_freq)
+ print("Number of training examples = %d" % len(dataset_train))
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
+
+ num_layers = model_without_ddp.get_num_layers()
+ if args.layer_decay < 1.0:
+ assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
+ else:
+ assigner = None
+
+ if assigner is not None:
+ print("Assigned values = %s" % str(assigner.values))
+
+ skip_weight_decay_list = model.no_weight_decay()
+ print("Skip weight decay list: ", skip_weight_decay_list)
+
+ if args.enable_deepspeed:
+ loss_scaler = None
+ optimizer_params = get_parameter_groups(
+ model, args.weight_decay, skip_weight_decay_list,
+ assigner.get_layer_id if assigner is not None else None,
+ assigner.get_scale if assigner is not None else None)
+ model, optimizer, _, _ = ds_init(
+ args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
+ )
+
+ print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
+ assert model.gradient_accumulation_steps() == args.update_freq
+ else:
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
+ model_without_ddp = model.module
+
+ optimizer = create_optimizer(
+ args, model_without_ddp, skip_list=skip_weight_decay_list,
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
+ loss_scaler = NativeScaler()
+
+ print("Use step level LR scheduler!")
+ lr_schedule_values = utils.cosine_scheduler(
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
+ warmup_epochs=args.warmup_epochs, start_warmup_value=args.warmup_lr, warmup_steps=args.warmup_steps,
+ )
+ if args.weight_decay_end is None:
+ args.weight_decay_end = args.weight_decay
+ wd_schedule_values = utils.cosine_scheduler(
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
+
+ if mixup_fn is not None:
+ # smoothing is handled with mixup label transform
+ criterion = SoftTargetCrossEntropy()
+ elif args.smoothing > 0.:
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+ else:
+ criterion = torch.nn.CrossEntropyLoss()
+
+ print("criterion = %s" % str(criterion))
+ ceph_args = {
+ 'use_ceph_checkpoint': args.use_ceph_checkpoint,
+ 'ceph_checkpoint_prefix': args.ceph_checkpoint_prefix,
+ 'ckpt_path_split': args.ckpt_path_split,
+ 'local_rank': args.gpu,
+ }
+ if ceph_args['use_ceph_checkpoint']:
+ print("Will automatically upload model on ceph")
+ assert ceph_args['ceph_checkpoint_prefix'] != '', "Should set prefix for ceph checkpoint!"
+
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+
+ print(f"Use bf16 {args.bf16}")
+
+ if args.eval:
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
+ test_stats = final_test(data_loader_test, model, device, preds_file, ds=args.enable_deepspeed, bf16=args.bf16)
+ torch.distributed.barrier()
+ if global_rank == 0:
+ print("Start merging results...")
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
+ log_stats = {'Final top-1': final_top1,
+ 'Final Top-5': final_top5}
+ if args.output_dir and utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ exit(0)
+
+
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ max_accuracy = 0.0
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ if log_writer is not None:
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train, optimizer,
+ device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
+ log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
+ lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
+ num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
+ bf16=args.bf16
+ )
+ if args.output_dir and args.save_ckpt:
+ # if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
+ # utils.save_model(
+ # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ # loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema,
+ # ceph_args=ceph_args,
+ # )
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, model_name='latest', model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+ if data_loader_val is not None:
+ test_stats = validation_one_epoch(data_loader_val, model, device, ds=args.enable_deepspeed, bf16=args.bf16)
+ timestep = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+ print(f"[{timestep}] Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%")
+ if max_accuracy < test_stats["acc1"]:
+ max_accuracy = test_stats["acc1"]
+ if args.output_dir and args.save_ckpt:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, model_name='best', model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+
+ print(f'Max accuracy: {max_accuracy:.2f}%')
+ if log_writer is not None:
+ log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch)
+ log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch)
+ log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch)
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ if args.output_dir and utils.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
+ if args.test_best:
+ print("Auto testing the best model")
+ args.eval = True
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema,
+ ceph_args=ceph_args,
+ )
+ test_stats = final_test(data_loader_test, model, device, preds_file, ds=args.enable_deepspeed, bf16=args.bf16)
+ torch.distributed.barrier()
+ if global_rank == 0:
+ print("Start merging results...")
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
+ log_stats = {'Final top-1': final_top1,
+ 'Final Top-5': final_top5}
+ if args.output_dir and utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ opts, ds_init = get_args()
+ if opts.output_dir:
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
+ main(opts, ds_init)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/run_pretraining.py b/third_party/InternVideo/InternVideo2/single_modality/run_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c88213e5f6c9652e71d82db80b242eaa5f06dd7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/run_pretraining.py
@@ -0,0 +1,451 @@
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.backends.cudnn as cudnn
+import json
+import os
+from functools import partial
+
+from pathlib import Path
+from timm.models import create_model
+from optim_factory import (
+ create_optimizer,
+ get_parameter_groups,
+)
+from datasets import build_multi_pretraining_dataset
+from single_modality.engines.engine_for_pretraining import train_one_epoch
+from utils import NativeScalerWithGradNormCount as NativeScaler
+from utils import multiple_pretrain_samples_collate
+import utils
+from models import *
+
+
+def get_args():
+ parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int)
+ parser.add_argument('--epochs', default=800, type=int)
+ parser.add_argument('--update_freq', default=1, type=int)
+ parser.add_argument('--save_ckpt_freq', default=50, type=int)
+ parser.add_argument('--steps_per_print', default=1, type=int)
+ parser.add_argument('--use_ceph_checkpoint', action='store_true',
+ help="whether use ceph to save and load checkpoint, may be some bug now")
+ parser.set_defaults(use_ceph_checkpoint=False)
+ parser.add_argument('--ceph_checkpoint_prefix', default='', type=str,
+ help='prefix for checkpoint in ceph')
+ parser.add_argument('--ckpt_path_split', default='/exp/', type=str,
+ help='string for splitting the ckpt_path')
+
+ # Model parameters
+ parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--decoder_depth', default=4, type=int,
+ help='depth of decoder')
+ parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'attention'],
+ type=str, help='masked strategy of video tokens/patches')
+ parser.add_argument('--mask_ratio', default=0.75, type=float,
+ help='ratio of the visual tokens/patches need be masked')
+ parser.add_argument('--input_size', default=224, type=int,
+ help='videos input size for backbone')
+ parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
+ help='Drop path rate (default: 0.0)')
+ parser.add_argument('--normlize_target', default=True, type=bool,
+ help='normalized the target patch pixels')
+ parser.add_argument('--tubelet_size', default=1, type=int,
+ help='temporal tube size for the patch embedding')
+ parser.add_argument('--layer_scale_init_value', default=1e-5, type=float,
+ help="0.1 for base, 1e-5 for large. set 0 to disable LayerScale")
+ parser.add_argument('--layerscale_no_force_fp32', action='store_true',
+ help="Not force fp32 for LayerScale")
+ parser.set_defaults(layerscale_no_force_fp32=False)
+ parser.add_argument('--sep_pos_embed', action='store_true',
+ help="whether use seperable position embedding")
+ parser.set_defaults(sep_pos_embed=False)
+
+ # CLIP decpder parameters
+ parser.add_argument('--clip_teacher', default='internvl_clip_6b', type=str,
+ help='Name of CLIP teacher')
+ parser.add_argument('--clip_input_resolution', default=224, type=int,
+ help='input resolution of CLIP decoder')
+ parser.add_argument('--clip_teacher_embed_dim', default=3200, type=int,
+ help='output dimension of CLIP decoder in the intermediate layers')
+ parser.add_argument('--clip_teacher_final_dim', default=768, type=int,
+ help='output dimension of CLIP decoder in the final layer, 0 means w/o alignment')
+ parser.add_argument('--clip_loss_ratio', default=[1, 1], type=float, nargs='+', metavar='BETA',
+ help='Loss ratio for middle features and final features (default: [1, 0.5])')
+ parser.add_argument('--clip_norm_type', default='l2', type=str,
+ help='type of feature normalization')
+ parser.add_argument('--clip_return_attn', action='store_true',
+ help="whether return CLIP attention")
+ parser.set_defaults(clip_return_attn=False)
+ parser.add_argument('--clip_return_layer', default=1, type=int,
+ help='number of CLIP return layers')
+ parser.add_argument('--clip_teacher_return_interval', default=1, type=float,
+ help='interval of CLIP teacher return layers')
+ parser.add_argument('--clip_student_return_interval', default=1, type=float,
+ help='interval of CLIP student return layers')
+
+ # MAE decoder parameters
+ parser.add_argument('--mae_teacher', default='clip_b16', type=str,
+ help='Name of MAE teacher')
+ parser.add_argument('--mae_input_resolution', default=224, type=int,
+ help='input resolution of MAE decoder')
+ parser.add_argument('--mae_tubelet_size', default=2, type=int,
+ help='tubelet size of MAE decoder')
+ parser.add_argument('--mae_teacher_embed_dim', default=1408, type=int,
+ help='output dimension of MAE decoder')
+ parser.add_argument('--mae_norm_type', default='l2', type=str,
+ help='type of feature normalization')
+ parser.add_argument('--mae_loss_ratio', default=1., type=float,
+ help='ratio for MAE loss')
+ parser.add_argument('--mae_return_layer', default=1, type=int,
+ help='number of MAE return layers')
+ parser.add_argument('--mae_teacher_return_interval', default=1, type=float,
+ help='interval of MAE teacher return layers')
+ parser.add_argument('--mae_student_return_interval', default=1, type=float,
+ help='interval of MAE student return layers')
+
+ # Optimizer parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--opt_eps', default=1e-6, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-6)')
+ parser.add_argument('--opt_betas', default=[0.9, 0.98], type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: [0.9, 0.98])')
+ parser.add_argument('--clip_grad', type=float, default=3.0, metavar='NORM',
+ help='Clip gradient norm (default: 3.0)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
+ weight decay. We use a cosine schedule for WD.
+ (Set the same value with args.weight_decay to keep weight decay no change)""")
+
+ parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
+ help='learning rate (default: 1.5e-4)')
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
+ help='warmup learning rate (default: 1e-6)')
+ parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--use_checkpoint', action='store_true')
+ parser.set_defaults(use_checkpoint=False)
+ parser.add_argument('--checkpoint_num', type=int, default=0)
+
+ # Augmentation parameters
+ parser.add_argument('--num_sample', type=int, default=1, help='Repeated_aug (default: 1)')
+ parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
+ help='Color jitter factor (default: 0.0)')
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+ parser.add_argument('--flip', default=False, action='store_true',
+ help='whether flip the video in pretraining')
+
+ # Dataset parameters
+ parser.add_argument('--prefix', default='', type=str, help='prefix for data')
+ parser.add_argument('--split', default=' ', type=str, help='split for metadata')
+ parser.add_argument('--data_path', default='you_data_path', type=str,
+ help='dataset path')
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
+ parser.add_argument('--use_decord', action='store_true',
+ help='whether use decord to load video, otherwise load image')
+ parser.add_argument('--no_use_decord', action='store_false', dest='use_decord')
+ parser.set_defaults(use_decord=True)
+ parser.add_argument('--num_segments', type=int, default=1)
+ parser.add_argument('--num_frames', type=int, default=16)
+ parser.add_argument('--sampling_rate', type=int, default=4)
+ parser.add_argument('--output_dir', default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default=None,
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
+ parser.set_defaults(auto_resume=True)
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--test_best', action='store_true',
+ help='Whether test the best model')
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
+ help='')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+
+ parser.add_argument('--enable_deepspeed',
+ action='store_true', default=False)
+ parser.add_argument('--bf16', default=False, action='store_true')
+ parser.add_argument('--zero_stage', default=0, type=int,
+ help='ZeRO optimizer stage (default: 0)')
+
+ known_args, _ = parser.parse_known_args()
+
+ if known_args.enable_deepspeed:
+ try:
+ import deepspeed
+ parser = deepspeed.add_config_arguments(parser)
+ ds_init = deepspeed.initialize
+ except:
+ print("Please install DeepSpeed")
+ exit(0)
+ else:
+ ds_init = None
+
+ return parser.parse_args(), ds_init
+
+
+def get_model(args):
+ print(f"Creating model: {args.model}")
+ model = create_model(
+ args.model,
+ pretrained=False,
+ drop_path_rate=args.drop_path,
+ num_frames=args.num_frames//(args.mae_tubelet_size//args.tubelet_size),
+ tubelet_size=args.tubelet_size,
+ sep_pos_embed=args.sep_pos_embed,
+ use_checkpoint=args.use_checkpoint,
+ checkpoint_num=args.checkpoint_num,
+ init_values=args.layer_scale_init_value,
+ layerscale_no_force_fp32=args.layerscale_no_force_fp32,
+ clip_teacher_embed_dim=args.clip_teacher_embed_dim,
+ clip_teacher_final_dim=args.clip_teacher_final_dim,
+ clip_norm_type=args.clip_norm_type,
+ clip_return_layer=args.clip_return_layer,
+ clip_student_return_interval=args.clip_student_return_interval,
+ mae_teacher_embed_dim=args.mae_teacher_embed_dim,
+ mae_norm_type=args.mae_norm_type,
+ mae_return_layer=args.mae_return_layer,
+ mae_student_return_interval=args.mae_student_return_interval,
+ )
+ return model
+
+
+def main(args, ds_init):
+ utils.init_distributed_mode(args)
+
+ if ds_init is not None:
+ utils.create_internvideo2_ds_config(args)
+
+ print(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+
+ model = get_model(args)
+ patch_size = model.patch_embed.patch_size
+ print("Patch size = %s" % str(patch_size))
+ print("Tubelet size = %s" % str(args.tubelet_size))
+ args.window_size = (args.num_frames // args.tubelet_size, args.input_size // patch_size[0], args.input_size // patch_size[1])
+ args.patch_size = patch_size
+
+ # CLIP teacher model
+ print(f'CLIP Teacher model: {args.clip_teacher}')
+ clip_teacher_model = eval(args.clip_teacher)(
+ img_size=args.clip_input_resolution,
+ clip_norm_type=args.clip_norm_type,
+ return_attn=args.clip_return_attn,
+ clip_return_layer=args.clip_return_layer,
+ clip_return_interval=args.clip_teacher_return_interval
+ )
+
+ # MAE teacher model
+ print(f'MAE Teacher model: {args.mae_teacher}')
+ mae_teacher_model = eval(args.mae_teacher)(
+ img_size=args.mae_input_resolution,
+ tubelet_size=args.mae_tubelet_size,
+ mae_norm_type=args.mae_norm_type,
+ mae_return_layer=args.mae_return_layer,
+ mae_return_interval=args.mae_teacher_return_interval
+ )
+
+ # get dataset
+ dataset_train = build_multi_pretraining_dataset(args)
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ sampler_rank = global_rank
+ num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
+
+ sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True)
+ print("Sampler_train = %s" % str(sampler_train))
+
+ if global_rank == 0 and args.log_dir is not None:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if args.num_sample > 1:
+ collate_func = partial(multiple_pretrain_samples_collate, fold=False)
+ else:
+ collate_func = None
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ collate_fn=collate_func,
+ worker_init_fn=utils.seed_worker,
+ persistent_workers=True
+ )
+
+ model.to(device)
+ clip_teacher_model.to(device)
+ mae_teacher_model.to(device)
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params: {} M'.format(n_parameters / 1e6))
+
+ total_batch_size = args.batch_size * utils.get_world_size()
+
+ args.lr = args.lr * total_batch_size * args.num_sample / 256
+ args.min_lr = args.min_lr * total_batch_size * args.num_sample / 256
+ args.warmup_lr = args.warmup_lr * total_batch_size * args.num_sample / 256
+ print("LR = %.8f" % args.lr)
+ print("Batch size = %d" % total_batch_size)
+ print("Repeated sample = %d" % args.num_sample)
+ print("Number of training steps = %d" % num_training_steps_per_epoch)
+ print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
+
+ skip_weight_decay_list = model.no_weight_decay()
+ print("Skip weight decay list: ", skip_weight_decay_list)
+
+ if args.enable_deepspeed:
+ loss_scaler = None
+ optimizer_params = get_parameter_groups(
+ model, args.weight_decay, skip_weight_decay_list
+ )
+ model, optimizer, _, _ = ds_init(
+ args=args, model=model, model_parameters=optimizer_params,
+ dist_init_required=not args.distributed,
+ )
+
+ print("model.gradient_accumulation_steps() = %d" %
+ model.gradient_accumulation_steps())
+ assert model.gradient_accumulation_steps() == args.update_freq
+ else:
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
+ model_without_ddp = model.module
+
+ optimizer = create_optimizer(args, model_without_ddp)
+ loss_scaler = NativeScaler()
+
+ print("Use step level LR & WD scheduler!")
+ lr_schedule_values = utils.cosine_scheduler(
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
+ )
+ if args.weight_decay_end is None:
+ args.weight_decay_end = args.weight_decay
+ wd_schedule_values = utils.cosine_scheduler(args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
+
+ ceph_args = {
+ 'use_ceph_checkpoint': args.use_ceph_checkpoint,
+ 'ceph_checkpoint_prefix': args.ceph_checkpoint_prefix,
+ 'ckpt_path_split': args.ckpt_path_split,
+ 'local_rank': args.gpu,
+ }
+ if ceph_args['use_ceph_checkpoint']:
+ print("Will automatically upload model on ceph")
+ assert ceph_args['ceph_checkpoint_prefix'] != '', "Should set prefix for ceph checkpoint!"
+
+ utils.auto_load_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler,
+ ceph_args=ceph_args,
+ )
+ torch.cuda.empty_cache()
+ print(f"Start training for {args.epochs} epochs")
+ print(f"Use bf16 {args.bf16}")
+ print(f"Mask ratio: {args.mask_ratio}")
+ print(f"Mask typr: {args.mask_type}")
+ distill_final_features = args.clip_teacher_final_dim > 0
+ print(f"Distill final (AttnPoll) features of teacher: {distill_final_features}")
+ print(f"Loss ratio: {args.clip_loss_ratio}")
+
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ if log_writer is not None:
+ log_writer.set_step(epoch * num_training_steps_per_epoch)
+ train_stats = train_one_epoch(
+ model, data_loader_train,
+ optimizer, device, epoch, loss_scaler,
+ args.clip_grad, log_writer=log_writer,
+ start_steps=epoch * num_training_steps_per_epoch,
+ lr_schedule_values=lr_schedule_values,
+ wd_schedule_values=wd_schedule_values,
+ clip_teacher_model=clip_teacher_model,
+ clip_input_resolution=args.clip_input_resolution,
+ distill_final_features=distill_final_features,
+ clip_loss_ratio=args.clip_loss_ratio,
+ mae_teacher_model=mae_teacher_model,
+ mae_input_resolution=args.mae_input_resolution,
+ mae_loss_ratio=args.mae_loss_ratio,
+ td_ratio=args.mae_tubelet_size//args.tubelet_size,
+ mask_type=args.mask_type,
+ mask_ratio=args.mask_ratio,
+ bf16=args.bf16,
+ )
+ if args.output_dir:
+ if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch,
+ ceph_args=ceph_args,
+ )
+ utils.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch,
+ model_name='latest', ceph_args=ceph_args,
+ )
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch, 'n_parameters': n_parameters}
+
+ if args.output_dir and utils.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ opts, ds_init = get_args()
+ if opts.output_dir:
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
+ main(opts, ds_init)
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/1B_ap_k710_ap_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/1B_ap_k710_ap_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..36badb6a0c76c91282b955e609e775952433bc07
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/1B_ap_k710_ap_k400_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_ap_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/1B_ap_k710_ap_k400_f16.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/6B_ap_k710_ap_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/6B_ap_k710_ap_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..db11d89d271b2e76a5cd60d9544b12c15d932d35
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k400/6B_ap_k710_ap_k400_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_ap_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/6B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/1B_ap_k710_ap_k600_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/1B_ap_k710_ap_k600_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..878c38390b7ee976efc0da6b5a90723720d8b69b
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/1B_ap_k710_ap_k600_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_ap_k600_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/1B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/6B_ap_k710_ap_k600_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/6B_ap_k710_ap_k600_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..85c5d9c8a500596a074893344a8d5c49104f1e7a
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k600/6B_ap_k710_ap_k600_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ap_k710_ap_k600_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/6B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/1B_ap_k710_ap_k700_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/1B_ap_k710_ap_k700_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2d09de01795eaba375e74899d47e18b90c10d2af
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/1B_ap_k710_ap_k700_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_ap_k700_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/1B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/6B_ap_k710_ap_k700_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/6B_ap_k710_ap_k700_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cab4861e3ed4a87d59b7c04bb96c2ed5a306e3cc
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k700/6B_ap_k710_ap_k700_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ap_k710_ap_k700_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/6B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/1B_ap_k710_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/1B_ap_k710_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cbde3a04eb4b29c453d4023a1b7b0f08e766edee
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/1B_ap_k710_f16_loadStage2.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='your_data_path/k710'
+MODEL_PATH='your_model_path/1B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/1B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 710 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 25 \
+ --lr 2e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/6B_ap_k710_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/6B_ap_k710_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..41e07ad75587e42488f880c944b9d7c3a8ef49dd
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/k710/6B_ap_k710_f16_loadStage2.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ap_k710_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='your_data_path/k710'
+MODEL_PATH='your_model_path/1B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 710 \
+ --finetune ${MODEL_PATH} \
+ --finetune_vclip ${VCLIP_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 32 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 4 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/1B_ap_k710_ap_k400_ap_mit_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/1B_ap_k710_ap_k400_ap_mit_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e8986bf524fd4ec42139ca4ad88acc5ce6cf9d77
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/1B_ap_k710_ap_k400_ap_mit_f16.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_k710_ap_k400_ap_mit_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/mit'
+DATA_PATH='your_data_path/mit'
+MODEL_PATH='your_model_path/1B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=4
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'mitv1_sparse' \
+ --nb_classes 339 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 35 \
+ --lr 2e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/6B_ap_k710_ap_k400_ap_mit_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/6B_ap_k710_ap_k400_ap_mit_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ebc9ad5666f1a7ce8c309636a985e193afb7bc7c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/mit/6B_ap_k710_ap_k400_ap_mit_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ap_k710_ap_k400_ap_mit_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/mit'
+DATA_PATH='your_data_path/mit'
+MODEL_PATH='your_model_path/6B_ap_k710_f16_loadStage2.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'mitv1_sparse' \
+ --nb_classes 339 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 16 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 30 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/1B_ap_ssv2_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/1B_ap_ssv2_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7a7881e2a02b71b98d4a3ca5efc1ef489fd3d38d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/1B_ap_ssv2_f16_loadStage2.sh
@@ -0,0 +1,65 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ap_ssv2_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/1B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/1B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_ap_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 25 \
+ --lr 2e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/6B_ap_ssv2_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/6B_ap_ssv2_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..749ebe79baac9c11c9e0b533d88ddfdb38557aa1
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/attentive_probing/ssv2/6B_ap_ssv2_f16_loadStage2.sh
@@ -0,0 +1,65 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ap_ssv2_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/6B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_ap_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 4 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/anet/6B_ft_k710_ft_k400_ap_anet_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/anet/6B_ft_k710_ft_k400_ap_anet_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3f31c934ad6b51790ec4f2de929b64d49276590f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/anet/6B_ft_k710_ft_k400_ap_anet_f8.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_ap_anet_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/anet'
+DATA_PATH='your_data_path/anet'
+MODEL_PATH='your_model_path/1B_ft_k710_ft_k400_f8.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'ANet' \
+ --nb_classes 200 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 40 \
+ --lr 2e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0.5 \
+ --fc_drop_rate 0.5 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/hacs/6B_ft_k710_ft_k400_ap_hacs_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/hacs/6B_ft_k710_ft_k400_ap_hacs_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b430216bf74da8172311c07718aff53748a534c9
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/hacs/6B_ft_k710_ft_k400_ap_hacs_f8.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_ap_anet_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/anet'
+DATA_PATH='your_data_path/anet'
+MODEL_PATH='your_model_path/1B_ft_k710_ft_k400_f8.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --open_clip_projector \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'HACS' \
+ --nb_classes 200 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 40 \
+ --lr 2e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0.5 \
+ --fc_drop_rate 0.5 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3d4ed2a66a737950870dff9b48c1057aca36f611
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f16.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f5c1617781f64b6de94c9c00525ded4babed3d7c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/1B_ft_k710_ft_k400_f8.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k400_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0b36e12b5fd310ac3542f3ef82df1874a41e9e42
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f16.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 2 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 30 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c534fde98f16585212599a1386f92e8b4a5fc5e8
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k400/6B_ft_k710_ft_k400_f8.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c83e3b9b8fb53378ad5b6c657a9dbbb2386417b4
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f16.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k600_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bcf034cb90ffb5cdbd9b5a9a2af929d800fab686
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/1B_ft_k710_ft_k600_f8.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k600_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fb5d5de404cf18f2c1af87cfe4066ac0f051e382
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f16.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k600_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 2 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 30 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f13ddde323bd9c94aa47cb08e77e567961fa716
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k600/6B_ft_k710_ft_k600_f8.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k600_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k600'
+DATA_PATH='your_data_path/k600'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 600 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e7719d4f7802d9fb1410ebcf3fc0ec3ae7cb89c7
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f16.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k700_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8f786851e9d132ce52649bd8f65bbe1f483d5855
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/1B_ft_k710_ft_k700_f8.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k700_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/1B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 3 \
+ --lr 1e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c673f26a08936e63e5e87954c2d9b367ac23f7d5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f16.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k700_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 2 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 30 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8b68247f41aabec14b7a2f9e1ec8d99e31b1ff0d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k700/6B_ft_k710_ft_k700_f8.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k700_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k700'
+DATA_PATH='your_data_path/k700'
+MODEL_PATH='your_model_path/6B_ft_k710_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 700 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 1 \
+ --lr 1e-5 \
+ --min_lr 1e-6 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/1B_ft_k710_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/1B_ft_k710_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..309945bb06555b99b6ab8d605ad5c54f6a3659d5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/1B_ft_k710_f8.sh
@@ -0,0 +1,60 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='your_data_path/k710'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 710 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 9 \
+ --lr 5e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/6B_ft_k710_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/6B_ft_k710_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..02633094a2d1fd873f3d69bcfc766c337200b701
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/k710/6B_ft_k710_f8.sh
@@ -0,0 +1,61 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='your_data_path/k710'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 710 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 1 \
+ --tubelet_size 1 \
+ --epochs 6 \
+ --lr 2.5e-5 \
+ --min_lr 0 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/1B_ft_k710_ft_k400_ft_mit_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/1B_ft_k710_ft_k400_ft_mit_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..55dba74ed9d095ffe61e74b72c6d3e7f6699c40f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/1B_ft_k710_ft_k400_ft_mit_f8.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_k710_ft_k400_ft_mit_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/mit'
+DATA_PATH='your_data_path/mit'
+MODEL_PATH='your_model_path/1B_ft_k710_ft_k400_f8.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'mitv1_sparse' \
+ --nb_classes 339 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 8 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 15 \
+ --lr 5e-5 \
+ --drop_path 0.3 \
+ --layer_decay 0.9 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best \
+ 2>&1 | tee "$(dirname $0)/log_$JOB_NAME.txt"
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f89ae41d6d20cae592233454702f27a183d33b16
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_ft_mit_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/mit'
+DATA_PATH='your_data_path/mit'
+MODEL_PATH='your_model_path/6B_ft_k710_ft_k400_f8.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'mitv1_sparse' \
+ --nb_classes 339 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 4 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 8 \
+ --lr 2.5e-5 \
+ --min_lr 0 \
+ --drop_path 0.35 \
+ --head_drop_path 0.35 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8_res224to336.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8_res224to336.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7976101a097d3bf38e4e709d546e10f9034b7b46
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/mit/6B_ft_k710_ft_k400_ft_mit_f8_res224to336.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_k710_ft_k400_ft_mit_f8_res224to336'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/mit'
+DATA_PATH='your_data_path/mit'
+MODEL_PATH='6B_ft_k710_ft_k400_ft_mit_f8'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'mitv1_sparse' \
+ --nb_classes 339 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 4 \
+ --num_sample 1 \
+ --input_size 336 \
+ --short_side_size 336 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --sampling_rate 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 2 \
+ --lr 1e-5 \
+ --min_lr 0 \
+ --drop_path 0.4 \
+ --head_drop_path 0.4 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 40 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.1 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/1B_ft_ssv1_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/1B_ft_ssv1_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1c7fe2fb2f043120ac99817ec4407073741b19c
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/1B_ft_ssv1_f8.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_ft_ssv1_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv1_frame'
+DATA_PATH='your_data_path/ssv1_frame'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --filename_tmpl '{:05}.jpg' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --num_workers 12 \
+ --warmup_epochs 3 \
+ --tubelet_size 1 \
+ --epochs 9 \
+ --lr 1e-4 \
+ --drop_path 0.3 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 2 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best \
+ 2>&1 | tee "$(dirname $0)/log_$JOB_NAME.txt"
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/6B_ft_ssv1_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/6B_ft_ssv1_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ac86f068fee6caaa6e70d2fa193b7b1daeda9649
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv1/6B_ft_ssv1_f8.sh
@@ -0,0 +1,65 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_ssv1_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv1_frame'
+DATA_PATH='your_data_path/ssv1_frame'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --filename_tmpl '{:05}.jpg' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 7 \
+ --lr 7.5e-5 \
+ --min_lr 0 \
+ --drop_path 0.4 \
+ --head_drop_path 0.4 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 4 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --eval \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/1B_ft_ssv2_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/1B_ft_ssv2_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8abaace6857c0e1a9c9707fc076fe1922945774
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/1B_ft_ssv2_f8.sh
@@ -0,0 +1,60 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_ssv2_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=32
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --batch_size 8 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --num_workers 12 \
+ --warmup_epochs 3 \
+ --tubelet_size 1 \
+ --epochs 8 \
+ --lr 1e-4 \
+ --drop_path 0.3 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 6 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 2 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/6B_ft_ssv2_f8.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/6B_ft_ssv2_f8.sh
new file mode 100644
index 0000000000000000000000000000000000000000..70b468f56280c8236d740a6f63f80fbae12b5e5f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/full_tuning/ssv2/6B_ft_ssv2_f8.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_ft_ssv2_f8'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=64
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_finetuning.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 4 \
+ --num_sample 2 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 8 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 6 \
+ --lr 7.5e-5 \
+ --min_lr 0 \
+ --drop_path 0.4 \
+ --head_drop_path 0.4 \
+ --layer_decay 0.915 \
+ --use_checkpoint \
+ --checkpoint_num 24 \
+ --layer_scale_init_value 1e-5 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0.05 \
+ --test_num_segment 5 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/1B_lp_hmdb51_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/1B_lp_hmdb51_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5d2cb8ea6777f0ddfa8019bbc643833811e610c5
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/1B_lp_hmdb51_f16.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_lp_hmdb51_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/hmdb51'
+DATA_PATH='your_data_path/hmdb51'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'HMDB51' \
+ --no_use_decord \
+ --nb_classes 51 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 2e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0.0 \
+ --fc_drop_rate 0.3 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7983cad7c2e056fdd52a32805457ff93d2eaaf45
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_hmdb51_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/hmdb51'
+DATA_PATH='your_data_path/hmdb51'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'HMDB51' \
+ --no_use_decord \
+ --nb_classes 51 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 2e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0.0 \
+ --fc_drop_rate 0.3 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9941c76915a20d85de47b01abe6d6ff1e937cefd
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/hmdb51/6B_lp_hmdb51_f16_loadStage2.sh
@@ -0,0 +1,66 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_hmdb51_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/hmdb51'
+DATA_PATH='your_data_path/hmdb51'
+MODEL_PATH='your_model_path/6B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'HMDB51' \
+ --no_use_decord \
+ --nb_classes 51 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 2e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0 \
+ --fc_drop_rate 0.3 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/1B_lp_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/1B_lp_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f4b5c587173ccea794fb27283a1fe17f286eb95
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/1B_lp_k400_f16.sh
@@ -0,0 +1,61 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_lp_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 3 \
+ --tubelet_size 1 \
+ --epochs 15 \
+ --lr 1e-3 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..493dbbe048752a04690fda3d16aba699b06e9952
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16.sh
@@ -0,0 +1,62 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_k400_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model vit_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 9 \
+ --lr 5e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..98f184f07914f471e8f99f10fbb0928400ab7873
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/k400/6B_lp_k400_f16_loadStage2.sh
@@ -0,0 +1,64 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_k400_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/k400'
+DATA_PATH='your_data_path/k400'
+MODEL_PATH='your_model_path/6B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing4_2.py \
+ --model vit_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'Kinetics_sparse' \
+ --split ',' \
+ --nb_classes 400 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 2 \
+ --tubelet_size 1 \
+ --epochs 9 \
+ --lr 5e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/1B_lp_ssv2_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/1B_lp_ssv2_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..57cc8463fdcd9f1bbc6b30989186e9d6fdd5be00
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/1B_lp_ssv2_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_lp_ssv2_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_cat_1B_patch14_224 \
+ --merge_method 'cls_avgN_proj' \
+ --merge_norm 'LN' \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 5 \
+ --tubelet_size 1 \
+ --epochs 25 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b40e8145531d0a9710d371e03d8728796aa0bb60
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_ssv2_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_cat_6B_patch14_224 \
+ --merge_method 'cls_avgN_proj' \
+ --merge_norm 'LN' \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 4 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e59ef422e9754ec719830b4421e336a3f857018f
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ssv2/6B_lp_ssv2_f16_loadStage2.sh
@@ -0,0 +1,66 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_ssv2_f16_loadStage2'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ssv2_frame'
+DATA_PATH='your_data_path/ssv2_frame'
+MODEL_PATH='your_model_path/6B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=16
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_cat_6B_patch14_224 \
+ --merge_method 'cls_avgN_proj' \
+ --merge_norm 'LN' \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'SSV2' \
+ --no_use_decord \
+ --nb_classes 174 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 4 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-4 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 1 \
+ --test_num_crop 3 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/1B_lp_ucf101_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/1B_lp_ucf101_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2380af5f1cea063aadf8d184b6e049df8f827950
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/1B_lp_ucf101_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_lp_ucf101_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ucf101'
+DATA_PATH='your_data_path/ucf101'
+MODEL_PATH='your_model_path/1B_pt.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_1B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'UCF101' \
+ --nb_classes 101 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0 \
+ --fc_drop_rate 0.5 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ssv2_f16_loadStage2.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ssv2_f16_loadStage2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ea6188c355a64e49af8f38785d8c2e5bddec675d
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ssv2_f16_loadStage2.sh
@@ -0,0 +1,65 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_ucf101_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ucf101'
+DATA_PATH='your_data_path/ucf101'
+MODEL_PATH='your_model_path/6B_pt.pth'
+EXTRA_MODEL_PATH='your_model_path/6B_pt_stage2.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'UCF101' \
+ --nb_classes 101 \
+ --finetune ${MODEL_PATH} \
+ --finetune_extra ${EXTRA_MODEL_PATH} \
+ --orig_t_size 4 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0.5 \
+ --fc_drop_rate 0.0 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ucf101_f16.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ucf101_f16.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f769b6641d38ffd4ae9219278a755a7ca59a2ace
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/finetuning/linear_probing/ucf101/6B_lp_ucf101_f16.sh
@@ -0,0 +1,63 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_lp_ucf101_f16'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+PREFIX='your_data_path/ucf101'
+DATA_PATH='your_data_path/ucf101'
+MODEL_PATH='your_model_path/6B_pt.pth'
+
+PARTITION='video'
+GPUS=8
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ python run_linear_probing.py \
+ --model internvideo2_6B_patch14_224 \
+ --data_path ${DATA_PATH} \
+ --prefix ${PREFIX} \
+ --data_set 'UCF101' \
+ --nb_classes 101 \
+ --finetune ${MODEL_PATH} \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --steps_per_print 10 \
+ --batch_size 64 \
+ --num_sample 1 \
+ --input_size 224 \
+ --short_side_size 224 \
+ --save_ckpt_freq 100 \
+ --num_frames 16 \
+ --orig_t_size 8 \
+ --num_workers 12 \
+ --warmup_epochs 0 \
+ --tubelet_size 1 \
+ --epochs 20 \
+ --lr 1e-3 \
+ --min_lr 0 \
+ --drop_path 0.0 \
+ --head_drop_path 0 \
+ --fc_drop_rate 0.5 \
+ --layer_decay 1.0 \
+ --layer_scale_init_value 1e-5 \
+ --aa rand-m5-n2-mstd0.25-inc1 \
+ --opt adamw \
+ --opt_betas 0.9 0.999 \
+ --weight_decay 0 \
+ --test_num_segment 2 \
+ --test_num_crop 1 \
+ --dist_eval \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR} \
+ --test_best
+
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/1B_pt.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/1B_pt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f22501f8a27e0327c35fcce401b9e51d915faa89
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/1B_pt.sh
@@ -0,0 +1,67 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='1B_pt'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='train_1.1M.csv'
+
+PARTITION='video'
+GPUS=128
+GPUS_PER_NODE=8
+CPUS_PER_TASK=16
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ python -u run_pretraining.py \
+ --data_path ${DATA_PATH} \
+ --num_sample 1 \
+ --flip \
+ --mask_type 'attention' \
+ --mask_ratio 0.8 \
+ --model 'pretrain_internvideo2_1B_patch14_224' \
+ --clip_teacher 'internvl_clip_6b' \
+ --clip_input_resolution 224 \
+ --clip_teacher_embed_dim 3200 \
+ --clip_teacher_final_dim 768 \
+ --clip_loss_ratio 1 1 \
+ --clip_norm_type 'l2' \
+ --clip_return_attn \
+ --clip_return_layer 6 \
+ --clip_teacher_return_interval 1 \
+ --clip_student_return_interval 1 \
+ --mae_teacher 'mae_g14_hybrid' \
+ --mae_tubelet_size 2 \
+ --mae_loss_ratio 1 \
+ --mae_norm_type 'l2' \
+ --mae_teacher_embed_dim 1408 \
+ --mae_return_layer 4 \
+ --mae_teacher_return_interval 1 \
+ --mae_student_return_interval 1 \
+ --tubelet_size 1 \
+ --lr 1.5e-4 \
+ --drop_path 0.25 \
+ --use_checkpoint \
+ --checkpoint_num 40 \
+ --layer_scale_init_value 1e-5 \
+ --batch_size 32 \
+ --num_segments 16 \
+ --num_frames 16 \
+ --sampling_rate 1 \
+ --num_workers 12 \
+ --opt adamw \
+ --opt_eps 1e-6 \
+ --opt_betas 0.9 0.98 \
+ --clip_grad 3.0 \
+ --weight_decay 0.05 \
+ --warmup_epochs 40 \
+ --save_ckpt_freq 50 \
+ --epochs 301 \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/6B_pt.sh b/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/6B_pt.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d4fe6bdbdbcb569adcd86e3a6fef1ededecd8cca
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/scripts/pretraining/6B_pt.sh
@@ -0,0 +1,67 @@
+export MASTER_PORT=$((12000 + $RANDOM % 20000))
+export OMP_NUM_THREADS=1
+
+JOB_NAME='6B_pt'
+OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
+LOG_DIR="./logs/${JOB_NAME}"
+DATA_PATH='train_2M.csv'
+
+PARTITION='video'
+GPUS=256
+GPUS_PER_NODE=8
+CPUS_PER_TASK=14
+
+srun -p $PARTITION \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ python -u run_pretraining.py \
+ --data_path ${DATA_PATH} \
+ --num_sample 1 \
+ --flip \
+ --mask_type 'attention' \
+ --mask_ratio 0.8 \
+ --model 'pretrain_internvideo2_6B_patch14_224' \
+ --clip_teacher 'internvl_clip_6b' \
+ --clip_input_resolution 224 \
+ --clip_teacher_embed_dim 3200 \
+ --clip_teacher_final_dim 768 \
+ --clip_loss_ratio 1 1 \
+ --clip_norm_type 'l2' \
+ --clip_return_attn \
+ --clip_return_layer 6 \
+ --clip_teacher_return_interval 1 \
+ --clip_student_return_interval 1 \
+ --mae_teacher 'mae_g14_hybrid' \
+ --mae_tubelet_size 2 \
+ --mae_loss_ratio 1 \
+ --mae_norm_type 'l2' \
+ --mae_teacher_embed_dim 1408 \
+ --mae_return_layer 4 \
+ --mae_teacher_return_interval 1 \
+ --mae_student_return_interval 1 \
+ --tubelet_size 1 \
+ --lr 1.5e-4 \
+ --drop_path 0.3 \
+ --use_checkpoint \
+ --checkpoint_num 48 \
+ --layer_scale_init_value 1e-5 \
+ --batch_size 8 \
+ --num_segments 16 \
+ --num_frames 16 \
+ --sampling_rate 1 \
+ --num_workers 12 \
+ --opt adamw \
+ --opt_eps 1e-6 \
+ --opt_betas 0.9 0.98 \
+ --clip_grad 3.0 \
+ --weight_decay 0.05 \
+ --warmup_epochs 40 \
+ --save_ckpt_freq 50 \
+ --epochs 301 \
+ --enable_deepspeed \
+ --bf16 \
+ --zero_stage 1 \
+ --log_dir ${OUTPUT_DIR} \
+ --output_dir ${OUTPUT_DIR}
\ No newline at end of file
diff --git a/third_party/InternVideo/InternVideo2/single_modality/utils.py b/third_party/InternVideo/InternVideo2/single_modality/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5d1479dc123f6886f5d3db42952301c8f0c5164
--- /dev/null
+++ b/third_party/InternVideo/InternVideo2/single_modality/utils.py
@@ -0,0 +1,955 @@
+import io
+import os
+import math
+import time
+import json
+from collections import defaultdict, deque
+import datetime
+import numpy as np
+from timm.utils import get_state_dict
+from torch.utils.data._utils.collate import default_collate
+from pathlib import Path
+import subprocess
+import torch
+import torch.distributed as dist
+from torch._six import inf
+import random
+
+from tensorboardX import SummaryWriter
+
+import fnmatch
+try:
+ from petrel_client.client import Client
+ has_client = True
+ client = Client('~/petreloss.conf')
+except ImportError:
+ has_client = False
+ client = None
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f} (max: {max:.4f})')
+ data_time = SmoothedValue(fmt='{avg:.4f} (max: {max:.4f})')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+ def log_every_joint(self, video_loader, image_loader, print_freq, header=None, image_num_ratio=1.0):
+ # prepare random squeue
+ total_len = int(len(video_loader) + len(image_loader) * image_num_ratio)
+ random_sequence = np.arange(total_len)
+ np.random.shuffle(random_sequence)
+ loader_list = [iter(video_loader), iter(image_loader)]
+ # prepare print template
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f} (max: {max:.4f})')
+ data_time = SmoothedValue(fmt='{avg:.4f} (max: {max:.4f})')
+ space_fmt = ':' + str(len(str(total_len))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+
+ for i, random_num in enumerate(random_sequence):
+ # randomly selct image or video
+ if random_num < len(video_loader):
+ loader_idx = 0
+ use_image = False
+ mark = '<>\t'
+ else:
+ loader_idx = 1
+ use_image = True
+ mark = '<>\t'
+ data_time.update(time.time() - end)
+ yield (next(loader_list[loader_idx]), use_image)
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == total_len - 1:
+ eta_seconds = iter_time.global_avg * (total_len - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(mark, log_msg.format(
+ i, total_len, eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(mark, log_msg.format(
+ i, total_len, eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / total_len))
+
+
+class TensorboardLogger(object):
+ def __init__(self, log_dir):
+ self.writer = SummaryWriter(logdir=log_dir)
+ self.step = 0
+
+ def set_step(self, step=None):
+ if step is not None:
+ self.step = step
+ else:
+ self.step += 1
+
+ def update(self, head='scalar', step=None, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
+
+ def flush(self):
+ self.writer.flush()
+
+
+def seed_worker(worker_id):
+ worker_seed = torch.initial_seed() % 2**32
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+
+def _load_checkpoint_for_ema(model_ema, checkpoint):
+ """
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
+ """
+ mem_file = io.BytesIO()
+ torch.save(checkpoint, mem_file)
+ mem_file.seek(0)
+ model_ema._load_checkpoint(mem_file)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def get_ceph_path(ckpt_path, ceph_args):
+ sub_path = str(ckpt_path).split(ceph_args['ckpt_path_split'])[-1]
+ ceph_ckpt_path = os.path.join(ceph_args['ceph_checkpoint_prefix'], sub_path)
+ return sub_path, ceph_ckpt_path
+
+def save_on_master(obj, ckpt_path, ceph_args):
+ if is_main_process():
+ if ceph_args['use_ceph_checkpoint']:
+ assert has_client == True, "petrel_client is not installed!!!"
+ _, ceph_ckpt_path = get_ceph_path(ckpt_path, ceph_args)
+ with io.BytesIO() as f:
+ torch.save(obj, f)
+ client.put(ceph_ckpt_path, f.getvalue())
+ else:
+ torch.save(obj, ckpt_path)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = int(os.environ['SLURM_LOCALID'])
+ args.world_size = int(os.environ['SLURM_NTASKS'])
+ os.environ['RANK'] = str(args.rank)
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+
+ node_list = os.environ['SLURM_NODELIST']
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ # assert torch.distributed.is_initialized()
+ setup_for_distributed(args.rank == 0)
+
+
+def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ def load(module, prefix=''):
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(model, prefix=prefix)
+
+ warn_missing_keys = []
+ ignore_missing_keys = []
+ for key in missing_keys:
+ keep_flag = True
+ for ignore_key in ignore_missing.split('|'):
+ if ignore_key in key:
+ keep_flag = False
+ break
+ if keep_flag:
+ warn_missing_keys.append(key)
+ else:
+ ignore_missing_keys.append(key)
+
+ missing_keys = warn_missing_keys
+
+ if len(missing_keys) > 0:
+ print("Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys))
+ if len(unexpected_keys) > 0:
+ print("Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys))
+ if len(ignore_missing_keys) > 0:
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, ignore_missing_keys))
+ if len(error_msgs) > 0:
+ print('\n'.join(error_msgs))
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
+ start_warmup_value=0, warmup_steps=-1):
+ warmup_schedule = np.array([])
+ warmup_iters = int(warmup_epochs * niter_per_ep)
+ if warmup_steps > 0:
+ warmup_iters = warmup_steps
+ print("Set warmup steps = %d" % warmup_iters)
+ if warmup_epochs > 0:
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
+ schedule = np.array(
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
+
+ schedule = np.concatenate((warmup_schedule, schedule))
+
+ assert len(schedule) == epochs * niter_per_ep
+ return schedule
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, model_name=None, ceph_args={'use_ceph_checkpoint': False}):
+ output_dir = Path(args.output_dir)
+ if model_name is None:
+ model_name = str(epoch)
+ if loss_scaler is not None:
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % model_name)]
+ for checkpoint_path in checkpoint_paths:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+
+ if model_ema is not None:
+ to_save['model_ema'] = get_state_dict(model_ema)
+
+ save_on_master(to_save, checkpoint_path, ceph_args=ceph_args)
+ else:
+ client_state = {'epoch': epoch}
+ if model_ema is not None:
+ client_state['model_ema'] = get_state_dict(model_ema)
+
+ if ceph_args['use_ceph_checkpoint']:
+ sub_path, ceph_save_dir = get_ceph_path(output_dir, ceph_args)
+ local_save_dir = os.path.join('/dev/shm', sub_path)
+ Path(local_save_dir).mkdir(parents=True, exist_ok=True)
+ else:
+ local_save_dir = output_dir
+ tag_name = "checkpoint-%s" % model_name
+ model.save_checkpoint(save_dir=local_save_dir, tag=tag_name, client_state=client_state)
+
+ if ceph_args['use_ceph_checkpoint'] and ceph_args['local_rank'] == 0:
+ try:
+ assert has_client == True, "petrel_client is not installed!!!"
+ ckpt_shm_dir = os.path.join(local_save_dir, tag_name)
+ ckpt_petrel_dir = os.path.join(ceph_save_dir, tag_name)
+ for f_name in os.listdir(ckpt_shm_dir):
+ f_shm_path = os.path.join(ckpt_shm_dir, f_name)
+ f_petrel_path = os.path.join(ckpt_petrel_dir, f_name)
+ with open(f_shm_path, 'rb') as f:
+ print(f"Upload checkpoint at {f_petrel_path}", flush=True)
+ client.put(f_petrel_path, f)
+ print("Finish! Will remove the original files!", flush=True)
+ os.remove(f_shm_path)
+ except Exception as e:
+ print(f'Fail to upload or delete {f_shm_path} with error {e}')
+
+
+def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, ceph_args={'use_ceph_checkpoint': False}):
+ output_dir = Path(args.output_dir)
+
+ if ceph_args['use_ceph_checkpoint']:
+ assert has_client == True, "petrel_client is not installed!!!"
+ sub_path, ceph_save_dir = get_ceph_path(output_dir, ceph_args)
+ if loss_scaler is not None:
+ # torch.amp
+ if args.test_best and args.eval:
+ args.resume = os.path.join(ceph_save_dir, 'checkpoint-best.pth')
+ elif check_ceph_exists(os.path.join(ceph_save_dir, 'checkpoint-latest.pth')):
+ args.resume = os.path.join(ceph_save_dir, 'checkpoint-latest.pth')
+ elif args.auto_resume and len(args.resume) == 0:
+ all_checkpoints = fnmatch.filter(list(client.list(ceph_save_dir)), 'checkpoint-*')
+ all_checkpoints = [
+ os.path.join(ceph_save_dir, ckpt_path)
+ for ckpt_path in all_checkpoints
+ ]
+ latest_ckpt = -1
+ for ckpt in all_checkpoints:
+ t = ckpt.split('-')[-1].split('.')[0]
+ if t.isdigit():
+ latest_ckpt = max(int(t), latest_ckpt)
+ if latest_ckpt >= 0:
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
+ print("Auto resume checkpoint: %s" % args.resume)
+
+ if args.resume:
+ with io.BytesIO(client.get(args.resume)) as buffer:
+ checkpoint = torch.load(buffer, map_location='cpu')
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ print("Resume checkpoint %s" % args.resume)
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if hasattr(args, 'model_ema') and args.model_ema:
+ _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+ else:
+ # deepspeed, only support '--auto_resume'.
+ flag = False
+ if args.test_best and args.eval:
+ try:
+ load_specific_ceph_model(
+ model, model_ema, args, sub_path, ceph_save_dir,
+ model_name='best', ceph_args=ceph_args
+ )
+ flag = True
+ except Exception:
+ print('No best model')
+ if not flag:
+ try:
+ load_specific_ceph_model(
+ model, model_ema, args, sub_path, ceph_save_dir,
+ model_name='latest', ceph_args=ceph_args
+ )
+ flag = True
+ except Exception:
+ print('No latest model')
+ if not flag:
+ try:
+ load_specific_ceph_model(
+ model, model_ema, args, sub_path, ceph_save_dir,
+ model_name='best', ceph_args=ceph_args
+ )
+ flag = True
+ except Exception:
+ print('No best model')
+ if not flag:
+ all_checkpoints = fnmatch.filter(list(client.list(ceph_save_dir)), 'checkpoint-*')
+ all_checkpoints = [
+ os.path.join(ceph_save_dir, ckpt_path)
+ for ckpt_path in all_checkpoints
+ ]
+ latest_ckpt = -1
+ for ckpt in all_checkpoints:
+ t = ckpt.split('-')[-1].split('.')[0]
+ if t.isdigit():
+ latest_ckpt = max(int(t), latest_ckpt)
+ if latest_ckpt >= 0:
+ load_specific_ceph_model(
+ model, model_ema, args, sub_path, ceph_save_dir,
+ model_name=latest_ckpt, ceph_args=ceph_args
+ )
+ else:
+ print('No other models')
+ else:
+ if loss_scaler is not None:
+ # torch.amp
+ if args.test_best and args.eval:
+ args.resume = os.path.join(output_dir, 'checkpoint-best.pth')
+ elif os.path.exists(os.path.join(output_dir, 'checkpoint-latest.pth')):
+ args.resume = os.path.join(output_dir, 'checkpoint-latest.pth')
+ elif args.auto_resume and len(args.resume) == 0:
+ import glob
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
+ latest_ckpt = -1
+ for ckpt in all_checkpoints:
+ t = ckpt.split('-')[-1].split('.')[0]
+ if t.isdigit():
+ latest_ckpt = max(int(t), latest_ckpt)
+ if latest_ckpt >= 0:
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
+ print("Auto resume checkpoint: %s" % args.resume)
+
+ if args.resume:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ print("Resume checkpoint %s" % args.resume)
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if hasattr(args, 'model_ema') and args.model_ema:
+ _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+ else:
+ # deepspeed, only support '--auto_resume'.
+ flag = False
+ if args.test_best and args.eval:
+ try:
+ load_specific_model(model, model_ema, args, output_dir, model_name='best')
+ flag = True
+ except Exception:
+ print('No best model')
+ if not flag:
+ try:
+ load_specific_model(model, model_ema, args, output_dir, model_name='latest')
+ flag = True
+ except Exception:
+ print('No latest model')
+ if not flag:
+ try:
+ load_specific_model(model, model_ema, args, output_dir, model_name='best')
+ flag = True
+ except Exception:
+ print('No best model')
+ if not flag:
+ import glob
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
+ latest_ckpt = -1
+ for ckpt in all_checkpoints:
+ t = ckpt.split('-')[-1].split('.')[0]
+ if t.isdigit():
+ latest_ckpt = max(int(t), latest_ckpt)
+ if latest_ckpt >= 0:
+ load_specific_model(model, model_ema, args, output_dir, model_name=latest_ckpt)
+ else:
+ print('No other models')
+
+
+def load_specific_model(model, model_ema, args, output_dir, model_name):
+ args.resume = os.path.join(output_dir, f'checkpoint-{model_name}')
+ print(f"Auto resume the {model_name} checkpoint")
+ _, client_states = model.load_checkpoint(args.output_dir, tag=f'checkpoint-{model_name}')
+ args.start_epoch = client_states['epoch'] + 1
+ if model_ema is not None:
+ if args.model_ema:
+ _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
+
+
+def check_ceph_exists(ceph_path):
+ return list(client.list(ceph_path)) > 0
+
+
+def load_specific_ceph_model(model, model_ema, args, sub_path, ceph_save_dir, model_name, ceph_args):
+ tag_name = f'checkpoint-{model_name}'
+ args.resume = os.path.join(ceph_save_dir, tag_name)
+ print(f"Auto resume checkpoint: {args.resume}", flush=True)
+ shm_resume_dir = os.path.join('/dev/shm', sub_path, tag_name)
+ Path(shm_resume_dir).mkdir(parents=True, exist_ok=True)
+
+ if ceph_args['local_rank'] == 0:
+ for f_name in client.list(args.resume):
+ ckpt_petrel_path = os.path.join(args.resume, f_name)
+ ckpt_shm_path = os.path.join(shm_resume_dir, f_name)
+ print(f"Download model from {ckpt_petrel_path}", flush=True)
+ with open(ckpt_shm_path, 'wb') as f:
+ f.write(memoryview(client.get(ckpt_petrel_path)))
+ print("Finish downloading!", flush=True)
+
+ torch.distributed.barrier()
+
+ _, client_states = model.load_checkpoint(os.path.join('/dev/shm', sub_path), tag=f'checkpoint-{model_name}')
+ args.start_epoch = client_states['epoch'] + 1
+ if model_ema is not None:
+ if args.model_ema:
+ _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
+
+ if ceph_args['local_rank'] == 0:
+ try:
+ for root, dirs, files in os.walk(shm_resume_dir):
+ for name in files:
+ os.remove(os.path.join(root, name))
+ for name in dirs:
+ os.rmdir(os.path.join(root, name))
+ os.rmdir(root)
+ except Exception as e:
+ print(f'Fail to clean {shm_resume_dir} with error {e}')
+
+
+def create_ds_config(args):
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
+ with open(args.deepspeed_config, mode="w") as writer:
+ ds_config = {
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
+ "train_micro_batch_size_per_gpu": args.batch_size,
+ "steps_per_print": 1000,
+ "optimizer": {
+ "type": "Adam",
+ "adam_w_mode": True,
+ "params": {
+ "lr": args.lr,
+ "weight_decay": args.weight_decay,
+ "bias_correction": True,
+ "betas": [
+ 0.9,
+ 0.999
+ ],
+ "eps": 1e-8
+ }
+ },
+ "fp16": {
+ "enabled": True,
+ "loss_scale": 0,
+ "initial_scale_power": 7,
+ "loss_scale_window": 128
+ }
+ }
+
+ writer.write(json.dumps(ds_config, indent=2))
+
+
+def create_internvideo2_lp_ds_config(args):
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
+ with open(args.deepspeed_config, mode="w") as writer:
+ ds_config = {
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
+ "train_micro_batch_size_per_gpu": args.batch_size,
+ "steps_per_print": 1000,
+ "optimizer": {
+ "type": "Adam",
+ "adam_w_mode": True,
+ "params": {
+ "lr": args.lr,
+ "weight_decay": args.weight_decay,
+ "bias_correction": True,
+ "betas": [
+ args.opt_betas[0],
+ args.opt_betas[1]
+ ],
+ "eps": args.opt_eps
+ }
+ },
+ "fp16": {
+ "enabled": not args.bf16,
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 500,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": args.bf16
+ },
+ }
+ if args.clip_grad is not None:
+ ds_config.update({'gradient_clipping': args.clip_grad})
+
+ writer.write(json.dumps(ds_config, indent=2))
+
+
+# stolen from https://github.com/baaivision/EVA/blob/7389aeeec97c056fc8424fa6b78f35c6f1b07d0d/EVA-02/asuka/utils.py#L529C5-L599C54
+def create_internvideo_ds_config(args):
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
+ with open(args.deepspeed_config, mode="w") as writer:
+ ds_config = {
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
+ "train_micro_batch_size_per_gpu": args.batch_size,
+ "steps_per_print": args.steps_per_print,
+ "optimizer": {
+ "type": "Adam",
+ "adam_w_mode": True,
+ "params": {
+ "lr": args.lr,
+ "weight_decay": args.weight_decay,
+ "bias_correction": True,
+ "betas": [
+ args.opt_betas[0],
+ args.opt_betas[1]
+ ],
+ "eps": args.opt_eps
+ }
+ },
+ "fp16": {
+ "enabled": not args.bf16,
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 500,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": args.bf16
+ },
+ "amp": {
+ "enabled": False,
+ "opt_level": "O2"
+ },
+ "flops_profiler": {
+ "enabled": True,
+ "profile_step": -1,
+ "module_depth": -1,
+ "top_modules": 1,
+ "detailed": True,
+ },
+ "zero_allow_untested_optimizer": True
+ }
+
+ if args.clip_grad is not None:
+ ds_config.update({'gradient_clipping': args.clip_grad})
+
+ if args.zero_stage == 1:
+ ds_config.update(
+ {
+ "zero_optimization": {
+ "stage": 1,
+ "reduce_bucket_size": 5e8,
+ }
+ }
+ )
+ elif args.zero_stage == 2:
+ ds_config.update(
+ {
+ "zero_optimization": {
+ "stage": 2,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ "allgather_bucket_size": 5e8,
+ "cpu_offload": False,
+ }
+ }
+ )
+ elif args.zero_stage == 3:
+ ds_config.update(
+ {
+ "zero_optimization": {
+ "stage": 3,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e4,
+ "allgather_bucket_size": 5e4,
+ "cpu_offload": False,
+ "stage3_max_live_parameters": 1e5,
+ "stage3_max_reuse_distance": 1e5,
+ },
+ }
+ )
+ elif args.zero_stage > 3:
+ raise NotImplementedError()
+
+ opt_lower = args.opt.lower()
+ if opt_lower != 'adamw': del ds_config['optimizer']
+
+ writer.write(json.dumps(ds_config, indent=2))
+
+
+def multiple_samples_collate(batch, fold=False):
+ """
+ Collate function for repeated augmentation. Each instance in the batch has
+ more than one sample.
+ Args:
+ batch (tuple or list): data batch to collate.
+ Returns:
+ (tuple): collated data batch.
+ """
+ inputs, labels, video_idx, extra_data = zip(*batch)
+ inputs = [item for sublist in inputs for item in sublist]
+ labels = [item for sublist in labels for item in sublist]
+ video_idx = [item for sublist in video_idx for item in sublist]
+ inputs, labels, video_idx, extra_data = (
+ default_collate(inputs),
+ default_collate(labels),
+ default_collate(video_idx),
+ default_collate(extra_data),
+ )
+ if fold:
+ return [inputs], labels, video_idx, extra_data
+ else:
+ return inputs, labels, video_idx, extra_data
+
+
+def multiple_pretrain_samples_collate(batch, fold=False):
+ """
+ Collate function for repeated augmentation. Each instance in the batch has
+ more than one sample.
+ Args:
+ batch (tuple or list): data batch to collate.
+ Returns:
+ (tuple): collated data batch.
+ """
+ process_data, mask = zip(*batch)
+ process_data = [item for sublist in process_data for item in sublist]
+ mask = [item for sublist in mask for item in sublist]
+ process_data, mask = (
+ default_collate(process_data),
+ default_collate(mask),
+ )
+ if fold:
+ return [process_data], mask
+ else:
+ return process_data, mask
\ No newline at end of file
diff --git a/third_party/InternVideo/LICENSE b/third_party/InternVideo/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..838ad8987b86157ff55f7251d268e8f17d2ed22c
--- /dev/null
+++ b/third_party/InternVideo/LICENSE
@@ -0,0 +1,203 @@
+Copyright (c) OpenGVLab. All rights reserved
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/third_party/InternVideo/README.md b/third_party/InternVideo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f572e68c0a81c9dc53629d34f904b1763885be06
--- /dev/null
+++ b/third_party/InternVideo/README.md
@@ -0,0 +1,33 @@
+# InternVideo: Video Foundation Models for Multimodal Understanding
+
+
+
+---
+
+
+
+
+
+This repo contains InternVideo series and related works in video foundation models.
+
+- [InternVideo](InternVideo1): general video foundation models via generative and discriminative learning
+- [InternVideo2](InternVideo2): scaling video foundation models for multimodal video understanding
+- [InternVid](Data/InternVid): a large-scale video-text dataset for multimodal understanding and generation
+
+## Updates
+- `2024.03`: The [technical report](https://arxiv.org/abs/2403.15377) of InternVideo2 is released.
+- `2024.01`: [InternVid](Data/InternVid) (a video-text dataset for video understanding and generation) has been accepted for spotlight presentation of ICLR 2024.
+- `2023.07`: A **video-text dataset InternVid** is released at [here](Data/InternVid) for facilitating multimodal understanding and generation.
+- `2023.05`: **Video instruction data** are released at [here](Data/instruction_data) for tuning end-to-end video-centric multimodal dialogue systems like [VideoChat](https://github.com/OpenGVLab/Ask-Anything).
+- `2023.01`: The [code & models](InternVideo1) of InternVideo are released.
+- `2022.12`: The [technical report](https://arxiv.org/pdf/2212.03191.pdf) of InternVideo is released.
+- `2022.09`: Press releases of InternVideo ([official](https://www.shlab.org.cn/news/5443279) | [163 news](https://www.163.com/dy/article/HG939TNR0530QRMB.html) | [qq news](https://new.qq.com/rain/a/20220902A053JP00)).
+
+## Contact
+- If you have any questions during the trial, running or deployment, feel free to join our WeChat group discussion! If you have any ideas or suggestions for the project, you are also welcome to join our WeChat group discussion!
+
+
+
+
+
+- We are hiring researchers, engineers and interns in General Vision Group, Shanghai AI Lab. If you are interested in working with us on video foundation models and related topics, please contact Yi Wang (wangyi@pjlab.org.cn).
\ No newline at end of file
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/genrl_utils.py b/tools/genrl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f734980bc55a6d8220b1c1a739accf0fc1cacd2d
--- /dev/null
+++ b/tools/genrl_utils.py
@@ -0,0 +1,312 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from pathlib import Path
+
+MODELS_ROOT_PATH = Path(__file__).parent.parent / 'models'
+INTERNVIDEO_PATH = Path(__file__).parent.parent / 'third_party' / 'InternVideo'
+
+DOMAIN2PREDICATES = {
+ 'walker' : ['taking a walk', 'standing up vertically on both feet', 'single-leg balancing', "standing upside down", 'high kick', 'walking', 'stepping forward', 'running fast',
+ 'standing on one bended knee', 'lying down on the back with one raised leg', 'sitting on the knees', 'dog yoga pose', 'lying down horizontally', ],
+ 'stickman' : ['taking a walk', 'standing up vertically', 'one leg balancing', 'high kick', 'walking', 'running fast',
+ 'praying', 'lying down with one raised leg', 'dog yoga pose', 'lying down horizontally', 'punching', 'raised hands' ],
+ 'cheetah' : ['jumping', 'crawling', 'running', 'flipping', 'standing up', 'hopping', 'lying down', 'falling',
+ 'standing on the knees'],
+ 'quadruped' : ['jumping', 'crawling', 'walking', 'standing up',
+ 'hopping', 'lying down', 'falling', 'standing on the knees'],
+ 'finger' : ['spin', 'touch', 'rotate', 'horizontal', 'vertical', "not moving", "is not touching", "staying far away", "staying still"],
+ 'pendulum' : ['horizontal', 'vertical', 'left', 'right',
+ 'swingup', 'balance'],
+ 'hopper' : ['jumping', 'crawling', 'walking', 'standing up',
+ 'hopping', 'lying down', 'falling', 'standing on the knees'],
+ 'reacher' : ['horizontal', 'vertical', 'ball on the left', 'ball on the right', 'touch the ball with the elbow', 'touch the ball with the tip', 'arm reaches the sphere', 'rotating', 'bending', 'keeping straight', "not moving", "is not touching"],
+ 'jaco' : ['horizontal', 'vertical', 'left', 'right', 'spin', 'touch', 'rotate', 'bend', 'straight', "is not touching"],
+ 'kitchen' : [ "touch", "pick up", "lift", "grasp", "hold", "pull", "open", "close",
+ "push", "sweep", "slide"] + ['switch light on', 'open the microwave', 'move the kettle', 'turn on the burner'],
+}
+
+TASK2PROMPT = {
+ "quadruped_run" : 'spider running fast',
+ "quadruped_walk" : 'spider walking fast',
+ "quadruped_stand" : 'spider standing',
+ "quadruped_jump" : 'spider jumping',
+
+ "quadruped_two_legs" : 'on two legs',
+ "quadruped_lie_down" : 'lying down',
+
+ "cheetah_run" : 'running like a quadruped',
+
+ "cheetah_flipping" : 'quadruped rotating flips',
+ "cheetah_standing" : 'standing like a human',
+ "cheetah_lying_down" : 'lying down',
+
+ 'stickman_walk' : 'robot walk fast clean',
+ 'stickman_run' : 'robot run fast clean',
+ 'stickman_stand' : 'standing',
+ 'stickman_urlb_flip' : 'doing flips',
+
+ 'stickman_flip' : 'doing flips',
+ 'stickman_flipping' : 'doing flips',
+ 'stickman_backflip' : 'doing backflips',
+ 'stickman_one_foot' : 'stand on one foot',
+ 'stickman_high_kick' : 'stand up and kick',
+ 'stickman_lying_down' : 'lying down horizontally',
+ 'stickman_legs_up' : 'lying down with feet up',
+ 'stickman_sit_knees' : 'praying',
+ 'stickman_lunge_pose' : 'lunge_pose',
+ 'stickman_headstand' : 'headstand',
+ 'stickman_boxing' : 'punch',
+ 'stickman_hands_up' : 'standing with the hands up',
+
+ 'walker_walk' : 'walk fast clean',
+ 'walker_run' : 'run fast clean',
+ 'walker_stand' : 'standing up straight',
+ 'walker_urlb_flip' : 'doing backflips',
+
+ 'walker_flip' : 'doing flips',
+ 'walker_flipping' : 'doing backflips',
+ 'walker_backflip' : 'doing backflips',
+ 'walker_one_foot' : 'stand on one foot',
+ 'walker_high_kick' : 'stand up and kick',
+ 'walker_lying_down' : 'lying down horizontally',
+ 'walker_arabesque' : 'arabesque position',
+ 'walker_legs_up' : 'lying down with feet up',
+ 'walker_sit_knees' : 'praying',
+ 'walker_lunge_pose' : 'lunge_pose',
+ 'walker_headstand' : 'headstand',
+
+ 'kitchen_microwave' : 'opening the microwave fully open',
+ 'kitchen_light' : 'activate the light',
+ 'kitchen_burner' : 'the burner becomes red',
+ 'kitchen_slide' : 'slide cabinet above the knobs',
+
+ 'kitchen_kettle' : 'pushing up the kettle',
+
+ 'jaco_reach_top_left' : 'robot grasp the red cube',
+ 'jaco_reach_bottom_left' : 'robot grasp the red cube',
+ 'jaco_reach_top_right' : 'robot grasp the red cube',
+ 'jaco_reach_bottom_right' : 'robot grasp the red cube',
+}
+
+class ViCLIPGlobalInstance:
+ def __init__(self, model='internvideo2'):
+ self._instantiated = False
+ self._model = model
+
+ def instantiate(self, device='cuda'):
+ from torchvision.transforms import transforms as vision_transf
+ import sys
+
+ self._instantiated = True
+
+ if self._model =='internvideo2':
+ sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/'))
+ sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality'))
+ import numpy as np
+ from small_config import (Config, eval_dict_leaf)
+ from small_utils import setup_internvideo2
+ config = Config.from_file(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/internvideo2_stage2_config.py')
+ config = eval_dict_leaf(config)
+ config.model.vision_encoder.num_frames = 8
+ config.num_frames = 8
+ config.num_frames_test = 8
+ # # >> can be configured in case the bert model doesn't load
+ # config.model.text_encoder.pretrained = str(MODELS_ROOT_PATH / 'bert-large-uncased')
+ config.model.text_encoder.config = str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality') + "/" + config.model.text_encoder.config
+ model_pth = str(MODELS_ROOT_PATH / 'InternVideo2-stage2_1b-224p-f4.pt')
+ config.pretrained_path = model_pth
+ config['model']['vision_encoder']['pretrained'] = model_pth
+ intern_model, tokenizer = setup_internvideo2(config)
+ self.viclip_tokenizer = tokenizer
+ self.viclip = intern_model
+ self.viclip.device = device
+ self.viclip.to(self.viclip.device)
+ self.viclip.eval()
+ self.viclip.n_frames = 8
+ self.viclip.preprocess_transf = vision_transf.Compose([
+ vision_transf.Resize(size=(224, 224), interpolation=vision_transf.InterpolationMode.BILINEAR),
+ vision_transf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+ sys.path.pop(0)
+ sys.path.pop(0)
+ else:
+ raise NotImplementedError(f"Model {self._model} not implemented")
+
+ vid_feat = self.viclip.get_vid_features(torch.zeros(1,self.viclip.n_frames,3,224,224, device=self.viclip.device))
+ self.viclip_emb_dim = vid_feat.shape[1]
+
+
+def report_text2video(agent, data,):
+ report = {}
+
+ domain = agent.cfg.task.split('_')[0]
+ labels_list = DOMAIN2PREDICATES[domain]
+
+ wm = world_model = agent.wm
+ decoder = world_model.heads['decoder'] # B, T, C, H, W
+ connector = agent.wm.connector
+ n_frames = connector.n_frames
+
+ if hasattr(world_model, 'viclip_model'):
+ clip = world_model.viclip_model
+ else:
+ # Get ViCLIP
+ viclip_global_instance = globals()['viclip_global_instance']
+ if not viclip_global_instance._instantiated:
+ viclip_global_instance.instantiate()
+ clip = viclip_global_instance.viclip
+
+ # Get text(video) embed
+ text_feat = []
+ for text in labels_list:
+ with torch.no_grad():
+ text_feat.append(clip.get_txt_feat(text,))
+ text_feat = torch.stack(text_feat, dim=0)
+ # Check device is right
+ video_embed = text_feat.to(agent.device)
+ B = video_embed.shape[0]
+
+ # Get actions
+ video_embed = video_embed.repeat(1,n_frames, 1)
+ # Imagine
+ prior = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=False, reset_every_n_frames=False, denoise=True)
+ prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5
+ report[f'text_to_video'] = prior_recon
+ return report
+
+def max_cosine_similarity(u, v, dim=-1):
+ max_norm = torch.max(torch.norm(u, dim=dim), torch.norm(v, dim=dim)).unsqueeze(-1)
+ return torch.sum((u / max_norm) * (v / max_norm), dim=dim)
+
+def neg_mse_fn(a, b, dim=-1, scale=True):
+ dist = - torch.norm(a - b, dim=dim)
+ if scale:
+ dist = dist / np.sqrt(a.shape[-1]).item()
+ return dist
+
+def compute_reward(agent, agent_seq, target_seq, score_fn='cosine',):
+ if score_fn in ['cosine', 'max_cosine', 'neg_mse', 'exp_neg_mse']:
+ distance_fn = dict(cosine=F.cosine_similarity, max_cosine=max_cosine_similarity, neg_mse=neg_mse_fn, exp_neg_mse=neg_mse_fn)[score_fn]
+ target_stoch = agent.wm.connector.get_stoch( target_seq )
+ agent_stoch = agent.wm.rssm.get_stoch( agent_seq )
+ conv_target = agent.wm.heads['decoder']._conv_in[0](target_stoch)
+ conv_agent = agent.wm.heads['decoder']._conv_in[0](agent_stoch)
+ reward = distance_fn(conv_target, conv_agent, dim=-1)
+ if score_fn == 'exp_neg_mse':
+ reward = torch.exp(reward)
+ elif score_fn == 'neg_kl':
+ agent_dist = agent.wm.rssm.get_dist( agent_seq )
+ target_dist = agent.wm.connector.get_dist( target_seq )
+ reward = -torch.distributions.kl_divergence(agent_dist, target_dist,)
+ # scaling factor ( x log x w.r.t. to classes, or just x)
+ if 'logit' in target_seq:
+ reward = reward / ( np.log(target_seq['logit'].shape[-1]) * target_seq['logit'].shape[-2] )
+ else:
+ reward = reward / target_seq['mean'].shape[-1]
+ elif score_fn == 'max_like':
+ agent_dist = agent.wm.rssm.get_dist( agent_seq )
+ target_sample = target_seq['stoch']
+ reward = agent_dist.log_prob(target_sample)
+ elif score_fn == 'combo':
+ return compute_reward(agent, agent_seq, target_seq, 'cosine') + compute_reward(agent, agent_seq, target_seq, 'neg_kl')
+ else:
+ raise NotImplementedError(f"{score_fn} reward not implemented")
+ return reward
+
+def video_text_reward(agent, seq, score_fn='cosine',
+ sample_for_target=False, weighted_align=False, align_initial=False, align_sequence=False,
+ task_prompt='', skip_first_target=False, **kwargs):
+ wm = world_model = agent.wm
+ connector = agent.wm.connector
+ n_frames = connector.n_frames
+
+ T, B = seq['deter'].shape[:2]
+ imagined_steps = T
+
+ if not hasattr(agent, 'unconditional_target'):
+ if hasattr(world_model, 'viclip_model'):
+ clip = world_model.viclip_model
+ else:
+ # Get ViCLIP
+ viclip_global_instance = globals()['viclip_global_instance']
+ if not viclip_global_instance._instantiated:
+ viclip_global_instance.instantiate()
+ clip = viclip_global_instance.viclip
+
+ if task_prompt != '':
+ task = [task_prompt]
+ else:
+ task = [ TASK2PROMPT[agent.cfg.task] ]
+
+ # Get text(video) embed
+ with torch.no_grad():
+ text_feat = clip.get_txt_feat(task[0],)
+ # Check device is right
+ video_embed = text_feat.to(agent.device)
+
+ # Unconditional gen
+ if skip_first_target:
+ video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps + 1, 1)
+ unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True)
+ unconditional_stats = { k: v[:,1:].permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
+ else:
+ video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps, 1)
+ unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True)
+ unconditional_stats = { k: v.permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
+ agent.unconditional_target = unconditional_stats
+ else:
+ unconditional_stats = agent.unconditional_target
+
+ agent_seq = seq
+ target_seq = unconditional_stats
+ if align_initial:
+ assert not align_sequence, 'Cannot align initial and sequence at the same time'
+ init_seq = { k: v[0] for k,v in target_seq.items() }
+ init_score = compute_reward(agent, agent_seq, init_seq, score_fn=score_fn,)
+ if weighted_align:
+ w = 0.99 * torch.ones_like(init_score, device=init_score.device)
+ w = torch.cumprod(w, dim=1)
+ init_score = w * init_score
+ #
+ best_indexes_one_hot = F.one_hot(torch.argmax(init_score, dim=0), num_classes=target_seq['stoch'].shape[0])
+ ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
+ new_target_seq = {}
+ for k,v in target_seq.items():
+ if len(v.shape) == 4:
+ new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
+ else:
+ new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
+ new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
+ return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
+ elif align_sequence:
+ align_score = []
+ get_prev_a_b = lambda d, a, b : { k : v[a:b] for k,v in d.items() }
+ shorter_target_seq = get_prev_a_b(unconditional_stats, 0, n_frames)
+ for t in range(T-n_frames):
+ cur_agent_seq = get_prev_a_b(seq, t, t+n_frames)
+ score = compute_reward(agent, cur_agent_seq, shorter_target_seq, score_fn=score_fn,).mean(dim=0) # 0 is time dimension
+ align_score.append(score)
+ align_score = torch.stack(align_score, dim=0)
+ if weighted_align:
+ w = 0.99 * torch.ones_like(align_score, device=align_score.device)
+ w = torch.cumprod(w, dim=1)
+ align_score = w * align_score
+ best_indexes_one_hot = F.one_hot(torch.argmax(align_score, dim=0), num_classes=target_seq['stoch'].shape[0])
+ ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
+ new_target_seq = {}
+ for k,v in target_seq.items():
+ if len(v.shape) == 4:
+ new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
+ else:
+ new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
+ new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
+ return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
+ else:
+ neg_kl = compute_reward(agent, agent_seq, target_seq, score_fn=score_fn,)
+
+ return neg_kl.unsqueeze(-1)
+
+global viclip_global_instance
+viclip_global_instance = ViCLIPGlobalInstance()
\ No newline at end of file
diff --git a/tools/logger.py b/tools/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc1378f30ee143aa756bf06c0b64c2231bf15b69
--- /dev/null
+++ b/tools/logger.py
@@ -0,0 +1,236 @@
+import csv
+import datetime
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torchvision
+import wandb
+from termcolor import colored
+from torch.utils.tensorboard import SummaryWriter
+
+COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
+ ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
+ ('episode_reward', 'R', 'float'),
+ ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]
+
+COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
+ ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
+ ('episode_reward', 'R', 'float'),
+ ('total_time', 'T', 'time')]
+
+
+class AverageMeter(object):
+ def __init__(self):
+ self._sum = 0
+ self._count = 0
+
+ def update(self, value, n=1):
+ self._sum += value
+ self._count += n
+
+ def value(self):
+ return self._sum / max(1, self._count)
+
+
+class MetersGroup(object):
+ def __init__(self, csv_file_name, formating, use_wandb):
+ self._csv_file_name = csv_file_name
+ self._formating = formating
+ self._meters = defaultdict(AverageMeter)
+ self._csv_file = None
+ self._csv_writer = None
+ self.use_wandb = use_wandb
+
+ def log(self, key, value, n=1):
+ self._meters[key].update(value, n)
+
+ def _prime_meters(self):
+ data = dict()
+ for key, meter in self._meters.items():
+ if key.startswith('train'):
+ key = key[len('train') + 1:]
+ else:
+ key = key[len('eval') + 1:]
+ key = key.replace('/', '_')
+ data[key] = meter.value()
+ return data
+
+ def _remove_old_entries(self, data):
+ rows = []
+ with self._csv_file_name.open('r') as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ if 'episode' in row:
+ # BUGFIX: covers weird cases where CSV are badly written
+ if row['episode'] == '':
+ rows.append(row)
+ continue
+ if type(row['episode']) == type(None):
+ continue
+ if float(row['episode']) >= data['episode']:
+ break
+ rows.append(row)
+ with self._csv_file_name.open('w') as f:
+ # To handle CSV that have more keys than new data
+ keys = set(data.keys())
+ if len(rows) > 0: keys = keys | set(row.keys())
+ keys = sorted(list(keys))
+ #
+ writer = csv.DictWriter(f,
+ fieldnames=keys,
+ restval=0.0)
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+
+ def _dump_to_csv(self, data):
+ if self._csv_writer is None:
+ should_write_header = True
+ if self._csv_file_name.exists():
+ self._remove_old_entries(data)
+ should_write_header = False
+
+ self._csv_file = self._csv_file_name.open('a')
+ self._csv_writer = csv.DictWriter(self._csv_file,
+ fieldnames=sorted(data.keys()),
+ restval=0.0)
+ if should_write_header:
+ self._csv_writer.writeheader()
+
+ # To handle components that start training later
+ # (restval covers only when data has less keys than the CSV)
+ if self._csv_writer.fieldnames != sorted(data.keys()) and \
+ len(self._csv_writer.fieldnames) < len(data.keys()):
+ self._csv_file.close()
+ self._csv_file = self._csv_file_name.open('r')
+ dict_reader = csv.DictReader(self._csv_file)
+ rows = [row for row in dict_reader]
+ self._csv_file.close()
+ self._csv_file = self._csv_file_name.open('w')
+ self._csv_writer = csv.DictWriter(self._csv_file,
+ fieldnames=sorted(data.keys()),
+ restval=0.0)
+ self._csv_writer.writeheader()
+ for row in rows:
+ self._csv_writer.writerow(row)
+
+ self._csv_writer.writerow(data)
+ self._csv_file.flush()
+
+ def _format(self, key, value, ty):
+ if ty == 'int':
+ value = int(value)
+ return f'{key}: {value}'
+ elif ty == 'float':
+ return f'{key}: {value:.04f}'
+ elif ty == 'time':
+ value = str(datetime.timedelta(seconds=int(value)))
+ return f'{key}: {value}'
+ else:
+ raise f'invalid format type: {ty}'
+
+ def _dump_to_console(self, data, prefix):
+ prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
+ pieces = [f'| {prefix: <14}']
+ for key, disp_key, ty in self._formating:
+ value = data.get(key, 0)
+ pieces.append(self._format(disp_key, value, ty))
+ print(' | '.join(pieces))
+
+ def _dump_to_wandb(self, data):
+ wandb.log(data)
+
+ def dump(self, step, prefix):
+ if len(self._meters) == 0:
+ return
+ data = self._prime_meters()
+ data['frame'] = step
+ if self.use_wandb:
+ wandb_data = {prefix + '/' + key: val for key, val in data.items()}
+ self._dump_to_wandb(data=wandb_data)
+ # self._dump_to_csv(data)
+ self._dump_to_console(data, prefix)
+ self._meters.clear()
+
+
+class Logger(object):
+ def __init__(self, log_dir, use_tb, use_wandb):
+ self._log_dir = log_dir
+ self._train_mg = MetersGroup(log_dir / 'train.csv',
+ formating=COMMON_TRAIN_FORMAT,
+ use_wandb=use_wandb)
+ self._eval_mg = MetersGroup(log_dir / 'eval.csv',
+ formating=COMMON_EVAL_FORMAT,
+ use_wandb=use_wandb)
+ if use_tb:
+ self._sw = SummaryWriter(str(log_dir / 'tb'))
+ else:
+ self._sw = None
+ self.use_wandb = use_wandb
+
+ def _try_sw_log(self, key, value, step):
+ if self._sw is not None:
+ self._sw.add_scalar(key, value, step)
+
+ def log(self, key, value, step):
+ assert key.startswith('train') or key.startswith('eval')
+ if type(value) == torch.Tensor:
+ value = value.item()
+ self._try_sw_log(key, value, step)
+ mg = self._train_mg if key.startswith('train') else self._eval_mg
+ mg.log(key, value)
+
+ def log_metrics(self, metrics, step, ty):
+ for key, value in metrics.items():
+ self.log(f'{ty}/{key}', value, step)
+
+ def dump(self, step, ty=None):
+ if ty is None or ty == 'eval':
+ self._eval_mg.dump(step, 'eval')
+ if ty is None or ty == 'train':
+ self._train_mg.dump(step, 'train')
+
+ def log_and_dump_ctx(self, step, ty):
+ return LogAndDumpCtx(self, step, ty)
+
+ def log_visual(self, data, step):
+ if self._sw is not None:
+ for k, v in data.items():
+ if len(v.shape) == 3:
+ self._sw.add_image(k, v)
+ else:
+ if len(v.shape) == 4:
+ v = np.expand_dims(v, axis=0)
+ self._sw.add_video(k, v, global_step=step, fps=15)
+ if self.use_wandb:
+ for k, v in data.items():
+ if type(v) is not np.ndarray:
+ v = v.cpu()
+ if v.dtype not in [np.uint8]:
+ v = v*255
+ v = np.uint8(v)
+ if len(v.shape) == 3:
+ if v.shape[0] == 3:
+ v = v.transpose(1,2,0)
+ # Note: defaulting to save only one image/video to save storage on wandb
+ wandb.log({k: wandb.Image(v)},)
+ else:
+ # Note: defaulting to save only one image/video to save storage on wandb
+ wandb.log({k: wandb.Video(v, fps=15, format="gif")},)
+
+
+class LogAndDumpCtx:
+ def __init__(self, logger, step, ty):
+ self._logger = logger
+ self._step = step
+ self._ty = ty
+
+ def __enter__(self):
+ return self
+
+ def __call__(self, key, value):
+ self._logger.log(f'{self._ty}/{key}', value, self._step)
+
+ def __exit__(self, *args):
+ self._logger.dump(self._step, self._ty)
diff --git a/tools/replay.py b/tools/replay.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9646fb6b63f11d0f0bfc0765d08ddefcbf6e42a
--- /dev/null
+++ b/tools/replay.py
@@ -0,0 +1,319 @@
+import collections
+import datetime
+import io
+import pathlib
+import uuid
+import os
+
+import numpy as np
+from gym.spaces import Dict
+import random
+from torch.utils.data import IterableDataset, DataLoader
+import torch
+import tools.utils as utils
+import traceback
+from pathlib import Path
+from tqdm import tqdm
+
+SIG_FAILURE = -1
+
+def get_length(filename):
+ if "-" in str(filename):
+ length = int(str(filename).split('-')[-1])
+ else:
+ length = int(str(filename).split('_')[-1])
+ return length
+
+def get_idx(filename):
+ if "-" in str(filename):
+ length = int(str(filename).split('-')[0])
+ else:
+ length = int(str(filename).split('_')[0])
+ return length
+
+def on_fn(): return collections.defaultdict(list) # this function is to avoid lambdas
+
+class ReplayBuffer(IterableDataset):
+
+ def __init__(
+ self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0,
+ prioritize_ends=False, device='cuda', load_first=False, save_episodes=True, ignore_extra_keys=False, load_recursive=False, min_t_sampling=0, **kwargs):
+ self._directory = pathlib.Path(directory).expanduser()
+ self._directory.mkdir(parents=True, exist_ok=True)
+ self._capacity = capacity
+ self._ongoing = ongoing
+ self._minlen = minlen
+ self._maxlen = maxlen
+ self._prioritize_ends = prioritize_ends
+ self._ignore_extra_keys = ignore_extra_keys
+ self._min_t_sampling = min_t_sampling
+ # self._random = np.random.RandomState()
+ # filename -> key -> value_sequence
+
+ self._save_episodes = save_episodes
+ self._last_added_idx = 0
+
+ self._episode_lens = np.array([])
+ self._complete_eps = {}
+ self._data_specs = data_specs
+ self._meta_specs = meta_specs
+ for spec_group in [data_specs, meta_specs]:
+ for spec in spec_group:
+ if type(spec) in [dict, Dict]:
+ for k,v in spec.items():
+ self._complete_eps[k] = []
+ else:
+ self._complete_eps[spec.name] = []
+
+ # load episodes
+ if type(directory) == str:
+ directory = Path(directory)
+ self._loaded_episodes = 0
+ self._loaded_steps = 0
+ for f in tqdm(load_filenames(self._directory, capacity, minlen, load_first=load_first, load_recursive=load_recursive)):
+ self.store_episode(filename=f)
+ try:
+ self._total_episodes, self._total_steps = count_episodes(directory)
+ except:
+ print("Couldn't count episodes")
+ print("Loaded episodes: ", self._loaded_episodes)
+ print("Loaded steps: ", self._loaded_steps)
+ self._total_episodes, self._total_steps = self._loaded_episodes, self._loaded_steps
+
+ # worker -> key -> value_sequence
+ self._length = length
+ self._ongoing_eps = collections.defaultdict(on_fn)
+ self.device = device
+ try:
+ assert self._minlen <= self._length <= self._maxlen
+ except:
+ print("Sampling sequences with fixed length ", length)
+ self._minlen = self._maxlen = self._length = length
+
+ def __len__(self):
+ return self._total_steps
+
+ def preallocate_memory(self, max_size):
+ self._preallocated_mem = collections.defaultdict(list)
+ for spec in self._data_specs:
+ if type(spec) in [dict, Dict]:
+ for k,v in spec.items():
+ for _ in range(max_size):
+ self._preallocated_mem[k].append(np.empty(list(v.shape), v.dtype))
+ self._preallocated_mem[k][-1].fill(0.)
+ else:
+ for _ in range(max_size):
+ self._preallocated_mem[spec.name].append(np.empty(list(v.shape), v.dtype))
+ self._preallocated_mem[spec.name][-1].fill(0.)
+
+ @property
+ def stats(self):
+ return {
+ 'total_steps': self._total_steps,
+ 'total_episodes': self._total_episodes,
+ 'loaded_steps': self._loaded_steps,
+ 'loaded_episodes': self._loaded_episodes,
+ }
+
+ def add(self, time_step, meta, idx=0):
+ ### Useful if there was any failure in the environment
+ if time_step == SIG_FAILURE:
+ episode = self._ongoing_eps[idx]
+ episode.clear()
+ print("Discarding episode from process", idx)
+ return
+ ####
+
+ episode = self._ongoing_eps[idx]
+
+ def add_to_episode(name, data, spec):
+ value = data[name]
+ if np.isscalar(value):
+ value = np.full(spec.shape, value, spec.dtype)
+ assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })"
+ ### Deallocate preallocated memory
+ if getattr(self, '_preallocated_mem', False):
+ if len(self._preallocated_mem[name]) > 0:
+ tmp = self._preallocated_mem[name].pop()
+ del tmp
+ else:
+ # Out of pre-allocated memory
+ del self._preallocated_mem
+ ###
+ episode[name].append(value)
+
+ for spec in self._data_specs:
+ if type(spec) in [dict, Dict]:
+ for k,v in spec.items():
+ add_to_episode(k, time_step, v)
+ else:
+ add_to_episode(spec.name, time_step, spec)
+ for spec in self._meta_specs:
+ if type(spec) in [dict, Dict]:
+ for k,v in spec.items():
+ add_to_episode(k, meta, v)
+ else:
+ add_to_episode(spec.name, meta, spec)
+ if type(time_step) in [dict, Dict]:
+ if time_step['is_last']:
+ self.add_episode(episode)
+ episode.clear()
+ else:
+ if time_step.last():
+ self.add_episode(episode)
+ episode.clear()
+
+ def add_episode(self, episode):
+ length = eplen(episode)
+ if length < self._minlen:
+ print(f'Skipping short episode of length {length}.')
+ return
+ self._total_steps += length
+ self._total_episodes += 1
+ episode = {key: convert(value) for key, value in episode.items()}
+ if self._save_episodes:
+ filename = self.save_episode(self._directory, episode)
+ self.store_episode(episode=episode)
+
+ def store_episode(self, filename=None, episode=None, run_checks=True):
+ if filename is not None:
+ episode = load_episode(filename)
+ if len(episode['reward'].shape) == 1:
+ episode['reward'] = episode['reward'].reshape(-1, 1)
+ if 'discount' not in episode:
+ episode['discount'] = (1 - episode['is_terminal']).reshape(-1, 1).astype(np.float32)
+ #
+ if run_checks:
+ for spec_set in [self._data_specs, self._meta_specs]:
+ for spec in spec_set:
+ if type(spec) in [dict, Dict]:
+ for k,v in spec.items():
+ value = episode[k][0]
+ assert v.shape == value.shape and v.dtype == value.dtype, f"for ({k}) expected {v.dtype, v.shape, }), received ({value.dtype, value.shape, })"
+ else:
+ value = episode[spec.name][0]
+ assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({spec.name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })"
+ if not episode:
+ return False
+ length = eplen(episode)
+ if run_checks:
+ for k in episode:
+ assert len(episode[k]) == length, f'Found {episode[k].shape} VS eplen: {length}'
+
+ # Enforce limit
+ while self._loaded_steps + length > self._capacity:
+ for k in self._complete_eps:
+ self._complete_eps[k].pop(0)
+ removed_len, self._episode_lens = self._episode_lens[0], self._episode_lens[1:]
+ self._loaded_steps -= removed_len
+ self._loaded_episodes -= 1
+
+ # add episode
+ for k,v in episode.items():
+ if k not in self._complete_eps:
+ if self._ignore_extra_keys: continue
+ else: raise KeyError("Extra key ", k)
+ self._complete_eps[k].append(v)
+ self._episode_lens = np.append(self._episode_lens, length)
+ self._loaded_steps += length
+ self._loaded_episodes += 1
+
+ return True
+
+ def __iter__(self):
+ while True:
+ sequences, batch_size, batch_length = self._loaded_episodes, self.batch_size, self._length
+
+ b_indices = np.random.randint(0, sequences, size=batch_size)
+ t_indices = np.random.randint(np.zeros(batch_size) + self._min_t_sampling, self._episode_lens[b_indices]-batch_length+1, size=batch_size)
+ t_ranges = np.repeat( np.expand_dims(np.arange(0, batch_length,), 0), batch_size, axis=0) + np.expand_dims(t_indices, 1)
+
+ chunk = {}
+ for k in self._complete_eps:
+ chunk[k] = np.stack([self._complete_eps[k][b][t] for b,t in zip(b_indices, t_ranges)])
+ for k in chunk:
+ chunk[k] = torch.as_tensor(chunk[k], device=self.device)
+ yield chunk
+
+ @utils.retry
+ def save_episode(self, directory, episode):
+ idx = self._total_episodes
+ timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
+ identifier = str(uuid.uuid4().hex)
+ length = eplen(episode)
+ filename = directory / f'{idx}-{timestamp}-{identifier}-{length}.npz'
+ with io.BytesIO() as f1:
+ np.savez_compressed(f1, **episode)
+ f1.seek(0)
+ with filename.open('wb') as f2:
+ f2.write(f1.read())
+ return filename
+
+def load_episode(filename):
+ try:
+ with filename.open('rb') as f:
+ episode = np.load(f, allow_pickle=True)
+ episode = {k: episode[k] for k in episode.keys()}
+ except Exception as e:
+ print(f'Could not load episode {str(filename)}: {e}')
+ return False
+ return episode
+
+def count_episodes(directory):
+ filenames = list(directory.glob('*.npz'))
+ num_episodes = len(filenames)
+ if num_episodes == 0 : return 0, 0
+ if len(filenames) > 0 and "-" in str(filenames[0]):
+ num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames)
+ last_episode = sorted(list(int(n.stem.split('-')[0]) for n in filenames))[-1]
+ else:
+ num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames)
+ last_episode = sorted(list(int(n.stem.split('_')[0]) for n in filenames))[-1]
+ return last_episode, num_steps
+
+def load_filenames(directory, capacity=None, minlen=1, load_first=False, load_recursive=False):
+ # The returned directory from filenames to episodes is guaranteed to be in
+ # temporally sorted order.
+ if load_recursive:
+ filenames = sorted(directory.glob('**/*.npz'))
+ else:
+ filenames = sorted(directory.glob('*.npz'))
+ if capacity:
+ num_steps = 0
+ num_episodes = 0
+ ordered_filenames = filenames if load_first else reversed(filenames)
+ for filename in ordered_filenames:
+ if "-" in str(filename):
+ length = int(str(filename).split('-')[-1][:-4])
+ else:
+ length = int(str(filename).split('_')[-1][:-4])
+ num_steps += length
+ num_episodes += 1
+ if num_steps >= capacity:
+ break
+ if load_first:
+ filenames = filenames[:num_episodes]
+ else:
+ filenames = filenames[-num_episodes:]
+ return filenames
+
+def convert(value):
+ value = np.array(value)
+ if np.issubdtype(value.dtype, np.floating):
+ return value.astype(np.float32)
+ elif np.issubdtype(value.dtype, np.signedinteger):
+ return value.astype(np.int32)
+ elif np.issubdtype(value.dtype, np.uint8):
+ return value.astype(np.uint8)
+ return value
+
+def eplen(episode):
+ return len(episode['action'])
+
+def make_replay_loader(buffer, batch_size,):
+ buffer.batch_size = batch_size
+ return DataLoader(buffer,
+ batch_size=None,
+ # NOTE: do not use any workers,
+ # as they don't get copies of the replay buffer (requires different implementation)
+ )
\ No newline at end of file
diff --git a/tools/task_scores.py b/tools/task_scores.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72c2ab611d1ca19684b3ca9a9d20b026e1bcad0
--- /dev/null
+++ b/tools/task_scores.py
@@ -0,0 +1,85 @@
+MAX = {
+ 'walker_run' : 770,
+ 'walker_walk' : 960,
+ 'walker_stand' : 970,
+ 'quadruped_run' : 930,
+ 'quadruped_walk' : 960,
+ 'quadruped_stand' : 990,
+ 'kitchen_microwave' : 1,
+ 'kitchen_light' : 1,
+ 'kitchen_burner' : 1,
+ 'kitchen_slide' : 1,
+ 'stickman_run' : 830,
+ 'stickman_walk' : 960,
+ 'stickman_stand' : 970,
+ 'cheetah_run' : 890,
+ 'walker_one_foot' : 955,
+ 'walker_high_kick' : 960,
+ 'walker_lying_down' : 975,
+ 'walker_sit_knees' : 945,
+ 'walker_lunge_pose' : 945,
+ 'walker_flipping' : 720,
+ 'walker_urlb_flip' : 720,
+ 'quadruped_jump' : 875,
+ 'quadruped_two_legs' : 875,
+ 'quadruped_lie_down' : 965,
+ 'stickman_flipping' : 790,
+ 'stickman_one_foot' : 865,
+ 'stickman_high_kick' : 920,
+ 'stickman_lying_down' : 965,
+ 'stickman_legs_up' : 935,
+ 'stickman_sit_knees' : 966,
+ 'stickman_lunge_pose' : 950,
+ 'stickman_headstand' : 955,
+ 'stickman_boxing' : 920,
+ 'stickman_hands_up' : 830,
+ 'cheetah_standing' : 930,
+ 'cheetah_lying_down' : 920,
+ 'jaco_reach_bottom_right' : 230,
+ 'jaco_reach_top_right' : 230,
+ 'jaco_reach_bottom_left' : 230,
+ 'jaco_reach_top_left' : 230,
+}
+
+MIN = {
+ 'walker_run' : 30,
+ 'walker_walk' : 45,
+ 'walker_stand' : 150,
+ 'quadruped_run' : 10,
+ 'quadruped_walk' : 10,
+ 'quadruped_stand' : 15,
+ 'kitchen_microwave' : 0,
+ 'kitchen_light' : 0,
+ 'kitchen_burner' : 0,
+ 'kitchen_slide' : 0,
+ 'stickman_run' : 25,
+ 'stickman_walk' : 35,
+ 'stickman_stand' : 70,
+ 'cheetah_run' : 9,
+ 'walker_one_foot' : 20,
+ 'walker_high_kick' : 25,
+ 'walker_lying_down' : 170,
+ 'walker_sit_knees' : 100,
+ 'walker_lunge_pose' : 150,
+ 'walker_flipping' : 20,
+ 'walker_urlb_flip' : 20,
+ 'quadruped_jump' : 15,
+ 'quadruped_two_legs' : 14,
+ 'quadruped_lie_down' : 750,
+ 'stickman_flipping' : 45,
+ 'stickman_one_foot' : 20,
+ 'stickman_high_kick' : 55,
+ 'stickman_lying_down' : 380,
+ 'stickman_legs_up' : 115,
+ 'stickman_sit_knees' : 40,
+ 'stickman_lunge_pose' : 100,
+ 'stickman_headstand' : 180,
+ 'stickman_boxing' : 80,
+ 'stickman_hands_up' : 5,
+ 'cheetah_standing' : 5,
+ 'cheetah_lying_down' : 430,
+ 'jaco_reach_bottom_right' : 0,
+ 'jaco_reach_top_right' : 0,
+ 'jaco_reach_bottom_left' : 0,
+ 'jaco_reach_top_left' : 0,
+}
diff --git a/tools/utils.py b/tools/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..45e3583b79a38163b08c96ec716c731ea704fdc9
--- /dev/null
+++ b/tools/utils.py
@@ -0,0 +1,253 @@
+import math
+import random
+import time
+from functools import wraps
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import distributions as pyd
+from torch.distributions.utils import _standard_normal
+from collections.abc import MutableMapping
+
+class eval_mode:
+ def __init__(self, *models):
+ self.models = models
+
+ def __enter__(self):
+ self.prev_states = []
+ for model in self.models:
+ self.prev_states.append(model.training)
+ model.train(False)
+
+ def __exit__(self, *args):
+ for model, state in zip(self.models, self.prev_states):
+ model.train(state)
+ return False
+
+
+def set_seed_everywhere(seed):
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def soft_update_params(net, target_net, tau):
+ for param, target_param in zip(net.parameters(), target_net.parameters()):
+ target_param.data.copy_(tau * param.data +
+ (1 - tau) * target_param.data)
+
+
+def hard_update_params(net, target_net):
+ for param, target_param in zip(net.parameters(), target_net.parameters()):
+ target_param.data.copy_(param.data)
+
+
+def weight_init(m):
+ """Custom weight init for Conv2D and Linear layers."""
+ if isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight.data)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ gain = nn.init.calculate_gain('relu')
+ nn.init.orthogonal_(m.weight.data, gain)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+
+class Until:
+ def __init__(self, until, action_repeat=1):
+ self._until = until
+ self._action_repeat = action_repeat
+
+ def __call__(self, step):
+ if self._until is None:
+ return True
+ until = self._until // self._action_repeat
+ return step < until
+
+
+class Every:
+ def __init__(self, every, action_repeat=1):
+ self._every = every
+ self._action_repeat = action_repeat
+
+ def __call__(self, step):
+ if self._every is None:
+ return False
+ every = self._every // self._action_repeat
+ if step % every == 0:
+ return True
+ return False
+
+
+class Timer:
+ def __init__(self):
+ self._start_time = time.time()
+ self._last_time = time.time()
+
+ def reset(self):
+ elapsed_time = time.time() - self._last_time
+ self._last_time = time.time()
+ total_time = time.time() - self._start_time
+ return elapsed_time, total_time
+
+ def total_time(self):
+ return time.time() - self._start_time
+
+
+class TruncatedNormal(pyd.Normal):
+ def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
+ super().__init__(loc, scale, validate_args=False)
+ self.low = low
+ self.high = high
+ self.eps = eps
+
+ def _clamp(self, x):
+ clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
+ x = x - x.detach() + clamped_x.detach()
+ return x
+
+ def sample(self, sample_shape=torch.Size(), stddev_clip=None):
+ shape = self._extended_shape(sample_shape)
+ eps = _standard_normal(shape,
+ dtype=self.loc.dtype,
+ device=self.loc.device)
+ eps *= self.scale
+ if stddev_clip is not None:
+ eps = torch.clamp(eps, -stddev_clip, stddev_clip)
+ x = self.loc + eps
+ return self._clamp(x)
+
+
+class TanhTransform(pyd.transforms.Transform):
+ domain = pyd.constraints.real
+ codomain = pyd.constraints.interval(-1.0, 1.0)
+ bijective = True
+ sign = +1
+
+ def __init__(self, cache_size=1):
+ super().__init__(cache_size=cache_size)
+
+ @staticmethod
+ def atanh(x):
+ return 0.5 * (x.log1p() - (-x).log1p())
+
+ def __eq__(self, other):
+ return isinstance(other, TanhTransform)
+
+ def _call(self, x):
+ return x.tanh()
+
+ def _inverse(self, y):
+ # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
+ # one should use `cache_size=1` instead
+ return self.atanh(y)
+
+ def log_abs_det_jacobian(self, x, y):
+ # We use a formula that is more numerically stable, see details in the following link
+ # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
+ return 2. * (math.log(2.) - x - F.softplus(-2. * x))
+
+
+class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
+ def __init__(self, loc, scale):
+ self.loc = loc
+ self.scale = scale
+
+ self.base_dist = pyd.Normal(loc, scale)
+ transforms = [TanhTransform()]
+ super().__init__(self.base_dist, transforms)
+
+ @property
+ def mean(self):
+ mu = self.loc
+ for tr in self.transforms:
+ mu = tr(mu)
+ return mu
+
+def retry(func):
+ """
+ A Decorator to retry a function for a certain amount of attempts
+ """
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ attempts = 0
+ max_attempts = 1000
+ while attempts < max_attempts:
+ try:
+ return func(*args, **kwargs)
+ except (OSError, PermissionError):
+ attempts += 1
+ time.sleep(0.1)
+ raise OSError("Retry failed")
+
+ return wrapper
+
+def flatten_dict(dictionary, parent_key='', separator='_'):
+ items = []
+ for key in dictionary.keys():
+ try:
+ value = dictionary[key]
+ except:
+ value = '??? '
+ new_key = parent_key + separator + key if parent_key else key
+ if isinstance(value, MutableMapping):
+ items.extend(flatten_dict(value, new_key, separator=separator).items())
+ else:
+ items.append((new_key, value))
+ return dict(items)
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ '''
+ Spherical linear interpolation
+ Args:
+ t (float/np.ndarray): Float value between 0.0 and 1.0
+ v0 (np.ndarray): Starting vector
+ v1 (np.ndarray): Final vector
+ DOT_THRESHOLD (float): Threshold for considering the two vectors as
+ colineal. Not recommended to alter this.
+ Returns:
+ v2 (np.ndarray): Interpolation vector between v0 and v1
+ '''
+ c = False
+ if not isinstance(v0,np.ndarray):
+ c = True
+ v0 = v0.detach().cpu().numpy()
+ if not isinstance(v1,np.ndarray):
+ c = True
+ v1 = v1.detach().cpu().numpy()
+ if len(v0.shape) == 1:
+ v0 = v0.reshape(1, -1)
+ if len(v1.shape) == 1:
+ v1 = v1.reshape(1, -1)
+ # Copy the vectors to reuse them later
+ v0_copy = np.copy(v0)
+ v1_copy = np.copy(v1)
+ # Normalize the vectors to get the directions and angles
+ v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True)
+ v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
+ # Dot product with the normalized vectors (can't use np.dot in W)
+ dot = np.sum(v0 * v1, axis=-1)
+ # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
+ if (np.abs(dot) > DOT_THRESHOLD).any():
+ raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy)
+ # Calculate initial angle between v0 and v1
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ # Angle at timestep t
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ # Finish the slerp algorithm
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy
+ if c:
+ res = torch.from_numpy(v2).to("cuda")
+ else:
+ res = v2
+ return res
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c3c71332b50eff6d56e76d5b0f9d8bfdea3fca2
--- /dev/null
+++ b/train.py
@@ -0,0 +1,452 @@
+import warnings
+
+warnings.filterwarnings('ignore', category=DeprecationWarning)
+
+import os
+
+os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
+
+from pathlib import Path
+from collections import defaultdict
+
+import hydra
+import numpy as np
+import torch
+import wandb
+from dm_env import specs
+
+import tools.utils as utils
+from tools.logger import Logger
+from tools.replay import ReplayBuffer, make_replay_loader
+
+torch.backends.cudnn.benchmark = True
+
+def make_agent(obs_type, obs_spec, action_spec, num_expl_steps, cfg):
+ cfg.obs_type = obs_type
+ cfg.obs_shape = obs_spec.shape
+ cfg.action_shape = action_spec.shape
+ cfg.num_expl_steps = num_expl_steps
+ return hydra.utils.instantiate(cfg)
+
+
+def make_dreamer_agent(obs_space, action_spec, cur_config, cfg):
+ from copy import deepcopy
+ cur_config = deepcopy(cur_config)
+ if hasattr(cur_config, 'agent'):
+ del cur_config.agent
+ return hydra.utils.instantiate(cfg, cfg=cur_config, obs_space=obs_space, act_spec=action_spec)
+
+class Workspace:
+ def __init__(self, cfg, savedir=None, workdir=None,):
+ self.workdir = Path.cwd() if workdir is None else workdir
+ print(f'workspace: {self.workdir}')
+
+ self.cfg = cfg
+
+ utils.set_seed_everywhere(cfg.seed)
+ self.device = torch.device(cfg.device)
+
+ # create logger
+ self.logger = Logger(self.workdir,
+ use_tb=cfg.use_tb,
+ use_wandb=cfg.use_wandb)
+ # create envs
+ self.task = task = cfg.task
+ img_size = cfg.img_size
+
+ import envs.main as envs
+ self.train_env = envs.make(task, cfg.obs_type, cfg.action_repeat, cfg.seed, img_size=img_size, viclip_encode=cfg.viclip_encode, clip_hd_rendering=cfg.clip_hd_rendering)
+
+
+ # # create agent
+ sample_agent = make_dreamer_agent(self.train_env.obs_space, self.train_env.act_space['action'], cfg, cfg.agent)
+
+ # create replay buffer
+ data_specs = (self.train_env.obs_space,
+ self.train_env.act_space,
+ specs.Array((1,), np.float32, 'reward'),
+ specs.Array((1,), np.float32, 'discount'))
+
+ if cfg.train_from_data:
+ # Loading replay buffer
+ if cfg.replay_from_wandb_project is not None:
+ api = wandb.Api()
+ project_name = cfg.replay_from_wandb_project
+ params2search = {
+ "task" : cfg.task if cfg.task_snapshot is None else cfg.task_snapshot,
+ "seed" : cfg.seed if cfg.seed_snapshot is None else cfg.seed_snapshot,
+ }
+ runs = api.runs(f"PUT_YOUR_USER_HERE/{project_name}")
+ found = False
+ for run in runs:
+ if np.all([ v == run.config.get(k, None) for k,v in params2search.items()]):
+ found = True
+ found_path = Path(run.config['workdir'].replace('/code', ''))
+ break
+ if not found:
+ raise Exception("Replay from wandb buffer not found")
+
+ replay_dir = found_path / 'code' / 'buffer'
+ else:
+ replay_dir = Path(cfg.replay_load_dir)
+
+ # create data storage
+ self.replay_storage = ReplayBuffer(data_specs, [],
+ replay_dir,
+ length=cfg.batch_length, **cfg.replay,
+ device=cfg.device, ignore_extra_keys=True, load_recursive=True)
+ print('Loaded ', self.replay_storage._loaded_episodes, 'episodes from ', str(replay_dir))
+
+ # create replay buffer
+ self.replay_loader = make_replay_loader(self.replay_storage,
+ cfg.batch_size,)
+ self._replay_iter = None
+
+ # Loading snapshot
+ if cfg.snapshot_from_wandb_project is not None:
+ api = wandb.Api()
+ project_name = cfg.snapshot_from_wandb_project
+ params2search = {
+ "task" : cfg.task if cfg.task_snapshot is None else cfg.task_snapshot,
+ "agent_name" : cfg.agent.name if cfg.agent_name_snapshot is None else cfg.agent_name_snapshot,
+ "seed" : cfg.seed if cfg.seed_snapshot is None else cfg.seed_snapshot,
+ }
+ if cfg.agent.clip_lafite_noise > 0.:
+ params2search['clip_lafite_noise'] = cfg.agent.clip_lafite_noise
+ if cfg.agent.clip_add_noise > 0.:
+ params2search['clip_add_noise'] = cfg.agent.clip_add_noise
+ if cfg.reset_connector:
+ del params2search['clip_add_noise']
+ runs = api.runs(f"PUT_YOUR_USER_HERE/{project_name}")
+ found = False
+ for run in runs:
+ if np.all([ v == run.config.get(k, None) for k,v in params2search.items()]):
+ found = True
+ found_path = Path(run.config['workdir'].replace('/code', ''))
+ break
+ if not found:
+ raise Exception("Snapshot from wandb not found")
+
+ if cfg.snapshot_step is None:
+ snapshot_dir = found_path / 'code' / 'last_snapshot.pt'
+ else:
+ snapshot_dir = found_path / 'code' / f'snapshot_{cfg.snapshot_step}.pt'
+ elif cfg.snapshot_load_dir is not None:
+ snapshot_dir = Path(cfg.snapshot_load_dir)
+ else:
+ snapshot_dir = None
+
+ if snapshot_dir is not None:
+ self.load_snapshot(snapshot_dir, resume=False)
+ if self.cfg.reset_world_model:
+ self.agent.wm = sample_agent.wm
+ # To reset optimization
+ from agent import dreamer_utils as common
+ self.agent.wm.model_opt = common.Optimizer('model', self.agent.wm.parameters(), **self.agent.wm.cfg.model_opt, use_amp=self.agent.wm._use_amp)
+ if self.cfg.reset_connector:
+ self.agent.wm.connector = sample_agent.wm.connector
+ # To reset optimization
+ from agent import dreamer_utils as common
+ self.agent.wm.model_opt = common.Optimizer('model', self.agent.wm.parameters(), **self.agent.wm.cfg.model_opt, use_amp=self.agent.wm._use_amp)
+
+ # overwriting cfg
+ self.agent.cfg = sample_agent.cfg
+ self.agent.wm.cfg = sample_agent.wm.cfg
+
+ if self.cfg.reset_imag_behavior:
+ self.agent.instantiate_imag_behavior()
+ else:
+ self.agent = sample_agent
+
+ self.eval_env = envs.make(self.task, self.cfg.obs_type, self.cfg.action_repeat, self.cfg.seed, img_size=64, )
+ if hasattr(self.eval_env, 'eval_mode'):
+ self.eval_env.eval_mode()
+ eval_specs = (self.eval_env.obs_space,
+ self.eval_env.act_space,
+ specs.Array((1,), np.float32, 'reward'),
+ specs.Array((1,), np.float32, 'discount'))
+ self.eval_storage = ReplayBuffer(eval_specs, {},
+ self.workdir / 'eval_buffer',
+ length=cfg.batch_length, **cfg.replay,
+ device=cfg.device, ignore_extra_keys=True,)
+ self.eval_storage._minlen = 1
+
+ self.timer = utils.Timer()
+ self._global_step = 0
+ self._global_episode = 0
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ @property
+ def global_episode(self):
+ return self._global_episode
+
+ @property
+ def global_frame(self):
+ return self.global_step * self.cfg.action_repeat
+
+ @property
+ def replay_iter(self):
+ if self._replay_iter is None:
+ self._replay_iter = iter(self.replay_loader)
+ return self._replay_iter
+
+ def eval(self):
+ import envs.main as envs
+ eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
+ episode_reward = []
+ while eval_until_episode(len(episode_reward)):
+ if len(episode_reward) > 0 and self.global_step == 0:
+ return
+ episode_reward.append(0)
+ step, episode = 0, defaultdict(list)
+ meta = self.agent.init_meta()
+ time_step, dreamer_obs = self.eval_env.reset()
+ data = dreamer_obs
+ if 'clip_video' in data:
+ del data['clip_video']
+ self.eval_storage.add(data, meta)
+ agent_state = None
+ while not time_step.last():
+ with torch.no_grad(), utils.eval_mode(self.agent):
+ action, agent_state = self.agent.act(dreamer_obs,
+ meta,
+ self.global_step,
+ eval_mode=True,
+ state=agent_state)
+ time_step, dreamer_obs = self.eval_env.step(action)
+ for k in dreamer_obs:
+ episode[k].append(dreamer_obs[k])
+ episode_reward[-1] += time_step.reward
+ if time_step.last():
+ if episode_reward[-1] == np.max(episode_reward):
+ best_episode = {**episode}
+ if episode_reward[-1] == np.min(episode_reward):
+ worst_episode = {**episode}
+ data = dreamer_obs
+ if 'clip_video' in data:
+ del data['clip_video']
+ self.eval_storage.add(data, meta)
+ step += 1
+
+ if self.global_step > 0 and self.global_frame % self.cfg.log_episodes_every_frames == 0:
+ # B, T, C, H, W = video.shape
+ videos = {'best_episode' : np.stack(best_episode['observation'], axis=0),
+ 'worst_episode' : np.stack(worst_episode['observation'], axis=0),}
+ self.logger.log_visual(videos, self.global_frame)
+
+ with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
+ log('episode_reward', np.mean(episode_reward))
+ log('episode_length', step * self.cfg.action_repeat)
+ log('episode', self.global_episode)
+ log('step', self.global_step)
+
+ def eval_imag_behavior(self,):
+ self.agent._backup_acting_behavior = self.agent._acting_behavior
+ self.agent._acting_behavior = self.agent._imag_behavior
+ self.eval()
+ self.agent._acting_behavior = self.agent._backup_acting_behavior
+
+ def train(self):
+ # predicates
+ train_until_step = utils.Until(self.cfg.num_train_frames, 1)
+ eval_every_step = utils.Every(self.cfg.eval_every_frames, 1)
+ should_log_scalars = utils.Every(self.cfg.log_every_frames, 1)
+ should_save_model = utils.Every(self.cfg.save_every_frames, 1)
+ should_log_visual = utils.Every(self.cfg.visual_every_frames, 1)
+ metrics = None
+ while train_until_step(self.global_step):
+ # try to evaluate
+ if eval_every_step(self.global_step):
+ if self.cfg.eval_modality == 'task':
+ self.eval()
+ if self.cfg.eval_modality == 'task_imag':
+ self.eval_imag_behavior()
+ if self.cfg.eval_modality == 'from_text':
+ self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame)
+ self.eval_from_text()
+
+ if self.cfg.train_from_data:
+ # Sampling data
+ batch_data = next(self.replay_iter)
+ if self.cfg.train_world_model:
+ state, outputs, metrics = self.agent.update_wm(batch_data, self.global_step)
+ else:
+ with torch.no_grad():
+ outputs, metrics = self.agent.wm.observe_data(batch_data,)
+ if self.cfg.train_connector:
+ _, metrics = self.agent.wm.update_additional_detached_modules(batch_data, outputs, metrics)
+ else:
+ imag_warmup_steps = self.cfg.imag_warmup_steps
+ metrics, batch_data = {}, None
+ with torch.no_grad():
+ # fake actions
+ mix = self.cfg.mix_random_actions
+ random = False
+ # num warmup steps
+
+ if mix:
+ init = self.agent.wm.rssm.initial(self.cfg.batch_size * (self.cfg.batch_length // 2))
+ else:
+ init = self.agent.wm.rssm.initial(self.cfg.batch_size * self.cfg.batch_length)
+
+
+ unif_dist = self.agent.wm.rssm.get_unif_dist(init)
+ if 'logit' in init:
+ init['logit'] = unif_dist.mean
+ else:
+ init['mean'] = unif_dist.mean
+ init['std'] = unif_dist.std
+ init['stoch'] = unif_dist.sample()
+
+ if self.cfg.start_from_video in [True, 'mix']:
+ T = self.agent.wm.connector.n_frames * 2 # should this be an hyperparam?
+ B = init['deter'].shape[0] // T
+ text_feat_dim = self.agent.wm.connector.viclip_emb_dim
+ video_embed = torch.randn((B, T, text_feat_dim), device=self.agent.device)
+ video_embed = torch.nn.functional.normalize(video_embed, dim=-1)
+ # Get initial state
+ video_init = self.agent.wm.connector.video_imagine(video_embed, dreamer_init=None, sample=True, reset_every_n_frames=False, denoise=True)
+ video_init = { k : v.reshape(B * T, *v.shape[2:]) for k, v in video_init.items()}
+ if self.cfg.start_from_video == 'mix':
+ probs = torch.rand((B * T, 1,1), device=init['stoch'].device) > 0.5 # should this be an hyperparam?
+ init['stoch'] = (probs * init['stoch']) + ( (~probs) * video_init['stoch'] )
+ else:
+ init['stoch'] = video_init['stoch']
+
+ if random:
+ fake_action = torch.rand(self.cfg.batch_size * self.cfg.batch_length, imag_warmup_steps, self.agent.act_dim, device=self.agent.device) * 2 - 1
+ post = self.agent.wm.rssm.imagine(fake_action, init, sample=True)
+ post = { k : v[:, -1].reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[2:])) for k,v in post.items() }
+ elif mix:
+ fake_action = torch.rand(self.cfg.batch_size * self.cfg.batch_length // 2, imag_warmup_steps, self.agent.act_dim, device=self.agent.device) * 2 - 1
+ post1 = self.agent.wm.rssm.imagine(fake_action, init, sample=True)
+ post1 = { k : v[:, -1].reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[2:])) for k,v in post1.items() }
+
+ init2 = { k : v.reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[1:])) for k,v in init.items() }
+ post2 = self.agent.wm.imagine(self.agent._imag_behavior.actor, init2, None, imag_warmup_steps)
+ post2 = { k : v[-1, :].reshape([self.cfg.batch_size, self.cfg.batch_length // 2, ] + list(v.shape[2:])) for k,v in post2.items() }
+ post = { k: torch.cat([post1[k], post2[k]], dim=1) for k in post1 }
+ else:
+ init = { k : v.reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[1:])) for k,v in init.items() }
+ post = self.agent.wm.imagine(self.agent._imag_behavior.actor, init, None, imag_warmup_steps)
+ post = { k : v[-1, :].reshape([self.cfg.batch_size, self.cfg.batch_length, ] + list(v.shape[2:])) for k,v in post.items() }
+
+ is_terminal = torch.zeros(self.cfg.batch_size, self.cfg.batch_length, device=self.agent.device)
+ outputs = dict(post=post, is_terminal=is_terminal)
+ if getattr(self.cfg.agent, 'imag_reward_fn', None) is not None:
+ metrics.update(self.agent.update_imag_behavior(state=None, outputs=outputs, metrics=metrics, seq_data=batch_data,)[1])
+
+ if self.global_step > 0:
+ if should_log_scalars(self.global_step):
+ if hasattr(self, 'replay_storage'):
+ metrics.update(self.replay_storage.stats)
+ self.logger.log_metrics(metrics, self.global_frame, ty='train')
+ if should_log_visual(self.global_step) and self.cfg.train_from_data and hasattr(self.agent, 'report'):
+ with torch.no_grad(), utils.eval_mode(self.agent):
+ videos = self.agent.report(next(self.replay_iter))
+ self.logger.log_visual(videos, self.global_frame)
+ if should_log_scalars(self.global_step):
+ elapsed_time, total_time = self.timer.reset()
+ with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log:
+ log('fps', self.cfg.log_every_frames / elapsed_time)
+ log('step', self.global_step)
+ if 'model_loss' in metrics:
+ log('episode_reward', metrics['model_loss'].item())
+
+ # save last model
+ if should_save_model(self.global_step):
+ self.save_last_model()
+
+ self._global_step += 1
+ # == 1000 is to make sure everything is going well since the start
+ if (self.global_frame == 1000) or (self.global_frame % self.cfg.snapshot_every_frames == 0):
+ self.save_snapshot()
+
+ @utils.retry
+ def save_snapshot(self):
+ snapshot = self.root_dir / f'snapshot_{self.global_frame}.pt'
+ keys_to_save = ['agent', '_global_step', '_global_episode']
+ payload = {k: self.__dict__[k] for k in keys_to_save}
+ with snapshot.open('wb') as f:
+ torch.save(payload, f)
+
+ def setup_wandb(self):
+ cfg = self.cfg
+ exp_name = '_'.join([
+ cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
+ str(cfg.seed)
+ ])
+ wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name)
+ flat_cfg = utils.flatten_dict(cfg)
+ wandb.config.update(flat_cfg)
+ self.wandb_run_id = wandb.run.id
+
+ @utils.retry
+ def save_last_model(self):
+ snapshot = self.root_dir / 'last_snapshot.pt'
+ if snapshot.is_file():
+ temp = Path(str(snapshot).replace("last_snapshot.pt", "second_last_snapshot.pt"))
+ os.replace(snapshot, temp)
+ keys_to_save = ['agent', '_global_step', '_global_episode']
+ if self.cfg.use_wandb:
+ keys_to_save.append('wandb_run_id')
+ payload = {k: self.__dict__[k] for k in keys_to_save}
+ with snapshot.open('wb') as f:
+ torch.save(payload, f)
+
+ @utils.retry
+ def load_snapshot(self, snapshot_dir, resume=True):
+ print('Loading snapshot from: ', str(snapshot_dir))
+ try:
+ snapshot = snapshot_dir / 'last_snapshot.pt' if resume else snapshot_dir
+ with snapshot.open('rb') as f:
+ payload = torch.load(f)
+ except:
+ snapshot = Path(str(snapshot_dir).replace('last_snapshot', 'second_last_snapshot'))
+ with snapshot.open('rb') as f:
+ payload = torch.load(f)
+ if type(payload) != dict:
+ self.agent = payload
+ self.agent.requires_grad_(requires_grad=False)
+ return
+ for k,v in payload.items():
+ setattr(self, k, v)
+ if k == 'wandb_run_id' and resume:
+ assert wandb.run is None
+ cfg = self.cfg
+ exp_name = '_'.join([
+ cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type,
+ str(cfg.seed)
+ ])
+ wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, id=v, resume="must")
+
+ def get_snapshot_dir(self):
+ snap_dir = self.cfg.snapshot_dir
+ snapshot_dir = self.workdir / Path(snap_dir)
+ snapshot_dir.mkdir(exist_ok=True, parents=True)
+ return snapshot_dir
+
+def start_training(cfg, savedir, workdir):
+ from train import Workspace as W
+ root_dir = Path.cwd()
+ cfg.workdir = str(root_dir)
+ workspace = W(cfg, savedir, workdir)
+ workspace.root_dir = root_dir
+ snapshot = workspace.root_dir / 'last_snapshot.pt'
+ if snapshot.exists():
+ print(f'resuming: {snapshot}')
+ workspace.load_snapshot(workspace.root_dir)
+ if cfg.use_wandb and wandb.run is None:
+ # otherwise it was resumed
+ workspace.setup_wandb()
+ workspace.train()
+
+@hydra.main(config_path='.', config_name='train')
+def main(cfg):
+ start_training(cfg, None, None)
+
+if __name__ == '__main__':
+ main()
diff --git a/train.yaml b/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..38581ea9dddd9e3c81b414cf05c70faa907bd92f
--- /dev/null
+++ b/train.yaml
@@ -0,0 +1,86 @@
+defaults:
+ - agent: genrl
+ - conf/defaults: genrl
+ - conf/env: dmc_pixels
+ - override hydra/launcher: submitit_local
+
+# mode
+label: default
+# task settings
+task: stickman_walk
+# train settings
+num_train_frames: 200_010
+num_seed_frames: 400
+# eval
+eval_every_frames: 5_000
+eval_modality: null
+num_eval_episodes: 10
+# snapshot
+snapshot_dir: ../../../models/${obs_type}/${task}/${agent.name}/${seed}
+snapshot_every_frames: 50_000
+save_every_frames: 1000
+# misc
+seed: 1
+device: cuda:0
+use_tb: true
+use_wandb: true
+
+# Clip stuff
+viclip_encode: true
+viclip_model: internvideo2
+clip_add_noise: 0.0
+clip_hd_rendering: false
+
+# experiment
+experiment: train
+project_name: genrl
+
+# log settings
+log_every_frames: 1_000
+visual_every_frames: 100000000 # edit for debug
+log_episodes_every_frames: 100_000
+
+workdir: ???
+
+# training models
+train_connector: false
+train_world_model: false
+
+reset_connector: false
+reset_world_model: false
+reset_imag_behavior: true
+
+# loading options
+replay_from_wandb_project: null
+snapshot_from_wandb_project: null
+
+task_snapshot: null
+seed_snapshot: null
+agent_name_snapshot: null
+
+snapshot_load_dir: null
+replay_load_dir: null
+
+#
+snapshot_step: null
+
+# data-free
+train_from_data: true
+start_from_video: mix
+mix_random_actions: True
+imag_warmup_steps: 5
+
+hydra:
+ run:
+ dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}
+ sweep:
+ dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}
+ subdir: ${hydra.job.num}
+ launcher:
+ timeout_min: 4300
+ cpus_per_task: 10
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 160
+ nodes: 1
+ submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm