alessandro trinca tornidor commited on
Commit
719ecfd
·
1 Parent(s): 995f2bf

[refactor] fix missing parse_args() functions, fix inference()

Browse files
Files changed (2) hide show
  1. app.py +54 -36
  2. utils/constants.py +47 -0
app.py CHANGED
@@ -1,32 +1,20 @@
1
  import argparse
2
- import cv2
3
- import gradio as gr
4
  import json
5
  import logging
6
- import nh3
7
- import numpy as np
8
  import os
9
- import re
10
  import sys
11
- import torch
12
- import torch.nn.functional as F
13
- from fastapi import FastAPI, File, UploadFile, Request
14
- from fastapi.responses import HTMLResponse, RedirectResponse
 
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
- from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
- from typing import Callable
19
 
20
- from model.LISA import LISAForCausalLM
21
- from model.llava import conversation as conversation_lib
22
- from model.llava.mm_utils import tokenizer_image_token
23
- from model.segment_anything.utils.transforms import ResizeLongestSide
24
- from utils import session_logger
25
- from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
26
- DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
27
 
28
- session_logger.change_logging(logging.DEBUG)
29
 
 
30
 
31
  CUSTOM_GRADIO_PATH = "/"
32
  app = FastAPI(title="lisa_app", version="1.0")
@@ -48,6 +36,37 @@ def health() -> str:
48
  return json.dumps({"msg": "request failed"})
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @session_logger.set_uuid_logging
52
  def get_cleaned_input(input_str):
53
  logging.info(f"start cleaning of input_str: {input_str}.")
@@ -85,12 +104,11 @@ def get_inference_model_by_args(args_to_parse):
85
 
86
  @session_logger.set_uuid_logging
87
  def inference(input_str, input_image):
88
- ## filter out special chars
89
-
90
- input_str = get_cleaned_input(input_str)
91
- logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.")
92
- logging.info(f"input_str: {input_str}.")
93
-
94
  return output_image, output_str
95
 
96
  return inference
@@ -100,20 +118,20 @@ def get_inference_model_by_args(args_to_parse):
100
  def get_gradio_interface(fn_inference: Callable):
101
  return gr.Interface(
102
  fn_inference,
103
- inputs=[
104
  gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
105
  gr.Image(type="filepath", label="Input Image")
106
- ],
107
- outputs=[
108
  gr.Image(type="pil", label="Segmentation Output"),
109
- gr.Textbox(lines=1, placeholder=None, label="Text Output"),
110
- ],
111
- title=title,
112
- description=description,
113
- article=article,
114
- examples=examples,
115
- allow_flagging="auto",
116
- )
117
 
118
 
119
  args = parse_args(sys.argv[1:])
 
1
  import argparse
 
 
2
  import json
3
  import logging
 
 
4
  import os
 
5
  import sys
6
+ from typing import Callable
7
+
8
+ import gradio as gr
9
+ import nh3
10
+ from fastapi import FastAPI
11
  from fastapi.staticfiles import StaticFiles
12
  from fastapi.templating import Jinja2Templates
 
 
13
 
14
+ from utils import constants, session_logger
 
 
 
 
 
 
15
 
 
16
 
17
+ session_logger.change_logging(logging.DEBUG)
18
 
19
  CUSTOM_GRADIO_PATH = "/"
20
  app = FastAPI(title="lisa_app", version="1.0")
 
36
  return json.dumps({"msg": "request failed"})
37
 
38
 
39
+ @session_logger.set_uuid_logging
40
+ def parse_args(args_to_parse):
41
+ parser = argparse.ArgumentParser(description="LISA chat")
42
+ parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
43
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
44
+ parser.add_argument(
45
+ "--precision",
46
+ default="fp16",
47
+ type=str,
48
+ choices=["fp32", "bf16", "fp16"],
49
+ help="precision for inference",
50
+ )
51
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
52
+ parser.add_argument("--model_max_length", default=512, type=int)
53
+ parser.add_argument("--lora_r", default=8, type=int)
54
+ parser.add_argument(
55
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
56
+ )
57
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
58
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
59
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
60
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
61
+ parser.add_argument(
62
+ "--conv_type",
63
+ default="llava_v1",
64
+ type=str,
65
+ choices=["llava_v1", "llava_llama_2"],
66
+ )
67
+ return parser.parse_args(args_to_parse)
68
+
69
+
70
  @session_logger.set_uuid_logging
71
  def get_cleaned_input(input_str):
72
  logging.info(f"start cleaning of input_str: {input_str}.")
 
104
 
105
  @session_logger.set_uuid_logging
106
  def inference(input_str, input_image):
107
+ logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
108
+ output_str = get_cleaned_input(input_str)
109
+ logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
110
+ output_image = input_image
111
+ logging.info(f"output_image type: {type(output_image)}.")
 
112
  return output_image, output_str
113
 
114
  return inference
 
118
  def get_gradio_interface(fn_inference: Callable):
119
  return gr.Interface(
120
  fn_inference,
121
+ inputs=[
122
  gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
123
  gr.Image(type="filepath", label="Input Image")
124
+ ],
125
+ outputs=[
126
  gr.Image(type="pil", label="Segmentation Output"),
127
+ gr.Textbox(lines=1, placeholder=None, label="Text Output")
128
+ ],
129
+ title=constants.title,
130
+ description=constants.description,
131
+ article=constants.article,
132
+ examples=constants.examples,
133
+ allow_flagging="auto"
134
+ )
135
 
136
 
137
  args = parse_args(sys.argv[1:])
utils/constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio
2
+ examples = [
3
+ [
4
+ "Where can the driver see the car speed in this image? Please output segmentation mask.",
5
+ "./resources/imgs/example1.jpg",
6
+ ],
7
+ [
8
+ "Can you segment the food that tastes spicy and hot?",
9
+ "./resources/imgs/example2.jpg",
10
+ ],
11
+ [
12
+ "Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.",
13
+ "./resources/imgs/example1.jpg",
14
+ ],
15
+ [
16
+ "What can make the woman stand higher? Please output segmentation mask and explain why.",
17
+ "./resources/imgs/example3.jpg",
18
+ ],
19
+ ]
20
+ output_labels = ["Segmentation Output"]
21
+
22
+ title = "LISA: Reasoning Segmentation via Large Language Model"
23
+
24
+ description = """
25
+ <font size=4>
26
+ This is the online demo of LISA. \n
27
+ If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
28
+ **Note**: **Different prompts can lead to significantly varied results**. \n
29
+ **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
30
+ **Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n
31
+ **Usage**: <br>
32
+ &ensp;(1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask."; <br>
33
+ &ensp;(2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why."; <br>
34
+ &ensp;(3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA). <br>
35
+ Hope you can enjoy our work!
36
+ </font>
37
+ """
38
+
39
+ article = """
40
+ <p style='text-align: center'>
41
+ <a href='https://arxiv.org/abs/2308.00692' target='_blank'>
42
+ Preprint Paper
43
+ </a>
44
+ \n
45
+ <p style='text-align: center'>
46
+ <a href='https://github.com/dvlab-research/LISA' target='_blank'> Github Repo </a></p>
47
+ """