File size: 2,822 Bytes
13e8440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from util import get_client_id, get_trained_models, train_client_model, download_dataset_locally, predict_vendor_category
from typing import Optional
download_dataset_locally()
app = FastAPI()
# Models
class TrainInput(BaseModel):
client_id: str
data: list[list[str]]
ignore_value: Optional[str] = 'Need help from accountant'
class PredictInput(BaseModel):
client_id: str
data: list[list[str]]
class UserInput(BaseModel):
client_name: str
# Endpoints
@app.get("/models")
def get_models():
trained_models = get_trained_models()
if len(trained_models) == 0:
return {"models": trained_models, "message": "No models trained yet."}
return {"models": trained_models, "message": "List of trained models."}
@app.post("/create-client")
def create_username(user_input: UserInput):
client_name = user_input.client_name
trained_models = get_trained_models()
client_ids = [m['client_id'] for m in trained_models]
client_id = get_client_id(client_name)
if client_id in client_ids:
raise HTTPException(status_code=400, detail=f"Model for {client_name}, {client_id} already exists.")
return {"client_id": client_id, "message": "client created successfully."}
@app.post("/train")
def train_model(train_input: TrainInput):
# check if client_id contains space
if ' ' in train_input.client_id:
raise HTTPException(status_code=400, detail="client_id cannot contain space.")
# check if every entry in rows is contains exactly 4 items
for row in train_input.data:
if len(row) != 4:
raise HTTPException(status_code=400, detail="Each row must contain exactly 4 items.")
training_result = train_client_model(client_id=train_input.client_id,
rows=train_input.data,
ignore_value=train_input.ignore_value)
return {"message": f"Model '{train_input.client_id}' trained successfully.",
"result": training_result}
@app.post("/predict")
def predict(predict_input: PredictInput):
# check if client_id contains space
if ' ' in predict_input.client_id:
raise HTTPException(status_code=400, detail="client_id cannot contain space.")
# check if every entry in rows is contains exactly 4 items
for row in predict_input.data:
if len(row) != 2:
raise HTTPException(status_code=400, detail="Each row must contain exactly 2 items.")
predictions = predict_vendor_category(client_id=predict_input.client_id,
data=predict_input.data)
return {"result": predictions,
'message': 'Predictions generated successfully.'
}
|