|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from util import get_client_id, get_trained_models, train_client_model, download_dataset_locally, predict_vendor_category
|
|
from typing import Optional
|
|
|
|
download_dataset_locally()
|
|
app = FastAPI()
|
|
|
|
|
|
class TrainInput(BaseModel):
|
|
client_id: str
|
|
data: list[list[str]]
|
|
ignore_value: Optional[str] = 'Need help from accountant'
|
|
|
|
class PredictInput(BaseModel):
|
|
client_id: str
|
|
data: list[list[str]]
|
|
|
|
class UserInput(BaseModel):
|
|
client_name: str
|
|
|
|
|
|
@app.get("/models")
|
|
def get_models():
|
|
trained_models = get_trained_models()
|
|
if len(trained_models) == 0:
|
|
return {"models": trained_models, "message": "No models trained yet."}
|
|
return {"models": trained_models, "message": "List of trained models."}
|
|
|
|
@app.post("/create-client")
|
|
def create_username(user_input: UserInput):
|
|
client_name = user_input.client_name
|
|
trained_models = get_trained_models()
|
|
client_ids = [m['client_id'] for m in trained_models]
|
|
client_id = get_client_id(client_name)
|
|
if client_id in client_ids:
|
|
raise HTTPException(status_code=400, detail=f"Model for {client_name}, {client_id} already exists.")
|
|
return {"client_id": client_id, "message": "client created successfully."}
|
|
|
|
@app.post("/train")
|
|
def train_model(train_input: TrainInput):
|
|
|
|
if ' ' in train_input.client_id:
|
|
raise HTTPException(status_code=400, detail="client_id cannot contain space.")
|
|
|
|
for row in train_input.data:
|
|
if len(row) != 4:
|
|
raise HTTPException(status_code=400, detail="Each row must contain exactly 4 items.")
|
|
training_result = train_client_model(client_id=train_input.client_id,
|
|
rows=train_input.data,
|
|
ignore_value=train_input.ignore_value)
|
|
return {"message": f"Model '{train_input.client_id}' trained successfully.",
|
|
"result": training_result}
|
|
|
|
@app.post("/predict")
|
|
def predict(predict_input: PredictInput):
|
|
|
|
if ' ' in predict_input.client_id:
|
|
raise HTTPException(status_code=400, detail="client_id cannot contain space.")
|
|
|
|
for row in predict_input.data:
|
|
if len(row) != 2:
|
|
raise HTTPException(status_code=400, detail="Each row must contain exactly 2 items.")
|
|
predictions = predict_vendor_category(client_id=predict_input.client_id,
|
|
data=predict_input.data)
|
|
return {"result": predictions,
|
|
'message': 'Predictions generated successfully.'
|
|
}
|
|
|
|
|