Spaces:
Running
on
Zero
Running
on
Zero
Update stf/stf-api-alternative/src/stf_alternative/model.py
Browse files
stf/stf-api-alternative/src/stf_alternative/model.py
CHANGED
@@ -131,14 +131,14 @@ def create_model(
|
|
131 |
if args.model_type == "stf_v3":
|
132 |
g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
|
133 |
#g_audio_encoder.to(device)
|
134 |
-
g_audio_encoder.cuda(
|
135 |
g_audio_encoder.eval()
|
136 |
else:
|
137 |
#model.to(device).eval()
|
138 |
-
model.cuda(
|
139 |
if args.model_type == "stf_v3":
|
140 |
#g_audio_encoder.to(device).eval()
|
141 |
-
g_audio_encoder.cuda(
|
142 |
|
143 |
model_data = ModelInfo(
|
144 |
model=model,
|
|
|
131 |
if args.model_type == "stf_v3":
|
132 |
g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
|
133 |
#g_audio_encoder.to(device)
|
134 |
+
g_audio_encoder.cuda(0)
|
135 |
g_audio_encoder.eval()
|
136 |
else:
|
137 |
#model.to(device).eval()
|
138 |
+
model.cuda(0).eval()
|
139 |
if args.model_type == "stf_v3":
|
140 |
#g_audio_encoder.to(device).eval()
|
141 |
+
g_audio_encoder.cuda(0).eval()
|
142 |
|
143 |
model_data = ModelInfo(
|
144 |
model=model,
|