Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
·
719ecfd
1
Parent(s):
995f2bf
[refactor] fix missing parse_args() functions, fix inference()
Browse files- app.py +54 -36
- utils/constants.py +47 -0
app.py
CHANGED
@@ -1,32 +1,20 @@
|
|
1 |
import argparse
|
2 |
-
import cv2
|
3 |
-
import gradio as gr
|
4 |
import json
|
5 |
import logging
|
6 |
-
import nh3
|
7 |
-
import numpy as np
|
8 |
import os
|
9 |
-
import re
|
10 |
import sys
|
11 |
-
import
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
15 |
from fastapi.staticfiles import StaticFiles
|
16 |
from fastapi.templating import Jinja2Templates
|
17 |
-
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
18 |
-
from typing import Callable
|
19 |
|
20 |
-
from
|
21 |
-
from model.llava import conversation as conversation_lib
|
22 |
-
from model.llava.mm_utils import tokenizer_image_token
|
23 |
-
from model.segment_anything.utils.transforms import ResizeLongestSide
|
24 |
-
from utils import session_logger
|
25 |
-
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
26 |
-
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
27 |
|
28 |
-
session_logger.change_logging(logging.DEBUG)
|
29 |
|
|
|
30 |
|
31 |
CUSTOM_GRADIO_PATH = "/"
|
32 |
app = FastAPI(title="lisa_app", version="1.0")
|
@@ -48,6 +36,37 @@ def health() -> str:
|
|
48 |
return json.dumps({"msg": "request failed"})
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@session_logger.set_uuid_logging
|
52 |
def get_cleaned_input(input_str):
|
53 |
logging.info(f"start cleaning of input_str: {input_str}.")
|
@@ -85,12 +104,11 @@ def get_inference_model_by_args(args_to_parse):
|
|
85 |
|
86 |
@session_logger.set_uuid_logging
|
87 |
def inference(input_str, input_image):
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
logging.info(f"
|
93 |
-
|
94 |
return output_image, output_str
|
95 |
|
96 |
return inference
|
@@ -100,20 +118,20 @@ def get_inference_model_by_args(args_to_parse):
|
|
100 |
def get_gradio_interface(fn_inference: Callable):
|
101 |
return gr.Interface(
|
102 |
fn_inference,
|
103 |
-
|
104 |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
|
105 |
gr.Image(type="filepath", label="Input Image")
|
106 |
-
|
107 |
-
|
108 |
gr.Image(type="pil", label="Segmentation Output"),
|
109 |
-
|
110 |
-
|
111 |
-
title=title,
|
112 |
-
description=description,
|
113 |
-
article=article,
|
114 |
-
examples=examples,
|
115 |
-
allow_flagging="auto"
|
116 |
-
)
|
117 |
|
118 |
|
119 |
args = parse_args(sys.argv[1:])
|
|
|
1 |
import argparse
|
|
|
|
|
2 |
import json
|
3 |
import logging
|
|
|
|
|
4 |
import os
|
|
|
5 |
import sys
|
6 |
+
from typing import Callable
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import nh3
|
10 |
+
from fastapi import FastAPI
|
11 |
from fastapi.staticfiles import StaticFiles
|
12 |
from fastapi.templating import Jinja2Templates
|
|
|
|
|
13 |
|
14 |
+
from utils import constants, session_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
|
17 |
+
session_logger.change_logging(logging.DEBUG)
|
18 |
|
19 |
CUSTOM_GRADIO_PATH = "/"
|
20 |
app = FastAPI(title="lisa_app", version="1.0")
|
|
|
36 |
return json.dumps({"msg": "request failed"})
|
37 |
|
38 |
|
39 |
+
@session_logger.set_uuid_logging
|
40 |
+
def parse_args(args_to_parse):
|
41 |
+
parser = argparse.ArgumentParser(description="LISA chat")
|
42 |
+
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
|
43 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
44 |
+
parser.add_argument(
|
45 |
+
"--precision",
|
46 |
+
default="fp16",
|
47 |
+
type=str,
|
48 |
+
choices=["fp32", "bf16", "fp16"],
|
49 |
+
help="precision for inference",
|
50 |
+
)
|
51 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
52 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
53 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
54 |
+
parser.add_argument(
|
55 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
56 |
+
)
|
57 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
58 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
59 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
60 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
61 |
+
parser.add_argument(
|
62 |
+
"--conv_type",
|
63 |
+
default="llava_v1",
|
64 |
+
type=str,
|
65 |
+
choices=["llava_v1", "llava_llama_2"],
|
66 |
+
)
|
67 |
+
return parser.parse_args(args_to_parse)
|
68 |
+
|
69 |
+
|
70 |
@session_logger.set_uuid_logging
|
71 |
def get_cleaned_input(input_str):
|
72 |
logging.info(f"start cleaning of input_str: {input_str}.")
|
|
|
104 |
|
105 |
@session_logger.set_uuid_logging
|
106 |
def inference(input_str, input_image):
|
107 |
+
logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
|
108 |
+
output_str = get_cleaned_input(input_str)
|
109 |
+
logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
|
110 |
+
output_image = input_image
|
111 |
+
logging.info(f"output_image type: {type(output_image)}.")
|
|
|
112 |
return output_image, output_str
|
113 |
|
114 |
return inference
|
|
|
118 |
def get_gradio_interface(fn_inference: Callable):
|
119 |
return gr.Interface(
|
120 |
fn_inference,
|
121 |
+
inputs=[
|
122 |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
|
123 |
gr.Image(type="filepath", label="Input Image")
|
124 |
+
],
|
125 |
+
outputs=[
|
126 |
gr.Image(type="pil", label="Segmentation Output"),
|
127 |
+
gr.Textbox(lines=1, placeholder=None, label="Text Output")
|
128 |
+
],
|
129 |
+
title=constants.title,
|
130 |
+
description=constants.description,
|
131 |
+
article=constants.article,
|
132 |
+
examples=constants.examples,
|
133 |
+
allow_flagging="auto"
|
134 |
+
)
|
135 |
|
136 |
|
137 |
args = parse_args(sys.argv[1:])
|
utils/constants.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio
|
2 |
+
examples = [
|
3 |
+
[
|
4 |
+
"Where can the driver see the car speed in this image? Please output segmentation mask.",
|
5 |
+
"./resources/imgs/example1.jpg",
|
6 |
+
],
|
7 |
+
[
|
8 |
+
"Can you segment the food that tastes spicy and hot?",
|
9 |
+
"./resources/imgs/example2.jpg",
|
10 |
+
],
|
11 |
+
[
|
12 |
+
"Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.",
|
13 |
+
"./resources/imgs/example1.jpg",
|
14 |
+
],
|
15 |
+
[
|
16 |
+
"What can make the woman stand higher? Please output segmentation mask and explain why.",
|
17 |
+
"./resources/imgs/example3.jpg",
|
18 |
+
],
|
19 |
+
]
|
20 |
+
output_labels = ["Segmentation Output"]
|
21 |
+
|
22 |
+
title = "LISA: Reasoning Segmentation via Large Language Model"
|
23 |
+
|
24 |
+
description = """
|
25 |
+
<font size=4>
|
26 |
+
This is the online demo of LISA. \n
|
27 |
+
If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
|
28 |
+
**Note**: **Different prompts can lead to significantly varied results**. \n
|
29 |
+
**Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
|
30 |
+
**Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n
|
31 |
+
**Usage**: <br>
|
32 |
+
 (1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask."; <br>
|
33 |
+
 (2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why."; <br>
|
34 |
+
 (3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA). <br>
|
35 |
+
Hope you can enjoy our work!
|
36 |
+
</font>
|
37 |
+
"""
|
38 |
+
|
39 |
+
article = """
|
40 |
+
<p style='text-align: center'>
|
41 |
+
<a href='https://arxiv.org/abs/2308.00692' target='_blank'>
|
42 |
+
Preprint Paper
|
43 |
+
</a>
|
44 |
+
\n
|
45 |
+
<p style='text-align: center'>
|
46 |
+
<a href='https://github.com/dvlab-research/LISA' target='_blank'> Github Repo </a></p>
|
47 |
+
"""
|