Spaces:
Runtime error
Runtime error
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, field_validator | |
from typing import Optional, List, Union, Dict, Any | |
import torch | |
from transformers import ( | |
Qwen2_5_VLForConditionalGeneration, | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
BitsAndBytesConfig | |
) | |
from qwen_vl_utils import process_vision_info | |
import uvicorn | |
import json | |
from datetime import datetime | |
import logging | |
import time | |
import psutil | |
import GPUtil | |
import base64 | |
from PIL import Image | |
import io | |
import os | |
import threading | |
# Set environment variables to disable compilation cache and avoid CUDA kernel issues | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "0" | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0" # Compatible with A5000 | |
# Model configuration | |
MODELS = { | |
"Qwen2.5-VL-7B-Instruct": { | |
"path": "Qwen/Qwen2.5-VL-7B-Instruct", | |
"model_class": Qwen2_5_VLForConditionalGeneration, | |
}, | |
"Qwen2-VL-7B-Instruct": { | |
"path": "Qwen/Qwen2-VL-7B-Instruct", | |
"model_class": Qwen2VLForConditionalGeneration, | |
}, | |
"Qwen2-VL-2B-Instruct": { | |
"path": "Qwen/Qwen2-VL-2B-Instruct", | |
"model_class": Qwen2VLForConditionalGeneration, | |
} | |
} | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Global variables | |
models = {} | |
processors = {} | |
model_locks = {} # Thread locks for model loading | |
last_used = {} # Record last use time of models | |
# Set default CUDA device | |
if torch.cuda.is_available(): | |
# Get GPU information and select the device with maximum memory | |
gpus = GPUtil.getGPUs() | |
if gpus: | |
max_memory_gpu = max(gpus, key=lambda g: g.memoryTotal) | |
selected_device = max_memory_gpu.id | |
torch.cuda.set_device(selected_device) | |
device = torch.device(f"cuda:{selected_device}") | |
logger.info(f"Selected GPU {selected_device} ({max_memory_gpu.name}) with {max_memory_gpu.memoryTotal}MB memory") | |
else: | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device("cpu") | |
logger.info(f"Using device: {device}") | |
class ImageURL(BaseModel): | |
url: str | |
class MessageContent(BaseModel): | |
type: str | |
text: Optional[str] = None | |
image_url: Optional[Dict[str, str]] = None | |
def validate_type(cls, v: str) -> str: | |
if v not in ['text', 'image_url']: | |
raise ValueError(f"Invalid content type: {v}") | |
return v | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[MessageContent]] | |
def validate_role(cls, v: str) -> str: | |
if v not in ['system', 'user', 'assistant']: | |
raise ValueError(f"Invalid role: {v}") | |
return v | |
def validate_content(cls, v: Union[str, List[Any]]) -> Union[str, List[MessageContent]]: | |
if isinstance(v, str): | |
return v | |
if isinstance(v, list): | |
return [MessageContent(**item) if isinstance(item, dict) else item for item in v] | |
raise ValueError("Content must be either a string or a list of content items") | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
max_tokens: Optional[int] = 2048 | |
stream: Optional[bool] = False | |
response_format: Optional[Dict[str, str]] = None | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str | |
created: int | |
model: str | |
choices: List[Dict[str, Any]] | |
usage: Dict[str, int] | |
class ModelCard(BaseModel): | |
id: str | |
created: int | |
owned_by: str | |
permission: List[Dict[str, Any]] = [] | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
capabilities: Optional[Dict[str, bool]] = None | |
context_window: Optional[int] = None | |
max_tokens: Optional[int] = None | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] | |
def process_base64_image(base64_string: str) -> Image.Image: | |
"""Process base64 image data and return PIL Image""" | |
try: | |
# Remove data URL prefix if present | |
if 'base64,' in base64_string: | |
base64_string = base64_string.split('base64,')[1] | |
image_data = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_data)) | |
# Convert to RGB if necessary | |
if image.mode not in ('RGB', 'L'): | |
image = image.convert('RGB') | |
return image | |
except Exception as e: | |
logger.error(f"Error processing base64 image: {str(e)}") | |
raise ValueError(f"Invalid base64 image data: {str(e)}") | |
def log_system_info(): | |
"""Log system resource information""" | |
try: | |
cpu_percent = psutil.cpu_percent(interval=1) | |
memory = psutil.virtual_memory() | |
gpu_info = [] | |
if torch.cuda.is_available(): | |
for gpu in GPUtil.getGPUs(): | |
gpu_info.append({ | |
'id': gpu.id, | |
'name': gpu.name, | |
'load': f"{gpu.load*100}%", | |
'memory_used': f"{gpu.memoryUsed}MB/{gpu.memoryTotal}MB", | |
'temperature': f"{gpu.temperature}°C" | |
}) | |
logger.info(f"System Info - CPU: {cpu_percent}%, RAM: {memory.percent}%, " | |
f"Available RAM: {memory.available/1024/1024/1024:.1f}GB") | |
if gpu_info: | |
logger.info(f"GPU Info: {gpu_info}") | |
except Exception as e: | |
logger.warning(f"Failed to log system info: {str(e)}") | |
def get_or_initialize_model(model_name: str): | |
"""Get or initialize a model if not already loaded""" | |
global models, processors, model_locks, last_used | |
if model_name not in MODELS: | |
available_models = list(MODELS.keys()) | |
raise ValueError(f"Unsupported model: {model_name}\nAvailable models: {available_models}") | |
# Initialize lock for the model (if not already done) | |
if model_name not in model_locks: | |
model_locks[model_name] = threading.Lock() | |
with model_locks[model_name]: | |
if model_name not in models or model_name not in processors: | |
try: | |
start_time = time.time() | |
logger.info(f"Starting {model_name} initialization...") | |
log_system_info() | |
model_config = MODELS[model_name] | |
# Configure 8-bit quantization | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=False, | |
bnb_4bit_quant_type="nf4", | |
) | |
logger.info(f"Loading {model_name} with 8-bit quantization...") | |
model = model_config["model_class"].from_pretrained( | |
model_config["path"], | |
quantization_config=quantization_config, | |
device_map={"": device.index if device.type == "cuda" else "cpu"}, | |
local_files_only=False | |
).eval() | |
processor = AutoProcessor.from_pretrained( | |
model_config["path"], | |
local_files_only=False | |
) | |
models[model_name] = model | |
processors[model_name] = processor | |
end_time = time.time() | |
logger.info(f"Model {model_name} initialized in {end_time - start_time:.2f} seconds") | |
log_system_info() | |
except Exception as e: | |
logger.error(f"Model initialization error for {model_name}: {str(e)}", exc_info=True) | |
raise RuntimeError(f"Failed to initialize model {model_name}: {str(e)}") | |
# Update last use time | |
last_used[model_name] = time.time() | |
return models[model_name], processors[model_name] | |
async def lifespan(app: FastAPI): | |
logger.info("Starting application initialization...") | |
try: | |
yield | |
finally: | |
logger.info("Shutting down application...") | |
global models, processors | |
for model_name, model in models.items(): | |
try: | |
del model | |
logger.info(f"Model {model_name} unloaded") | |
except Exception as e: | |
logger.error(f"Error during cleanup of {model_name}: {str(e)}") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info("CUDA cache cleared") | |
models = {} | |
processors = {} | |
logger.info("Shutdown complete") | |
app = FastAPI( | |
title="Qwen2.5-VL API", | |
description="OpenAI-compatible API for Qwen2.5-VL vision-language model", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def list_models(): | |
"""List available models""" | |
model_cards = [] | |
for model_name in MODELS.keys(): | |
model_cards.append( | |
ModelCard( | |
id=model_name, | |
created=1709251200, | |
owned_by="Qwen", | |
permission=[{ | |
"id": f"modelperm-{model_name}", | |
"created": 1709251200, | |
"allow_create_engine": False, | |
"allow_sampling": True, | |
"allow_logprobs": True, | |
"allow_search_indices": False, | |
"allow_view": True, | |
"allow_fine_tuning": False, | |
"organization": "*", | |
"group": None, | |
"is_blocking": False | |
}], | |
capabilities={ | |
"vision": True, | |
"chat": True, | |
"embeddings": False, | |
"text_completion": True | |
}, | |
context_window=4096, | |
max_tokens=2048 | |
) | |
) | |
return ModelList(data=model_cards) | |
async def chat_completions(request: ChatCompletionRequest): | |
"""Handle chat completion requests with vision support""" | |
try: | |
# Get or initialize requested model | |
model, processor = get_or_initialize_model(request.model) | |
request_start_time = time.time() | |
logger.info(f"Received chat completion request for model: {request.model}") | |
logger.info(f"Request content: {request.model_dump_json()}") | |
messages = [] | |
for msg in request.messages: | |
if isinstance(msg.content, str): | |
messages.append({"role": msg.role, "content": msg.content}) | |
else: | |
processed_content = [] | |
for content_item in msg.content: | |
if content_item.type == "text": | |
processed_content.append({ | |
"type": "text", | |
"text": content_item.text | |
}) | |
elif content_item.type == "image_url": | |
if "url" in content_item.image_url: | |
if content_item.image_url["url"].startswith("data:image"): | |
processed_content.append({ | |
"type": "image", | |
"image": process_base64_image(content_item.image_url["url"]) | |
}) | |
messages.append({"role": msg.role, "content": processed_content}) | |
text = processor.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
# Ensure input data is on the correct device | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt" | |
) | |
# Move all tensors to specified device | |
input_tensors = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**input_tensors, | |
max_new_tokens=request.max_tokens, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id | |
) | |
# Get input length and trim generated IDs | |
input_length = input_tensors['input_ids'].shape[1] | |
generated_ids_trimmed = generated_ids[:, input_length:] | |
response = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
if request.response_format and request.response_format.get("type") == "json_object": | |
try: | |
if response.startswith('```'): | |
response = '\n'.join(response.split('\n')[1:-1]) | |
if response.startswith('json'): | |
response = response[4:].lstrip() | |
content = json.loads(response) | |
response = json.dumps(content) | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON parsing error: {str(e)}") | |
raise HTTPException(status_code=400, detail=f"Invalid JSON response: {str(e)}") | |
total_time = time.time() - request_start_time | |
logger.info(f"Request completed in {total_time:.2f} seconds") | |
return ChatCompletionResponse( | |
id=f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}", | |
object="chat.completion", | |
created=int(datetime.now().timestamp()), | |
model=request.model, | |
choices=[{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response | |
}, | |
"finish_reason": "stop" | |
}], | |
usage={ | |
"prompt_tokens": input_length, | |
"completion_tokens": len(generated_ids_trimmed[0]), | |
"total_tokens": input_length + len(generated_ids_trimmed[0]) | |
} | |
) | |
except Exception as e: | |
logger.error(f"Request error: {str(e)}", exc_info=True) | |
if isinstance(e, HTTPException): | |
raise | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
log_system_info() | |
return { | |
"status": "healthy", | |
"loaded_models": list(models.keys()), | |
"device": str(device), | |
"cuda_available": torch.cuda.is_available(), | |
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
"timestamp": datetime.now().isoformat() | |
} | |
async def model_status(): | |
"""Get the status of all models""" | |
status = {} | |
for model_name in MODELS: | |
status[model_name] = { | |
"loaded": model_name in models, | |
"last_used": last_used.get(model_name, None), | |
"available": model_name in MODELS | |
} | |
return status | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=9192) |