import argparse import logging import os import re from typing import Callable import cv2 import gradio as gr import nh3 import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor from lisa_on_cuda import app_logger from lisa_on_cuda.LISA import LISAForCausalLM from lisa_on_cuda.llava import conversation as conversation_lib from lisa_on_cuda.llava.mm_utils import tokenizer_image_token from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide from . import constants, utils placeholders = utils.create_placeholder_variables() def get_device_map_kwargs(device_map="auto", device="cuda"): kwargs = {"device_map": device_map} if device != "cuda": kwargs['device_map'] = {"": device} return kwargs def parse_args(args_to_parse, internal_logger=None): if internal_logger is None: internal_logger = app_logger internal_logger.info(f"ROOT_PROJECT:{utils.PROJECT_ROOT_FOLDER}, default vis_output:{utils.VIS_OUTPUT}.") parser = argparse.ArgumentParser(description="LISA chat") parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1-explanatory") parser.add_argument("--vis_save_path", default=str(utils.VIS_OUTPUT), type=str) parser.add_argument( "--precision", default="fp16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument( "--vision-tower", default="openai/clip-vit-large-patch14", type=str ) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=True) parser.add_argument("--use_mm_start_end", action="store_true", default=True) parser.add_argument( "--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"], ) return parser.parse_args(args_to_parse) def get_cleaned_input(input_str, internal_logger=None): if internal_logger is None: internal_logger = app_logger internal_logger.info(f"start cleaning of input_str: {input_str}.") input_str = nh3.clean( input_str, tags={ "a", "abbr", "acronym", "b", "blockquote", "code", "em", "i", "li", "ol", "strong", "ul", }, attributes={ "a": {"href", "title"}, "abbr": {"title"}, "acronym": {"title"}, }, url_schemes={"http", "https", "mailto"}, link_rel=None, ) internal_logger.info(f"cleaned input_str: {input_str}.") return input_str def set_image_precision_by_args(input_image, precision): if precision == "bf16": input_image = input_image.bfloat16() elif precision == "fp16": input_image = input_image.half() else: input_image = input_image.float() return input_image def preprocess( x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, ) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" logging.info("preprocess started") # Normalize colors x = (x - pixel_mean) / pixel_std # Pad h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) logging.info("preprocess ended") return x def load_model_for_causal_llm_pretrained( version, torch_dtype, load_in_8bit, load_in_4bit, seg_token_idx, vision_tower, internal_logger: logging = None, device_map="auto", device="cuda" ): if internal_logger is None: internal_logger = app_logger internal_logger.debug(f"prepare kwargs, 4bit:{load_in_4bit}, 8bit:{load_in_8bit}.") kwargs_device_map = get_device_map_kwargs(device_map=device_map, device=device) kwargs = {"torch_dtype": torch_dtype, **kwargs_device_map} if load_in_4bit: kwargs.update( { "torch_dtype": torch.half, # "load_in_4bit": True, "quantization_config": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["visual_model"], ), } ) elif load_in_8bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), } ) internal_logger.debug(f"start loading model:{version}.") _model = LISAForCausalLM.from_pretrained( version, low_cpu_mem_usage=True, vision_tower=vision_tower, seg_token_idx=seg_token_idx, # try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware (injected into kwargs) **kwargs ) internal_logger.debug("model loaded!") return _model def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"): """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2) Args: args_to_parse: default input arguments internal_logger: logger inference_decorator: inference decorator (now it's supported and tested ZeroGPU spaces.GPU decorator) device_map: device type needed for ZeroGPU cuda hw device: device type needed for ZeroGPU cuda hw device2: device type needed for ZeroGPU cuda hw, default to cpu to avoid bug on loading model Returns: inference function with LISA model """ if internal_logger is None: internal_logger = app_logger internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.") try: vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path) logging.info(f"vis_save_path_exists:{vis_save_path_exists}.") os.makedirs(args_to_parse.vis_save_path, exist_ok=True) except PermissionError as pex: internal_logger.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.") # global tokenizer, tokenizer # Create model internal_logger.info(f"creating tokenizer: {args_to_parse.version}, max_length:{args_to_parse.model_max_length}.") _tokenizer = AutoTokenizer.from_pretrained( args_to_parse.version, cache_dir=None, model_max_length=args_to_parse.model_max_length, padding_side="right", use_fast=False, ) _tokenizer.pad_token = _tokenizer.unk_token internal_logger.info("tokenizer ok") args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0] torch_dtype = torch.float32 if args_to_parse.precision == "bf16": torch_dtype = torch.bfloat16 elif args_to_parse.precision == "fp16": torch_dtype = torch.half internal_logger.debug(f"start loading causal llm:{args_to_parse.version}...") _model = inference_decorator( load_model_for_causal_llm_pretrained( args_to_parse.version, torch_dtype=torch_dtype, load_in_8bit=args_to_parse.load_in_8bit, load_in_4bit=args_to_parse.load_in_4bit, seg_token_idx=args_to_parse.seg_token_idx, vision_tower=args_to_parse.vision_tower, device_map=device_map, # try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware device=device )) if inference_decorator else load_model_for_causal_llm_pretrained( args_to_parse.version, torch_dtype=torch_dtype, load_in_8bit=args_to_parse.load_in_8bit, load_in_4bit=args_to_parse.load_in_4bit, seg_token_idx=args_to_parse.seg_token_idx, vision_tower=args_to_parse.vision_tower, device_map=device_map ) internal_logger.debug("causal llm loaded!") _model.config.eos_token_id = _tokenizer.eos_token_id _model.config.bos_token_id = _tokenizer.bos_token_id _model.config.pad_token_id = _tokenizer.pad_token_id _model.get_model().initialize_vision_modules(_model.get_model().config) internal_logger.debug(f"start vision tower:{args_to_parse.vision_tower}...") _model, vision_tower = inference_decorator( prepare_model_vision_tower(_model, args_to_parse, torch_dtype) ) if inference_decorator else prepare_model_vision_tower( _model, args_to_parse, torch_dtype ) internal_logger.debug(f"_model type:{type(_model)}, vision_tower type:{type(vision_tower)}.") # set device to "cuda" try to avoid CUDA init RuntimeError on ZeroGPU huggingface hardware vision_tower.to(device=device2) internal_logger.debug("vision tower loaded, prepare clip image processor...") _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower) internal_logger.debug("clip image processor done.") _transform = ResizeLongestSide(args_to_parse.image_size) internal_logger.debug("start model evaluation...") inference_decorator(_model.eval()) if inference_decorator else _model.eval() internal_logger.info("model preparation ok!") return _model, _clip_image_processor, _tokenizer, _transform def prepare_model_vision_tower(_model, args_to_parse, torch_dtype, internal_logger: logging = None): if internal_logger is None: internal_logger = app_logger internal_logger.debug(f"start vision tower preparation, torch dtype:{torch_dtype}, args_to_parse:{args_to_parse}.") vision_tower = _model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) if args_to_parse.precision == "bf16": internal_logger.debug(f"vision tower precision bf16? {args_to_parse.precision}, 1.") _model = _model.bfloat16().cuda() elif ( args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit) ): internal_logger.debug(f"vision tower precision fp16? {args_to_parse.precision}, 2.") vision_tower = _model.get_model().get_vision_tower() _model.model.vision_tower = None import deepspeed model_engine = deepspeed.init_inference( model=_model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) _model = model_engine.module _model.model.vision_tower = vision_tower.half().cuda() elif args_to_parse.precision == "fp32": internal_logger.debug(f"vision tower precision fp32? {args_to_parse.precision}, 3.") _model = _model.float().cuda() vision_tower = _model.get_model().get_vision_tower() internal_logger.debug("vision tower ok!") return _model, vision_tower def get_inference_model_by_args( args_to_parse, internal_logger0: logging = None, inference_decorator: Callable = None, device_map="auto", device="cuda" ): """Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2) Args: args_to_parse: default input arguments internal_logger0: logger inference_decorator: inference decorator (now it's supported and tested ZeroGPU spaces.GPU decorator) device_map: device type needed for ZeroGPU cuda hw device: device type needed for ZeroGPU cuda hw Returns: inference function with LISA model """ if internal_logger0 is None: internal_logger0 = app_logger internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...") model, clip_image_processor, tokenizer, transform = get_model(args_to_parse, device_map=device_map, device=device) internal_logger0.info("created model, preparing inference function") no_seg_out = placeholders["no_seg_out"] def inference( input_str: str, input_image: str | np.ndarray, internal_logger: logging = None, embedding_key: str = None ): if internal_logger is None: internal_logger = app_logger # filter out special chars input_str = get_cleaned_input(input_str) internal_logger.info(f" input_str type: {type(input_str)}, input_image type: {type(input_image)}.") internal_logger.info(f"input_str: {input_str}, input_image: {type(input_image)}.") # input valid check if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1: output_str = f"[Error] Unprocessable Entity input: {input_str}." internal_logger.error(output_str) from fastapi import status from fastapi.responses import JSONResponse return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"msg": "Error - Unprocessable Entity"} ) # Model Inference conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy() conv.messages = [] prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + input_str if args_to_parse.use_mm_start_end: replace_token = ( utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() internal_logger.info("read and preprocess image.") image_np = input_image if isinstance(input_image, str): image_np = cv2.imread(input_image) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] internal_logger.debug("start clip_image_processor.preprocess") image_clip = ( clip_image_processor.preprocess(image_np, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() ) internal_logger.debug("done clip_image_processor.preprocess") internal_logger.info(f"image_clip type: {type(image_clip)}.") image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision) image = transform.apply_image(image_np) resize_list = [image.shape[:2]] internal_logger.debug(f"starting preprocess image: {type(image_clip)}.") image = ( preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() ) internal_logger.info(f"done preprocess image:{type(image)}, image_clip type: {type(image_clip)}.") image = set_image_precision_by_args(image, args_to_parse.precision) input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() embedding_key = get_hash_array(embedding_key, image, internal_logger) internal_logger.info(f"start model evaluation with embedding_key {embedding_key}.") output_ids, pred_masks = model.evaluate( image_clip, image, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer, model_logger=internal_logger, embedding_key=embedding_key ) internal_logger.info("model evaluation done, start token decoding...") output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX] text_output = tokenizer.decode(output_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") text_output = text_output.split("ASSISTANT: ")[-1] internal_logger.info( f"token decoding ended,found n {len(pred_masks)} prediction masks, " f"text_output type: {type(text_output)}, text_output: {text_output}." ) output_image = no_seg_out output_mask = no_seg_out for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0 or pred_mask.shape[1] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask_bool = pred_mask > 0 output_mask = pred_mask_bool.astype(np.uint8) * 255 output_image = image_np.copy() output_image[pred_mask_bool] = ( image_np * 0.5 + pred_mask_bool[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask_bool] output_str = f"ASSISTANT: {text_output} ..." internal_logger.info(f"output_image type: {type(output_mask)}.") return output_image, output_mask, output_str internal_logger0.info("prepared inference function.") internal_logger0.info(f"inference decorator none? {type(inference_decorator)}.") if inference_decorator: return inference_decorator(inference) return inference def get_gradio_interface( fn_inference: Callable, args: str = None ): article_and_demo_parameters = constants.article if args is not None: article_and_demo_parameters = constants.demo_parameters args_dict = {arg: getattr(args, arg) for arg in vars(args)} for arg_k, arg_v in args_dict.items(): print(f"arg_k:{arg_v}, arg_v:{arg_v}.") article_and_demo_parameters += " * " + "".join(f"{arg_k}: {arg_v};\n") print(f"args_dict:{args_dict}.") print(f"description_and_demo_parameters:{article_and_demo_parameters}.") article_and_demo_parameters += "\n\n" + constants.article return gr.Interface( fn_inference, inputs=[ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), gr.Image(type="filepath", label="Input Image") ], outputs=[ gr.Image(type="pil", label="segmentation Output"), gr.Image(type="pil", label="mask Output"), gr.Textbox(lines=1, placeholder=None, label="Text Output") ], title=constants.title, description=constants.description, article=article_and_demo_parameters, examples=constants.examples, allow_flagging="auto" ) def get_hash_array(embedding_key: str, arr: np.ndarray | torch.Tensor, model_logger: logging): from base64 import b64encode from hashlib import sha256 model_logger.debug(f"embedding_key {embedding_key} is None? {embedding_key is None}.") if embedding_key is None: img2hash = arr if isinstance(arr, torch.Tensor): model_logger.debug("images variable is a Tensor, start converting back to numpy") img2hash = arr.numpy(force=True) model_logger.debug("done Tensor converted back to numpy") model_logger.debug("start image hashing") img2hash_fn = sha256(img2hash) embedding_key = b64encode(img2hash_fn.digest()) embedding_key = embedding_key.decode("utf-8") model_logger.debug(f"done image hashing, now embedding_key is {embedding_key}.") return embedding_key if __name__ == '__main__': parsed_args = parse_args([]) print("arrrrg:", parsed_args)