Tonic commited on
Commit
77cc5e6
Β·
unverified Β·
1 Parent(s): 402a6d6

Add HFTOKEN

Browse files
Files changed (2) hide show
  1. main.py +8 -1
  2. requirements.txt +2 -1
main.py CHANGED
@@ -12,11 +12,17 @@ import gradio as gr
12
  from audiocraft.data.audio_utils import f32_pcm, normalize_audio
13
  from audiocraft.data.audio import audio_write
14
  from audiocraft.models import JASCO
 
 
15
 
16
  MODEL = None
17
  MAX_BATCH_SIZE = 12
18
  INTERRUPTING = False
19
 
 
 
 
 
20
  # Wrap subprocess call to clean logs
21
  _old_call = sp.call
22
 
@@ -63,12 +69,13 @@ def chords_string_to_list(chords: str):
63
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
64
  return [(x[0], float(x[1])) for x in chrd_times]
65
 
 
66
  def load_model(version='facebook/jasco-chords-drums-400M'):
67
  global MODEL
68
  print("Loading model", version)
69
  if MODEL is None or MODEL.name != version:
70
  MODEL = None
71
- MODEL = JASCO.get_pretrained(version)
72
 
73
  @spaces.GPU
74
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
 
12
  from audiocraft.data.audio_utils import f32_pcm, normalize_audio
13
  from audiocraft.data.audio import audio_write
14
  from audiocraft.models import JASCO
15
+ import os
16
+ from huggingface_hub import login
17
 
18
  MODEL = None
19
  MAX_BATCH_SIZE = 12
20
  INTERRUPTING = False
21
 
22
+ hf_token = os.environ.get('HFTOKEN')
23
+ if hf_token:
24
+ login(token=hf_token)
25
+
26
  # Wrap subprocess call to clean logs
27
  _old_call = sp.call
28
 
 
69
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
70
  return [(x[0], float(x[1])) for x in chrd_times]
71
 
72
+
73
  def load_model(version='facebook/jasco-chords-drums-400M'):
74
  global MODEL
75
  print("Loading model", version)
76
  if MODEL is None or MODEL.name != version:
77
  MODEL = None
78
+ MODEL = JASCO.get_pretrained(version, token=hf_token)
79
 
80
  @spaces.GPU
81
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
requirements.txt CHANGED
@@ -8,4 +8,5 @@ scipy
8
  einops
9
  rotary_embedding_torch
10
  xformers
11
- demucs
 
 
8
  einops
9
  rotary_embedding_torch
10
  xformers
11
+ demucs
12
+ huggingface_hub