aditii09 commited on
Commit
4e21035
1 Parent(s): 2c3919b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -1,32 +1,46 @@
1
- import librosa
2
  import gradio as gr
3
- import numpy as np
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- import soundfile as sf
6
  import torch
 
7
 
8
- # load model and tokenizer
9
- processor = Wav2Vec2Processor.from_pretrained("aditii09/facebook_english_asr")
10
- model = Wav2Vec2ForCTC.from_pretrained("aditii09/facebook_english_asr")
11
 
12
- def speech2text(audio):
13
- sr, data = audio
14
 
15
- # resample to 16hz
16
- data_16hz = librosa.resample(data[:,0].astype(np.float32),sr,16000)
17
 
18
- # tokenize
19
- input_values = processor([data_16hz], return_tensors="pt", padding="longest").input_values # Batch size 1
20
 
21
- # retrieve logits
22
- logits = model(input_values).logits
23
 
24
- # take argmax and decode
25
- predicted_ids = torch.argmax(logits, dim=-1)
26
- transcription = processor.batch_decode(predicted_ids)
27
 
28
- return transcription[0].lower() # batch size 1
 
 
 
 
 
29
 
30
- iface = gr.Interface(speech2text, "microphone", "text")
31
-
32
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import whisper
4
+ import librosa
 
5
  import torch
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2Tokenizer
7
 
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
9
 
10
+ def audio_to_text(audio):
11
+ model = whisper.load_model("base")
12
 
13
+ audio = whisper.load_audio(audio)
14
+ result = model.transcribe(audio)
15
 
16
+ return result["text"]
17
+ # tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
18
 
19
+ # logits = preprocess(audio)
 
20
 
21
+ # predicted_ids = torch.argmax(logits, dim=-1)
22
+ # transcriptions = tokenizer.decode(predicted_ids[0])
23
+ # return transcriptions
24
 
25
+ def preprocess(audio):
26
+ model_save_path = "model_save"
27
+ model_name = "wav2vec2_osr_version_1"
28
+ speech, rate = librosa.load(audio, sr=16000)
29
+ model_path = os.path.join(model_save_path, model_name+".pt")
30
+ pipeline_path = os.path.join(model_save_path, model_name+"_vocab")
31
 
32
+ access_token = "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"
33
+ processor = Wav2Vec2Processor.from_pretrained(pipeline_path, use_auth_token=access_token)
34
+ model = torch.load(model_path)
35
+ model.eval()
36
+ input_values = processor(speech, sampling_rate=rate, return_tensors="pt").input_values.to(device)
37
+ logits = model(input_values).logits
38
+ return logits
39
+
40
+ demo = gr.Interface(
41
+ fn=audio_to_text,
42
+ inputs=gr.Audio(source="upload", type="filepath"),
43
+ examples=[["example.flac"]],
44
+ outputs="text"
45
+ )
46
+ demo.launch()