EyeSee_chi / app.py
LouisLi's picture
Update app.py
6b3e32a verified
raw
history blame
83.2 kB
from io import BytesIO
from math import inf
import os
import base64
import json
import gradio as gr
import numpy as np
from gradio import processing_utils
import requests
from packaging import version
from PIL import Image, ImageDraw
import functools
import emoji
from langchain.llms.openai import OpenAI
from caption_anything.model import CaptionAnything
from caption_anything.utils.image_editing_utils import create_bubble_frame
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
from caption_anything.utils.parser import parse_augment
from caption_anything.captioner import build_captioner
from caption_anything.text_refiner import build_text_refiner
from caption_anything.segmenter import build_segmenter
from chatbox import ConversationBot, build_chatbot_tools, get_new_image_name
from segment_anything import sam_model_registry
import easyocr
import re
import edge_tts
# import tts
###############################################################################
############# this part is for 3D generate #############
###############################################################################
# import spaces #
import os
# import uuid
# from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
# from diffusers.utils import export_to_video
# from safetensors.torch import load_file
#from diffusers.models.modeling_outputs import Transformer2DModelOutput
import random
import uuid
import json
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
import imageio
import numpy as np
import torch
import rembg
from PIL import Image
from torchvision.transforms import v2
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from einops import rearrange, repeat
from tqdm import tqdm
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
FOV_to_intrinsics,
get_zero123plus_input_cameras,
get_circular_camera_poses,
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
import tempfile
from functools import partial
from huggingface_hub import hf_hub_download
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
"""
Get the rendering camera parameters.
"""
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
if is_flexicubes:
cameras = torch.linalg.inv(c2ws)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
else:
extrinsics = c2ws.flatten(-2)
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
return cameras
def images_to_video(images, output_path, fps=30):
# images: (N, C, H, W)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
frames = []
for i in range(images.shape[0]):
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
assert frame.min() >= 0 and frame.max() <= 255, \
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
frames.append(frame)
imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
###############################################################################
# Configuration.
###############################################################################
import shutil
def find_cuda():
# Check if CUDA_HOME or CUDA_PATH environment variables are set
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home and os.path.exists(cuda_home):
return cuda_home
# Search for the nvcc executable in the system's PATH
nvcc_path = shutil.which('nvcc')
if nvcc_path:
# Remove the 'bin/nvcc' part to get the CUDA installation path
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
return cuda_path
return None
cuda_path = find_cuda()
if cuda_path:
print(f"CUDA installation found at: {cuda_path}")
else:
print("CUDA installation not found")
config_path = 'configs/instant-nerf-base.yaml'
config = OmegaConf.load(config_path)
config_name = os.path.basename(config_path).replace('.yaml', '')
model_config = config.model_config
infer_config = config.infer_config
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
device = torch.device('cuda')
# load diffusion model
print('Loading diffusion model ...')
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
# load custom white-background UNet
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)
pipeline = pipeline.to(device)
# load reconstruction model
print('Loading reconstruction model ...')
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_nerf_base.ckpt", repo_type="model")
model0 = instantiate_from_config(model_config)
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
model0.load_state_dict(state_dict, strict=True)
model0 = model0.to(device)
print('Loading Finished!')
def check_input_image(input_image):
if input_image is None:
raise gr.Error("No image uploaded!")
image = None
else:
image = Image.open(input_image)
return image
def preprocess(input_image, do_remove_background):
rembg_session = rembg.new_session() if do_remove_background else None
if do_remove_background:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image
# @spaces.GPU
def generate_mvs(input_image, sample_steps, sample_seed):
seed_everything(sample_seed)
# sampling
z123_image = pipeline(
input_image,
num_inference_steps=sample_steps
).images[0]
show_image = np.asarray(z123_image, dtype=np.uint8)
show_image = torch.from_numpy(show_image) # (960, 640, 3)
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
show_image = Image.fromarray(show_image.numpy())
return z123_image, show_image
# @spaces.GPU
def make3d(images):
global model0
if IS_FLEXICUBES:
model0.init_flexicubes_geometry(device)
model0 = model0.eval()
images = np.asarray(images, dtype=np.float32) / 255.0
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
images = images.unsqueeze(0).to(device)
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
print(mesh_fpath)
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
mesh_dirname = os.path.dirname(mesh_fpath)
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
with torch.no_grad():
# get triplane
planes = model0.forward_planes(images, input_cameras)
# # get video
# chunk_size = 20 if IS_FLEXICUBES else 1
# render_size = 384
# frames = []
# for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
# if IS_FLEXICUBES:
# frame = model.forward_geometry(
# planes,
# render_cameras[:, i:i+chunk_size],
# render_size=render_size,
# )['img']
# else:
# frame = model.synthesizer(
# planes,
# cameras=render_cameras[:, i:i+chunk_size],
# render_size=render_size,
# )['images_rgb']
# frames.append(frame)
# frames = torch.cat(frames, dim=1)
# images_to_video(
# frames[0],
# video_fpath,
# fps=30,
# )
# print(f"Video saved to {video_fpath}")
# get mesh
mesh_out = model0.extract_mesh(
planes,
use_texture_map=False,
**infer_config,
)
vertices, faces, vertex_colors = mesh_out
vertices = vertices[:, [1, 2, 0]]
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
save_obj(vertices, faces, vertex_colors, mesh_fpath)
print(f"Mesh saved to {mesh_fpath}")
return mesh_fpath, mesh_glb_fpath
###############################################################################
############# above part is for 3D generate #############
###############################################################################
###############################################################################
############# this part is for text to image #############
###############################################################################
# Use environment variables for flexibility
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
# Determine device and load model outside of function for efficiency
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
add_watermarker=False,
).to(device)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# Torch compile for potential speedup (experimental)
if USE_TORCH_COMPILE:
pipe.compile()
# CPU offloading for larger RAM capacity (experimental)
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
MAX_SEED = np.iinfo(np.int32).max
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
# @spaces.GPU(duration=30, queue=False)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 1,
width: int = 200,
height: int = 200,
guidance_scale: float = 3,
num_inference_steps: int = 30,
randomize_seed: bool = False,
num_images: int = 4, # Number of images to generate
use_resolution_binning: bool = True,
progress=gr.Progress(track_tqdm=True),
):
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator(device=device).manual_seed(seed)
# Improved options handling
options = {
"prompt": [prompt] * num_images,
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"output_type": "pil",
}
# Use resolution binning for faster generation with less VRAM usage
# if use_resolution_binning:
# options["use_resolution_binning"] = True
# Generate images potentially in batches
images = []
for i in range(0, num_images, BATCH_SIZE):
batch_options = options.copy()
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
if "negative_prompt" in batch_options:
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
images.extend(pipe(**batch_options).images,f"label {i}")
image_paths = [save_image(img) for img in images]
return image_paths, seed
examples = [
"a cat eating a piece of cheese",
"a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
"Ironman VS Hulk, ultrarealistic",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
"Kids going to school, Anime style"
]
###############################################################################
############# above part is for text to image #############
###############################################################################
css = """
#warning {background-color: #FFCCCB}
.chatbot {
padding: 0 !important;
margin: 0 !important;
}
"""
filtered_language_dict = {
'English': 'en-US-JennyNeural',
'Chinese': 'zh-CN-XiaoxiaoNeural',
'French': 'fr-FR-DeniseNeural',
'Spanish': 'es-MX-DaliaNeural',
'Arabic': 'ar-SA-ZariyahNeural',
'Portuguese': 'pt-BR-FranciscaNeural',
'Cantonese': 'zh-HK-HiuGaaiNeural'
}
focus_map = {
"CFV-D":0,
"CFV-DA":1,
"CFV-DAI":2,
"PFV-DDA":3
}
'''
prompt_list = [
'Wiki_caption: {Wiki_caption}, you have to generate a caption according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}, you have to select sentences from wiki caption that describe the surrounding objects that may be associated with the picture object. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}. You have to choose sentences from the wiki caption that describe unrelated objects to the image. Around {length} words of {sentiment} sentiment in {language}.'
]
prompt_list = [
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Around {length} words of {sentiment} sentiment in {language}.'
]
'''
prompt_list = [
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact (describes the selected object but does not include analysis)as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and list one fact and one analysis and one interpret as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.',
'Wiki_caption: {Wiki_caption}, you have to help me understand what is about the selected object and the objects that may be related to the selected object and list one fact of selected object, one fact of related object and one analysis as markdown outline with appropriate emojis that describes what you see according to the image and wiki caption. Each point listed is to be in {language} language, with a response length of about {length} words.'
]
gpt_state = 0
VOICE = "en-GB-SoniaNeural"
article = """
<div style='margin:20px auto;'>
<p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
</div>
"""
args = parse_augment()
args.segmenter = "huge"
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
args.clip_filter = True
if args.segmenter_checkpoint is None:
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
else:
segmenter_checkpoint = args.segmenter_checkpoint
shared_captioner = build_captioner(args.captioner, args.device, args)
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
ocr_lang = ["ch_tra", "en"]
shared_ocr_reader = easyocr.Reader(ocr_lang)
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
shared_chatbot_tools = build_chatbot_tools(tools_dict)
class ImageSketcher(gr.Image):
"""
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
"""
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
def __init__(self, **kwargs):
super().__init__(tool="sketch", **kwargs)
def preprocess(self, x):
if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
assert isinstance(x, dict)
if x['mask'] is None:
decode_image = processing_utils.decode_base64_to_image(x['image'])
width, height = decode_image.size
mask = np.zeros((height, width, 4), dtype=np.uint8)
mask[..., -1] = 255
mask = self.postprocess(mask)
x['mask'] = mask
return super().preprocess(x)
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
session_id=None):
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
captioner = captioner
if session_id is not None:
print('Init caption anything for session {}'.format(session_id))
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
def validate_api_key(api_key):
api_key = str(api_key).strip()
print(api_key)
try:
test_llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
response = test_llm("Test API call")
print(response)
return True
except Exception as e:
print(f"API key validation failed: {e}")
return False
def init_openai_api_key(api_key=""):
text_refiner = None
visual_chatgpt = None
if api_key and len(api_key) > 30:
print(api_key)
if validate_api_key(api_key):
try:
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
assert len(text_refiner.llm('hi')) > 0 # test
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
except Exception as e:
print(f"Error initializing TextRefiner or ConversationBot: {e}")
text_refiner = None
visual_chatgpt = None
else:
print("Invalid API key.")
else:
print("API key is too short.")
print(text_refiner)
openai_available = text_refiner is not None
if openai_available:
global gpt_state
gpt_state=1
# return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*3
return [gr.update(visible=True)]+[gr.update(visible=False)]+[gr.update(visible=True)]*3+[gr.update(visible=False)]+ [gr.update(visible=False)]*3 + [text_refiner, visual_chatgpt, None]+[gr.update(visible=True)]*2
else:
gpt_state=0
# return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*3
return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']+[gr.update(visible=False)]*2
def init_wo_openai_api_key():
global gpt_state
gpt_state=0
# return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*3
return [gr.update(visible=False)]*4 + [gr.update(visible=True)]+ [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2 + [None, None, None]+[gr.update(visible=False)]*2
def get_click_prompt(chat_input, click_state, click_mode):
inputs = json.loads(chat_input)
if click_mode == 'Continuous':
points = click_state[0]
labels = click_state[1]
for input in inputs:
points.append(input[:2])
labels.append(input[2])
elif click_mode == 'Single':
points = []
labels = []
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
click_state[1] = labels
else:
raise NotImplementedError
prompt = {
"prompt_type": ["click"],
"input_point": click_state[0],
"input_label": click_state[1],
"multimask_output": "True",
}
return prompt
def update_click_state(click_state, caption, click_mode):
if click_mode == 'Continuous':
click_state[2].append(caption)
elif click_mode == 'Single':
click_state[2] = [caption]
else:
raise NotImplementedError
async def chat_input_callback(*args):
visual_chatgpt, chat_input, click_state, state, aux_state ,language , autoplay = args
if visual_chatgpt is not None:
state, _, aux_state, _ = visual_chatgpt.run_text(chat_input, state, aux_state)
last_text, last_response = state[-1]
print("last response",last_response)
if autoplay:
audio = await texttospeech(last_response,language,autoplay)
else:
audio=None
return state, state, aux_state, audio
else:
response = "Text refiner is not initilzed, please input openai api key."
state = state + [(chat_input, response)]
audio = await texttospeech(response,language,autoplay)
return state, state, None, audio
def upload_callback(image_input, state, visual_chatgpt=None, openai_api_key=None,language="English"):
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
image_input, mask = image_input['image'], image_input['mask']
click_state = [[], [], []]
image_input = image_resize(image_input, res=1024)
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
ocr_reader=shared_ocr_reader,
session_id=iface.app_id
)
model.segmenter.set_image(image_input)
image_embedding = model.image_embedding
original_size = model.original_size
input_size = model.input_size
if visual_chatgpt is not None:
print('upload_callback: add caption to chatGPT memory')
new_image_path = get_new_image_name('chat_image', func_name='upload')
image_input.save(new_image_path)
visual_chatgpt.current_image = new_image_path
img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
AI_prompt = "Received."
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
parsed_data = get_image_gpt(openai_api_key, new_image_path,"Please provide the name, artist, year of creation, and material used for this painting. Return the information in dictionary format without any newline characters. If any information is unavailable, return \"None\" for that field. Format as follows: { \"name\": \"Name of the painting\",\"artist\": \"Name of the artist\", \"year\": \"Year of creation\", \"material\": \"Material used in the painting\" }.")
parsed_data = json.loads(parsed_data.replace("'", "\""))
name, artist, year, material= parsed_data["name"],parsed_data["artist"],parsed_data["year"], parsed_data["material"]
# artwork_info = f"<div>Painting: {name}<br>Artist name: {artist}<br>Year: {year}<br>Material: {material}</div>"
paragraph = get_image_gpt(openai_api_key, new_image_path,f"What's going on in this picture? in {language}")
state = [
(
None,
f"🤖 Hi, I am EyeSee. Let's explore this painting {name} together. You can click on the area you're interested in and choose from four types of information: Description, Analysis, Interpretation, and Judgment. Based on your selection, I will provide you with the relevant information."
)
]
return state, state, image_input, click_state, image_input, image_input, image_input, image_embedding, \
original_size, input_size, f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",f"Name: {name}", f"Artist: {artist}", f"Year: {year}", f"Material: {material}",paragraph,artist
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state, evt: gr.SelectData):
click_index = evt.index
if point_prompt == 'Positive':
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
else:
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
prompt = get_click_prompt(coordinate, click_state, click_mode)
input_points = prompt['input_point']
input_labels = prompt['input_label']
controls = {'length': length,
'sentiment': sentiment,
'factuality': factuality,
'language': language}
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
ocr_reader=shared_ocr_reader,
text_refiner=text_refiner,
session_id=iface.app_id
)
model.setup(image_embedding, original_size, input_size, is_image_set=True)
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
text = out['generated_captions']['raw_caption']
input_mask = np.array(out['mask'].convert('P'))
image_input_nobackground = mask_painter(np.array(image_input), input_mask,background_alpha=0)
image_input_withbackground=mask_painter(np.array(image_input), input_mask)
click_index_state = click_index
input_mask_state = input_mask
input_points_state = input_points
input_labels_state = input_labels
out_state = out
if visual_chatgpt is not None:
print('inference_click: add caption to chatGPT memory')
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
Image.open(out["crop_save_path"]).save(new_crop_save_path)
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
visual_chatgpt.point_prompt = point_prompt
print("new crop save",new_crop_save_path)
yield state, state, click_state, image_input_nobackground, image_input_withbackground, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground
async def submit_caption(state, text_refiner, length, sentiment, factuality, language,
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
autoplay,paragraph,focus_type,openai_api_key,new_crop_save_path):
print("state",state)
click_index = click_index_state
# if pre_click_index==click_index:
# click_index = (click_index[0] - 1, click_index[1] - 1)
# pre_click_index = click_index
# else:
# pre_click_index = click_index
print("click_index",click_index)
print("input_points_state",input_points_state)
print("input_labels_state",input_labels_state)
prompt=generate_prompt(focus_type,paragraph,length,sentiment,factuality,language)
print("Prompt:", prompt)
print("click",click_index)
# image_input = create_bubble_frame(np.array(image_input), generated_caption, click_index, input_mask,
# input_points=input_points, input_labels=input_labels)
if not args.disable_gpt and text_refiner:
print("new crop save",new_crop_save_path)
focus_info=get_image_gpt(openai_api_key,new_crop_save_path,prompt)
if focus_info.startswith('"') and focus_info.endswith('"'):
focus_info=focus_info[1:-1]
focus_info=focus_info.replace('#', '')
# state = state + [(None, f"Wiki: {paragraph}")]
state = state + [(None, f"{focus_info}")]
print("new_cap",focus_info)
read_info = re.sub(r'[#[\]!*]','',focus_info)
read_info = emoji.replace_emoji(read_info,replace="")
print("read info",read_info)
# refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
# input_points=input_points, input_labels=input_labels)
try:
audio_output = await texttospeech(read_info, language, autoplay)
print("done")
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
except Exception as e:
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
print(f"Error during TTS prediction: {str(e)}")
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
else:
try:
audio_output = await texttospeech(focus_info, language, autoplay)
# waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
# return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, audio_output
except Exception as e:
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
print(f"Error during TTS prediction: {str(e)}")
return state, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
def generate_prompt(focus_type, paragraph,length, sentiment, factuality, language):
mapped_value = focus_map.get(focus_type, -1)
controls = {
'length': length,
'sentiment': sentiment,
'factuality': factuality,
'language': language
}
if mapped_value != -1:
prompt = prompt_list[mapped_value].format(
Wiki_caption=paragraph,
length=controls['length'],
sentiment=controls['sentiment'],
language=controls['language']
)
else:
prompt = "Invalid focus type."
if controls['factuality'] == "Imagination":
prompt += " Assuming that I am someone who has viewed a lot of art and has a lot of experience viewing art. Explain artistic features (composition, color, style, or use of light) and discuss the symbolism of the content and its influence on later artistic movements."
return prompt
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_image_gpt(api_key, image_path,prompt,enable_wiki=None):
# Getting the base64 string
base64_image = encode_image(image_path)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
prompt_text = prompt
payload = {
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt_text
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}
# Sending the request to the OpenAI API
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
result = response.json()
print(result)
content = result['choices'][0]['message']['content']
# Assume the model returns a valid JSON string in 'content'
try:
return content
except json.JSONDecodeError:
return {"error": "Failed to parse model output"}
def get_sketch_prompt(mask: Image.Image):
"""
Get the prompt for the sketcher.
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
"""
mask = np.asarray(mask)[..., 0]
# Get the bounding box of the sketch
y, x = np.where(mask != 0)
x1, y1 = np.min(x), np.min(y)
x2, y2 = np.max(x), np.max(y)
prompt = {
'prompt_type': ['box'],
'input_boxes': [
[x1, y1, x2, y2]
]
}
return prompt
submit_traj=0
async def inference_traject(origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type):
image_input, mask = sketcher_image['image'], sketcher_image['mask']
crop_save_path=""
prompt = get_sketch_prompt(mask)
boxes = prompt['input_boxes']
boxes = boxes[0]
global submit_traj
submit_traj=1
controls = {'length': length,
'sentiment': sentiment,
'factuality': factuality,
'language': language}
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
ocr_reader=shared_ocr_reader,
text_refiner=text_refiner,
session_id=iface.app_id
)
model.setup(image_embedding, original_size, input_size, is_image_set=True)
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki,verbose=True)[0]
print(trace_type)
if trace_type=="Trace+Seg":
input_mask = np.array(out['mask'].convert('P'))
image_input = mask_painter(np.array(image_input), input_mask, background_alpha=0 )
crop_save_path=out['crop_save_path']
else:
image_input = Image.fromarray(np.array(origin_image))
draw = ImageDraw.Draw(image_input)
draw.rectangle(boxes, outline='red', width=2)
cropped_image = origin_image.crop(boxes)
cropped_image.save('temp.png')
crop_save_path='temp.png'
print("crop_svae_path",out['crop_save_path'])
# Update components and states
state.append((f'Box: {boxes}', None))
# fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
# image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language)
width, height = sketcher_image['image'].size
sketcher_image['mask'] = np.zeros((height, width, 4), dtype=np.uint8)
sketcher_image['mask'][..., -1] = 255
sketcher_image['image']=image_input
if not args.disable_gpt and text_refiner:
focus_info=get_image_gpt(openai_api_key,crop_save_path,prompt)
if focus_info.startswith('"') and focus_info.endswith('"'):
focus_info=focus_info[1:-1]
focus_info=focus_info.replace('#', '')
state = state + [(None, f"{focus_info}")]
print("new_cap",focus_info)
read_info = re.sub(r'[#[\]!*]','',focus_info)
read_info = emoji.replace_emoji(read_info,replace="")
print("read info",read_info)
# refined_image_input = create_bubble_frame(np.array(origin_image_input), focus_info, click_index, input_mask,
# input_points=input_points, input_labels=input_labels)
try:
audio_output = await texttospeech(read_info, language,autoplay)
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
return state, state,image_input,audio_output
except Exception as e:
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
print(f"Error during TTS prediction: {str(e)}")
# return state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
return state, state, image_input,audio_output
else:
try:
audio_output = await texttospeech(focus_info, language, autoplay)
# waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
# return state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
return state, state, image_input,audio_output
except Exception as e:
state = state + [(None, f"Error during TTS prediction: {str(e)}")]
print(f"Error during TTS prediction: {str(e)}")
return state, state, image_input,audio_output
def clear_chat_memory(visual_chatgpt, keep_global=False):
if visual_chatgpt is not None:
visual_chatgpt.memory.clear()
visual_chatgpt.point_prompt = ""
if keep_global:
visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
else:
visual_chatgpt.current_image = None
visual_chatgpt.global_prompt = ""
def export_chat_log(chat_state, paragraph, liked, disliked):
try:
if not chat_state:
return None
chat_log = f"Image Description: {paragraph}\n\n"
for entry in chat_state:
user_message, bot_response = entry
if user_message and bot_response:
chat_log += f"User: {user_message}\nBot: {bot_response}\n"
elif user_message:
chat_log += f"User: {user_message}\n"
elif bot_response:
chat_log += f"Bot: {bot_response}\n"
# 添加 liked 和 disliked 信息
chat_log += "\nLiked Responses:\n"
for response in liked:
chat_log += f"{response}\n"
chat_log += "\nDisliked Responses:\n"
for response in disliked:
chat_log += f"{response}\n"
print("export log...")
print("chat_log", chat_log)
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
temp_file.write(chat_log.encode('utf-8'))
temp_file_path = temp_file.name
print(temp_file_path)
return temp_file_path
except Exception as e:
print(f"An error occurred while exporting the chat log: {e}")
return None
async def cap_everything(paragraph, visual_chatgpt,language,autoplay):
# state = state + [(None, f"Caption Everything: {paragraph}")]
Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
AI_prompt = "Received."
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
# waveform_visual, audio_output=tts.predict(paragraph, input_language, input_audio, input_mic, use_mic, agree)
audio_output=await texttospeech(paragraph,language,autoplay)
return paragraph,audio_output
def cap_everything_withoutsound(image_input, visual_chatgpt, text_refiner,paragraph):
model = build_caption_anything_with_models(
args,
api_key="",
captioner=shared_captioner,
sam_model=shared_sam_model,
ocr_reader=shared_ocr_reader,
text_refiner=text_refiner,
session_id=iface.app_id
)
paragraph = model.inference_cap_everything(image_input, verbose=True)
# state = state + [(None, f"Caption Everything: {paragraph}")]
Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
AI_prompt = "Received."
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
return paragraph
def handle_liked(state,like_res):
if state:
like_res.append(state[-1][1])
print(f"Last response recorded: {state[-1][1]}")
else:
print("No response to record.")
state = state + [(None, f"Liked Received 👍")]
return state,like_res
def handle_disliked(state,dislike_res):
if state:
dislike_res.append(state[-1][1])
print(f"Last response recorded: {state[-1][1]}")
else:
print("No response to record.")
state = state + [(None, f"Disliked Received 🥹")]
return state,dislike_res
def get_style():
current_version = version.parse(gr.__version__)
print(current_version)
if current_version <= version.parse('3.24.1'):
style = '''
#image_sketcher{min-height:500px}
#image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
#image_upload{min-height:500px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
.custom-language {
width: 20%;
}
.custom-autoplay {
width: 40%;
}
.custom-output {
width: 30%;
}
'''
elif current_version <= version.parse('3.27'):
style = '''
#image_sketcher{min-height:500px}
#image_upload{min-height:500px}
.custom-language {
width: 20%;
}
.custom-autoplay {
width: 40%;
}
.custom-output {
width: 30%;
}
.custom-gallery {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
}
.custom-gallery img {
width: 48%;
margin-bottom: 10px;
}
'''
else:
style = None
return style
# def handle_like_dislike(like_data, like_state, dislike_state):
# if like_data.liked:
# if like_data.index not in like_state:
# like_state.append(like_data.index)
# message = f"Liked: {like_data.value} at index {like_data.index}"
# else:
# message = "You already liked this item"
# else:
# if like_data.index not in dislike_state:
# dislike_state.append(like_data.index)
# message = f"Disliked: {like_data.value} at index {like_data.index}"
# else:
# message = "You already disliked this item"
# return like_state, dislike_state
async def texttospeech(text, language, autoplay):
try:
if autoplay:
voice = filtered_language_dict[language]
communicate = edge_tts.Communicate(text, voice)
file_path = "output.wav"
await communicate.save(file_path)
with open(file_path, "rb") as audio_file:
audio_bytes = BytesIO(audio_file.read())
audio = base64.b64encode(audio_bytes.read()).decode("utf-8")
print("TTS processing completed.")
audio_style = 'style="width:210px;"'
audio_player = f'<audio src="data:audio/wav;base64,{audio}" controls autoplay {audio_style}></audio>'
else:
audio_player = None
print("Autoplay is disabled.")
return audio_player
except Exception as e:
print(f"Error in texttospeech: {e}")
return None
def create_ui():
title = """<p><h1 align="center">EyeSee Anything in Art</h1></p>
"""
description = """<p>Gradio demo for EyeSee Anything in Art, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. """
examples = [
["test_images/ambass.jpg"],
["test_images/pearl.jpg"],
["test_images/Picture0.png"],
["test_images/Picture1.png"],
["test_images/Picture2.png"],
["test_images/Picture3.png"],
["test_images/Picture4.png"],
["test_images/Picture5.png"],
]
with gr.Blocks(
css=get_style(),
theme=gr.themes.Base()
) as iface:
state = gr.State([])
out_state = gr.State(None)
click_state = gr.State([[], [], []])
origin_image = gr.State(None)
image_embedding = gr.State(None)
text_refiner = gr.State(None)
visual_chatgpt = gr.State(None)
original_size = gr.State(None)
input_size = gr.State(None)
paragraph = gr.State("")
aux_state = gr.State([])
click_index_state = gr.State((0, 0))
input_mask_state = gr.State(np.zeros((1, 1)))
input_points_state = gr.State([])
input_labels_state = gr.State([])
new_crop_save_path = gr.State(None)
image_input_nobackground = gr.State(None)
artist=gr.State(None)
like_res=gr.State([])
dislike_res=gr.State([])
gr.Markdown(title)
gr.Markdown(description)
# with gr.Row(align="right", visible=False, elem_id="top_row") as top_row:
# with gr.Column(scale=0.5):
# # gr.Markdown("Left side content")
# with gr.Column(scale=0.5):
# with gr.Row(align="right",visible=False) as language_select:
# language = gr.Dropdown(
# ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
# value="English", label="Language", interactive=True)
# with gr.Row(align="right",visible=False) as autoplay:
# auto_play = gr.Checkbox(label="Check to autoplay audio", value=False,scale=0.4)
# output_audio = gr.HTML(label="Synthesised Audio",scale=0.6)
with gr.Row(align="right", visible=False, elem_id="top_row") as top_row:
language = gr.Dropdown(
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
value="English", label="Language", interactive=True, scale=0.2, elem_classes="custom-language"
)
auto_play = gr.Checkbox(
label="Check to autoplay audio", value=False, scale=0.4, elem_classes="custom-autoplay"
)
output_audio = gr.HTML(
label="Synthesised Audio", scale=0.3, elem_classes="custom-output"
)
# with gr.Row(align="right",visible=False) as language_select:
# language = gr.Dropdown(
# ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
# value="English", label="Language", interactive=True)
# with gr.Row(align="right",visible=False) as autoplay:
# auto_play = gr.Checkbox(label="Check to autoplay audio", value=False,scale=0.4)
# output_audio = gr.HTML(label="Synthesised Audio",scale=0.6)
with gr.Row():
with gr.Column(scale=1.0):
with gr.Column(visible=False) as modules_not_need_gpt:
with gr.Tab("Base(GPT Power)",visible=False) as base_tab:
image_input_base = gr.Image(type="pil", interactive=True, elem_id="image_upload")
example_image = gr.Image(type="pil", interactive=False, visible=False)
with gr.Row():
name_label_base = gr.Button(value="Name: ")
artist_label_base = gr.Button(value="Artist: ")
year_label_base = gr.Button(value="Year: ")
material_label_base = gr.Button(value="Material: ")
with gr.Tab("Click") as click_tab:
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
example_image = gr.Image(type="pil", interactive=False, visible=False)
with gr.Row():
name_label = gr.Button(value="Name: ")
artist_label = gr.Button(value="Artist: ")
year_label = gr.Button(value="Year: ")
material_label = gr.Button(value="Material: ")
with gr.Row(scale=1.0):
with gr.Row(scale=0.8):
focus_type = gr.Radio(
choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
value="CFV-D",
label="Information Type",
interactive=True)
with gr.Row(scale=0.2):
submit_button_click=gr.Button(value="Submit", interactive=True,variant='primary',size="sm")
with gr.Row(scale=1.0):
with gr.Row(scale=0.4):
point_prompt = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
click_mode = gr.Radio(
choices=["Continuous", "Single"],
value="Continuous",
label="Clicking Mode",
interactive=True)
with gr.Row(scale=0.4):
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
clear_button_image = gr.Button(value="Clear Image", interactive=True)
with gr.Tab("Trajectory (beta)") as traj_tab:
sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
elem_id="image_sketcher")
example_image = gr.Image(type="pil", interactive=False, visible=False)
with gr.Row():
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
with gr.Row(scale=1.0):
with gr.Row(scale=0.8):
focus_type_sketch = gr.Radio(
choices=["CFV-D", "CFV-DA", "CFV-DAI","PFV-DDA"],
value="CFV-D",
label="Information Type",
interactive=True)
Input_sketch = gr.Radio(
choices=["Trace+Seg", "Trace"],
value="Trace+Seg",
label="Trace Type",
interactive=True)
with gr.Column(visible=False) as modules_need_gpt1:
with gr.Row(scale=1.0):
sentiment = gr.Radio(
choices=["Positive", "Natural", "Negative"],
value="Natural",
label="Sentiment",
interactive=True,
)
with gr.Row(scale=1.0):
factuality = gr.Radio(
choices=["Factual", "Imagination"],
value="Factual",
label="Factuality",
interactive=True,
)
length = gr.Slider(
minimum=10,
maximum=80,
value=10,
step=1,
interactive=True,
label="Generated Caption Length",
)
# 是否启用wiki内容整合到caption中
enable_wiki = gr.Radio(
choices=["Yes", "No"],
value="No",
label="Expert",
interactive=True)
with gr.Column(visible=True) as modules_not_need_gpt3:
gr.Examples(
examples=examples,
inputs=[example_image],
)
with gr.Column(scale=0.5):
with gr.Column(visible=True) as module_key_input:
openai_api_key = gr.Textbox(
placeholder="Input openAI API key",
show_label=False,
label="OpenAI API Key",
lines=1,
type="password")
with gr.Row(scale=0.5):
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
variant='primary')
with gr.Column(visible=False) as module_notification_box:
notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
with gr.Column() as modules_need_gpt0:
with gr.Column(visible=False,scale=1.0) as modules_need_gpt2:
paragraph_output = gr.Textbox(lines=16, label="Describe Everything", max_lines=16)
cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
with gr.Column(visible=False) as modules_not_need_gpt2:
with gr.Blocks():
chatbot = gr.Chatbot(label="Chatbox", elem_classes="chatbot",likeable=True).style(height=600, scale=0.5)
with gr.Column(visible=False) as modules_need_gpt3:
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
container=False)
with gr.Row():
clear_button_text = gr.Button(value="Clear Text", interactive=True)
submit_button_text = gr.Button(value="Send", interactive=True, variant="primary")
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
with gr.Row():
export_button = gr.Button(value="Export Chat Log", interactive=True, variant="primary")
with gr.Row():
chat_log_file = gr.File(label="Download Chat Log")
# TTS interface hidden initially
with gr.Column(visible=False) as tts_interface:
input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
agree = gr.Checkbox(label="Agree", value=True)
output_waveform = gr.Video(label="Waveform Visual")
# output_audio = gr.HTML(label="Synthesised Audio")
with gr.Row():
submit_tts = gr.Button(value="Submit", interactive=True)
clear_tts = gr.Button(value="Clear", interactive=True)
###############################################################################
############# this part is for text to image #############
###############################################################################
with gr.Row(variant="panel") as text2image_model:
with gr.Column(scale=0.4):
with gr.Column():
gr.Radio([artist], label="Artist", info="Who is the artist?🧑‍🎨"),
gr.Radio(["Oil Painting","Printmaking","Watercolor Painting","Drawing"], label="Art Forms", info="What are the art forms?🎨"),
gr.Radio(["Renaissance", "Baroque", "Impressionism","Modernism"], label="Period", info="Which art period?⏳"),
# to be done
gr.Dropdown(
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Items", info="Which items are you interested in?"
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Accordion("Advanced options", open=False):
num_images = gr.Slider(
label="Number of Images",
minimum=1,
maximum=4,
step=1,
value=4,
)
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=5,
lines=4,
placeholder="Enter a negative prompt",
value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row(visible=True):
width = gr.Slider(
label="Width",
minimum=100,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=100,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.1,
maximum=6,
step=0.1,
value=3.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=15,
step=1,
value=8,
)
with gr.Column(scale=0.6):
result = gr.Gallery(
label="Result"
# columns=4,
# rows=2,
# show_label=False,
# allow_preview=True,
# object_fit="contain",
# height="auto",
# preview=True,
# show_share_button=True,
# show_download_button=True
).style(height='auto',columns=4)
# gr.Examples(
# examples=examples,
# inputs=prompt,
# cache_examples=False
# )
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt,
api_name=False,
)
# gr.on(
# triggers=[
# prompt.submit,
# negative_prompt.submit,
# run_button.click,
# ],
# fn=generate,
# inputs=[
# prompt,
# negative_prompt,
# use_negative_prompt,
# seed,
# width,
# height,
# guidance_scale,
# num_inference_steps,
# randomize_seed,
# num_images
# ],
# outputs=[result, seed],
# api_name="run",
# )
run_button.click(
fn=generate,
inputs=[
prompt,
negative_prompt,
use_negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
randomize_seed,
num_images
],
outputs=[result, seed]
)
###############################################################################
############# above part is for text to image #############
###############################################################################
###############################################################################
# this part is for 3d generate.
###############################################################################
with gr.Row(variant="panel",visible=False) as d3_model:
with gr.Column():
with gr.Row():
input_image = gr.Image(
label="Input Image",
image_mode="RGBA",
sources="upload",
#width=256,
#height=256,
type="pil",
elem_id="content_image",
)
processed_image = gr.Image(
label="Processed Image",
image_mode="RGBA",
#width=256,
#height=256,
type="pil",
interactive=False
)
with gr.Row():
with gr.Group():
do_remove_background = gr.Checkbox(
label="Remove Background", value=True
)
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
sample_steps = gr.Slider(
label="Sample Steps",
minimum=30,
maximum=75,
value=75,
step=5
)
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Row(variant="panel"):
gr.Examples(
examples=[
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
],
inputs=[input_image],
label="Examples",
cache_examples=False,
examples_per_page=16
)
with gr.Column():
with gr.Row():
with gr.Column():
mv_show_images = gr.Image(
label="Generated Multi-views",
type="pil",
width=379,
interactive=False
)
# with gr.Column():
# output_video = gr.Video(
# label="video", format="mp4",
# width=379,
# autoplay=True,
# interactive=False
# )
with gr.Row():
with gr.Tab("OBJ"):
output_model_obj = gr.Model3D(
label="Output Model (OBJ Format)",
interactive=False,
)
gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
with gr.Tab("GLB"):
output_model_glb = gr.Model3D(
label="Output Model (GLB Format)",
interactive=False,
)
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
mv_images = gr.State()
# chatbot.like(handle_like_dislike, inputs=[like_state, dislike_state], outputs=[like_state, dislike_state])
submit.click(fn=check_input_image, inputs=[new_crop_save_path], outputs=[processed_image]).success(
fn=generate_mvs,
inputs=[processed_image, sample_steps, sample_seed],
outputs=[mv_images, mv_show_images]
).success(
fn=make3d,
inputs=[mv_images],
outputs=[output_model_obj, output_model_glb]
)
###############################################################################
# above part is for 3d generate.
###############################################################################
def clear_tts_fields():
return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
# submit_tts.click(
# tts.predict,
# inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
# outputs=[output_waveform, output_audio],
# queue=True
# )
clear_tts.click(
clear_tts_fields,
inputs=None,
outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
queue=False
)
clear_button_sketcher.click(
lambda x: (x),
[origin_image],
[sketcher_input],
queue=False,
show_progress=False
)
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
modules_not_need_gpt,
modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
disable_chatGPT_button.click(init_wo_openai_api_key,
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
modules_not_need_gpt,
modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box,d3_model,top_row])
enable_chatGPT_button.click(
lambda: (None, [], [], [[], [], []], "", "", ""),
[],
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
queue=False,
show_progress=False
)
openai_api_key.submit(
lambda: (None, [], [], [[], [], []], "", "", ""),
[],
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
queue=False,
show_progress=False
)
cap_everything_button.click(cap_everything, [paragraph, visual_chatgpt, language,auto_play],
[paragraph_output,output_audio])
clear_button_click.click(
lambda x: ([[], [], []], x),
[origin_image],
[click_state, image_input],
queue=False,
show_progress=False
)
clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
clear_button_image.click(
lambda: (None, [], [], [[], [], []], "", "", ""),
[],
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
queue=False,
show_progress=False
)
clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
clear_button_text.click(
lambda: ([], [], [[], [], [], []]),
[],
[chatbot, state, click_state],
queue=False,
show_progress=False
)
clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
image_input.clear(
lambda: (None, [], [], [[], [], []], "", "", ""),
[],
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
queue=False,
show_progress=False
)
image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
image_input_base.upload(upload_callback, [image_input_base, state, visual_chatgpt,openai_api_key],
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
image_input.upload(upload_callback, [image_input, state, visual_chatgpt, openai_api_key],
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt, openai_api_key],
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
[chatbot, state, aux_state,output_audio])
chat_input.submit(lambda: "", None, chat_input)
submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state,language,auto_play],
[chatbot, state, aux_state,output_audio])
submit_button_text.click(lambda: "", None, chat_input)
example_image.change(upload_callback, [example_image, state, visual_chatgpt, openai_api_key],
[chatbot, state, origin_image, click_state, image_input, image_input_base, sketcher_input,
image_embedding, original_size, input_size,name_label,artist_label,year_label,material_label,name_label_base, artist_label_base, year_label_base, material_label_base,paragraph,artist])
example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
def on_click_tab_selected():
if gpt_state ==1:
print(gpt_state)
print("using gpt")
return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
else:
print("no gpt")
print("gpt_state",gpt_state)
return [gr.update(visible=False)]+[gr.update(visible=True)]+[gr.update(visible=False)]*2
def on_base_selected():
if gpt_state ==1:
print(gpt_state)
print("using gpt")
return [gr.update(visible=True)]*2+[gr.update(visible=False)]*2
else:
print("no gpt")
return [gr.update(visible=False)]*4
traj_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
click_tab.select(on_click_tab_selected, outputs=[modules_need_gpt1,modules_not_need_gpt2,modules_need_gpt0,modules_need_gpt2])
base_tab.select(on_base_selected, outputs=[modules_need_gpt0,modules_need_gpt2,modules_not_need_gpt2,modules_need_gpt1])
image_input.select(
inference_click,
inputs=[
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
],
outputs=[chatbot, state, click_state, image_input, input_image, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,new_crop_save_path,image_input_nobackground],
show_progress=False, queue=True
)
submit_button_click.click(
submit_caption,
inputs=[
state, text_refiner,length, sentiment, factuality, language,
out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
auto_play,paragraph,focus_type,openai_api_key,new_crop_save_path
],
outputs=[
chatbot, state, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
output_audio
],
show_progress=True,
queue=True
)
submit_button_sketcher.click(
inference_traject,
inputs=[
origin_image,sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
original_size, input_size, text_refiner,focus_type_sketch,paragraph,openai_api_key,auto_play,Input_sketch
],
outputs=[chatbot, state, sketcher_input,output_audio],
show_progress=False, queue=True
)
export_button.click(
export_chat_log,
inputs=[state,paragraph,like_res,dislike_res],
outputs=[chat_log_file],
queue=True
)
upvote_btn.click(
handle_liked,
inputs=[state,like_res],
outputs=[chatbot,like_res]
)
downvote_btn.click(
handle_disliked,
inputs=[state,dislike_res],
outputs=[chatbot,dislike_res]
)
return iface
if __name__ == '__main__':
iface = create_ui()
iface.queue(concurrency_count=5, api_open=False, max_size=10)
iface.launch(server_name="0.0.0.0", enable_queue=True)