Spaces:
Paused
Paused
Abdulrahman1989
commited on
Commit
·
28f286f
1
Parent(s):
803a138
Fix the model path
Browse files- Image3DProcessor.py +1 -1
- app.py +3 -4
Image3DProcessor.py
CHANGED
@@ -34,7 +34,7 @@ class Image3DProcessor:
|
|
34 |
self.model_cfg = OmegaConf.load(model_cfg_path)
|
35 |
|
36 |
# Load pre-trained model weights
|
37 |
-
model_path =
|
38 |
self.model = GaussianSplatPredictor(self.model_cfg)
|
39 |
ckpt_loaded = torch.load(model_path, map_location=self.device)
|
40 |
self.model.load_state_dict(ckpt_loaded["model_state_dict"])
|
|
|
34 |
self.model_cfg = OmegaConf.load(model_cfg_path)
|
35 |
|
36 |
# Load pre-trained model weights
|
37 |
+
model_path = model_filename
|
38 |
self.model = GaussianSplatPredictor(self.model_cfg)
|
39 |
ckpt_loaded = torch.load(model_path, map_location=self.device)
|
40 |
self.model.load_state_dict(ckpt_loaded["model_state_dict"])
|
app.py
CHANGED
@@ -10,9 +10,9 @@ from io import BytesIO
|
|
10 |
import numpy as np
|
11 |
|
12 |
class VideoGenerator:
|
13 |
-
def __init__(self, model_cfg_path,
|
14 |
# Initialize the Image3DProcessor
|
15 |
-
self.processor = Image3DProcessor(model_cfg_path,
|
16 |
|
17 |
def generate_3d_video(self, image):
|
18 |
# Ensure the image is a PIL Image object
|
@@ -33,8 +33,7 @@ class GradioApp:
|
|
33 |
# Initialize VideoGenerator with required paths and details
|
34 |
self.video_generator = VideoGenerator(
|
35 |
model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
|
36 |
-
|
37 |
-
model_filename="model_latest.pth"
|
38 |
)
|
39 |
|
40 |
def launch(self):
|
|
|
10 |
import numpy as np
|
11 |
|
12 |
class VideoGenerator:
|
13 |
+
def __init__(self, model_cfg_path, model_filename):
|
14 |
# Initialize the Image3DProcessor
|
15 |
+
self.processor = Image3DProcessor(model_cfg_path, model_filename)
|
16 |
|
17 |
def generate_3d_video(self, image):
|
18 |
# Ensure the image is a PIL Image object
|
|
|
33 |
# Initialize VideoGenerator with required paths and details
|
34 |
self.video_generator = VideoGenerator(
|
35 |
model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
|
36 |
+
model_filename="/home/user/app/fused_text_to_3D.pth"
|
|
|
37 |
)
|
38 |
|
39 |
def launch(self):
|