juancopi81 commited on
Commit
d8fad2b
·
1 Parent(s): 8f37e44

Initial commit with files

Browse files
Files changed (9) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. constants.py +133 -0
  4. main.py +133 -5
  5. model.py +31 -0
  6. pyproject.toml +6 -0
  7. requirements.txt +1 -0
  8. string_to_notes.py +137 -0
  9. utils.py +245 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Multitrack Midi Music Generator
3
- emoji: 📚
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: docker
 
1
  ---
2
  title: Multitrack Midi Music Generator
3
+ emoji: 🎵
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: docker
constants.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAMPLE_RATE = 44100
2
+
3
+
4
+ GM_INSTRUMENTS = [
5
+ "Acoustic Grand Piano",
6
+ "Bright Acoustic Piano",
7
+ "Electric Grand Piano",
8
+ "Honky-tonk Piano",
9
+ "Electric Piano 1",
10
+ "Electric Piano 2",
11
+ "Harpsichord",
12
+ "Clavi",
13
+ "Celesta",
14
+ "Glockenspiel",
15
+ "Music Box",
16
+ "Vibraphone",
17
+ "Marimba",
18
+ "Xylophone",
19
+ "Tubular Bells",
20
+ "Dulcimer",
21
+ "Drawbar Organ",
22
+ "Percussive Organ",
23
+ "Rock Organ",
24
+ "Church Organ",
25
+ "Reed Organ",
26
+ "Accordion",
27
+ "Harmonica",
28
+ "Tango Accordion",
29
+ "Acoustic Guitar (nylon)",
30
+ "Acoustic Guitar (steel)",
31
+ "Electric Guitar (jazz)",
32
+ "Electric Guitar (clean)",
33
+ "Electric Guitar (muted)",
34
+ "Overdriven Guitar",
35
+ "Distortion Guitar",
36
+ "Guitar Harmonics",
37
+ "Acoustic Bass",
38
+ "Electric Bass (finger)",
39
+ "Electric Bass (pick)",
40
+ "Fretless Bass",
41
+ "Slap Bass 1",
42
+ "Slap Bass 2",
43
+ "Synth Bass 1",
44
+ "Synth Bass 2",
45
+ "Violin",
46
+ "Viola",
47
+ "Cello",
48
+ "Contrabass",
49
+ "Tremolo Strings",
50
+ "Pizzicato Strings",
51
+ "Orchestral Harp",
52
+ "Timpani",
53
+ "String Ensemble 1",
54
+ "String Ensemble 2",
55
+ "Synth Strings 1",
56
+ "Synth Strings 2",
57
+ "Choir Aahs",
58
+ "Voice Oohs",
59
+ "Synth Choir",
60
+ "Orchestra Hit",
61
+ "Trumpet",
62
+ "Trombone",
63
+ "Tuba",
64
+ "Muted Trumpet",
65
+ "French Horn",
66
+ "Brass Section",
67
+ "Synth Brass 1",
68
+ "Synth Brass 2",
69
+ "Soprano Sax",
70
+ "Alto Sax",
71
+ "Tenor Sax",
72
+ "Baritone Sax",
73
+ "Oboe",
74
+ "English Horn",
75
+ "Bassoon",
76
+ "Clarinet",
77
+ "Piccolo",
78
+ "Flute",
79
+ "Recorder",
80
+ "Pan Flute",
81
+ "Blown Bottle",
82
+ "Shakuhachi",
83
+ "Whistle",
84
+ "Ocarina",
85
+ "Lead 1 (square)",
86
+ "Lead 2 (sawtooth)",
87
+ "Lead 3 (calliope)",
88
+ "Lead 4 (chiff)",
89
+ "Lead 5 (charang)",
90
+ "Lead 6 (voice)",
91
+ "Lead 7 (fifths)",
92
+ "Lead 8 (bass + lead)",
93
+ "Pad 1 (new age)",
94
+ "Pad 2 (warm)",
95
+ "Pad 3 (polysynth)",
96
+ "Pad 4 (choir)",
97
+ "Pad 5 (bowed)",
98
+ "Pad 6 (metallic)",
99
+ "Pad 7 (halo)",
100
+ "Pad 8 (sweep)",
101
+ "FX 1 (rain)",
102
+ "FX 2 (soundtrack)",
103
+ "FX 3 (crystal)",
104
+ "FX 4 (atmosphere)",
105
+ "FX 5 (brightness)",
106
+ "FX 6 (goblins)",
107
+ "FX 7 (echoes)",
108
+ "FX 8 (sci-fi)",
109
+ "Sitar",
110
+ "Banjo",
111
+ "Shamisen",
112
+ "Koto",
113
+ "Kalimba",
114
+ "Bagpipe",
115
+ "Fiddle",
116
+ "Shanai",
117
+ "Tinkle Bell",
118
+ "Agogo",
119
+ "Steel Drums",
120
+ "Woodblock",
121
+ "Taiko Drum",
122
+ "Melodic Tom",
123
+ "Synth Drum",
124
+ "Reverse Cymbal",
125
+ "Guitar Fret Noise",
126
+ "Breath Noise",
127
+ "Seashore",
128
+ "Bird Tweet",
129
+ "Telephone Ring",
130
+ "Helicopter",
131
+ "Applause",
132
+ "Gunshot",
133
+ ]
main.py CHANGED
@@ -1,13 +1,141 @@
 
 
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def run():
4
- demo = gr.Interface(
5
- inputs=gr.inputs.Image(type="pil"),
6
- outputs=gr.outputs.Label(num_top_classes=3),
7
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  demo.launch(server_name="0.0.0.0", server_port=7860)
10
 
11
 
12
  if __name__ == "__main__":
13
- run()
 
1
+ import os
2
+
3
  import gradio as gr
4
 
5
+ from utils import (
6
+ generate_song,
7
+ remove_last_instrument,
8
+ regenerate_last_instrument,
9
+ change_tempo,
10
+ )
11
+
12
+
13
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
14
+
15
+ DESCRIPTION = """
16
+
17
+ # 🎵 Multitrack Midi Generator 🎶
18
+ This interactive application uses an AI model to generate music sequences based on a chosen genre and various user inputs.
19
+
20
+ Features:
21
+ 🎼 Select the genre for the music.
22
+ 🌡️ Use the "Temperature" slider to adjust the randomness of the music generated (higher values will produce more random outputs).
23
+ ⏱️ Adjust the "Tempo" slider to change the speed of the music.
24
+ 🎹 Use the buttons to generate a new song from scratch, continue generation with the current settings, remove the last added instrument, regenerate the last added instrument with a new one, or change the tempo of the current song.
25
+ Outputs:
26
+ The app outputs the following:
27
+
28
+ 🎧 The audio of the generated song.
29
+ 📁 A MIDI file of the song.
30
+ 📊 A plot of the song's sequence.
31
+ 🎸 A list of the generated instruments.
32
+ 📝 The text sequence of the song.
33
+ Enjoy creating your own AI-generated music! 🎵
34
+ """
35
+
36
+ genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"]
37
+
38
+ demo = gr.Blocks()
39
+
40
+
41
  def run():
42
+ with demo:
43
+ gr.Markdown(DESCRIPTION)
44
+ with gr.Row():
45
+ with gr.Column():
46
+ temp = gr.Slider(
47
+ minimum=0, maximum=1, step=0.05, value=0.75, label="Temperature"
48
+ )
49
+ genre = gr.Dropdown(
50
+ choices=genres, value="POP", label="Select the genre"
51
+ )
52
+ with gr.Row():
53
+ btn_from_scratch = gr.Button("Start from scratch")
54
+ btn_continue = gr.Button("Continue Generation")
55
+ btn_remove_last = gr.Button("Remove last instrument")
56
+ btn_regenerate_last = gr.Button("Regenerate last instrument")
57
+ with gr.Column():
58
+ with gr.Box():
59
+ audio_output = gr.Video()
60
+ midi_file = gr.File()
61
+ with gr.Row():
62
+ qpm = gr.Slider(
63
+ minimum=60, maximum=140, step=10, value=120, label="Tempo"
64
+ )
65
+ btn_qpm = gr.Button("Change Tempo")
66
+ with gr.Row():
67
+ with gr.Column():
68
+ plot_output = gr.Plot()
69
+ with gr.Column():
70
+ instruments_output = gr.Markdown("# List of generated instruments")
71
+ with gr.Row():
72
+ text_sequence = gr.Text()
73
+ empty_sequence = gr.Text(visible=False)
74
+ with gr.Row():
75
+ num_tokens = gr.Text()
76
+ btn_from_scratch.click(
77
+ fn=generate_song,
78
+ inputs=[genre, temp, empty_sequence, qpm],
79
+ outputs=[
80
+ audio_output,
81
+ midi_file,
82
+ plot_output,
83
+ instruments_output,
84
+ text_sequence,
85
+ num_tokens,
86
+ ],
87
+ )
88
+ btn_continue.click(
89
+ fn=generate_song,
90
+ inputs=[genre, temp, text_sequence, qpm],
91
+ outputs=[
92
+ audio_output,
93
+ midi_file,
94
+ plot_output,
95
+ instruments_output,
96
+ text_sequence,
97
+ num_tokens,
98
+ ],
99
+ )
100
+ btn_remove_last.click(
101
+ fn=remove_last_instrument,
102
+ inputs=[text_sequence, qpm],
103
+ outputs=[
104
+ audio_output,
105
+ midi_file,
106
+ plot_output,
107
+ instruments_output,
108
+ text_sequence,
109
+ num_tokens,
110
+ ],
111
+ )
112
+ btn_regenerate_last.click(
113
+ fn=regenerate_last_instrument,
114
+ inputs=[text_sequence, qpm],
115
+ outputs=[
116
+ audio_output,
117
+ midi_file,
118
+ plot_output,
119
+ instruments_output,
120
+ text_sequence,
121
+ num_tokens,
122
+ ],
123
+ )
124
+ btn_qpm.click(
125
+ fn=change_tempo,
126
+ inputs=[text_sequence, qpm],
127
+ outputs=[
128
+ audio_output,
129
+ midi_file,
130
+ plot_output,
131
+ instruments_output,
132
+ text_sequence,
133
+ num_tokens,
134
+ ],
135
+ )
136
 
137
  demo.launch(server_name="0.0.0.0", server_port=7860)
138
 
139
 
140
  if __name__ == "__main__":
141
+ run()
model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # Initialize the model and tokenizer variables as None
6
+ tokenizer = None
7
+ model = None
8
+
9
+
10
+ def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
11
+ """
12
+ Returns the preloaded model and tokenizer. If they haven't been loaded before, loads them.
13
+
14
+ Returns:
15
+ tuple: A tuple containing the preloaded model and tokenizer.
16
+ """
17
+ global model, tokenizer
18
+ if model is None or tokenizer is None:
19
+ # Set device
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Load the tokenizer and the model
23
+ tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ "juancopi81/lmd-8bars-2048-epochs20_v3"
26
+ )
27
+
28
+ # Move model to device
29
+ model = model.to(device)
30
+
31
+ return model, tokenizer
pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ exclude = '''
3
+ (
4
+ /env
5
+ )
6
+ '''
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  note-seq
2
  matplotlib
3
  transformers
 
1
+ gradio
2
  note-seq
3
  matplotlib
4
  transformers
string_to_notes.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from note_seq.protobuf.music_pb2 import NoteSequence
4
+ from note_seq.constants import STANDARD_PPQ
5
+
6
+
7
+ def token_sequence_to_note_sequence(
8
+ token_sequence: str,
9
+ qpm: float = 120.0,
10
+ use_program: bool = True,
11
+ use_drums: bool = True,
12
+ instrument_mapper: Optional[dict] = None,
13
+ only_piano: bool = False,
14
+ ) -> NoteSequence:
15
+ """
16
+ Converts a sequence of tokens into a sequence of notes.
17
+
18
+ Args:
19
+ token_sequence (str): The sequence of tokens to convert.
20
+ qpm (float, optional): The quarter notes per minute. Defaults to 120.0.
21
+ use_program (bool, optional): Whether to use program. Defaults to True.
22
+ use_drums (bool, optional): Whether to use drums. Defaults to True.
23
+ instrument_mapper (Optional[dict], optional): The instrument mapper. Defaults to None.
24
+ only_piano (bool, optional): Whether to only use piano. Defaults to False.
25
+
26
+ Returns:
27
+ NoteSequence: The resulting sequence of notes.
28
+ """
29
+ if isinstance(token_sequence, str):
30
+ token_sequence = token_sequence.split()
31
+
32
+ note_sequence = empty_note_sequence(qpm)
33
+
34
+ # Compute note and bar lengths based on the provided QPM
35
+ note_length_16th = 0.25 * 60 / qpm
36
+ bar_length = 4.0 * 60 / qpm
37
+
38
+ # Render all notes.
39
+ current_program = 1
40
+ current_is_drum = False
41
+ current_instrument = 0
42
+ track_count = 0
43
+ for _, token in enumerate(token_sequence):
44
+ if token == "PIECE_START":
45
+ pass
46
+ elif token == "PIECE_END":
47
+ break
48
+ elif token == "TRACK_START":
49
+ current_bar_index = 0
50
+ track_count += 1
51
+ pass
52
+ elif token == "TRACK_END":
53
+ pass
54
+ elif token == "KEYS_START":
55
+ pass
56
+ elif token == "KEYS_END":
57
+ pass
58
+ elif token.startswith("KEY="):
59
+ pass
60
+ elif token.startswith("INST"):
61
+ instrument = token.split("=")[-1]
62
+ if instrument != "DRUMS" and use_program:
63
+ if instrument_mapper is not None:
64
+ if instrument in instrument_mapper:
65
+ instrument = instrument_mapper[instrument]
66
+ current_program = int(instrument)
67
+ current_instrument = track_count
68
+ current_is_drum = False
69
+ if instrument == "DRUMS" and use_drums:
70
+ current_instrument = 0
71
+ current_program = 0
72
+ current_is_drum = True
73
+ elif token == "BAR_START":
74
+ current_time = current_bar_index * bar_length
75
+ current_notes = {}
76
+ elif token == "BAR_END":
77
+ current_bar_index += 1
78
+ pass
79
+ elif token.startswith("NOTE_ON"):
80
+ pitch = int(token.split("=")[-1])
81
+ note = note_sequence.notes.add()
82
+ note.start_time = current_time
83
+ note.end_time = current_time + 4 * note_length_16th
84
+ note.pitch = pitch
85
+ note.instrument = current_instrument
86
+ note.program = current_program
87
+ note.velocity = 80
88
+ note.is_drum = current_is_drum
89
+ current_notes[pitch] = note
90
+ elif token.startswith("NOTE_OFF"):
91
+ pitch = int(token.split("=")[-1])
92
+ if pitch in current_notes:
93
+ note = current_notes[pitch]
94
+ note.end_time = current_time
95
+ elif token.startswith("TIME_DELTA"):
96
+ delta = float(token.split("=")[-1]) * note_length_16th
97
+ current_time += delta
98
+ elif token.startswith("DENSITY="):
99
+ pass
100
+ elif token == "[PAD]":
101
+ pass
102
+ else:
103
+ pass
104
+
105
+ # Make the instruments right.
106
+ instruments_drums = []
107
+ for note in note_sequence.notes:
108
+ pair = [note.program, note.is_drum]
109
+ if pair not in instruments_drums:
110
+ instruments_drums += [pair]
111
+ note.instrument = instruments_drums.index(pair)
112
+
113
+ if only_piano:
114
+ for note in note_sequence.notes:
115
+ if not note.is_drum:
116
+ note.instrument = 0
117
+ note.program = 0
118
+
119
+ return note_sequence
120
+
121
+
122
+ def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> NoteSequence:
123
+ """
124
+ Creates an empty note sequence.
125
+
126
+ Args:
127
+ qpm (float, optional): The quarter notes per minute. Defaults to 120.0.
128
+ total_time (float, optional): The total time. Defaults to 0.0.
129
+
130
+ Returns:
131
+ NoteSequence: The empty note sequence.
132
+ """
133
+ note_sequence = NoteSequence()
134
+ note_sequence.tempos.add().qpm = qpm
135
+ note_sequence.ticks_per_quarter = STANDARD_PPQ
136
+ note_sequence.total_time = total_time
137
+ return note_sequence
utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import note_seq
6
+ from matplotlib.figure import Figure
7
+ from numpy import ndarray
8
+
9
+ from constants import GM_INSTRUMENTS, SAMPLE_RATE
10
+ from string_to_notes import token_sequence_to_note_sequence
11
+ from model import get_model_and_tokenizer
12
+
13
+
14
+ model, tokenizer = get_model_and_tokenizer()
15
+
16
+
17
+ def create_seed_string(genre: str = "OTHER") -> str:
18
+ """
19
+ Creates a seed string for generating a new piece.
20
+
21
+ Args:
22
+ genre (str, optional): The genre of the piece. Defaults to "OTHER".
23
+
24
+ Returns:
25
+ str: The seed string.
26
+ """
27
+ seed_string = f"PIECE_START GENRE={genre} TRACK_START"
28
+ return seed_string
29
+
30
+
31
+ def get_instruments(text_sequence: str) -> List[str]:
32
+ """
33
+ Extracts the list of instruments from a text sequence.
34
+
35
+ Args:
36
+ text_sequence (str): The text sequence.
37
+
38
+ Returns:
39
+ List[str]: The list of instruments.
40
+ """
41
+ instruments = []
42
+ parts = text_sequence.split()
43
+ for part in parts:
44
+ if part.startswith("INST="):
45
+ if part[5:] == "DRUMS":
46
+ instruments.append("Drums")
47
+ else:
48
+ index = int(part[5:])
49
+ instruments.append(GM_INSTRUMENTS[index])
50
+ return instruments
51
+
52
+
53
+ def generate_new_instrument(
54
+ seed: str, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, temp: float = 0.75
55
+ ) -> str:
56
+ """
57
+ Generates a new instrument sequence from a given seed and temperature.
58
+
59
+ Args:
60
+ seed (str): The seed string for the generation.
61
+ tokenizer (PreTrainedTokenizer): The tokenizer used to encode and decode the sequences.
62
+ model (PreTrainedModel): The pretrained model used for generating the sequences.
63
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
64
+
65
+ Returns:
66
+ str: The generated instrument sequence.
67
+ """
68
+ seed_length = len(tokenizer.encode(seed))
69
+
70
+ while True:
71
+ # Encode the conditioning tokens.
72
+ input_ids = tokenizer.encode(seed, return_tensors="pt")
73
+
74
+ # Move the input_ids tensor to the same device as the model
75
+ input_ids = input_ids.to(model.device)
76
+
77
+ # Generate more tokens.
78
+ eos_token_id = tokenizer.encode("TRACK_END")[0]
79
+ generated_ids = model.generate(
80
+ input_ids,
81
+ max_new_tokens=2048,
82
+ do_sample=True,
83
+ temperature=temp,
84
+ eos_token_id=eos_token_id,
85
+ )
86
+ generated_sequence = tokenizer.decode(generated_ids[0])
87
+
88
+ # Check if the generated sequence contains "NOTE_ON" beyond the seed
89
+ new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
90
+ if "NOTE_ON" in new_generated_sequence:
91
+ return generated_sequence
92
+
93
+
94
+ def get_outputs_from_string(
95
+ generated_sequence: str, qpm: int = 120
96
+ ) -> Tuple[ndarray, str, Figure, str, str]:
97
+ """
98
+ Converts a generated sequence into various output formats including audio, MIDI, plot, etc.
99
+
100
+ Args:
101
+ generated_sequence (str): The generated sequence of tokens.
102
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
103
+
104
+ Returns:
105
+ Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
106
+ instruments string, and number of tokens string.
107
+ """
108
+ instruments = get_instruments(generated_sequence)
109
+ instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
110
+ note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
111
+
112
+ synth = note_seq.fluidsynth
113
+ array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
114
+ int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
115
+ fig = note_seq.plot_sequence(note_sequence, show_figure=False)
116
+ num_tokens = str(len(generated_sequence.split()))
117
+ audio = gr.make_waveform((SAMPLE_RATE, int16_data))
118
+ note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
119
+ return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
120
+
121
+
122
+ def remove_last_instrument(
123
+ text_sequence: str, qpm: int = 120
124
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
125
+ """
126
+ Removes the last instrument from a song string and returns the various output formats.
127
+
128
+ Args:
129
+ text_sequence (str): The song string.
130
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
131
+
132
+ Returns:
133
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
134
+ instruments string, new song string, and number of tokens string.
135
+ """
136
+ # We split the song into tracks by splitting on 'TRACK_START'
137
+ tracks = text_sequence.split("TRACK_START")
138
+ # We keep all tracks except the last one
139
+ modified_tracks = tracks[:-1]
140
+ # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
141
+ new_song = "TRACK_START".join(modified_tracks)
142
+
143
+ if len(tracks) == 2:
144
+ # There is only one instrument, so start from scratch
145
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
146
+ text_sequence=new_song
147
+ )
148
+ elif len(tracks) == 1:
149
+ # No instrument so start from empty sequence
150
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
151
+ text_sequence=""
152
+ )
153
+ else:
154
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
155
+ new_song, qpm
156
+ )
157
+
158
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
159
+
160
+
161
+ def regenerate_last_instrument(
162
+ text_sequence: str, qpm: int = 120
163
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
164
+ """
165
+ Regenerates the last instrument in a song string and returns the various output formats.
166
+
167
+ Args:
168
+ text_sequence (str): The song string.
169
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
170
+
171
+ Returns:
172
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
173
+ instruments string, new song string, and number of tokens string.
174
+ """
175
+ last_inst_index = text_sequence.rfind("INST=")
176
+ if last_inst_index == -1:
177
+ # No instrument so start from empty sequence
178
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
179
+ text_sequence="", qpm=qpm
180
+ )
181
+ else:
182
+ # Take it from the last instrument and continue generation
183
+ next_space_index = text_sequence.find(" ", last_inst_index)
184
+ new_seed = text_sequence[:next_space_index]
185
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
186
+ text_sequence=new_seed, qpm=qpm
187
+ )
188
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
189
+
190
+
191
+ def change_tempo(
192
+ text_sequence: str, qpm: int
193
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
194
+ """
195
+ Changes the tempo of a song string and returns the various output formats.
196
+
197
+ Args:
198
+ text_sequence (str): The song string.
199
+ qpm (int): The new quarter notes per minute.
200
+
201
+ Returns:
202
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
203
+ instruments string, text sequence, and number of tokens string.
204
+ """
205
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
206
+ text_sequence, qpm=qpm
207
+ )
208
+ return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
209
+
210
+
211
+ def generate_song(
212
+ model: AutoModelForCausalLM = model,
213
+ tokenizer: AutoTokenizer = tokenizer,
214
+ genre: str = "OTHER",
215
+ temp: float = 0.75,
216
+ text_sequence: str = "",
217
+ qpm: int = 120,
218
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
219
+ """
220
+ Generates a song given a genre, temperature, initial text sequence, and tempo.
221
+
222
+ Args:
223
+ model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
224
+ tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
225
+ genre (str, optional): The genre of the song. Defaults to "OTHER".
226
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
227
+ text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
228
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
229
+
230
+ Returns:
231
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
232
+ instruments string, generated song string, and number of tokens string.
233
+ """
234
+ if text_sequence == "":
235
+ seed_string = create_seed_string(genre)
236
+ else:
237
+ seed_string = text_sequence
238
+
239
+ generated_sequence = generate_new_instrument(
240
+ seed=seed_string, tokenizer=tokenizer, model=model, temp=temp
241
+ )
242
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
243
+ generated_sequence, qpm
244
+ )
245
+ return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens