Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,088 Bytes
e968589 aba6f6c e968589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import gradio as gr
import spaces
import time
import os
import torch
from PIL import Image
from threading import Thread
from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM
from constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from conversation import conv_templates
from eval_utils import load_maya_model
from utils import disable_torch_init
from mm_utils import tokenizer_image_token, process_images
from huggingface_hub._login import _login
# Import LLaVA modules to register model types
from model import *
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig
# Register model type and config
AutoConfig.register("llava_cohere", LlavaCohereConfig)
AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM)
hf_token = os.getenv("hf_token")
_login(token=hf_token, add_to_git_credential=False)
# Global Variables
MODEL_BASE = "CohereForAI/aya-23-8B"
MODEL_PATH = "maya-multimodal/maya"
MODE = "finetuned"
def load_model():
"""Load the Maya model and required components"""
model, tokenizer, image_processor, _ = load_maya_model(
MODEL_BASE, MODEL_PATH, None, MODE
)
model = model.cuda()
model.eval()
return model, tokenizer, image_processor
# Load model globally
print("Loading model...")
model, tokenizer, image_processor = load_model()
print("Model loaded successfully!")
def validate_image_file(image_path):
"""Validate that the image file exists and is in a supported format."""
if not os.path.isfile(image_path):
raise gr.Error(f"Error: File {image_path} does not exist.")
try:
with Image.open(image_path) as img:
img.verify()
return True
except (IOError, SyntaxError) as e:
raise gr.Error(f"Error: {image_path} is not a valid image file. {e}")
@spaces.GPU
def process_chat_stream(message, history):
print(message)
print("History:", history)
image = None # Initialize image variable first
# First try to get image from current message
if message.get("files", []):
current_files = message["files"]
if current_files:
last_file = current_files[-1]
image = last_file["path"] if isinstance(last_file, dict) else last_file
# If no image in current message, try to get from history
if image is None and history:
for hist in reversed(history):
print("Processing history item:", hist)
if isinstance(hist["content"], tuple):
image = hist["content"][0]
break
elif isinstance(hist["content"], dict) and hist["content"].get("files"):
hist_files = hist["content"]["files"]
if hist_files:
first_file = hist_files[0]
image = first_file["path"] if isinstance(first_file, dict) else first_file
break
# Check if we found an image
if image is None:
raise gr.Error("Please upload an image to start the conversation.")
# Validate and process image
validate_image_file(image)
image = Image.open(image).convert("RGB")
# Process image for the model
image_tensor = process_images([image], image_processor, model.config)
if image_tensor is None:
raise gr.Error("Failed to process image")
image_tensor = image_tensor.cuda()
# Prepare conversation
conv = conv_templates["aya"].copy()
# Add conversation history
for hist in history:
# Handle user messages
if hist["role"] == "user":
# Extract text content based on format
if isinstance(hist["content"], str):
human_text = hist["content"]
elif isinstance(hist["content"], tuple):
human_text = hist["content"][1] if len(hist["content"]) > 1 else ""
else:
human_text = hist["content"]
conv.append_message(conv.roles[0], human_text)
# Handle assistant messages
elif hist["role"] == "assistant":
conv.append_message(conv.roles[1], hist["content"])
# Format current message with proper image token placement
current_message = message["text"]
if not history:
if model.config.mm_use_im_start_end:
current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}"
else:
current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}"
# Add current message to conversation
conv.append_message(conv.roles[0], current_message)
conv.append_message(conv.roles[1], None)
# Get prompt and ensure input_ids are properly created
prompt = conv.get_prompt()
# print("PROMPT: ", prompt)
try:
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
if input_ids is None:
raise ValueError("Tokenization returned None")
# Ensure input_ids is 2D tensor
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
input_ids = input_ids.cuda()
# Validate vision tower and image tensor before starting generation
if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None:
raise ValueError("Model's vision tower is not properly initialized")
if image_tensor is None:
raise ValueError("Image tensor is None")
# Setup streamer and generation
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = {
"inputs": input_ids,
"images": image_tensor,
"image_sizes": [image.size],
"streamer": streamer,
"temperature": 0.3,
"do_sample": True,
"top_p": 0.9,
"num_beams": 1,
"max_new_tokens": 4096,
"use_cache": True
}
def generate_with_error_handling():
try:
model.generate(**generation_kwargs)
except Exception as e:
import traceback
error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}"
raise gr.Error(error_msg)
thread = Thread(target=generate_with_error_handling)
thread.start()
except Exception as e:
error_msg = f"Setup error: {str(e)}"
import traceback
error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}"
raise gr.Error(error_msg)
partial_message = ""
for new_token in streamer:
partial_message += new_token
time.sleep(0.1)
yield {"role": "assistant", "content": partial_message}
# Create Gradio interface
chatbot = gr.Chatbot(
show_label=False,
height=450,
show_share_button=False,
show_copy_button=False,
avatar_images=None,
container=True,
render_markdown=True,
scale=1,
type="messages"
)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=process_chat_stream,
title="Maya: Multilingual Multimodal Model",
examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]},
{"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}],
description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam π Maya",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |