asigalov61's picture
Update app.py
869a44f verified
import os.path
import time as reqtime
import datetime
from pytz import timezone
import torch
import spaces
import gradio as gr
import random
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
from inference import PianoTranscription
from config import sample_rate
from utilities import load_audio
# =================================================================================================
@spaces.GPU
def TranscribePianoAudio(input_file):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
fn = os.path.basename(input_file)
fn1 = fn.split('.')[0]
out_mid = fn1+'.mid'
print('-' * 70)
print('Input file name:', fn)
print('-' * 70)
print('Loading audio...')
# Load audio
(audio, _) = load_audio(input_file, sr=sample_rate, mono=True)
print('Done!')
print('-' * 70)
print('Loading transcriptor..')
# Transcriptor
transcriptor = PianoTranscription(device='cuda') # 'cuda' | 'cpu'
print('Done!')
print('-' * 70)
print('Transcribing...')
transcribed_dict = transcriptor.transcribe(audio, out_mid)
print('Done!')
print('-' * 70)
#===============================================================================
raw_score = TMIDIX.midi2single_track_ms_score(out_mid)
#===============================================================================
# Enhanced score notes
escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
#==================================================================
print('=' * 70)
print('Number of transcribed notes:', len(escore))
print('Sample trascribed MIDI events', escore[:5])
print('=' * 70)
print('Done!')
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
audio = midi_to_colab_audio(out_mid,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1)
output_midi_summary = str(escore[:3])
output_midi = str(out_mid)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(escore, plot_title=output_midi_title, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>ByteDance Solo Piano Audio to MIDI Transcription</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Transcribe any Solo Piano WAV or MP3 audio to MIDI</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.ByteDance-Solo-Piano-Adio-to-MIDI-Transcription&style=flat)\n\n"
"This is a ByteDance Solo Piano Audio to MIDI Transcription Model\n\n"
"Check out [ByteDance Solo Piano Audio to MIDI Transcription](https://github.com/asigalov61/piano_transcription_inference) on GitHub!\n\n"
"[Open In Colab]"
"(https://colab.research.google.com/github/asigalov61/tegridy-tools/blob/main/tegridy-tools/notebooks/ByteDance_Piano_Transcription.ipynb)"
" for faster execution and endless transcription"
)
gr.Markdown("## Upload your Solo Piano WAV or MP3 audio or select a sample example audio file")
input_audio = gr.File(label="Input Solo Piano WAV or MP3 Audio File", file_types=[".wav", ".mp3"])
run_btn = gr.Button("transcribe", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(TranscribePianoAudio, [input_audio],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
gr.Examples(
[["cut_liszt.mp3"]
],
[input_audio],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
TranscribePianoAudio,
cache_examples=True,
)
app.queue().launch()