import os import gradio as gr import tempfile from SDXLImageGenerator import SDXLImageGenerator # Import your existing class import sys from Image3DProcessor import Image3DProcessor # Import your 3D processing class from PIL import Image import io class VideoGenerator: def __init__(self, model_cfg_path, model_repo_id, model_filename): # Initialize the Image3DProcessor self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename) def generate_3d_video(self, image): # Preprocess the image first processed_image = self.processor.preprocess(image) # Then pass it to reconstruct_and_export video_data = self.processor.reconstruct_and_export(processed_image) return video_data class GradioApp: def __init__(self): self.sdxl_generator = SDXLImageGenerator() # Use your existing class # Initialize VideoGenerator with required paths and details self.video_generator = VideoGenerator( model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml", model_repo_id="szymanowiczs/splatter-image-multi-category-v1", model_filename="model_latest.pth" ) def full_pipeline(self, prompt): # Generate the initial image using SDXLImageGenerator initial_image = self.sdxl_generator.generate_images([prompt])[0] # Generate a 3D video using the image video_data = self.video_generator.generate_3d_video(initial_image) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file: video_file.write(video_data) video_path = video_file.name # Convert bytes to a PIL Image for further processing and display initial_image = Image.open(io.BytesIO(initial_image)) return initial_image, video_path def launch(self): with gr.Blocks() as interface: prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox") generate_button = gr.Button("Generate") with gr.Row(): image_output = gr.Image(label="Generated Image", elem_id="generated_image") video_output = gr.Video(label="3D Model Video", elem_id="model_video") generate_button.click(fn=self.full_pipeline, inputs=prompt_input, outputs=[image_output, video_output]) interface.launch(share=True) if __name__ == "__main__": app = GradioApp() app.launch()