Spaces:
Sleeping
Sleeping
import os | |
import datetime | |
import pandas as pd | |
import gradio as gr | |
import argilla as rg | |
import plotly.graph_objects as go | |
import plotly.colors as colors | |
client = rg.Argilla( | |
api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY") | |
) | |
def fetch_data(dataset_name: str, workspace: str): | |
return client.datasets(dataset_name, workspace=workspace) | |
def get_progress(dataset) -> dict: | |
records = list(dataset.records) | |
total_records = len(records) | |
annotated_records = len( | |
[record.status for record in records if record.status == "completed"] | |
) | |
progress = (annotated_records / total_records) * 100 if total_records > 0 else 0 | |
return { | |
"total": total_records, | |
"annotated": annotated_records, | |
"progress": progress, | |
} | |
def get_leaderboard(dataset) -> dict: | |
user_annotations = {} | |
for record in dataset.records: | |
for response in record.responses: | |
user = response.user_id | |
retrieved_user = client.users(id=user) | |
user = retrieved_user.username | |
if user not in user_annotations: | |
user_annotations[user] = 0 | |
user_annotations[user] += 1 | |
print(user_annotations) | |
return user_annotations | |
def create_gauge_chart(progress): | |
fig = go.Figure( | |
go.Indicator( | |
mode="gauge+number+delta", | |
value=progress["progress"], | |
title={"text": "Dataset Annotation Progress", "font": {"size": 24}}, | |
delta={"reference": 100, "increasing": {"color": "RebeccaPurple"}}, | |
number={"font": {"size": 40}, "valueformat": ".1f", "suffix": "%"}, | |
gauge={ | |
"axis": {"range": [None, 100], "tickwidth": 1, "tickcolor": "darkblue"}, | |
"bar": {"color": "deepskyblue"}, | |
"bgcolor": "white", | |
"borderwidth": 2, | |
"bordercolor": "gray", | |
"steps": [ | |
{"range": [0, progress["progress"]], "color": "royalblue"}, | |
{"range": [progress["progress"], 100], "color": "lightgray"}, | |
], | |
"threshold": { | |
"line": {"color": "red", "width": 4}, | |
"thickness": 0.75, | |
"value": 100, | |
}, | |
}, | |
) | |
) | |
fig.update_layout( | |
annotations=[ | |
dict( | |
text=( | |
f"Total records: {progress['total']}<br>" | |
f"Annotated: {progress['annotated']} ({progress['progress']:.1f}%)<br>" | |
f"Remaining: {progress['total'] - progress['annotated']} ({100 - progress['progress']:.1f}%)" | |
), | |
# x=0.5, | |
# y=-0.2, | |
showarrow=False, | |
xref="paper", | |
yref="paper", | |
font=dict(size=16), | |
) | |
], | |
) | |
fig.add_annotation( | |
text=( | |
f"Current Progress: {progress['progress']:.1f}% complete<br>" | |
f"({progress['annotated']} out of {progress['total']} records annotated)" | |
), | |
xref="paper", | |
yref="paper", | |
x=0.5, | |
y=1.1, | |
showarrow=False, | |
font=dict(size=18), | |
align="center", | |
) | |
return fig | |
def create_treemap(user_annotations, total_records): | |
sorted_users = sorted(user_annotations.items(), key=lambda x: x[1], reverse=True) | |
color_scale = colors.qualitative.Pastel + colors.qualitative.Set3 | |
labels, parents, values, text, user_colors = [], [], [], [], [] | |
for i, (user, contribution) in enumerate(sorted_users): | |
percentage = (contribution / total_records) * 100 | |
labels.append(user) | |
parents.append("Annotations") | |
values.append(contribution) | |
text.append(f"{contribution} annotations<br>{percentage:.2f}%") | |
user_colors.append(color_scale[i % len(color_scale)]) | |
labels.append("Annotations") | |
parents.append("") | |
values.append(total_records) | |
text.append(f"Total: {total_records} annotations") | |
user_colors.append("#FFFFFF") | |
fig = go.Figure( | |
go.Treemap( | |
labels=labels, | |
parents=parents, | |
values=values, | |
text=text, | |
textinfo="label+text", | |
hoverinfo="label+text+value", | |
marker=dict(colors=user_colors, line=dict(width=2)), | |
) | |
) | |
fig.update_layout( | |
title_text="User contributions to the total end dataset", | |
height=500, | |
margin=dict(l=10, r=10, t=50, b=10), | |
paper_bgcolor="#F0F0F0", # Light gray background | |
plot_bgcolor="#F0F0F0", # Light gray background | |
) | |
return fig | |
def update_dashboard(): | |
dataset = fetch_data(os.getenv("DATASET_NAME"), os.getenv("WORKSPACE")) | |
progress = get_progress(dataset) | |
user_annotations = get_leaderboard(dataset) | |
gauge_chart = create_gauge_chart(progress) | |
treemap = create_treemap(user_annotations, progress["total"]) | |
leaderboard_df = pd.DataFrame( | |
list(user_annotations.items()), columns=["User", "Annotations"] | |
) | |
leaderboard_df = leaderboard_df.sort_values( | |
"Annotations", ascending=False | |
).reset_index(drop=True) | |
return gauge_chart, treemap, leaderboard_df | |
with gr.Blocks() as demo: | |
gr.Markdown("# Argilla Dataset Dashboard") | |
with gr.Row(): | |
gauge_output = gr.Plot(label="Overall Progress") | |
treemap_output = gr.Plot(label="User contributions") | |
with gr.Row(): | |
leaderboard_output = gr.Dataframe( | |
label="Leaderboard", headers=["User", "Annotations"] | |
) | |
demo.load( | |
update_dashboard, | |
inputs=None, | |
outputs=[gauge_output, treemap_output, leaderboard_output], | |
) | |
gr.Button("Refresh").click( | |
update_dashboard, | |
inputs=None, | |
outputs=[gauge_output, treemap_output, leaderboard_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |