Abdulrahman1989 commited on
Commit
02554a3
·
1 Parent(s): e43b248

Code now works

Browse files
Files changed (2) hide show
  1. SDXLImageGenerator.py +1 -1
  2. app.py +51 -24
SDXLImageGenerator.py CHANGED
@@ -22,7 +22,7 @@ class SDXLImageGenerator:
22
  )
23
  self.pipe.to(self.device)
24
 
25
- def generate_images(self, prompts):
26
  start_time = time.time()
27
 
28
  # Generate images in a batch
 
22
  )
23
  self.pipe.to(self.device)
24
 
25
+ def generate_image(self, prompts):
26
  start_time = time.time()
27
 
28
  # Generate images in a batch
app.py CHANGED
@@ -6,6 +6,8 @@ import sys
6
  from Image3DProcessor import Image3DProcessor # Import your 3D processing class
7
  from PIL import Image
8
  import io
 
 
9
 
10
  class VideoGenerator:
11
  def __init__(self, model_cfg_path, model_repo_id, model_filename):
@@ -13,15 +15,21 @@ class VideoGenerator:
13
  self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
14
 
15
  def generate_3d_video(self, image):
 
 
 
16
  # Preprocess the image first
17
  processed_image = self.processor.preprocess(image)
18
  # Then pass it to reconstruct_and_export
19
  video_data = self.processor.reconstruct_and_export(processed_image)
20
- return video_data
 
 
 
21
 
22
  class GradioApp:
23
  def __init__(self):
24
- self.sdxl_generator = SDXLImageGenerator() # Use your existing class
25
  # Initialize VideoGenerator with required paths and details
26
  self.video_generator = VideoGenerator(
27
  model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
@@ -29,34 +37,53 @@ class GradioApp:
29
  model_filename="model_latest.pth"
30
  )
31
 
32
- def full_pipeline(self, prompt):
33
- # Generate the initial image using SDXLImageGenerator
34
- initial_image = self.sdxl_generator.generate_images([prompt])[0]
35
-
36
- # Generate a 3D video using the image
37
- video_data = self.video_generator.generate_3d_video(initial_image)
38
-
39
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
40
- video_file.write(video_data)
41
- video_path = video_file.name
42
-
43
- # Convert bytes to a PIL Image for further processing and display
44
- initial_image = Image.open(io.BytesIO(initial_image))
45
-
46
- return initial_image, video_path
47
-
48
  def launch(self):
49
  with gr.Blocks() as interface:
 
50
  prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox")
51
- generate_button = gr.Button("Generate")
 
 
 
 
52
  with gr.Row():
53
- image_output = gr.Image(label="Generated Image", elem_id="generated_image")
54
- video_output = gr.Video(label="3D Model Video", elem_id="model_video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- generate_button.click(fn=self.full_pipeline, inputs=prompt_input, outputs=[image_output, video_output])
57
-
58
  interface.launch(share=True)
59
 
60
  if __name__ == "__main__":
61
  app = GradioApp()
62
- app.launch()
 
6
  from Image3DProcessor import Image3DProcessor # Import your 3D processing class
7
  from PIL import Image
8
  import io
9
+ from io import BytesIO
10
+ import numpy as np
11
 
12
  class VideoGenerator:
13
  def __init__(self, model_cfg_path, model_repo_id, model_filename):
 
15
  self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
16
 
17
  def generate_3d_video(self, image):
18
+ # Ensure the image is a PIL Image object
19
+ if isinstance(image, np.ndarray):
20
+ image = Image.fromarray(image)
21
  # Preprocess the image first
22
  processed_image = self.processor.preprocess(image)
23
  # Then pass it to reconstruct_and_export
24
  video_data = self.processor.reconstruct_and_export(processed_image)
25
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as video_file:
26
+ video_file.write(video_data)
27
+ video_path = video_file.name
28
+ return video_path
29
 
30
  class GradioApp:
31
  def __init__(self):
32
+ self.sdxl_generator = SDXLImageGenerator()
33
  # Initialize VideoGenerator with required paths and details
34
  self.video_generator = VideoGenerator(
35
  model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
 
37
  model_filename="model_latest.pth"
38
  )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def launch(self):
41
  with gr.Blocks() as interface:
42
+ # Input for the prompt at the top
43
  prompt_input = gr.Textbox(label="Input Prompt", elem_id="input_textbox")
44
+
45
+ # Button for generating the 3D object
46
+ generate_3d_object = gr.Button("Generate 3D object")
47
+
48
+ # Outputs: image on the bottom left, video on the bottom right
49
  with gr.Row():
50
+ with gr.Column():
51
+ image_output = gr.Image(label="Generated Image", elem_id="generated_image")
52
+ with gr.Column():
53
+ video_output = gr.Video(label="3D Model Video", elem_id="model_video")
54
+
55
+ # Generate the image first
56
+ def generate_image_and_display(prompt):
57
+ # Generate the image from the prompt
58
+ image_data = self.sdxl_generator.generate_image([prompt])[0]
59
+ return Image.open(BytesIO(image_data))
60
+
61
+ # Generate the 3D after the image is ready
62
+ def generate_3D_from_image(image):
63
+ # Ensure the image is a PIL Image object
64
+ if isinstance(image, np.ndarray):
65
+ image = Image.fromarray(image)
66
+ # Generate the 3D from the generated image
67
+ return self.video_generator.generate_3d_video(image)
68
+
69
+ # First click generates the image
70
+ generate_3d_object.click(
71
+ fn=generate_image_and_display,
72
+ inputs=prompt_input,
73
+ outputs=image_output,
74
+ queue=True
75
+ )
76
+
77
+ # Once the image is ready, generate the video
78
+ image_output.change(
79
+ fn=generate_3D_from_image,
80
+ inputs=image_output,
81
+ outputs=video_output,
82
+ queue=True
83
+ )
84
 
 
 
85
  interface.launch(share=True)
86
 
87
  if __name__ == "__main__":
88
  app = GradioApp()
89
+ app.launch()