Tonic commited on
Commit
c1b6b5f
Β·
unverified Β·
1 Parent(s): db6cbf2

load models cuda

Browse files
Files changed (1) hide show
  1. main.py +13 -1
main.py CHANGED
@@ -71,12 +71,24 @@ def chords_string_to_list(chords: str):
71
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
72
  return [(x[0], float(x[1])) for x in chrd_times]
73
 
 
 
 
 
74
  def load_model(version='facebook/jasco-chords-drums-400M'):
75
  global MODEL
76
  print("Loading model", version)
77
  if MODEL is None or MODEL.name != version:
78
  MODEL = None
79
- MODEL = JASCO.get_pretrained(version)
 
 
 
 
 
 
 
 
80
 
81
  @spaces.GPU
82
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
 
71
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
72
  return [(x[0], float(x[1])) for x in chrd_times]
73
 
74
+ # Create necessary directories
75
+ os.makedirs("models", exist_ok=True)
76
+
77
+ @spaces.GPU
78
  def load_model(version='facebook/jasco-chords-drums-400M'):
79
  global MODEL
80
  print("Loading model", version)
81
  if MODEL is None or MODEL.name != version:
82
  MODEL = None
83
+ try:
84
+ MODEL = JASCO.get_pretrained(version, device='cuda')
85
+ MODEL.name = version
86
+ except Exception as e:
87
+ raise gr.Error(f"Error loading model: {str(e)}")
88
+
89
+ if MODEL is None:
90
+ raise gr.Error("Failed to load model")
91
+ return MODEL
92
 
93
  @spaces.GPU
94
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):