Tonic commited on
Commit
cfd88ed
·
unverified ·
1 Parent(s): 68fc76e

workaround for gradio file permission -attempt

Browse files
Files changed (1) hide show
  1. main.py +99 -19
main.py CHANGED
@@ -14,7 +14,10 @@ from audiocraft.data.audio_utils import f32_pcm, normalize_audio
14
  from audiocraft.data.audio import audio_write
15
  from audiocraft.models import JASCO
16
  import os
 
17
  from huggingface_hub import login
 
 
18
 
19
  title = """# 🙋🏻‍♂️Welcome to 🌟Tonic's 🎼Jasco🎶AudioCraft Demo"""
20
  description = """Facebook presents JASCO, a temporally controlled text-to-music generation model utilizing both symbolic and audio-based conditions. JASCO can generate high-quality music samples conditioned on global text descriptions along with fine-grained local controls. JASCO is based on the Flow Matching modeling paradigm together with a novel conditioning method, allowing for music generation controlled both locally (e.g., chords) and globally (text description). [run this demo locally](https://huggingface.co/spaces/Tonic/audiocraft?docker=true) or [embed this space](https://huggingface.co/spaces/Tonic/audiocraft?embed=true) or [duplicate this space](https://huggingface.co/spaces/Tonic/audiocraft?duplicate=true) to run it privately . you can also use this demo via API by clicking the link at the bottom of the page."""
@@ -187,15 +190,37 @@ Model: facebook/jasco-chords-drums-1B
187
  - Melody-enabled models may be slower
188
  """
189
 
 
 
 
190
  hf_token = os.environ.get('HFTOKEN')
191
  if hf_token:
192
  login(token=hf_token)
193
 
194
- MODEL = None
195
- MAX_BATCH_SIZE = 12
196
- INTERRUPTING = False
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- os.makedirs(os.path.join(os.path.dirname(__file__), "models"), exist_ok=True)
 
 
 
 
 
 
 
 
199
 
200
  def generate_chord_mappings():
201
  # Define basic chord mappings
@@ -308,41 +333,87 @@ def chords_string_to_list(chords: str):
308
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
309
  return [(x[0], float(x[1])) for x in chrd_times]
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # Create necessary directories
312
  os.makedirs("models", exist_ok=True)
313
 
314
  @spaces.GPU
315
- def load_model(version='facebook/jasco-chords-drums-400M'):
316
  global MODEL
317
  print("Loading model", version)
318
  if MODEL is None or MODEL.name != version:
319
  MODEL = None
320
-
321
- # Setup model directory
322
- model_dir = os.path.join(os.path.dirname(__file__), "models")
323
- os.makedirs(model_dir, exist_ok=True)
324
-
325
- # Generate and save chord mappings
326
- chord_mapping_path = os.path.join(model_dir, "chord_to_index_mapping.pkl")
327
- if not os.path.exists(chord_mapping_path):
328
- chord_mapping_path = generate_chord_mappings()
329
-
330
  try:
331
- # Initialize JASCO with the chord mapping path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  MODEL = JASCO.get_pretrained(
333
  version,
334
  device='cuda',
335
- chords_mapping_path=chord_mapping_path
 
336
  )
337
  MODEL.name = version
 
 
 
 
 
 
 
 
338
  except Exception as e:
339
  raise gr.Error(f"Error loading model: {str(e)}")
340
 
341
  if MODEL is None:
342
  raise gr.Error("Failed to load model")
343
-
344
  return MODEL
345
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  @spaces.GPU
347
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
348
  MODEL.set_generation_params(**gen_kwargs)
@@ -516,4 +587,13 @@ with gr.Blocks() as demo:
516
  outputs=[audio_output_0, audio_output_1]
517
  )
518
 
519
- demo.queue().launch(ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
14
  from audiocraft.data.audio import audio_write
15
  from audiocraft.models import JASCO
16
  import os
17
+ import tempfile
18
  from huggingface_hub import login
19
+ from pathlib import Path
20
+
21
 
22
  title = """# 🙋🏻‍♂️Welcome to 🌟Tonic's 🎼Jasco🎶AudioCraft Demo"""
23
  description = """Facebook presents JASCO, a temporally controlled text-to-music generation model utilizing both symbolic and audio-based conditions. JASCO can generate high-quality music samples conditioned on global text descriptions along with fine-grained local controls. JASCO is based on the Flow Matching modeling paradigm together with a novel conditioning method, allowing for music generation controlled both locally (e.g., chords) and globally (text description). [run this demo locally](https://huggingface.co/spaces/Tonic/audiocraft?docker=true) or [embed this space](https://huggingface.co/spaces/Tonic/audiocraft?embed=true) or [duplicate this space](https://huggingface.co/spaces/Tonic/audiocraft?duplicate=true) to run it privately . you can also use this demo via API by clicking the link at the bottom of the page."""
 
190
  - Melody-enabled models may be slower
191
  """
192
 
193
+ MODEL = None
194
+ INTERRUPTING = False
195
+
196
  hf_token = os.environ.get('HFTOKEN')
197
  if hf_token:
198
  login(token=hf_token)
199
 
200
+ # Set up cache directory
201
+ CACHE_DIR = os.path.join(tempfile.gettempdir(), 'audiocraft_cache')
202
+ os.makedirs(CACHE_DIR, exist_ok=True)
203
+
204
+ # Set environment variables
205
+ os.environ['AUDIOCRAFT_CACHE_DIR'] = CACHE_DIR
206
+ os.environ['TORCH_HOME'] = CACHE_DIR
207
+ os.environ['HF_HOME'] = os.path.join(CACHE_DIR, 'huggingface')
208
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')
209
+ os.environ['XDG_CACHE_HOME'] = CACHE_DIR
210
+
211
+ # Create necessary subdirectories
212
+ for subdir in ['models', 'cache', 'huggingface', 'drum_cache', 'transformers']:
213
+ os.makedirs(os.path.join(CACHE_DIR, subdir), exist_ok=True)
214
 
215
+ def cleanup_cache():
216
+ """Clean up temporary cache files"""
217
+ try:
218
+ import shutil
219
+ if os.path.exists(CACHE_DIR):
220
+ shutil.rmtree(CACHE_DIR)
221
+ os.makedirs(CACHE_DIR, exist_ok=True)
222
+ except Exception as e:
223
+ print(f"Error cleaning cache: {e}")
224
 
225
  def generate_chord_mappings():
226
  # Define basic chord mappings
 
333
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
334
  return [(x[0], float(x[1])) for x in chrd_times]
335
 
336
+ # Add this before model loading
337
+ def patch_jasco_cache():
338
+ """Monkey patch JASCO cache initialization"""
339
+ from audiocraft.modules import jasco_conditioners
340
+
341
+ original_init = jasco_conditioners.DrumConditioner.__init__
342
+
343
+ def new_init(self, *args, **kwargs):
344
+ if 'cache_path' in kwargs:
345
+ kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
346
+ return original_init(self, *args, **kwargs)
347
+
348
+ jasco_conditioners.DrumConditioner.__init__ = new_init
349
+
350
+ # Apply the patch
351
+ patch_jasco_cache()
352
+
353
  # Create necessary directories
354
  os.makedirs("models", exist_ok=True)
355
 
356
  @spaces.GPU
357
+ def load_model(version='facebook/jasco-chords-drums-melody-400M'):
358
  global MODEL
359
  print("Loading model", version)
360
  if MODEL is None or MODEL.name != version:
361
  MODEL = None
 
 
 
 
 
 
 
 
 
 
362
  try:
363
+ # Set up custom cache paths
364
+ cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
365
+ os.makedirs(cache_path, exist_ok=True)
366
+
367
+ # Set additional environment variables
368
+ os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
369
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
370
+
371
+ # Initialize model with custom cache configuration
372
+ model_kwargs = {
373
+ 'device': 'cuda',
374
+ 'cache_dir': cache_path,
375
+ 'model_cache_dir': cache_path
376
+ }
377
+
378
+ # Initialize chord mapping
379
+ mapping_file = initialize_chord_mapping()
380
+ os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
381
+
382
+ # Load the model with specific cache paths
383
  MODEL = JASCO.get_pretrained(
384
  version,
385
  device='cuda',
386
+ cache_dir=cache_path,
387
+ local_files_only=False
388
  )
389
  MODEL.name = version
390
+
391
+ # Configure model paths
392
+ MODEL._cache_dir = cache_path
393
+
394
+ # Load the chord mapping
395
+ with open(mapping_file, 'rb') as f:
396
+ MODEL.chord_to_index = pickle.load(f)
397
+
398
  except Exception as e:
399
  raise gr.Error(f"Error loading model: {str(e)}")
400
 
401
  if MODEL is None:
402
  raise gr.Error("Failed to load model")
 
403
  return MODEL
404
 
405
+ class ModelLoadingError(Exception):
406
+ pass
407
+
408
+ def handle_model_error(func):
409
+ def wrapper(*args, **kwargs):
410
+ try:
411
+ return func(*args, **kwargs)
412
+ except Exception as e:
413
+ print(f"Error in {func.__name__}: {str(e)}")
414
+ raise ModelLoadingError(f"Failed to load model: {str(e)}")
415
+ return wrapper
416
+
417
  @spaces.GPU
418
  def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
419
  MODEL.set_generation_params(**gen_kwargs)
 
587
  outputs=[audio_output_0, audio_output_1]
588
  )
589
 
590
+ # Add cleanup on close
591
+ demo.load(lambda: None, [], [], _js="() => { window.addEventListener('beforeunload', () => { cleanup_cache(); }); }")
592
+ # Launch with cleanup and error handling
593
+ try:
594
+ demo.queue().launch(ssr_mode=False
595
+ )
596
+ except Exception as e:
597
+ print(f"Error launching demo: {str(e)}")
598
+ finally:
599
+ cleanup_cache()