Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler | |
import torch | |
import cv2 | |
import os | |
import base64 | |
import soundfile as sf | |
import time | |
# --- Set up Models --- | |
# Stable Diffusion for image generation | |
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", | |
scheduler=scheduler, | |
torch_dtype=torch.float16 | |
).to("cuda") | |
# LLaVA for vision-based language understanding | |
tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers") | |
model = AutoModelForCausalLM.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers").to("cuda") | |
# Open-source language model for text generation (e.g., GPT-Neo) | |
gpt_neo_pipe = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B") | |
# Text-to-Speech | |
text_to_speech = pipeline( | |
"text-to-speech", model="espnet/fastspeech2_en_ljspeech" | |
) | |
# --- Functions --- | |
def process_image(image_base64, chat_history): | |
"""Processes an image, sends it to LLaVA, and generates a response.""" | |
# Prepare LLaVA input | |
input_text = f"""<image> {image_base64} </image>\n\nWhat do you see in this image?""" | |
inputs = tokenizer(input_text, return_tensors="pt").to("cuda") | |
# Generate response using LLaVA | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
response = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True) | |
# Generate speech from the response | |
audio = text_to_speech(response) | |
audio_path = "generated_audio.wav" | |
sf.write(audio_path, audio[0].numpy(), samplerate=22050) | |
# Update chat history | |
chat_history += "You: Image\n" | |
chat_history += "Model: " + response + "\n" | |
return chat_history, audio_path | |
def generate_image(prompt, chat_history): | |
"""Generates an image using Stable Diffusion based on a prompt.""" | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=7.5, | |
num_inference_steps=50, | |
).images[0] | |
# Update chat history | |
chat_history += "You: " + prompt + "\n" | |
chat_history += "Model: Image\n" | |
return chat_history, image | |
def process_text(text, chat_history): | |
"""Processes text, generates a response using GPT-Neo, and generates speech.""" | |
# Generate response using GPT-Neo | |
response = gpt_neo_pipe( | |
text, | |
max_length=100, | |
num_return_sequences=1, | |
)[0]["generated_text"] | |
# Generate speech from the response | |
audio = text_to_speech(response) | |
audio_path = "generated_audio.wav" | |
sf.write(audio_path, audio[0].numpy(), samplerate=22050) | |
# Update chat history | |
chat_history += "You: " + text + "\n" | |
chat_history += "Model: " + response + "\n" | |
return chat_history, audio_path | |
# --- Webcam Capture --- | |
def capture_image(): | |
"""Captures a screenshot from the webcam.""" | |
cap = cv2.VideoCapture(0) | |
ret, frame = cap.read() | |
cap.release() | |
image = Image.fromarray(frame) | |
image_bytes = image.convert("RGB").save("captured_image.jpg", "JPEG") | |
with open("captured_image.jpg", "rb") as f: | |
image_base64 = base64.b64encode(f.read()).decode("utf-8") | |
return image_base64 | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## Llama-LLaVA Vision Speech Assistant") | |
chat_history = gr.Textbox(label="Chat History", lines=10, interactive=False) | |
webcam_output = gr.Image(label="Webcam Feed", interactive=False) | |
image_input = gr.Image(label="Uploaded Image") | |
text_input = gr.Textbox(label="Enter Text") | |
audio_output = gr.Audio(label="Audio Response") | |
# Screenshot button | |
screenshot_button = gr.Button("Capture Screenshot") | |
screenshot_button.click(fn=capture_image, outputs=image_input) | |
# Image processing (LLaVA) | |
image_input.change(fn=process_image, inputs=[image_input, chat_history], outputs=[chat_history, audio_output]) | |
# Text processing (GPT-Neo) | |
text_input.submit(fn=process_text, inputs=[text_input, chat_history], outputs=[chat_history, audio_output]) | |
# Image generation (Stable Diffusion) | |
with gr.Tab("Image Generation"): | |
image_prompt = gr.Textbox(label="Enter image prompt:") | |
image_generation_output = gr.Image(label="Generated Image") | |
generate_image_button = gr.Button("Generate Image") | |
generate_image_button.click( | |
fn=generate_image, inputs=[image_prompt, chat_history], outputs=[chat_history, image_generation_output] | |
) | |
# Webcam stream | |
with gr.Tab("Webcam"): | |
webcam_output = gr.Image(label="Webcam Feed", source="webcam", interactive=False) | |
# Update webcam image every second | |
def update_webcam(): | |
cap = cv2.VideoCapture(0) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
image = Image.fromarray(frame) | |
yield image | |
time.sleep(1) # Update every second | |
webcam_output.source = update_webcam() | |
demo.launch(share=True) |