TextTo3D / app.py
Abdulrahman1989's picture
Fix the model path
28f286f
import os
import gradio as gr
import tempfile
from SDXLImageGenerator import SDXLImageGenerator # Import your existing class
import sys
from Image3DProcessor import Image3DProcessor # Import your 3D processing class
from PIL import Image
import io
from io import BytesIO
import numpy as np
class VideoGenerator:
def __init__(self, model_cfg_path, model_filename):
# Initialize the Image3DProcessor
self.processor = Image3DProcessor(model_cfg_path, model_filename)
def generate_3d_video(self, image):
# Ensure the image is a PIL Image object
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Preprocess the image first
processed_image = self.processor.preprocess(image)
# Then pass it to reconstruct_and_export
video_data = self.processor.reconstruct_and_export(processed_image)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
video_file.write(video_data)
video_path = video_file.name
return video_path
class GradioApp:
def __init__(self):
self.sdxl_generator = SDXLImageGenerator()
# Initialize VideoGenerator with required paths and details
self.video_generator = VideoGenerator(
model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
model_filename="/home/user/app/fused_text_to_3D.pth"
)
def launch(self):
with gr.Blocks(css="#small_video { width: 400px !important; height: 300px !important; }") as interface:
# Input for the prompt at the top
prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox")
# Button for generating the 3D object
generate_3d_object = gr.Button("Generate 3D object")
# Outputs: image on the bottom left, video on the bottom right
with gr.Row():
with gr.Column():
image_output = gr.Image(label="Generated Image", elem_id="generated_image")
with gr.Column():
video_output = gr.Video(label="3D Model Video", elem_id="small_video")
# Generate the image first
def generate_image_and_display(prompt):
modified_prompt = prompt + ", isolated, on a plain background, minimal, no extra objects"
print(modified_prompt)
# Generate the image from the prompt
image_data = self.sdxl_generator.generate_image([modified_prompt])[0]
return Image.open(BytesIO(image_data))
# Generate the 3D after the image is ready
def generate_3D_from_image(image):
# Ensure the image is a PIL Image object
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Generate the 3D from the generated image
return self.video_generator.generate_3d_video(image)
# First click generates the image
generate_3d_object.click(
fn=generate_image_and_display,
inputs=prompt_input,
outputs=image_output,
queue=True
)
# Once the image is ready, generate the video
image_output.change(
fn=generate_3D_from_image,
inputs=image_output,
outputs=video_output,
queue=True
)
interface.launch(share=True)
if __name__ == "__main__":
app = GradioApp()
app.launch()