salek877's picture
main api file
8e22f7f
raw
history blame
1.43 kB
from fastapi import FastAPI, File, UploadFile
from starlette.middleware.cors import CORSMiddleware
from PIL import Image
import tensorflow as tf
import numpy as np
import io
app = FastAPI()
# Enable CORS to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load the Coffee Land Classifier model
model_path = "model/model.h5"
class_labels = ["Coffee Land", "Not Coffee Land"]
model = tf.keras.models.load_model(model_path, compile=False)
def preprocess_image(image):
# Resize and preprocess the image
img = image.resize((64, 64))
img = np.array(img)
img = img.astype('float32') / 255.0
img = np.expand_dims(img, axis=0)
return img
def predict_class(image):
img = preprocess_image(image)
predictions = model.predict(img)
class_index = np.argmax(predictions)
predicted_class = class_labels[class_index]
return predicted_class, predictions[0].tolist()
@app.post("/predict/")
async def predict(upload_file: UploadFile = File(...)):
file_contents = await upload_file.read() # Use upload_file, not request.file
print(f"Received file with size: {len(file_contents)} bytes")
image = Image.open(io.BytesIO(file_contents))
predicted_class, class_probabilities = predict_class(image)
return {"predicted_class": predicted_class, "class_probabilities": class_probabilities}