maya_demo / app.py
kkr5155's picture
Update app.py
aba6f6c verified
import gradio as gr
import spaces
import time
import os
import torch
from PIL import Image
from threading import Thread
from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM
from constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from conversation import conv_templates
from eval_utils import load_maya_model
from utils import disable_torch_init
from mm_utils import tokenizer_image_token, process_images
from huggingface_hub._login import _login
# Import LLaVA modules to register model types
from model import *
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig
# Register model type and config
AutoConfig.register("llava_cohere", LlavaCohereConfig)
AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM)
hf_token = os.getenv("hf_token")
_login(token=hf_token, add_to_git_credential=False)
# Global Variables
MODEL_BASE = "CohereForAI/aya-23-8B"
MODEL_PATH = "maya-multimodal/maya"
MODE = "finetuned"
def load_model():
"""Load the Maya model and required components"""
model, tokenizer, image_processor, _ = load_maya_model(
MODEL_BASE, MODEL_PATH, None, MODE
)
model = model.cuda()
model.eval()
return model, tokenizer, image_processor
# Load model globally
print("Loading model...")
model, tokenizer, image_processor = load_model()
print("Model loaded successfully!")
def validate_image_file(image_path):
"""Validate that the image file exists and is in a supported format."""
if not os.path.isfile(image_path):
raise gr.Error(f"Error: File {image_path} does not exist.")
try:
with Image.open(image_path) as img:
img.verify()
return True
except (IOError, SyntaxError) as e:
raise gr.Error(f"Error: {image_path} is not a valid image file. {e}")
@spaces.GPU
def process_chat_stream(message, history):
print(message)
print("History:", history)
image = None # Initialize image variable first
# First try to get image from current message
if message.get("files", []):
current_files = message["files"]
if current_files:
last_file = current_files[-1]
image = last_file["path"] if isinstance(last_file, dict) else last_file
# If no image in current message, try to get from history
if image is None and history:
for hist in reversed(history):
print("Processing history item:", hist)
if isinstance(hist["content"], tuple):
image = hist["content"][0]
break
elif isinstance(hist["content"], dict) and hist["content"].get("files"):
hist_files = hist["content"]["files"]
if hist_files:
first_file = hist_files[0]
image = first_file["path"] if isinstance(first_file, dict) else first_file
break
# Check if we found an image
if image is None:
raise gr.Error("Please upload an image to start the conversation.")
# Validate and process image
validate_image_file(image)
image = Image.open(image).convert("RGB")
# Process image for the model
image_tensor = process_images([image], image_processor, model.config)
if image_tensor is None:
raise gr.Error("Failed to process image")
image_tensor = image_tensor.cuda()
# Prepare conversation
conv = conv_templates["aya"].copy()
# Add conversation history
for hist in history:
# Handle user messages
if hist["role"] == "user":
# Extract text content based on format
if isinstance(hist["content"], str):
human_text = hist["content"]
elif isinstance(hist["content"], tuple):
human_text = hist["content"][1] if len(hist["content"]) > 1 else ""
else:
human_text = hist["content"]
conv.append_message(conv.roles[0], human_text)
# Handle assistant messages
elif hist["role"] == "assistant":
conv.append_message(conv.roles[1], hist["content"])
# Format current message with proper image token placement
current_message = message["text"]
if not history:
if model.config.mm_use_im_start_end:
current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}"
else:
current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}"
# Add current message to conversation
conv.append_message(conv.roles[0], current_message)
conv.append_message(conv.roles[1], None)
# Get prompt and ensure input_ids are properly created
prompt = conv.get_prompt()
# print("PROMPT: ", prompt)
try:
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
if input_ids is None:
raise ValueError("Tokenization returned None")
# Ensure input_ids is 2D tensor
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
input_ids = input_ids.cuda()
# Validate vision tower and image tensor before starting generation
if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None:
raise ValueError("Model's vision tower is not properly initialized")
if image_tensor is None:
raise ValueError("Image tensor is None")
# Setup streamer and generation
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = {
"inputs": input_ids,
"images": image_tensor,
"image_sizes": [image.size],
"streamer": streamer,
"temperature": 0.3,
"do_sample": True,
"top_p": 0.9,
"num_beams": 1,
"max_new_tokens": 4096,
"use_cache": True
}
def generate_with_error_handling():
try:
model.generate(**generation_kwargs)
except Exception as e:
import traceback
error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}"
raise gr.Error(error_msg)
thread = Thread(target=generate_with_error_handling)
thread.start()
except Exception as e:
error_msg = f"Setup error: {str(e)}"
import traceback
error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}"
raise gr.Error(error_msg)
partial_message = ""
for new_token in streamer:
partial_message += new_token
time.sleep(0.1)
yield {"role": "assistant", "content": partial_message}
# Create Gradio interface
chatbot = gr.Chatbot(
show_label=False,
height=450,
show_share_button=False,
show_copy_button=False,
avatar_images=None,
container=True,
render_markdown=True,
scale=1,
type="messages"
)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=process_chat_stream,
title="Maya: Multilingual Multimodal Model",
examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]},
{"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}],
description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam πŸ’š Maya",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)