jens commited on
Commit
1d564b4
·
1 Parent(s): 2ba4f1e
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -10,8 +10,8 @@ import requests
10
  class DepthPredictor:
11
  def __init__(self):
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- self.processor = DPTImageProcessor.from_pretrained("Intel/dpt-large").to(self.device)
14
- self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(self.device)
15
  self.model.eval()
16
 
17
  def predict(self, image):
 
10
  class DepthPredictor:
11
  def __init__(self):
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
14
+ self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
15
  self.model.eval()
16
 
17
  def predict(self, image):