meow
init
d6d3a5b
import torch
from data_loaders.humanml.networks.modules import *
from data_loaders.humanml.networks.trainers import CompTrainerV6
from torch.utils.data import Dataset, DataLoader
from os.path import join as pjoin
from tqdm import tqdm
from utils import dist_util
def build_models(opt):
if opt.text_enc_mod == 'bigru':
text_encoder = TextEncoderBiGRU(word_size=opt.dim_word,
pos_size=opt.dim_pos_ohot,
hidden_size=opt.dim_text_hidden,
device=opt.device)
text_size = opt.dim_text_hidden * 2
else:
raise Exception("Text Encoder Mode not Recognized!!!")
seq_prior = TextDecoder(text_size=text_size,
input_size=opt.dim_att_vec + opt.dim_movement_latent,
output_size=opt.dim_z,
hidden_size=opt.dim_pri_hidden,
n_layers=opt.n_layers_pri)
seq_decoder = TextVAEDecoder(text_size=text_size,
input_size=opt.dim_att_vec + opt.dim_z + opt.dim_movement_latent,
output_size=opt.dim_movement_latent,
hidden_size=opt.dim_dec_hidden,
n_layers=opt.n_layers_dec)
att_layer = AttLayer(query_dim=opt.dim_pos_hidden,
key_dim=text_size,
value_dim=opt.dim_att_vec)
movement_enc = MovementConvEncoder(opt.dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, opt.dim_pose)
len_estimator = MotionLenEstimatorBiGRU(opt.dim_word, opt.dim_pos_ohot, 512, opt.num_classes)
# latent_dis = LatentDis(input_size=opt.dim_z * 2)
checkpoints = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_est_bigru', 'model', 'latest.tar'), map_location=opt.device)
len_estimator.load_state_dict(checkpoints['estimator'])
len_estimator.to(opt.device)
len_estimator.eval()
# return text_encoder, text_decoder, att_layer, vae_pri, vae_dec, vae_pos, motion_dis, movement_dis, latent_dis
return text_encoder, seq_prior, seq_decoder, att_layer, movement_enc, movement_dec, len_estimator
class CompV6GeneratedDataset(Dataset):
def __init__(self, opt, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
assert mm_num_samples < len(dataset)
print(opt.model_dir)
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
text_enc, seq_pri, seq_dec, att_layer, mov_enc, mov_dec, len_estimator = build_models(opt)
trainer = CompTrainerV6(opt, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=mov_enc)
epoch, it, sub_ep, schedule_len = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
generated_motion = []
mm_generated_motions = []
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
mm_idxs = np.sort(mm_idxs)
min_mov_length = 10 if opt.dataset_name == 't2m' else 6
# print(mm_idxs)
print('Loading model: Epoch %03d Schedule_len %03d' % (epoch, schedule_len))
trainer.eval_mode()
trainer.to(opt.device)
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader)):
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
tokens = tokens[0].split('_')
word_emb = word_emb.detach().to(opt.device).float()
pos_ohot = pos_ohot.detach().to(opt.device).float()
pred_dis = len_estimator(word_emb, pos_ohot, cap_lens)
pred_dis = nn.Softmax(-1)(pred_dis).squeeze()
mm_num_now = len(mm_generated_motions)
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
repeat_times = mm_num_repeats if is_mm else 1
mm_motions = []
for t in range(repeat_times):
mov_length = torch.multinomial(pred_dis, 1, replacement=True)
if mov_length < min_mov_length:
mov_length = torch.multinomial(pred_dis, 1, replacement=True)
if mov_length < min_mov_length:
mov_length = torch.multinomial(pred_dis, 1, replacement=True)
m_lens = mov_length * opt.unit_length
pred_motions, _, _ = trainer.generate(word_emb, pos_ohot, cap_lens, m_lens,
m_lens[0]//opt.unit_length, opt.dim_pose)
if t == 0:
# print(m_lens)
# print(text_data)
sub_dict = {'motion': pred_motions[0].cpu().numpy(),
'length': m_lens[0].item(),
'cap_len': cap_lens[0].item(),
'caption': caption[0],
'tokens': tokens}
generated_motion.append(sub_dict)
if is_mm:
mm_motions.append({
'motion': pred_motions[0].cpu().numpy(),
'length': m_lens[0].item()
})
if is_mm:
mm_generated_motions.append({'caption': caption[0],
'tokens': tokens,
'cap_len': cap_lens[0].item(),
'mm_motions': mm_motions})
self.generated_motion = generated_motion
self.mm_generated_motion = mm_generated_motions
self.opt = opt
self.w_vectorizer = w_vectorizer
def __len__(self):
return len(self.generated_motion)
def __getitem__(self, item):
data = self.generated_motion[item]
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
sent_len = data['cap_len']
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
if m_length < self.opt.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
class CompMDMGeneratedDataset(Dataset):
def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=1.):
self.dataloader = dataloader
self.dataset = dataloader.dataset
assert mm_num_samples < len(dataloader.dataset)
use_ddim = False # FIXME - hardcoded
clip_denoised = False # FIXME - hardcoded
self.max_motion_length = max_motion_length
sample_fn = (
diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
)
real_num_batches = len(dataloader)
if num_samples_limit is not None:
real_num_batches = num_samples_limit // dataloader.batch_size + 1
print('real_num_batches', real_num_batches)
generated_motion = []
mm_generated_motions = []
if mm_num_samples > 0:
mm_idxs = np.random.choice(real_num_batches, mm_num_samples // dataloader.batch_size +1, replace=False)
mm_idxs = np.sort(mm_idxs)
else:
mm_idxs = []
print('mm_idxs', mm_idxs)
model.eval()
with torch.no_grad():
for i, (motion, model_kwargs) in tqdm(enumerate(dataloader)):
if num_samples_limit is not None and len(generated_motion) >= num_samples_limit:
break
tokens = [t.split('_') for t in model_kwargs['y']['tokens']]
# add CFG scale to batch
if scale != 1.:
model_kwargs['y']['scale'] = torch.ones(motion.shape[0],
device=dist_util.dev()) * scale
mm_num_now = len(mm_generated_motions) // dataloader.batch_size
is_mm = i in mm_idxs
repeat_times = mm_num_repeats if is_mm else 1
mm_motions = []
for t in range(repeat_times):
sample = sample_fn(
model,
motion.shape,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
init_image=None,
progress=False,
dump_steps=None,
noise=None,
const_noise=False,
# when experimenting guidance_scale we want to nutrileze the effect of noise on generation
)
if t == 0:
sub_dicts = [{'motion': sample[bs_i].squeeze().permute(1,0).cpu().numpy(),
'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
'caption': model_kwargs['y']['text'][bs_i],
'tokens': tokens[bs_i],
'cap_len': len(tokens[bs_i]),
} for bs_i in range(dataloader.batch_size)]
generated_motion += sub_dicts
if is_mm:
mm_motions += [{'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
} for bs_i in range(dataloader.batch_size)]
if is_mm:
mm_generated_motions += [{
'caption': model_kwargs['y']['text'][bs_i],
'tokens': tokens[bs_i],
'cap_len': len(tokens[bs_i]),
'mm_motions': mm_motions[bs_i::dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions
} for bs_i in range(dataloader.batch_size)]
self.generated_motion = generated_motion
self.mm_generated_motion = mm_generated_motions
self.w_vectorizer = dataloader.dataset.w_vectorizer
def __len__(self):
return len(self.generated_motion)
def __getitem__(self, item):
data = self.generated_motion[item]
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
sent_len = data['cap_len']
if self.dataset.mode == 'eval':
normed_motion = motion
denormed_motion = self.dataset.t2m_dataset.inv_transform(normed_motion)
renormed_motion = (denormed_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval # according to T2M norms
motion = renormed_motion
# This step is needed because T2M evaluators expect their norm convention
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)