#!/usr/bin/env python # -*- coding: utf-8 -*- """ ================================================ @author: Jaron @time: 2024/08/21 17:41:52 @email: fjjth98@163.com @description: Video-CCAM ================================================ """ import torch import os.path as osp from PIL import Image from peft import PeftModel from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, SiglipVisionModel, SiglipImageProcessor, GenerationConfig from .configuration_videoccam import VideoCCAMConfig class VideoCCAM(PreTrainedModel): config_class = VideoCCAMConfig _auto_class = 'AutoModel' supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config, device_map: str = 'auto'): super().__init__(config) self.image_token = config.image_token self.video_token = config.video_token self.vision_select_layer = config.vision_select_layer self.vision_max_chunk_size = config.vision_max_chunk_size self.gradient_checkpointing = False self.projector = AutoModel.from_pretrained( config.projector_name_or_path, device_map=device_map, trust_remote_code=True, torch_dtype=config.torch_dtype, attn_implementation='sdpa' if config._attn_implementation == 'flash_attention_2' else config._attn_implementation # CCAM does not support flash_attention_2 ) self.llm = AutoModelForCausalLM.from_pretrained( config.llm_name_or_path, device_map=device_map, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation ) self.tokenizer = AutoTokenizer.from_pretrained( config.llm_name_or_path, additional_special_tokens=[self.image_token, self.video_token] ) self.generation_config = GenerationConfig.from_pretrained(config.llm_name_or_path) self.image_token_id, self.video_token_id = self.tokenizer.convert_tokens_to_ids([self.image_token, self.video_token]) self.vision_encoder = SiglipVisionModel.from_pretrained( config.vision_encoder_name_or_path, device_map=device_map, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation ) self.image_processor = SiglipImageProcessor.from_pretrained( config.vision_encoder_name_or_path ) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = dict(use_reentrant=False) self.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) self.vision_encoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs) def forward_visual_embeds(self, pixel_values: torch.Tensor) -> torch.Tensor: if self.vision_select_layer in {-1, self.vision_encoder.config.num_hidden_layers}: visual_embeds = self.vision_encoder(pixel_values, output_hidden_states=False).last_hidden_state else: visual_embeds = self.vision_encoder(pixel_values, output_hidden_states=True).hidden_states[self.vision_select_layer] return visual_embeds @torch.inference_mode def chat( self, messages: list[list[dict]], images: list[Image.Image, list[Image.Image]] = None, generation_config = None, batch_generate: bool = False, visual_embeds: torch.Tensor = None, return_visual_embeds: bool = False, **kwargs ): if generation_config is None: generation_config = self.generation_config # compute visual embeds if visual_embeds is None: _images, split_size = [], [] for i in images: if isinstance(i, Image.Image): _images.append(i) split_size.append(1) else: _images += i split_size.append(len(i)) pixel_values = self.image_processor( _images, return_tensors='pt' )['pixel_values'].to( dtype=self.vision_encoder.get_input_embeddings().weight.dtype, device=self.vision_encoder.get_input_embeddings().weight.device ) if 0 < self.vision_max_chunk_size < len(pixel_values): split_idx = list(range(0, len(pixel_values), self.vision_max_chunk_size)) + [-1] visual_embeds = torch.cat([ self.forward_visual_embeds(pixel_values[le:ri]) for le, ri in zip(split_idx[:-1], split_idx[1:]) ], dim=0) else: visual_embeds = self.forward_visual_embeds(pixel_values) visual_embeds = self.projector(visual_embeds.split(split_size, dim=0)) # compute textual embeds device = self.llm.get_input_embeddings().weight.device input_ids = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) # list[list[int]] _input_ids, split_idx = [], [0] for i in input_ids: _input_ids += i split_idx.append(split_idx[-1] + len(i)) _input_ids = torch.tensor(_input_ids, dtype=torch.long, device=device) visual_idx = torch.where((_input_ids == self.image_token_id) | (_input_ids == self.video_token_id))[0].tolist() assert len(visual_idx) == len(visual_embeds), f'The number of visual tokens ({len(visual_idx)}) should be equal to the number of visual features ({len(visual_embeds)}).' _input_ids[visual_idx] = 0 # avoid index overflow _inputs_embeds = self.llm.get_input_embeddings()(_input_ids) inputs_embeds, cur_visual_pointer = [], 0 for start_idx, end_idx in zip(split_idx[:-1], split_idx[1:]): if cur_visual_pointer < len(visual_idx) and visual_idx[cur_visual_pointer] < end_idx: mid_idx = visual_idx[cur_visual_pointer] embeds = [_inputs_embeds[start_idx:mid_idx], visual_embeds[cur_visual_pointer]] cur_visual_pointer += 1 while cur_visual_pointer < len(visual_idx) and visual_idx[cur_visual_pointer] < end_idx: embeds += [_inputs_embeds[mid_idx+1:visual_idx[cur_visual_pointer]], visual_embeds[cur_visual_pointer]] mid_idx = visual_idx[cur_visual_pointer] cur_visual_pointer += 1 embeds.append(_inputs_embeds[mid_idx+1:end_idx]) inputs_embeds.append(torch.cat(embeds, dim=0)) # Pure Text else: inputs_embeds.append(_inputs_embeds[start_idx:end_idx]) if batch_generate: B, L = len(inputs_embeds), max(i.size(0) for i in inputs_embeds) pad_embeds = self.llm.get_input_embeddings()( torch.tensor([self.tokenizer.pad_token_id], dtype=torch.long, device=device) ) # (1, C) inputs_embeds_list = [] attention_mask = torch.zeros(B, L, dtype=torch.long, device=device) for i, embeds in enumerate(inputs_embeds): l = embeds.size(0) inputs_embeds_list += [pad_embeds.expand(L - l, -1), embeds] attention_mask[i, -l:] = 1 inputs_embeds = torch.cat(inputs_embeds_list, dim=0).view(B, L, -1) output_ids = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, generation_config=generation_config, **kwargs ) else: output_ids = [] for embeds in inputs_embeds: output_ids.append(self.llm.generate( inputs_embeds=embeds[None], attention_mask=torch.ones(1, embeds.size(0), dtype=torch.long, device=device), generation_config=generation_config, **kwargs )[0]) prediction = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) if return_visual_embeds: return prediction, visual_embeds else: return prediction @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, *args, config: VideoCCAMConfig = None, torch_dtype: torch.dtype = torch.bfloat16, device_map: str = 'auto', **kwargs ) -> PreTrainedModel: merge_pretrained_lora = kwargs.pop('merge_pretrained_lora', True) config.torch_dtype = torch_dtype config.projector_name_or_path = osp.join(pretrained_model_name_or_path, 'projector') if osp.isdir(cur_path := osp.join(pretrained_model_name_or_path, 'llm')): config.llm_name_or_path = cur_path if osp.isdir(cur_path := osp.join(pretrained_model_name_or_path, 'vision_encoder')): config.vision_encoder_name_or_path = cur_path model = cls(config, device_map) # load LoRA if exists if osp.exists(cur_path := osp.join(pretrained_model_name_or_path, 'llm_adapter')): model.llm = PeftModel.from_pretrained(model.llm, cur_path, device_map=device_map) print(f'Load LLM adapter from {cur_path}.') if merge_pretrained_lora: model.llm = model.llm.merge_and_unload() if osp.exists(cur_path := osp.join(pretrained_model_name_or_path, 'vision_encoder_adapter')): model.vision_encoder = PeftModel.from_pretrained(model.vision_encoder, cur_path, device_map=device_map) print(f'Load vision encoder adapter from {cur_path}.') if merge_pretrained_lora: model.vision_encoder = model.vision_encoder.merge_and_unload() return model