File size: 4,687 Bytes
101c142
 
 
 
 
 
 
1d32376
 
101c142
 
 
 
 
 
 
 
 
1d32376
 
 
101c142
62dd38d
101c142
1d32376
 
 
 
 
 
101c142
1d32376
 
101c142
1d32376
f7d283c
1d32376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101c142
1d32376
 
 
 
 
 
101c142
1d32376
101c142
1d32376
 
 
 
 
 
101c142
1d32376
101c142
1d32376
101c142
1d32376
 
 
 
 
 
101c142
1d32376
 
101c142
1d32376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c4d1
 
 
1d32376
 
f92272f
 
 
 
 
 
1d32376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
from streamlit.components.v1 import html
# from PIL import Image 
from app.show_examples import *
from app.content import *

import pandas as pd
from typing import List

from model_information import get_dataframe

info_df = get_dataframe()


def sum_table_mulit_metrix(task_name, metrics_lists: List[str]):

    # combine chart data from multiple sources
    chart_data = pd.DataFrame()
    for metrics in metrics_lists:
        folder = f"./results_organized/{metrics}"
        data_path = f'{folder}/{task_name.lower()}.csv'
        one_chart_data = pd.read_csv(data_path).round(3)
        if len(chart_data) == 0:
            chart_data = one_chart_data
        else:
            chart_data = pd.merge(chart_data, one_chart_data, on='Model', how='outer')
    

    selected_columns = [i for i in chart_data.columns if i != 'Model']
    chart_data['Average'] = chart_data[selected_columns].mean(axis=1)

    # Update dataset name in table
    chart_data = chart_data.rename(columns=datasetname2diaplayname)
    
    st.markdown("""
                <style>
                .stMultiSelect [data-baseweb=select] span {
                    max-width: 800px;
                    font-size: 0.9rem;
                    background-color: #3C6478 !important; /* Background color for selected items */
                    color: white; /* Change text color */
                    back
                }
                </style>
                """, unsafe_allow_html=True)
    
    # remap model names
    display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
    chart_data['model_show'] = chart_data['Model'].map(lambda x: display_model_names.get(x, x))

    models = st.multiselect("Please choose the model", 
                            sorted(chart_data['model_show'].tolist()), 
                            default = sorted(chart_data['model_show'].tolist()),
                            )
    
    chart_data = chart_data[chart_data['model_show'].isin(models)].dropna(axis=0)

    if len(chart_data) == 0: return

    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Table
    '''
    with st.container():
        st.markdown(f'##### TABLE')

        model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}

        chart_data['model_link'] = chart_data['model_show'].map(model_link) 

        tabel_columns = [i for i in chart_data.columns if i not in ['Model', 'model_show']]
        column_to_front = 'Average'
        new_order = [column_to_front] + [col for col in tabel_columns if col != column_to_front]
        
        chart_data_table = chart_data[['model_show'] + new_order]
        

        # Format numeric columns to 2 decimal places
        chart_data_table[chart_data_table.columns[1]] = chart_data_table[chart_data_table.columns[1]].apply(lambda x: round(float(x), 3) if isinstance(float(x), (int, float)) else float(x))

        if metrics in ['wer']:
            ascend = True
        else:
            ascend= False

        chart_data_table = chart_data_table.sort_values(
                by=['Average'],
                ascending=ascend
            ).reset_index(drop=True)
        
        # Highlight the best performing model
        def highlight_first_element(x):
            # Create a DataFrame with the same shape as the input
            df_style = pd.DataFrame('', index=x.index, columns=x.columns)
            # Apply background color to the first element in row 0 (df[0][0])
            # df_style.iloc[0, 1] = 'background-color: #b0c1d7; color: white'
            df_style.iloc[0, 1] = 'background-color: #b0c1d7'

            return df_style
        

        styled_df = chart_data_table.style.format(
            {
                chart_data_table.columns[i]: "{:.3f}" for i in range(1, len(chart_data_table.columns) - 1)
             }
        ).apply(
            highlight_first_element, axis=None
        )

        st.dataframe(
                styled_df,
                column_config={
                    'model_show': 'Model',
                    chart_data_table.columns[1]: {'alignment': 'left'},
                    "model_link": st.column_config.LinkColumn(
                        "Model Link",
                    ),
                },
                hide_index=True,
                use_container_width=True
            )
            
    # Only report the last metrics
    st.markdown(f'###### Metric: {metrics_info[metrics]}')