import os import gradio as gr import tempfile from SDXLImageGenerator import SDXLImageGenerator # Import your existing class import sys from Image3DProcessor import Image3DProcessor # Import your 3D processing class from PIL import Image import io from io import BytesIO import numpy as np class VideoGenerator: def __init__(self, model_cfg_path, model_filename): # Initialize the Image3DProcessor self.processor = Image3DProcessor(model_cfg_path, model_filename) def generate_3d_video(self, image): # Ensure the image is a PIL Image object if isinstance(image, np.ndarray): image = Image.fromarray(image) # Preprocess the image first processed_image = self.processor.preprocess(image) # Then pass it to reconstruct_and_export video_data = self.processor.reconstruct_and_export(processed_image) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file: video_file.write(video_data) video_path = video_file.name return video_path class GradioApp: def __init__(self): self.sdxl_generator = SDXLImageGenerator() # Initialize VideoGenerator with required paths and details self.video_generator = VideoGenerator( model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml", model_filename="/home/user/app/fused_text_to_3D.pth" ) def launch(self): with gr.Blocks(css="#small_video { width: 400px !important; height: 300px !important; }") as interface: # Input for the prompt at the top prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox") # Button for generating the 3D object generate_3d_object = gr.Button("Generate 3D object") # Outputs: image on the bottom left, video on the bottom right with gr.Row(): with gr.Column(): image_output = gr.Image(label="Generated Image", elem_id="generated_image") with gr.Column(): video_output = gr.Video(label="3D Model Video", elem_id="small_video") # Generate the image first def generate_image_and_display(prompt): modified_prompt = prompt + ", isolated, on a plain background, minimal, no extra objects" print(modified_prompt) # Generate the image from the prompt image_data = self.sdxl_generator.generate_image([modified_prompt])[0] return Image.open(BytesIO(image_data)) # Generate the 3D after the image is ready def generate_3D_from_image(image): # Ensure the image is a PIL Image object if isinstance(image, np.ndarray): image = Image.fromarray(image) # Generate the 3D from the generated image return self.video_generator.generate_3d_video(image) # First click generates the image generate_3d_object.click( fn=generate_image_and_display, inputs=prompt_input, outputs=image_output, queue=True ) # Once the image is ready, generate the video image_output.change( fn=generate_3D_from_image, inputs=image_output, outputs=video_output, queue=True ) interface.launch(share=True) if __name__ == "__main__": app = GradioApp() app.launch()