from typing import Optional, Tuple, Union, List import torch import torch.utils.checkpoint from torch import nn from transformers.utils import ( logging, ) from transformers.models.blip_2.configuration_blip_2 import Blip2Config from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenerationModelOutput from transformers import ( Blip2PreTrainedModel, Blip2VisionModel, Blip2QFormerModel, PreTrainedTokenizer, PreTrainedModel, ) logger = logging.get_logger(__name__) class ZiyaBlip2ForCausalLM(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" _keys_to_ignore_on_load_missing = [ r"language_model", ] def __init__(self, config: Blip2Config, language_model: PreTrainedModel = None): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros( 1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) self.language_projection = nn.Linear( config.qformer_config.hidden_size, config.text_config.hidden_size) self.language_model = language_model # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check https://github.com/huggingface/transformers/pull/21707 for more details. """ hf_device_map = self.hf_device_map if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. logger.warning( "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." " Please pass a `device_map` that contains `language_model` to remove this warning." " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", " more details on creating a `device_map` for large models.", ) if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility def forward( self, pixel_values: torch.FloatTensor, input_ids_before_image: torch.FloatTensor, input_ids_after_image: torch.FloatTensor, labels_after_image: torch.FloatTensor, # 因为label不会出现在image之前,所以这里不需要labels_before_image, 按照input_ids_before_image补-100就可以了 output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # step 1: forward the images through the vision encoder, # to get image embeddings of shape (batch_size, seq_len, hidden_size) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[0] # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones( image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) query_output = query_outputs[0] # step 2.5 generate the lm input by prompt and output language_model_inputs = self.language_projection(query_output) language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) # 确保language_model_inputs的batch assert language_model_inputs.shape[0] == input_ids_after_image.shape[0] inputs_embeds_before_image = self.language_model.get_input_embeddings()(input_ids_before_image) inputs_embeds_after_image = self.language_model.get_input_embeddings()(input_ids_after_image) inputs_embeds = torch.cat( [ inputs_embeds_before_image.to(language_model_inputs.device), language_model_inputs, inputs_embeds_after_image.to(language_model_inputs.device) ], dim=1) attention_mask_before = torch.ones_like(input_ids_before_image) attention_mask_after = torch.ones_like(input_ids_after_image) attention_mask = torch.cat( [ attention_mask_before.to(language_model_attention_mask.device), language_model_attention_mask, attention_mask_after.to(language_model_attention_mask.device) ], dim=1 ) # labels也需要对应的处理,把前面空缺的-100加进去 labels = torch.cat( [ torch.tensor( [-100]).expand_as(input_ids_before_image).to(language_model_inputs.device), torch.tensor([-100]).expand(query_tokens.shape[:-1] ).to(language_model_inputs.device), labels_after_image, ], dim=1 ) # step 3: use the language model if self.config.use_decoder_only_language_model: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] else: raise Exception("not impl") if not return_dict: output = (logits, vision_outputs, query_outputs, outputs) return ((loss,) + output) if loss is not None else output return Blip2ForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, qformer_outputs=query_outputs, language_model_outputs=outputs, ) def prepare_inputs_for_chat( self, tokenizer: PreTrainedTokenizer, query: str, pixel_values: torch.Tensor, previous_querys: List[str], previous_outputs: List[str], max_length: int, ): # 1. process input_ids assert len(previous_querys) == len(previous_outputs) device = self.device prefix = self.config.prompt_prefix human_name = self.config.human_name assistant_name = self.config.assistant_name input_ids_before_image = tokenizer( prefix, return_tensors="pt").input_ids.to(device) inputs_ids_after_image = [] for (p, o) in zip(previous_querys, previous_outputs): # {pormpt}\n[答]: {output}\n[问]: inputs_ids_after_image += tokenizer(f"{human_name}: {p}\n", add_special_tokens=False).input_ids + \ tokenizer(f"{assistant_name}: {o}\n", add_special_tokens=False).input_ids inputs_ids_after_image += tokenizer(f"{human_name}: {query}\n", add_special_tokens=False).input_ids + tokenizer(f"{assistant_name} :", add_special_tokens=False).input_ids inputs_ids_after_image = torch.IntTensor([inputs_ids_after_image]).to(device) # 2. Prepare embeddings pixel_values.to(device) image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones( image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs.last_hidden_state language_model_inputs = self.language_projection(query_output) # concatenate query embeddings with prompt embeddings prefix_inputs_embeds = self.get_input_embeddings()(input_ids_before_image) prompt_inputs_embeds = self.get_input_embeddings()(inputs_ids_after_image) inputs_embeds = torch.cat([ prefix_inputs_embeds.to(language_model_inputs.device), language_model_inputs, prompt_inputs_embeds.to(language_model_inputs.device)], dim=1) if inputs_embeds.shape[1] > max_length: inputs_embeds = inputs_embeds[:, -max_length:, :] input_ids = torch.concat([ input_ids_before_image, torch.tensor([tokenizer.eos_token_id]).expand( query_tokens.shape[:-1]).to(language_model_inputs.device), inputs_ids_after_image, ], dim=1) return input_ids, inputs_embeds def chat(self, tokenizer, query: str, pixel_values: torch.Tensor, previous_querys: List[str], previous_outputs: List[str], **generate_kwargs,): """ use for generate text by chat-style Args: tokenizer (PretrainedTokenizer): llama tokenizer query (str): current input query pixel_values (torch.Tensor): image after image_processor prompts (List[str]): chat history outputs (List[str]): chat history Returns: text: generate text """ input_ids, inputs_embeds = self.prepare_inputs_for_chat( tokenizer, query, pixel_values, previous_querys, previous_outputs, 2048 ) response = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=torch.ones_like(input_ids), **generate_kwargs, ) response = tokenizer.decode(response[0], skip_special_tokens=True) return response