yerang commited on
Commit
4d43820
·
verified ·
1 Parent(s): a17e76c

Update stf/stf-api-alternative/src/stf_alternative/model.py

Browse files
stf/stf-api-alternative/src/stf_alternative/model.py CHANGED
@@ -125,7 +125,7 @@ def create_model(
125
  gpus = list(range(torch.cuda.device_count()))
126
  print("Multi GPU activate, gpus : ", gpus)
127
  model = torch.nn.DataParallel(model, device_ids=gpus)
128
- model.to(device)
129
  model.eval()
130
 
131
  if args.model_type == "stf_v3":
 
125
  gpus = list(range(torch.cuda.device_count()))
126
  print("Multi GPU activate, gpus : ", gpus)
127
  model = torch.nn.DataParallel(model, device_ids=gpus)
128
+ model.cuda(0) # to(device)
129
  model.eval()
130
 
131
  if args.model_type == "stf_v3":