File size: 2,822 Bytes
13e8440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from util import  get_client_id, get_trained_models, train_client_model, download_dataset_locally, predict_vendor_category
from typing import Optional

download_dataset_locally()
app = FastAPI()

# Models
class TrainInput(BaseModel):
    client_id: str
    data: list[list[str]]
    ignore_value: Optional[str] = 'Need help from accountant'

class PredictInput(BaseModel):
    client_id: str
    data: list[list[str]]

class UserInput(BaseModel):
    client_name: str

# Endpoints
@app.get("/models")
def get_models():
    trained_models = get_trained_models()
    if len(trained_models) == 0:
        return {"models": trained_models, "message": "No models trained yet."}
    return {"models": trained_models, "message": "List of trained models."}

@app.post("/create-client")
def create_username(user_input: UserInput):
    client_name = user_input.client_name
    trained_models = get_trained_models()
    client_ids = [m['client_id'] for m in trained_models]
    client_id = get_client_id(client_name)
    if client_id in client_ids:
        raise HTTPException(status_code=400, detail=f"Model for {client_name}, {client_id} already exists.")
    return {"client_id": client_id, "message": "client created successfully."}

@app.post("/train")
def train_model(train_input: TrainInput):
    # check if client_id contains space
    if ' ' in train_input.client_id:
        raise HTTPException(status_code=400, detail="client_id cannot contain space.")
    # check if every entry in rows is contains exactly 4 items
    for row in train_input.data:
        if len(row) != 4:
            raise HTTPException(status_code=400, detail="Each row must contain exactly 4 items.")
    training_result = train_client_model(client_id=train_input.client_id,
                     rows=train_input.data,
                     ignore_value=train_input.ignore_value)
    return {"message": f"Model '{train_input.client_id}' trained successfully.",
            "result": training_result}

@app.post("/predict")
def predict(predict_input: PredictInput):
    # check if client_id contains space
    if ' ' in predict_input.client_id:
        raise HTTPException(status_code=400, detail="client_id cannot contain space.")
    # check if every entry in rows is contains exactly 4 items
    for row in predict_input.data:
        if len(row) != 2:
            raise HTTPException(status_code=400, detail="Each row must contain exactly 2 items.")
    predictions = predict_vendor_category(client_id=predict_input.client_id,
                                          data=predict_input.data)
    return {"result": predictions,
            'message': 'Predictions generated successfully.'
            }