MoMA_demo / model_lib /modules.py
Kunpeng Song
bg
ef3a17c
raw
history blame
No virus
7.35 kB
import os
from PIL import Image
import torch
import torch.nn as nn
from typing import List, Optional
import torch.utils.checkpoint
from torchvision.transforms import ToPILImage
from model_lib.moMA_generator import MoMA_generator
from transformers.activations import ACT2FN
from huggingface_hub import hf_hub_download
from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX
def add_function(model):
def my_llava_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
):
(_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=inputs_embeds,
use_cache=True,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
return outputs[0]
model.my_llava_forward = my_llava_forward
class LlamaMLP_mapping(nn.Module):
def __init__(self, hidden_size,hidden_size_out):
super().__init__()
self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
self.act_fn = ACT2FN["silu"]
self.act_fn_output = ACT2FN["tanh"]
self.init_linear()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def init_linear(self):
torch.nn.init.xavier_normal_(self.gate_proj.weight)
self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
torch.nn.init.xavier_normal_(self.up_proj.weight)
self.up_proj.weight.data=self.up_proj.weight.data/4.0
torch.nn.init.xavier_normal_(self.down_proj.weight)
self.down_proj.weight.data=self.down_proj.weight.data/4.0
class MoMA_main_modal(nn.Module):
def __init__(self,args):
super().__init__()
self.args = args
self.device = args.device
self.moMA_generator = MoMA_generator(self.device,args)
self.unet = self.moMA_generator.pipe.unet
self.vae = self.moMA_generator.pipe.vae
print('Loading MoMA: its Multi-modal LLM...')
model_name = get_model_name_from_path(args.model_path)
self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
add_function(self.model_llava)
self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
self.load_saved_components()
self.freeze_modules()
def load_saved_components(self):
if not os.path.exists(self.args.load_attn_adapters):
print('Loading Attentions and LLM mappings...')
hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))
#load attention adapters and self cross attentions
state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)
#load LLM projectors
self.load_state_dict(state_dict['llm_mapping'],strict=False)
def freeze_modules(self):
all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
for module in all_modules:
module.train = False
module.requires_grad_(False)
def forward_MLLM(self,batch):
llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
input_ids,attention_masks,position_ids = [],[],[]
for subject,prompt in zip(subjects,prompts):
prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *"
input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
position_ids += [position_id]
attention_masks += [attention_mask[0]]
input_ids += [input_id[0]]
input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
output = self.mapping(output)
return output[:,-1,:]
def reset(self):
self.moMA_generator.reset_all()
def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
self.moMA_generator.set_selfAttn_strength(strength)
with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
with torch.no_grad():
### key steps
llava_emb = self.forward_MLLM(batch).clone().detach()
img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)
self.reset()
result = ToPILImage()(img[0])
return result