|
import argparse |
|
import json |
|
import random |
|
from pathlib import Path |
|
|
|
import imageio |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModel |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
IMAGE_SIZE = (288, 512) |
|
N_FRAMES_PER_ROUND = 25 |
|
MAX_NUM_FRAMES = 50 |
|
N_TOKENS_PER_FRAME = 576 |
|
TRAJ_TEMPLATE_PATH = Path("./assets/template_trajectory.json") |
|
PATH_START_ID = 9 |
|
PATH_POINT_INTERVAL = 10 |
|
N_ACTION_TOKENS = 6 |
|
WM_TOKENIZER_COMBINATION = { |
|
"world_model": "lfq_tokenizer_B_256", |
|
"world_model_v2": "lfq_tokenizer_B_256_ema", |
|
} |
|
|
|
|
|
CONDITIONING_FRAMES_DIR = Path("./assets/conditioning_frames") |
|
CONDITIONING_FRAMES_PATH_LIST = [ |
|
CONDITIONING_FRAMES_DIR / "001.png", |
|
CONDITIONING_FRAMES_DIR / "002.png", |
|
CONDITIONING_FRAMES_DIR / "003.png" |
|
] |
|
|
|
|
|
def set_random_seed(seed: int = 0): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def preprocess_image(image: Image.Image, size: tuple[int, int] = (288, 512)) -> torch.Tensor: |
|
H, W = size |
|
image = image.convert("RGB") |
|
image = image.resize((W, H)) |
|
image_array = np.array(image) |
|
image_array = (image_array / 127.5 - 1.0).astype(np.float32) |
|
return torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0).float() |
|
|
|
|
|
def to_np_images(images: torch.Tensor) -> np.ndarray: |
|
images = images.detach().cpu() |
|
images = torch.clamp(images, -1., 1.) |
|
images = (images + 1.) / 2. |
|
images = images.permute(0, 2, 3, 1).numpy() |
|
return (255 * images).astype(np.uint8) |
|
|
|
|
|
def load_images(file_path_list: list[Path], size: tuple[int, int] = (288, 512)) -> torch.Tensor: |
|
images = [] |
|
for file_path in file_path_list: |
|
image = Image.open(file_path) |
|
image = preprocess_image(image, size) |
|
images.append(image) |
|
return torch.cat(images, dim=0) |
|
|
|
|
|
def save_images_to_mp4(images: np.ndarray, output_path: Path, fps: int = 10): |
|
writer = imageio.get_writer(output_path, fps=fps) |
|
for img in images: |
|
writer.append_data(img) |
|
writer.close() |
|
|
|
|
|
def determine_num_rounds(num_frames: int, num_overlapping_frames: int, n_initial_frames: int) -> int: |
|
n_rounds = (num_frames - n_initial_frames) // (N_FRAMES_PER_ROUND - num_overlapping_frames) |
|
if (num_frames - n_initial_frames) % (N_FRAMES_PER_ROUND - num_overlapping_frames) > 0: |
|
n_rounds += 1 |
|
return n_rounds |
|
|
|
|
|
def prepare_action( |
|
traj_template: dict, |
|
cmd: str, |
|
path_start_id: int, |
|
path_point_interval: int, |
|
n_action_tokens: int = 5, |
|
start_index: int = 0, |
|
n_frames: int = 25 |
|
) -> torch.Tensor: |
|
trajs = traj_template[cmd]["instruction_trajs"] |
|
actions = [] |
|
timesteps = np.arange(0.0, 3.0, 0.05) |
|
for i in range(start_index, start_index + n_frames): |
|
traj = trajs[i][path_start_id::path_point_interval][:n_action_tokens] |
|
action = np.array(traj) |
|
timestep = timesteps[path_start_id::path_point_interval][:n_action_tokens] |
|
action = np.concatenate([ |
|
action[:, [1, 0]], |
|
timestep.reshape(-1, 1) |
|
], axis=1) |
|
actions.append(torch.tensor(action)) |
|
return torch.cat(actions, dim=0) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=0) |
|
parser.add_argument("--output_dir", type=Path) |
|
parser.add_argument("--cmd", type=str, default="curving_to_left/curving_to_left_moderate") |
|
parser.add_argument("--num_frames", type=int, default=25) |
|
parser.add_argument("--num_overlapping_frames", type=int, default=3) |
|
parser.add_argument("--model_name", type=str, default="world_model_v2") |
|
args = parser.parse_args() |
|
|
|
assert args.num_frames <= MAX_NUM_FRAMES, f"`num_frames` should be less than or equal to {MAX_NUM_FRAMES}" |
|
assert args.num_overlapping_frames < N_FRAMES_PER_ROUND, f"`num_overlapping_frames` should be less than {N_FRAMES_PER_ROUND}" |
|
|
|
set_random_seed(args.seed) |
|
if args.output_dir is None: |
|
output_dir = Path(f"./outputs/{args.cmd}") |
|
else: |
|
output_dir = args.output_dir |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
tokenizer_name = WM_TOKENIZER_COMBINATION[args.model_name] |
|
tokenizer = AutoModel.from_pretrained("turing-motors/Terra", subfolder=tokenizer_name, trust_remote_code=True).to(device).eval() |
|
model = AutoModel.from_pretrained("turing-motors/Terra", subfolder=args.model_name, trust_remote_code=True).to(device).eval() |
|
|
|
conditioning_frames = load_images(CONDITIONING_FRAMES_PATH_LIST, IMAGE_SIZE).to(device) |
|
with torch.inference_mode(), torch.autocast(device_type="cuda"): |
|
input_ids = tokenizer.tokenize(conditioning_frames).detach().unsqueeze(0) |
|
|
|
num_rounds = determine_num_rounds(args.num_frames, args.num_overlapping_frames, len(CONDITIONING_FRAMES_PATH_LIST)) |
|
print(f"Number of generation rounds: {num_rounds}") |
|
|
|
with open(TRAJ_TEMPLATE_PATH) as f: |
|
traj_template = json.load(f) |
|
|
|
all_outputs = [] |
|
for round in range(num_rounds): |
|
start_index = round * (N_FRAMES_PER_ROUND - args.num_overlapping_frames) |
|
num_frames_for_round = min(N_FRAMES_PER_ROUND, args.num_frames - start_index) |
|
actions = prepare_action( |
|
traj_template, args.cmd, PATH_START_ID, PATH_POINT_INTERVAL, N_ACTION_TOKENS, start_index, num_frames_for_round |
|
).unsqueeze(0).to(device).float() |
|
if round == 0: |
|
num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - len(CONDITIONING_FRAMES_PATH_LIST)) |
|
else: |
|
num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - args.num_overlapping_frames) |
|
progress_bar = tqdm(total=num_generated_tokens, desc=f"Round {round + 1}") |
|
with torch.inference_mode(), torch.autocast(device_type="cuda"): |
|
output_tokens = model.generate( |
|
input_ids=input_ids, |
|
actions=actions, |
|
do_sample=True, |
|
max_length=N_TOKENS_PER_FRAME * num_frames_for_round, |
|
temperature=1.0, |
|
top_p=1.0, |
|
use_cache=True, |
|
pad_token_id=None, |
|
eos_token_id=None, |
|
progress_bar=progress_bar |
|
) |
|
if round == 0: |
|
all_outputs.append(output_tokens[0]) |
|
else: |
|
all_outputs.append(output_tokens[0, args.num_overlapping_frames * N_TOKENS_PER_FRAME:]) |
|
input_ids = output_tokens[:, -args.num_overlapping_frames * N_TOKENS_PER_FRAME:] |
|
progress_bar.close() |
|
|
|
output_ids = torch.cat(all_outputs) |
|
|
|
|
|
downsample_ratio = 1 |
|
for coef in tokenizer.config.encoder_decoder_config["ch_mult"]: |
|
downsample_ratio *= coef |
|
h = IMAGE_SIZE[0] // downsample_ratio |
|
w = IMAGE_SIZE[1] // downsample_ratio |
|
c = tokenizer.config.encoder_decoder_config["z_channels"] |
|
latent_shape = (len(output_ids) // 576, h, w, c) |
|
|
|
|
|
with torch.inference_mode(), torch.autocast(device_type="cuda"): |
|
reconstructed = tokenizer.decode_tokens(output_ids, latent_shape) |
|
reconstructed_images = to_np_images(reconstructed) |
|
save_images_to_mp4(reconstructed_images, output_dir / "generated.mp4", fps=10) |
|
|