Spaces:
Paused
Paused
import logging | |
import os | |
import sys | |
import cv2 | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor | |
from model.LISA import LISAForCausalLM | |
from model.llava import conversation as conversation_lib | |
from model.llava.mm_utils import tokenizer_image_token | |
from model.segment_anything.utils.transforms import ResizeLongestSide | |
from utils import app_helpers, utils | |
def main(args): | |
args = app_helpers.parse_args(args) | |
os.makedirs(args.vis_save_path, exist_ok=True) | |
# Create model | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.version, | |
cache_dir=None, | |
model_max_length=args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
tokenizer.pad_token = tokenizer.unk_token | |
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
torch_dtype = change_torch_dtype_by_precision(args.precision) | |
kwargs = {"torch_dtype": torch_dtype} | |
if args.load_in_4bit: | |
kwargs.update( | |
{ | |
"torch_dtype": torch.half, | |
"load_in_4bit": True, | |
"quantization_config": BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
llm_int8_skip_modules=["visual_model"], | |
), | |
} | |
) | |
elif args.load_in_8bit: | |
kwargs.update( | |
{ | |
"torch_dtype": torch.half, | |
"quantization_config": BitsAndBytesConfig( | |
llm_int8_skip_modules=["visual_model"], | |
load_in_8bit=True, | |
), | |
} | |
) | |
model = LISAForCausalLM.from_pretrained( | |
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs | |
) | |
model.config.eos_token_id = tokenizer.eos_token_id | |
model.config.bos_token_id = tokenizer.bos_token_id | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.get_model().initialize_vision_modules(model.get_model().config) | |
vision_tower = model.get_model().get_vision_tower() | |
vision_tower.to(dtype=torch_dtype) | |
if args.precision == "bf16": | |
model = model.bfloat16().cuda() | |
elif ( | |
args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) | |
): | |
vision_tower = model.get_model().get_vision_tower() | |
model.model.vision_tower = None | |
import deepspeed | |
model_engine = deepspeed.init_inference( | |
model=model, | |
dtype=torch.half, | |
replace_with_kernel_inject=True, | |
replace_method="auto", | |
) | |
model = model_engine.module | |
model.model.vision_tower = vision_tower.half().cuda() | |
elif args.precision == "fp32": | |
model = model.float().cuda() | |
vision_tower = model.get_model().get_vision_tower() | |
vision_tower.to(device=args.local_rank) | |
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) | |
transform = ResizeLongestSide(args.image_size) | |
model.eval() | |
while True: | |
conv = conversation_lib.conv_templates[args.conv_type].copy() | |
conv.messages = [] | |
prompt = input("Please input your prompt: ") | |
prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
if args.use_mm_start_end: | |
replace_token = ( | |
utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN | |
) | |
prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token) | |
conv.append_message(conv.roles[0], prompt) | |
conv.append_message(conv.roles[1], "") | |
prompt = conv.get_prompt() | |
image_path = input("Please input the image path: ") | |
if not os.path.exists(image_path): | |
print("File not found in {}".format(image_path)) | |
continue | |
image_np = cv2.imread(image_path) | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
original_size_list = [image_np.shape[:2]] | |
image_clip = ( | |
clip_image_processor.preprocess(image_np, return_tensors="pt")[ | |
"pixel_values" | |
][0] | |
.unsqueeze(0) | |
.cuda() | |
) | |
logging.info(f"image_clip type: {type(image_clip)}.") | |
image_clip = app_helpers.set_image_precision_by_args(image_clip, args.precision) | |
image = transform.apply_image(image_np) | |
resize_list = [image.shape[:2]] | |
image = ( | |
app_helpers.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) | |
.unsqueeze(0) | |
.cuda() | |
) | |
logging.info(f"image_clip type: {type(image_clip)}.") | |
image = app_helpers.set_image_precision_by_args(image, args.precision) | |
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
input_ids = input_ids.unsqueeze(0).cuda() | |
output_ids, pred_masks = model.evaluate( | |
image_clip, | |
image, | |
input_ids, | |
resize_list, | |
original_size_list, | |
max_new_tokens=512, | |
tokenizer=tokenizer, | |
) | |
output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX] | |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False) | |
text_output = text_output.replace("\n", "").replace(" ", " ") | |
logging.info(f"text_output: {text_output}.") | |
for i, pred_mask in enumerate(pred_masks): | |
if pred_mask.shape[0] == 0: | |
continue | |
pred_mask = pred_mask.detach().cpu().numpy()[0] | |
pred_mask = pred_mask > 0 | |
save_path = "{}/{}_mask_{}.jpg".format( | |
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i | |
) | |
cv2.imwrite(save_path, pred_mask * 100) | |
print("{} has been saved.".format(save_path)) | |
save_path = "{}/{}_masked_img_{}.jpg".format( | |
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i | |
) | |
save_img = image_np.copy() | |
save_img[pred_mask] = ( | |
image_np * 0.5 | |
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 | |
)[pred_mask] | |
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) | |
cv2.imwrite(save_path, save_img) | |
print("{} has been saved.".format(save_path)) | |
def change_torch_dtype_by_precision(precision): | |
torch_dtype = torch.float32 | |
if precision == "bf16": | |
torch_dtype = torch.bfloat16 | |
elif precision == "fp16": | |
torch_dtype = torch.half | |
return torch_dtype | |
if __name__ == "__main__": | |
main(sys.argv[1:]) | |