import logging import os import sys import cv2 import numpy as np import torch from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor from model.LISA import LISAForCausalLM from model.llava import conversation as conversation_lib from model.llava.mm_utils import tokenizer_image_token from model.segment_anything.utils.transforms import ResizeLongestSide from utils import app_helpers, utils def main(args): args = app_helpers.parse_args(args) os.makedirs(args.vis_save_path, exist_ok=True) # Create model tokenizer = AutoTokenizer.from_pretrained( args.version, cache_dir=None, model_max_length=args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] torch_dtype = change_torch_dtype_by_precision(args.precision) kwargs = {"torch_dtype": torch_dtype} if args.load_in_4bit: kwargs.update( { "torch_dtype": torch.half, "load_in_4bit": True, "quantization_config": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["visual_model"], ), } ) elif args.load_in_8bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), } ) model = LISAForCausalLM.from_pretrained( args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs ) model.config.eos_token_id = tokenizer.eos_token_id model.config.bos_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id model.get_model().initialize_vision_modules(model.get_model().config) vision_tower = model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) if args.precision == "bf16": model = model.bfloat16().cuda() elif ( args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) ): vision_tower = model.get_model().get_vision_tower() model.model.vision_tower = None import deepspeed model_engine = deepspeed.init_inference( model=model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) model = model_engine.module model.model.vision_tower = vision_tower.half().cuda() elif args.precision == "fp32": model = model.float().cuda() vision_tower = model.get_model().get_vision_tower() vision_tower.to(device=args.local_rank) clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) transform = ResizeLongestSide(args.image_size) model.eval() while True: conv = conversation_lib.conv_templates[args.conv_type].copy() conv.messages = [] prompt = input("Please input your prompt: ") prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt if args.use_mm_start_end: replace_token = ( utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() image_path = input("Please input the image path: ") if not os.path.exists(image_path): print("File not found in {}".format(image_path)) continue image_np = cv2.imread(image_path) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] image_clip = ( clip_image_processor.preprocess(image_np, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image_clip = app_helpers.set_image_precision_by_args(image_clip, args.precision) image = transform.apply_image(image_np) resize_list = [image.shape[:2]] image = ( app_helpers.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image = app_helpers.set_image_precision_by_args(image, args.precision) input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() output_ids, pred_masks = model.evaluate( image_clip, image, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer, ) output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX] text_output = tokenizer.decode(output_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") logging.info(f"text_output: {text_output}.") for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 save_path = "{}/{}_mask_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) cv2.imwrite(save_path, pred_mask * 100) print("{} has been saved.".format(save_path)) save_path = "{}/{}_masked_img_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) save_img = image_np.copy() save_img[pred_mask] = ( image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask] save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) cv2.imwrite(save_path, save_img) print("{} has been saved.".format(save_path)) def change_torch_dtype_by_precision(precision): torch_dtype = torch.float32 if precision == "bf16": torch_dtype = torch.bfloat16 elif precision == "fp16": torch_dtype = torch.half return torch_dtype if __name__ == "__main__": main(sys.argv[1:])