Spaces:
Running
Running
File size: 7,103 Bytes
89dc200 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# -*- encoding: utf-8 -*-
'''
@File : pretrain_cogvideo.py
@Time : 2021/10/06 00:58:32
@Author : Wenyi Hong
@Contact : [email protected]
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import argparse
import numpy as np
from icetk import icetk as tokenizer
tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
from models.cogvideo_model import CogVideoModel
from SwissArmyTransformer import mpu, get_args
from SwissArmyTransformer.training.deepspeed_training import training_main
from SwissArmyTransformer.data_utils import BinaryDataset
def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
assert attention_mask_totxt is not None
layout = args.layout
assert seq_length == layout[-1]
n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
frame_len = layout[1]-layout[0]
position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
device=data.device)
for i in range(batch_size):
torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
dtype=torch.long, device=data.device)
torch.arange(512, 512+layout[2]-layout[0],
out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
return position_ids
def get_batch(data_iterator, args, timers):
# Items and their type.
keys = ['text', 'loss_mask', 'attention_mask_totxt']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
loss_mask = data_b['loss_mask'].float()
attention_mask_totxt = data_b['attention_mask_totxt'].float()
labels = tokens_[:, 1:].clone().contiguous()
loss_mask = loss_mask[:, 1:].contiguous()
tokens = tokens_[:, :-1].clone().contiguous()
for idx in range(args.layout[0], args.layout[2], 400):
tokens[:, idx] = tokenizer['<start_of_image>']
# Get the masks and postition ids.
position_ids = get_masks_and_position_ids_video(
tokens,
attention_mask_totxt=attention_mask_totxt,
args=args
)
attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
# Convert
if args.fp16:
attention_mask_totxt = attention_mask_totxt.half()
return tokens, labels, loss_mask, attention_mask_totxt, position_ids
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
# Forward model.
logits, *mems = model(tokens, position_ids, attention_mask_totxt)
# ======= hyper params =======#
perframe_len = 400
text_len=64
frame_num = 5
logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
# scaling loss mask
loss_mask = loss_mask[:, text_len:].reshape(-1)
losses_1d = losses.reshape(-1) * loss_mask
loss = torch.sum(losses_1d) / loss_mask.sum()
# ===================== Log partial losses ======================== #
log_loss_dict = {}
bs = losses.shape[0]
if args.cogvideo_stage == 1:
for i in range(frame_num):
log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
else:
for i in range(1, frame_num-1):
log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
# ===================== END OF BLOCK ======================= #
return loss, log_loss_dict
def create_dataset_function(path, args):
dataset_layout = [64, 464, 2064]
input_layout = [64, 464, 2064]
# frame_num = 6
# frame_interval = 2 # DEBUG!!!
def process_fn(row):
row = row.astype(np.int64)
text = row[:dataset_layout[0]]
frames = row[dataset_layout[0]:]
if text[0] == tokenizer['<pad>']:
text = text[1:] # due to our way of data processing
if args.cogvideo_stage == 1:
text, loss_mask, frames = make_text_video_generation(text, frames)
else:
text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
n_pad = input_layout[0] - len(text)
parts = [
np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
text,
np.array([tokenizer['<start_of_image>']], dtype=np.int64),
frames,
]
ret = np.concatenate(parts, axis=0)
attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
return {'text': ret,
'loss_mask': loss_mask,
'attention_mask_totxt': attention_mask_totxt,
}
return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
def make_text_video_generation(text, frames):
input_layout = [64, 464, 2064]
text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
return text, loss_mask, frames
def mask_video_frame_interpolation(text, frames):
input_layout = [64, 464, 2064]
frame_len = input_layout[1]-input_layout[0]
# text format: <pad> 1.0秒 <n> {text} <pad> <pad>
text = text[text!= tokenizer['<pad>']][:input_layout[0]]
loss_mask = np.array([0] * (input_layout[1]+1)
+ [1] * (input_layout[1]-input_layout[0])
+ [0] * (input_layout[1]-input_layout[0])
+ [1] * (input_layout[1]-input_layout[0])
+ [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
return text, loss_mask, frames
if __name__ == '__main__':
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--txt-loss-scale', type=float, default=1)
CogVideoModel.add_model_specific_args(py_parser)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.layout = [int(x) for x in args.layout.split(',')]
training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
|