mayhug's picture
Update app.py
9b783d7
import json
from pprint import pprint
import tensorflow as tf
import tensorflow.keras as keras
from gradio import Interface, inputs, outputs
RESNET50_SIZE = 256
RESNET101_SIZE = 360
DEVICE = "/cpu:0"
with open("./tags.json", "rt", encoding="utf-8") as f:
tags = json.load(f)
with tf.device(DEVICE):
model_resnet50 = keras.Sequential(
[
keras.applications.resnet.ResNet50(
include_top=False,
weights=None,
input_shape=(RESNET50_SIZE, RESNET50_SIZE, 3),
),
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
keras.layers.BatchNormalization(epsilon=1.001e-5),
keras.layers.GlobalAveragePooling2D(name="avg_pool"),
keras.layers.Activation("sigmoid"),
]
)
model_resnet50.load_weights("./tf_model_resnet50.h5")
with tf.device(DEVICE):
model_resnet101 = keras.Sequential(
[
keras.applications.resnet.ResNet101(
include_top=False,
weights=None,
input_shape=(RESNET101_SIZE, RESNET101_SIZE, 3),
),
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
keras.layers.BatchNormalization(epsilon=1.001e-5),
keras.layers.GlobalAveragePooling2D(name="avg_pool"),
keras.layers.Activation("sigmoid"),
]
)
model_resnet101.load_weights("./tf_model_resnet101.h5")
def predict_resnet50(img):
with tf.device(DEVICE):
img = tf.image.resize_with_pad(img, RESNET50_SIZE, RESNET50_SIZE)
img = tf.image.per_image_standardization(img)
data = tf.expand_dims(img, 0)
out, *_ = model_resnet50(data)
return out
def predict_resnet101(img):
with tf.device(DEVICE):
img = tf.image.resize_with_pad(img, RESNET101_SIZE, RESNET101_SIZE)
img = tf.image.per_image_standardization(img)
data = tf.expand_dims(img, 0)
out, *_ = model_resnet101(data)
return out
def main(img, hide: float, model: str):
if model.endswith("50"):
out = predict_resnet50(img)
elif model.endswith("101"):
out = predict_resnet101(img)
else:
raise ValueError(f"Invalid model type: {model!r}")
result = {
tag: confidence
for i, tag in enumerate(tags)
if (confidence := float(out[i].numpy())) >= hide
}
pprint(result)
return result
image = inputs.Image(label="Upload your image here!")
hide_threshold = inputs.Slider(
label="Hide confidence lower than", default=0.5, maximum=1, minimum=0
)
select_model = inputs.Radio(
["ResNet50", "ResNet101"], label="Select model", type="value"
)
labels = outputs.Label(label="Tags", type="confidences")
interface = Interface(
main, inputs=[image, hide_threshold, select_model], outputs=[labels]
)
interface.launch()