File size: 2,492 Bytes
fc352b3
a15c0ab
30f0bac
ad5cb25
583955f
66a8046
2f8a737
 
66a8046
30f0bac
66a8046
 
 
 
 
673639e
 
 
9304db7
 
30f0bac
 
 
ad5cb25
66a8046
 
27c0c4f
66a8046
 
 
30f0bac
 
5f35729
ad5cb25
65ec2c8
5f35729
65ec2c8
 
9d98595
 
 
157398f
 
60367c0
9d98595
65ec2c8
30f0bac
 
724d8ef
e43b248
 
724d8ef
e43b248
 
 
 
5f35729
 
a15c0ab
 
30f0bac
e43b248
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import gradio as gr
import tempfile
from SDXLImageGenerator import SDXLImageGenerator  # Import your existing class
import sys
from Image3DProcessor import Image3DProcessor  # Import your 3D processing class
from PIL import Image
import io

class VideoGenerator:
    def __init__(self, model_cfg_path, model_repo_id, model_filename):
        # Initialize the Image3DProcessor
        self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)

    def generate_3d_video(self, image):
        # Preprocess the image first
        processed_image = self.processor.preprocess(image)
        # Then pass it to reconstruct_and_export
        video_data = self.processor.reconstruct_and_export(processed_image)
        return video_data

class GradioApp:
    def __init__(self):
        self.sdxl_generator = SDXLImageGenerator()  # Use your existing class
        # Initialize VideoGenerator with required paths and details
        self.video_generator = VideoGenerator(
            model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
            model_repo_id="szymanowiczs/splatter-image-multi-category-v1",
            model_filename="model_latest.pth"
        )

    def full_pipeline(self, prompt):
        # Generate the initial image using SDXLImageGenerator
        initial_image = self.sdxl_generator.generate_images([prompt])[0]

        # Generate a 3D video using the image
        video_data = self.video_generator.generate_3d_video(initial_image)

        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
            video_file.write(video_data)
            video_path = video_file.name

        # Convert bytes to a PIL Image for further processing and display
        initial_image = Image.open(io.BytesIO(initial_image))
        
        return initial_image, video_path

    def launch(self):
        with gr.Blocks() as interface:
            prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox")
            generate_button = gr.Button("Generate")
            with gr.Row():
                image_output = gr.Image(label="Generated Image", elem_id="generated_image")
                video_output = gr.Video(label="3D Model Video", elem_id="model_video")

            generate_button.click(fn=self.full_pipeline, inputs=prompt_input, outputs=[image_output, video_output])
        
        interface.launch(share=True)

if __name__ == "__main__":
    app = GradioApp()
    app.launch()