mayhug commited on
Commit
9b783d7
·
1 Parent(s): 639d058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -13
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import json
 
2
 
3
- import gradio as gr
4
  import tensorflow as tf
5
  import tensorflow.keras as keras
6
- from gradio import inputs, outputs
 
 
 
7
 
8
- SIZE = 256
9
  DEVICE = "/cpu:0"
10
 
11
 
@@ -14,40 +16,83 @@ with open("./tags.json", "rt", encoding="utf-8") as f:
14
 
15
 
16
  with tf.device(DEVICE):
17
- base_model = keras.applications.resnet.ResNet50(
18
- include_top=False, weights=None, input_shape=(SIZE, SIZE, 3)
 
 
 
 
 
 
 
 
 
 
19
  )
20
- model = keras.Sequential(
 
 
 
21
  [
22
- base_model,
 
 
 
 
23
  keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
24
  keras.layers.BatchNormalization(epsilon=1.001e-5),
25
  keras.layers.GlobalAveragePooling2D(name="avg_pool"),
26
  keras.layers.Activation("sigmoid"),
27
  ]
28
  )
29
- model.load_weights("tf_model.h5")
30
 
31
 
32
- def predict(img, hide: float):
33
  with tf.device(DEVICE):
34
- img = tf.image.resize_with_pad(img, SIZE, SIZE)
35
  img = tf.image.per_image_standardization(img)
36
  data = tf.expand_dims(img, 0)
37
- out, *_ = model(data)
38
- return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  tag: confidence
40
  for i, tag in enumerate(tags)
41
  if (confidence := float(out[i].numpy())) >= hide
42
  }
 
 
43
 
44
 
45
  image = inputs.Image(label="Upload your image here!")
46
  hide_threshold = inputs.Slider(
47
  label="Hide confidence lower than", default=0.5, maximum=1, minimum=0
48
  )
 
 
 
49
 
50
  labels = outputs.Label(label="Tags", type="confidences")
51
 
52
- interface = gr.Interface(predict, inputs=[image, hide_threshold], outputs=[labels])
 
 
53
  interface.launch()
 
1
  import json
2
+ from pprint import pprint
3
 
 
4
  import tensorflow as tf
5
  import tensorflow.keras as keras
6
+ from gradio import Interface, inputs, outputs
7
+
8
+ RESNET50_SIZE = 256
9
+ RESNET101_SIZE = 360
10
 
 
11
  DEVICE = "/cpu:0"
12
 
13
 
 
16
 
17
 
18
  with tf.device(DEVICE):
19
+ model_resnet50 = keras.Sequential(
20
+ [
21
+ keras.applications.resnet.ResNet50(
22
+ include_top=False,
23
+ weights=None,
24
+ input_shape=(RESNET50_SIZE, RESNET50_SIZE, 3),
25
+ ),
26
+ keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
27
+ keras.layers.BatchNormalization(epsilon=1.001e-5),
28
+ keras.layers.GlobalAveragePooling2D(name="avg_pool"),
29
+ keras.layers.Activation("sigmoid"),
30
+ ]
31
  )
32
+ model_resnet50.load_weights("./tf_model_resnet50.h5")
33
+
34
+ with tf.device(DEVICE):
35
+ model_resnet101 = keras.Sequential(
36
  [
37
+ keras.applications.resnet.ResNet101(
38
+ include_top=False,
39
+ weights=None,
40
+ input_shape=(RESNET101_SIZE, RESNET101_SIZE, 3),
41
+ ),
42
  keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
43
  keras.layers.BatchNormalization(epsilon=1.001e-5),
44
  keras.layers.GlobalAveragePooling2D(name="avg_pool"),
45
  keras.layers.Activation("sigmoid"),
46
  ]
47
  )
48
+ model_resnet101.load_weights("./tf_model_resnet101.h5")
49
 
50
 
51
+ def predict_resnet50(img):
52
  with tf.device(DEVICE):
53
+ img = tf.image.resize_with_pad(img, RESNET50_SIZE, RESNET50_SIZE)
54
  img = tf.image.per_image_standardization(img)
55
  data = tf.expand_dims(img, 0)
56
+ out, *_ = model_resnet50(data)
57
+ return out
58
+
59
+
60
+ def predict_resnet101(img):
61
+ with tf.device(DEVICE):
62
+ img = tf.image.resize_with_pad(img, RESNET101_SIZE, RESNET101_SIZE)
63
+ img = tf.image.per_image_standardization(img)
64
+ data = tf.expand_dims(img, 0)
65
+ out, *_ = model_resnet101(data)
66
+ return out
67
+
68
+
69
+ def main(img, hide: float, model: str):
70
+ if model.endswith("50"):
71
+ out = predict_resnet50(img)
72
+ elif model.endswith("101"):
73
+ out = predict_resnet101(img)
74
+ else:
75
+ raise ValueError(f"Invalid model type: {model!r}")
76
+ result = {
77
  tag: confidence
78
  for i, tag in enumerate(tags)
79
  if (confidence := float(out[i].numpy())) >= hide
80
  }
81
+ pprint(result)
82
+ return result
83
 
84
 
85
  image = inputs.Image(label="Upload your image here!")
86
  hide_threshold = inputs.Slider(
87
  label="Hide confidence lower than", default=0.5, maximum=1, minimum=0
88
  )
89
+ select_model = inputs.Radio(
90
+ ["ResNet50", "ResNet101"], label="Select model", type="value"
91
+ )
92
 
93
  labels = outputs.Label(label="Tags", type="confidences")
94
 
95
+ interface = Interface(
96
+ main, inputs=[image, hide_threshold, select_model], outputs=[labels]
97
+ )
98
  interface.launch()