Abdulrahman1989 commited on
Commit
28f286f
·
1 Parent(s): 803a138

Fix the model path

Browse files
Files changed (2) hide show
  1. Image3DProcessor.py +1 -1
  2. app.py +3 -4
Image3DProcessor.py CHANGED
@@ -34,7 +34,7 @@ class Image3DProcessor:
34
  self.model_cfg = OmegaConf.load(model_cfg_path)
35
 
36
  # Load pre-trained model weights
37
- model_path = hf_hub_download(repo_id=model_repo_id, filename=model_filename)
38
  self.model = GaussianSplatPredictor(self.model_cfg)
39
  ckpt_loaded = torch.load(model_path, map_location=self.device)
40
  self.model.load_state_dict(ckpt_loaded["model_state_dict"])
 
34
  self.model_cfg = OmegaConf.load(model_cfg_path)
35
 
36
  # Load pre-trained model weights
37
+ model_path = model_filename
38
  self.model = GaussianSplatPredictor(self.model_cfg)
39
  ckpt_loaded = torch.load(model_path, map_location=self.device)
40
  self.model.load_state_dict(ckpt_loaded["model_state_dict"])
app.py CHANGED
@@ -10,9 +10,9 @@ from io import BytesIO
10
  import numpy as np
11
 
12
  class VideoGenerator:
13
- def __init__(self, model_cfg_path, model_repo_id, model_filename):
14
  # Initialize the Image3DProcessor
15
- self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
16
 
17
  def generate_3d_video(self, image):
18
  # Ensure the image is a PIL Image object
@@ -33,8 +33,7 @@ class GradioApp:
33
  # Initialize VideoGenerator with required paths and details
34
  self.video_generator = VideoGenerator(
35
  model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
36
- model_repo_id="szymanowiczs/splatter-image-multi-category-v1",
37
- model_filename="model_latest.pth"
38
  )
39
 
40
  def launch(self):
 
10
  import numpy as np
11
 
12
  class VideoGenerator:
13
+ def __init__(self, model_cfg_path, model_filename):
14
  # Initialize the Image3DProcessor
15
+ self.processor = Image3DProcessor(model_cfg_path, model_filename)
16
 
17
  def generate_3d_video(self, image):
18
  # Ensure the image is a PIL Image object
 
33
  # Initialize VideoGenerator with required paths and details
34
  self.video_generator = VideoGenerator(
35
  model_cfg_path="/home/user/app/splatter-image/gradio_config.yaml",
36
+ model_filename="/home/user/app/fused_text_to_3D.pth"
 
37
  )
38
 
39
  def launch(self):