bienom's picture
hotfix
d2d0403
import gradio as gr
from model import SixDRepNet
import os
import numpy as np
import torch
from torchvision import transforms
import utils
import time
transformations = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
model = SixDRepNet(backbone_name='RepVGG-A0',
backbone_file='',
deploy=True,
pretrained=False)
saved_state_dict = torch.load(os.path.join(
"weights_ALFW_A0.pth"), map_location='cpu')
if 'model_state_dict' in saved_state_dict:
model.load_state_dict(saved_state_dict['model_state_dict'])
else:
model.load_state_dict(saved_state_dict)
# Test the Model
model.eval() # Change model to 'eval' mode (BN uses moving mean/var).
th = 15
def predict(img):
img = img.convert('RGB')
img = transformations(img).unsqueeze(0)
with torch.no_grad():
start = time.time()
R_pred = model(img)
end = time.time()
timemilis = (end - start)*1000
euler = utils.compute_euler_angles_from_rotation_matrices(
R_pred,use_gpu=False)*180/np.pi
p_pred_deg = euler[:, 0].cpu().item()
y_pred_deg = euler[:, 1].cpu().item()
direction_str = ""
if p_pred_deg > th:
direction_str = "UP "
elif p_pred_deg < -th:
direction_str ="DOWN "
if y_pred_deg > th:
direction_str += "LEFT"
elif y_pred_deg < -th:
direction_str += "RIGHT"
return f"Yaw: {y_pred_deg:0.1f} \n Pitch: {p_pred_deg:0.1f}\n Direction: {direction_str} \n Time: {timemilis:0.2f}ms"
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
examples=["face_left.jpg","face_right.jpg","face_up.jpg","face_down.jpg"]).launch(share=True)