fusion / app.py
whyumesh's picture
Update app.py
8bd4f69 verified
raw
history blame
8.3 kB
import torch
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer
)
from qwen_vl_utils import process_vision_info
from PIL import Image
import cv2
import numpy as np
import gradio as gr
import spaces
from huggingface_hub import login
import os
# Add quota management constants
MAX_GPU_TIME_PER_REQUEST = 59 # seconds
COOLDOWN_PERIOD = 300 # 5 minutes in seconds
# Add login function at the start
def init_huggingface_auth():
# Get token from environment variable or set it directly
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
login(token=token)
print("Successfully authenticated with Hugging Face")
else:
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
# Load both models and their processors/tokenizers
def load_models():
try:
# Initialize HF auth before loading models
init_huggingface_auth()
# Vision model
vision_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=True # Add auth token usage
)
vision_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
use_auth_token=True # Add auth token usage
)
# Code model
code_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=True # Add auth token usage
)
code_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
use_auth_token=True # Add auth token usage
)
# Free up CUDA memory after loading
torch.cuda.empty_cache()
return vision_model, vision_processor, code_model, code_tokenizer
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
vision_model, vision_processor, code_model, code_tokenizer = load_models()
VISION_SYSTEM_PROMPT = """Extract code from images/videos:
1. Output exact code snippets only
2. Keep original formatting/indentation
focus on code-relevant frames only
[code]
If multiple code sections are visible, separate them with ---
Note: In video, irrelevant frames may occur (e.g., other windows tabs, eterniq website, etc.) in video. Please focus on code-specific frames as we have to extract that content only.
"""
CODE_SYSTEM_PROMPT = """Debug code as an expert:
- Analyze OCR-extracted code + user's issue
- Find bugs/issues
- Provide fixes
- Explain corrections
Output:
Fixed Code:
[corrected code]
Original Issue:
[brief analysis]
Note: Please provide the output in a well-structured Markdown format. Remove all unnecessary information and exclude any additional code formatting such as triple backticks or language identifiers.
"""
def process_video_for_code(video_path, transcribed_text, max_frames=16, frame_interval=30):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
frame_count += 1
cap.release()
if not frames:
return "No frames could be extracted from the video.", "No code could be analyzed."
# Process all frames
vision_descriptions = []
for frame in frames:
vision_description = process_image_for_vision(frame, transcribed_text)
vision_descriptions.append(vision_description)
# Combine all vision descriptions
combined_vision_description = "\n\n".join(vision_descriptions)
# Use code model to fix the code based on combined description
fixed_code_response = process_for_code(combined_vision_description)
return combined_vision_description, fixed_code_response
def process_image_for_vision(image, transcribed_text):
vision_messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": f"{VISION_SYSTEM_PROMPT}\n\nDescribe the code and any errors you see in this image. User's description: {transcribed_text}"},
],
}
]
vision_text = vision_processor.apply_chat_template(
vision_messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(vision_messages)
vision_inputs = vision_processor(
text=[vision_text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(vision_model.device)
with torch.no_grad():
vision_output_ids = vision_model.generate(**vision_inputs, max_new_tokens=512)
vision_output_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(vision_inputs.input_ids, vision_output_ids)
]
return vision_processor.batch_decode(
vision_output_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
def process_for_code(vision_description):
code_messages = [
{"role": "system", "content": CODE_SYSTEM_PROMPT},
{"role": "user", "content": f"Here's a description of code with errors:\n\n{vision_description}\n\nPlease analyze and fix the code."}
]
code_text = code_tokenizer.apply_chat_template(
code_messages,
tokenize=False,
add_generation_prompt=True
)
code_inputs = code_tokenizer([code_text], return_tensors="pt").to(code_model.device)
with torch.no_grad():
code_output_ids = code_model.generate(
**code_inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.95,
)
code_output_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(code_inputs.input_ids, code_output_ids)
]
return code_tokenizer.batch_decode(
code_output_trimmed,
skip_special_tokens=True
)[0]
@spaces.GPU
def process_content(video, transcribed_text):
try:
if video is None:
return "Please upload a video file of code with errors.", ""
# Add GPU memory management
torch.cuda.empty_cache()
# Check available GPU memory
if torch.cuda.is_available():
available_memory = torch.cuda.get_device_properties(0).total_memory
if available_memory < 1e9: # Less than 1GB available
raise RuntimeError("Insufficient GPU memory available")
vision_output, code_output = process_video_for_code(
video.name,
transcribed_text,
max_frames=8 # Reduced from 16 to lower GPU usage
)
return vision_output, code_output
except spaces.zero.gradio.HTMLError as e:
if "exceeded your GPU quota" in str(e):
return (
"GPU quota exceeded. Please try again later or consider upgrading to a paid plan.",
""
)
except Exception as e:
return f"Error processing content: {str(e)}", ""
finally:
# Clean up GPU memory
torch.cuda.empty_cache()
# Gradio interface
iface = gr.Interface(
fn=process_content,
inputs=[
gr.File(label="Upload Video of Code with Errors"),
gr.Textbox(label="Transcribed Audio")
],
outputs=[
gr.Textbox(label="Vision Model Output (Code Description)"),
gr.Code(label="Fixed Code", language="python")
],
title="Vision Code Debugger",
description="Upload a video of code with errors and provide transcribed audio, and the AI will analyze and fix the issues.",
allow_flagging="never", # Disable flagging to reduce overhead
cache_examples=True # Enable caching to reduce GPU usage
)
if __name__ == "__main__":
iface.launch(show_error=True)