File size: 4,503 Bytes
9fe1f6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates

import numpy as np
from transformers import pipeline
import torch
from transformers.pipelines.audio_utils import ffmpeg_microphone_live

device = "cuda:0" if torch.cuda.is_available() else "cpu"

classifier = pipeline(
    "audio-classification", model="MIT/ast-finetuned-speech-commands-v2", device=device
)
intent_class_pipe = pipeline(
    "audio-classification", model="anton-l/xtreme_s_xlsr_minds14", device=device
)


async def launch_fn(
    wake_word="marvin",
    prob_threshold=0.5,
    chunk_length_s=2.0,
    stream_chunk_s=0.25,
    debug=False,
):
    if wake_word not in classifier.model.config.label2id.keys():
        raise ValueError(
            f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
        )

    sampling_rate = classifier.feature_extractor.sampling_rate

    mic = ffmpeg_microphone_live(
        sampling_rate=sampling_rate,
        chunk_length_s=chunk_length_s,
        stream_chunk_s=stream_chunk_s,
    )

    print("Listening for wake word...")
    for prediction in classifier(mic):
        prediction = prediction[0]
        if debug:
            print(prediction)
        if prediction["label"] == wake_word:
            if prediction["score"] > prob_threshold:
                return True


async def listen(websocket, chunk_length_s=2.0, stream_chunk_s=2.0):
    sampling_rate = intent_class_pipe.feature_extractor.sampling_rate

    mic = ffmpeg_microphone_live(
        sampling_rate=sampling_rate,
        chunk_length_s=chunk_length_s,
        stream_chunk_s=stream_chunk_s,
    )
    audio_buffer = []
    
    print("Listening")
    for i in range(4):
        audio_chunk = next(mic)
        audio_buffer.append(audio_chunk["raw"])
        
        prediction = intent_class_pipe(audio_chunk["raw"])
        await websocket.send_text(f"chunk: {prediction[0]['label']} | {i+1} / 4")
    
        if await is_silence(audio_chunk["raw"], threshold=0.7):
            print("Silence detected, processing audio.")
            break

    combined_audio = np.concatenate(audio_buffer)
    prediction = intent_class_pipe(combined_audio)
    top_3_predictions = prediction[:3]
    formatted_predictions = "\n".join([f"{pred['label']}: {pred['score'] * 100:.2f}%" for pred in top_3_predictions])
    await websocket.send_text(f"classes: \n{formatted_predictions}")
    return


async def is_silence(audio_chunk, threshold):
    silence = intent_class_pipe(audio_chunk)
    if silence[0]["label"] == "silence" and silence[0]["score"] > threshold:
        return True
    else:
        return False


# Initialize FastAPI app
app = FastAPI()

# Set up static file directory
app.mount("/static", StaticFiles(directory="static"), name="static")

# Jinja2 Template for HTML rendering
templates = Jinja2Templates(directory="templates")


@app.get("/", response_class=HTMLResponse)
async def get_home(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        process_active = False  # Flag to track the state of the process

        while True:
            message = await websocket.receive_text()

            if message == "start" and not process_active:
                process_active = True
                await websocket.send_text("Listening for wake word...")
                wake_word_detected = await launch_fn(debug=True)
                if wake_word_detected:
                    await websocket.send_text("Wake word detected. Listening for your query...")
                    await listen(websocket) 
                    process_active = False  # Reset the process flag

            elif message == "stop":
                if process_active:
                    # Implement logic to stop the ongoing process
                    # This might involve setting a flag that your launch_fn and listen functions check
                    process_active = False
                    await websocket.send_text("Process stopped. Ready to restart.")
                    break  # Or keep the loop running if you want to allow restarting without reconnecting

    except WebSocketDisconnect:
        print("Client disconnected.")