This-and-That / data_loader /video_this_that_dataset.py
HikariDawn777's picture
feat: initial push
59b2a81
import os, sys
import json
import cv2
import math
import shutil
import numpy as np
import random
from PIL import Image
import torch.nn.functional as F
import torch
import os.path as osp
import time
from moviepy.editor import VideoFileClip
from torch.utils.data import Dataset
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from utils.img_utils import resize_with_antialiasing, numpy_to_pt
from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
from data_loader.video_dataset import tokenize_captions
# For the 2D dilation
blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True)
def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False):
'''
Args:
idx (int): The index to the folder we need to process
'''
# Read file
file_path = os.path.join(intput_dir, "data.txt")
file1 = open(file_path, 'r')
Lines = file1.readlines()
# Initial the optical flow format we want
thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty
# Init the image
sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg"))
org_height, org_width, _ = sample_img.shape
# Prepare masking
controlnet_image_index = []
coordinate_values = []
# Iterate all points in the txt file
for idx in range(len(Lines)):
# Read points
frame_idx, horizontal, vertical = Lines[idx].split(' ')
frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal))
# Read the mask frame idx
controlnet_image_index.append(frame_idx)
coordinate_values.append((vertical, horizontal))
# Init the base image
base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size
base_img.fill(255)
# Draw square around the target position
dot_range = 10 # Diameter
for i in range(-1*dot_range, dot_range+1):
for j in range(-1*dot_range, dot_range+1):
dil_vertical, dil_horizontal = vertical + i, horizontal + j
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
if idx == 0:
base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
else:
base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
# Dilate
if config["dilate"]:
base_img = cv2.filter2D(base_img, -1, blur_kernel)
##############################################################################################################################
### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store
# Resize frames Don't use negative and don't resize in [0,1]
base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
# Flip the image for aug if needed
if flip:
base_img = np.fliplr(base_img)
# Channel Transform and Range Shift
if config["conditioning_channels"] == 3:
# Map to [0, 1] range
if store_dir is not None and verbose: # For the first frame condition visualization
cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img)
base_img = base_img / 255.0
else:
raise NotImplementedError()
# ReOrganize shape
base_img = base_img.transpose(2, 0, 1) # hwc -> chw
# Check the min max value range
# if verbose:
# print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img)))
# Write base img based on frame_idx
thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized
##############################################################################################################################
if config["motion_bucket_id"] is None:
# take the motion to stats collected before
reflected_motion_bucket_id = 200
else:
reflected_motion_bucket_id = config["motion_bucket_id"]
# print("Motion Bucket ID is ", reflected_motion_bucket_id)
return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values)
class Video_ThisThat_Dataset(Dataset):
'''
Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow
'''
def __init__(self, config, device, normalize=True, tokenizer=None):
# Attribute variables
self.config = config
self.device = device
self.normalize = normalize
self.tokenizer = tokenizer
# Obtain values
self.video_seq_length = config["video_seq_length"]
self.height = config["height"]
self.width = config["width"]
# Process data
self.video_lists = []
for dataset_path in config["dataset_path"]:
for video_name in sorted(os.listdir(dataset_path)):
if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")):
continue
self.video_lists.append(os.path.join(dataset_path, video_name))
print("length of the dataset is ", len(self.video_lists))
def __len__(self):
return len(self.video_lists)
def _extract_frame_bridge(self, idx, flip=False):
''' Extract the frame in video based on the needed fps from already extracted frame
Args:
idx (int): The index to the file in the directory
flip (bool): Bool for whether we will flip
Returns:
video_frames (numpy): Extracted video frames in numpy format
'''
# Init the the Video Reader
# The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to
video_frame_path = self.video_lists[idx]
# Find needed file
needed_img_path = []
for idx in range(self.video_seq_length):
img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg")
needed_img_path.append(img_path)
# Read all img_path based on the order
video_frames = []
for img_path in needed_img_path:
if not os.path.exists(img_path):
print("We don't have ", img_path)
frame = cv2.imread(img_path)
try:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
except Exception:
print("The exception place is ", img_path)
# Resize frames
frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC)
# Flip aug
if flip:
frame = np.fliplr(frame)
# Collect frames
video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here.
# Concatenate
video_frames = np.concatenate(video_frames, axis=0)
assert(len(video_frames) == self.video_seq_length)
# Returns
return video_frames
def __getitem__(self, idx):
''' Get item by idx and pre-process by Resize and Normalize to [0, 1]
Args:
idx (int): The index to the file in the directory
Returns:
return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1]
'''
# Prepare the text if needed:
if self.config["use_text"]:
# Read the file
file_path = os.path.join(self.video_lists[idx], "lang.txt")
file = open(file_path, 'r')
prompt = file.readlines()[0] # Only read the first line
if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")):
# If we don't have this txt file, we skip
######################################################## Mix up prompt ########################################################
# Read the file
file_path = os.path.join(self.video_lists[idx], "processed_text.txt")
file = open(file_path, 'r')
prompts = [line for line in file.readlines()] # Only read the first line
# Get the componenet
action = prompts[0][:-1]
this = prompts[1][:-1]
there = prompts[2][:-1]
random_value = random.random()
# If less than 0.4, we don't care, just use the most concrete one
if random_value >= 0.4 and random_value < 0.6:
# Mask pick object to "This"
prompt = action + " this to " + there
elif random_value >= 0.6 and random_value < 0.8:
# Mask place position to "There"
prompt = action + " " + this + " to there"
elif random_value >= 0.8 and random_value < 1.0:
# Just be like "this to there"
prompt = action + " this to there"
# print("New prompt is ", prompt)
###################################################################################################################################################
# else:
# print("We don't have llama processed prompt at ", self.video_lists[idx])
else:
prompt = ""
# Tokenize text prompt
tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config)
# Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
flip = False
if random.random() < self.config["flip_aug_prob"]:
if self.config["use_text"]:
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
flip = True
else:
flip = True
# Read frames for different dataset; Currently, we have WebVid / Bridge
if self.config["dataset_name"] == "Bridge":
video_frames_raw = self._extract_frame_bridge(idx, flip=flip)
else:
raise NotImplementedError("We don't support this dataset loader")
# Scale [0, 255] -> [-1, 1] if needed
if self.normalize:
video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32
# Transform to Pytorch Tensor in the range [-1, 1]
video_frames = numpy_to_pt(video_frames)
# Generate the pairs we need
intput_dir = self.video_lists[idx]
# Get the This That point information
controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip)
controlnet_condition = torch.from_numpy(controlnet_condition)
# Cast other value to tensor
reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32)
controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32)
coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32)
# The tensor we returned is torch float32. We won't cast here for mixed precision training!
return {"video_frames" : video_frames,
"controlnet_condition" : controlnet_condition,
"reflected_motion_bucket_id" : reflected_motion_bucket_id,
"controlnet_image_index": controlnet_image_index,
"prompt": tokenized_prompt,
"coordinate_values": coordinate_values, # Useless now, but I still passed back
}