|
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor |
|
import torch |
|
import librosa |
|
import numpy as np |
|
from torch.cuda.amp import autocast |
|
import torch.nn.functional as F |
|
|
|
model_id = "facebook/mms-lid-1024" |
|
|
|
|
|
processor = AutoFeatureExtractor.from_pretrained(model_id) |
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) |
|
|
|
|
|
LID_SAMPLING_RATE = 16_000 |
|
LID_TOPK = 10 |
|
LID_THRESHOLD = 0.33 |
|
MAX_DURATION = 30 |
|
|
|
|
|
LID_LANGUAGES = {} |
|
with open(f"data/lid/all_langs.tsv") as f: |
|
for line in f: |
|
iso, name = line.split(" ", 1) |
|
LID_LANGUAGES[iso] = name.strip() |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
USE_FP16 = torch.cuda.is_available() |
|
|
|
|
|
model = model.to(DEVICE) |
|
if USE_FP16: |
|
model = model.half() |
|
model.eval() |
|
|
|
def process_audio_chunk(audio_chunk, sampling_rate): |
|
"""Process audio data, ensuring it's limited to 30 seconds.""" |
|
|
|
max_samples = MAX_DURATION * sampling_rate |
|
|
|
|
|
if len(audio_chunk) > max_samples: |
|
audio_chunk = audio_chunk[:max_samples] |
|
|
|
|
|
if sampling_rate != LID_SAMPLING_RATE: |
|
audio_chunk = librosa.resample( |
|
audio_chunk, orig_sr=sampling_rate, target_sr=LID_SAMPLING_RATE |
|
) |
|
|
|
|
|
if audio_chunk.dtype != np.float32: |
|
audio_chunk = audio_chunk.astype(np.float32) |
|
if np.abs(audio_chunk).max() > 1.0: |
|
audio_chunk = audio_chunk / 32768.0 |
|
|
|
return audio_chunk |
|
|
|
def identify(audio_data): |
|
""" |
|
Identify the language from the first 30 seconds of audio data. |
|
|
|
Args: |
|
audio_data: Can be either: |
|
- numpy array of audio samples |
|
- tuple of (sample_rate, audio_samples) |
|
- path to audio file |
|
|
|
Returns: |
|
dict: Mapping of language names to confidence scores |
|
""" |
|
if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0): |
|
return "<<ERROR: Empty Audio Input>>" |
|
|
|
try: |
|
|
|
if isinstance(audio_data, tuple): |
|
sr, audio_samples = audio_data |
|
audio_samples = process_audio_chunk(audio_samples, sr) |
|
elif isinstance(audio_data, str): |
|
|
|
audio_samples, sr = librosa.load( |
|
audio_data, |
|
sr=LID_SAMPLING_RATE, |
|
mono=True, |
|
duration=MAX_DURATION |
|
) |
|
audio_samples = process_audio_chunk(audio_samples, sr) |
|
elif isinstance(audio_data, np.ndarray): |
|
audio_samples = process_audio_chunk(audio_data, LID_SAMPLING_RATE) |
|
else: |
|
raise ValueError(f"Unsupported audio_data type: {type(audio_data)}") |
|
|
|
|
|
inputs = processor( |
|
audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt" |
|
) |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
if USE_FP16: |
|
with autocast(): |
|
logits = model(**inputs).logits |
|
else: |
|
logits = model(**inputs).logits |
|
|
|
|
|
probs = F.softmax(logits.float(), dim=-1) |
|
scores, indices = torch.topk(probs[0], LID_TOPK) |
|
|
|
|
|
scores = scores.cpu().tolist() |
|
indices = indices.cpu().tolist() |
|
|
|
|
|
iso2score = {model.config.id2label[i]: s for i, s in zip(indices, scores)} |
|
|
|
if max(iso2score.values()) < LID_THRESHOLD: |
|
return "Low confidence in the language identification predictions. Output is not shown!" |
|
|
|
|
|
results = { |
|
LID_LANGUAGES[iso]: score |
|
for iso, score in iso2score.items() |
|
if score >= LID_THRESHOLD and iso in LID_LANGUAGES |
|
} |
|
|
|
return results if results else "No language detected above threshold." |
|
|
|
except Exception as e: |
|
return f"Error processing audio: {str(e)}" |
|
|
|
|
|
LID_EXAMPLES = [ |
|
["upload/english.mp3"], |
|
["upload/tamil.mp3"], |
|
["upload/burmese.mp3"], |
|
] |
|
|