Spaces:
Paused
Paused
import os | |
print("Current working directory:", os.getcwd()) | |
import gradio as gr | |
import tempfile | |
from SDXLImageGenerator import SDXLImageGenerator # Import your existing class | |
import sys | |
from Image3DProcessor import Image3DProcessor # Import your 3D processing class | |
class ControlNetProcessor: | |
def controlnet_image(self, image): | |
# Placeholder for ControlNet processing (e.g., returning a processed image or placeholder text) | |
return image # Returning the image for further processing | |
class VideoGenerator: | |
def __init__(self, model_cfg_path, model_repo_id, model_filename): | |
# Initialize the Image3DProcessor | |
self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename) | |
def generate_3d_video(self, image): | |
# Preprocess the image first | |
processed_image = self.processor.preprocess(image) | |
# Then pass it to reconstruct_and_export | |
mesh_data, video_data = self.processor.reconstruct_and_export(processed_image) | |
return mesh_data, video_data | |
class GradioApp: | |
def __init__(self): | |
self.sdxl_generator = SDXLImageGenerator() # Use your existing class | |
self.controlnet_processor = ControlNetProcessor() | |
# Initialize VideoGenerator with required paths and details | |
self.video_generator = VideoGenerator( | |
model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml", | |
model_repo_id="szymanowiczs/splatter-image-multi-category-v1", | |
model_filename="model_latest.pth" | |
) | |
def full_pipeline(self, prompt): | |
initial_image = self.sdxl_generator.generate_images([prompt])[0] | |
# controlled_image = self.controlnet_processor.controlnet_image(initial_image) | |
mesh_data, video_data = self.video_generator.generate_3d_video(initial_image) | |
# Create temporary files to display mesh and video content | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".ply") as mesh_file: | |
mesh_file.write(mesh_data) | |
mesh_path = mesh_file.name | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file: | |
video_file.write(video_data) | |
video_path = video_file.name | |
return initial_image, mesh_path, video_path | |
def launch(self): | |
interface = gr.Interface( | |
fn=self.full_pipeline, | |
inputs=gr.Textbox(label="Input Prompt"), | |
outputs=[ | |
gr.Image(label="Generated Image"), | |
gr.File(label="3D Mesh (.ply)"), | |
gr.Video(label="3D Model Video") | |
], | |
title="SDXL to ControlNet to 3D Pipeline", | |
description="Generate an image using SDXL, refine it with ControlNet, and generate a 3D video output." | |
) | |
interface.launch(share=True) # Added `share=True` for public link | |
if __name__ == "__main__": | |
app = GradioApp() | |
app.launch() |