Spaces:
Running
Running
import datetime | |
from io import BytesIO | |
import io | |
from math import inf | |
import os | |
import base64 | |
import json | |
import gradio as gr | |
import numpy as np | |
from gradio import processing_utils | |
import requests | |
from packaging import version | |
from PIL import Image, ImageDraw | |
import functools | |
import emoji | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.schema import HumanMessage | |
from caption_anything.model import CaptionAnything | |
from caption_anything.utils.image_editing_utils import create_bubble_frame | |
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize | |
from caption_anything.utils.parser import parse_augment | |
from caption_anything.captioner import build_captioner | |
from caption_anything.text_refiner import build_text_refiner | |
from caption_anything.segmenter import build_segmenter | |
from chatbox import ConversationBot, build_chatbot_tools, get_new_image_name | |
from segment_anything import sam_model_registry | |
import easyocr | |
import re | |
import edge_tts | |
from langchain import __version__ | |
import torch | |
from transformers import AutoProcessor, SiglipModel | |
import faiss | |
from huggingface_hub import hf_hub_download | |
from datasets import load_dataset | |
import pandas as pd | |
import requests | |
import spaces | |
# Print the current version of LangChain | |
print(f"Current LangChain version: {__version__}") | |
print("testing testing") | |
# import tts | |
############################################################################### | |
############# this part is for 3D generate ############# | |
############################################################################### | |
# import spaces # | |
# import threading | |
# lock = threading.Lock() | |
import os | |
# import uuid | |
# from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
# from diffusers.utils import export_to_video | |
# from safetensors.torch import load_file | |
#from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
import random | |
import uuid | |
import json | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
import imageio | |
import numpy as np | |
import torch | |
import rembg | |
from PIL import Image | |
from torchvision.transforms import v2 | |
from pytorch_lightning import seed_everything | |
from omegaconf import OmegaConf | |
from einops import rearrange, repeat | |
from tqdm import tqdm | |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
from src.utils.train_util import instantiate_from_config | |
from src.utils.camera_util import ( | |
FOV_to_intrinsics, | |
get_zero123plus_input_cameras, | |
get_circular_camera_poses, | |
) | |
from src.utils.mesh_util import save_obj, save_glb | |
from src.utils.infer_util import remove_background, resize_foreground, images_to_video | |
import tempfile | |
from functools import partial | |
from huggingface_hub import hf_hub_download | |
# def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): | |
# """ | |
# Get the rendering camera parameters. | |
# """ | |
# c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) | |
# if is_flexicubes: | |
# cameras = torch.linalg.inv(c2ws) | |
# cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
# else: | |
# extrinsics = c2ws.flatten(-2) | |
# intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) | |
# cameras = torch.cat([extrinsics, intrinsics], dim=-1) | |
# cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) | |
# return cameras | |
# def images_to_video(images, output_path, fps=30): | |
# # images: (N, C, H, W) | |
# os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
# frames = [] | |
# for i in range(images.shape[0]): | |
# frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255) | |
# assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ | |
# f"Frame shape mismatch: {frame.shape} vs {images.shape}" | |
# assert frame.min() >= 0 and frame.max() <= 255, \ | |
# f"Frame value out of range: {frame.min()} ~ {frame.max()}" | |
# frames.append(frame) | |
# imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264') | |
# ############################################################################### | |
# # Configuration. | |
# ############################################################################### | |
# import shutil | |
# def find_cuda(): | |
# # Check if CUDA_HOME or CUDA_PATH environment variables are set | |
# cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') | |
# if cuda_home and os.path.exists(cuda_home): | |
# return cuda_home | |
# # Search for the nvcc executable in the system's PATH | |
# nvcc_path = shutil.which('nvcc') | |
# if nvcc_path: | |
# # Remove the 'bin/nvcc' part to get the CUDA installation path | |
# cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) | |
# return cuda_path | |
# return None | |
# cuda_path = find_cuda() | |
# if cuda_path: | |
# print(f"CUDA installation found at: {cuda_path}") | |
# else: | |
# print("CUDA installation not found") | |
# config_path = 'configs/instant-nerf-base.yaml' | |
# config = OmegaConf.load(config_path) | |
# config_name = os.path.basename(config_path).replace('.yaml', '') | |
# model_config = config.model_config | |
# infer_config = config.infer_config | |
# IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False | |
# device = torch.device('cuda') | |
# # load diffusion model | |
# print('Loading diffusion model ...') | |
# pipeline = DiffusionPipeline.from_pretrained( | |
# "sudo-ai/zero123plus-v1.2", | |
# custom_pipeline="zero123plus", | |
# torch_dtype=torch.float16, | |
# ) | |
# pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
# pipeline.scheduler.config, timestep_spacing='trailing' | |
# ) | |
# # load custom white-background UNet | |
# unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") | |
# state_dict = torch.load(unet_ckpt_path, map_location='cpu') | |
# pipeline.unet.load_state_dict(state_dict, strict=True) | |
# pipeline = pipeline.to(device) | |
# # load reconstruction model | |
# print('Loading reconstruction model ...') | |
# model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_base.ckpt", repo_type="model") | |
# model0 = instantiate_from_config(model_config) | |
# state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] | |
# state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} | |
# model0.load_state_dict(state_dict, strict=True) | |
# model0 = model0.to(device) | |
# print('Loading Finished!') | |
# def check_input_image(input_image): | |
# if input_image is None: | |
# raise gr.Error("No image uploaded!") | |
# image = None | |
# else: | |
# image = Image.open(input_image) | |
# return image | |
# def preprocess(input_image, do_remove_background): | |
# rembg_session = rembg.new_session() if do_remove_background else None | |
# if do_remove_background: | |
# input_image = remove_background(input_image, rembg_session) | |
# input_image = resize_foreground(input_image, 0.85) | |
# return input_image | |
# # @spaces.GPU | |
# def generate_mvs(input_image, sample_steps, sample_seed): | |
# seed_everything(sample_seed) | |
# # sampling | |
# z123_image = pipeline( | |
# input_image, | |
# num_inference_steps=sample_steps | |
# ).images[0] | |
# show_image = np.asarray(z123_image, dtype=np.uint8) | |
# show_image = torch.from_numpy(show_image) # (960, 640, 3) | |
# show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2) | |
# show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3) | |
# show_image = Image.fromarray(show_image.numpy()) | |
# return z123_image, show_image | |
# # @spaces.GPU | |
# def make3d(images): | |
# global model0 | |
# if IS_FLEXICUBES: | |
# model0.init_flexicubes_geometry(device) | |
# model0 = model0.eval() | |
# images = np.asarray(images, dtype=np.float32) / 255.0 | |
# images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) | |
# images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) | |
# input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) | |
# render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) | |
# images = images.unsqueeze(0).to(device) | |
# images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) | |
# mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name | |
# print(mesh_fpath) | |
# mesh_basename = os.path.basename(mesh_fpath).split('.')[0] | |
# mesh_dirname = os.path.dirname(mesh_fpath) | |
# video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4") | |
# mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb") | |
# with torch.no_grad(): | |
# # get triplane | |
# planes = model0.forward_planes(images, input_cameras) | |
# # # get video | |
# # chunk_size = 20 if IS_FLEXICUBES else 1 | |
# # render_size = 384 | |
# # frames = [] | |
# # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): | |
# # if IS_FLEXICUBES: | |
# # frame = model.forward_geometry( | |
# # planes, | |
# # render_cameras[:, i:i+chunk_size], | |
# # render_size=render_size, | |
# # )['img'] | |
# # else: | |
# # frame = model.synthesizer( | |
# # planes, | |
# # cameras=render_cameras[:, i:i+chunk_size], | |
# # render_size=render_size, | |
# # )['images_rgb'] | |
# # frames.append(frame) | |
# # frames = torch.cat(frames, dim=1) | |
# # images_to_video( | |
# # frames[0], | |
# # video_fpath, | |
# # fps=30, | |
# # ) | |
# # print(f"Video saved to {video_fpath}") | |
# # get mesh | |
# mesh_out = model0.extract_mesh( | |
# planes, | |
# use_texture_map=False, | |
# **infer_config, | |
# ) | |
# vertices, faces, vertex_colors = mesh_out | |
# vertices = vertices[:, [1, 2, 0]] | |
# save_glb(vertices, faces, vertex_colors, mesh_glb_fpath) | |
# save_obj(vertices, faces, vertex_colors, mesh_fpath) | |
# print(f"Mesh saved to {mesh_fpath}") | |
# return mesh_fpath, mesh_glb_fpath | |
############################################################################### | |
############# above part is for 3D generate ############# | |
############################################################################### | |
############################################################################### | |
############# This part is for sCLIP ############# | |
############################################################################### | |
# download model and dataset | |
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./") | |
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./") | |
# read index, dataset and load siglip model and processor | |
index = faiss.read_index("./siglip_10k_latest.index") | |
df = pd.read_csv("./wikiart_10k_latest.csv") | |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") | |
slipmodel = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device) | |
def read_image_from_url(url): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)).convert("RGB") | |
return img | |
#@spaces.GPU | |
def extract_features_siglip(image): | |
with torch.no_grad(): | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
image_features = slipmodel.get_image_features(**inputs) | |
return image_features | |
def infer(crop_image_path,full_image_path,state,language,task_type=None): | |
print("task type",task_type) | |
style_gallery_output = [] | |
item_gallery_output=[] | |
if task_type=="task 1": | |
item_gallery_output.append("recomendation_pic/1.8.jpg") | |
item_gallery_output.append("recomendation_pic/1.9.jpg") | |
input_image = Image.open(full_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 2) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
style_gallery_output.append(img_retrieved) | |
if language=="English": | |
msg="🖼️ Please refer to the section below to see the recommended results." | |
else: | |
msg="🖼️ 请到下方查看推荐结果。" | |
state+=[(None,msg)] | |
return item_gallery_output, style_gallery_output,state,state | |
elif task_type=="task 2": | |
item_gallery_output.append("recomendation_pic/2.8.jpg") | |
item_gallery_output.append("recomendation_pic/2.9.png") | |
input_image = Image.open(full_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 2) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
style_gallery_output.append(img_retrieved) | |
if language=="English": | |
msg="🖼️ Please refer to the section below to see the recommended results." | |
else: | |
msg="🖼️ 请到下方查看推荐结果。" | |
state+=[(None,msg)] | |
return item_gallery_output, style_gallery_output,state,state | |
elif task_type=="task 3": | |
item_gallery_output.append("recomendation_pic/3.8.png") | |
item_gallery_output.append("recomendation_pic/3.9.png") | |
input_image = Image.open(full_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 2) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
style_gallery_output.append(img_retrieved) | |
if language=="English": | |
msg="🖼️ Please refer to the section below to see the recommended results." | |
else: | |
msg="🖼️ 请到下方查看推荐结果。" | |
state+=[(None,msg)] | |
return item_gallery_output, style_gallery_output,state,state | |
elif crop_image_path: | |
input_image = Image.open(crop_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 2) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
item_gallery_output.append(img_retrieved) | |
input_image = Image.open(full_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 2) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
style_gallery_output.append(img_retrieved) | |
if language=="English": | |
msg="🖼️ Please refer to the section below to see the recommended results." | |
else: | |
msg="🖼️ 请到下方查看推荐结果。" | |
state+=[(None,msg)] | |
return item_gallery_output, style_gallery_output,state,state | |
else: | |
input_image = Image.open(full_image_path).convert("RGB") | |
input_features = extract_features_siglip(input_image.convert("RGB")) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = index.search(input_features, 4) | |
for i,v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = df.iloc[v]["Link"] | |
img_retrieved = read_image_from_url(image_url) | |
style_gallery_output.append(img_retrieved) | |
if language=="English": | |
msg="🖼️ Please refer to the section below to see the recommended results." | |
else: | |
msg="🖼️ 请到下方查看推荐结果。" | |
state+=[(None,msg)] | |
return item_gallery_output, style_gallery_output,state,state | |
############################################################################### | |
############# Above part is for sCLIP ############# | |
############################################################################### | |
############################################################################### | |
############# this part is for text to image ############# | |
############################################################################### | |
# # Use environment variables for flexibility | |
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash") | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) | |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" | |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" | |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once | |
# # Determine device and load model outside of function for efficiency | |
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# pipe = StableDiffusionXLPipeline.from_pretrained( | |
# MODEL_ID, | |
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
# use_safetensors=True, | |
# add_watermarker=False, | |
# ).to(device) | |
# pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
# # Torch compile for potential speedup (experimental) | |
# if USE_TORCH_COMPILE: | |
# pipe.compile() | |
# # CPU offloading for larger RAM capacity (experimental) | |
# if ENABLE_CPU_OFFLOAD: | |
# pipe.enable_model_cpu_offload() | |
MAX_SEED = np.iinfo(np.int32).max | |
# def save_image(img): | |
# unique_name = str(uuid.uuid4()) + ".png" | |
# img.save(unique_name) | |
# return unique_name | |
# def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
# if randomize_seed: | |
# seed = random.randint(0, MAX_SEED) | |
# return seed | |
# @spaces.GPU(duration=30, queue=False) | |
# def generate( | |
# prompt: str, | |
# negative_prompt: str = "", | |
# use_negative_prompt: bool = False, | |
# seed: int = 1, | |
# width: int = 200, | |
# height: int = 200, | |
# guidance_scale: float = 3, | |
# num_inference_steps: int = 30, | |
# randomize_seed: bool = False, | |
# num_images: int = 4, # Number of images to generate | |
# use_resolution_binning: bool = True, | |
# progress=gr.Progress(track_tqdm=True), | |
# ): | |
# seed = int(randomize_seed_fn(seed, randomize_seed)) | |
# generator = torch.Generator(device=device).manual_seed(seed) | |
# # Improved options handling | |
# options = { | |
# "prompt": [prompt] * num_images, | |
# "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None, | |
# "width": width, | |
# "height": height, | |
# "guidance_scale": guidance_scale, | |
# "num_inference_steps": num_inference_steps, | |
# "generator": generator, | |
# "output_type": "pil", | |
# } | |
# # Use resolution binning for faster generation with less VRAM usage | |
# # if use_resolution_binning: | |
# # options["use_resolution_binning"] = True | |
# # Generate images potentially in batches | |
# images = [] | |
# for i in range(0, num_images, BATCH_SIZE): | |
# batch_options = options.copy() | |
# batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE] | |
# if "negative_prompt" in batch_options: | |
# batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE] | |
# images.extend(pipe(**batch_options).images) | |
# image_paths = [save_image(img) for img in images] | |
# return image_paths, seed | |
# examples = [ | |
# "a cat eating a piece of cheese", | |
# "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k", | |
# "Ironman VS Hulk, ultrarealistic", | |
# "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", | |
# "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk", | |
# "Kids going to school, Anime style" | |
# ] | |
############################################################################### | |
############# above part is for text to image ############# | |
############################################################################### | |
print("4") | |
css = """ | |
#warning {background-color: #FFCCCB} | |
.tools_button { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
justify-content: center; | |
background: white; | |
border: none !important; | |
box-shadow: none !important; | |
text-align: center; | |
color: black; | |
} | |
.tools_button_clicked { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
justify-content: center; | |
background: white; | |
border: none !important; | |
box-shadow: none !important; | |
text-align: center; | |
color: rgb(18,150,219); | |
} | |
.tools_button_add { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
justify-content: center; | |
background: white; | |
border: none !important; | |
box-shadow: none !important; | |
text-align: center; | |
color: rgb(18,150,219); | |
} | |
.info_btn { | |
background: white !important; | |
border: none !important; | |
box-shadow: none !important; | |
font-size: 15px !important; | |
min-width: 6rem !important; | |
max-width: 10rem !important; | |
} | |
.info_btn_interact { | |
background: rgb(242, 240, 233) !important; | |
box-shadow: none !important; | |
font-size: 15px !important; | |
min-width: 6rem !important; | |
max-width: 10rem !important; | |
} | |
.function_button { | |
border: none !important; | |
box-shadow: none !important; | |
} | |
.function_button_rec { | |
background: rgb(245, 193, 154) !important; | |
border: none !important; | |
box-shadow: none !important; | |
} | |
#tool_box {max-width: 50px} | |
""" | |
filtered_language_dict = { | |
'English': {'female': 'en-US-JennyNeural', 'male': 'en-US-GuyNeural'}, | |
'Chinese': {'female': 'zh-CN-XiaoxiaoNeural', 'male': 'zh-CN-YunxiNeural'}, | |
'French': {'female': 'fr-FR-DeniseNeural', 'male': 'fr-FR-HenriNeural'}, | |
'Spanish': {'female': 'es-MX-DaliaNeural', 'male': 'es-MX-JorgeNeural'}, | |
'Arabic': {'female': 'ar-SA-ZariyahNeural', 'male': 'ar-SA-HamedNeural'}, | |
'Portuguese': {'female': 'pt-BR-FranciscaNeural', 'male': 'pt-BR-AntonioNeural'}, | |
'Cantonese': {'female': 'zh-HK-HiuGaaiNeural', 'male': 'zh-HK-WanLungNeural'} | |
} | |
focus_map = { | |
"Describe":0, | |
"D+Analysis":1, | |
"DA+Interprete":2, | |
"Judge":3 | |
} | |
prompt_list = [ | |
[ | |
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the selected object but does not include analysis) as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'Wiki_caption: {Wiki_caption}, You have to help me understand what is about the selected object and list one object judgement and one whole art judgement(how successful do you think the artist was?) as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.' | |
], | |
[ | |
"When generating the answer, you should tell others that you are one of the creators of these paintings and generate the text in the tone and manner as if you are the creator of the painting. When generating the answer, you should tell others that you are the creator of this painting and generate the text in the tone and manner as if you are the creator of this painting. You have to help me understand what is about the selected object and list one fact (describes the selected object but does not include analysis) as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.", | |
"When generating the answer, you should tell others that you are one of the creators of these paintings and generate the text in the tone and manner as if you are the creator of the painting. When generating the answer, you should tell others that you are the creator of this painting and generate the text in the tone and manner as if you are the creator of this painting. You have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.", | |
"When generating the answer, you should tell others that you are one of the creators of these paintings and generate the text in the tone and manner as if you are the creator of the painting. When generating the answer, you should tell others that you are the creator of this painting and generate the text in the tone and manner as if you are the creator of this painting. You have to help me understand what is about the selected object and list one fact, one analysis, and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.", | |
"When generating the answer, you should tell others that you are one of the creators of these paintings and generate the text in the tone and manner as if you are the creator of the painting. According to image and wiki_caption {Wiki_caption}, You have to help me understand what is about the selected object and list one object judgement and one whole art judgement(how successful do you think the artist was?) as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.", | |
], | |
[ | |
'When generating answers, you should tell people that you are the object itself that was selected, and generate text in the tone and manner in which you are the object or the person. You have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the object and start every sentence with I. Please generate the above points in the tone and manner as if you are the object of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'When generating answers, you should tell people that you are the object itself that was selected, and generate text in the tone and manner in which you are the object or the person. You have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the object and start every sentence with I. Please generate the above points in the tone and manner as if you are the object of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'When generating answers, you should tell people that you are the object itself that was selected, and generate text in the tone and manner in which you are the object or the person. You have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and {Wiki_caption}. Please generate the above points in the tone and manner as if you are the object and start every sentence with I. Please generate the above points in the tone and manner as if you are the object of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
'When generating answers, you should tell people that you are the object itself that was selected, and generate text in the tone and manner in which you are the object or the person. According to image and wiki_caption {Wiki_caption}, You have to help me understand what is about the selected object and list one object judgement and one whole art judgement(how successful do you think the artist was?) as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Please generate the above points in the tone and manner as if you are the object of this painting and start every sentence with I. Each point listed is to be in {language} language, with a response length of about {length} words.', | |
] | |
] | |
recommendation_prompt=[ | |
'''I want you to write the recommendation reason according to the following content, as a markdown outline with appropriate emojis that describe what you see according to the image:Recommendation reason: {{Recommendation based on objects in the image or Recommendation based on overall visual similarity}} | |
Detailed analysis: Based on the recommendation reason, explain why you recommend image 2 after viewing image 1.Each bullet point should be in {language} language, with a response length of about {length} words.''', | |
''' | |
When generating the answer, you should tell others that you are the creators of the first paintings and generate the text in the tone and manner as if you are the creator of the painting. | |
I want you to write the recommendation reason according to the following content, as a markdown outline with appropriate emojis that describe what you see according to the image: | |
Recommendation reason: {{ As the author of the first painting, I recommend based on the object I painted OR As the author of the first painting, I recommend based on the overall similarity in appearance}} | |
Detailed analysis: Based on the recommendation reason, explain why you recommend image 2 after viewing image 1. Please generate the above points in the tone and manner as if you are the creator of this painting and start every sentence with I. | |
Each bullet point should be in {language} language, with a response length of about {length} words. | |
''', | |
''' | |
When generating answers, you should tell people that you are the object itself that was selected in the first painting, and generate text in the tone and manner in which you are the object | |
I want you to write the recommendation reason according to the following content, as a markdown outline with appropriate emojis that describe what you see according to the image: | |
Recommendation reason: {{As an object in the first painting, I am recommending based on myself OR As an object in the first painting, I am recommending based on the overall similarity of the first painting's appearance}} | |
Detailed analysis: Based on the recommendation reason, explain why you recommend image 2 after viewing image 1. Please generate the above points in the tone and manner as if you are the object of this painting and start every sentence with I. | |
Each bullet point should be in {language} language, with a response length of about {length} words. | |
''' | |
] | |
gpt_state = 0 | |
VOICE = "en-GB-SoniaNeural" | |
article = """ | |
<div style='margin:20px auto;'> | |
<p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p> | |
</div> | |
""" | |
args = parse_augment() | |
args.segmenter = "huge" | |
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth" | |
args.clip_filter = True | |
try: | |
print("Before preparing segmenter") | |
if args.segmenter_checkpoint is None: | |
_, segmenter_checkpoint = prepare_segmenter(args.segmenter) | |
else: | |
segmenter_checkpoint = args.segmenter_checkpoint | |
print("After preparing segmenter") | |
except Exception as e: | |
print(f"Error in preparing segmenter: {e}") | |
try: | |
print("Before building captioner") | |
shared_captioner = build_captioner(args.captioner, args.device, args) | |
print("After building captioner") | |
except Exception as e: | |
print(f"Error in building captioner: {e}") | |
try: | |
print("Before loading SAM model") | |
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device) | |
print("After loading SAM model") | |
except Exception as e: | |
print(f"Error in loading SAM model: {e}") | |
try: | |
print("Before initializing OCR reader") | |
ocr_lang = ["ch_tra", "en"] | |
shared_ocr_reader = easyocr.Reader(ocr_lang,model_storage_directory=".EasyOCR/model") | |
print("After initializing OCR reader") | |
except Exception as e: | |
print(f"Error in initializing OCR reader: {e}") | |
try: | |
print("Before building chatbot tools") | |
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')} | |
shared_chatbot_tools = build_chatbot_tools(tools_dict) | |
print("After building chatbot tools") | |
except Exception as e: | |
print(f"Error in building chatbot tools: {e}") | |
print(5) | |
# class ImageSketcher(gr.Image): | |
# """ | |
# Fix the bug of gradio.Image that cannot upload with tool == 'sketch'. | |
# """ | |
# is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing. | |
# def __init__(self, **kwargs): | |
# super().__init__(**kwargs) | |
# def preprocess(self, x): | |
# if self.tool == 'sketch' and self.source in ["upload", "webcam"]: | |
# assert isinstance(x, dict) | |
# if x['mask'] is None: | |
# decode_image = processing_utils.decode_base64_to_image(x['image']) | |
# width, height = decode_image.size | |
# mask = np.zeros((height, width, 4), dtype=np.uint8) | |
# mask[..., -1] = 255 | |
# mask = self.postprocess(mask) | |
# x['mask'] = mask | |
# return super().preprocess(x) | |
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None, | |
session_id=None): | |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model) | |
captioner = captioner | |
if session_id is not None: | |
print('Init caption anything for session {}'.format(session_id)) | |
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner) | |
def validate_api_key(api_key): | |
api_key = str(api_key).strip() | |
print(api_key) | |
try: | |
test_llm = ChatOpenAI(model_name="gpt-4o", temperature=0, openai_api_key=api_key) | |
print("test_llm") | |
response = test_llm([HumanMessage(content='Hello')]) | |
print(response) | |
return True | |
except Exception as e: | |
print(f"API key validation failed: {e}") | |
return False | |
def init_openai_api_key(api_key=""): | |
# api_key = 'sk-proj-bxHhgjZV8TVgd1IupZrUT3BlbkFJvrthq6zIxpZVk3vwsvJ9' | |
text_refiner = None | |
visual_chatgpt = None | |
if api_key and len(api_key) > 30: | |
print(api_key) | |
if validate_api_key(api_key): | |
try: | |
# text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key) | |
# assert len(text_refiner.llm('hi')) > 0 # test | |
text_refiner = None | |
print("text refiner") | |
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key=api_key) | |
except Exception as e: | |
print(f"Error initializing TextRefiner or ConversationBot: {e}") | |
text_refiner = None | |
visual_chatgpt = None | |
else: | |
print("Invalid API key.") | |
else: | |
print("API key is too short.") | |
print(text_refiner) | |
openai_available = text_refiner is not None | |
if visual_chatgpt: | |
global gpt_state | |
gpt_state=1 | |
# return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*3 | |
return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]* 3 + [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*4+[gr.update(visible=False)] | |
else: | |
gpt_state=0 | |
# return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*3 | |
return [gr.update(visible=False)]*6 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*5 | |
def init_wo_openai_api_key(): | |
global gpt_state | |
gpt_state=0 | |
# return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*3 | |
return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)] | |
def get_click_prompt(chat_input, click_state, click_mode): | |
inputs = json.loads(chat_input) | |
if click_mode == 'Continuous': | |
points = click_state[0] | |
labels = click_state[1] | |
for input in inputs: | |
points.append(input[:2]) | |
labels.append(input[2]) | |
elif click_mode == 'Single': | |
points = [] | |
labels = [] | |
for input in inputs: | |
points.append(input[:2]) | |
labels.append(input[2]) | |
click_state[0] = points | |
click_state[1] = labels | |
else: | |
raise NotImplementedError | |
prompt = { | |
"prompt_type": ["click"], | |
"input_point": click_state[0], | |
"input_label": click_state[1], | |
"multimask_output": "True", | |
} | |
return prompt | |
def update_click_state(click_state, caption, click_mode): | |
if click_mode == 'Continuous': | |
click_state[2].append(caption) | |
elif click_mode == 'Single': | |
click_state[2] = [caption] | |
else: | |
raise NotImplementedError | |
async def chat_input_callback(*args): | |
visual_chatgpt, chat_input, click_state, state, aux_state ,language , autoplay,gender,api_key,image_input,log_state,history = args | |
message = chat_input["text"] | |
prompt="Please help me answer the question with this painting {question} in {language}." | |
prompt=prompt.format(question=message, language=language) | |
if visual_chatgpt is not None: | |
result=get_gpt_response(api_key, image_input,prompt+message,history) | |
read_info = re.sub(r'[#[\]!*]','',result) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
state = state + [(message,result)] | |
log_state += [(message,result)] | |
# log_state += [("%% chat messahe %%",None)] | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": result}) | |
if autoplay==False: | |
return state, state, aux_state, None,log_state,history | |
else: | |
audio = await texttospeech(read_info,language,gender) | |
return state, state, aux_state, audio,log_state,history | |
else: | |
response = "Text refiner is not initilzed, please input openai api key." | |
state = state + [(chat_input, response)] | |
audio = await texttospeech(response,language,gender) | |
return state, state, None, audio,log_state,history | |
async def upload_callback(image_input,state, log_state, visual_chatgpt=None, openai_api_key=None,language="English",narritive=None,history=None,autoplay=False,session="Session 1"): | |
print("narritive", narritive) | |
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask | |
image_input = image_input['background'] | |
if isinstance(image_input, str): | |
image_input = Image.open(io.BytesIO(base64.b64decode(image_input))) | |
elif isinstance(image_input, bytes): | |
image_input = Image.open(io.BytesIO(image_input)) | |
click_state = [[], [], []] | |
# width, height = image_input.size | |
# target_width=500 | |
# target_height=650 | |
# width_ratio = target_width / width | |
# height_ratio = target_height / height | |
# ratio = min(width_ratio, height_ratio) | |
# if ratio < 1.0: | |
# new_size = (int(width * ratio), int(height * ratio)) | |
# image_input = image_input.resize(new_size, Image.ANTIALIAS) | |
image_input = image_resize(image_input, res=1024) | |
model = build_caption_anything_with_models( | |
args, | |
api_key="", | |
captioner=shared_captioner, | |
sam_model=shared_sam_model, | |
ocr_reader=shared_ocr_reader, | |
session_id=iface.app_id | |
) | |
model.segmenter.set_image(image_input) | |
image_embedding = model.image_embedding | |
original_size = model.original_size | |
input_size = model.input_size | |
if visual_chatgpt is not None: | |
print('upload_callback: add caption to chatGPT memory') | |
new_image_path = get_new_image_name('chat_image', func_name='upload') | |
image_input.save(new_image_path) | |
print("img_path",new_image_path) | |
visual_chatgpt.current_image = new_image_path | |
paragraph = get_gpt_response(openai_api_key, new_image_path,f"What's going on in this picture? in {language}") | |
# visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt | |
parsed_data = get_gpt_response(openai_api_key, new_image_path,"Please provide the name, artist, year of creation (including the art historical period), and painting style used for this painting. Return the information in dictionary format without any newline characters. Format as follows: { \"name\": \"Name of the painting\", \"artist\": \"Name of the artist\", \"year\": \"Year of creation (Art historical period)\", \"style\": \"Painting style used in the painting\",\"gender\": \"The gender of the author\"}") | |
print(parsed_data) | |
parsed_data = json.loads(parsed_data.replace("'", "\"")) | |
name, artist, year, material,gender= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["style"],parsed_data['gender'] | |
gender=gender.lower() | |
print("gender",gender) | |
if language=="English": | |
if naritive_mapping[narritive]==0 : | |
msg=f"🤖 Hi, I am EyeSee. Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information." | |
elif naritive_mapping[narritive]==1: | |
msg=f"🧑🎨 Hello, I am the {artist}. Welcome to explore my painting, '{name}'. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant insights and thoughts behind my creation." | |
elif naritive_mapping[narritive]==2: | |
msg=f"🎨 Hello, Let's explore this painting '{name}' together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with relevant insights and thoughts from the perspective of the objects within the painting" | |
elif language=="Chinese": | |
if naritive_mapping[narritive]==0: | |
msg=f"🤖 你好,我是 EyeSee。让我们一起探索这幅画《{name}》。你可以点击你感兴趣的区域,并选择四种信息类型之一:描述、分析、解读和评判。根据你的选择,我会为你提供相关的信息。" | |
elif naritive_mapping[narritive]==1: | |
msg=f"🧑🎨 你好,我是{artist}。欢迎探索我的画作《{name}》。你可以点击你感兴趣的区域,并选择四种信息类型之一:描述、分析、解读和评判。根据你的选择,我会为你提供我的创作背后的相关见解和想法。" | |
elif naritive_mapping[narritive]==2: | |
msg=f"🎨 你好,让我们一起探索这幅画《{name}》。你可以点击你感兴趣的区域,并选择四种信息类型之一:描述、分析、解读和评判。根据你的选择,我会从画面上事物的视角为你提供相关的见解和想法。" | |
state = [(msg,None)] | |
log_state += [(name,None)] | |
log_state=log_state+[(paragraph,None)] | |
log_state=log_state+[(narritive,None)] | |
log_state=log_state+state | |
log_state = log_state + [("%% basic information %%", None)] | |
read_info=emoji.replace_emoji(msg,replace="") | |
history=[] | |
history.append({"role": "assistant", "content": paragraph+msg}) | |
audio_output = None | |
if autoplay: | |
audio_output = await texttospeech(read_info, language,gender) | |
return [state, state, image_input, click_state, image_input, image_input, image_input, image_input, image_embedding, \ | |
original_size, input_size] + [f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Style: {material}"]*4 + [paragraph,artist, gender,new_image_path,log_state,history,audio_output] | |
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, | |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData): | |
click_index = evt.index | |
if point_prompt == 'Positive': | |
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1])) | |
else: | |
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1])) | |
prompt = get_click_prompt(coordinate, click_state, click_mode) | |
input_points = prompt['input_point'] | |
input_labels = prompt['input_label'] | |
controls = {'length': length, | |
'sentiment': sentiment, | |
'factuality': factuality, | |
'language': language} | |
model = build_caption_anything_with_models( | |
args, | |
api_key="", | |
captioner=shared_captioner, | |
sam_model=shared_sam_model, | |
ocr_reader=shared_ocr_reader, | |
text_refiner=text_refiner, | |
session_id=iface.app_id | |
) | |
model.setup(image_embedding, original_size, input_size, is_image_set=True) | |
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False | |
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0] | |
# state = state + [("You've selected image point at {}, ".format(prompt["input_point"]), None)] | |
print(prompt["input_label"][-1]) | |
if language=="English": | |
if prompt["input_label"][-1]==1: | |
msg="You've added an area at {}. ".format(prompt["input_point"][-1]) | |
else: | |
msg="You've removed an area at {}. ".format(prompt["input_point"][-1]) | |
else: | |
if prompt["input_label"][-1]==1: | |
msg="你添加了在 {} 的区域。 ".format(prompt["input_point"][-1]) | |
else: | |
msg="你删除了在 {} 的区域。 ".format(prompt["input_point"][-1]) | |
state = state + [(msg, None)] | |
input_mask = np.array(out['mask'].convert('P')) | |
image_input_nobackground = mask_painter(np.array(image_input), input_mask,background_alpha=0) | |
click_index_state = click_index | |
input_mask_state = input_mask | |
input_points_state = input_points | |
input_labels_state = input_labels | |
out_state = out | |
if visual_chatgpt is not None: | |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop') | |
Image.open(out["crop_save_path"]).save(new_crop_save_path) | |
print("new crop save",new_crop_save_path) | |
return state, state, click_state, image_input_nobackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground | |
query_focus_en = [ | |
"Provide a description of the item.", | |
"Provide a description and analysis of the item.", | |
"Provide a description, analysis, and interpretation of the item.", | |
"Evaluate the item." | |
] | |
query_focus_zh = [ | |
"请描述一下这个物体。", | |
"请描述和分析一下这个物体。", | |
"请描述、分析和解释一下这个物体。", | |
"请以艺术鉴赏的角度评价一下这个物体。" | |
] | |
async def submit_caption(naritive, state,length, sentiment, factuality, language, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, | |
autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path, gender,log_state,history): | |
focus_value=focus_map[focus_type] | |
click_index = click_index_state | |
print("click_index",click_index) | |
print("input_points_state",input_points_state) | |
print("input_labels_state",input_labels_state) | |
prompt=generate_prompt(focus_type,paragraph,length,sentiment,factuality,language, naritive) | |
log_state = log_state + [("Selected image point: {}, Input label: {}".format(input_points_state, input_labels_state), None)] | |
print("Prompt:", prompt) | |
print("click",click_index) | |
log_state = log_state + [(naritive, None)] | |
# if not args.disable_gpt and text_refiner: | |
if not args.disable_gpt: | |
print("new crop save",new_crop_save_path) | |
focus_info=get_gpt_response(openai_api_key,new_crop_save_path,prompt,history) | |
if focus_info.startswith('"') and focus_info.endswith('"'): | |
focus_info=focus_info[1:-1] | |
focus_info=focus_info.replace('#', '') | |
# state = state + [(None, f"Wiki: {paragraph}")] | |
if language=="English": | |
user_query=query_focus_en[focus_value] | |
else: | |
user_query=query_focus_zh[focus_value] | |
state = state + [(user_query, f"{focus_info}")] | |
log_state = log_state + [(user_query, None)] | |
log_state = log_state + [(None, f"{focus_info}")] | |
# save history | |
history.append({"role": "user", "content": user_query}) | |
history.append({"role": "assistant", "content": focus_info}) | |
print("new_cap",focus_info) | |
read_info = re.sub(r'[#[\]!*]','',focus_info) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
print("read info",read_info) | |
if naritive_mapping[naritive]==2: | |
parsed_data = get_gpt_response(openai_api_key, new_crop_save_path,prompt = f"Based on the information {focus_info}, return the gender of this item, returns its most likely gender, do not return unknown, in the format {{\"gender\": \"<gender>\"}}") | |
parsed_data = json.loads(parsed_data) | |
try: | |
gender=parsed_data['gender'] | |
gender=gender.lower() | |
except: | |
print("error gpt responese") | |
print("item gender",gender) | |
try: | |
if autoplay==False: | |
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None,log_state,history | |
audio_output = await texttospeech(read_info, language,gender) | |
print("done") | |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output | |
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output,log_state,history | |
except Exception as e: | |
state = state + [(None, f"Error during TTS prediction: {str(e)}")] | |
print(f"Error during TTS prediction: {str(e)}") | |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None | |
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output,log_state,history | |
else: | |
state = state + [(None, f"Error during TTS prediction: {str(e)}")] | |
print(f"Error during TTS prediction: {str(e)}") | |
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None,None,log_state,history | |
naritive_mapping = {"Third-person": 0, "Single-Persona: Artist": 1, "Multi-Persona: Objects": 2} | |
def generate_prompt(focus_type, paragraph,length, sentiment, factuality, language,naritive): | |
mapped_value = focus_map.get(focus_type, -1) | |
controls = { | |
'length': length, | |
'sentiment': sentiment, | |
'factuality': factuality, | |
'language': language | |
} | |
naritive_value=naritive_mapping[naritive] | |
if mapped_value != -1: | |
prompt = prompt_list[naritive_value][mapped_value].format( | |
Wiki_caption=paragraph, | |
length=controls['length'], | |
sentiment=controls['sentiment'], | |
language=controls['language'] | |
) | |
else: | |
prompt = "Invalid focus type." | |
# if controls['factuality'] == "Imagination": | |
# prompt += " Assuming that I am someone who has viewed a lot of art and has a lot of experience viewing art. Explain artistic features (composition, color, style, or use of light) and discuss the symbolism of the content and its influence on later artistic movements." | |
return prompt | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
def get_gpt_response(api_key, image_path, prompt, history=None): | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {api_key}" | |
} | |
if history: | |
if len(history) > 4: | |
history = history[-4:] | |
else: | |
history = [] | |
messages = history[:] | |
base64_images = [] | |
if image_path: | |
if isinstance(image_path, list): | |
for img in image_path: | |
base64_image = encode_image(img) | |
base64_images.append(base64_image) | |
else: | |
base64_image = encode_image(image_path) | |
base64_images.append(base64_image) | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_images}" | |
} | |
} | |
] | |
}) | |
else: | |
messages.append({ | |
"role": "user", | |
"content": prompt | |
}) | |
payload = { | |
"model": "gpt-4o", | |
"messages": messages, | |
"max_tokens": 600 | |
} | |
# Sending the request to the OpenAI API | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
result = response.json() | |
print("gpt result",result) | |
try: | |
content = result['choices'][0]['message']['content'] | |
if content.startswith("```json"): | |
content = content[7:] | |
if content.endswith("```"): | |
content = content[:-3] | |
return content | |
except (KeyError, IndexError, json.JSONDecodeError) as e: | |
return json.dumps({"error": "Failed to parse model output", "details": str(e)}) | |
def get_sketch_prompt(mask: Image.Image): | |
""" | |
Get the prompt for the sketcher. | |
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster. | |
""" | |
mask = np.asarray(mask)[..., 0] | |
# Get the bounding box of the sketch | |
y, x = np.where(mask != 0) | |
x1, y1 = np.min(x), np.min(y) | |
x2, y2 = np.max(x), np.max(y) | |
prompt = { | |
'prompt_type': ['box'], | |
'input_boxes': [ | |
[x1, y1, x2, y2] | |
] | |
} | |
return prompt | |
submit_traj=0 | |
# async def inference_traject(naritive, origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state, | |
# original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type): | |
# image_input, mask = sketcher_image['background'], sketcher_image['layers'][0] | |
# crop_save_path="" | |
# prompt = get_sketch_prompt(mask) | |
# boxes = prompt['input_boxes'] | |
# boxes = boxes[0] | |
# controls = {'length': length, | |
# 'sentiment': sentiment, | |
# 'factuality': factuality, | |
# 'language': language} | |
# model = build_caption_anything_with_models( | |
# args, | |
# api_key="", | |
# captioner=shared_captioner, | |
# sam_model=shared_sam_model, | |
# ocr_reader=shared_ocr_reader, | |
# text_refiner=text_refiner, | |
# session_id=iface.app_id | |
# ) | |
# model.setup(image_embedding, original_size, input_size, is_image_set=True) | |
# enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False | |
# out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki,verbose=True)[0] | |
# print(trace_type) | |
# if trace_type=="Trace+Seg": | |
# input_mask = np.array(out['mask'].convert('P')) | |
# image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0) | |
# d3_input=mask_painter(np.array(image_input), input_mask) | |
# crop_save_path=out['crop_save_path'] | |
# else: | |
# image_input = Image.fromarray(np.array(origin_image)) | |
# draw = ImageDraw.Draw(image_input) | |
# draw.rectangle(boxes, outline='red', width=2) | |
# d3_input=image_input | |
# cropped_image = origin_image.crop(boxes) | |
# cropped_image.save('temp.png') | |
# crop_save_path='temp.png' | |
# print("crop_svae_path",out['crop_save_path']) | |
# # Update components and states | |
# state.append((f'Box: {boxes}', None)) | |
# # fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2)) | |
# # image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask) | |
# prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language,naritive) | |
# # if not args.disable_gpt and text_refiner: | |
# if not args.disable_gpt: | |
# focus_info=get_gpt_response(openai_api_key,crop_save_path,prompt) | |
# if focus_info.startswith('"') and focus_info.endswith('"'): | |
# focus_info=focus_info[1:-1] | |
# focus_info=focus_info.replace('#', '') | |
# state = state + [(None, f"{focus_info}")] | |
# print("new_cap",focus_info) | |
# read_info = re.sub(r'[#[\]!*]','',focus_info) | |
# read_info = emoji.replace_emoji(read_info,replace="") | |
# print("read info",read_info) | |
# # refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask, | |
# # input_points=input_points, input_labels=input_labels) | |
# try: | |
# audio_output = await texttospeech(read_info, language,autoplay,gender) | |
# # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output | |
# return state, state,image_input,audio_output,crop_save_path,d3_input | |
# except Exception as e: | |
# state = state + [(None, f"Error during TTS prediction: {str(e)}")] | |
# print(f"Error during TTS prediction: {str(e)}") | |
# # return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None | |
# return state, state, image_input,audio_output,crop_save_path | |
# else: | |
# try: | |
# audio_output = await texttospeech(focus_info, language, autoplay) | |
# # waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree) | |
# # return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output | |
# return state, state, image_input,audio_output | |
# except Exception as e: | |
# state = state + [(None, f"Error during TTS prediction: {str(e)}")] | |
# print(f"Error during TTS prediction: {str(e)}") | |
# return state, state, image_input,audio_output | |
def clear_chat_memory(visual_chatgpt, keep_global=False): | |
if visual_chatgpt is not None: | |
visual_chatgpt.memory.clear() | |
visual_chatgpt.point_prompt = "" | |
if keep_global: | |
# visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt | |
visual_chatgpt.agent.memory.save_context({"input": visual_chatgpt.global_prompt}, {"output": ""}) | |
else: | |
visual_chatgpt.current_image = None | |
visual_chatgpt.global_prompt = "" | |
def export_chat_log(chat_state,log_list,narrative): | |
try: | |
chat_log="" | |
if not chat_state: | |
return None | |
for entry in chat_state: | |
user_message, bot_response = entry | |
if user_message and bot_response: | |
chat_log += f"User: {user_message}\nBot: {bot_response}\n" | |
elif user_message and user_message.startswith("%%"): | |
chat_log += f"{user_message}\n" | |
elif user_message: | |
chat_log += f"User: {user_message}\n" | |
chat_log += f"///// \n" | |
elif bot_response: | |
chat_log += f"Bot: {bot_response}\n" | |
chat_log += f"///// \n" | |
print("export log...") | |
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
file_name = f"{current_time}_{narrative}.txt" | |
file_path = os.path.join(os.getcwd(), file_name) # Save to the current working directory | |
with open(file_path, 'w', encoding='utf-8') as file: | |
file.write(chat_log) | |
print(file_path) | |
log_list.append(file_path) | |
return log_list,log_list | |
except Exception as e: | |
print(f"An error occurred while exporting the chat log: {e}") | |
return None,None | |
async def get_artistinfo(artist_name,api_key,state,language,autoplay,length,log_state): | |
prompt = f"Provide a concise summary of about {length} words in {language} on the painter {artist_name}, covering his biography, major works, artistic style, significant contributions to the art world, and any major awards or recognitions he has received. Start your response with 'Artist Background: '." | |
res=get_gpt_response(api_key,None,prompt) | |
state = state + [(None, res)] | |
read_info = re.sub(r'[#[\]!*]','',res) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
log_state=log_state+[(f"res", None)] | |
# refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask, | |
# input_points=input_points, input_labels=input_labels) | |
if autoplay: | |
audio_output = await texttospeech(read_info, language) | |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output | |
return state, state,audio_output,log_state | |
return state, state,None,log_state | |
async def get_yearinfo(year,api_key,state,language,autoplay,length,log_state): | |
prompt = f"Provide a concise summary of about {length} words in {language} on the art historical period associated with the year {year}, covering its major characteristics, influential artists, notable works, and its significance in the broader context of art history with 'History Background: '." | |
res=get_gpt_response(api_key,None,prompt) | |
log_state=log_state+[(f"res", None)] | |
state = state + [(None, res)] | |
read_info = re.sub(r'[#[\]!*]','',res) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
# refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask, | |
# input_points=input_points, input_labels=input_labels) | |
if autoplay: | |
audio_output = await texttospeech(read_info, language) | |
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output | |
return state, state,audio_output,log_state | |
return state, state,None,log_state | |
# async def cap_everything(paragraph, visual_chatgpt,language,autoplay): | |
# # state = state + [(None, f"Caption Everything: {paragraph}")] | |
# Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n' | |
# AI_prompt = "Received." | |
# visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt | |
# # visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt | |
# visual_chatgpt.agent.memory.save_context({"input": Human_prompt}, {"output": AI_prompt}) | |
# # waveform_visual, audio_output=tts.predict(paragraph, input_language, input_audio, input_mic, use_mic, agree) | |
# audio_output=await texttospeech(paragraph,language,autoplay) | |
# return paragraph,audio_output | |
# def cap_everything_withoutsound(image_input, visual_chatgpt, text_refiner,paragraph): | |
# model = build_caption_anything_with_models( | |
# args, | |
# api_key="", | |
# captioner=shared_captioner, | |
# sam_model=shared_sam_model, | |
# ocr_reader=shared_ocr_reader, | |
# text_refiner=text_refiner, | |
# session_id=iface.app_id | |
# ) | |
# paragraph = model.inference_cap_everything(image_input, verbose=True) | |
# # state = state + [(None, f"Caption Everything: {paragraph}")] | |
# Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n' | |
# AI_prompt = "Received." | |
# visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt | |
# visual_chatgpt.agent.memory.save_context({"input": Human_prompt}, {"output": AI_prompt}) | |
# # visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt | |
# return paragraph | |
# def handle_liked(state,like_res): | |
# if state: | |
# like_res.append(state[-1][1]) | |
# print(f"Last response recorded: {state[-1][1]}") | |
# else: | |
# print("No response to record.") | |
# state = state + [(None, f"Liked Received 👍")] | |
# return state,like_res | |
# def handle_disliked(state,dislike_res): | |
# if state: | |
# dislike_res.append(state[-1][1]) | |
# print(f"Last response recorded: {state[-1][1]}") | |
# else: | |
# print("No response to record.") | |
# state = state + [(None, f"Disliked Received 🥹")] | |
# return state,dislike_res | |
# def get_style(): | |
# current_version = version.parse(gr.__version__) | |
# print(current_version) | |
# if current_version <= version.parse('3.24.1'): | |
# style = ''' | |
# #image_sketcher{min-height:500px} | |
# #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px} | |
# #image_upload{min-height:500px} | |
# #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px} | |
# .custom-language { | |
# width: 20%; | |
# } | |
# .custom-autoplay { | |
# width: 40%; | |
# } | |
# .custom-output { | |
# width: 30%; | |
# } | |
# ''' | |
# elif current_version <= version.parse('3.27'): | |
# style = ''' | |
# #image_sketcher{min-height:500px} | |
# #image_upload{min-height:500px} | |
# .custom-language { | |
# width: 20%; | |
# } | |
# .custom-autoplay { | |
# width: 40%; | |
# } | |
# .custom-output { | |
# width: 30%; | |
# } | |
# .custom-gallery { | |
# display: flex; | |
# flex-wrap: wrap; | |
# justify-content: space-between; | |
# } | |
# .custom-gallery img { | |
# width: 48%; | |
# margin-bottom: 10px; | |
# } | |
# ''' | |
# else: | |
# style = None | |
# return style | |
# def handle_like_dislike(like_data, like_state, dislike_state): | |
# if like_data.liked: | |
# if like_data.index not in like_state: | |
# like_state.append(like_data.index) | |
# message = f"Liked: {like_data.value} at index {like_data.index}" | |
# else: | |
# message = "You already liked this item" | |
# else: | |
# if like_data.index not in dislike_state: | |
# dislike_state.append(like_data.index) | |
# message = f"Disliked: {like_data.value} at index {like_data.index}" | |
# else: | |
# message = "You already disliked this item" | |
# return like_state, dislike_state | |
async def texttospeech(text, language,gender='female'): | |
try: | |
voice = filtered_language_dict[language][gender] | |
communicate = edge_tts.Communicate(text=text, voice=voice,rate="+25%") | |
file_path = "output.wav" | |
await communicate.save(file_path) | |
with open(file_path, "rb") as audio_file: | |
audio_bytes = BytesIO(audio_file.read()) | |
audio = base64.b64encode(audio_bytes.read()).decode("utf-8") | |
print("TTS processing completed.") | |
audio_style = 'style="width:210px;"' | |
audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls autoplay {audio_style}></audio>' | |
return audio_player | |
except Exception as e: | |
print(f"Error in texttospeech: {e}") | |
return None | |
# give the reason of recommendation | |
async def item_associate(image_path,new_crop,openai_api_key,language,autoplay,length,log_state,sort_score,narritive,state,evt: gr.SelectData): | |
persona=naritive_mapping[narritive] | |
rec_path=evt._data['value']['image']['path'] | |
index="Item Recommendation Picture "+str(evt.index) | |
print("rec_path",rec_path) | |
prompt=recommendation_prompt[persona].format(language=language,length=length) | |
if new_crop: | |
image_paths=[new_crop,rec_path] | |
else: | |
image_paths=[image_path,rec_path] | |
result=get_gpt_response(openai_api_key, image_paths, prompt) | |
print("recommend result",result) | |
state += [(None, f"{result}")] | |
log_state = log_state + [(narritive, None)] | |
log_state = log_state + [(f"image sort ranking {sort_score}", None)] | |
log_state = log_state + [(None, f"{result}")] | |
read_info = re.sub(r'[#[\]!*]','',result) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
print("associate",read_info) | |
audio_output=None | |
if autoplay: | |
audio_output = await texttospeech(read_info, language) | |
return state,state,audio_output,log_state,index,gr.update(value=[]) | |
async def style_associate(image_path,new_crop,openai_api_key,language,autoplay,length,log_state,sort_score,narritive,state,evt: gr.SelectData): | |
persona=naritive_mapping[narritive] | |
rec_path=evt._data['value']['image']['path'] | |
index="Style Recommendation Picture "+str(evt.index) | |
print("rec_path",rec_path) | |
prompt=recommendation_prompt[persona].format(language=language,length=length) | |
if new_crop: | |
image_paths=[new_crop,rec_path] | |
else: | |
image_paths=[image_path,rec_path] | |
result=get_gpt_response(openai_api_key, image_paths, prompt) | |
print("recommend result",result) | |
state += [(None, f"{result}")] | |
log_state = log_state + [(narritive, None)] | |
log_state = log_state + [(f"image sort ranking {sort_score}", None)] | |
log_state = log_state + [(None, f"{result}")] | |
read_info = re.sub(r'[#[\]!*]','',result) | |
read_info = emoji.replace_emoji(read_info,replace="") | |
print("associate",read_info) | |
audio_output=None | |
if autoplay: | |
audio_output = await texttospeech(read_info, language) | |
return state,state,audio_output,log_state,index,gr.update(value=[]) | |
def change_naritive(session_type,image_input, state, click_state, paragraph, origin_image,narritive,task_instruct,gallery_output,reco_reasons,language="English"): | |
if session_type=="Session 1": | |
return None, [], [], [[], [], []], "", None, None, [], [],[] | |
else: | |
if language=="English": | |
if narritive=="Third-person" : | |
state += [ | |
( | |
None, | |
f"🤖 Hi, I am EyeSee. Let's explore this painting together." | |
) | |
] | |
elif narritive=="Single-Persona: Artist": | |
state += [ | |
( | |
None, | |
f"🧑🎨 Let's delve into it from the perspective of the artist." | |
) | |
] | |
elif narritive=="Multi-Persona: Objects": | |
state += [ | |
( | |
None, | |
f"🎨 Let's delve into it from the perspective of the objects depicted in the scene." | |
) | |
] | |
elif language=="Chinese": | |
if narritive=="Third-person" : | |
state += [ | |
( | |
None, | |
"🤖 让我们从第三方视角一起探索这幅画吧。" | |
) | |
] | |
elif narritive == "Single-Persona: Artist": | |
state += [ | |
( | |
None, | |
"🧑🎨 让我们从艺术家的视角深入探索这幅画。" | |
) | |
] | |
elif narritive == "Multi-Persona: Objects": | |
state += [ | |
( | |
None, | |
"🎨 让我们从画面中事物的视角深入探索这幅画。" | |
) | |
] | |
return image_input, state, state, click_state, paragraph, origin_image,task_instruct,gallery_output,reco_reasons,reco_reasons | |
def print_like_dislike(x: gr.LikeData,state,log_state): | |
print(x.index, x.value, x.liked) | |
if x.liked == True: | |
print("liked") | |
log_state=log_state+[(f"User liked this message", None)] | |
state = state + [(None, f"Liked Received 👍")] | |
else: | |
log_state=log_state+[(f"User disliked this message", None)] | |
state = state + [(None, f"Disliked Received 👎")] | |
log_state+=[("%% user interaction %%", None)] | |
return log_state,state | |
def get_recommendationscore(index,score,log_state): | |
log_state+=[(f"{index} : {score}",None)] | |
log_state+=[("%% recommendation %%",None)] | |
return log_state | |
def toggle_icons_and_update_prompt(point_prompt): | |
new_prompt = "Negative" if point_prompt == "Positive" else "Positive" | |
new_add_icon = "assets/icons/plus-square-blue.png" if new_prompt == "Positive" else "assets/icons/plus-square.png" | |
new_add_css = "tools_button_clicked" if new_prompt == "Positive" else "tools_button" | |
new_minus_icon = "assets/icons/minus-square.png" if new_prompt == "Positive" else "assets/icons/minus-square-blue.png" | |
new_minus_css= "tools_button" if new_prompt == "Positive" else "tools_button_clicked" | |
return new_prompt, gr.update(icon=new_add_icon,elem_classes=new_add_css), gr.update(icon=new_minus_icon,elem_classes=new_minus_css) | |
add_icon_path="assets/icons/plus-square-blue.png" | |
minus_icon_path="assets/icons/minus-square.png" | |
print("this is a print test") | |
def create_ui(): | |
print(6) | |
title = """<p><h1 align="center">EyeSee Anything in Art</h1></p> | |
""" | |
description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """ | |
examples = [ | |
["test_images/1.The Ambassadors.jpg","test_images/task1.jpg","task 1"], | |
["test_images/2.Football Players.jpg","test_images/task2.jpg","task 2"], | |
["test_images/3-square.jpg","test_images/task3.jpg","task 3"], | |
# ["test_images/test4.jpg"], | |
# ["test_images/test5.jpg"], | |
# ["test_images/Picture5.png"], | |
] | |
with gr.Blocks( | |
css=css, | |
theme=gr.themes.Base() | |
) as iface: | |
#display in the chatbox | |
state = gr.State([]) | |
# expoer in log | |
log_state=gr.State([]) | |
# history log for gpt | |
history_log=gr.State([]) | |
out_state = gr.State(None) | |
click_state = gr.State([[], [], []]) | |
origin_image = gr.State(None) | |
image_embedding = gr.State(None) | |
text_refiner = gr.State(None) | |
visual_chatgpt = gr.State(None) | |
original_size = gr.State(None) | |
input_size = gr.State(None) | |
paragraph = gr.State("") | |
aux_state = gr.State([]) | |
click_index_state = gr.State((0, 0)) | |
input_mask_state = gr.State(np.zeros((1, 1))) | |
input_points_state = gr.State([]) | |
input_labels_state = gr.State([]) | |
#store the selected image | |
new_crop_save_path = gr.State(None) | |
image_input_nobackground = gr.State(None) | |
artist=gr.State(None) | |
gr.Markdown(title) | |
gr.Markdown(description) | |
point_prompt = gr.State("Positive") | |
log_list=gr.State([]) | |
gender=gr.State('female') | |
# store the whole image path | |
image_path=gr.State('') | |
pic_index=gr.State(None) | |
recomended_state=gr.State([]) | |
with gr.Row(): | |
auto_play = gr.Checkbox( | |
label="Check to autoplay audio", value=False, elem_classes="custom-autoplay" | |
) | |
output_audio = gr.HTML( | |
label="Synthesised Audio", elem_classes="custom-output" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1,min_width=50,visible=False) as instruct: | |
task_instuction=gr.Image(type="pil", interactive=True, elem_classes="task_instruct",height=650,label=None) | |
with gr.Column(scale=6): | |
with gr.Column(visible=False) as modules_not_need_gpt: | |
with gr.Tab("Base(GPT Power)",visible=False) as base_tab: | |
image_input_base = gr.Image(type="pil", interactive=True, elem_classes="image_upload",height=650) | |
with gr.Row(): | |
name_label_base = gr.Button(value="Name: ",elem_classes="info_btn") | |
artist_label_base = gr.Button(value="Artist: ",elem_classes="info_btn_interact") | |
year_label_base = gr.Button(value="Year: ",elem_classes="info_btn_interact") | |
material_label_base = gr.Button(value="Style: ",elem_classes="info_btn") | |
with gr.Tab("Base2",visible=False) as base_tab2: | |
image_input_base_2 = gr.Image(type="pil", interactive=True, elem_classes="image_upload",height=650) | |
with gr.Row(): | |
name_label_base2 = gr.Button(value="Name: ",elem_classes="info_btn") | |
artist_label_base2 = gr.Button(value="Artist: ",elem_classes="info_btn_interact") | |
year_label_base2 = gr.Button(value="Year: ",elem_classes="info_btn_interact") | |
material_label_base2 = gr.Button(value="Style: ",elem_classes="info_btn") | |
with gr.Tab("Click") as click_tab: | |
with gr.Row(): | |
with gr.Column(scale=10,min_width=600): | |
image_input = gr.Image(type="pil", interactive=True, elem_classes="image_upload",height=650) | |
example_image = gr.Image(type="pil", interactive=False, visible=False) | |
with gr.Row(): | |
name_label = gr.Button(value="Name: ",elem_classes="info_btn") | |
artist_label = gr.Button(value="Artist: ",elem_classes="info_btn_interact") | |
year_label = gr.Button(value="Year: ",elem_classes="info_btn_interact") | |
material_label = gr.Button(value="Style: ",elem_classes="info_btn") | |
# example_image_click = gr.Image(type="pil", interactive=False, visible=False) | |
# the tool column | |
with gr.Column(scale=1,elem_id="tool_box",min_width=80): | |
add_button = gr.Button(value="Extend Area", interactive=True,elem_classes="tools_button_add",icon=add_icon_path) | |
minus_button = gr.Button(value="Remove Area", interactive=True,elem_classes="tools_button",icon=minus_icon_path) | |
clear_button_click = gr.Button(value="Reset", interactive=True,elem_classes="tools_button") | |
focus_d = gr.Button(value="Describe",interactive=True,elem_classes="function_button",variant="primary") | |
focus_da = gr.Button(value="D+Analysis",interactive=True,elem_classes="function_button",variant="primary") | |
focus_dai = gr.Button(value="DA+Interprete",interactive=True,elem_classes="function_button",variant="primary") | |
focus_dda = gr.Button(value="Judge",interactive=True,elem_classes="function_button",variant="primary") | |
recommend_btn = gr.Button(value="Recommend",interactive=True,elem_classes="function_button_rec") | |
# focus_asso = gr.Button(value="Associate",interactive=True,elem_classes="function_button",variant="primary") | |
with gr.Row(visible=False): | |
with gr.Column(): | |
with gr.Row(): | |
# point_prompt = gr.Radio( | |
# choices=["Positive", "Negative"], | |
# value="Positive", | |
# label="Point Prompt", | |
# scale=5, | |
# interactive=True) | |
click_mode = gr.Radio( | |
choices=["Continuous", "Single"], | |
value="Continuous", | |
label="Clicking Mode", | |
scale=5, | |
interactive=True) | |
with gr.Tab("Trajectory (beta)", visible=False) as traj_tab: | |
# sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10, | |
# elem_id="image_sketcher") | |
sketcher_input = gr.ImageEditor(type="pil", interactive=True | |
) | |
with gr.Row(): | |
name_label_traj = gr.Button(value="Name: ") | |
artist_label_traj = gr.Button(value="Artist: ") | |
year_label_traj = gr.Button(value="Year: ") | |
material_label_traj = gr.Button(value="Material: ") | |
# example_image_traj = gr.Image(type="pil", interactive=False, visible=False) | |
with gr.Row(): | |
clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True) | |
submit_button_sketcher = gr.Button(value="Submit", interactive=True) | |
with gr.Column(visible=False,scale=4) as modules_need_gpt1: | |
with gr.Row(visible=False): | |
sentiment = gr.Radio( | |
choices=["Positive", "Natural", "Negative"], | |
value="Natural", | |
label="Sentiment", | |
interactive=True, | |
) | |
factuality = gr.Radio( | |
choices=["Factual", "Imagination"], | |
value="Factual", | |
label="Factuality", | |
interactive=True, | |
) | |
# length = gr.Slider( | |
# minimum=10, | |
# maximum=80, | |
# value=10, | |
# step=1, | |
# interactive=True, | |
# label="Generated Caption Length", | |
# ) | |
# 是否启用wiki内容整合到caption中 | |
enable_wiki = gr.Radio( | |
choices=["Yes", "No"], | |
value="No", | |
label="Expert", | |
interactive=True) | |
with gr.Column(visible=True) as modules_not_need_gpt3: | |
gr.Examples( | |
examples=examples, | |
inputs=[example_image], | |
) | |
with gr.Column(scale=4): | |
with gr.Column(visible=True) as module_key_input: | |
openai_api_key = gr.Textbox( | |
value="sk-proj-bxHhgjZV8TVgd1IupZrUT3BlbkFJvrthq6zIxpZVk3vwsvJ9", | |
placeholder="Input openAI API key", | |
show_label=False, | |
label="OpenAI API Key", | |
lines=1, | |
type="password") | |
with gr.Row(): | |
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary') | |
# disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, | |
# variant='primary') | |
with gr.Column(visible=False) as module_notification_box: | |
notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False) | |
# with gr.Column() as modules_need_gpt0: | |
# with gr.Column(visible=False) as modules_need_gpt2: | |
# paragraph_output = gr.Textbox(lines=16, label="Describe Everything", max_lines=16) | |
# cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True) | |
with gr.Column(visible=False) as modules_not_need_gpt2: | |
with gr.Row(): | |
naritive = gr.Radio( | |
choices=["Third-person", "Single-Persona: Artist","Multi-Persona: Objects"], | |
value="Third-person", | |
label="Persona", | |
scale=5, | |
interactive=True) | |
with gr.Blocks(): | |
chatbot = gr.Chatbot(label="Chatbox", elem_classes="chatbot",likeable=True,height=600,bubble_full_width=False) | |
with gr.Column() as modules_need_gpt3: | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=[".txt"], placeholder="Message EyeSee...", show_label=False) | |
with gr.Row(): | |
clear_button_text = gr.Button(value="Clear Chat", interactive=True) | |
export_button = gr.Button(value="Export Chat Log", interactive=True, variant="primary") | |
# submit_button_text = gr.Button(value="Send", interactive=True, variant="primary") | |
# upvote_btn = gr.Button(value="👍 Upvote", interactive=True) | |
# downvote_btn = gr.Button(value="👎 Downvote", interactive=True) | |
# TTS interface hidden initially | |
with gr.Column(visible=False) as tts_interface: | |
input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality") | |
input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en") | |
input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav") | |
input_mic = gr.Audio(sources="microphone", type="filepath", label="Use Microphone for Reference") | |
use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False) | |
agree = gr.Checkbox(label="Agree", value=True) | |
output_waveform = gr.Video(label="Waveform Visual") | |
# output_audio = gr.HTML(label="Synthesised Audio") | |
with gr.Row(): | |
submit_tts = gr.Button(value="Submit", interactive=True) | |
clear_tts = gr.Button(value="Clear", interactive=True) | |
with gr.Row(): | |
with gr.Column(scale=6): | |
with gr.Column(visible=False) as recommend: | |
sort_rec=gr.Dropdown(["1", "2", "3", "4"], | |
value=[], | |
multiselect=True, | |
label="Score", info="Please sort the pictures according to your preference" | |
) | |
gallery_result = gr.Gallery( | |
label="Recommendations Based on Item", | |
height="auto", | |
columns=2 | |
# columns=4, | |
# rows=2, | |
# show_label=False, | |
# allow_preview=True, | |
# object_fit="contain", | |
# height="auto", | |
# preview=True, | |
# show_share_button=True, | |
# show_download_button=True | |
) | |
style_gallery_result = gr.Gallery( | |
label="Recommendations Based on Style", | |
height="auto", | |
columns=2 | |
# columns=4, | |
# rows=2, | |
# show_label=False, | |
# allow_preview=True, | |
# object_fit="contain", | |
# height="auto", | |
# preview=True, | |
# show_share_button=True, | |
# show_download_button=True | |
) | |
with gr.Column(scale=4,visible=False) as reco_reasons: | |
recommend_bot = gr.Chatbot(label="Recommend Reasons", elem_classes="chatbot",height=600) | |
recommend_score = gr.Radio( | |
choices=[1,2,3,4,5,6,7], | |
label="Score", | |
interactive=True) | |
with gr.Row(): | |
task_type = gr.Textbox(visible=False) | |
gr.Examples( | |
examples=examples, | |
inputs=[example_image,task_instuction,task_type], | |
) | |
############################################################################### | |
############# this part is for text to image ############# | |
############################################################################### | |
with gr.Row(variant="panel",visible=False) as text2image_model: | |
with gr.Column(): | |
with gr.Column(): | |
gr.Radio(["Other Paintings by the Artist"], label="Artist", info="Who is the artist?🧑🎨"), | |
gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"), | |
gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"), | |
# to be done | |
gr.Dropdown( | |
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Items", info="Which items are you interested in?" | |
) | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced options", open=False): | |
num_images = gr.Slider( | |
label="Number of Images", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=4, | |
) | |
with gr.Row(): | |
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=5, | |
lines=4, | |
placeholder="Enter a negative prompt", | |
value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW", | |
visible=True, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=100, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=100, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.1, | |
maximum=6, | |
step=0.1, | |
value=3.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=15, | |
step=1, | |
value=8, | |
) | |
# with gr.Column(): | |
# result = gr.Gallery( | |
# label="Result", | |
# height="auto", | |
# columns=4 | |
# # columns=4, | |
# # rows=2, | |
# # show_label=False, | |
# # allow_preview=True, | |
# # object_fit="contain", | |
# # height="auto", | |
# # preview=True, | |
# # show_share_button=True, | |
# # show_download_button=True | |
# ) | |
with gr.Row(visible=False) as export: | |
chat_log_file = gr.File(label="Download Chat Log",scale=5) | |
with gr.Row(elem_id="top_row",visible=False) as top_row: | |
session_type = gr.Dropdown( | |
["Session 1","Session 2"], | |
value="Session 1", label="Task", interactive=True, elem_classes="custom-language" | |
) | |
language = gr.Dropdown( | |
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], | |
value="English", label="Language", interactive=True, elem_classes="custom-language" | |
) | |
length = gr.Slider( | |
minimum=60, | |
maximum=120, | |
value=80, | |
step=1, | |
interactive=True, | |
label="Generated Caption Length", | |
) | |
# auto_play = gr.Checkbox( | |
# label="Check to autoplay audio", value=False, elem_classes="custom-autoplay" | |
# ) | |
# output_audio = gr.HTML( | |
# label="Synthesised Audio", elem_classes="custom-output" | |
# ) | |
# gr.Examples( | |
# examples=examples, | |
# inputs=prompt, | |
# cache_examples=False | |
# ) | |
use_negative_prompt.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=use_negative_prompt, | |
outputs=negative_prompt, | |
api_name=False, | |
) | |
# gr.on( | |
# triggers=[ | |
# prompt.submit, | |
# negative_prompt.submit, | |
# run_button.click, | |
# ], | |
# fn=generate, | |
# inputs=[ | |
# prompt, | |
# negative_prompt, | |
# use_negative_prompt, | |
# seed, | |
# width, | |
# height, | |
# guidance_scale, | |
# num_inference_steps, | |
# randomize_seed, | |
# num_images | |
# ], | |
# outputs=[result, seed], | |
# api_name="run", | |
# ) | |
recommend_btn.click( | |
fn=infer, | |
inputs=[new_crop_save_path,image_path,state,language,task_type], | |
outputs=[gallery_result,style_gallery_result,chatbot,state] | |
) | |
gallery_result.select( | |
item_associate, | |
inputs=[image_path,new_crop_save_path,openai_api_key,language,auto_play,length,log_state,sort_rec,naritive,recomended_state], | |
outputs=[recommend_bot,recomended_state,output_audio,log_state,pic_index,recommend_score], | |
) | |
style_gallery_result.select( | |
style_associate, | |
inputs=[image_path,new_crop_save_path,openai_api_key,language,auto_play,length,log_state,sort_rec,naritive,recomended_state], | |
outputs=[recommend_bot,recomended_state,output_audio,log_state,pic_index,recommend_score], | |
) | |
############################################################################### | |
############# above part is for text to image ############# | |
############################################################################### | |
############################################################################### | |
# this part is for 3d generate. | |
############################################################################### | |
# with gr.Row(variant="panel",visible=False) as d3_model: | |
# with gr.Column(): | |
# with gr.Row(): | |
# input_image = gr.Image( | |
# label="Input Image", | |
# image_mode="RGBA", | |
# sources="upload", | |
# #width=256, | |
# #height=256, | |
# type="pil", | |
# elem_id="content_image", | |
# ) | |
# processed_image = gr.Image( | |
# label="Processed Image", | |
# image_mode="RGBA", | |
# #width=256, | |
# #height=256, | |
# type="pil", | |
# interactive=False | |
# ) | |
# with gr.Row(): | |
# with gr.Group(): | |
# do_remove_background = gr.Checkbox( | |
# label="Remove Background", value=True | |
# ) | |
# sample_seed = gr.Number(value=42, label="Seed Value", precision=0) | |
# sample_steps = gr.Slider( | |
# label="Sample Steps", | |
# minimum=30, | |
# maximum=75, | |
# value=75, | |
# step=5 | |
# ) | |
# with gr.Row(): | |
# submit = gr.Button("Generate", elem_id="generate", variant="primary") | |
# with gr.Row(variant="panel"): | |
# gr.Examples( | |
# examples=[ | |
# os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) | |
# ], | |
# inputs=[input_image], | |
# label="Examples", | |
# cache_examples=False, | |
# examples_per_page=16 | |
# ) | |
# with gr.Column(): | |
# with gr.Row(): | |
# with gr.Column(): | |
# mv_show_images = gr.Image( | |
# label="Generated Multi-views", | |
# type="pil", | |
# width=379, | |
# interactive=False | |
# ) | |
# # with gr.Column(): | |
# # output_video = gr.Video( | |
# # label="video", format="mp4", | |
# # width=379, | |
# # autoplay=True, | |
# # interactive=False | |
# # ) | |
# with gr.Row(): | |
# with gr.Tab("OBJ"): | |
# output_model_obj = gr.Model3D( | |
# label="Output Model (OBJ Format)", | |
# interactive=False, | |
# ) | |
# gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.") | |
# with gr.Tab("GLB"): | |
# output_model_glb = gr.Model3D( | |
# label="Output Model (GLB Format)", | |
# interactive=False, | |
# ) | |
# gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.") | |
# mv_images = gr.State() | |
chatbot.like(print_like_dislike, inputs=[state,log_state], outputs=[log_state,chatbot]) | |
# submit.click(fn=check_input_image, inputs=[new_crop_save_path], outputs=[processed_image]).success( | |
# fn=generate_mvs, | |
# inputs=[processed_image, sample_steps, sample_seed], | |
# outputs=[mv_images, mv_show_images] | |
# ).success( | |
# fn=make3d, | |
# inputs=[mv_images], | |
# outputs=[output_model_obj, output_model_glb] | |
# ) | |
############################################################################### | |
# above part is for 3d generate. | |
############################################################################### | |
def clear_tts_fields(): | |
return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None] | |
# submit_tts.click( | |
# tts.predict, | |
# inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree], | |
# outputs=[output_waveform, output_audio], | |
# queue=True | |
# ) | |
clear_tts.click( | |
clear_tts_fields, | |
inputs=None, | |
outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio], | |
queue=False | |
) | |
# clear_button_sketcher.click( | |
# lambda x: (x), | |
# [origin_image], | |
# [sketcher_input], | |
# queue=False, | |
# show_progress=False | |
# ) | |
recommend_score.select( | |
get_recommendationscore, | |
inputs=[pic_index,recommend_score,log_state], | |
outputs=[log_state], | |
) | |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], | |
outputs=[export, modules_need_gpt1, modules_need_gpt3, modules_not_need_gpt, | |
modules_not_need_gpt2, tts_interface, module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box,top_row,recommend,reco_reasons,instruct,modules_not_need_gpt3]) | |
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], | |
outputs=[export,modules_need_gpt1, modules_need_gpt3, | |
modules_not_need_gpt, | |
modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box,top_row,recommend,reco_reasons,instruct,modules_not_need_gpt3]) | |
# disable_chatGPT_button.click(init_wo_openai_api_key, | |
# outputs=[export,modules_need_gpt1, modules_need_gpt3, | |
# modules_not_need_gpt, | |
# modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box,top_row]) | |
# artist_label_base2.click( | |
# get_artistinfo, | |
# inputs=[artist_label_base2,openai_api_key,state,language,auto_play,length], | |
# outputs=[chatbot,state,output_audio] | |
# ) | |
artist_label.click( | |
get_artistinfo, | |
inputs=[artist_label,openai_api_key,state,language,auto_play,length,log_state], | |
outputs=[chatbot,state,output_audio,log_state] | |
) | |
# artist_label_traj.click( | |
# get_artistinfo, | |
# inputs=[artist_label_traj,openai_api_key,state,language,auto_play,length], | |
# outputs=[chatbot,state,output_audio] | |
# ) | |
# year_label_base2.click( | |
# get_yearinfo, | |
# inputs=[year_label_base2,openai_api_key,state,language,auto_play,length], | |
# outputs=[chatbot,state,output_audio] | |
# ) | |
year_label.click( | |
get_yearinfo, | |
inputs=[year_label,openai_api_key,state,language,auto_play,length,log_state], | |
outputs=[chatbot,state,output_audio,log_state] | |
) | |
# year_label_traj.click( | |
# get_yearinfo, | |
# inputs=[year_label_traj,openai_api_key,state,language,auto_play,length], | |
# outputs=[chatbot,state,output_audio] | |
# ) | |
# enable_chatGPT_button.click( | |
# lambda: (None, [], [], [[], [], []], "", "", ""), | |
# [], | |
# [image_input, chatbot, state, click_state, paragraph_output, origin_image], | |
# queue=False, | |
# show_progress=False | |
# ) | |
# openai_api_key.submit( | |
# lambda: (None, [], [], [[], [], []], "", "", ""), | |
# [], | |
# [image_input, chatbot, state, click_state, paragraph_output, origin_image], | |
# queue=False, | |
# show_progress=False | |
# ) | |
# cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play], | |
# [paragraph_output,output_audio]) | |
def reset_and_add(origin_image): | |
new_prompt = "Positive" | |
new_add_icon = "assets/icons/plus-square-blue.png" | |
new_add_css = "tools_button_clicked" | |
new_minus_icon = "assets/icons/minus-square.png" | |
new_minus_css= "tools_button" | |
return [[],[],[]],origin_image, new_prompt, gr.update(icon=new_add_icon,elem_classes=new_add_css), gr.update(icon=new_minus_icon,elem_classes=new_minus_css) | |
clear_button_click.click( | |
reset_and_add, | |
[origin_image], | |
[click_state, image_input,point_prompt,add_button,minus_button], | |
queue=False, | |
show_progress=False | |
) | |
clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt]) | |
# clear_button_image.click( | |
# lambda: (None, [], [], [[], [], []], "", "", ""), | |
# [], | |
# [image_input, chatbot, state, click_state, paragraph, origin_image], | |
# queue=False, | |
# show_progress=False | |
# ) | |
# clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt]) | |
clear_button_text.click( | |
lambda: ([], [], [[], [], [], []],[]), | |
[], | |
[chatbot, state, click_state,history_log], | |
queue=False, | |
show_progress=False | |
) | |
clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt]) | |
image_input.clear( | |
lambda: (None, [], [], [[], [], []], "", None, []), | |
[], | |
[image_input, chatbot, state, click_state, paragraph, origin_image,history_log], | |
queue=False, | |
show_progress=False | |
) | |
image_input.clear(clear_chat_memory, inputs=[visual_chatgpt]) | |
# image_input.change( | |
# lambda: ([], [], [[], [], []], [], []), | |
# [], | |
# [chatbot, state, click_state, history_log, log_state], | |
# queue=False, | |
# show_progress=False | |
# ) | |
# image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key,language,naritive], | |
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2, | |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \ | |
# name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \ | |
# paragraph,artist,gender,image_path]) | |
# image_input_base_2.upload(upload_callback, [image_input_base_2, state, visual_chatgpt,openai_api_key,language,naritive], | |
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2, | |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \ | |
# name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \ | |
# paragraph,artist,gender,image_path]) | |
image_input.upload(upload_callback, [image_input, state, log_state,visual_chatgpt,openai_api_key,language,naritive,history_log,auto_play,session_type], | |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2, | |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \ | |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \ | |
paragraph,artist,gender,image_path,log_state,history_log,output_audio]) | |
# sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt,openai_api_key], | |
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2, | |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \ | |
# name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \ | |
# paragraph,artist]) | |
# image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key], | |
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input, | |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist]) | |
# sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key], | |
# [chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input, | |
# image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist]) | |
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play,gender,openai_api_key,image_path,log_state,history_log], | |
[chatbot, state, aux_state,output_audio,log_state,history_log]) | |
# chat_input.submit(lambda: "", None, chat_input) | |
chat_input.submit(lambda: {"text": ""}, None, chat_input) | |
# submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play], | |
# [chatbot, state, aux_state,output_audio]) | |
# submit_button_text.click(lambda: "", None, chat_input) | |
example_image.change(upload_callback, [example_image, state, log_state, visual_chatgpt, openai_api_key,language,naritive,history_log,auto_play,session_type], | |
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,image_input_base_2, | |
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base, \ | |
name_label_base2, artist_label_base2, year_label_base2, material_label_base2,name_label_traj, artist_label_traj, year_label_traj, material_label_traj, \ | |
paragraph,artist,gender,image_path, log_state,history_log,output_audio]) | |
example_image.change(clear_chat_memory, inputs=[visual_chatgpt]) | |
# example_image.change( | |
# lambda:([],[]), | |
# [], | |
# [gallery_result,recommend_bot]) | |
# def on_click_tab_selected(): | |
# if gpt_state ==1: | |
# print(gpt_state) | |
# print("using gpt") | |
# return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2 | |
# else: | |
# print("no gpt") | |
# print("gpt_state",gpt_state) | |
# return [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 | |
# def on_base_selected(): | |
# if gpt_state ==1: | |
# print(gpt_state) | |
# print("using gpt") | |
# return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2 | |
# else: | |
# print("no gpt") | |
# return [gr.update(visible=False)]*4 | |
# traj_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2]) | |
# click_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2]) | |
# base_tab.select(on_base_selected, outputs=[modules_need_gpt0,modules_need_gpt2,modules_not_need_gpt2,modules_need_gpt1]) | |
# base_tab2.select(on_base_selected, outputs=[modules_not_need_gpt2,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt1]) | |
image_input.select( | |
inference_click, | |
inputs=[ | |
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length, | |
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state | |
], | |
outputs=[chatbot, state, click_state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground], | |
show_progress=False, queue=True | |
) | |
focus_d.click( | |
submit_caption, | |
inputs=[ | |
naritive, state,length, sentiment, factuality, language, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, auto_play, paragraph,focus_d,openai_api_key,new_crop_save_path,gender,log_state,history_log | |
], | |
outputs=[ | |
chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,output_audio,log_state,history_log | |
], | |
show_progress=True, | |
queue=True | |
) | |
focus_da.click( | |
submit_caption, | |
inputs=[ | |
naritive,state,length, sentiment, factuality, language, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,auto_play, paragraph,focus_da,openai_api_key,new_crop_save_path,gender,log_state, | |
history_log | |
], | |
outputs=[ | |
chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,output_audio,log_state,history_log | |
], | |
show_progress=True, | |
queue=True | |
) | |
focus_dai.click( | |
submit_caption, | |
inputs=[ | |
naritive,state,length, sentiment, factuality, language, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, | |
auto_play, paragraph,focus_dai,openai_api_key,new_crop_save_path,gender,log_state,history_log | |
], | |
outputs=[ | |
chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,output_audio,log_state,history_log | |
], | |
show_progress=True, | |
queue=True | |
) | |
focus_dda.click( | |
submit_caption, | |
inputs=[ | |
naritive,state,length, sentiment, factuality, language, | |
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, | |
auto_play, paragraph,focus_dda,openai_api_key,new_crop_save_path,gender,log_state,history_log | |
], | |
outputs=[ | |
chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,output_audio,log_state,history_log | |
], | |
show_progress=True, | |
queue=True | |
) | |
add_button.click( | |
toggle_icons_and_update_prompt, | |
inputs=[point_prompt], | |
outputs=[point_prompt,add_button,minus_button], | |
show_progress=True, | |
queue=True | |
) | |
minus_button.click( | |
toggle_icons_and_update_prompt, | |
inputs=[point_prompt], | |
outputs=[point_prompt,add_button,minus_button], | |
show_progress=True, | |
queue=True | |
) | |
# submit_button_sketcher.click( | |
# inference_traject, | |
# inputs=[ | |
# origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state, | |
# original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch | |
# ], | |
# outputs=[chatbot, state, sketcher_input,output_audio,new_crop_save_path], | |
# show_progress=False, queue=True | |
# ) | |
export_button.click( | |
export_chat_log, | |
inputs=[log_state,log_list,naritive], | |
outputs=[chat_log_file,log_list], | |
queue=True | |
) | |
naritive.change( | |
change_naritive, | |
[session_type, image_input, state, click_state, paragraph, origin_image,naritive, | |
task_instuction,gallery_result,recomended_state,language], | |
[image_input, chatbot, state, click_state, paragraph, origin_image,task_instuction,gallery_result,recomended_state,recommend_bot], | |
queue=False, | |
show_progress=False | |
) | |
def session_change(): | |
instruction=Image.open('test_images/task4.jpg') | |
return None, [], [], [[], [], []], "", None, [],[],instruction,"task 4" | |
session_type.change( | |
session_change, | |
[], | |
[image_input, chatbot, state, click_state, paragraph, origin_image,history_log,log_state,task_instuction,task_type] | |
) | |
# upvote_btn.click( | |
# handle_liked, | |
# inputs=[state,like_res], | |
# outputs=[chatbot,like_res] | |
# ) | |
# downvote_btn.click( | |
# handle_disliked, | |
# inputs=[state,dislike_res], | |
# outputs=[chatbot,dislike_res] | |
# ) | |
return iface | |
if __name__ == '__main__': | |
print("main") | |
iface = create_ui() | |
iface.queue(api_open=False, max_size=10) | |
# iface.queue(concurrency_count=5, api_open=False, max_size=10) | |
iface.launch(server_name="0.0.0.0") |