import torch import pickle import whisper import streamlit as st import torchaudio as ta from io import BytesIO from transformers import WhisperProcessor, WhisperForConditionalGeneration # Set up device and dtype device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda:0" else torch.float32 SAMPLING_RATE = 16000 # Load Whisper model and processor processor = WhisperProcessor.from_pretrained("openai/whisper-small") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") # Title of the app st.title("Audio Player with Live Transcription") # Sidebar for file uploader and submit button st.sidebar.header("Upload Audio Files") uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) submit_button = st.sidebar.button("Submit") # Session state to hold data if 'audio_files' not in st.session_state: st.session_state.audio_files = [] st.session_state.transcriptions = {} st.session_state.translations = {} st.session_state.detected_languages = [] st.session_state.waveforms = [] def detect_language(audio_file): whisper_model = whisper.load_model("small") trimmed_audio = whisper.pad_or_trim(audio_file) mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device) _, probs = whisper_model.detect_language(mel[0]) detected_lang = max(probs, key=probs.get) print(f"Detected language: {detected_lang}") return detected_lang # Process uploaded files if submit_button and uploaded_files is not None: st.session_state.audio_files = uploaded_files st.session_state.detected_languages = [] for uploaded_file in uploaded_files: waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read())) if sampling_rate != SAMPLING_RATE: waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) st.session_state.waveforms.append(waveform) detected_language = detect_language(waveform) st.session_state.detected_languages.append(detected_language) # Display uploaded files and options if 'audio_files' in st.session_state and st.session_state.audio_files: for i, uploaded_file in enumerate(st.session_state.audio_files): col1, col2 = st.columns([1, 3]) with col1: st.write(f"**File name**: {uploaded_file.name}") st.audio(BytesIO(uploaded_file.read()), format=uploaded_file.type) st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}") with col2: # import pdb;pdb.set_trace() input_features = processor(st.session_state.waveforms[i][0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features if st.button(f"Transcribe {uploaded_file.name}"): predicted_ids = model.generate(input_features) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) st.session_state.transcriptions[i] = transcription if st.session_state.transcriptions.get(i): st.write("**Transcription**:") for line in st.session_state.transcriptions[i]: st.write(line) if st.button(f"Translate {uploaded_file.name}"): with open('languages.pkl', 'rb') as f: lang_dict = pickle.load(f) detected_language_name = lang_dict[st.session_state.detected_languages[i]] forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate") predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) translation = processor.batch_decode(predicted_ids, skip_special_tokens=True) st.session_state.translations[i] = translation if st.session_state.translations.get(i): st.write("**Translation**:") for line in st.session_state.translations[i]: st.write(line)