Spaces:
Running
on
Zero
Running
on
Zero
load models cuda
Browse files
main.py
CHANGED
@@ -71,12 +71,24 @@ def chords_string_to_list(chords: str):
|
|
71 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
72 |
return [(x[0], float(x[1])) for x in chrd_times]
|
73 |
|
|
|
|
|
|
|
|
|
74 |
def load_model(version='facebook/jasco-chords-drums-400M'):
|
75 |
global MODEL
|
76 |
print("Loading model", version)
|
77 |
if MODEL is None or MODEL.name != version:
|
78 |
MODEL = None
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
@spaces.GPU
|
82 |
def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
|
|
|
71 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
72 |
return [(x[0], float(x[1])) for x in chrd_times]
|
73 |
|
74 |
+
# Create necessary directories
|
75 |
+
os.makedirs("models", exist_ok=True)
|
76 |
+
|
77 |
+
@spaces.GPU
|
78 |
def load_model(version='facebook/jasco-chords-drums-400M'):
|
79 |
global MODEL
|
80 |
print("Loading model", version)
|
81 |
if MODEL is None or MODEL.name != version:
|
82 |
MODEL = None
|
83 |
+
try:
|
84 |
+
MODEL = JASCO.get_pretrained(version, device='cuda')
|
85 |
+
MODEL.name = version
|
86 |
+
except Exception as e:
|
87 |
+
raise gr.Error(f"Error loading model: {str(e)}")
|
88 |
+
|
89 |
+
if MODEL is None:
|
90 |
+
raise gr.Error("Failed to load model")
|
91 |
+
return MODEL
|
92 |
|
93 |
@spaces.GPU
|
94 |
def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
|