File size: 2,954 Bytes
fc352b3
 
a15c0ab
30f0bac
ad5cb25
583955f
66a8046
 
30f0bac
 
17397c2
66a8046
30f0bac
 
66a8046
 
 
 
 
673639e
 
 
 
9d98595
30f0bac
 
 
ad5cb25
30f0bac
66a8046
 
27c0c4f
66a8046
 
 
30f0bac
 
ad5cb25
b2ebb9a
9d98595
 
531cd4b
b2ebb9a
 
 
9d98595
 
 
 
 
b2ebb9a
30f0bac
 
 
 
17397c2
30f0bac
ad5cb25
b2ebb9a
30f0bac
 
 
 
 
b2ebb9a
a15c0ab
 
30f0bac
673639e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
print("Current working directory:", os.getcwd())
import gradio as gr
import tempfile
from SDXLImageGenerator import SDXLImageGenerator  # Import your existing class
import sys
from Image3DProcessor import Image3DProcessor  # Import your 3D processing class

class ControlNetProcessor:
    def controlnet_image(self, image):
        # Placeholder for ControlNet processing (e.g., returning a processed image or placeholder text)
        return image  # Returning the image for further processing

class VideoGenerator:
    def __init__(self, model_cfg_path, model_repo_id, model_filename):
        # Initialize the Image3DProcessor
        self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)

    def generate_3d_video(self, image):
        # Preprocess the image first
        processed_image = self.processor.preprocess(image)
        # Then pass it to reconstruct_and_export
        mesh_data, video_data = self.processor.reconstruct_and_export(processed_image)
        return mesh_data, video_data

class GradioApp:
    def __init__(self):
        self.sdxl_generator = SDXLImageGenerator()  # Use your existing class
        self.controlnet_processor = ControlNetProcessor()
        # Initialize VideoGenerator with required paths and details
        self.video_generator = VideoGenerator(
            model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
            model_repo_id="szymanowiczs/splatter-image-multi-category-v1",
            model_filename="model_latest.pth"
        )

    def full_pipeline(self, prompt):
        initial_image = self.sdxl_generator.generate_images([prompt])[0]
        # controlled_image = self.controlnet_processor.controlnet_image(initial_image)
        mesh_data, video_data = self.video_generator.generate_3d_video(initial_image)
        
        # Create temporary files to display mesh and video content
        with tempfile.NamedTemporaryFile(delete=False, suffix=".ply") as mesh_file:
            mesh_file.write(mesh_data)
            mesh_path = mesh_file.name
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
            video_file.write(video_data)
            video_path = video_file.name
        
        return initial_image, mesh_path, video_path

    def launch(self):
        interface = gr.Interface(
            fn=self.full_pipeline,
            inputs=gr.Textbox(label="Input Prompt"),
            outputs=[
                gr.Image(label="Generated Image"),
                gr.File(label="3D Mesh (.ply)"),
                gr.Video(label="3D Model Video")
            ],
            title="SDXL to ControlNet to 3D Pipeline",
            description="Generate an image using SDXL, refine it with ControlNet, and generate a 3D video output."
        )
        interface.launch(share=True)  # Added `share=True` for public link

if __name__ == "__main__":
    app = GradioApp()
    app.launch()