yerang commited on
Commit
cf4e51a
·
verified ·
1 Parent(s): a297d32

Update stf/stf-api-alternative/src/stf_alternative/model.py

Browse files
stf/stf-api-alternative/src/stf_alternative/model.py CHANGED
@@ -131,14 +131,14 @@ def create_model(
131
  if args.model_type == "stf_v3":
132
  g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
133
  #g_audio_encoder.to(device)
134
- g_audio_encoder.cuda(device)
135
  g_audio_encoder.eval()
136
  else:
137
  #model.to(device).eval()
138
- model.cuda(device).eval()
139
  if args.model_type == "stf_v3":
140
  #g_audio_encoder.to(device).eval()
141
- g_audio_encoder.cuda(device).eval()
142
 
143
  model_data = ModelInfo(
144
  model=model,
 
131
  if args.model_type == "stf_v3":
132
  g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
133
  #g_audio_encoder.to(device)
134
+ g_audio_encoder.cuda(0)
135
  g_audio_encoder.eval()
136
  else:
137
  #model.to(device).eval()
138
+ model.cuda(0).eval()
139
  if args.model_type == "stf_v3":
140
  #g_audio_encoder.to(device).eval()
141
+ g_audio_encoder.cuda(0).eval()
142
 
143
  model_data = ModelInfo(
144
  model=model,