Runtime error
Runtime error
## LIBRARIES ### | |
from cProfile import label | |
from tkinter import font | |
from turtle import width | |
import streamlit as st | |
import pandas as pd | |
from datetime import datetime | |
import as px | |
def select_plot_data(df, quantile_low, qunatile_high): | |
df.fillna(0, inplace=True) | |
df_plot = df.set_index('Model').T | |
df_plot.index = date_range(df_plot) | |
df_stats = df_plot.describe() | |
quantile_lvalue = df_stats.quantile(quantile_low, axis=1)['mean'] | |
quantile_hvalue = df_stats.quantile(qunatile_high, axis=1)['mean'] | |
df_plot_data = df_plot.loc[:,[(df_plot[col].mean() > quantile_lvalue and df_plot[col].mean() < quantile_hvalue) for col in df_plot.columns]] | |
return df_plot_data | |
def read_file_to_df(file): | |
return pd.read_csv(file) | |
def date_range(df): | |
time = df.index.to_list() | |
time_range = [] | |
for t in time: | |
time_range.append(str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().month) +'/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().day) + '/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().year)[-2:]) | |
return time_range | |
if __name__ == "__main__": | |
st.set_page_config(layout="wide", page_title="HF Hub Model Usage Visualization") | |
st.header("Model Usage Visualization") | |
with st.expander("How to read and interact with the plot:"): | |
st.markdown("The plots below visualize weekly usage for HF models categorized by the model creation time.") | |
st.markdown("Select the model creation time range you want to visualize using the dropdown menu below.") | |
st.markdown("Choose the quantile range to filter out models with high or low usage.") | |
st.markdown("The plots are interactive. Hover over the points to see the model name and the number of weekly mean usage. Click on the legend to hide/show the models.") | |
model_init_year = st.multiselect("Model creation year", ["before_2021", "2021", "2022"], key = "model_init_year", default = "2022") | |
popularity_low = st.slider("Model popularity quantile (lower limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.90, key = "popularity_low") | |
popularity_high = st.slider("Model popularity quantile (upper limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.99, key = "popularity_high") | |
if 'model_init_year' not in st.session_state: | |
st.session_state['model_init_year'] = model_init_year | |
if 'popularity_low' not in st.session_state: | |
st.session_state['popularity_low'] = popularity_low | |
if 'popularity_high' not in st.session_state: | |
st.session_state['popularity_high'] = popularity_high | |
with st.container(): | |
for year in st.session_state['model_init_year']: | |
plotly_spot = st.empty() | |
df = read_file_to_df("./assets/"+year+"/model_usage.csv") | |
df_plot_data = select_plot_data(df, st.session_state['popularity_low'], st.session_state['popularity_high']) | |
fig = px.line(df_plot_data, title="Models created in "+year, labels={"index": "Weeks", "value": "Usage", "variable": "Model"}) | |
with plotly_spot: | |
st.plotly_chart(fig, use_container_width=True) | |