asigalov61 commited on
Commit
890ef0c
·
verified ·
1 Parent(s): 8d97094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -46
app.py CHANGED
@@ -257,26 +257,22 @@ def save_midi(tokens, batch_number=None):
257
 
258
  @spaces.GPU
259
  def generate_music(prime,
260
- num_gen_tokens,
261
- num_gen_batches,
262
- gen_outro,
263
- gen_drums,
264
- model_temperature,
265
- model_sampling_top_p
 
266
  ):
267
 
268
- model.cuda()
269
- model.eval()
270
-
271
- print('Generating...')
272
-
273
  if not prime:
274
  inputs = [19461]
275
 
276
  else:
277
- inputs = prime
278
 
279
- if gen_outro:
280
  inputs.extend([18945])
281
 
282
  if gen_drums:
@@ -301,11 +297,17 @@ def generate_music(prime,
301
  verbose=False)
302
 
303
  output = out.tolist()
304
-
305
- print('Done!')
306
- print('=' * 70)
307
-
308
- return output
 
 
 
 
 
 
309
 
310
  #==================================================================================
311
 
@@ -316,12 +318,13 @@ block_lines = []
316
  #==================================================================================
317
 
318
  def generate_callback(input_midi,
319
- num_prime_tokens,
320
- num_gen_tokens,
321
- gen_outro,
322
- gen_drums,
323
- model_temperature,
324
- model_sampling_top_p
 
325
  ):
326
 
327
  global generated_batches
@@ -333,7 +336,8 @@ def generate_callback(input_midi,
333
  block_lines.append(midi_score[-1][1] / 1000)
334
 
335
  batched_gen_tokens = generate_music(final_composition,
336
- num_gen_tokens,
 
337
  NUM_OUT_BATCHES,
338
  gen_outro,
339
  gen_drums,
@@ -385,18 +389,15 @@ def generate_callback(input_midi,
385
  #==================================================================================
386
 
387
  def generate_callback_wrapper(input_midi,
388
- num_prime_tokens,
389
- num_gen_tokens,
390
- gen_outro,
391
- gen_drums,
392
- model_temperature,
393
- model_sampling_top_p
 
394
  ):
395
 
396
- print('=' * 70)
397
- print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
398
- start_time = reqtime.time()
399
-
400
  print('=' * 70)
401
  if input_midi is not None:
402
  fn = os.path.basename(input_midi.name)
@@ -413,6 +414,7 @@ def generate_callback_wrapper(input_midi,
413
  result = generate_callback(input_midi,
414
  num_prime_tokens,
415
  num_gen_tokens,
 
416
  gen_outro,
417
  gen_drums,
418
  model_temperature,
@@ -420,12 +422,6 @@ def generate_callback_wrapper(input_midi,
420
  )
421
 
422
  generated_batches.extend([sublist[2] for sublist in result])
423
-
424
- print('=' * 70)
425
- print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
426
- print('=' * 70)
427
- print('Req execution time:', (reqtime.time() - start_time), 'sec')
428
- print('*' * 70)
429
 
430
  return tuple(item for sublist in result for item in sublist[:2])
431
 
@@ -499,7 +495,7 @@ def reset():
499
  final_composition = []
500
  generated_batches = []
501
  block_lines = []
502
-
503
  #==================================================================================
504
 
505
  PDT = timezone('US/Pacific')
@@ -529,17 +525,18 @@ with gr.Blocks() as demo:
529
  for faster execution and endless generation!
530
  """)
531
 
532
- gr.Markdown("## Upload seed MIDI or click 'Generate' button for random output")
533
 
534
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
535
  input_midi.upload(reset)
536
 
537
  gr.Markdown("## Generate")
538
 
539
- num_prime_tokens = gr.Slider(15, 6999, value=600, step=3, label="Number of prime tokens")
540
  num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
541
- gen_outro = gr.Checkbox(value=False, label="Try to generate an outro")
542
- gen_drums = gr.Checkbox(value=False, label="Try to introduce drums")
 
543
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
544
  model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
545
 
@@ -561,6 +558,7 @@ with gr.Blocks() as demo:
561
  [input_midi,
562
  num_prime_tokens,
563
  num_gen_tokens,
 
564
  gen_outro,
565
  gen_drums,
566
  model_temperature,
 
257
 
258
  @spaces.GPU
259
  def generate_music(prime,
260
+ num_gen_tokens,
261
+ num_mem_tokens,
262
+ num_gen_batches,
263
+ gen_outro,
264
+ gen_drums,
265
+ model_temperature,
266
+ model_sampling_top_p
267
  ):
268
 
 
 
 
 
 
269
  if not prime:
270
  inputs = [19461]
271
 
272
  else:
273
+ inputs = prime[-num_mem_tokens:]
274
 
275
+ if gen_outro == 'Force':
276
  inputs.extend([18945])
277
 
278
  if gen_drums:
 
297
  verbose=False)
298
 
299
  output = out.tolist()
300
+
301
+ output_batches = []
302
+
303
+ if gen_outro == 'Disable':
304
+ for o in output:
305
+ output_batches.append([t for t in o if not 18944 < t < 19330])
306
+
307
+ else:
308
+ output_batches = output
309
+
310
+ return output_batches
311
 
312
  #==================================================================================
313
 
 
318
  #==================================================================================
319
 
320
  def generate_callback(input_midi,
321
+ num_prime_tokens,
322
+ num_gen_tokens,
323
+ num_mem_tokens,
324
+ gen_outro,
325
+ gen_drums,
326
+ model_temperature,
327
+ model_sampling_top_p
328
  ):
329
 
330
  global generated_batches
 
336
  block_lines.append(midi_score[-1][1] / 1000)
337
 
338
  batched_gen_tokens = generate_music(final_composition,
339
+ num_gen_tokens,
340
+ num_mem_tokens,
341
  NUM_OUT_BATCHES,
342
  gen_outro,
343
  gen_drums,
 
389
  #==================================================================================
390
 
391
  def generate_callback_wrapper(input_midi,
392
+ num_prime_tokens,
393
+ num_gen_tokens,
394
+ num_mem_tokens,
395
+ gen_outro,
396
+ gen_drums,
397
+ model_temperature,
398
+ model_sampling_top_p
399
  ):
400
 
 
 
 
 
401
  print('=' * 70)
402
  if input_midi is not None:
403
  fn = os.path.basename(input_midi.name)
 
414
  result = generate_callback(input_midi,
415
  num_prime_tokens,
416
  num_gen_tokens,
417
+ num_mem_tokens,
418
  gen_outro,
419
  gen_drums,
420
  model_temperature,
 
422
  )
423
 
424
  generated_batches.extend([sublist[2] for sublist in result])
 
 
 
 
 
 
425
 
426
  return tuple(item for sublist in result for item in sublist[:2])
427
 
 
495
  final_composition = []
496
  generated_batches = []
497
  block_lines = []
498
+
499
  #==================================================================================
500
 
501
  PDT = timezone('US/Pacific')
 
525
  for faster execution and endless generation!
526
  """)
527
 
528
+ gr.Markdown("## Upload your MIDI or select a sample example MIDI")
529
 
530
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
531
  input_midi.upload(reset)
532
 
533
  gr.Markdown("## Generate")
534
 
535
+ num_prime_tokens = gr.Slider(15, 6990, value=600, step=3, label="Number of prime tokens")
536
  num_gen_tokens = gr.Slider(15, 1200, value=600, step=3, label="Number of tokens to generate")
537
+ num_mem_tokens = gr.Slider(15, 6990, value=6990, step=3, label="Number of memory tokens")
538
+ gen_drums = gr.Checkbox(value=False, label="Introduce drums")
539
+ gen_outro = gr.Radio(["Auto", "Disable", "Force"], value="Auto", label="Outro options")
540
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
541
  model_sampling_top_p = gr.Slider(0.1, 1, value=0.96, step=0.01, label="Model sampling top p value")
542
 
 
558
  [input_midi,
559
  num_prime_tokens,
560
  num_gen_tokens,
561
+ num_mem_tokens,
562
  gen_outro,
563
  gen_drums,
564
  model_temperature,