Spaces:
Running
Running
File size: 4,687 Bytes
101c142 1d32376 101c142 1d32376 101c142 62dd38d 101c142 1d32376 101c142 1d32376 101c142 1d32376 f7d283c 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 101c142 1d32376 bd0c4d1 1d32376 f92272f 1d32376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
from streamlit.components.v1 import html
# from PIL import Image
from app.show_examples import *
from app.content import *
import pandas as pd
from typing import List
from model_information import get_dataframe
info_df = get_dataframe()
def sum_table_mulit_metrix(task_name, metrics_lists: List[str]):
# combine chart data from multiple sources
chart_data = pd.DataFrame()
for metrics in metrics_lists:
folder = f"./results_organized/{metrics}"
data_path = f'{folder}/{task_name.lower()}.csv'
one_chart_data = pd.read_csv(data_path).round(3)
if len(chart_data) == 0:
chart_data = one_chart_data
else:
chart_data = pd.merge(chart_data, one_chart_data, on='Model', how='outer')
selected_columns = [i for i in chart_data.columns if i != 'Model']
chart_data['Average'] = chart_data[selected_columns].mean(axis=1)
# Update dataset name in table
chart_data = chart_data.rename(columns=datasetname2diaplayname)
st.markdown("""
<style>
.stMultiSelect [data-baseweb=select] span {
max-width: 800px;
font-size: 0.9rem;
background-color: #3C6478 !important; /* Background color for selected items */
color: white; /* Change text color */
back
}
</style>
""", unsafe_allow_html=True)
# remap model names
display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
chart_data['model_show'] = chart_data['Model'].map(lambda x: display_model_names.get(x, x))
models = st.multiselect("Please choose the model",
sorted(chart_data['model_show'].tolist()),
default = sorted(chart_data['model_show'].tolist()),
)
chart_data = chart_data[chart_data['model_show'].isin(models)].dropna(axis=0)
if len(chart_data) == 0: return
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
'''
Show Table
'''
with st.container():
st.markdown(f'##### TABLE')
model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}
chart_data['model_link'] = chart_data['model_show'].map(model_link)
tabel_columns = [i for i in chart_data.columns if i not in ['Model', 'model_show']]
column_to_front = 'Average'
new_order = [column_to_front] + [col for col in tabel_columns if col != column_to_front]
chart_data_table = chart_data[['model_show'] + new_order]
# Format numeric columns to 2 decimal places
chart_data_table[chart_data_table.columns[1]] = chart_data_table[chart_data_table.columns[1]].apply(lambda x: round(float(x), 3) if isinstance(float(x), (int, float)) else float(x))
if metrics in ['wer']:
ascend = True
else:
ascend= False
chart_data_table = chart_data_table.sort_values(
by=['Average'],
ascending=ascend
).reset_index(drop=True)
# Highlight the best performing model
def highlight_first_element(x):
# Create a DataFrame with the same shape as the input
df_style = pd.DataFrame('', index=x.index, columns=x.columns)
# Apply background color to the first element in row 0 (df[0][0])
# df_style.iloc[0, 1] = 'background-color: #b0c1d7; color: white'
df_style.iloc[0, 1] = 'background-color: #b0c1d7'
return df_style
styled_df = chart_data_table.style.format(
{
chart_data_table.columns[i]: "{:.3f}" for i in range(1, len(chart_data_table.columns) - 1)
}
).apply(
highlight_first_element, axis=None
)
st.dataframe(
styled_df,
column_config={
'model_show': 'Model',
chart_data_table.columns[1]: {'alignment': 'left'},
"model_link": st.column_config.LinkColumn(
"Model Link",
),
},
hide_index=True,
use_container_width=True
)
# Only report the last metrics
st.markdown(f'###### Metric: {metrics_info[metrics]}')
|