yerang commited on
Commit
a297d32
·
verified ·
1 Parent(s): b37595d

Update stf/stf-api-alternative/src/stf_alternative/model.py

Browse files
stf/stf-api-alternative/src/stf_alternative/model.py CHANGED
@@ -130,12 +130,15 @@ def create_model(
130
 
131
  if args.model_type == "stf_v3":
132
  g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
133
- g_audio_encoder.to(device)
 
134
  g_audio_encoder.eval()
135
  else:
136
- model.to(device).eval()
 
137
  if args.model_type == "stf_v3":
138
- g_audio_encoder.to(device).eval()
 
139
 
140
  model_data = ModelInfo(
141
  model=model,
 
130
 
131
  if args.model_type == "stf_v3":
132
  g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
133
+ #g_audio_encoder.to(device)
134
+ g_audio_encoder.cuda(device)
135
  g_audio_encoder.eval()
136
  else:
137
+ #model.to(device).eval()
138
+ model.cuda(device).eval()
139
  if args.model_type == "stf_v3":
140
+ #g_audio_encoder.to(device).eval()
141
+ g_audio_encoder.cuda(device).eval()
142
 
143
  model_data = ModelInfo(
144
  model=model,