Spaces:
Sleeping
Sleeping
# fastapi_crud/app/main.py | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Depends | |
from fastapi.responses import JSONResponse | |
from sqlalchemy.orm import Session | |
from pydub import AudioSegment | |
import io | |
import spacy | |
import speech_recognition as sr | |
from app.database import engine, Base, get_db | |
from app.routers import user, device | |
from app import crud, schemas, auth, models | |
# Create the database tables | |
Base.metadata.create_all(bind=engine) | |
app = FastAPI() | |
app.include_router(user.router) | |
app.include_router(device.router) | |
# Load spaCy models | |
nlp = spacy.load("custom_nlp_model") | |
nlp2 = spacy.load("text_categorizer_model") | |
def convert_audio_to_text(audio_file: UploadFile): | |
try: | |
audio_format = audio_file.filename.split(".")[-1] | |
if audio_format not in ["wav", "mp3", "ogg", "flac"]: | |
raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a wav, mp3, ogg, or flac file.") | |
audio = AudioSegment.from_file(io.BytesIO(audio_file.file.read()), format=audio_format) | |
audio = audio.set_channels(1).set_frame_rate(16000) | |
wav_io = io.BytesIO() | |
audio.export(wav_io, format="wav") | |
wav_io.seek(0) | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(wav_io) as source: | |
audio_data = recognizer.record(source) | |
text = recognizer.recognize_google(audio_data) | |
return text | |
except sr.UnknownValueError: | |
raise HTTPException(status_code=400, detail="Speech recognition could not understand the audio.") | |
except sr.RequestError as e: | |
raise HTTPException(status_code=500, detail=f"Speech recognition service error: {e}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Audio processing error: {e}") | |
def update_device_status(db: Session, device_name: str, location: str, status: str): | |
db_device = db.query(models.Device).filter(models.Device.name == device_name, models.Device.location == location).first() | |
active_status = True if status.lower() == "on" else False | |
if db_device: | |
crud.set_device_active(db, db_device.id, active_status) | |
return {"device": db_device.name, "location": db_device.location, "status": "turned " + status} | |
return {"device": device_name, "location": location, "status": "not found"} | |
def find_location(db: Session, text: str): | |
words = text.split() | |
for word in words: | |
db_location = db.query(models.Device).filter(models.Device.location == word).first() | |
if db_location: | |
return db_location.location | |
return '' | |
def process_entities_and_update(db: Session, doc, text, status): | |
updates = [] | |
location_entity = next((ent for ent in doc.ents if ent.label_.lower() == 'location'), None) | |
location = location_entity.text if location_entity else find_location(db, text) | |
for ent in doc.ents: | |
if ent.label_ == 'device': | |
update = update_device_status(db, ent.text, location, status) | |
updates.append(update) | |
if not updates: # No device entities found, process all words as potential device names | |
words = text.split() | |
for word in words: | |
update = update_device_status(db, word, location, status) | |
updates.append(update) | |
return updates | |
async def predict(audio_file: UploadFile = File(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)): | |
try: | |
text = convert_audio_to_text(audio_file) | |
doc = nlp(text) | |
doc2 = nlp2(text) | |
predictions = {"category": max(doc2.cats, key=doc2.cats.get)} | |
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] | |
updates = process_entities_and_update(db, doc, text, predictions['category']) | |
return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates}) | |
except HTTPException as e: | |
return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) | |
async def predict_text(text: str = Form(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)): | |
try: | |
doc = nlp(text) | |
doc2 = nlp2(text) | |
predictions = {"category": max(doc2.cats, key=doc2.cats.get)} | |
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] | |
updates = process_entities_and_update(db, doc, text, predictions['category']) | |
return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates}) | |
except HTTPException as e: | |
return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) | |