Spaces:
Running
Running
Niki Zhang
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -37,8 +37,9 @@ import requests
|
|
37 |
import spaces
|
38 |
# Print the current version of LangChain
|
39 |
print(f"Current LangChain version: {__version__}")
|
40 |
-
|
41 |
print("testing testing")
|
|
|
|
|
42 |
# import tts
|
43 |
|
44 |
###############################################################################
|
@@ -94,21 +95,6 @@ from huggingface_hub import hf_hub_download
|
|
94 |
|
95 |
|
96 |
|
97 |
-
# import logging
|
98 |
-
|
99 |
-
# logging.basicConfig(level=logging.DEBUG)
|
100 |
-
|
101 |
-
# logger = logging.getLogger(__name__)
|
102 |
-
|
103 |
-
# def my_function(input_text):
|
104 |
-
# logger.info(f'Received input: {input_text}')
|
105 |
-
# return "Output: " + input_text
|
106 |
-
|
107 |
-
import sys
|
108 |
-
|
109 |
-
# 设置无缓冲输出
|
110 |
-
sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', buffering=1)
|
111 |
-
sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', buffering=1)
|
112 |
|
113 |
# def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
114 |
# """
|
@@ -385,97 +371,97 @@ def infer(image_path):
|
|
385 |
############# this part is for text to image #############
|
386 |
###############################################################################
|
387 |
|
388 |
-
# Use environment variables for flexibility
|
389 |
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
|
390 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
391 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
392 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
393 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
|
394 |
|
395 |
-
# Determine device and load model outside of function for efficiency
|
396 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
397 |
-
pipe = StableDiffusionXLPipeline.from_pretrained(
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
).to(device)
|
403 |
-
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
404 |
|
405 |
-
# Torch compile for potential speedup (experimental)
|
406 |
-
if USE_TORCH_COMPILE:
|
407 |
-
|
408 |
|
409 |
-
# CPU offloading for larger RAM capacity (experimental)
|
410 |
-
if ENABLE_CPU_OFFLOAD:
|
411 |
-
|
412 |
|
413 |
MAX_SEED = np.iinfo(np.int32).max
|
414 |
|
415 |
-
def save_image(img):
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
|
420 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
|
425 |
# @spaces.GPU(duration=30, queue=False)
|
426 |
-
def generate(
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
):
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
examples = [
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
]
|
479 |
|
480 |
|
481 |
|
@@ -485,6 +471,8 @@ examples = [
|
|
485 |
###############################################################################
|
486 |
|
487 |
|
|
|
|
|
488 |
css = """
|
489 |
#warning {background-color: #FFCCCB}
|
490 |
.tools_button {
|
@@ -492,6 +480,18 @@ css = """
|
|
492 |
border: none !important;
|
493 |
box-shadow: none !important;
|
494 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
#tool_box {max-width: 50px}
|
496 |
|
497 |
"""
|
@@ -547,18 +547,48 @@ args = parse_augment()
|
|
547 |
args.segmenter = "huge"
|
548 |
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
|
549 |
args.clip_filter = True
|
550 |
-
if args.segmenter_checkpoint is None:
|
551 |
-
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
552 |
-
else:
|
553 |
-
segmenter_checkpoint = args.segmenter_checkpoint
|
554 |
-
|
555 |
-
shared_captioner = build_captioner(args.captioner, args.device, args)
|
556 |
-
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
557 |
-
ocr_lang = ["ch_tra", "en"]
|
558 |
-
shared_ocr_reader = easyocr.Reader(ocr_lang)
|
559 |
-
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
560 |
-
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
# class ImageSketcher(gr.Image):
|
564 |
# """
|
@@ -595,15 +625,15 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
|
|
595 |
|
596 |
def validate_api_key(api_key):
|
597 |
api_key = str(api_key).strip()
|
598 |
-
print(api_key
|
599 |
try:
|
600 |
test_llm = ChatOpenAI(model_name="gpt-4o", temperature=0, openai_api_key=api_key)
|
601 |
-
print("test_llm"
|
602 |
response = test_llm([HumanMessage(content='Hello')])
|
603 |
-
print(response
|
604 |
return True
|
605 |
except Exception as e:
|
606 |
-
print(f"API key validation failed: {e}"
|
607 |
return False
|
608 |
|
609 |
|
@@ -612,23 +642,23 @@ def init_openai_api_key(api_key=""):
|
|
612 |
text_refiner = None
|
613 |
visual_chatgpt = None
|
614 |
if api_key and len(api_key) > 30:
|
615 |
-
print(api_key
|
616 |
if validate_api_key(api_key):
|
617 |
try:
|
618 |
# text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
619 |
# assert len(text_refiner.llm('hi')) > 0 # test
|
620 |
text_refiner = None
|
621 |
-
print("text refiner"
|
622 |
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key=api_key)
|
623 |
except Exception as e:
|
624 |
-
print(f"Error initializing TextRefiner or ConversationBot: {e}"
|
625 |
text_refiner = None
|
626 |
visual_chatgpt = None
|
627 |
else:
|
628 |
-
print("Invalid API key."
|
629 |
else:
|
630 |
-
print("API key is too short."
|
631 |
-
print(text_refiner
|
632 |
openai_available = text_refiner is not None
|
633 |
if visual_chatgpt:
|
634 |
|
@@ -704,7 +734,8 @@ async def chat_input_callback(*args):
|
|
704 |
|
705 |
|
706 |
|
707 |
-
def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
|
|
|
708 |
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
709 |
image_input = image_input['background']
|
710 |
|
@@ -748,13 +779,29 @@ def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None
|
|
748 |
name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["style"]
|
749 |
# artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
|
750 |
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
|
759 |
return [state, state, image_input, click_state, image_input, image_input, image_input, image_input, image_embedding, \
|
760 |
original_size, input_size] + [f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Style: {material}"]*4 + [paragraph,artist]
|
@@ -795,9 +842,9 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
795 |
|
796 |
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
797 |
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
|
798 |
-
|
799 |
|
800 |
-
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
801 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
802 |
text = out['generated_captions']['raw_caption']
|
803 |
input_mask = np.array(out['mask'].convert('P'))
|
@@ -823,11 +870,20 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
823 |
yield state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
|
824 |
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
|
827 |
async def submit_caption(state,length, sentiment, factuality, language,
|
828 |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
|
829 |
autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
|
830 |
-
|
|
|
|
|
831 |
|
832 |
click_index = click_index_state
|
833 |
|
@@ -1110,8 +1166,7 @@ def clear_chat_memory(visual_chatgpt, keep_global=False):
|
|
1110 |
visual_chatgpt.point_prompt = ""
|
1111 |
if keep_global:
|
1112 |
# visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
|
1113 |
-
visual_chatgpt.agent.memory.save_context({"input": visual_chatgpt.global_prompt}, {"output":
|
1114 |
-
print("test")
|
1115 |
else:
|
1116 |
visual_chatgpt.current_image = None
|
1117 |
visual_chatgpt.global_prompt = ""
|
@@ -1345,10 +1400,10 @@ def print_like_dislike(x: gr.LikeData,like_res,dislike_res,state):
|
|
1345 |
|
1346 |
def toggle_icons_and_update_prompt(point_prompt):
|
1347 |
new_prompt = "Negative" if point_prompt == "Positive" else "Positive"
|
1348 |
-
new_add_icon = "assets/icons/plus-square-blue.png" if
|
1349 |
-
new_minus_icon = "assets/icons/minus-square.png" if
|
1350 |
-
print(point_prompt)
|
1351 |
-
print(new_prompt)
|
1352 |
|
1353 |
return new_prompt, gr.update(icon=new_add_icon), gr.update(icon=new_minus_icon)
|
1354 |
|
@@ -1358,6 +1413,7 @@ minus_icon_path="assets/icons/minus-square.png"
|
|
1358 |
print("this is a print test")
|
1359 |
|
1360 |
def create_ui():
|
|
|
1361 |
title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
|
1362 |
"""
|
1363 |
description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
|
@@ -1453,38 +1509,38 @@ def create_ui():
|
|
1453 |
with gr.Tab("Base(GPT Power)") as base_tab:
|
1454 |
image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1455 |
with gr.Row():
|
1456 |
-
name_label_base = gr.Button(value="Name: ")
|
1457 |
-
artist_label_base = gr.Button(value="Artist: ")
|
1458 |
-
year_label_base = gr.Button(value="Year: ")
|
1459 |
-
material_label_base = gr.Button(value="Style: ")
|
1460 |
|
1461 |
with gr.Tab("Base2") as base_tab2:
|
1462 |
image_input_base_2 = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1463 |
with gr.Row():
|
1464 |
-
name_label_base2 = gr.Button(value="Name: ")
|
1465 |
-
artist_label_base2 = gr.Button(value="Artist: ")
|
1466 |
-
year_label_base2 = gr.Button(value="Year: ")
|
1467 |
-
material_label_base2 = gr.Button(value="Style: ")
|
1468 |
|
1469 |
with gr.Tab("Click") as click_tab:
|
1470 |
with gr.Row():
|
1471 |
-
with gr.Column(scale=10,min_width=
|
1472 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1473 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
1474 |
with gr.Row():
|
1475 |
-
name_label = gr.Button(value="Name: ")
|
1476 |
-
artist_label = gr.Button(value="Artist: ")
|
1477 |
-
year_label = gr.Button(value="Year: ")
|
1478 |
-
material_label = gr.Button(value="Style: ")
|
1479 |
|
1480 |
|
1481 |
# example_image_click = gr.Image(type="pil", interactive=False, visible=False)
|
1482 |
# the tool column
|
1483 |
-
with gr.Column(scale=1,elem_id="tool_box",min_width=
|
1484 |
add_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=add_icon_path)
|
1485 |
minus_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=minus_icon_path)
|
1486 |
clear_button_click = gr.Button(value="Reset", interactive=True,elem_classes="tools_button")
|
1487 |
-
clear_button_image = gr.Button(value="Change
|
1488 |
focus_d = gr.Button(value="D",interactive=True,elem_classes="function_button")
|
1489 |
focus_da = gr.Button(value="DA",interactive=True,elem_classes="function_button")
|
1490 |
focus_dai = gr.Button(value="DAI",interactive=True,elem_classes="function_button")
|
@@ -2017,7 +2073,7 @@ def create_ui():
|
|
2017 |
|
2018 |
|
2019 |
|
2020 |
-
image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key],
|
2021 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
|
2022 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
|
2023 |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
|
@@ -2049,11 +2105,12 @@ def create_ui():
|
|
2049 |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
2050 |
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
2051 |
[chatbot, state, aux_state,output_audio])
|
2052 |
-
chat_input.submit(lambda: "", None, chat_input)
|
|
|
2053 |
# submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
2054 |
# [chatbot, state, aux_state,output_audio])
|
2055 |
# submit_button_text.click(lambda: "", None, chat_input)
|
2056 |
-
example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key],
|
2057 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
|
2058 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
|
2059 |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
|
@@ -2220,7 +2277,7 @@ def create_ui():
|
|
2220 |
|
2221 |
|
2222 |
if __name__ == '__main__':
|
2223 |
-
|
2224 |
iface = create_ui()
|
2225 |
iface.queue(api_open=False, max_size=10)
|
2226 |
# iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
|
|
37 |
import spaces
|
38 |
# Print the current version of LangChain
|
39 |
print(f"Current LangChain version: {__version__}")
|
|
|
40 |
print("testing testing")
|
41 |
+
|
42 |
+
|
43 |
# import tts
|
44 |
|
45 |
###############################################################################
|
|
|
95 |
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
100 |
# """
|
|
|
371 |
############# this part is for text to image #############
|
372 |
###############################################################################
|
373 |
|
374 |
+
# # Use environment variables for flexibility
|
375 |
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
|
376 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
377 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
378 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
379 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
|
380 |
|
381 |
+
# # Determine device and load model outside of function for efficiency
|
382 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
383 |
+
# pipe = StableDiffusionXLPipeline.from_pretrained(
|
384 |
+
# MODEL_ID,
|
385 |
+
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
386 |
+
# use_safetensors=True,
|
387 |
+
# add_watermarker=False,
|
388 |
+
# ).to(device)
|
389 |
+
# pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
390 |
|
391 |
+
# # Torch compile for potential speedup (experimental)
|
392 |
+
# if USE_TORCH_COMPILE:
|
393 |
+
# pipe.compile()
|
394 |
|
395 |
+
# # CPU offloading for larger RAM capacity (experimental)
|
396 |
+
# if ENABLE_CPU_OFFLOAD:
|
397 |
+
# pipe.enable_model_cpu_offload()
|
398 |
|
399 |
MAX_SEED = np.iinfo(np.int32).max
|
400 |
|
401 |
+
# def save_image(img):
|
402 |
+
# unique_name = str(uuid.uuid4()) + ".png"
|
403 |
+
# img.save(unique_name)
|
404 |
+
# return unique_name
|
405 |
|
406 |
+
# def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
407 |
+
# if randomize_seed:
|
408 |
+
# seed = random.randint(0, MAX_SEED)
|
409 |
+
# return seed
|
410 |
|
411 |
# @spaces.GPU(duration=30, queue=False)
|
412 |
+
# def generate(
|
413 |
+
# prompt: str,
|
414 |
+
# negative_prompt: str = "",
|
415 |
+
# use_negative_prompt: bool = False,
|
416 |
+
# seed: int = 1,
|
417 |
+
# width: int = 200,
|
418 |
+
# height: int = 200,
|
419 |
+
# guidance_scale: float = 3,
|
420 |
+
# num_inference_steps: int = 30,
|
421 |
+
# randomize_seed: bool = False,
|
422 |
+
# num_images: int = 4, # Number of images to generate
|
423 |
+
# use_resolution_binning: bool = True,
|
424 |
+
# progress=gr.Progress(track_tqdm=True),
|
425 |
+
# ):
|
426 |
+
# seed = int(randomize_seed_fn(seed, randomize_seed))
|
427 |
+
# generator = torch.Generator(device=device).manual_seed(seed)
|
428 |
+
|
429 |
+
# # Improved options handling
|
430 |
+
# options = {
|
431 |
+
# "prompt": [prompt] * num_images,
|
432 |
+
# "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
|
433 |
+
# "width": width,
|
434 |
+
# "height": height,
|
435 |
+
# "guidance_scale": guidance_scale,
|
436 |
+
# "num_inference_steps": num_inference_steps,
|
437 |
+
# "generator": generator,
|
438 |
+
# "output_type": "pil",
|
439 |
+
# }
|
440 |
+
|
441 |
+
# # Use resolution binning for faster generation with less VRAM usage
|
442 |
+
# # if use_resolution_binning:
|
443 |
+
# # options["use_resolution_binning"] = True
|
444 |
+
|
445 |
+
# # Generate images potentially in batches
|
446 |
+
# images = []
|
447 |
+
# for i in range(0, num_images, BATCH_SIZE):
|
448 |
+
# batch_options = options.copy()
|
449 |
+
# batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
|
450 |
+
# if "negative_prompt" in batch_options:
|
451 |
+
# batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
|
452 |
+
# images.extend(pipe(**batch_options).images)
|
453 |
+
|
454 |
+
# image_paths = [save_image(img) for img in images]
|
455 |
+
# return image_paths, seed
|
456 |
+
|
457 |
+
# examples = [
|
458 |
+
# "a cat eating a piece of cheese",
|
459 |
+
# "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
|
460 |
+
# "Ironman VS Hulk, ultrarealistic",
|
461 |
+
# "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
|
462 |
+
# "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
|
463 |
+
# "Kids going to school, Anime style"
|
464 |
+
# ]
|
465 |
|
466 |
|
467 |
|
|
|
471 |
###############################################################################
|
472 |
|
473 |
|
474 |
+
print("4")
|
475 |
+
|
476 |
css = """
|
477 |
#warning {background-color: #FFCCCB}
|
478 |
.tools_button {
|
|
|
480 |
border: none !important;
|
481 |
box-shadow: none !important;
|
482 |
}
|
483 |
+
|
484 |
+
.info_btn {
|
485 |
+
background: white;
|
486 |
+
border: none !important;
|
487 |
+
box-shadow: none !important;
|
488 |
+
}
|
489 |
+
|
490 |
+
.function_button {
|
491 |
+
border: none !important;
|
492 |
+
box-shadow: none !important;
|
493 |
+
}
|
494 |
+
|
495 |
#tool_box {max-width: 50px}
|
496 |
|
497 |
"""
|
|
|
547 |
args.segmenter = "huge"
|
548 |
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
|
549 |
args.clip_filter = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
+
try:
|
552 |
+
print("Before preparing segmenter")
|
553 |
+
if args.segmenter_checkpoint is None:
|
554 |
+
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
555 |
+
else:
|
556 |
+
segmenter_checkpoint = args.segmenter_checkpoint
|
557 |
+
print("After preparing segmenter")
|
558 |
+
except Exception as e:
|
559 |
+
print(f"Error in preparing segmenter: {e}")
|
560 |
+
|
561 |
+
try:
|
562 |
+
print("Before building captioner")
|
563 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
564 |
+
print("After building captioner")
|
565 |
+
except Exception as e:
|
566 |
+
print(f"Error in building captioner: {e}")
|
567 |
+
|
568 |
+
try:
|
569 |
+
print("Before loading SAM model")
|
570 |
+
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
571 |
+
print("After loading SAM model")
|
572 |
+
except Exception as e:
|
573 |
+
print(f"Error in loading SAM model: {e}")
|
574 |
+
|
575 |
+
try:
|
576 |
+
print("Before initializing OCR reader")
|
577 |
+
ocr_lang = ["ch_tra", "en"]
|
578 |
+
shared_ocr_reader = easyocr.Reader(ocr_lang,model_storage_directory=".EasyOCR/model")
|
579 |
+
print("After initializing OCR reader")
|
580 |
+
except Exception as e:
|
581 |
+
print(f"Error in initializing OCR reader: {e}")
|
582 |
+
|
583 |
+
try:
|
584 |
+
print("Before building chatbot tools")
|
585 |
+
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
586 |
+
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
587 |
+
print("After building chatbot tools")
|
588 |
+
except Exception as e:
|
589 |
+
print(f"Error in building chatbot tools: {e}")
|
590 |
+
|
591 |
+
print(5)
|
592 |
|
593 |
# class ImageSketcher(gr.Image):
|
594 |
# """
|
|
|
625 |
|
626 |
def validate_api_key(api_key):
|
627 |
api_key = str(api_key).strip()
|
628 |
+
print(api_key)
|
629 |
try:
|
630 |
test_llm = ChatOpenAI(model_name="gpt-4o", temperature=0, openai_api_key=api_key)
|
631 |
+
print("test_llm")
|
632 |
response = test_llm([HumanMessage(content='Hello')])
|
633 |
+
print(response)
|
634 |
return True
|
635 |
except Exception as e:
|
636 |
+
print(f"API key validation failed: {e}")
|
637 |
return False
|
638 |
|
639 |
|
|
|
642 |
text_refiner = None
|
643 |
visual_chatgpt = None
|
644 |
if api_key and len(api_key) > 30:
|
645 |
+
print(api_key)
|
646 |
if validate_api_key(api_key):
|
647 |
try:
|
648 |
# text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
649 |
# assert len(text_refiner.llm('hi')) > 0 # test
|
650 |
text_refiner = None
|
651 |
+
print("text refiner")
|
652 |
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key=api_key)
|
653 |
except Exception as e:
|
654 |
+
print(f"Error initializing TextRefiner or ConversationBot: {e}")
|
655 |
text_refiner = None
|
656 |
visual_chatgpt = None
|
657 |
else:
|
658 |
+
print("Invalid API key.")
|
659 |
else:
|
660 |
+
print("API key is too short.")
|
661 |
+
print(text_refiner)
|
662 |
openai_available = text_refiner is not None
|
663 |
if visual_chatgpt:
|
664 |
|
|
|
734 |
|
735 |
|
736 |
|
737 |
+
def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English",narritive=None):
|
738 |
+
print("narritive", narritive)
|
739 |
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
740 |
image_input = image_input['background']
|
741 |
|
|
|
779 |
name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["style"]
|
780 |
# artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
|
781 |
|
782 |
+
if narritive==None or narritive=="Third":
|
783 |
+
state = [
|
784 |
+
(
|
785 |
+
None,
|
786 |
+
f"🤖 Hi, I am EyeSee. Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
|
787 |
+
)
|
788 |
+
]
|
789 |
+
elif narritive=="Artist":
|
790 |
+
state = [
|
791 |
+
(
|
792 |
+
None,
|
793 |
+
f"🧑🎨 Hello, I am the {artist}. Welcome to explore my painting, '{name}'. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant insights and thoughts behind my creation."
|
794 |
+
)
|
795 |
+
]
|
796 |
+
elif narritive=="Item":
|
797 |
+
state = [
|
798 |
+
(
|
799 |
+
None,
|
800 |
+
f"🎨 Hello, I am the Item. Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant insights and thoughts behind my creation."
|
801 |
+
)
|
802 |
+
]
|
803 |
+
|
804 |
+
|
805 |
|
806 |
return [state, state, image_input, click_state, image_input, image_input, image_input, image_input, image_embedding, \
|
807 |
original_size, input_size] + [f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Style: {material}"]*4 + [paragraph,artist]
|
|
|
842 |
|
843 |
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
844 |
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
|
845 |
+
state = state + [("You've selected image point at {}, ".format(prompt["input_point"]), None)]
|
846 |
|
847 |
+
# state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
848 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
849 |
text = out['generated_captions']['raw_caption']
|
850 |
input_mask = np.array(out['mask'].convert('P'))
|
|
|
870 |
yield state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
|
871 |
|
872 |
|
873 |
+
query_focus = {
|
874 |
+
"D": "Provide a description of the item.",
|
875 |
+
"DA": "Provide a description and analysis of the item.",
|
876 |
+
"DAI": "Provide a description, analysis, and interpretation of the item.",
|
877 |
+
"DDA": "Evaluate the item."
|
878 |
+
}
|
879 |
+
|
880 |
|
881 |
async def submit_caption(state,length, sentiment, factuality, language,
|
882 |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
|
883 |
autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
|
884 |
+
|
885 |
+
|
886 |
+
state = state + [(query_focus[focus_type], None)]
|
887 |
|
888 |
click_index = click_index_state
|
889 |
|
|
|
1166 |
visual_chatgpt.point_prompt = ""
|
1167 |
if keep_global:
|
1168 |
# visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
|
1169 |
+
visual_chatgpt.agent.memory.save_context({"input": visual_chatgpt.global_prompt}, {"output": ""})
|
|
|
1170 |
else:
|
1171 |
visual_chatgpt.current_image = None
|
1172 |
visual_chatgpt.global_prompt = ""
|
|
|
1400 |
|
1401 |
def toggle_icons_and_update_prompt(point_prompt):
|
1402 |
new_prompt = "Negative" if point_prompt == "Positive" else "Positive"
|
1403 |
+
new_add_icon = "assets/icons/plus-square-blue.png" if new_prompt == "Positive" else "assets/icons/plus-square.png"
|
1404 |
+
new_minus_icon = "assets/icons/minus-square.png" if new_prompt == "Positive" else "assets/icons/minus-square-blue.png"
|
1405 |
+
print(point_prompt,flush=True)
|
1406 |
+
print(new_prompt,flush=True)
|
1407 |
|
1408 |
return new_prompt, gr.update(icon=new_add_icon), gr.update(icon=new_minus_icon)
|
1409 |
|
|
|
1413 |
print("this is a print test")
|
1414 |
|
1415 |
def create_ui():
|
1416 |
+
print(6)
|
1417 |
title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
|
1418 |
"""
|
1419 |
description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
|
|
|
1509 |
with gr.Tab("Base(GPT Power)") as base_tab:
|
1510 |
image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1511 |
with gr.Row():
|
1512 |
+
name_label_base = gr.Button(value="Name: ",elem_classes="info_btn")
|
1513 |
+
artist_label_base = gr.Button(value="Artist: ",elem_classes="info_btn")
|
1514 |
+
year_label_base = gr.Button(value="Year: ",elem_classes="info_btn")
|
1515 |
+
material_label_base = gr.Button(value="Style: ",elem_classes="info_btn")
|
1516 |
|
1517 |
with gr.Tab("Base2") as base_tab2:
|
1518 |
image_input_base_2 = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1519 |
with gr.Row():
|
1520 |
+
name_label_base2 = gr.Button(value="Name: ",elem_classes="info_btn")
|
1521 |
+
artist_label_base2 = gr.Button(value="Artist: ",elem_classes="info_btn")
|
1522 |
+
year_label_base2 = gr.Button(value="Year: ",elem_classes="info_btn")
|
1523 |
+
material_label_base2 = gr.Button(value="Style: ",elem_classes="info_btn")
|
1524 |
|
1525 |
with gr.Tab("Click") as click_tab:
|
1526 |
with gr.Row():
|
1527 |
+
with gr.Column(scale=10,min_width=600):
|
1528 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
1529 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
1530 |
with gr.Row():
|
1531 |
+
name_label = gr.Button(value="Name: ",elem_classes="info_btn")
|
1532 |
+
artist_label = gr.Button(value="Artist: ",elem_classes="info_btn")
|
1533 |
+
year_label = gr.Button(value="Year: ",elem_classes="info_btn")
|
1534 |
+
material_label = gr.Button(value="Style: ",elem_classes="info_btn")
|
1535 |
|
1536 |
|
1537 |
# example_image_click = gr.Image(type="pil", interactive=False, visible=False)
|
1538 |
# the tool column
|
1539 |
+
with gr.Column(scale=1,elem_id="tool_box",min_width=80):
|
1540 |
add_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=add_icon_path)
|
1541 |
minus_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=minus_icon_path)
|
1542 |
clear_button_click = gr.Button(value="Reset", interactive=True,elem_classes="tools_button")
|
1543 |
+
clear_button_image = gr.Button(value="Change", interactive=True,elem_classes="tools_button")
|
1544 |
focus_d = gr.Button(value="D",interactive=True,elem_classes="function_button")
|
1545 |
focus_da = gr.Button(value="DA",interactive=True,elem_classes="function_button")
|
1546 |
focus_dai = gr.Button(value="DAI",interactive=True,elem_classes="function_button")
|
|
|
2073 |
|
2074 |
|
2075 |
|
2076 |
+
image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key,language,naritive],
|
2077 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
|
2078 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
|
2079 |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
|
|
|
2105 |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
|
2106 |
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
2107 |
[chatbot, state, aux_state,output_audio])
|
2108 |
+
# chat_input.submit(lambda: "", None, chat_input)
|
2109 |
+
chat_input.submit(lambda: {"text": ""}, None, chat_input)
|
2110 |
# submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
|
2111 |
# [chatbot, state, aux_state,output_audio])
|
2112 |
# submit_button_text.click(lambda: "", None, chat_input)
|
2113 |
+
example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key,language,naritive],
|
2114 |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
|
2115 |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
|
2116 |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
|
|
|
2277 |
|
2278 |
|
2279 |
if __name__ == '__main__':
|
2280 |
+
print("main")
|
2281 |
iface = create_ui()
|
2282 |
iface.queue(api_open=False, max_size=10)
|
2283 |
# iface.queue(concurrency_count=5, api_open=False, max_size=10)
|