import torch import cv2 import decord from decord import VideoReader, cpu decord.bridge.set_bridge('torch') import numpy as np from PIL import Image from torchvision import transforms from transformers import ProcessorMixin, BatchEncoding from transformers.image_processing_utils import BatchFeature from pytorchvideo.data.encoded_video import EncodedVideo from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) def make_list_of_images(x): if not isinstance(x, list): return [x] return x def get_video_transform(config): config = config.vision_config if config.video_decode_backend == 'pytorchvideo': transform = ApplyTransformToKey( key="video", transform=Compose( [ UniformTemporalSubsample(config.num_frames), Lambda(lambda x: x / 255.0), NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), ShortSideScale(size=224), CenterCropVideo(224), RandomHorizontalFlipVideo(p=0.5), ] ), ) elif config.video_decode_backend == 'decord': transform = Compose( [ # UniformTemporalSubsample(num_frames), Lambda(lambda x: x / 255.0), NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), ShortSideScale(size=224), CenterCropVideo(224), RandomHorizontalFlipVideo(p=0.5), ] ) elif config.video_decode_backend == 'opencv': transform = Compose( [ # UniformTemporalSubsample(num_frames), Lambda(lambda x: x / 255.0), NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), ShortSideScale(size=224), CenterCropVideo(224), RandomHorizontalFlipVideo(p=0.5), ] ) else: raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv)') return transform def load_and_transform_video( video_path, transform, video_decode_backend='opencv', clip_start_sec=0.0, clip_end_sec=None, num_frames=8, ): if video_decode_backend == 'pytorchvideo': # decord pyav video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False) duration = video.duration start_sec = clip_start_sec # secs end_sec = clip_end_sec if clip_end_sec is not None else duration # secs video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) video_outputs = transform(video_data) elif video_decode_backend == 'decord': decord.bridge.set_bridge('torch') decord_vr = VideoReader(video_path, ctx=cpu(0)) duration = len(decord_vr) frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) video_data = decord_vr.get_batch(frame_id_list) video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) video_outputs = transform(video_data) elif video_decode_backend == 'opencv': cv2_vr = cv2.VideoCapture(video_path) duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int) video_data = [] for frame_idx in frame_id_list: cv2_vr.set(1, frame_idx) ret, frame = cv2_vr.read() if not ret: raise ValueError(f'video error at {video_path}') frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) cv2_vr.release() video_data = torch.stack(video_data, dim=1) video_outputs = transform(video_data) else: raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv)') return video_outputs class LanguageBindVideoProcessor(ProcessorMixin): attributes = [] tokenizer_class = ("LanguageBindVideoTokenizer") def __init__(self, config, tokenizer=None, **kwargs): super().__init__(**kwargs) self.config = config # self.config.vision_config.video_decode_backend = 'opencv' self.transform = get_video_transform(config) self.image_processor = load_and_transform_video self.tokenizer = tokenizer def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") if text is not None: encoding = self.tokenizer(text, max_length=context_length, padding='max_length', truncation=True, return_tensors=return_tensors, **kwargs) if images is not None: images = make_list_of_images(images) image_features = [self.image_processor(image, self.transform, video_decode_backend=self.config.vision_config.video_decode_backend, num_frames=self.config.vision_config.num_frames) for image in images] # image_features = [torch.rand(3, 8, 224, 224) for image in images] image_features = torch.stack(image_features) if text is not None and images is not None: encoding["pixel_values"] = image_features return encoding elif text is not None: return encoding else: return {"pixel_values": image_features} def preprocess(self, images, return_tensors): return self.__call__(images=images, return_tensors=return_tensors) def batch_decode(self, skip_special_tokens=True, *args, **kwargs): """ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) def decode(self, skip_special_tokens=True, *args, **kwargs): """ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs)