Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -257,26 +257,22 @@ def save_midi(tokens, batch_number=None):
|
|
257 |
|
258 |
@spaces.GPU
|
259 |
def generate_music(prime,
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
266 |
):
|
267 |
|
268 |
-
model.cuda()
|
269 |
-
model.eval()
|
270 |
-
|
271 |
-
print('Generating...')
|
272 |
-
|
273 |
if not prime:
|
274 |
inputs = [19461]
|
275 |
|
276 |
else:
|
277 |
-
inputs = prime
|
278 |
|
279 |
-
if gen_outro:
|
280 |
inputs.extend([18945])
|
281 |
|
282 |
if gen_drums:
|
@@ -301,11 +297,17 @@ def generate_music(prime,
|
|
301 |
verbose=False)
|
302 |
|
303 |
output = out.tolist()
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
#==================================================================================
|
311 |
|
@@ -316,12 +318,13 @@ block_lines = []
|
|
316 |
#==================================================================================
|
317 |
|
318 |
def generate_callback(input_midi,
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
325 |
):
|
326 |
|
327 |
global generated_batches
|
@@ -333,7 +336,8 @@ def generate_callback(input_midi,
|
|
333 |
block_lines.append(midi_score[-1][1] / 1000)
|
334 |
|
335 |
batched_gen_tokens = generate_music(final_composition,
|
336 |
-
num_gen_tokens,
|
|
|
337 |
NUM_OUT_BATCHES,
|
338 |
gen_outro,
|
339 |
gen_drums,
|
@@ -385,18 +389,15 @@ def generate_callback(input_midi,
|
|
385 |
#==================================================================================
|
386 |
|
387 |
def generate_callback_wrapper(input_midi,
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
394 |
):
|
395 |
|
396 |
-
print('=' * 70)
|
397 |
-
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
398 |
-
start_time = reqtime.time()
|
399 |
-
|
400 |
print('=' * 70)
|
401 |
if input_midi is not None:
|
402 |
fn = os.path.basename(input_midi.name)
|
@@ -413,6 +414,7 @@ def generate_callback_wrapper(input_midi,
|
|
413 |
result = generate_callback(input_midi,
|
414 |
num_prime_tokens,
|
415 |
num_gen_tokens,
|
|
|
416 |
gen_outro,
|
417 |
gen_drums,
|
418 |
model_temperature,
|
@@ -420,12 +422,6 @@ def generate_callback_wrapper(input_midi,
|
|
420 |
)
|
421 |
|
422 |
generated_batches.extend([sublist[2] for sublist in result])
|
423 |
-
|
424 |
-
print('=' * 70)
|
425 |
-
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
426 |
-
print('=' * 70)
|
427 |
-
print('Req execution time:', (reqtime.time() - start_time), 'sec')
|
428 |
-
print('*' * 70)
|
429 |
|
430 |
return tuple(item for sublist in result for item in sublist[:2])
|
431 |
|
@@ -499,7 +495,7 @@ def reset():
|
|
499 |
final_composition = []
|
500 |
generated_batches = []
|
501 |
block_lines = []
|
502 |
-
|
503 |
#==================================================================================
|
504 |
|
505 |
PDT = timezone('US/Pacific')
|
@@ -529,17 +525,18 @@ with gr.Blocks() as demo:
|
|
529 |
for faster execution and endless generation!
|
530 |
""")
|
531 |
|
532 |
-
gr.Markdown("## Upload
|
533 |
|
534 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
535 |
input_midi.upload(reset)
|
536 |
|
537 |
gr.Markdown("## Generate")
|
538 |
|
539 |
-
num_prime_tokens = gr.Slider(15,
|
540 |
num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
|
541 |
-
|
542 |
-
gen_drums = gr.Checkbox(value=False, label="
|
|
|
543 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
544 |
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
545 |
|
@@ -561,6 +558,7 @@ with gr.Blocks() as demo:
|
|
561 |
[input_midi,
|
562 |
num_prime_tokens,
|
563 |
num_gen_tokens,
|
|
|
564 |
gen_outro,
|
565 |
gen_drums,
|
566 |
model_temperature,
|
|
|
257 |
|
258 |
@spaces.GPU
|
259 |
def generate_music(prime,
|
260 |
+
num_gen_tokens,
|
261 |
+
num_mem_tokens,
|
262 |
+
num_gen_batches,
|
263 |
+
gen_outro,
|
264 |
+
gen_drums,
|
265 |
+
model_temperature,
|
266 |
+
model_sampling_top_p
|
267 |
):
|
268 |
|
|
|
|
|
|
|
|
|
|
|
269 |
if not prime:
|
270 |
inputs = [19461]
|
271 |
|
272 |
else:
|
273 |
+
inputs = prime[-num_mem_tokens:]
|
274 |
|
275 |
+
if gen_outro == 'Force':
|
276 |
inputs.extend([18945])
|
277 |
|
278 |
if gen_drums:
|
|
|
297 |
verbose=False)
|
298 |
|
299 |
output = out.tolist()
|
300 |
+
|
301 |
+
output_batches = []
|
302 |
+
|
303 |
+
if gen_outro == 'Disable':
|
304 |
+
for o in output:
|
305 |
+
output_batches.append([t for t in o if not 18944 < t < 19330])
|
306 |
+
|
307 |
+
else:
|
308 |
+
output_batches = output
|
309 |
+
|
310 |
+
return output_batches
|
311 |
|
312 |
#==================================================================================
|
313 |
|
|
|
318 |
#==================================================================================
|
319 |
|
320 |
def generate_callback(input_midi,
|
321 |
+
num_prime_tokens,
|
322 |
+
num_gen_tokens,
|
323 |
+
num_mem_tokens,
|
324 |
+
gen_outro,
|
325 |
+
gen_drums,
|
326 |
+
model_temperature,
|
327 |
+
model_sampling_top_p
|
328 |
):
|
329 |
|
330 |
global generated_batches
|
|
|
336 |
block_lines.append(midi_score[-1][1] / 1000)
|
337 |
|
338 |
batched_gen_tokens = generate_music(final_composition,
|
339 |
+
num_gen_tokens,
|
340 |
+
num_mem_tokens,
|
341 |
NUM_OUT_BATCHES,
|
342 |
gen_outro,
|
343 |
gen_drums,
|
|
|
389 |
#==================================================================================
|
390 |
|
391 |
def generate_callback_wrapper(input_midi,
|
392 |
+
num_prime_tokens,
|
393 |
+
num_gen_tokens,
|
394 |
+
num_mem_tokens,
|
395 |
+
gen_outro,
|
396 |
+
gen_drums,
|
397 |
+
model_temperature,
|
398 |
+
model_sampling_top_p
|
399 |
):
|
400 |
|
|
|
|
|
|
|
|
|
401 |
print('=' * 70)
|
402 |
if input_midi is not None:
|
403 |
fn = os.path.basename(input_midi.name)
|
|
|
414 |
result = generate_callback(input_midi,
|
415 |
num_prime_tokens,
|
416 |
num_gen_tokens,
|
417 |
+
num_mem_tokens,
|
418 |
gen_outro,
|
419 |
gen_drums,
|
420 |
model_temperature,
|
|
|
422 |
)
|
423 |
|
424 |
generated_batches.extend([sublist[2] for sublist in result])
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
return tuple(item for sublist in result for item in sublist[:2])
|
427 |
|
|
|
495 |
final_composition = []
|
496 |
generated_batches = []
|
497 |
block_lines = []
|
498 |
+
|
499 |
#==================================================================================
|
500 |
|
501 |
PDT = timezone('US/Pacific')
|
|
|
525 |
for faster execution and endless generation!
|
526 |
""")
|
527 |
|
528 |
+
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
529 |
|
530 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
531 |
input_midi.upload(reset)
|
532 |
|
533 |
gr.Markdown("## Generate")
|
534 |
|
535 |
+
num_prime_tokens = gr.Slider(15, 6990, value=600, step=3, label="Number of prime tokens")
|
536 |
num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
|
537 |
+
num_mem_tokens = gr.Slider(15, 6990, value=6990, step=3, label="Number of memory tokens")
|
538 |
+
gen_drums = gr.Checkbox(value=False, label="Introduce drums")
|
539 |
+
gen_outro = gr.Radio(["Auto", "Disable", "Force"], value="Auto", label="Outro options")
|
540 |
model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
|
541 |
model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
|
542 |
|
|
|
558 |
[input_midi,
|
559 |
num_prime_tokens,
|
560 |
num_gen_tokens,
|
561 |
+
num_mem_tokens,
|
562 |
gen_outro,
|
563 |
gen_drums,
|
564 |
model_temperature,
|