THE INPUT IMAGE MUST HAVE RGB
CHANNELS. IT WILL NOT WORK WITH RGBA
CHANNELS!
Usage
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CNN(nn.Module):
def __init__(self, hidden_size=512):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 192 * 192, hidden_size)
self.fc2 = nn.Linear(hidden_size, 2)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, kernel_size=2, stride=2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(-1, 32 * 192 * 192)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = CNN().to(device).half()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2.5e-5)
transform = transforms.Compose([
transforms.Resize((768, 768)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def infer(model, image_path):
model.eval()
image = Image.open(image_path)
image = transform(image).unsqueeze(0).to(device).half()
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output).item()
return predicted_class
checkpoint = torch.load('half_precision_model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
image_path = 'good.jpg'
predicted_class = infer(model, image_path)
if int(predicted_class) == 0:
print('Predicted class: Bad Image')
elif int(predicted_class) == 1:
print('Predicted class: Good Image')