Spaces:
Sleeping
Sleeping
File size: 4,915 Bytes
dc65e63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# fastapi_crud/app/main.py
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Depends
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from pydub import AudioSegment
import io
import spacy
import speech_recognition as sr
from app.database import engine, Base, get_db
from app.routers import user, device
from app import crud, schemas, auth, models
# Create the database tables
Base.metadata.create_all(bind=engine)
app = FastAPI()
app.include_router(user.router)
app.include_router(device.router)
# Load spaCy models
nlp = spacy.load("custom_nlp_model")
nlp2 = spacy.load("text_categorizer_model")
def convert_audio_to_text(audio_file: UploadFile):
try:
audio_format = audio_file.filename.split(".")[-1]
if audio_format not in ["wav", "mp3", "ogg", "flac"]:
raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a wav, mp3, ogg, or flac file.")
audio = AudioSegment.from_file(io.BytesIO(audio_file.file.read()), format=audio_format)
audio = audio.set_channels(1).set_frame_rate(16000)
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
recognizer = sr.Recognizer()
with sr.AudioFile(wav_io) as source:
audio_data = recognizer.record(source)
text = recognizer.recognize_google(audio_data)
return text
except sr.UnknownValueError:
raise HTTPException(status_code=400, detail="Speech recognition could not understand the audio.")
except sr.RequestError as e:
raise HTTPException(status_code=500, detail=f"Speech recognition service error: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Audio processing error: {e}")
def update_device_status(db: Session, device_name: str, location: str, status: str):
db_device = db.query(models.Device).filter(models.Device.name == device_name, models.Device.location == location).first()
active_status = True if status.lower() == "on" else False
if db_device:
crud.set_device_active(db, db_device.id, active_status)
return {"device": db_device.name, "location": db_device.location, "status": "turned " + status}
return {"device": device_name, "location": location, "status": "not found"}
def find_location(db: Session, text: str):
words = text.split()
for word in words:
db_location = db.query(models.Device).filter(models.Device.location == word).first()
if db_location:
return db_location.location
return ''
def process_entities_and_update(db: Session, doc, text, status):
updates = []
location_entity = next((ent for ent in doc.ents if ent.label_.lower() == 'location'), None)
location = location_entity.text if location_entity else find_location(db, text)
for ent in doc.ents:
if ent.label_ == 'device':
update = update_device_status(db, ent.text, location, status)
updates.append(update)
if not updates: # No device entities found, process all words as potential device names
words = text.split()
for word in words:
update = update_device_status(db, word, location, status)
updates.append(update)
return updates
@app.post("/predict/")
async def predict(audio_file: UploadFile = File(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)):
try:
text = convert_audio_to_text(audio_file)
doc = nlp(text)
doc2 = nlp2(text)
predictions = {"category": max(doc2.cats, key=doc2.cats.get)}
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
updates = process_entities_and_update(db, doc, text, predictions['category'])
return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates})
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
@app.post("/predict_text/")
async def predict_text(text: str = Form(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)):
try:
doc = nlp(text)
doc2 = nlp2(text)
predictions = {"category": max(doc2.cats, key=doc2.cats.get)}
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
updates = process_entities_and_update(db, doc, text, predictions['category'])
return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates})
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|