Spaces:
Runtime error
Runtime error
import json | |
from pprint import pprint | |
import tensorflow as tf | |
import tensorflow.keras as keras | |
from gradio import Interface, inputs, outputs | |
RESNET50_SIZE = 256 | |
RESNET101_SIZE = 360 | |
DEVICE = "/cpu:0" | |
with open("./tags.json", "rt", encoding="utf-8") as f: | |
tags = json.load(f) | |
with tf.device(DEVICE): | |
model_resnet50 = keras.Sequential( | |
[ | |
keras.applications.resnet.ResNet50( | |
include_top=False, | |
weights=None, | |
input_shape=(RESNET50_SIZE, RESNET50_SIZE, 3), | |
), | |
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"), | |
keras.layers.BatchNormalization(epsilon=1.001e-5), | |
keras.layers.GlobalAveragePooling2D(name="avg_pool"), | |
keras.layers.Activation("sigmoid"), | |
] | |
) | |
model_resnet50.load_weights("./tf_model_resnet50.h5") | |
with tf.device(DEVICE): | |
model_resnet101 = keras.Sequential( | |
[ | |
keras.applications.resnet.ResNet101( | |
include_top=False, | |
weights=None, | |
input_shape=(RESNET101_SIZE, RESNET101_SIZE, 3), | |
), | |
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"), | |
keras.layers.BatchNormalization(epsilon=1.001e-5), | |
keras.layers.GlobalAveragePooling2D(name="avg_pool"), | |
keras.layers.Activation("sigmoid"), | |
] | |
) | |
model_resnet101.load_weights("./tf_model_resnet101.h5") | |
def predict_resnet50(img): | |
with tf.device(DEVICE): | |
img = tf.image.resize_with_pad(img, RESNET50_SIZE, RESNET50_SIZE) | |
img = tf.image.per_image_standardization(img) | |
data = tf.expand_dims(img, 0) | |
out, *_ = model_resnet50(data) | |
return out | |
def predict_resnet101(img): | |
with tf.device(DEVICE): | |
img = tf.image.resize_with_pad(img, RESNET101_SIZE, RESNET101_SIZE) | |
img = tf.image.per_image_standardization(img) | |
data = tf.expand_dims(img, 0) | |
out, *_ = model_resnet101(data) | |
return out | |
def main(img, hide: float, model: str): | |
if model.endswith("50"): | |
out = predict_resnet50(img) | |
elif model.endswith("101"): | |
out = predict_resnet101(img) | |
else: | |
raise ValueError(f"Invalid model type: {model!r}") | |
result = { | |
tag: confidence | |
for i, tag in enumerate(tags) | |
if (confidence := float(out[i].numpy())) >= hide | |
} | |
pprint(result) | |
return result | |
image = inputs.Image(label="Upload your image here!") | |
hide_threshold = inputs.Slider( | |
label="Hide confidence lower than", default=0.5, maximum=1, minimum=0 | |
) | |
select_model = inputs.Radio( | |
["ResNet50", "ResNet101"], label="Select model", type="value" | |
) | |
labels = outputs.Label(label="Tags", type="confidences") | |
interface = Interface( | |
main, inputs=[image, hide_threshold, select_model], outputs=[labels] | |
) | |
interface.launch() | |