okeowo1014's picture
Update main.py
3b7916d verified
raw
history blame
2.51 kB
import io
from fastapi import FastAPI, File, UploadFile
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import torch.nn as nn
num_classes = 10
# Class definition for the model (same as in your code)
class FingerprintRecognitionModel(nn.Module):
def __init__(self, num_classes):
super(FingerprintRecognitionModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.fc1 = nn.Linear(128 * 28 * 28, 256)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 28 * 28)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
app = FastAPI()
# Load the model
model_path = 'fingerprint_recognition_model.pt'
model = FingerprintRecognitionModel(num_classes)
model.load_state_dict(torch.load(model_path))
model.eval()
def preprocess_image(image_bytes):
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale
# Resize to 224x224
img_resized = image.resize((224, 224))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Apply transforms and add batch dimension
img_tensor = transform(img_resized).unsqueeze(0)
return img_tensor
def predict_class(image_bytes):
img_tensor = preprocess_image(image_bytes)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs.data, 1)
predicted_class = int(predicted.item())
return predicted_class
@app.post("/predict/")
async def predict_endpoint(file: UploadFile = File(...)):
contents = await file.read()
predicted_class = predict_class(contents)
class_labels = {0: 'Left_ring_fingers', 1: 'Left_thumb_fingers', 2: 'Right_index_fingers', 3: 'Right_little_fingers', 4: 'Right_middle_fingers', 5: 'Right_ring_fingers', 6: 'Right_thumb_fingers', 7: 'left_index_fingers', 8: 'left_little_fingers', 9: 'left_middle_fingers'}
return {"predicted_class": predicted_class, "class_label": class_labels[predicted_class]}