File size: 2,428 Bytes
7bcf8d7
 
 
d697dab
f0e1d81
7bcf8d7
f0e1d81
7bcf8d7
 
 
 
f0e1d81
7bcf8d7
 
 
f0e1d81
7bcf8d7
f0e1d81
7bcf8d7
 
f0e1d81
7bcf8d7
f0e1d81
 
 
 
7bcf8d7
f0e1d81
 
8a0a956
 
f0e1d81
 
d697dab
f0e1d81
d697dab
a043894
d4f2e14
f0e1d81
 
 
d697dab
7bcf8d7
f0e1d81
7bcf8d7
f0e1d81
 
7bcf8d7
f0e1d81
 
 
7bcf8d7
f0e1d81
 
7bcf8d7
 
f0e1d81
 
7bcf8d7
 
 
f0e1d81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch
import librosa
import numpy as np
from torch.cuda.amp import autocast

# Load model and processor only once
model_id = "facebook/mms-lid-1024"
processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)

# Constants
LID_SAMPLING_RATE = 16_000
LID_THRESHOLD = 0.33

# Load LID languages only once
LID_LANGUAGES = {}
with open("data/lid/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        LID_LANGUAGES[iso] = name.strip()

# Set device once (prefer GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu")
model.to(device)
print(f"Using device: {device}")

def identify(audio_data=None):
    """Identifies the language of the input audio."""
    if not audio_data:
        return "<<ERROR: Empty Audio Input>>"
    
    # Preprocess audio data
    if isinstance(audio_data, tuple):
        # Microphone input
        sr, audio_samples = audio_data
        audio_samples = (audio_samples / 32768.0).astype(np.float32)
        if sr != LID_SAMPLING_RATE:
            audio_samples = librosa.resample(audio_samples, orig_sr=sr, target_sr=LID_SAMPLING_RATE)
    elif isinstance(audio_data, str):
        # File upload
        audio_samples = librosa.load(audio_data, sr=LID_SAMPLING_RATE, mono=True)[0]
    else:
        return f"<<ERROR: Invalid Audio Input Instance: {type(audio_data)}>>"

    # Extract features
    inputs = processor(audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt").to(device)

    # Perform inference with mixed precision
    with torch.no_grad(), autocast():
        logits = model(**inputs).logits

    # Compute scores and indices
    logit_lsm = torch.log_softmax(logits.squeeze(), dim=-1)
    scores, indices = torch.topk(logit_lsm, 5, dim=-1)
    scores, indices = torch.exp(scores).to("cpu").tolist(), indices.to("cpu").tolist()

    # Map scores to language labels
    iso2score = {model.config.id2label[int(i)]: s for s, i in zip(scores, indices)}
    if max(iso2score.values()) < LID_THRESHOLD:
        return "Low confidence in the language identification predictions. Output is not shown!"
    
    return {LID_LANGUAGES[iso]: score for iso, score in iso2score.items()}