bofenghuang commited on
Commit
f1dd1bd
1 Parent(s): 16000ae

updt example

Browse files
Files changed (1) hide show
  1. README.md +16 -10
README.md CHANGED
@@ -130,23 +130,26 @@ import torchaudio
130
 
131
  from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
132
 
133
- model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").cuda()
 
 
134
  processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained("bhuang/asr-wav2vec2-french")
 
135
 
136
  wav_path = "example.wav" # path to your audio file
137
  waveform, sample_rate = torchaudio.load(wav_path)
138
  waveform = waveform.squeeze(axis=0) # mono
139
 
140
  # resample
141
- if sample_rate != 16_000:
142
- resampler = torchaudio.transforms.Resample(sample_rate, 16_000)
143
  waveform = resampler(waveform)
144
 
145
  # normalize
146
- input_dict = processor_with_lm(waveform, sampling_rate=16_000, return_tensors="pt")
147
 
148
  with torch.inference_mode():
149
- logits = model(input_dict.input_values.to("cuda")).logits
150
 
151
  predicted_sentence = processor_with_lm.batch_decode(logits.cpu().numpy()).text[0]
152
  ```
@@ -159,23 +162,26 @@ import torchaudio
159
 
160
  from transformers import AutoModelForCTC, Wav2Vec2Processor
161
 
162
- model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").cuda()
 
 
163
  processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french")
 
164
 
165
  wav_path = "example.wav" # path to your audio file
166
  waveform, sample_rate = torchaudio.load(wav_path)
167
  waveform = waveform.squeeze(axis=0) # mono
168
 
169
  # resample
170
- if sample_rate != 16_000:
171
- resampler = torchaudio.transforms.Resample(sample_rate, 16_000)
172
  waveform = resampler(waveform)
173
 
174
  # normalize
175
- input_dict = processor(waveform, sampling_rate=16_000, return_tensors="pt")
176
 
177
  with torch.inference_mode():
178
- logits = model(input_dict.input_values.to("cuda")).logits
179
 
180
  # decode
181
  predicted_ids = torch.argmax(logits, dim=-1)
 
130
 
131
  from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
132
 
133
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+
135
+ model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device)
136
  processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained("bhuang/asr-wav2vec2-french")
137
+ model_sample_rate = processor_with_lm.feature_extractor.sampling_rate
138
 
139
  wav_path = "example.wav" # path to your audio file
140
  waveform, sample_rate = torchaudio.load(wav_path)
141
  waveform = waveform.squeeze(axis=0) # mono
142
 
143
  # resample
144
+ if sample_rate != model_sample_rate:
145
+ resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate)
146
  waveform = resampler(waveform)
147
 
148
  # normalize
149
+ input_dict = processor_with_lm(waveform, sampling_rate=model_sample_rate, return_tensors="pt")
150
 
151
  with torch.inference_mode():
152
+ logits = model(input_dict.input_values.to(device)).logits
153
 
154
  predicted_sentence = processor_with_lm.batch_decode(logits.cpu().numpy()).text[0]
155
  ```
 
162
 
163
  from transformers import AutoModelForCTC, Wav2Vec2Processor
164
 
165
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
+
167
+ model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device)
168
  processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french")
169
+ model_sample_rate = processor.feature_extractor.sampling_rate
170
 
171
  wav_path = "example.wav" # path to your audio file
172
  waveform, sample_rate = torchaudio.load(wav_path)
173
  waveform = waveform.squeeze(axis=0) # mono
174
 
175
  # resample
176
+ if sample_rate != model_sample_rate:
177
+ resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate)
178
  waveform = resampler(waveform)
179
 
180
  # normalize
181
+ input_dict = processor(waveform, sampling_rate=model_sample_rate, return_tensors="pt")
182
 
183
  with torch.inference_mode():
184
+ logits = model(input_dict.input_values.to(device)).logits
185
 
186
  # decode
187
  predicted_ids = torch.argmax(logits, dim=-1)