Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import numpy as np | |
import sys | |
from omg_llava.tools.app_utils import process_markdown, show_mask_pred, parse_visual_prompts | |
import torch | |
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, | |
BitsAndBytesConfig, CLIPImageProcessor, | |
CLIPVisionModel, GenerationConfig) | |
from transformers.generation.streamers import TextStreamer | |
from xtuner.dataset.utils import expand2square, load_image | |
from omg_llava.dataset.utils import expand2square_bbox, expand2square_mask, expand2square_points | |
import argparse | |
import os.path as osp | |
from gradio_image_prompter import ImagePrompter | |
# @spaces.GPU | |
def import_func(): | |
global prepare_inputs_labels_for_multimodal_with_visual_prompts | |
global get_stop_criteria | |
global DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, PROMPT_TEMPLATE, SYSTEM_TEMPLATE | |
global Config, DictAction | |
global PetrelBackend, get_file_backend | |
global cfgs_name_path | |
global guess_load_checkpoint | |
global BUILDER | |
from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts | |
from xtuner.tools.utils import get_stop_criteria | |
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, | |
PROMPT_TEMPLATE, SYSTEM_TEMPLATE) | |
from mmengine.config import Config, DictAction | |
from mmengine.fileio import PetrelBackend, get_file_backend | |
from xtuner.configs import cfgs_name_path | |
from xtuner.model.utils import guess_load_checkpoint | |
from xtuner.registry import BUILDER | |
return | |
import_func() | |
TORCH_DTYPE_MAP = dict( | |
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') | |
def parse_args(args): | |
parser = argparse.ArgumentParser(description="OMG-LLaVA Demo") | |
parser.add_argument('--config', help='config file name or path.', | |
default='./omg_llava/configs/finetune/hf_app.py') | |
parser.add_argument('--pth_model', help='pth model file', | |
default='./pretrained/omg_llava/omg_llava_fintune_8gpus.pth') | |
parser.add_argument('--image', default=None, help='image') | |
parser.add_argument( | |
'--torch-dtype', | |
default='fp16', | |
choices=TORCH_DTYPE_MAP.keys(), | |
help='Override the default `torch.dtype` and load the model under ' | |
'a specific `dtype`.') | |
parser.add_argument( | |
'--prompt-template', | |
choices=PROMPT_TEMPLATE.keys(), | |
default="internlm2_chat", | |
help='Specify a prompt template') | |
system_group = parser.add_mutually_exclusive_group() | |
system_group.add_argument( | |
'--system', default=None, help='Specify the system text') | |
system_group.add_argument( | |
'--system-template', | |
choices=SYSTEM_TEMPLATE.keys(), | |
default=None, | |
help='Specify a system template') | |
parser.add_argument( | |
'--bits', | |
type=int, | |
choices=[4, 8, None], | |
default=None, | |
help='LLM bits') | |
parser.add_argument( | |
'--bot-name', type=str, default='BOT', help='Name for Bot') | |
parser.add_argument( | |
'--with-plugins', | |
nargs='+', | |
choices=['calculate', 'solve', 'search'], | |
help='Specify plugins to use') | |
parser.add_argument( | |
'--no-streamer', action='store_true', help='Whether to with streamer') | |
parser.add_argument( | |
'--lagent', action='store_true', help='Whether to use lagent') | |
parser.add_argument( | |
'--stop-words', nargs='+', type=str, default=[], help='Stop words') | |
parser.add_argument( | |
'--offload-folder', | |
default=None, | |
help='The folder in which to offload the model weights (or where the ' | |
'model weights are already offloaded).') | |
parser.add_argument( | |
'--max-new-tokens', | |
type=int, | |
default=2048, | |
help='Maximum number of new tokens allowed in generated text') | |
parser.add_argument( | |
'--temperature', | |
type=float, | |
default=0.1, | |
help='The value used to modulate the next token probabilities.') | |
parser.add_argument( | |
'--top-k', | |
type=int, | |
default=40, | |
help='The number of highest probability vocabulary tokens to ' | |
'keep for top-k-filtering.') | |
parser.add_argument( | |
'--top-p', | |
type=float, | |
default=0.75, | |
help='If set to float < 1, only the smallest set of most probable ' | |
'tokens with probabilities that add up to top_p or higher are ' | |
'kept for generation.') | |
parser.add_argument( | |
'--repetition-penalty', | |
type=float, | |
default=1.0, | |
help='The parameter for repetition penalty. 1.0 means no penalty.') | |
parser.add_argument( | |
'--seed', | |
type=int, | |
default=0, | |
help='Random seed for reproducible text generation') | |
return parser.parse_args(args) | |
def get_points_embeddings(points, input_ids, width, height, | |
mark_token_idx, mode='point'): | |
if points is None or len(points) == 0: | |
return [] | |
mark_token_mask = input_ids == mark_token_idx | |
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( | |
input_ids.device) | |
batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number | |
points = points.to(torch.float32) | |
# print(points.dtype, batch_idxs.dtype) | |
if mode == 'point': | |
marks_embeddings = visual_encoder.forward_point_sam( | |
points, batch_idxs, width=width, height=height | |
)[:, 0] # (N, C) | |
elif mode == 'box': | |
marks_embeddings = visual_encoder.forward_box_sam( | |
points, batch_idxs, width=width, height=height | |
)[:, 0] # (N, C) | |
else: | |
raise NotImplementedError | |
marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype) | |
marks_embeddings = projector.model.query_proj(marks_embeddings) | |
marks_embeddings = projector.model.model(marks_embeddings) | |
print('marks_embeddings shape: ', marks_embeddings.shape) | |
return marks_embeddings # (N, C) | |
def get_visual_prompts_embeddings( | |
height, width, input_ids, | |
): | |
points_prompts = global_infos.point_prompts | |
boxes_prompts = global_infos.box_prompts | |
if len(points_prompts) == 0: | |
points_mark_embedding = [] | |
else: | |
points = np.array(points_prompts) | |
points = expand2square_points(points, height=height, width=width) | |
points[:, 0] = points[:, 0] / max(height, width) * 1024 | |
points[:, 1] = points[:, 1] / max(height, width) * 1024 | |
points = torch.from_numpy(points) | |
points = points.cuda() | |
mark_token_id = omg_llava.mark_token_idx | |
points_mark_embedding = get_points_embeddings( | |
points, input_ids, | |
1024, 1024, | |
mark_token_id) | |
if len(boxes_prompts) == 0: | |
boxes_mark_embedding = [] | |
else: | |
boxes_prompts = np.array(boxes_prompts) | |
boxes_prompts = expand2square_bbox(boxes_prompts, height=height, width=width) | |
boxes_prompts[:, [0, 2]] = boxes_prompts[:, [0, 2]] / max(height, width) * 1024 | |
boxes_prompts[:, [1, 3]] = boxes_prompts[:, [1, 3]] / max(height, width) * 1024 | |
boxes_prompts = torch.from_numpy(boxes_prompts) | |
boxes_prompts = boxes_prompts.cuda() | |
# using <region> token | |
region_token_id = omg_llava.region_token_idx | |
boxes_mark_embedding = get_points_embeddings( | |
boxes_prompts, input_ids, | |
1024, 1024, | |
region_token_id, mode='point') | |
return points_mark_embedding, boxes_mark_embedding | |
def inference(input_str, all_inputs, follow_up): | |
input_str = input_str.replace('<point>', '<mark>')\ | |
.replace('<box>', '<region>') | |
print("Get Recieved Infos !!!") | |
prompts = all_inputs['points'] | |
visual_prompts = parse_visual_prompts(prompts) | |
input_image = all_inputs['image'] | |
print("follow_up: ", follow_up) | |
print(prompts) | |
print("input_str: ", input_str, "input_image: ", input_image) | |
# | |
if not follow_up: | |
# reset | |
print('Log: History responses have been removed!') | |
global_infos.n_turn = 0 | |
global_infos.inputs = '' | |
# reset prompts | |
global_infos.point_prompts = [] | |
global_infos.box_prompts = [] | |
global_infos.mask_prompts = [] | |
# first conversation, add image tokens | |
text = DEFAULT_IMAGE_TOKEN + '\n' + input_str | |
# prepare image | |
image = load_image(input_image) | |
width, height = image.size | |
global_infos.image_width = width | |
global_infos.image_height = height | |
image = expand2square( | |
image, tuple(int(x * 255) for x in image_processor.image_mean)) | |
global_infos.image_for_show = image | |
image = image_processor.preprocess( | |
image, return_tensors='pt')['pixel_values'][0] | |
image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) | |
visual_outputs = visual_encoder(image, output_hidden_states=True) | |
pixel_values = projector(visual_outputs) | |
global_infos.panoptic_masks = omg_llava.visual_encoder.vis_binary_masks | |
global_infos.pixel_values = pixel_values | |
# for remove padding | |
if width == height: | |
sx, ex, sy, ey = 0, width, 0, height | |
elif width > height: | |
sy = int((width - height) / 2.0) | |
ey = width - sy | |
sx, ex = 0, width | |
else: | |
sx = int((height - width) / 2.0) | |
ex = height - sx | |
sy, ey = 0, height | |
global_infos.sx = sx | |
global_infos.sy = sy | |
global_infos.ex = ex | |
global_infos.ey = ey | |
else: | |
text = input_str | |
pixel_values = global_infos.pixel_values | |
# add cur prompts into global prompts | |
global_infos.point_prompts += visual_prompts['points'] | |
global_infos.box_prompts += visual_prompts['boxes'] | |
if args.prompt_template: | |
prompt_text = '' | |
template = PROMPT_TEMPLATE[args.prompt_template] | |
if 'SYSTEM' in template and global_infos.n_turn == 0: | |
system_text = None | |
if args.system_template is not None: | |
system_text = SYSTEM_TEMPLATE[ | |
args.system_template].format( | |
round=global_infos.n_turn + 1, bot_name=args.bot_name) | |
elif args.system is not None: | |
system_text = args.system | |
if system_text is not None: | |
prompt_text += template['SYSTEM'].format( | |
system=system_text, | |
round=global_infos.n_turn + 1, | |
bot_name=args.bot_name) | |
prompt_text += template['INSTRUCTION'].format( | |
input=text, round=global_infos.n_turn + 1, bot_name=args.bot_name) | |
else: | |
prompt_text = text | |
print("prompt_text: ", prompt_text) | |
global_infos.inputs += prompt_text | |
# encode prompt text | |
chunk_encode = [] | |
for idx, chunk in enumerate(global_infos.inputs.split(DEFAULT_IMAGE_TOKEN)): | |
if idx == 0 and global_infos.n_turn == 0: | |
cur_encode = tokenizer.encode(chunk) | |
else: | |
cur_encode = tokenizer.encode( | |
chunk, add_special_tokens=False) | |
chunk_encode.append(cur_encode) | |
assert len(chunk_encode) == 2 | |
ids = [] | |
for idx, cur_chunk_encode in enumerate(chunk_encode): | |
ids.extend(cur_chunk_encode) | |
if idx != len(chunk_encode) - 1: | |
ids.append(IMAGE_TOKEN_INDEX) | |
ids = torch.tensor(ids).cuda().unsqueeze(0) | |
points_mark_embeddings, boxes_mark_embeddings = get_visual_prompts_embeddings( | |
height=global_infos.image_height, | |
width=global_infos.image_width, input_ids=ids | |
) | |
mark_embeddings = points_mark_embeddings | |
mark_token_id = omg_llava.mark_token_idx | |
mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( | |
llm=llm, input_ids=ids, pixel_values=pixel_values, | |
mark_id=mark_token_id, | |
mark_feats=mark_embeddings, region_id=-9999) | |
# mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16) | |
generate_output = llm.generate( | |
**mm_inputs, | |
generation_config=gen_config, | |
streamer=streamer, | |
bos_token_id=tokenizer.bos_token_id, | |
stopping_criteria=stop_criteria, | |
output_hidden_states=True, | |
return_dict_in_generate=True | |
) | |
predict = tokenizer.decode( | |
generate_output.sequences[0]) | |
global_infos.inputs += predict | |
predict = predict.strip() | |
global_infos.n_turn += 1 | |
global_infos.inputs += sep | |
if len(generate_output.sequences[0]) >= args.max_new_tokens: | |
print( | |
'Remove the memory of history responses, since ' | |
f'it exceeds the length limitation {args.max_new_tokens}.') | |
global_infos.n_turn = 0 | |
global_infos.inputs = '' | |
hidden_states = generate_output.hidden_states | |
last_hidden_states = [item[-1][0] for item in hidden_states] | |
last_hidden_states = torch.cat(last_hidden_states, dim=0) | |
seg_hidden_states = get_seg_hidden_states( | |
last_hidden_states, generate_output.sequences[0][:-1], | |
seg_id=omg_llava.seg_token_idx | |
) | |
# seg_hidden_states = seg_hidden_states.to(torch.float32) | |
if len(seg_hidden_states) != 0: | |
seg_hidden_states = projector_text2vision(seg_hidden_states) | |
batch_idxs = torch.zeros((seg_hidden_states.shape[0],), | |
dtype=torch.int64).to(seg_hidden_states.device) | |
pred_masks_list = omg_llava.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) | |
print((pred_masks_list[-1].flatten(2) > 0).sum(-1)) | |
print(pred_masks_list[-1].shape) | |
image_mask_show, selected_colors = show_mask_pred( | |
global_infos.image_for_show, pred_masks_list[-1], | |
crop_range = (global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey) | |
) | |
else: | |
image_mask_show = global_infos.image_for_show.crop( | |
(global_infos.sx, global_infos.sy, global_infos.ex, global_infos.ey)) | |
selected_colors = [] | |
panoptic_show, _ = show_mask_pred( | |
global_infos.image_for_show, global_infos.panoptic_masks, | |
crop_range=(global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey) | |
) | |
predict = process_markdown(predict, selected_colors) | |
# return panoptic_show, image_mask_show, predict | |
return image_mask_show, predict | |
def init_models(args): | |
torch.manual_seed(args.seed) | |
# parse config | |
if not osp.isfile(args.config): | |
try: | |
args.config = cfgs_name_path[args.config] | |
except KeyError: | |
raise FileNotFoundError(f'Cannot find {args.config}') | |
# load config | |
cfg = Config.fromfile(args.config) | |
model_name = cfg.model.type if isinstance(cfg.model.type, | |
str) else cfg.model.type.__name__ | |
if 'LLaVAModel' or 'OMG' in model_name: | |
cfg.model.pretrained_pth = None | |
model = BUILDER.build(cfg.model) | |
backend = get_file_backend(args.pth_model) | |
if isinstance(backend, PetrelBackend): | |
from xtuner.utils.fileio import patch_fileio | |
with patch_fileio(): | |
state_dict = guess_load_checkpoint(args.pth_model) | |
else: | |
state_dict = guess_load_checkpoint(args.pth_model) | |
model.load_state_dict(state_dict, strict=False) | |
print(f'Load PTH model from {args.pth_model}') | |
image_processor = cfg.image_processor | |
image_processor_type = image_processor['type'] | |
del image_processor['type'] | |
image_processor = image_processor_type(**image_processor) | |
# build llm | |
quantization_config = None | |
load_in_8bit = False | |
if args.bits == 4: | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
load_in_8bit=False, | |
llm_int8_threshold=6.0, | |
llm_int8_has_fp16_weight=False, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type='nf4') | |
elif args.bits == 8: | |
load_in_8bit = True | |
model_kwargs = { | |
'quantization_config': quantization_config, | |
'load_in_8bit': load_in_8bit, | |
'device_map': 'auto', | |
'offload_folder': args.offload_folder, | |
'trust_remote_code': True, | |
'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] | |
} | |
inner_thoughts_open = False | |
calculate_open = False | |
solve_open = False | |
search_open = False | |
# build llm | |
llm = model.llm | |
tokenizer = model.tokenizer | |
model.cuda() | |
model.eval() | |
llm.eval() | |
visual_encoder = model.visual_encoder | |
projector = model.projector | |
projector_text2vision = model.projector_text2vision | |
visual_encoder.eval() | |
projector.eval() | |
projector_text2vision.eval() | |
return model, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision | |
def get_seg_hidden_states(hidden_states, output_ids, seg_id): | |
seg_mask = output_ids == seg_id | |
n_out = len(seg_mask) | |
print(output_ids) | |
return hidden_states[-n_out:][seg_mask] | |
class global_infos: | |
inputs = '' | |
n_turn = 0 | |
image_width = 0 | |
image_height = 0 | |
image_for_show = None | |
pixel_values = None | |
panoptic_masks = None | |
sx, sy, ex, ey = 0, 0 ,1024, 1024 | |
point_prompts = [] | |
box_prompts = [] | |
mask_prompts = [] | |
if __name__ == "__main__": | |
# get parse args and set models | |
args = parse_args(sys.argv[1:]) | |
omg_llava, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision = \ | |
init_models(args) | |
stop_words = args.stop_words | |
sep = '' | |
if args.prompt_template: | |
template = PROMPT_TEMPLATE[args.prompt_template] | |
stop_words += template.get('STOP_WORDS', []) | |
sep = template.get('SEP', '') | |
stop_criteria = get_stop_criteria( | |
tokenizer=tokenizer, stop_words=stop_words) | |
if args.no_streamer: | |
streamer = None | |
else: | |
streamer = TextStreamer(tokenizer, skip_prompt=True) | |
gen_config = GenerationConfig( | |
max_new_tokens=args.max_new_tokens, | |
do_sample=args.temperature > 0, | |
temperature=args.temperature, | |
top_p=args.top_p, | |
top_k=args.top_k, | |
repetition_penalty=args.repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id | |
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, | |
) | |
demo = gr.Interface( | |
inference, inputs=[gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), ImagePrompter( | |
type='filepath', label='Input Image (Please click points or draw bboxes)', interactive=True, | |
elem_id='image_upload', height=360, visible=True, render=True | |
), | |
gr.Checkbox(label="Follow up Question")], | |
outputs=[ | |
# gr.Image(type="pil", label="Panoptic Segmentation", height=360), | |
gr.Image(type="pil", label="Output Image"), | |
gr.Markdown()], | |
theme=gr.themes.Soft(), allow_flagging="auto", ) | |
demo.queue() | |
demo.launch(share=True) | |