mrfakename commited on
Commit
aa189b2
1 Parent(s): fb85ce0

Create midi_util.py

Browse files
Files changed (1) hide show
  1. midi_util.py +576 -0
midi_util.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from math import ceil, floor, log
6
+ from typing import Dict, Iterator, List, Optional, Tuple
7
+
8
+ import mido
9
+
10
+
11
+ @dataclass
12
+ class VocabConfig:
13
+ # Number of note events. Should be 128.
14
+ note_events: int
15
+ # Number of wait events. Configurable, must evenly divide max_wait_time.
16
+ wait_events: int
17
+ # Max wait time in milliseconds to be represented by a single token.
18
+ max_wait_time: int
19
+ # Number of velocity events. Should be 128 (or 100? need to check midi standard)
20
+ velocity_events: int
21
+ # Number of bins to quantize velocity into. Should evenly divide velocity_events.
22
+ velocity_bins: int
23
+ # Exponential scaling factor for velocity bin sizes. 1.0 = linear scaling.
24
+ velocity_exp: float
25
+ # Whether to sort tokens by instrument, note. This should improve data reducibility.
26
+ do_token_sorting: bool
27
+ # Whether tokens should be represented as combined instrument/note/velocity tokens, or separate tokens for each.
28
+ unrolled_tokens: bool
29
+ # If non-zero, notes held for this many seconds will be automatically released during str->midi decoding.
30
+ decode_end_held_note_delay: float
31
+ # If true, repeated notes will be automatically released before playing again during str->midi decoding.
32
+ decode_fix_repeated_notes: bool
33
+ # List of instrument names to use for binning. Must have at most 16 values.
34
+ bin_instrument_names: List[str]
35
+ # Indicates which bin name represents percussion instruments on MIDI channel 10.
36
+ ch10_instrument_bin_name: str
37
+ # Mapping from instrument name to bin name.
38
+ program_name_to_bin_name: Dict[str, str]
39
+ # Mapping from bin name to program name.
40
+ bin_name_to_program_name: Dict[str, str]
41
+ # Mapping from program number to instrument name.
42
+ instrument_names: Dict[str, str]
43
+ # Manual override for velocity bins. Each element is the max velocity value for that bin by index.
44
+ velocity_bins_override: Optional[List[int]] = None
45
+
46
+ def __post_init__(self):
47
+ self.validate()
48
+
49
+ self._instrument_names_str_to_int = {name: int(i) for i, name in self.instrument_names.items()}
50
+ self._instrument_names_int_to_str = {int(i): name for i, name in self.instrument_names.items()}
51
+
52
+ self._bin_str_to_int = {name: int(i) for i, name in enumerate(self.bin_instrument_names)}
53
+
54
+ self._bin_int_to_instrument_int = [self._instrument_names_str_to_int[self.bin_name_to_program_name[name]] if name != self.ch10_instrument_bin_name else 0 for name in self.bin_instrument_names]
55
+ self._instrument_int_to_bin_int = [self._bin_str_to_int[self.program_name_to_bin_name[instr]] if self.program_name_to_bin_name[instr] != "" else -1 for instr in self.program_name_to_bin_name.keys()]
56
+
57
+ self._ch10_bin_int = self._bin_str_to_int[self.ch10_instrument_bin_name] if self.ch10_instrument_bin_name else -1
58
+
59
+ self.short_instr_bin_names = []
60
+ for instr in self.bin_instrument_names:
61
+ i = min(1, len(instr))
62
+ while instr[:i] in self.short_instr_bin_names:
63
+ i += 1
64
+ self.short_instr_bin_names.append(instr[:i])
65
+ self._short_instrument_names_str_to_int = {name: int(i) for i, name in enumerate(self.short_instr_bin_names)}
66
+
67
+ range_excluding_ch10 = [(i if i < 9 else i+1) for i in range(len(self.bin_instrument_names))]
68
+ bins_excluding_ch10 = [n for n in self.bin_instrument_names if n != self.ch10_instrument_bin_name]
69
+ self.bin_channel_map = {bin: channel for channel, bin in zip(range_excluding_ch10, bins_excluding_ch10)}
70
+ if self.ch10_instrument_bin_name:
71
+ self.bin_channel_map[self.ch10_instrument_bin_name] = 9
72
+
73
+ def validate(self):
74
+ if self.max_wait_time % self.wait_events != 0:
75
+ raise ValueError("max_wait_time must be exactly divisible by wait_events")
76
+ if self.velocity_bins < 2:
77
+ raise ValueError("velocity_bins must be at least 2")
78
+ if len(self.bin_instrument_names) > 16:
79
+ raise ValueError("bin_instruments must have at most 16 values")
80
+ if self.velocity_bins_override:
81
+ print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
82
+ if len(self.velocity_bins_override) != self.velocity_bins:
83
+ raise ValueError("velocity_bins_override must have same length as velocity_bins")
84
+ if self.ch10_instrument_bin_name and self.ch10_instrument_bin_name not in self.bin_instrument_names:
85
+ raise ValueError("ch10_instrument_bin_name must be in bin_instruments")
86
+ if self.velocity_exp <= 0:
87
+ raise ValueError("velocity_exp must be greater than 0")
88
+
89
+ @classmethod
90
+ def from_json(cls, path: str):
91
+ with open(path, "r") as f:
92
+ config = json.load(f)
93
+ return cls(**config)
94
+
95
+
96
+ class VocabUtils:
97
+ def __init__(self, cfg: VocabConfig) -> None:
98
+ self.cfg = cfg
99
+
100
+ @lru_cache(maxsize=128)
101
+ def format_wait_token(self, wait: int) -> str:
102
+ return f"t{wait}"
103
+
104
+ @lru_cache(maxsize=128)
105
+ def format_note_token(self, instrument_bin: int, note: int, velocity_bin: int) -> str:
106
+ return f"{self.cfg.short_instr_bin_names[instrument_bin]}:{note:x}:{velocity_bin:x}"
107
+
108
+ def format_unrolled_note(self, note: int) -> str:
109
+ return f"n{note:x}"
110
+
111
+ def format_unrolled_velocity(self, velocity_bin: int) -> str:
112
+ return f"v{velocity_bin:x}"
113
+
114
+ def format_unrolled_instrument_bin(self, instrument_bin: int) -> str:
115
+ return f"i{self.cfg.short_instr_bin_names[instrument_bin]}"
116
+
117
+ def velocity_to_bin(self, velocity: float) -> int:
118
+ velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
119
+ if self.cfg.velocity_bins_override:
120
+ for i, v in enumerate(self.cfg.velocity_bins_override):
121
+ if velocity <= v:
122
+ return i
123
+ return 0
124
+ binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
125
+ if self.cfg.velocity_exp == 1.0:
126
+ return ceil(velocity / binsize)
127
+ else:
128
+ return ceil((self.cfg.velocity_events*((self.cfg.velocity_exp**(velocity/self.cfg.velocity_events)-1.0) / (self.cfg.velocity_exp-1.0))) / binsize)
129
+
130
+ def bin_to_velocity(self, bin: int) -> int:
131
+ if self.cfg.velocity_bins_override:
132
+ return self.cfg.velocity_bins_override[bin]
133
+ binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
134
+ if self.cfg.velocity_exp == 1.0:
135
+ return max(0, ceil(bin * binsize - 1))
136
+ else:
137
+ return max(0, ceil(self.cfg.velocity_events*log(((self.cfg.velocity_exp-1)*binsize*bin)/self.cfg.velocity_events+1, self.cfg.velocity_exp) - 1))
138
+
139
+ def delta_to_wait_ids(self, delta_ms: float) -> Iterator[int]:
140
+ def roundi(f: float):
141
+ return ceil(f - 0.5)
142
+
143
+ max_wait_ms = self.cfg.max_wait_time
144
+ div = max_wait_ms / self.cfg.wait_events
145
+
146
+ #if delta_ms // max_wait_ms > 512: # arbitrary limit to avoid excessive time_shifts
147
+ # raise ValueError("delta_time is too large")
148
+ if delta_ms > max_wait_ms * 10:
149
+ delta_ms = max_wait_ms * 10 # truncate time
150
+
151
+ for _ in range(floor(delta_ms / max_wait_ms)):
152
+ yield roundi(max_wait_ms / div)
153
+ leftover_time_shift = roundi((delta_ms % max_wait_ms) / div)
154
+ if leftover_time_shift > 0:
155
+ yield leftover_time_shift
156
+
157
+ def prog_data_to_token_data(self, program: int, channel: int, note: int, velocity: float) -> Optional[Tuple[int, int, int]]:
158
+ if channel == 9:
159
+ if self.cfg._ch10_bin_int == -1:
160
+ return None
161
+ return self.cfg._ch10_bin_int, note, self.velocity_to_bin(velocity)
162
+
163
+ instrument_bin = self.cfg._instrument_int_to_bin_int[program]
164
+ if instrument_bin != -1:
165
+ return instrument_bin, note, self.velocity_to_bin(velocity)
166
+ return None
167
+
168
+ def prog_data_list_to_token_data_list(self, data: List[Tuple[int, int, int, float]]) -> Iterator[Tuple[int, int, int]]:
169
+ for d in data:
170
+ token_data = self.prog_data_to_token_data(*d)
171
+ if token_data is not None:
172
+ yield token_data
173
+
174
+ def sort_token_data(self, data: List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]:
175
+ # ensure order is preserved for tokens with the same instrument, note
176
+ data = [(i, n, v, x) for x, (i, n, v) in enumerate(data)]
177
+ data.sort(key=lambda x: (x[0]!=self.cfg._ch10_bin_int, x[0], x[1], x[3]))
178
+ return [(i, n, v) for i, n, v, _ in data]
179
+
180
+ def data_to_wait_tokens(self, delta_ms: float) -> List[str]:
181
+ if delta_ms == 0.0:
182
+ return []
183
+ return [self.format_wait_token(i) for i in self.delta_to_wait_ids(delta_ms)]
184
+
185
+ def wait_token_to_delta(self, token: str) -> float:
186
+ return self.cfg.max_wait_time / self.cfg.wait_events * int(token[1:])
187
+
188
+ def note_token_to_data(self, token: str) -> Tuple[int, int, int]:
189
+ instr_str, note_str, velocity_str = token.strip().split(":")
190
+ instr_bin = self.cfg._short_instrument_names_str_to_int[instr_str]
191
+ note = int(note_str, base=16)
192
+ velocity = self.bin_to_velocity(int(velocity_str, base=16))
193
+ return instr_bin, note, velocity
194
+
195
+
196
+ @dataclass
197
+ class AugmentValues:
198
+ instrument_bin_remap: Dict[int, int]
199
+ velocity_mod_factor: float
200
+ transpose_semitones: int
201
+ time_stretch_factor: float
202
+
203
+ @classmethod
204
+ def default(cls) -> "AugmentValues":
205
+ return cls(
206
+ instrument_bin_remap={},
207
+ velocity_mod_factor=1.0,
208
+ transpose_semitones=0,
209
+ time_stretch_factor=1.0,
210
+ )
211
+
212
+
213
+ @dataclass
214
+ class AugmentConfig:
215
+ # The number of times to augment each MIDI file. The dataset size will be multiplied by this number.
216
+ augment_data_factor: int
217
+ # A list of instrument names to randomly swap with each other.
218
+ instrument_mixups: List[List[str]]
219
+ # A list of percentages to change the note velocity by. 0.0 = no change. 0 is included by default.
220
+ velocity_mod_pct: List[float]
221
+ # A list of semitones to transpose by. 0 is included by default.
222
+ transpose_semitones: List[int]
223
+ # A list of percentages to stretch the tempo by. 0.0 = no stretch. 0 is included by default.
224
+ time_stretch_pct: List[float]
225
+ # Random seed to use for reproducibility.
226
+ seed: int
227
+
228
+ cfg: VocabConfig
229
+
230
+ def __post_init__(self):
231
+ self.validate()
232
+ if len(self.velocity_mod_pct) == 0:
233
+ self.velocity_mod_pct = [0.0]
234
+ if len(self.transpose_semitones) == 0:
235
+ self.transpose_semitones = [0]
236
+ if len(self.time_stretch_pct) == 0:
237
+ self.time_stretch_pct = [0.0]
238
+
239
+ self._instrument_mixups_int = [[self.cfg._bin_str_to_int[i] for i in l if i in self.cfg._bin_str_to_int] for l in self.instrument_mixups]
240
+ self._instrument_mixups_int = [l for l in self._instrument_mixups_int if len(l) > 0] # remove empty lists
241
+ self._instrument_pool_assignments = {}
242
+ self._mixup_pools = []
243
+ for pool_i, mixup_list in enumerate(self._instrument_mixups_int):
244
+ pool = set()
245
+ for i in mixup_list:
246
+ pool.add(i)
247
+ self._instrument_pool_assignments[i] = pool_i
248
+ self._mixup_pools.append(pool)
249
+
250
+
251
+ def validate(self):
252
+ if self.augment_data_factor < 1:
253
+ raise ValueError("augment_data_factor must be at least 1")
254
+ used_instruments = set()
255
+ for mixup_list in self.instrument_mixups:
256
+ for n in mixup_list:
257
+ if n in used_instruments:
258
+ raise ValueError(f"Duplicate instrument name: {n}")
259
+ used_instruments.add(n)
260
+
261
+ @classmethod
262
+ def from_json(cls, path: str, cfg: VocabConfig):
263
+ with open(path, "r") as f:
264
+ config = json.load(f)
265
+ config["cfg"] = cfg
266
+ if "seed" not in config:
267
+ config["seed"] = random.randint(0, 2**32 - 1)
268
+ return cls(**config)
269
+
270
+ def get_augment_values(self, filename: str) -> Iterator[AugmentValues]:
271
+ # first yield default values
272
+ yield AugmentValues.default()
273
+
274
+ rng = random.Random(self.seed + hash(filename))
275
+ for _ in range(int(self.augment_data_factor - 1)):
276
+ # randomize order for each pool
277
+ randomized_pools = [list(pool) for pool in self._mixup_pools]
278
+ for pool in randomized_pools:
279
+ rng.shuffle(pool)
280
+ # distribute reassignments
281
+ instrument_bin_remap = {}
282
+ for i, pool in enumerate(randomized_pools):
283
+ for j, instrument in enumerate(pool):
284
+ instrument_bin_remap[instrument] = randomized_pools[i - 1][j]
285
+ yield AugmentValues(
286
+ instrument_bin_remap=instrument_bin_remap,
287
+ velocity_mod_factor=1.0 + rng.choice(self.velocity_mod_pct),
288
+ transpose_semitones=rng.choice(self.transpose_semitones),
289
+ time_stretch_factor=1.0 + rng.choice(self.time_stretch_pct),
290
+ )
291
+
292
+
293
+ @dataclass
294
+ class FilterConfig:
295
+ # Whether to filter out MIDI files with duplicate MD5 hashes.
296
+ deduplicate_md5: bool
297
+ # Minimum time delay between notes in a file before splitting into multiple documents.
298
+ piece_split_delay: float
299
+ # Minimum length of a piece in milliseconds.
300
+ min_piece_length: float
301
+
302
+ @classmethod
303
+ def from_json(cls, path: str):
304
+ with open(path, "r") as f:
305
+ config = json.load(f)
306
+ return cls(**config)
307
+
308
+
309
+ def mix_volume(velocity: int, volume: int, expression: int) -> float:
310
+ return velocity * (volume / 127.0) * (expression / 127.0)
311
+
312
+
313
+ def convert_midi_to_str(cfg: VocabConfig, filter_cfg: FilterConfig, mid: mido.MidiFile, augment: AugmentValues = None) -> List[str]:
314
+ utils = VocabUtils(cfg)
315
+ if augment is None:
316
+ augment = AugmentValues.default()
317
+
318
+ # filter out unknown meta messages before merge (https://github.com/mido/mido/pull/286)
319
+ for i in range(len(mid.tracks)):
320
+ mid.tracks[i] = [msg for msg in mid.tracks[i] if msg.type != "unknown_meta"]
321
+
322
+ if len(mid.tracks) > 1:
323
+ mid.tracks = [mido.merge_tracks(mid.tracks)]
324
+
325
+ delta_time_ms = 0.0
326
+ tempo = 500000
327
+ channel_program = {i: 0 for i in range(16)}
328
+ channel_volume = {i: 127 for i in range(16)}
329
+ channel_expression = {i: 127 for i in range(16)} # unlikely to be useful. expression usually modifies an already played note.
330
+ channel_notes = {i: {} for i in range(16)}
331
+ channel_pedal_on = {i: False for i in range(16)}
332
+ channel_pedal_events = {i: {} for i in range(16)} # {channel: {(note, program) -> True}}
333
+ started_flag = False
334
+
335
+ output_list = []
336
+ output = ["<start>"]
337
+ output_length_ms = 0.0
338
+ token_data_buffer: List[Tuple[int, int, int, float]] = [] # need to sort notes between wait tokens
339
+
340
+ def flush_token_data_buffer():
341
+ nonlocal token_data_buffer, output, cfg, utils, augment
342
+ token_data = [x for x in utils.prog_data_list_to_token_data_list(token_data_buffer)]
343
+ if augment.instrument_bin_remap or augment.transpose_semitones:
344
+ # TODO put transpose in a real function
345
+ raw_transpose = lambda bin, n: n + augment.transpose_semitones if bin != cfg._ch10_bin_int else n
346
+ octave_shift_if_oob = lambda n: n + 12 if n < 0 else n - 12 if n >= cfg.note_events else n
347
+ # TODO handle ranges beyond 12
348
+ #octave_shift_if_oob = lambda n: 0 if n < 0 else (n - cfg.note_events) % 12 + cfg.note_events if n >= cfg.note_events else n
349
+ transpose = lambda bin, n: octave_shift_if_oob(raw_transpose(bin, n))
350
+
351
+ token_data = [(augment.instrument_bin_remap.get(i, i), transpose(i, n), v) for i, n, v in token_data]
352
+ if cfg.do_token_sorting:
353
+ token_data = utils.sort_token_data(token_data)
354
+ if cfg.unrolled_tokens:
355
+ for t in token_data:
356
+ output += [utils.format_unrolled_instrument_bin(t[0]), utils.format_unrolled_note(t[1]), utils.format_unrolled_velocity(t[2])]
357
+ else:
358
+ output += [utils.format_note_token(*t) for t in token_data]
359
+ token_data_buffer = []
360
+
361
+ def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
362
+ nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
363
+ is_token_valid = utils.prog_data_to_token_data(prog, chan, note, vel) is not None
364
+ if not is_token_valid:
365
+ return
366
+
367
+ if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
368
+ # check if any notes are still held
369
+ silent = True
370
+ for channel in channel_notes.keys():
371
+ if len(channel_notes[channel]) > 0:
372
+ silent = False
373
+ break
374
+ if silent:
375
+ flush_token_data_buffer()
376
+ output.append("<end>")
377
+ if output_length_ms > filter_cfg.min_piece_length * 1000.0:
378
+ output_list.append(" ".join(output))
379
+ output = ["<start>"]
380
+ output_length_ms = 0.0
381
+ started_flag = False
382
+ if started_flag:
383
+ wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
384
+ if len(wait_tokens) > 0:
385
+ flush_token_data_buffer()
386
+ output_length_ms += delta_time_ms
387
+ output += wait_tokens
388
+ delta_time_ms = 0.0
389
+ token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
390
+ started_flag = True
391
+
392
+ for msg in mid.tracks[0]:
393
+ time_ms = mido.tick2second(msg.time, mid.ticks_per_beat, tempo) * 1000.0
394
+ delta_time_ms += time_ms
395
+ t = msg.type
396
+
397
+ if msg.is_meta:
398
+ if t == "set_tempo":
399
+ tempo = msg.tempo * augment.time_stretch_factor
400
+ continue
401
+
402
+ def handle_note_off(ch, prog, n):
403
+ if channel_pedal_on[ch]:
404
+ channel_pedal_events[ch][(n, prog)] = True
405
+ else:
406
+ consume_note_program_data(prog, ch, n, 0)
407
+ if n in channel_notes[ch]:
408
+ del channel_notes[ch][n]
409
+
410
+ if t == "program_change":
411
+ channel_program[msg.channel] = msg.program
412
+ elif t == "note_on":
413
+ if msg.velocity == 0:
414
+ handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
415
+ else:
416
+ if (msg.note, channel_program[msg.channel]) in channel_pedal_events[msg.channel]:
417
+ del channel_pedal_events[msg.channel][(msg.note, channel_program[msg.channel])]
418
+ consume_note_program_data(
419
+ channel_program[msg.channel],
420
+ msg.channel,
421
+ msg.note,
422
+ mix_volume(msg.velocity, channel_volume[msg.channel], channel_expression[msg.channel]),
423
+ )
424
+ channel_notes[msg.channel][msg.note] = True
425
+ elif t == "note_off":
426
+ handle_note_off(msg.channel, channel_program[msg.channel], msg.note)
427
+ elif t == "control_change":
428
+ if msg.control == 7 or msg.control == 39: # volume
429
+ channel_volume[msg.channel] = msg.value
430
+ elif msg.control == 11: # expression
431
+ channel_expression[msg.channel] = msg.value
432
+ elif msg.control == 64: # sustain pedal
433
+ channel_pedal_on[msg.channel] = msg.value >= 64
434
+ if not channel_pedal_on[msg.channel]:
435
+ for (note, program) in channel_pedal_events[msg.channel]:
436
+ handle_note_off(msg.channel, program, note)
437
+ channel_pedal_events[msg.channel] = {}
438
+ elif msg.control == 123: # all notes off
439
+ for channel in channel_notes.keys():
440
+ for note in list(channel_notes[channel]).copy():
441
+ handle_note_off(channel, channel_program[channel], note)
442
+ else:
443
+ pass
444
+
445
+ flush_token_data_buffer()
446
+ output.append("<end>")
447
+ if output_length_ms > filter_cfg.min_piece_length * 1000.0:
448
+ output_list.append(" ".join(output))
449
+ return output_list
450
+
451
+
452
+ def generate_program_change_messages(cfg: VocabConfig):
453
+ for bin_name, channel in cfg.bin_channel_map.items():
454
+ if channel == 9:
455
+ continue
456
+ program = cfg._instrument_names_str_to_int[cfg.bin_name_to_program_name[bin_name]]
457
+ yield mido.Message("program_change", program=program, time=0, channel=channel)
458
+ yield mido.Message("program_change", program=0, time=0, channel=9)
459
+
460
+
461
+ @dataclass
462
+ class DecodeState:
463
+ total_time: float # milliseconds
464
+ delta_accum: float # milliseconds
465
+ current_bin: int
466
+ current_note: int
467
+ active_notes: Dict[Tuple[int, int], float] # { (channel, note): time started, ... }
468
+
469
+
470
+ def token_to_midi_message(utils: VocabUtils, token: str, state: DecodeState, end_token_pause: float = 3.0) -> Iterator[Tuple[Optional[mido.Message], DecodeState]]:
471
+ if state is None:
472
+ state = DecodeState(total_time=0.0, delta_accum=0.0, current_bin=utils.cfg._short_instrument_names_str_to_int[utils.cfg.short_instr_bin_names[0]], current_note=0, active_notes={})
473
+ token = token.strip()
474
+ if not token:
475
+ yield None, state
476
+ return
477
+ if token == "<end>":
478
+ d = end_token_pause * 1000.0
479
+ state.delta_accum += d
480
+ state.total_time += d
481
+ if utils.cfg.decode_end_held_note_delay != 0.0:
482
+ # end held notes
483
+ for (channel, note), start_time in list(state.active_notes.items()).copy():
484
+ ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
485
+ state.delta_accum = 0.0
486
+ del state.active_notes[(channel, note)]
487
+ yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
488
+ yield None, state
489
+ return
490
+ if token.startswith("<"):
491
+ yield None, state
492
+ return
493
+
494
+ if utils.cfg.unrolled_tokens:
495
+ if token[0] == "t":
496
+ d = utils.wait_token_to_delta(token)
497
+ state.delta_accum += d
498
+ state.total_time += d
499
+ elif token[0] == "n":
500
+ state.current_note = int(token[1:], base=16)
501
+ elif token[0] == "i":
502
+ state.current_bin = utils.cfg._short_instrument_names_str_to_int[token[1:]]
503
+ elif token[0] == "v":
504
+ current_velocity = utils.bin_to_velocity(int(token[1:], base=16))
505
+ channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[state.current_bin]]
506
+ ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
507
+ state.delta_accum = 0.0
508
+ if current_velocity > 0:
509
+ yield mido.Message("note_on", note=state.current_note, velocity=current_velocity, time=ticks, channel=channel), state
510
+ else:
511
+ yield mido.Message("note_off", note=state.current_note, velocity=0, time=ticks, channel=channel), state
512
+ else:
513
+ if token[0] == "t" and token[1].isdigit(): # wait token
514
+ d = utils.wait_token_to_delta(token)
515
+ state.delta_accum += d
516
+ state.total_time += d
517
+ if utils.cfg.decode_end_held_note_delay != 0.0:
518
+ # remove notes that have been held for too long
519
+ for (channel, note), start_time in list(state.active_notes.items()).copy():
520
+ if state.total_time - start_time > utils.cfg.decode_end_held_note_delay * 1000.0:
521
+ ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
522
+ state.delta_accum = 0.0
523
+ del state.active_notes[(channel, note)]
524
+ yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
525
+ return
526
+ else: # note token
527
+ bin, note, velocity = utils.note_token_to_data(token)
528
+ channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[bin]]
529
+ ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000))
530
+ state.delta_accum = 0.0
531
+ if velocity > 0:
532
+ if utils.cfg.decode_fix_repeated_notes:
533
+ if (channel, note) in state.active_notes:
534
+ del state.active_notes[(channel, note)]
535
+ yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
536
+ ticks = 0
537
+ state.active_notes[(channel, note)] = state.total_time
538
+ yield mido.Message("note_on", note=note, velocity=velocity, time=ticks, channel=channel), state
539
+ return
540
+ else:
541
+ if (channel, note) in state.active_notes:
542
+ del state.active_notes[(channel, note)]
543
+ yield mido.Message("note_off", note=note, time=ticks, channel=channel), state
544
+ return
545
+ yield None, state
546
+
547
+
548
+ def str_to_midi_messages(utils: VocabUtils, data: str) -> Iterator[mido.Message]:
549
+ state = None
550
+ for token in data.split(" "):
551
+ for msg, new_state in token_to_midi_message(utils, token, state):
552
+ state = new_state
553
+ if msg is not None:
554
+ yield msg
555
+
556
+
557
+ def convert_str_to_midi(cfg: VocabConfig, data: str, meta_text: str = "Generated by MIDI-LLM-tokenizer") -> mido.MidiFile:
558
+ utils = VocabUtils(cfg)
559
+ mid = mido.MidiFile()
560
+ track = mido.MidiTrack()
561
+ mid.tracks.append(track)
562
+
563
+ tempo = 500000
564
+ if meta_text:
565
+ track.append(mido.MetaMessage("text", text=meta_text, time=0))
566
+ track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0))
567
+ for msg in generate_program_change_messages(cfg):
568
+ track.append(msg)
569
+
570
+ #data = data.replace("<start>", "").replace("<end>", "").replace("<pad>", "").strip()
571
+ for msg in str_to_midi_messages(utils, data):
572
+ track.append(msg)
573
+
574
+ track.append(mido.MetaMessage("end_of_track", time=0))
575
+
576
+ return mid