jhj0517 commited on
Commit
da960ac
·
1 Parent(s): b14146c

Add `offload()`

Browse files
Files changed (1) hide show
  1. modules/uvr/music_separator.py +23 -2
modules/uvr/music_separator.py CHANGED
@@ -4,6 +4,7 @@ import torchaudio
4
  import soundfile as sf
5
  import os
6
  import torch
 
7
 
8
  from uvr.models import MDX, Demucs, VrNetwork, MDXC
9
 
@@ -30,6 +31,14 @@ class MusicSeparator:
30
  model_name: str = "UVR-MDX-NET-Inst_1",
31
  device: Optional[str] = None,
32
  segment_size: int = 256):
 
 
 
 
 
 
 
 
33
  if device is None:
34
  device = self.device
35
 
@@ -61,7 +70,10 @@ class MusicSeparator:
61
  "split": True
62
  }
63
 
64
- if self.model is None or self.current_model_size != model_name or self.model_config != model_config:
 
 
 
65
  self.update_model(
66
  model_name=model_name,
67
  device=device,
@@ -84,4 +96,13 @@ class MusicSeparator:
84
 
85
  @staticmethod
86
  def get_device():
87
- return "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
4
  import soundfile as sf
5
  import os
6
  import torch
7
+ import gc
8
 
9
  from uvr.models import MDX, Demucs, VrNetwork, MDXC
10
 
 
31
  model_name: str = "UVR-MDX-NET-Inst_1",
32
  device: Optional[str] = None,
33
  segment_size: int = 256):
34
+ """
35
+ Update model with the given model name
36
+
37
+ Args:
38
+ model_name (str): Model name.
39
+ device (str): Device to use for the model.
40
+ segment_size (int): Segment size for the prediction.
41
+ """
42
  if device is None:
43
  device = self.device
44
 
 
70
  "split": True
71
  }
72
 
73
+ if (self.model is None or
74
+ self.current_model_size != model_name or
75
+ self.model_config != model_config or
76
+ self.audio_info.sample_rate != sample_rate):
77
  self.update_model(
78
  model_name=model_name,
79
  device=device,
 
96
 
97
  @staticmethod
98
  def get_device():
99
+ return "cuda" if torch.cuda.is_available() else "cpu"
100
+
101
+ def offload(self):
102
+ if self.model is not None:
103
+ del self.model
104
+ self.model = None
105
+ if self.device == "cuda":
106
+ torch.cuda.empty_cache()
107
+ gc.collect()
108
+ self.audio_info = None