Model Card for resnet_mnist_digits
This model is is a Residual Neural Network (ResNet) for classifying handwritten digits in the MNIST dataset. This model has 1.35 M parameters and achieves 99.04% accuracy on the MNIST test dataset (i.e., on digits not seen during training).
Model Details
Model Description
This model takes as an input a 224x224 array of MNIST digits with values normalized to [0, 1]. Intended to compare to 224x224 vision transformers. The model was trained using Keras on an Nvidia Ampere A100.
- Developed by: Phillip Allen Lane
- Model type: ResNet
- License: afl-3.0
How to Get Started with the Model
Use the code below to get started with the model.
from tensorflow.keras import models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from keras.utils.data_utils import get_file
# load the MNIST dataset test images and labels
(_, _), (test_images, test_labels) = mnist.load_data()
# normalize the images
test_images = test_images.astype('float32') / 255
test_images = np.expand_dims(test_images, axis=-1)
test_images = np.repeat(vis_test_images, 3, axis=-1)
test_images = tf.image.resize(vis_test_images, [224,224]).numpy()
# create one-hot labels
test_labels_onehot = to_categorical(test_labels)
# download the model
model_path = get_file('/path/to/large_resnet_mnist.hdf5', 'https://huggingface.co/lane99/resnet_mnist_digits_highres/resolve/main/large-resnet-mnist.hdf5')
# import the model
resnet = models.load_model(model_path)
# evaluate the model
evaluation_conv = resnet.evaluate(test_images[...,0], test_labels_onehot)
print("Accuracy: ", str(evaluation_conv[1]))
Training Details
Training Data
This model was trained on the 60,000 entries in the MNIST training dataset.
Training Procedure
This model was trained with a 0.1 validation split for 10 epochs using a batch size of 128.
- Downloads last month
- 0
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.