import io import csv import sys import pickle from collections import Counter import numpy as np import gradio as gr import gdown import torchvision from torchvision.datasets import ImageFolder from PIL import Image from SimSearch import FaissCosineNeighbors, SearchableTrainingSet from ExtractEmbedding import QueryToEmbedding from CHMCorr import chm_classify_and_visualize from visualization import plot_from_reranker_corrmap csv.field_size_limit(sys.maxsize) concat = lambda x: np.concatenate(x, axis=0) # Embeddings gdown.cached_download( url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89", path="./embeddings.pickle", quiet=False, md5="002b2a7f5c80d910b9cc740c2265f058", ) # embeddings # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89") # labels gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") # CUB training set gdown.cached_download( url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo", path="./CUB_train.zip", quiet=False, md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", ) # EXTRACT training set torchvision.datasets.utils.extract_archive( from_path="CUB_train.zip", to_path="data/", remove_finished=False, ) # CHM Weights gdown.cached_download( url="https://drive.google.com/uc?id=1yM1zA0Ews2I8d9-BGc6Q0hIAl7LzYqr0", path="pas_psi.pt", quiet=False, md5="6b7b4d7bad7f89600fac340d6aa7708b", ) # Caluclate Accuracy with open(f"./embeddings.pickle", "rb") as f: Xtrain = pickle.load(f) # FIXME: re-run the code to get the embeddings in the right format with open(f"./labels.pickle", "rb") as f: ytrain = pickle.load(f) searcher = SearchableTrainingSet(Xtrain, ytrain) searcher.build_index() # Extract label names training_folder = ImageFolder(root="./data/train/") id_to_bird_name = { x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs } def search(query_image, searcher=searcher): query_embedding = QueryToEmbedding(query_image) scores, indices, labels = searcher.search(query_embedding, k=50) result_ctr = Counter(labels[0][:20]).most_common(5) top1_label = result_ctr[0][0] top_indices = [] for a, b in zip(labels[0][:20], indices[0][:20]): if a == top1_label: top_indices.append(b) gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} # CHM Prediction kNN_results = (top1_label, result_ctr[0][1], gallery_images) support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] support = [support_files, support_labels] chm_output = chm_classify_and_visualize( query_image, kNN_results, support, training_folder ) fig, chm_output_label = plot_from_reranker_corrmap(chm_output) # Resize the output img_buf = io.BytesIO() fig.savefig(img_buf, format="jpg") image = Image.open(img_buf) width, height = image.size new_width = width new_height = height left = (width - new_width) / 2 top = (height - new_height) / 2 right = (width + new_width) / 2 bottom = (height + new_height) / 2 viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80)) chm_output_labels = Counter( [ x.split("/")[-2].replace(".", " ").replace("_", " ") for x in chm_output["chm-nearest-neighbors-all"][:20] ] ) return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()} blocks = gr.Blocks() tldr = """ We propose two architectures of interpretable image classifiers that first explain, and then predict by harnessing the visual correspondences between a query image and exemplars. Our models improve on several out-of-distribution (OOD) ImageNet datasets while achieving competitive performance on ImageNet than the black-box baselines (e.g. ImageNet-pretrained ResNet-50). On a large-scale human study (∼60 users per method per dataset) on ImageNet and CUB, our correspondence-based explanations led to human-alone image classification accuracy and human-AI team accuracy that are consistently better than that of kNN. We show that it is possible to achieve complementary human-AI team accuracy (i.e., that is higher than either AI-alone or human-alone), on ImageNet and CUB.
Github Page
""" with blocks: gr.Markdown(""" # CHM-Corr DEMO""") gr.Markdown(f""" ## Description: \n {tldr}""") with gr.Row(): input_image = gr.Image(type="filepath") with gr.Column(): gr.Markdown(f"### Parameters:") gr.Markdown( "`N=50`\n `k=20` \nUsing `ImageNet Pretrained ResNet50` features" ) run_btn = gr.Button("Classify") gr.Markdown(""" ### CHM-Corr Output Visualization """) viz_plot = gr.Image(type="pil", label="Visualization") with gr.Row(): with gr.Column(): gr.Markdown(""" ### CHM-Corr Prediction """) labels = gr.Label(label="Prediction") with gr.Column(): gr.Markdown(""" ### Examples """) examples = gr.Examples( examples=[ ["./examples/bird.jpg"], ["./examples/Red_Winged_Blackbird_0012_6015.jpg"], ["./examples/Red_Winged_Blackbird_0025_5342.jpg"], ["./examples/sample1.jpeg"], ["./examples/sample2.jpeg"], ["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"], ["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"], ], inputs=[input_image], outputs=[viz_plot, labels], fn=search, cache_examples=False, ) run_btn.click( search, inputs=[input_image], outputs=[viz_plot, labels], ) if __name__ == "__main__": blocks.launch( debug=True, enable_queue=True, )