azhongai666666 commited on
Commit
bbf6b72
·
verified ·
1 Parent(s): c684106

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torchvision
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import pandas as pd
8
+ import segmentation_models_pytorch as smp
9
+ import gradio as gr
10
+
11
+ num_classes = 2
12
+ model_unet_path = "unet_model.pth"
13
+ model_fpn_path = "fpn_model.pth"
14
+ model_deeplab_path = "deeplabv3_model.pth"
15
+ image_path = "leaf11.jpg"
16
+
17
+ # Get cpu or gpu device for training.
18
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
19
+ print(f"Using {device} device")
20
+
21
+ model_unet = smp.Unet(
22
+ encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
23
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
24
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
25
+ classes=num_classes, # model output channels (number of classes in your dataset)
26
+ )
27
+
28
+ model_fpn = smp.FPN(
29
+ encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
30
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
31
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
32
+ classes=num_classes, # model output channels (number of classes in your dataset)
33
+ )
34
+
35
+ model_deeplab = smp.DeepLabV3(
36
+ encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
37
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
38
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
39
+ classes=num_classes, # model output channels (number of classes in your dataset)
40
+ )
41
+
42
+ def pred_one_image(inp,option):
43
+ one_image = np.array(inp.resize((256, 256)).convert("RGB"))
44
+ # convert to other format HWC -> CHW
45
+ one_image = np.moveaxis(one_image, -1, 0)
46
+ # mask = np.expand_dims(mask, 0)
47
+ one_image = torch.tensor(one_image).float()
48
+ one_image = one_image.unsqueeze(0)
49
+ one_image = one_image.to(device)
50
+ if option == "unet":
51
+ model_load = model_unet
52
+ elif option == "fpn":
53
+ model_load = model_fpn
54
+ elif option == "deeplab":
55
+ model_load = model_deeplab
56
+ model_load.eval()
57
+ with torch.no_grad():
58
+ output = model_load(one_image)
59
+ # print(output.shape)
60
+ predictions = torch.argmax(output, dim=1) # 获取预测的类别标签图像
61
+ pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8)
62
+ # print(pred_array.shape)
63
+ pred_img = Image.fromarray(pred_array)
64
+ # pred_img.save("pred.png")
65
+ # print(predictions.shape)
66
+ return pred_img
67
+
68
+
69
+
70
+ model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu')))
71
+ model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu')))
72
+ model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu')))
73
+
74
+ dropdown = gr.Dropdown(["unet", "fpn","deeplab"])
75
+ interface = gr.Interface(fn=pred_one_image,
76
+ inputs=[gr.Image(type="pil"),dropdown],
77
+ outputs=gr.Image(type="pil"),
78
+ examples=[["leaf11.jpg",'unet']],)
79
+ interface.launch(debug=False)