Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import shutil | |
import os | |
import torch | |
import TwinLite as net | |
from PIL import Image | |
from dotenv import load_dotenv | |
load_dotenv() | |
share = os.getenv("SHARE", False) | |
model = net.TwinLiteNet() | |
import cv2 | |
def Run(model,img): | |
img = cv2.resize(img, (640, 360)) | |
img_rs=img.copy() | |
img = img[:, :, ::-1].transpose(2, 0, 1) | |
img = np.ascontiguousarray(img) | |
img=torch.from_numpy(img) | |
img = torch.unsqueeze(img, 0) # add a batch dimension | |
img=img.float() / 255.0 | |
img = img | |
with torch.no_grad(): | |
img_out = model(img) | |
x0=img_out[0] | |
x1=img_out[1] | |
_,da_predict=torch.max(x0, 1) | |
_,ll_predict=torch.max(x1, 1) | |
DA = da_predict.byte().cpu().data.numpy()[0]*255 | |
LL = ll_predict.byte().cpu().data.numpy()[0]*255 | |
img_rs[DA>100]=[255,0,0] | |
img_rs[LL>100]=[0,255,0] | |
return img_rs | |
model = net.TwinLiteNet() | |
model = torch.nn.DataParallel(model) | |
model.load_state_dict(torch.load('fine-tuned-model.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
def predict(image): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image.save("input.png") | |
img = cv2.imread("input.png") | |
img = Run(model, img) | |
cv2.imwrite("sample.png", img) | |
prediction = Image.open("sample.png") | |
return prediction | |
iface = gr.Interface(fn=predict, inputs="image", outputs="image", title="Image Segmentation") | |
if __name__ == "__main__": | |
if share: | |
server = "0.0.0.0" | |
else: | |
server = "127.0.0.1" | |
iface.launch(server_name = server) | |