Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torchvision | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import pandas as pd | |
import segmentation_models_pytorch as smp | |
import gradio as gr | |
num_classes = 2 | |
model_unet_path = "unet_model.pth" | |
model_fpn_path = "fpn_model.pth" | |
model_deeplab_path = "deeplabv3_model.pth" | |
image_path = "leaf11.jpg" | |
# Get cpu or gpu device for training. | |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
print(f"Using {device} device") | |
model_unet = smp.Unet( | |
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
classes=num_classes, # model output channels (number of classes in your dataset) | |
) | |
model_fpn = smp.FPN( | |
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
classes=num_classes, # model output channels (number of classes in your dataset) | |
) | |
model_deeplab = smp.DeepLabV3( | |
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | |
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization | |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | |
classes=num_classes, # model output channels (number of classes in your dataset) | |
) | |
def pred_one_image(inp,option): | |
one_image = np.array(inp.resize((256, 256)).convert("RGB")) | |
# convert to other format HWC -> CHW | |
one_image = np.moveaxis(one_image, -1, 0) | |
# mask = np.expand_dims(mask, 0) | |
one_image = torch.tensor(one_image).float() | |
one_image = one_image.unsqueeze(0) | |
one_image = one_image.to(device) | |
if option == "unet": | |
model_load = model_unet | |
elif option == "fpn": | |
model_load = model_fpn | |
elif option == "deeplab": | |
model_load = model_deeplab | |
model_load.eval() | |
with torch.no_grad(): | |
output = model_load(one_image) | |
# print(output.shape) | |
predictions = torch.argmax(output, dim=1) # č·åé¢ęµēē±»å«ę ē¾å¾å | |
pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8) | |
# print(pred_array.shape) | |
pred_img = Image.fromarray(pred_array) | |
# pred_img.save("pred.png") | |
# print(predictions.shape) | |
return pred_img | |
model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu'))) | |
model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu'))) | |
model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu'))) | |
dropdown = gr.Dropdown(["unet", "fpn","deeplab"]) | |
interface = gr.Interface(fn=pred_one_image, | |
inputs=[gr.Image(type="pil"),dropdown], | |
outputs=gr.Image(type="pil"), | |
examples=[["leaf11.jpg",'unet']],) | |
interface.launch(debug=False) | |