Niki Zhang commited on
Commit
8708def
·
verified ·
1 Parent(s): c9593df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -150
app.py CHANGED
@@ -37,8 +37,9 @@ import requests
37
  import spaces
38
  # Print the current version of LangChain
39
  print(f"Current LangChain version: {__version__}")
40
-
41
  print("testing testing")
 
 
42
  # import tts
43
 
44
  ###############################################################################
@@ -94,21 +95,6 @@ from huggingface_hub import hf_hub_download
94
 
95
 
96
 
97
- # import logging
98
-
99
- # logging.basicConfig(level=logging.DEBUG)
100
-
101
- # logger = logging.getLogger(__name__)
102
-
103
- # def my_function(input_text):
104
- # logger.info(f'Received input: {input_text}')
105
- # return "Output: " + input_text
106
-
107
- import sys
108
-
109
- # 设置无缓冲输出
110
- sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', buffering=1)
111
- sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', buffering=1)
112
 
113
  # def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
114
  # """
@@ -385,97 +371,97 @@ def infer(image_path):
385
  ############# this part is for text to image #############
386
  ###############################################################################
387
 
388
- # Use environment variables for flexibility
389
  MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
390
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
391
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
392
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
393
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
394
 
395
- # Determine device and load model outside of function for efficiency
396
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
397
- pipe = StableDiffusionXLPipeline.from_pretrained(
398
- MODEL_ID,
399
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
400
- use_safetensors=True,
401
- add_watermarker=False,
402
- ).to(device)
403
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
404
 
405
- # Torch compile for potential speedup (experimental)
406
- if USE_TORCH_COMPILE:
407
- pipe.compile()
408
 
409
- # CPU offloading for larger RAM capacity (experimental)
410
- if ENABLE_CPU_OFFLOAD:
411
- pipe.enable_model_cpu_offload()
412
 
413
  MAX_SEED = np.iinfo(np.int32).max
414
 
415
- def save_image(img):
416
- unique_name = str(uuid.uuid4()) + ".png"
417
- img.save(unique_name)
418
- return unique_name
419
 
420
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
421
- if randomize_seed:
422
- seed = random.randint(0, MAX_SEED)
423
- return seed
424
 
425
  # @spaces.GPU(duration=30, queue=False)
426
- def generate(
427
- prompt: str,
428
- negative_prompt: str = "",
429
- use_negative_prompt: bool = False,
430
- seed: int = 1,
431
- width: int = 200,
432
- height: int = 200,
433
- guidance_scale: float = 3,
434
- num_inference_steps: int = 30,
435
- randomize_seed: bool = False,
436
- num_images: int = 4, # Number of images to generate
437
- use_resolution_binning: bool = True,
438
- progress=gr.Progress(track_tqdm=True),
439
- ):
440
- seed = int(randomize_seed_fn(seed, randomize_seed))
441
- generator = torch.Generator(device=device).manual_seed(seed)
442
-
443
- # Improved options handling
444
- options = {
445
- "prompt": [prompt] * num_images,
446
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
447
- "width": width,
448
- "height": height,
449
- "guidance_scale": guidance_scale,
450
- "num_inference_steps": num_inference_steps,
451
- "generator": generator,
452
- "output_type": "pil",
453
- }
454
-
455
- # Use resolution binning for faster generation with less VRAM usage
456
- # if use_resolution_binning:
457
- # options["use_resolution_binning"] = True
458
-
459
- # Generate images potentially in batches
460
- images = []
461
- for i in range(0, num_images, BATCH_SIZE):
462
- batch_options = options.copy()
463
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
464
- if "negative_prompt" in batch_options:
465
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
466
- images.extend(pipe(**batch_options).images)
467
-
468
- image_paths = [save_image(img) for img in images]
469
- return image_paths, seed
470
-
471
- examples = [
472
- "a cat eating a piece of cheese",
473
- "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
474
- "Ironman VS Hulk, ultrarealistic",
475
- "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
476
- "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
477
- "Kids going to school, Anime style"
478
- ]
479
 
480
 
481
 
@@ -485,6 +471,8 @@ examples = [
485
  ###############################################################################
486
 
487
 
 
 
488
  css = """
489
  #warning {background-color: #FFCCCB}
490
  .tools_button {
@@ -492,6 +480,18 @@ css = """
492
  border: none !important;
493
  box-shadow: none !important;
494
  }
 
 
 
 
 
 
 
 
 
 
 
 
495
  #tool_box {max-width: 50px}
496
 
497
  """
@@ -547,18 +547,48 @@ args = parse_augment()
547
  args.segmenter = "huge"
548
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
549
  args.clip_filter = True
550
- if args.segmenter_checkpoint is None:
551
- _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
552
- else:
553
- segmenter_checkpoint = args.segmenter_checkpoint
554
-
555
- shared_captioner = build_captioner(args.captioner, args.device, args)
556
- shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
557
- ocr_lang = ["ch_tra", "en"]
558
- shared_ocr_reader = easyocr.Reader(ocr_lang)
559
- tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
560
- shared_chatbot_tools = build_chatbot_tools(tools_dict)
561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  # class ImageSketcher(gr.Image):
564
  # """
@@ -595,15 +625,15 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
595
 
596
  def validate_api_key(api_key):
597
  api_key = str(api_key).strip()
598
- print(api_key, flush=True)
599
  try:
600
  test_llm = ChatOpenAI(model_name="gpt-4o", temperature=0, openai_api_key=api_key)
601
- print("test_llm", flush=True)
602
  response = test_llm([HumanMessage(content='Hello')])
603
- print(response, flush=True)
604
  return True
605
  except Exception as e:
606
- print(f"API key validation failed: {e}", flush=True)
607
  return False
608
 
609
 
@@ -612,23 +642,23 @@ def init_openai_api_key(api_key=""):
612
  text_refiner = None
613
  visual_chatgpt = None
614
  if api_key and len(api_key) > 30:
615
- print(api_key, flush=True)
616
  if validate_api_key(api_key):
617
  try:
618
  # text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
619
  # assert len(text_refiner.llm('hi')) > 0 # test
620
  text_refiner = None
621
- print("text refiner", flush=True)
622
  visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key=api_key)
623
  except Exception as e:
624
- print(f"Error initializing TextRefiner or ConversationBot: {e}", flush=True)
625
  text_refiner = None
626
  visual_chatgpt = None
627
  else:
628
- print("Invalid API key.", flush=True)
629
  else:
630
- print("API key is too short.", flush=True)
631
- print(text_refiner, flush=True)
632
  openai_available = text_refiner is not None
633
  if visual_chatgpt:
634
 
@@ -704,7 +734,8 @@ async def chat_input_callback(*args):
704
 
705
 
706
 
707
- def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
 
708
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
709
  image_input = image_input['background']
710
 
@@ -748,13 +779,29 @@ def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None
748
  name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["style"]
749
  # artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
750
 
751
-
752
- state = [
753
- (
754
- None,
755
- f"🤖 Hi, I am EyeSee. Let's explore this painting {name} together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
756
- )
757
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
  return [state, state, image_input, click_state, image_input, image_input, image_input, image_input, image_embedding, \
760
  original_size, input_size] + [f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Style: {material}"]*4 + [paragraph,artist]
@@ -795,9 +842,9 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
795
 
796
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
797
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
798
- # state = state + [("You've selected image point at {}, ".format(prompt["input_point"]), None)]
799
 
800
- state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
801
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
802
  text = out['generated_captions']['raw_caption']
803
  input_mask = np.array(out['mask'].convert('P'))
@@ -823,11 +870,20 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
823
  yield state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
824
 
825
 
 
 
 
 
 
 
 
826
 
827
  async def submit_caption(state,length, sentiment, factuality, language,
828
  out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
829
  autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
830
- print("state",state)
 
 
831
 
832
  click_index = click_index_state
833
 
@@ -1110,8 +1166,7 @@ def clear_chat_memory(visual_chatgpt, keep_global=False):
1110
  visual_chatgpt.point_prompt = ""
1111
  if keep_global:
1112
  # visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
1113
- visual_chatgpt.agent.memory.save_context({"input": visual_chatgpt.global_prompt}, {"output": None})
1114
- print("test")
1115
  else:
1116
  visual_chatgpt.current_image = None
1117
  visual_chatgpt.global_prompt = ""
@@ -1345,10 +1400,10 @@ def print_like_dislike(x: gr.LikeData,like_res,dislike_res,state):
1345
 
1346
  def toggle_icons_and_update_prompt(point_prompt):
1347
  new_prompt = "Negative" if point_prompt == "Positive" else "Positive"
1348
- new_add_icon = "assets/icons/plus-square-blue.png" if point_prompt == "Positive" else "assets/icons/plus-square.png"
1349
- new_minus_icon = "assets/icons/minus-square.png" if point_prompt == "Positive" else "assets/icons/minus-square-blue.png"
1350
- print(point_prompt)
1351
- print(new_prompt)
1352
 
1353
  return new_prompt, gr.update(icon=new_add_icon), gr.update(icon=new_minus_icon)
1354
 
@@ -1358,6 +1413,7 @@ minus_icon_path="assets/icons/minus-square.png"
1358
  print("this is a print test")
1359
 
1360
  def create_ui():
 
1361
  title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
1362
  """
1363
  description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
@@ -1453,38 +1509,38 @@ def create_ui():
1453
  with gr.Tab("Base(GPT Power)") as base_tab:
1454
  image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1455
  with gr.Row():
1456
- name_label_base = gr.Button(value="Name: ")
1457
- artist_label_base = gr.Button(value="Artist: ")
1458
- year_label_base = gr.Button(value="Year: ")
1459
- material_label_base = gr.Button(value="Style: ")
1460
 
1461
  with gr.Tab("Base2") as base_tab2:
1462
  image_input_base_2 = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1463
  with gr.Row():
1464
- name_label_base2 = gr.Button(value="Name: ")
1465
- artist_label_base2 = gr.Button(value="Artist: ")
1466
- year_label_base2 = gr.Button(value="Year: ")
1467
- material_label_base2 = gr.Button(value="Style: ")
1468
 
1469
  with gr.Tab("Click") as click_tab:
1470
  with gr.Row():
1471
- with gr.Column(scale=10,min_width=450):
1472
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1473
  example_image = gr.Image(type="pil", interactive=False, visible=False)
1474
  with gr.Row():
1475
- name_label = gr.Button(value="Name: ")
1476
- artist_label = gr.Button(value="Artist: ")
1477
- year_label = gr.Button(value="Year: ")
1478
- material_label = gr.Button(value="Style: ")
1479
 
1480
 
1481
  # example_image_click = gr.Image(type="pil", interactive=False, visible=False)
1482
  # the tool column
1483
- with gr.Column(scale=1,elem_id="tool_box",min_width=100):
1484
  add_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=add_icon_path)
1485
  minus_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=minus_icon_path)
1486
  clear_button_click = gr.Button(value="Reset", interactive=True,elem_classes="tools_button")
1487
- clear_button_image = gr.Button(value="Change Image", interactive=True,elem_classes="tools_button")
1488
  focus_d = gr.Button(value="D",interactive=True,elem_classes="function_button")
1489
  focus_da = gr.Button(value="DA",interactive=True,elem_classes="function_button")
1490
  focus_dai = gr.Button(value="DAI",interactive=True,elem_classes="function_button")
@@ -2017,7 +2073,7 @@ def create_ui():
2017
 
2018
 
2019
 
2020
- image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key],
2021
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
2022
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
2023
  name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
@@ -2049,11 +2105,12 @@ def create_ui():
2049
  # image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
2050
  chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
2051
  [chatbot, state, aux_state,output_audio])
2052
- chat_input.submit(lambda: "", None, chat_input)
 
2053
  # submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
2054
  # [chatbot, state, aux_state,output_audio])
2055
  # submit_button_text.click(lambda: "", None, chat_input)
2056
- example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key],
2057
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
2058
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
2059
  name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
@@ -2220,7 +2277,7 @@ def create_ui():
2220
 
2221
 
2222
  if __name__ == '__main__':
2223
- # logger.info("Starting Gradio app")
2224
  iface = create_ui()
2225
  iface.queue(api_open=False, max_size=10)
2226
  # iface.queue(concurrency_count=5, api_open=False, max_size=10)
 
37
  import spaces
38
  # Print the current version of LangChain
39
  print(f"Current LangChain version: {__version__}")
 
40
  print("testing testing")
41
+
42
+
43
  # import tts
44
 
45
  ###############################################################################
 
95
 
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
100
  # """
 
371
  ############# this part is for text to image #############
372
  ###############################################################################
373
 
374
+ # # Use environment variables for flexibility
375
  MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
376
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
377
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
378
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
379
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
380
 
381
+ # # Determine device and load model outside of function for efficiency
382
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
383
+ # pipe = StableDiffusionXLPipeline.from_pretrained(
384
+ # MODEL_ID,
385
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
386
+ # use_safetensors=True,
387
+ # add_watermarker=False,
388
+ # ).to(device)
389
+ # pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
390
 
391
+ # # Torch compile for potential speedup (experimental)
392
+ # if USE_TORCH_COMPILE:
393
+ # pipe.compile()
394
 
395
+ # # CPU offloading for larger RAM capacity (experimental)
396
+ # if ENABLE_CPU_OFFLOAD:
397
+ # pipe.enable_model_cpu_offload()
398
 
399
  MAX_SEED = np.iinfo(np.int32).max
400
 
401
+ # def save_image(img):
402
+ # unique_name = str(uuid.uuid4()) + ".png"
403
+ # img.save(unique_name)
404
+ # return unique_name
405
 
406
+ # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
407
+ # if randomize_seed:
408
+ # seed = random.randint(0, MAX_SEED)
409
+ # return seed
410
 
411
  # @spaces.GPU(duration=30, queue=False)
412
+ # def generate(
413
+ # prompt: str,
414
+ # negative_prompt: str = "",
415
+ # use_negative_prompt: bool = False,
416
+ # seed: int = 1,
417
+ # width: int = 200,
418
+ # height: int = 200,
419
+ # guidance_scale: float = 3,
420
+ # num_inference_steps: int = 30,
421
+ # randomize_seed: bool = False,
422
+ # num_images: int = 4, # Number of images to generate
423
+ # use_resolution_binning: bool = True,
424
+ # progress=gr.Progress(track_tqdm=True),
425
+ # ):
426
+ # seed = int(randomize_seed_fn(seed, randomize_seed))
427
+ # generator = torch.Generator(device=device).manual_seed(seed)
428
+
429
+ # # Improved options handling
430
+ # options = {
431
+ # "prompt": [prompt] * num_images,
432
+ # "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
433
+ # "width": width,
434
+ # "height": height,
435
+ # "guidance_scale": guidance_scale,
436
+ # "num_inference_steps": num_inference_steps,
437
+ # "generator": generator,
438
+ # "output_type": "pil",
439
+ # }
440
+
441
+ # # Use resolution binning for faster generation with less VRAM usage
442
+ # # if use_resolution_binning:
443
+ # # options["use_resolution_binning"] = True
444
+
445
+ # # Generate images potentially in batches
446
+ # images = []
447
+ # for i in range(0, num_images, BATCH_SIZE):
448
+ # batch_options = options.copy()
449
+ # batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
450
+ # if "negative_prompt" in batch_options:
451
+ # batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
452
+ # images.extend(pipe(**batch_options).images)
453
+
454
+ # image_paths = [save_image(img) for img in images]
455
+ # return image_paths, seed
456
+
457
+ # examples = [
458
+ # "a cat eating a piece of cheese",
459
+ # "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
460
+ # "Ironman VS Hulk, ultrarealistic",
461
+ # "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
462
+ # "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
463
+ # "Kids going to school, Anime style"
464
+ # ]
465
 
466
 
467
 
 
471
  ###############################################################################
472
 
473
 
474
+ print("4")
475
+
476
  css = """
477
  #warning {background-color: #FFCCCB}
478
  .tools_button {
 
480
  border: none !important;
481
  box-shadow: none !important;
482
  }
483
+
484
+ .info_btn {
485
+ background: white;
486
+ border: none !important;
487
+ box-shadow: none !important;
488
+ }
489
+
490
+ .function_button {
491
+ border: none !important;
492
+ box-shadow: none !important;
493
+ }
494
+
495
  #tool_box {max-width: 50px}
496
 
497
  """
 
547
  args.segmenter = "huge"
548
  args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
549
  args.clip_filter = True
 
 
 
 
 
 
 
 
 
 
 
550
 
551
+ try:
552
+ print("Before preparing segmenter")
553
+ if args.segmenter_checkpoint is None:
554
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
555
+ else:
556
+ segmenter_checkpoint = args.segmenter_checkpoint
557
+ print("After preparing segmenter")
558
+ except Exception as e:
559
+ print(f"Error in preparing segmenter: {e}")
560
+
561
+ try:
562
+ print("Before building captioner")
563
+ shared_captioner = build_captioner(args.captioner, args.device, args)
564
+ print("After building captioner")
565
+ except Exception as e:
566
+ print(f"Error in building captioner: {e}")
567
+
568
+ try:
569
+ print("Before loading SAM model")
570
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
571
+ print("After loading SAM model")
572
+ except Exception as e:
573
+ print(f"Error in loading SAM model: {e}")
574
+
575
+ try:
576
+ print("Before initializing OCR reader")
577
+ ocr_lang = ["ch_tra", "en"]
578
+ shared_ocr_reader = easyocr.Reader(ocr_lang,model_storage_directory=".EasyOCR/model")
579
+ print("After initializing OCR reader")
580
+ except Exception as e:
581
+ print(f"Error in initializing OCR reader: {e}")
582
+
583
+ try:
584
+ print("Before building chatbot tools")
585
+ tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
586
+ shared_chatbot_tools = build_chatbot_tools(tools_dict)
587
+ print("After building chatbot tools")
588
+ except Exception as e:
589
+ print(f"Error in building chatbot tools: {e}")
590
+
591
+ print(5)
592
 
593
  # class ImageSketcher(gr.Image):
594
  # """
 
625
 
626
  def validate_api_key(api_key):
627
  api_key = str(api_key).strip()
628
+ print(api_key)
629
  try:
630
  test_llm = ChatOpenAI(model_name="gpt-4o", temperature=0, openai_api_key=api_key)
631
+ print("test_llm")
632
  response = test_llm([HumanMessage(content='Hello')])
633
+ print(response)
634
  return True
635
  except Exception as e:
636
+ print(f"API key validation failed: {e}")
637
  return False
638
 
639
 
 
642
  text_refiner = None
643
  visual_chatgpt = None
644
  if api_key and len(api_key) > 30:
645
+ print(api_key)
646
  if validate_api_key(api_key):
647
  try:
648
  # text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
649
  # assert len(text_refiner.llm('hi')) > 0 # test
650
  text_refiner = None
651
+ print("text refiner")
652
  visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key=api_key)
653
  except Exception as e:
654
+ print(f"Error initializing TextRefiner or ConversationBot: {e}")
655
  text_refiner = None
656
  visual_chatgpt = None
657
  else:
658
+ print("Invalid API key.")
659
  else:
660
+ print("API key is too short.")
661
+ print(text_refiner)
662
  openai_available = text_refiner is not None
663
  if visual_chatgpt:
664
 
 
734
 
735
 
736
 
737
+ def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English",narritive=None):
738
+ print("narritive", narritive)
739
  if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
740
  image_input = image_input['background']
741
 
 
779
  name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["style"]
780
  # artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
781
 
782
+ if narritive==None or narritive=="Third":
783
+ state = [
784
+ (
785
+ None,
786
+ f"🤖 Hi, I am EyeSee. Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
787
+ )
788
+ ]
789
+ elif narritive=="Artist":
790
+ state = [
791
+ (
792
+ None,
793
+ f"🧑‍🎨 Hello, I am the {artist}. Welcome to explore my painting, '{name}'. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant insights and thoughts behind my creation."
794
+ )
795
+ ]
796
+ elif narritive=="Item":
797
+ state = [
798
+ (
799
+ None,
800
+ f"🎨 Hello, I am the Item. Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant insights and thoughts behind my creation."
801
+ )
802
+ ]
803
+
804
+
805
 
806
  return [state, state, image_input, click_state, image_input, image_input, image_input, image_input, image_embedding, \
807
  original_size, input_size] + [f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Style: {material}"]*4 + [paragraph,artist]
 
842
 
843
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
844
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
845
+ state = state + [("You've selected image point at {}, ".format(prompt["input_point"]), None)]
846
 
847
+ # state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
848
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
849
  text = out['generated_captions']['raw_caption']
850
  input_mask = np.array(out['mask'].convert('P'))
 
870
  yield state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
871
 
872
 
873
+ query_focus = {
874
+ "D": "Provide a description of the item.",
875
+ "DA": "Provide a description and analysis of the item.",
876
+ "DAI": "Provide a description, analysis, and interpretation of the item.",
877
+ "DDA": "Evaluate the item."
878
+ }
879
+
880
 
881
  async def submit_caption(state,length, sentiment, factuality, language,
882
  out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
883
  autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
884
+
885
+
886
+ state = state + [(query_focus[focus_type], None)]
887
 
888
  click_index = click_index_state
889
 
 
1166
  visual_chatgpt.point_prompt = ""
1167
  if keep_global:
1168
  # visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
1169
+ visual_chatgpt.agent.memory.save_context({"input": visual_chatgpt.global_prompt}, {"output": ""})
 
1170
  else:
1171
  visual_chatgpt.current_image = None
1172
  visual_chatgpt.global_prompt = ""
 
1400
 
1401
  def toggle_icons_and_update_prompt(point_prompt):
1402
  new_prompt = "Negative" if point_prompt == "Positive" else "Positive"
1403
+ new_add_icon = "assets/icons/plus-square-blue.png" if new_prompt == "Positive" else "assets/icons/plus-square.png"
1404
+ new_minus_icon = "assets/icons/minus-square.png" if new_prompt == "Positive" else "assets/icons/minus-square-blue.png"
1405
+ print(point_prompt,flush=True)
1406
+ print(new_prompt,flush=True)
1407
 
1408
  return new_prompt, gr.update(icon=new_add_icon), gr.update(icon=new_minus_icon)
1409
 
 
1413
  print("this is a print test")
1414
 
1415
  def create_ui():
1416
+ print(6)
1417
  title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
1418
  """
1419
  description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
 
1509
  with gr.Tab("Base(GPT Power)") as base_tab:
1510
  image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1511
  with gr.Row():
1512
+ name_label_base = gr.Button(value="Name: ",elem_classes="info_btn")
1513
+ artist_label_base = gr.Button(value="Artist: ",elem_classes="info_btn")
1514
+ year_label_base = gr.Button(value="Year: ",elem_classes="info_btn")
1515
+ material_label_base = gr.Button(value="Style: ",elem_classes="info_btn")
1516
 
1517
  with gr.Tab("Base2") as base_tab2:
1518
  image_input_base_2 = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1519
  with gr.Row():
1520
+ name_label_base2 = gr.Button(value="Name: ",elem_classes="info_btn")
1521
+ artist_label_base2 = gr.Button(value="Artist: ",elem_classes="info_btn")
1522
+ year_label_base2 = gr.Button(value="Year: ",elem_classes="info_btn")
1523
+ material_label_base2 = gr.Button(value="Style: ",elem_classes="info_btn")
1524
 
1525
  with gr.Tab("Click") as click_tab:
1526
  with gr.Row():
1527
+ with gr.Column(scale=10,min_width=600):
1528
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
1529
  example_image = gr.Image(type="pil", interactive=False, visible=False)
1530
  with gr.Row():
1531
+ name_label = gr.Button(value="Name: ",elem_classes="info_btn")
1532
+ artist_label = gr.Button(value="Artist: ",elem_classes="info_btn")
1533
+ year_label = gr.Button(value="Year: ",elem_classes="info_btn")
1534
+ material_label = gr.Button(value="Style: ",elem_classes="info_btn")
1535
 
1536
 
1537
  # example_image_click = gr.Image(type="pil", interactive=False, visible=False)
1538
  # the tool column
1539
+ with gr.Column(scale=1,elem_id="tool_box",min_width=80):
1540
  add_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=add_icon_path)
1541
  minus_button = gr.Button(value="", interactive=True,elem_classes="tools_button",icon=minus_icon_path)
1542
  clear_button_click = gr.Button(value="Reset", interactive=True,elem_classes="tools_button")
1543
+ clear_button_image = gr.Button(value="Change", interactive=True,elem_classes="tools_button")
1544
  focus_d = gr.Button(value="D",interactive=True,elem_classes="function_button")
1545
  focus_da = gr.Button(value="DA",interactive=True,elem_classes="function_button")
1546
  focus_dai = gr.Button(value="DAI",interactive=True,elem_classes="function_button")
 
2073
 
2074
 
2075
 
2076
+ image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key,language,naritive],
2077
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
2078
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
2079
  name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
 
2105
  # image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
2106
  chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
2107
  [chatbot, state, aux_state,output_audio])
2108
+ # chat_input.submit(lambda: "", None, chat_input)
2109
+ chat_input.submit(lambda: {"text": ""}, None, chat_input)
2110
  # submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
2111
  # [chatbot, state, aux_state,output_audio])
2112
  # submit_button_text.click(lambda: "", None, chat_input)
2113
+ example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key,language,naritive],
2114
  [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2,
2115
  image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \
2116
  name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \
 
2277
 
2278
 
2279
  if __name__ == '__main__':
2280
+ print("main")
2281
  iface = create_ui()
2282
  iface.queue(api_open=False, max_size=10)
2283
  # iface.queue(concurrency_count=5, api_open=False, max_size=10)