Spaces:
Running
on
Zero
Running
on
Zero
fix directory arguments
Browse files
main.py
CHANGED
@@ -333,19 +333,27 @@ def chords_string_to_list(chords: str):
|
|
333 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
334 |
return [(x[0], float(x[1])) for x in chrd_times]
|
335 |
|
336 |
-
# Add this before model loading
|
337 |
def patch_jasco_cache():
|
338 |
"""Monkey patch JASCO cache initialization"""
|
339 |
from audiocraft.modules import jasco_conditioners
|
340 |
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
def new_init(self, *args, **kwargs):
|
344 |
if 'cache_path' in kwargs:
|
345 |
kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
|
346 |
return original_init(self, *args, **kwargs)
|
347 |
|
348 |
-
jasco_conditioners
|
|
|
|
|
|
|
349 |
|
350 |
# Apply the patch
|
351 |
patch_jasco_cache()
|
@@ -364,32 +372,24 @@ def load_model(version='facebook/jasco-chords-drums-melody-400M'):
|
|
364 |
cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
|
365 |
os.makedirs(cache_path, exist_ok=True)
|
366 |
|
367 |
-
# Set
|
368 |
os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
|
369 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
|
370 |
|
371 |
-
# Initialize model with custom cache configuration
|
372 |
-
model_kwargs = {
|
373 |
-
'device': 'cuda',
|
374 |
-
'cache_dir': cache_path,
|
375 |
-
'model_cache_dir': cache_path
|
376 |
-
}
|
377 |
-
|
378 |
# Initialize chord mapping
|
379 |
mapping_file = initialize_chord_mapping()
|
380 |
os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
|
381 |
|
382 |
-
# Load the model with
|
383 |
MODEL = JASCO.get_pretrained(
|
384 |
version,
|
385 |
-
device='cuda'
|
386 |
-
cache_dir=cache_path,
|
387 |
-
local_files_only=False
|
388 |
)
|
389 |
MODEL.name = version
|
390 |
|
391 |
-
# Configure model paths
|
392 |
-
MODEL
|
|
|
393 |
|
394 |
# Load the chord mapping
|
395 |
with open(mapping_file, 'rb') as f:
|
|
|
333 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
334 |
return [(x[0], float(x[1])) for x in chrd_times]
|
335 |
|
|
|
336 |
def patch_jasco_cache():
|
337 |
"""Monkey patch JASCO cache initialization"""
|
338 |
from audiocraft.modules import jasco_conditioners
|
339 |
|
340 |
+
if hasattr(jasco_conditioners, 'DrumConditioner'):
|
341 |
+
original_init = jasco_conditioners.DrumConditioner.__init__
|
342 |
+
elif hasattr(jasco_conditioners, 'DrumsConditioner'):
|
343 |
+
original_init = jasco_conditioners.DrumsConditioner.__init__
|
344 |
+
else:
|
345 |
+
print("Warning: Could not find DrumConditioner class")
|
346 |
+
return
|
347 |
|
348 |
def new_init(self, *args, **kwargs):
|
349 |
if 'cache_path' in kwargs:
|
350 |
kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
|
351 |
return original_init(self, *args, **kwargs)
|
352 |
|
353 |
+
if hasattr(jasco_conditioners, 'DrumConditioner'):
|
354 |
+
jasco_conditioners.DrumConditioner.__init__ = new_init
|
355 |
+
elif hasattr(jasco_conditioners, 'DrumsConditioner'):
|
356 |
+
jasco_conditioners.DrumsConditioner.__init__ = new_init
|
357 |
|
358 |
# Apply the patch
|
359 |
patch_jasco_cache()
|
|
|
372 |
cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
|
373 |
os.makedirs(cache_path, exist_ok=True)
|
374 |
|
375 |
+
# Set environment variables for caching
|
376 |
os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
|
377 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
# Initialize chord mapping
|
380 |
mapping_file = initialize_chord_mapping()
|
381 |
os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
|
382 |
|
383 |
+
# Load the model with only supported parameters
|
384 |
MODEL = JASCO.get_pretrained(
|
385 |
version,
|
386 |
+
device='cuda'
|
|
|
|
|
387 |
)
|
388 |
MODEL.name = version
|
389 |
|
390 |
+
# Configure model paths after loading
|
391 |
+
if hasattr(MODEL, '_cache_dir'):
|
392 |
+
MODEL._cache_dir = cache_path
|
393 |
|
394 |
# Load the chord mapping
|
395 |
with open(mapping_file, 'rb') as f:
|