atatakun commited on
Commit
ba1e50f
·
1 Parent(s): 6611382

Upload 41 files

Browse files
annotator/normalbae/__init__.py CHANGED
@@ -28,7 +28,8 @@ class NormalBaeDetector:
28
  args.importance_ratio = 0.7
29
  model = NNET(args)
30
  model = utils.load_checkpoint(modelpath, model)
31
- model = model.cuda()
 
32
  model.eval()
33
  self.model = model
34
  self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@@ -37,7 +38,8 @@ class NormalBaeDetector:
37
  assert input_image.ndim == 3
38
  image_normal = input_image
39
  with torch.no_grad():
40
- image_normal = torch.from_numpy(image_normal).float().cuda()
 
41
  image_normal = image_normal / 255.0
42
  image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
43
  image_normal = self.norm(image_normal)
 
28
  args.importance_ratio = 0.7
29
  model = NNET(args)
30
  model = utils.load_checkpoint(modelpath, model)
31
+ # model = model.cuda()
32
+ model = model.cpu()
33
  model.eval()
34
  self.model = model
35
  self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
38
  assert input_image.ndim == 3
39
  image_normal = input_image
40
  with torch.no_grad():
41
+ # image_normal = torch.from_numpy(image_normal).float().cuda()
42
+ image_normal = torch.from_numpy(image_normal).float().cpu()
43
  image_normal = image_normal / 255.0
44
  image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
45
  image_normal = self.norm(image_normal)
annotator/normalbae/models/submodules/efficientnet_repo/geffnet/helpers.py CHANGED
@@ -14,6 +14,7 @@ def load_checkpoint(model, checkpoint_path):
14
  if checkpoint_path and os.path.isfile(checkpoint_path):
15
  print("=> Loading checkpoint '{}'".format(checkpoint_path))
16
  checkpoint = torch.load(checkpoint_path)
 
17
  if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
18
  new_state_dict = OrderedDict()
19
  for k, v in checkpoint['state_dict'].items():
 
14
  if checkpoint_path and os.path.isfile(checkpoint_path):
15
  print("=> Loading checkpoint '{}'".format(checkpoint_path))
16
  checkpoint = torch.load(checkpoint_path)
17
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
18
  if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
19
  new_state_dict = OrderedDict()
20
  for k, v in checkpoint['state_dict'].items():