Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import os.path as osp | |
import warnings | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
from accelerate import init_empty_weights | |
from mmengine import print_log | |
from mmengine.config import Config, ConfigDict | |
from mmengine.model import BaseModel | |
from peft import get_peft_model, prepare_model_for_kbit_training | |
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, | |
CLIPVisionModel, LlamaForCausalLM, | |
LlamaTokenizerFast, LlavaConfig, | |
LlavaForConditionalGeneration, LlavaProcessor) | |
from transformers.integrations import is_deepspeed_zero3_enabled | |
from xtuner.registry import BUILDER | |
from xtuner.utils import DEFAULT_IMAGE_TOKEN | |
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules | |
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 | |
from .utils import (LoadWoInit, find_all_linear_names, | |
get_peft_model_state_dict, guess_load_checkpoint, | |
make_inputs_require_grad, | |
prepare_inputs_labels_for_multimodal, traverse_dict) | |
def convert_state_dict_to_hf(state_dict, mapping): | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.endswith('.inv_freq'): | |
continue | |
for key_to_modify, new_key in mapping.items(): | |
if key_to_modify in key: | |
key = key.replace(key_to_modify, new_key) | |
new_state_dict[key] = value | |
return new_state_dict | |
class LLaVAModel(BaseModel): | |
def __init__(self, | |
llm, | |
visual_encoder, | |
freeze_llm=False, | |
freeze_visual_encoder=False, | |
visual_select_layer=-2, | |
pretrained_pth=None, | |
projector_depth=2, | |
llm_lora=None, | |
visual_encoder_lora=None, | |
use_activation_checkpointing=True, | |
max_position_embeddings=None): | |
super().__init__() | |
self.freeze_llm = freeze_llm | |
self.freeze_visual_encoder = freeze_visual_encoder | |
with LoadWoInit(): | |
if isinstance(llm, dict): | |
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) | |
self.llm = self._build_from_cfg_or_module(llm) | |
self.visual_encoder = self._build_from_cfg_or_module( | |
visual_encoder) | |
self.llm.config.use_cache = False | |
dispatch_modules(self.llm) | |
self.projector_depth = projector_depth | |
projector_config = ProjectorConfig( | |
visual_hidden_size=self.visual_encoder.config.hidden_size, | |
llm_hidden_size=self.llm.config.hidden_size, | |
depth=self.projector_depth) | |
self.projector = ProjectorModel(projector_config).to( | |
self.visual_encoder.dtype) | |
if self.freeze_llm: | |
self.llm.requires_grad_(False) | |
if self.freeze_visual_encoder: | |
self.visual_encoder.requires_grad_(False) | |
if use_activation_checkpointing: | |
# For backward compatibility | |
if hasattr(self.llm, 'enable_input_require_grads'): | |
self.llm.enable_input_require_grads() | |
else: | |
self.llm.get_input_embeddings().register_forward_hook( | |
make_inputs_require_grad) | |
if hasattr(self.visual_encoder, 'enable_input_require_grads'): | |
self.visual_encoder.enable_input_require_grads() | |
else: | |
self.visual_encoder.get_input_embeddings( | |
).register_forward_hook(make_inputs_require_grad) | |
self.projector.enable_input_require_grads() | |
# enable gradient (activation) checkpointing for memory efficiency | |
self.gradient_checkpointing_enable() | |
self.use_llm_lora = llm_lora is not None | |
self.use_visual_encoder_lora = visual_encoder_lora is not None | |
if self.use_llm_lora: | |
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) | |
if self.use_visual_encoder_lora: | |
self._prepare_visual_encoder_for_lora( | |
visual_encoder_lora, use_activation_checkpointing) | |
if pretrained_pth is not None: | |
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) | |
self.load_state_dict(pretrained_state_dict, strict=False) | |
print_log(f'Load pretrained weight from {pretrained_pth}', | |
'current') | |
self.visual_select_layer = visual_select_layer | |
self._is_init = True | |
self.is_first_iter = True | |
def _parse_lora_config(self, lora_config): | |
if isinstance(lora_config, dict) or isinstance( | |
lora_config, Config) or isinstance(lora_config, ConfigDict): | |
lora_config = BUILDER.build(lora_config) | |
return lora_config | |
def _prepare_llm_for_lora(self, | |
lora_config, | |
use_activation_checkpointing=True): | |
lora_config = self._parse_lora_config(lora_config) | |
self.llm = prepare_model_for_kbit_training( | |
self.llm, use_activation_checkpointing) | |
if lora_config.target_modules is None: | |
modules = find_all_linear_names(self.llm) | |
lora_config.target_modules = modules | |
self.llm = get_peft_model(self.llm, lora_config) | |
def _prepare_visual_encoder_for_lora(self, | |
lora_config, | |
use_activation_checkpointing=True): | |
lora_config = self._parse_lora_config(lora_config) | |
if lora_config.target_modules is None: | |
modules = find_all_linear_names(self.visual_encoder) | |
lora_config.target_modules = modules | |
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) | |
def gradient_checkpointing_enable(self): | |
self.activation_checkpointing_enable() | |
def activation_checkpointing_enable(self): | |
self.llm.gradient_checkpointing_enable() | |
self.visual_encoder.gradient_checkpointing_enable() | |
self.projector.gradient_checkpointing_enable() | |
def gradient_checkpointing_disable(self): | |
self.activation_checkpointing_disable() | |
def activation_checkpointing_disable(self): | |
self.llm.gradient_checkpointing_disable() | |
self.visual_encoder.gradient_checkpointing_disable() | |
self.projector.gradient_checkpointing_disable() | |
def init_weights(self): | |
pass | |
def state_dict(self, *args, **kwargs): | |
state_dict = super().state_dict(*args, **kwargs) | |
to_return = OrderedDict() | |
# Step 1. visual_encoder | |
if self.use_visual_encoder_lora: | |
to_return.update( | |
get_peft_model_state_dict( | |
self.visual_encoder, state_dict=state_dict)) | |
elif not self.freeze_visual_encoder: | |
to_return.update({ | |
k: v | |
for k, v in state_dict.items() if 'visual_encoder.' in k | |
}) | |
# Step 2. LLM | |
if self.use_llm_lora: | |
to_return.update( | |
get_peft_model_state_dict(self.llm, state_dict=state_dict)) | |
elif not self.freeze_llm: | |
to_return.update( | |
{k: v | |
for k, v in state_dict.items() if 'llm.' in k}) | |
# Step 3. Projector | |
to_return.update( | |
{k: v | |
for k, v in state_dict.items() if 'projector.' in k}) | |
return to_return | |
def _prepare_for_long_context_training(cfg, llm_cfg, | |
max_position_embeddings): | |
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) | |
if orig_rope_scaling is None: | |
orig_rope_scaling = {'factor': 1} | |
orig_rope_scaling_factor = orig_rope_scaling[ | |
'factor'] if 'factor' in orig_rope_scaling.keys() else 1 | |
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) | |
if orig_ctx_len: | |
orig_ctx_len *= orig_rope_scaling_factor | |
if max_position_embeddings > orig_ctx_len: | |
scaling_factor = float( | |
math.ceil(max_position_embeddings / orig_ctx_len)) | |
llm_cfg.rope_scaling = { | |
'type': 'linear', | |
'factor': scaling_factor | |
} | |
# hardcode for internlm2 | |
llm_cfg.attn_implementation = 'flash_attention_2' | |
cfg.config = llm_cfg | |
return cfg, llm_cfg | |
def _prepare_for_flash_attn(cfg, llm_cfg): | |
cls_name = type(llm_cfg).__name__ | |
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', | |
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', | |
'Starcoder2Config', 'Starcoder2Config', | |
'Phi3Config') | |
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', | |
'MistralConfig', 'MixtralConfig', 'Qwen2Config', | |
'Qwen2MoeConfig', 'Starcoder2Config', | |
'Starcoder2Config', 'Phi3Config') | |
torch_dtype = torch.bfloat16 if ( | |
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ | |
else torch.float16 | |
if getattr(cfg, 'attn_implementation', None) is not None: | |
# Flash Attention 2.0 only supports torch.float16 and | |
# torch.bfloat16 dtypes | |
if cfg.attn_implementation == 'flash_attention_2': | |
cfg.torch_dtype = torch_dtype | |
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: | |
cfg.torch_dtype = torch_dtype | |
cfg.attn_implementation = 'flash_attention_2' | |
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: | |
cfg.attn_implementation = 'sdpa' | |
return cfg, llm_cfg | |
def _prepare_for_qlora_zero3(cfg): | |
if (not is_deepspeed_zero3_enabled()) or (not hasattr( | |
cfg, 'quantization_config')): | |
return cfg | |
torch_dtype = torch.bfloat16 if ( | |
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ | |
else torch.float16 | |
cfg.torch_dtype = torch_dtype | |
quantization_config = cfg.quantization_config | |
quantization_config.bnb_4bit_compute_dtype = torch_dtype | |
quantization_config.bnb_4bit_quant_storage = torch_dtype | |
return cfg | |
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): | |
cfg = self._prepare_for_qlora_zero3(cfg) | |
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path | |
llm_cfg = AutoConfig.from_pretrained( | |
pretrained_model_name_or_path, trust_remote_code=True) | |
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) | |
if max_position_embeddings is not None: | |
cfg, llm_cfg = self._prepare_for_long_context_training( | |
cfg, llm_cfg, max_position_embeddings) | |
return cfg | |
def _build_from_cfg_or_module(self, cfg_or_mod): | |
if isinstance(cfg_or_mod, nn.Module): | |
return cfg_or_mod | |
elif isinstance(cfg_or_mod, dict): | |
traverse_dict(cfg_or_mod) | |
return BUILDER.build(cfg_or_mod) | |
else: | |
raise NotImplementedError | |
def forward(self, data, data_samples=None, mode='loss'): | |
if self.is_first_iter: | |
# hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to | |
# device | |
# Only required in `LLaVAModel` . | |
# We do not need this in `SupervisedFinetune` . | |
self.to(data['input_ids'].device) | |
self.is_first_iter = False | |
if 'pixel_values' in data: | |
visual_outputs = self.visual_encoder( | |
data['pixel_values'].to(self.visual_encoder.dtype), | |
output_hidden_states=True) | |
pixel_values = self.projector( | |
visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) | |
data['pixel_values'] = pixel_values | |
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) | |
if mode == 'loss': | |
return self.compute_loss(data, data_samples) | |
elif mode == 'predict': | |
return self.predict(data, data_samples) | |
elif mode == 'tensor': | |
return self._forward(data, data_samples) | |
else: | |
raise NotImplementedError | |
def _forward(self, data, data_samples=None): | |
outputs = self.llm(**data) | |
return outputs | |
def predict(self, data, data_samples=None): | |
outputs = self.llm(**data) | |
logits_dict = [{'logits': logits} for logits in outputs.logits] | |
return logits_dict | |
def compute_loss(self, data, data_samples=None): | |
outputs = self.llm(**data) | |
loss_dict = {'loss': outputs.loss} | |
return loss_dict | |
def __getattr__(self, name: str): | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
return getattr(self.llm, name) | |
def to_hf(self, | |
cfg, | |
save_dir, | |
fp32=False, | |
save_pretrained_kwargs={}, | |
save_format='xtuner', | |
**kwargs): | |
if save_format == 'xtuner': | |
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) | |
elif save_format == 'huggingface': | |
self.to_huggingface_llava(cfg, save_dir, fp32, | |
save_pretrained_kwargs) | |
elif save_format == 'official': | |
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) | |
else: | |
raise NotImplementedError | |
def to_xtuner_llava(self, | |
cfg, | |
save_dir, | |
fp32=False, | |
save_pretrained_kwargs={}): | |
# LLM | |
self.llm.config.use_cache = True | |
if not fp32: | |
print_log('Convert LLM to float16', 'current') | |
self.llm.half() | |
if self.use_llm_lora: | |
llm_path = osp.join(save_dir, 'llm_adapter') | |
print_log(f'Saving LLM adapter to {llm_path}', 'current') | |
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) | |
elif not self.freeze_llm: | |
llm_path = save_dir | |
print_log(f'Saving LLM tokenizer to {llm_path}', 'current') | |
tokenizer = BUILDER.build(cfg.tokenizer) | |
tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) | |
print_log(f'Saving LLM to {llm_path}', 'current') | |
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) | |
self.llm.config.use_cache = False | |
# Visual Encoder | |
if self.use_visual_encoder_lora: | |
visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') | |
print_log( | |
f'Saving visual_encoder adapter to {visual_encoder_path}', | |
'current') | |
self.visual_encoder.save_pretrained(visual_encoder_path, | |
**save_pretrained_kwargs) | |
elif not self.freeze_visual_encoder: | |
visual_encoder_path = osp.join(save_dir, 'visual_encoder') | |
print_log( | |
'Saving visual_encoder image_processor to' | |
f'{visual_encoder_path}', 'current') | |
image_processor = BUILDER.build(cfg.image_processor) | |
image_processor.save_pretrained(visual_encoder_path, | |
**save_pretrained_kwargs) | |
print_log(f'Saving visual_encoder to {visual_encoder_path}', | |
'current') | |
self.visual_encoder.save_pretrained(visual_encoder_path, | |
**save_pretrained_kwargs) | |
# Projector | |
projector_path = osp.join(save_dir, 'projector') | |
print_log(f'Saving projector to {projector_path}', 'current') | |
self.projector.save_pretrained(projector_path, | |
**save_pretrained_kwargs) | |
def to_huggingface_llava(self, | |
cfg, | |
save_dir, | |
fp32=False, | |
save_pretrained_kwargs={}): | |
LLM_MAPPING = { | |
'model': 'language_model.model', | |
'lm_head': 'language_model.lm_head', | |
} | |
VIT_MAPPING = { | |
'vision_model': 'vision_tower.vision_model', | |
} | |
PROJECTOR_MAPPING = { | |
'model.0': 'multi_modal_projector.linear_1', | |
'model.2': 'multi_modal_projector.linear_2', | |
} | |
assert getattr(self.llm, 'hf_quantizer', None) is None, \ | |
'This conversion format does not support quantized LLM.' | |
# get state_dict | |
llm = self.llm | |
if self.use_llm_lora: | |
llm = self.llm.merge_and_unload() | |
llm.config.use_cache = True | |
if not fp32: | |
print_log('Convert LLM to float16', 'current') | |
llm.half() | |
assert isinstance(llm, LlamaForCausalLM), \ | |
'This conversion format only supports LlamaForCausalLM.' | |
llm_state_dict = llm.state_dict() | |
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) | |
need_visual_encoder = (not self.freeze_visual_encoder | |
or self.use_visual_encoder_lora) | |
visual_encoder = self.visual_encoder | |
if self.use_visual_encoder_lora: | |
visual_encoder = self.visual_encoder.merge_and_unload() | |
assert isinstance(visual_encoder, CLIPVisionModel),\ | |
'This conversion format only supports CLIPVisionModel.' | |
if need_visual_encoder: | |
visual_encoder_state_dict = visual_encoder.state_dict() | |
visual_encoder_state_dict = convert_state_dict_to_hf( | |
visual_encoder_state_dict, VIT_MAPPING) | |
else: | |
visual_encoder_state_dict = {} | |
projector_state_dict = self.projector.state_dict() | |
projector_state_dict = convert_state_dict_to_hf( | |
projector_state_dict, PROJECTOR_MAPPING) | |
state_dict = { | |
**projector_state_dict, | |
**llm_state_dict, | |
**visual_encoder_state_dict | |
} | |
# init model | |
text_config = llm.config | |
vision_config = visual_encoder.config | |
config = LlavaConfig( | |
text_config=text_config, | |
vision_config=vision_config, | |
attn_implementation='eager') | |
with init_empty_weights(): | |
with warnings.catch_warnings(): | |
warnings.filterwarnings( | |
'ignore', message='.*non-meta.*', category=UserWarning) | |
model = LlavaForConditionalGeneration(config) | |
model.load_state_dict(state_dict, strict=True, assign=True) | |
# processor | |
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained | |
tokenizer = BUILDER.build(cfg.tokenizer) | |
tokenizer.add_tokens( | |
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), | |
special_tokens=True) | |
tokenizer.add_special_tokens({'pad_token': '<pad>'}) | |
image_processor = BUILDER.build(cfg.image_processor) | |
assert isinstance(image_processor, CLIPImageProcessor),\ | |
'This conversion format only supports CLIPImageProcessor.' | |
processor = LlavaProcessor( | |
tokenizer=tokenizer, image_processor=image_processor) | |
# Pad to 64 for performance reasons | |
pad_shape = 64 | |
pre_expansion_embeddings = \ | |
model.language_model.model.embed_tokens.weight.data | |
mu = torch.mean(pre_expansion_embeddings, dim=0).float() | |
n = pre_expansion_embeddings.size()[0] | |
sigma = ((pre_expansion_embeddings - mu).T | |
dist = torch.distributions.multivariate_normal.MultivariateNormal( | |
mu, covariance_matrix=1e-5 * sigma) | |
# We add an image token so we need to resize the model | |
ori_vocab_size = config.text_config.vocab_size | |
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1] | |
added_token = tokenizer_vocab_size - ori_vocab_size | |
if added_token > 0: | |
model.resize_token_embeddings(ori_vocab_size + added_token, | |
pad_shape) | |
model.language_model.model.embed_tokens.weight.data[ | |
ori_vocab_size:] = torch.stack( | |
tuple( | |
dist.sample() | |
for _ in range(model.language_model.model.embed_tokens. | |
weight.data[ori_vocab_size:].shape[0])), | |
dim=0, | |
) | |
model.language_model.lm_head.weight.data[ | |
ori_vocab_size:] = torch.stack( | |
tuple(dist.sample() | |
for _ in range(model.language_model.lm_head.weight. | |
data[ori_vocab_size:].shape[0])), | |
dim=0, | |
) | |
model.config.image_token_index = tokenizer.encode( | |
DEFAULT_IMAGE_TOKEN)[-1] | |
model.config.pad_token_id = tokenizer.encode('<pad>')[-1] | |
# save | |
print_log(f'Saving to {save_dir}', 'current') | |
model.save_pretrained(save_dir, **save_pretrained_kwargs) | |
processor.save_pretrained(save_dir, **save_pretrained_kwargs) | |
def to_official_llava(self, | |
cfg, | |
save_dir, | |
fp32=False, | |
save_pretrained_kwargs={}): | |
VIT_MAPPING = { | |
'vision_model': 'model.vision_tower.vision_tower.vision_model', | |
} | |
PROJECTOR_MAPPING = { | |
'model.0': 'model.mm_projector.0', | |
'model.2': 'model.mm_projector.2', | |
} | |
try: | |
from llava.model import LlavaConfig, LlavaLlamaForCausalLM | |
except ImportError: | |
raise ImportError( | |
'Please install llava with ' | |
'`pip install git+https://github.com/haotian-liu/LLaVA.git ' | |
'--no-deps`.') | |
assert getattr(self.llm, 'hf_quantizer', None) is None, \ | |
'This conversion format does not support quantized LLM.' | |
# get state_dict | |
llm = self.llm | |
if self.use_llm_lora: | |
llm = self.llm.merge_and_unload() | |
llm.config.use_cache = True | |
if not fp32: | |
print_log('Convert LLM to float16', 'current') | |
llm.half() | |
assert isinstance(llm, LlamaForCausalLM), \ | |
'This conversion format only supports LlamaForCausalLM.' | |
llm_state_dict = llm.state_dict() | |
need_visual_encoder = (not self.freeze_visual_encoder | |
or self.use_visual_encoder_lora) | |
visual_encoder = self.visual_encoder | |
if self.use_visual_encoder_lora: | |
visual_encoder = self.visual_encoder.merge_and_unload() | |
assert isinstance(visual_encoder, CLIPVisionModel),\ | |
'This conversion format only supports CLIPVisionModel.' | |
if need_visual_encoder: | |
visual_encoder_state_dict = visual_encoder.state_dict() | |
visual_encoder_state_dict = convert_state_dict_to_hf( | |
visual_encoder_state_dict, VIT_MAPPING) | |
else: | |
visual_encoder_state_dict = {} | |
projector_state_dict = self.projector.state_dict() | |
projector_state_dict = convert_state_dict_to_hf( | |
projector_state_dict, PROJECTOR_MAPPING) | |
state_dict = { | |
**projector_state_dict, | |
**llm_state_dict, | |
**visual_encoder_state_dict | |
} | |
# init model | |
tokenizer = BUILDER.build(cfg.tokenizer) | |
image_processor = BUILDER.build(cfg.image_processor) | |
assert isinstance(image_processor, CLIPImageProcessor),\ | |
'This conversion format only supports CLIPImageProcessor.' | |
llava_config_dict = llm.config.__dict__.copy() | |
llava_config_dict.update( | |
dict( | |
image_aspect_ratio='pad', | |
mm_hidden_size=visual_encoder.config.hidden_size, | |
mm_projector_type=f'mlp{self.projector_depth}x_gelu', | |
mm_use_im_patch_token=False, | |
mm_use_im_start_end=False, | |
mm_vision_select_feature='patch', | |
mm_vision_select_layer=self.visual_select_layer, | |
mm_vision_tower=visual_encoder.config.name_or_path, | |
unfreeze_mm_vision_tower=need_visual_encoder, | |
model_type='llava', | |
use_cache=True, | |
use_mm_proj=True)) | |
llava_config = LlavaConfig(**llava_config_dict) | |
with init_empty_weights(): | |
with warnings.catch_warnings(): | |
warnings.filterwarnings( | |
'ignore', message='.*non-meta.*', category=UserWarning) | |
model = LlavaLlamaForCausalLM(llava_config) | |
model.load_state_dict(state_dict, strict=True, assign=True) | |
# save | |
print_log(f'Saving to {save_dir}', 'current') | |
model.save_pretrained(save_dir, **save_pretrained_kwargs) | |
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) | |
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) | |