CHM-Corr / app.py
taesiri's picture
fix URLs
7019d33
raw
history blame
6.17 kB
import io
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from PIL import Image
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_corrmap
csv.field_size_limit(sys.maxsize)
concat = lambda x: np.concatenate(x, axis=0)
# Embeddings
gdown.cached_download(
url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
path="./embeddings.pickle",
quiet=False,
md5="002b2a7f5c80d910b9cc740c2265f058",
)
# embeddings
# gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
# labels
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# CUB training set
gdown.cached_download(
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
path="./CUB_train.zip",
quiet=False,
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)
# EXTRACT training set
torchvision.datasets.utils.extract_archive(
from_path="CUB_train.zip",
to_path="data/",
remove_finished=False,
)
# CHM Weights
gdown.cached_download(
url="https://drive.google.com/uc?id=1yM1zA0Ews2I8d9-BGc6Q0hIAl7LzYqr0",
path="pas_psi.pt",
quiet=False,
md5="6b7b4d7bad7f89600fac340d6aa7708b",
)
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_image, searcher=searcher):
query_embedding = QueryToEmbedding(query_image)
scores, indices, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], indices[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
# CHM Prediction
kNN_results = (top1_label, result_ctr[0][1], gallery_images)
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
support = [support_files, support_labels]
chm_output = chm_classify_and_visualize(
query_image, kNN_results, support, training_folder
)
fig, chm_output_label = plot_from_reranker_corrmap(chm_output)
# Resize the output
img_buf = io.BytesIO()
fig.savefig(img_buf, format="jpg")
image = Image.open(img_buf)
width, height = image.size
new_width = width
new_height = height
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))
chm_output_labels = Counter(
[
x.split("/")[-2].replace(".", " ").replace("_", " ")
for x in chm_output["chm-nearest-neighbors-all"][:20]
]
)
return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}
blocks = gr.Blocks()
tldr = """
We propose two architectures of interpretable image classifiers
that first explain, and then predict by harnessing
the visual correspondences between a query image and exemplars.
Our models improve on several out-of-distribution (OOD) ImageNet
datasets while achieving competitive performance on ImageNet
than the black-box baselines (e.g. ImageNet-pretrained ResNet-50).
On a large-scale human study (∼60 users per method per dataset)
on ImageNet and CUB, our correspondence-based explanations led
to human-alone image classification accuracy and human-AI team
accuracy that are consistently better than that of kNN.
We show that it is possible to achieve complementary human-AI
team accuracy (i.e., that is higher than either AI-alone or
human-alone), on ImageNet and CUB.
<div align="center">
<a href="https://github.com/anguyen8/visual-correspondence-XAI">Github Page</a>
</div>
"""
with blocks:
gr.Markdown(""" # CHM-Corr DEMO""")
gr.Markdown(f""" ## Description: \n {tldr}""")
with gr.Row():
input_image = gr.Image(type="filepath")
with gr.Column():
gr.Markdown(f"### Parameters:")
gr.Markdown(
"`N=50`\n `k=20` \nUsing `ImageNet Pretrained ResNet50` features"
)
run_btn = gr.Button("Classify")
gr.Markdown(""" ### CHM-Corr Output Visualization """)
viz_plot = gr.Image(type="pil", label="Visualization")
with gr.Row():
with gr.Column():
gr.Markdown(""" ### CHM-Corr Prediction """)
labels = gr.Label(label="Prediction")
with gr.Column():
gr.Markdown(""" ### Examples """)
examples = gr.Examples(
examples=[
["./examples/bird.jpg"],
["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
["./examples/sample1.jpeg"],
["./examples/sample2.jpeg"],
["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
],
inputs=[input_image],
outputs=[viz_plot, labels],
fn=search,
cache_examples=False,
)
run_btn.click(
search,
inputs=[input_image],
outputs=[viz_plot, labels],
)
if __name__ == "__main__":
blocks.launch(
debug=True,
enable_queue=True,
)