TextTo3D / app.py
Abdulrahman1989's picture
Fix initial_image
60367c0
raw
history blame
2.65 kB
import os
import gradio as gr
import tempfile
from SDXLImageGenerator import SDXLImageGenerator # Import your existing class
import sys
from Image3DProcessor import Image3DProcessor # Import your 3D processing class
from PIL import Image
import io
class ControlNetProcessor:
def controlnet_image(self, image):
# Placeholder for ControlNet processing (e.g., returning a processed image or placeholder text)
return image # Returning the image for further processing
class VideoGenerator:
def __init__(self, model_cfg_path, model_repo_id, model_filename):
# Initialize the Image3DProcessor
self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
def generate_3d_video(self, image):
# Preprocess the image first
processed_image = self.processor.preprocess(image)
# Then pass it to reconstruct_and_export
video_data = self.processor.reconstruct_and_export(processed_image)
return video_data
class GradioApp:
def __init__(self):
self.sdxl_generator = SDXLImageGenerator() # Use your existing class
self.controlnet_processor = ControlNetProcessor()
# Initialize VideoGenerator with required paths and details
self.video_generator = VideoGenerator(
model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
model_repo_id="szymanowiczs/splatter-image-multi-category-v1",
model_filename="model_latest.pth"
)
def full_pipeline(self, prompt):
initial_image = self.sdxl_generator.generate_images([prompt])[0]
video_data = self.video_generator.generate_3d_video(initial_image)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
video_file.write(video_data)
video_path = video_file.name
# Convert bytes to a PIL Image for further processing and display
initial_image = Image.open(io.BytesIO(initial_image))
return initial_image, video_path
def launch(self):
interface = gr.Interface(
fn=self.full_pipeline,
inputs=gr.Textbox(label="Input Prompt"),
outputs=[
gr.Image(label="Generated Image"),
gr.Video(label="3D Model Video")
],
title="SDXL to ControlNet to 3D Pipeline",
description="Generate an image using SDXL, refine it with ControlNet, and generate a 3D video output."
)
interface.launch(share=True) # Added `share=True` for public link
if __name__ == "__main__":
app = GradioApp()
app.launch()