Tonic commited on
Commit
91d712c
Β·
unverified Β·
1 Parent(s): c1b6b5f

add chord mapping

Browse files
Files changed (4) hide show
  1. build_chord_maps.py +92 -0
  2. extract_chords.py +73 -0
  3. main.py +83 -1
  4. requirements.txt +1 -0
build_chord_maps.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import os
7
+ import pickle
8
+ from tqdm import tqdm
9
+ import argparse
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--chords_folder', type=str, required=True,
15
+ help='path to directory containing parsed chords files')
16
+ parser.add_argument('--output_directory', type=str, required=False,
17
+ help='path to output directory to generate code maps to, \
18
+ if not given - chords_folder would be used', default='')
19
+ parser.add_argument('--path_to_pre_defined_map', type=str, required=False,
20
+ help='for evaluation purpose, use pre-defined chord-to-index map', default='')
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ def get_chord_dict(chord_folder: str):
26
+ chord_dict = {}
27
+ distinct_chords = set()
28
+
29
+ chord_to_index = {} # Mapping between chord and index
30
+ index_counter = 0
31
+
32
+ for filename in tqdm(os.listdir(chord_folder)):
33
+ if filename.endswith(".chords"):
34
+ idx = filename.split(".")[0]
35
+
36
+ with open(os.path.join(chord_folder, filename), "rb") as file:
37
+ chord_data = pickle.load(file)
38
+
39
+ for chord, _ in chord_data:
40
+ distinct_chords.add(chord)
41
+ if chord not in chord_to_index:
42
+ chord_to_index[chord] = index_counter
43
+ index_counter += 1
44
+
45
+ chord_dict[idx] = chord_data
46
+ chord_to_index["UNK"] = index_counter
47
+ return chord_dict, distinct_chords, chord_to_index
48
+
49
+
50
+ def get_predefined_chord_to_index_map(path_to_chords_to_index_map: str):
51
+ def inner(chord_folder: str):
52
+ chords_to_index = pickle.load(open(path_to_chords_to_index_map, "rb"))
53
+ distinct_chords = set(chords_to_index.keys())
54
+ chord_dict = {}
55
+ for filename in tqdm(os.listdir(chord_folder), desc=f'iterating: {chord_folder}'):
56
+ if filename.endswith(".chords"):
57
+ idx = filename.split(".")[0]
58
+
59
+ with open(os.path.join(chord_folder, filename), "rb") as file:
60
+ chord_data = pickle.load(file)
61
+
62
+ chord_dict[idx] = chord_data
63
+ return chord_dict, distinct_chords, chords_to_index
64
+ return inner
65
+
66
+
67
+ if __name__ == "__main__":
68
+ '''This script processes and maps chord data from a directory of parsed chords files,
69
+ generating two output files: a combined chord dictionary and a chord-to-index mapping.'''
70
+ args = parse_args()
71
+ chord_folder = args.chords_folder
72
+ output_dir = args.output_directory
73
+ if output_dir == '':
74
+ output_dir = chord_folder
75
+ func = get_chord_dict
76
+ if args.path_to_pre_defined_map != "":
77
+ func = get_predefined_chord_to_index_map(args.path_to_pre_defined_map)
78
+
79
+ chord_dict, distinct_chords, chord_to_index = func(chord_folder)
80
+
81
+ # Save the combined chord dictionary as a pickle file
82
+ combined_filename = os.path.join(output_dir, "combined_chord_dict.pkl")
83
+ with open(combined_filename, "wb") as file:
84
+ pickle.dump(chord_dict, file)
85
+
86
+ # Save the chord-to-index mapping as a pickle file
87
+ mapping_filename = os.path.join(output_dir, "chord_to_index_mapping.pkl")
88
+ with open(mapping_filename, "wb") as file:
89
+ pickle.dump(chord_to_index, file)
90
+
91
+ print("Number of distinct chords:", len(distinct_chords))
92
+ print("Chord dictionary:", chord_to_index)
extract_chords.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Env - chords_extraction on devfair
2
+
3
+ import pickle
4
+ import argparse
5
+ from chord_extractor.extractors import Chordino # type: ignore
6
+ from chord_extractor import clear_conversion_cache, LabelledChordSequence # type: ignore
7
+ import os
8
+ from tqdm import tqdm
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--src_jsonl_file', type=str, required=True,
14
+ help='abs path to .jsonl file containing list of absolute file paths seperated by new line')
15
+ parser.add_argument('--target_output_dir', type=str, required=True,
16
+ help='target directory to save parsed chord files to, individual files will be saved inside')
17
+ parser.add_argument("--override", action="store_true")
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+
22
+ def save_to_db_cb(tgt_dir: str):
23
+ # Every time one of the files has had chords extracted, receive the chords here
24
+ # along with the name of the original file and then run some logic here, e.g. to
25
+ # save the latest data to DB
26
+ def inner(results: LabelledChordSequence):
27
+ path = results.id.split(".wav")
28
+
29
+ sequence = [(item.chord, item.timestamp) for item in results.sequence]
30
+
31
+ if len(path) != 2:
32
+ print("Something")
33
+ print(path)
34
+ else:
35
+ file_idx = path[0].split("/")[-1]
36
+ with open(f"{tgt_dir}/{file_idx}.chords", "wb") as f:
37
+ # dump the object to the file
38
+ pickle.dump(sequence, f)
39
+ return inner
40
+
41
+
42
+ if __name__ == "__main__":
43
+ '''This script extracts chord data from a list of audio files using the Chordino extractor,
44
+ and saves the extracted chords to individual files in a target directory.'''
45
+ print("parsed args")
46
+ args = parse_args()
47
+ files_to_extract_from = list()
48
+ with open(args.src_jsonl_file, "r") as json_file:
49
+ for line in tqdm(json_file.readlines()):
50
+ # fpath = json.loads(line.replace("\n", ""))['path']
51
+ fpath = line.replace("\n", "")
52
+ if not args.override:
53
+ fname = fpath.split("/")[-1].replace(".wav", ".chords")
54
+ if os.path.exists(f"{args.target_output_dir}/{fname}"):
55
+ continue
56
+ files_to_extract_from.append(line.replace("\n", ""))
57
+
58
+ print(f"num files to parse: {len(files_to_extract_from)}")
59
+
60
+ chordino = Chordino()
61
+
62
+ # Optionally clear cache of file conversions (e.g. wav files that have been converted from midi)
63
+ clear_conversion_cache()
64
+
65
+ # Run bulk extraction
66
+ res = chordino.extract_many(
67
+ files_to_extract_from,
68
+ callback=save_to_db_cb(args.target_output_dir),
69
+ num_extractors=80,
70
+ num_preprocessors=80,
71
+ max_files_in_cache=400,
72
+ stop_on_error=False,
73
+ )
main.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
  import logging
3
  import os
 
4
  from concurrent.futures import ProcessPoolExecutor
5
  from pathlib import Path
6
  from tempfile import NamedTemporaryFile
@@ -23,7 +24,72 @@ MODEL = None
23
  MAX_BATCH_SIZE = 12
24
  INTERRUPTING = False
25
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Wrap subprocess call to clean logs
29
  _old_call = sp.call
@@ -80,14 +146,30 @@ def load_model(version='facebook/jasco-chords-drums-400M'):
80
  print("Loading model", version)
81
  if MODEL is None or MODEL.name != version:
82
  MODEL = None
 
 
 
 
 
 
 
 
 
 
83
  try:
84
- MODEL = JASCO.get_pretrained(version, device='cuda')
 
 
 
 
 
85
  MODEL.name = version
86
  except Exception as e:
87
  raise gr.Error(f"Error loading model: {str(e)}")
88
 
89
  if MODEL is None:
90
  raise gr.Error("Failed to load model")
 
91
  return MODEL
92
 
93
  @spaces.GPU
 
1
  import spaces
2
  import logging
3
  import os
4
+ import pickle
5
  from concurrent.futures import ProcessPoolExecutor
6
  from pathlib import Path
7
  from tempfile import NamedTemporaryFile
 
24
  MAX_BATCH_SIZE = 12
25
  INTERRUPTING = False
26
 
27
+ os.makedirs(os.path.join(os.path.dirname(__file__), "models"), exist_ok=True)
28
 
29
+ def generate_chord_mappings():
30
+ # Define basic chord mappings
31
+ basic_chords = ['N', 'C', 'Dm7', 'Am', 'F', 'D', 'Ab', 'Bb'] + ['UNK']
32
+ chord_to_index = {chord: idx for idx, chord in enumerate(basic_chords)}
33
+
34
+ # Save the mapping
35
+ mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl")
36
+ os.makedirs(os.path.dirname(mapping_path), exist_ok=True)
37
+
38
+ with open(mapping_path, "wb") as f:
39
+ pickle.dump(chord_to_index, f)
40
+
41
+ return mapping_path
42
+
43
+ def create_default_chord_mapping():
44
+ """Create a basic chord-to-index mapping with common chords"""
45
+ basic_chords = [
46
+ 'N', 'C', 'Cm', 'C7', 'Cmaj7', 'Cm7',
47
+ 'D', 'Dm', 'D7', 'Dmaj7', 'Dm7',
48
+ 'E', 'Em', 'E7', 'Emaj7', 'Em7',
49
+ 'F', 'Fm', 'F7', 'Fmaj7', 'Fm7',
50
+ 'G', 'Gm', 'G7', 'Gmaj7', 'Gm7',
51
+ 'A', 'Am', 'A7', 'Amaj7', 'Am7',
52
+ 'B', 'Bm', 'B7', 'Bmaj7', 'Bm7',
53
+ 'Ab', 'Abm', 'Ab7', 'Abmaj7', 'Abm7',
54
+ 'Bb', 'Bbm', 'Bb7', 'Bbmaj7', 'Bbm7',
55
+ 'UNK'
56
+ ]
57
+ return {chord: idx for idx, chord in enumerate(basic_chords)}
58
+
59
+ def initialize_chord_mapping():
60
+ """Initialize chord mapping file if it doesn't exist"""
61
+ mapping_dir = os.path.join(os.path.dirname(__file__), "models")
62
+ os.makedirs(mapping_dir, exist_ok=True)
63
+
64
+ mapping_file = os.path.join(mapping_dir, "chord_to_index_mapping.pkl")
65
+
66
+ if not os.path.exists(mapping_file):
67
+ chord_to_index = create_default_chord_mapping()
68
+ with open(mapping_file, "wb") as f:
69
+ pickle.dump(chord_to_index, f)
70
+
71
+ return mapping_file
72
+
73
+ def validate_chord(chord, chord_mapping):
74
+ if chord not in chord_mapping:
75
+ return 'UNK'
76
+ return chord
77
+
78
+ mapping_file = initialize_chord_mapping()
79
+ os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
80
+
81
+ def chords_string_to_list(chords: str):
82
+ if chords == '':
83
+ return []
84
+ chords = chords.replace('[', '').replace(']', '').replace(' ', '')
85
+ chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
86
+
87
+ # Load chord mapping
88
+ mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl")
89
+ with open(mapping_path, 'rb') as f:
90
+ chord_mapping = pickle.load(f)
91
+
92
+ return [(validate_chord(x[0], chord_mapping), float(x[1])) for x in chrd_times]
93
 
94
  # Wrap subprocess call to clean logs
95
  _old_call = sp.call
 
146
  print("Loading model", version)
147
  if MODEL is None or MODEL.name != version:
148
  MODEL = None
149
+
150
+ # Setup model directory
151
+ model_dir = os.path.join(os.path.dirname(__file__), "models")
152
+ os.makedirs(model_dir, exist_ok=True)
153
+
154
+ # Generate and save chord mappings
155
+ chord_mapping_path = os.path.join(model_dir, "chord_to_index_mapping.pkl")
156
+ if not os.path.exists(chord_mapping_path):
157
+ chord_mapping_path = generate_chord_mappings()
158
+
159
  try:
160
+ # Initialize JASCO with the chord mapping path
161
+ MODEL = JASCO.get_pretrained(
162
+ version,
163
+ device='cuda',
164
+ chords_mapping_path=chord_mapping_path
165
+ )
166
  MODEL.name = version
167
  except Exception as e:
168
  raise gr.Error(f"Error loading model: {str(e)}")
169
 
170
  if MODEL is None:
171
  raise gr.Error("Failed to load model")
172
+
173
  return MODEL
174
 
175
  @spaces.GPU
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  numpy<2.0.0
2
  torch>=2.0.0
 
3
  transformers
4
  accelerate
5
  git+https://github.com/facebookresearch/audiocraft.git
 
1
  numpy<2.0.0
2
  torch>=2.0.0
3
+ torchaudio
4
  transformers
5
  accelerate
6
  git+https://github.com/facebookresearch/audiocraft.git