File size: 3,602 Bytes
fc352b3
a15c0ab
30f0bac
ad5cb25
583955f
66a8046
2f8a737
 
02554a3
 
66a8046
30f0bac
28f286f
66a8046
28f286f
66a8046
 
02554a3
 
 
673639e
 
 
9304db7
02554a3
 
 
 
30f0bac
 
 
02554a3
66a8046
 
27c0c4f
28f286f
66a8046
30f0bac
 
8e410e5
02554a3
e43b248
02554a3
 
 
 
 
724d8ef
02554a3
 
 
8e410e5
02554a3
 
 
8e410e5
 
02554a3
8e410e5
02554a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e43b248
5f35729
a15c0ab
 
30f0bac
02554a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import gradio as gr
import tempfile
from SDXLImageGenerator import SDXLImageGenerator  # Import your existing class
import sys
from Image3DProcessor import Image3DProcessor  # Import your 3D processing class
from PIL import Image
import io
from io import BytesIO
import numpy as np

class VideoGenerator:
    def __init__(self, model_cfg_path, model_filename):
        # Initialize the Image3DProcessor
        self.processor = Image3DProcessor(model_cfg_path, model_filename)

    def generate_3d_video(self, image):
        # Ensure the image is a PIL Image object
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        # Preprocess the image first
        processed_image = self.processor.preprocess(image)
        # Then pass it to reconstruct_and_export
        video_data = self.processor.reconstruct_and_export(processed_image)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
            video_file.write(video_data)
            video_path = video_file.name
        return video_path

class GradioApp:
    def __init__(self):
        self.sdxl_generator = SDXLImageGenerator()
        # Initialize VideoGenerator with required paths and details
        self.video_generator = VideoGenerator(
            model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
            model_filename="/home/user/app/fused_text_to_3D.pth"
        )

    def launch(self):
        with gr.Blocks(css="#small_video { width: 400px !important; height: 300px !important; }") as interface:
            # Input for the prompt at the top
            prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox")

            # Button for generating the 3D object
            generate_3d_object = gr.Button("Generate 3D object")

            # Outputs: image on the bottom left, video on the bottom right
            with gr.Row():
                with gr.Column():
                    image_output = gr.Image(label="Generated Image", elem_id="generated_image")
                with gr.Column():
                    video_output = gr.Video(label="3D Model Video", elem_id="small_video")

            # Generate the image first
            def generate_image_and_display(prompt):
                modified_prompt = prompt + ", isolated, on a plain background, minimal, no extra objects"
                print(modified_prompt)
                # Generate the image from the prompt
                image_data = self.sdxl_generator.generate_image([modified_prompt])[0]
                return Image.open(BytesIO(image_data))

            # Generate the 3D after the image is ready
            def generate_3D_from_image(image):
                # Ensure the image is a PIL Image object
                if isinstance(image, np.ndarray):
                    image = Image.fromarray(image)
                # Generate the 3D from the generated image
                return self.video_generator.generate_3d_video(image)

            # First click generates the image
            generate_3d_object.click(
                fn=generate_image_and_display,
                inputs=prompt_input,
                outputs=image_output,
                queue=True
            )

            # Once the image is ready, generate the video
            image_output.change(
                fn=generate_3D_from_image,
                inputs=image_output,
                outputs=video_output,
                queue=True
            )

        interface.launch(share=True)

if __name__ == "__main__":
    app = GradioApp()
    app.launch()