Tonic commited on
Commit
d4c4b8e
Β·
unverified Β·
1 Parent(s): 8493b64

fix directory arguments

Browse files
Files changed (1) hide show
  1. main.py +17 -17
main.py CHANGED
@@ -333,19 +333,27 @@ def chords_string_to_list(chords: str):
333
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
334
  return [(x[0], float(x[1])) for x in chrd_times]
335
 
336
- # Add this before model loading
337
  def patch_jasco_cache():
338
  """Monkey patch JASCO cache initialization"""
339
  from audiocraft.modules import jasco_conditioners
340
 
341
- original_init = jasco_conditioners.DrumsConditioner.__init__
 
 
 
 
 
 
342
 
343
  def new_init(self, *args, **kwargs):
344
  if 'cache_path' in kwargs:
345
  kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
346
  return original_init(self, *args, **kwargs)
347
 
348
- jasco_conditioners.DrumsConditioner.__init__ = new_init
 
 
 
349
 
350
  # Apply the patch
351
  patch_jasco_cache()
@@ -364,32 +372,24 @@ def load_model(version='facebook/jasco-chords-drums-melody-400M'):
364
  cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
365
  os.makedirs(cache_path, exist_ok=True)
366
 
367
- # Set additional environment variables
368
  os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
369
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
370
 
371
- # Initialize model with custom cache configuration
372
- model_kwargs = {
373
- 'device': 'cuda',
374
- 'cache_dir': cache_path,
375
- 'model_cache_dir': cache_path
376
- }
377
-
378
  # Initialize chord mapping
379
  mapping_file = initialize_chord_mapping()
380
  os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
381
 
382
- # Load the model with specific cache paths
383
  MODEL = JASCO.get_pretrained(
384
  version,
385
- device='cuda',
386
- cache_dir=cache_path,
387
- local_files_only=False
388
  )
389
  MODEL.name = version
390
 
391
- # Configure model paths
392
- MODEL._cache_dir = cache_path
 
393
 
394
  # Load the chord mapping
395
  with open(mapping_file, 'rb') as f:
 
333
  chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
334
  return [(x[0], float(x[1])) for x in chrd_times]
335
 
 
336
  def patch_jasco_cache():
337
  """Monkey patch JASCO cache initialization"""
338
  from audiocraft.modules import jasco_conditioners
339
 
340
+ if hasattr(jasco_conditioners, 'DrumConditioner'):
341
+ original_init = jasco_conditioners.DrumConditioner.__init__
342
+ elif hasattr(jasco_conditioners, 'DrumsConditioner'):
343
+ original_init = jasco_conditioners.DrumsConditioner.__init__
344
+ else:
345
+ print("Warning: Could not find DrumConditioner class")
346
+ return
347
 
348
  def new_init(self, *args, **kwargs):
349
  if 'cache_path' in kwargs:
350
  kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
351
  return original_init(self, *args, **kwargs)
352
 
353
+ if hasattr(jasco_conditioners, 'DrumConditioner'):
354
+ jasco_conditioners.DrumConditioner.__init__ = new_init
355
+ elif hasattr(jasco_conditioners, 'DrumsConditioner'):
356
+ jasco_conditioners.DrumsConditioner.__init__ = new_init
357
 
358
  # Apply the patch
359
  patch_jasco_cache()
 
372
  cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
373
  os.makedirs(cache_path, exist_ok=True)
374
 
375
+ # Set environment variables for caching
376
  os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
377
  os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
378
 
 
 
 
 
 
 
 
379
  # Initialize chord mapping
380
  mapping_file = initialize_chord_mapping()
381
  os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
382
 
383
+ # Load the model with only supported parameters
384
  MODEL = JASCO.get_pretrained(
385
  version,
386
+ device='cuda'
 
 
387
  )
388
  MODEL.name = version
389
 
390
+ # Configure model paths after loading
391
+ if hasattr(MODEL, '_cache_dir'):
392
+ MODEL._cache_dir = cache_path
393
 
394
  # Load the chord mapping
395
  with open(mapping_file, 'rb') as f: