sts / lid.py
lewistape's picture
Update lid.py
5f620cd verified
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch
import librosa
import numpy as np
from torch.cuda.amp import autocast
import torch.nn.functional as F
model_id = "facebook/mms-lid-1024"
# Load model and processor only once
processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
# Constants
LID_SAMPLING_RATE = 16_000
LID_TOPK = 10
LID_THRESHOLD = 0.33
MAX_DURATION = 30 # Maximum duration in seconds
# Load language mappings
LID_LANGUAGES = {}
with open(f"data/lid/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
LID_LANGUAGES[iso] = name.strip()
# Set up device and optimization settings
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_FP16 = torch.cuda.is_available() # Only use FP16 if CUDA is available
# Move model to device and optimize
model = model.to(DEVICE)
if USE_FP16:
model = model.half() # Convert model to FP16
model.eval() # Set to evaluation mode
def process_audio_chunk(audio_chunk, sampling_rate):
"""Process audio data, ensuring it's limited to 30 seconds."""
# Calculate maximum number of samples for 30 seconds
max_samples = MAX_DURATION * sampling_rate
# Trim to first 30 seconds if longer
if len(audio_chunk) > max_samples:
audio_chunk = audio_chunk[:max_samples]
# Resample if necessary
if sampling_rate != LID_SAMPLING_RATE:
audio_chunk = librosa.resample(
audio_chunk, orig_sr=sampling_rate, target_sr=LID_SAMPLING_RATE
)
# Ensure proper normalization
if audio_chunk.dtype != np.float32:
audio_chunk = audio_chunk.astype(np.float32)
if np.abs(audio_chunk).max() > 1.0:
audio_chunk = audio_chunk / 32768.0
return audio_chunk
def identify(audio_data):
"""
Identify the language from the first 30 seconds of audio data.
Args:
audio_data: Can be either:
- numpy array of audio samples
- tuple of (sample_rate, audio_samples)
- path to audio file
Returns:
dict: Mapping of language names to confidence scores
"""
if audio_data is None or (isinstance(audio_data, np.ndarray) and audio_data.size == 0):
return "<<ERROR: Empty Audio Input>>"
try:
# Process input based on type
if isinstance(audio_data, tuple):
sr, audio_samples = audio_data
audio_samples = process_audio_chunk(audio_samples, sr)
elif isinstance(audio_data, str):
# Load only first 30 seconds using duration parameter
audio_samples, sr = librosa.load(
audio_data,
sr=LID_SAMPLING_RATE,
mono=True,
duration=MAX_DURATION
)
audio_samples = process_audio_chunk(audio_samples, sr)
elif isinstance(audio_data, np.ndarray):
audio_samples = process_audio_chunk(audio_data, LID_SAMPLING_RATE)
else:
raise ValueError(f"Unsupported audio_data type: {type(audio_data)}")
# Prepare input for model
inputs = processor(
audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt"
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# Run inference with mixed precision
with torch.no_grad():
if USE_FP16:
with autocast():
logits = model(**inputs).logits
else:
logits = model(**inputs).logits
# Convert logits to probabilities
probs = F.softmax(logits.float(), dim=-1)
scores, indices = torch.topk(probs[0], LID_TOPK) # Use LID_TOPK for consistency
# Convert to CPU and standard Python types
scores = scores.cpu().tolist()
indices = indices.cpu().tolist()
# Create results dictionary
iso2score = {model.config.id2label[i]: s for i, s in zip(indices, scores)}
if max(iso2score.values()) < LID_THRESHOLD:
return "Low confidence in the language identification predictions. Output is not shown!"
# Filter and map to language names
results = {
LID_LANGUAGES[iso]: score
for iso, score in iso2score.items()
if score >= LID_THRESHOLD and iso in LID_LANGUAGES # Add threshold check here
}
return results if results else "No language detected above threshold."
except Exception as e:
return f"Error processing audio: {str(e)}"
# Example usage
LID_EXAMPLES = [
["upload/english.mp3"],
["upload/tamil.mp3"],
["upload/burmese.mp3"],
]