|
import json |
|
import gradio as gr |
|
import numpy as np |
|
import time |
|
import csv |
|
import json |
|
import os |
|
import random |
|
import string |
|
import sys |
|
import time |
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from huggingface_hub import ( |
|
CommitScheduler, |
|
HfApi, |
|
InferenceClient, |
|
login, |
|
snapshot_download, |
|
) |
|
from PIL import Image |
|
from utils import string_to_image |
|
import matplotlib.backends.backend_agg as agg |
|
import math |
|
from pathlib import Path |
|
import zipfile |
|
|
|
|
|
np.random.seed(int(time.time())) |
|
csv.field_size_limit(sys.maxsize) |
|
np.random.seed(int(time.time())) |
|
|
|
|
|
|
|
session_token = os.environ.get("SessionToken") |
|
login(token=session_token, add_to_git_credential=True) |
|
|
|
|
|
snapshot_download( |
|
repo_id='XAI/PEEB-Data', |
|
repo_type='dataset', |
|
local_dir='./', |
|
cache_dir='./hf_cache' |
|
) |
|
|
|
with zipfile.ZipFile('./data.zip', 'r') as zip_ref: |
|
zip_ref.extractall("./") |
|
|
|
|
|
NUMBER_OF_IMAGES = 30 |
|
intro_screen = Image.open("./images/intro.jpg") |
|
|
|
meta_top1 = json.load(open("./dogs/top1/metadata.json")) |
|
meta_topK = json.load(open("./dogs/topK/metadata.json")) |
|
|
|
all_data = {} |
|
all_data["top1"] = meta_top1 |
|
all_data["topK"] = meta_topK |
|
|
|
|
|
|
|
for k in all_data["top1"].keys(): |
|
all_data["top1"][k]["type"] = "top1" |
|
|
|
for k in all_data["topK"].keys(): |
|
all_data["topK"][k]["type"] = "topK" |
|
|
|
|
|
|
|
REPO_URL = "taesiri/AdvisingNetworksReviewDataExtension" |
|
JSON_DATASET_DIR = Path("responses") |
|
|
|
|
|
|
|
scheduler = CommitScheduler( |
|
repo_id=REPO_URL, |
|
repo_type="dataset", |
|
folder_path=JSON_DATASET_DIR, |
|
path_in_repo="./data", |
|
every=1, |
|
private=True, |
|
) |
|
|
|
|
|
if not JSON_DATASET_DIR.exists(): |
|
JSON_DATASET_DIR.mkdir() |
|
|
|
|
|
def generate_data(type_of_nns): |
|
global NUMBER_OF_IMAGES |
|
|
|
keys = list(all_data[type_of_nns].keys()) |
|
sample_data = random.sample(keys, NUMBER_OF_IMAGES) |
|
|
|
data = [] |
|
for k in sample_data: |
|
new_datapoint = all_data[type_of_nns][k] |
|
new_datapoint["image-path"] = f"./dogs/{type_of_nns}/{k}.jpeg" |
|
data.append(new_datapoint) |
|
|
|
return data |
|
|
|
|
|
def load_sample(data, current_index): |
|
current_datapoint = data[current_index] |
|
|
|
image_path = current_datapoint["image-path"] |
|
image = Image.open(image_path) |
|
top_1 = current_datapoint["top1-label"] |
|
top_1_score = current_datapoint["top1-score"] |
|
|
|
q_template = ( |
|
"<div style='font-size: 24px;'>Sam guessed the Input image is " |
|
"<span style='font-weight: bold;'>{}</span> " |
|
"with <span style='font-weight: bold;'>{}%</span> " |
|
"confidence. Is this bird a <span style='font-weight: bold;'>{}</span>?" |
|
"</div>" |
|
) |
|
|
|
q_template = ( |
|
"<div style='font-size: 24px;'>Sam guessed the Input image is " |
|
"<span style='font-weight: bold;'>{}</span> " |
|
"with <span style='font-weight: bold;'>{}%</span> " |
|
"confidence.<br>Is this bird a <span style='font-weight: bold;'>{}</span>?" |
|
"</div>" |
|
) |
|
|
|
top_1_score = top_1_score * 100 |
|
top_1_score = round(top_1_score, 2) |
|
|
|
rounded_up_score = math.ceil(top_1_score) |
|
rounded_up_score = int(rounded_up_score) |
|
question = q_template.format(top_1, str(rounded_up_score), top_1) |
|
|
|
accept_reject = current_datapoint["Accept/Reject"] |
|
|
|
return image, top_1, rounded_up_score, question, accept_reject |
|
|
|
|
|
def preprocessing(data, type_of_nns, current_index, history, username): |
|
print("preprocessing") |
|
data = generate_data(type_of_nns) |
|
print("data generated") |
|
|
|
|
|
random_text = "".join( |
|
random.choice(string.ascii_lowercase + string.digits) for _ in range(8) |
|
) |
|
|
|
if username == "": |
|
username = "username" |
|
|
|
username = f"{username}-{random_text}" |
|
|
|
current_index = 0 |
|
print("loading sample ....") |
|
qimage, top_1, top_1_score, question, accept_reject = load_sample( |
|
data, current_index |
|
) |
|
|
|
return ( |
|
qimage, |
|
top_1, |
|
top_1_score, |
|
question, |
|
accept_reject, |
|
current_index, |
|
history, |
|
data, |
|
username, |
|
) |
|
|
|
|
|
def update_app(decision, data, current_index, history, username): |
|
global NUMBER_OF_IMAGES |
|
if current_index == -1: |
|
gr.Error("Please Enter your username and load samples") |
|
|
|
fake_plot = string_to_image("Please Enter your username and load samples") |
|
canvas = agg.FigureCanvasAgg(fake_plot) |
|
canvas.draw() |
|
empty_image = Image.frombytes( |
|
"RGBA", canvas.get_width_height(), canvas.tostring_argb() |
|
) |
|
|
|
return ( |
|
empty_image, |
|
"", |
|
"", |
|
"", |
|
"", |
|
current_index, |
|
history, |
|
data, |
|
0, |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
"", |
|
) |
|
|
|
|
|
if current_index == NUMBER_OF_IMAGES - 1: |
|
time_stamp = int(time.time()) |
|
|
|
|
|
current_dicitonary = data[current_index].copy() |
|
current_dicitonary["user_decision"] = decision |
|
current_dicitonary["user_id"] = username |
|
accept_reject_string = "Accept" if decision == "YES" else "Reject" |
|
current_dicitonary["is_user_correct"] = ( |
|
current_dicitonary["Accept/Reject"] == accept_reject_string |
|
) |
|
history.append(current_dicitonary) |
|
|
|
|
|
final_decision_data = { |
|
"user_id": username, |
|
"time": time_stamp, |
|
"history": history, |
|
} |
|
|
|
|
|
temp_filename = f"./responses/results_{username}.json" |
|
|
|
with open(temp_filename, "w") as f: |
|
json.dump(final_decision_data, f) |
|
|
|
fake_plot = string_to_image("Thank you for your time!") |
|
canvas = agg.FigureCanvasAgg(fake_plot) |
|
canvas.draw() |
|
empty_image = Image.frombytes( |
|
"RGBA", canvas.get_width_height(), canvas.tostring_argb() |
|
) |
|
|
|
|
|
|
|
all_is_user_correct = [d["is_user_correct"] for d in history] |
|
accuracy = np.mean(all_is_user_correct) * 100 |
|
accuracy = round(accuracy, 2) |
|
|
|
return ( |
|
empty_image, |
|
"", |
|
"", |
|
"", |
|
"", |
|
current_index, |
|
history, |
|
data, |
|
current_index + 1, |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
f"User Accuracy: {accuracy}", |
|
) |
|
|
|
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: |
|
current_dicitonary = data[current_index].copy() |
|
current_dicitonary["user_decision"] = decision |
|
current_dicitonary["user_id"] = username |
|
accept_reject_string = True if decision == "YES" else False |
|
current_dicitonary["is_user_correct"] = ( |
|
current_dicitonary["Accept/Reject"] == accept_reject_string |
|
) |
|
|
|
print(f" accept/reject : {current_dicitonary['Accept/Reject'] }") |
|
print( |
|
f" accept/reject status: {current_dicitonary['Accept/Reject'] == accept_reject_string}" |
|
) |
|
|
|
history.append(current_dicitonary) |
|
|
|
current_index += 1 |
|
qimage, top_1, top_1_score, question, accept_reject = load_sample( |
|
data, current_index |
|
) |
|
|
|
return ( |
|
qimage, |
|
top_1, |
|
top_1_score, |
|
question, |
|
accept_reject, |
|
current_index, |
|
history, |
|
data, |
|
current_index, |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
"", |
|
) |
|
|
|
|
|
def disable_component(): |
|
return gr.update(interactive=False) |
|
|
|
|
|
def enable_component(): |
|
return gr.update(interactive=True) |
|
|
|
|
|
def hide_component(): |
|
return gr.update(visible=False) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
data_state = gr.State({}) |
|
current_index = gr.State(-1) |
|
history = gr.State([]) |
|
|
|
gr.Markdown("# Advising Networks") |
|
gr.Markdown("## Accept/Reject AI predicted label using Explanations") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
username_textbox = gr.Textbox(label="Username", value=f"username") |
|
labeled_images_textbox = gr.Textbox(label="Labeled Images", value="0") |
|
total_images_textbox = gr.Textbox( |
|
label="Total Images", value=NUMBER_OF_IMAGES |
|
) |
|
type_of_nns_dropdown = gr.Dropdown( |
|
label="Type of NNs", |
|
choices=["top1", "topK"], |
|
value="top1", |
|
) |
|
|
|
prepare_btn = gr.Button(value="Start The Experiment") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
question_textbox = gr.HTML("") |
|
|
|
|
|
with gr.Column(elem_id="parent_row"): |
|
query_image = gr.Image( |
|
type="pil", label="Query", show_label=False, value="./images/intro.jpg" |
|
) |
|
|
|
with gr.Row(): |
|
accept_btn = gr.Button(value="YES", interactive=False) |
|
reject_btn = gr.Button(value="NO", interactive=False) |
|
|
|
with gr.Column(elem_id="parent_row"): |
|
top_1_textbox = gr.Textbox(label="Top 1", value="", visible=False) |
|
top_1_score_textbox = gr.Textbox( |
|
label="Top 1 Score", value="", visible=False |
|
) |
|
accept_reject_textbox = gr.Textbox( |
|
label="Accept/Reject", value="", visible=False |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
final_results = gr.HTML("") |
|
|
|
|
|
prepare_btn.click( |
|
preprocessing, |
|
inputs=[ |
|
data_state, |
|
type_of_nns_dropdown, |
|
current_index, |
|
history, |
|
username_textbox, |
|
], |
|
outputs=[ |
|
query_image, |
|
top_1_textbox, |
|
top_1_score_textbox, |
|
question_textbox, |
|
accept_reject_textbox, |
|
current_index, |
|
history, |
|
data_state, |
|
username_textbox, |
|
], |
|
).then(fn=disable_component, outputs=[prepare_btn]).then( |
|
fn=disable_component, outputs=[type_of_nns_dropdown] |
|
).then( |
|
fn=disable_component, outputs=[username_textbox] |
|
).then( |
|
fn=disable_component, outputs=[prepare_btn] |
|
).then( |
|
fn=enable_component, outputs=[accept_btn] |
|
).then( |
|
fn=enable_component, outputs=[reject_btn] |
|
).then( |
|
fn=hide_component, outputs=[prepare_btn] |
|
) |
|
|
|
accept_btn.click( |
|
update_app, |
|
inputs=[accept_btn, data_state, current_index, history, username_textbox], |
|
outputs=[ |
|
query_image, |
|
top_1_textbox, |
|
top_1_score_textbox, |
|
question_textbox, |
|
accept_reject_textbox, |
|
current_index, |
|
history, |
|
data_state, |
|
labeled_images_textbox, |
|
accept_btn, |
|
reject_btn, |
|
final_results, |
|
], |
|
) |
|
|
|
reject_btn.click( |
|
update_app, |
|
inputs=[reject_btn, data_state, current_index, history, username_textbox], |
|
outputs=[ |
|
query_image, |
|
top_1_textbox, |
|
top_1_score_textbox, |
|
question_textbox, |
|
accept_reject_textbox, |
|
current_index, |
|
history, |
|
data_state, |
|
labeled_images_textbox, |
|
accept_btn, |
|
reject_btn, |
|
final_results, |
|
], |
|
) |
|
|
|
|
|
demo.launch(debug=False, server_name="0.0.0.0") |
|
|
|
|