Spaces:
Sleeping
Sleeping
jhj0517
commited on
Commit
·
da960ac
1
Parent(s):
b14146c
Add `offload()`
Browse files
modules/uvr/music_separator.py
CHANGED
@@ -4,6 +4,7 @@ import torchaudio
|
|
4 |
import soundfile as sf
|
5 |
import os
|
6 |
import torch
|
|
|
7 |
|
8 |
from uvr.models import MDX, Demucs, VrNetwork, MDXC
|
9 |
|
@@ -30,6 +31,14 @@ class MusicSeparator:
|
|
30 |
model_name: str = "UVR-MDX-NET-Inst_1",
|
31 |
device: Optional[str] = None,
|
32 |
segment_size: int = 256):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
if device is None:
|
34 |
device = self.device
|
35 |
|
@@ -61,7 +70,10 @@ class MusicSeparator:
|
|
61 |
"split": True
|
62 |
}
|
63 |
|
64 |
-
if self.model is None or
|
|
|
|
|
|
|
65 |
self.update_model(
|
66 |
model_name=model_name,
|
67 |
device=device,
|
@@ -84,4 +96,13 @@ class MusicSeparator:
|
|
84 |
|
85 |
@staticmethod
|
86 |
def get_device():
|
87 |
-
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import soundfile as sf
|
5 |
import os
|
6 |
import torch
|
7 |
+
import gc
|
8 |
|
9 |
from uvr.models import MDX, Demucs, VrNetwork, MDXC
|
10 |
|
|
|
31 |
model_name: str = "UVR-MDX-NET-Inst_1",
|
32 |
device: Optional[str] = None,
|
33 |
segment_size: int = 256):
|
34 |
+
"""
|
35 |
+
Update model with the given model name
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model_name (str): Model name.
|
39 |
+
device (str): Device to use for the model.
|
40 |
+
segment_size (int): Segment size for the prediction.
|
41 |
+
"""
|
42 |
if device is None:
|
43 |
device = self.device
|
44 |
|
|
|
70 |
"split": True
|
71 |
}
|
72 |
|
73 |
+
if (self.model is None or
|
74 |
+
self.current_model_size != model_name or
|
75 |
+
self.model_config != model_config or
|
76 |
+
self.audio_info.sample_rate != sample_rate):
|
77 |
self.update_model(
|
78 |
model_name=model_name,
|
79 |
device=device,
|
|
|
96 |
|
97 |
@staticmethod
|
98 |
def get_device():
|
99 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
100 |
+
|
101 |
+
def offload(self):
|
102 |
+
if self.model is not None:
|
103 |
+
del self.model
|
104 |
+
self.model = None
|
105 |
+
if self.device == "cuda":
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
gc.collect()
|
108 |
+
self.audio_info = None
|