File size: 8,941 Bytes
b0e6781
 
 
 
fb2bc19
b0e6781
 
1d32376
fb2bc19
b0e6781
4c1d731
 
 
 
 
 
b0e6781
29fc06d
b0e6781
62dd38d
b0e6781
1d32376
b0e6781
2c6e148
b0e6781
 
1d32376
 
 
 
 
5a03d31
 
fcedbb9
 
 
 
 
 
 
5a03d31
 
fb2bc19
 
4c1d731
 
 
fb2bc19
 
f3cadf1
2e7bc8b
 
fb2bc19
f3cadf1
29fc06d
b0e6781
4c1d731
2e7bc8b
b0e6781
 
4c1d731
b0e6781
4c1d731
b0e6781
fb2bc19
4c1d731
 
fb2bc19
 
 
f3cadf1
 
29fc06d
 
4c1d731
f92272f
2e7bc8b
 
101c142
 
 
 
 
bd0c4d1
 
101c142
 
 
2e7bc8b
5792938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e7bc8b
 
101c142
 
 
 
2e7bc8b
 
 
 
 
101c142
 
f92272f
 
 
101c142
 
 
c751340
f3cadf1
2e7bc8b
f3cadf1
29fc06d
2e7bc8b
f3cadf1
 
 
 
 
 
 
2e7bc8b
 
4c1d731
 
 
 
f3cadf1
4c1d731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3cadf1
b0e6781
4c1d731
b0e6781
4c1d731
 
 
 
 
 
 
 
 
 
 
1d32376
 
4c1d731
1d32376
 
 
 
 
4c1d731
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
from streamlit.components.v1 import html
# from PIL import Image 
from app.show_examples import *
from app.content import *
import pandas as pd

from model_information import get_dataframe



info_df = get_dataframe()


def draw(folder_name, category_name, dataset_name, metrics, cus_sort=True):
    
    folder = f"./results_organized/{metrics}/"

    # Load the results from CSV
    data_path = f'{folder}/{category_name.lower()}.csv'
    chart_data = pd.read_csv(data_path).round(3)
    new_dataset_name = dataset_name.replace('-', '_').lower()
    chart_data = chart_data[['Model', new_dataset_name]]

    # Rename to proper display name
    new_dataset_name = dataname_column_rename_in_table[new_dataset_name]
    chart_data = chart_data.rename(columns=dataname_column_rename_in_table)

    st.markdown("""
                <style>
                .stMultiSelect [data-baseweb=select] span {
                    max-width: 800px;
                    font-size: 0.9rem;
                    background-color: #3C6478 !important; /* Background color for selected items */
                    color: white; /* Change text color */
                    back
                }
                </style>
                """, unsafe_allow_html=True)
    
    # remap model names
    display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
    chart_data['model_show'] = chart_data['Model'].map(lambda x: display_model_names.get(x, x))


    models = st.multiselect("Please choose the model", 
                            sorted(chart_data['model_show'].tolist()), 
                            default = sorted(chart_data['model_show'].tolist()),
                            )
    
    chart_data = chart_data[chart_data['model_show'].isin(models)]
    chart_data = chart_data.sort_values(by=[new_dataset_name], ascending=cus_sort).dropna(axis=0)

    if len(chart_data) == 0: return



    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Table
    '''
    with st.container():
        st.markdown('##### TABLE')

        
        model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}

        chart_data['model_link'] = chart_data['model_show'].map(model_link) 

        chart_data_table = chart_data[['model_show', chart_data.columns[1], chart_data.columns[3]]]

        # Format numeric columns to 2 decimal places
        #chart_data_table[chart_data_table.columns[1]] = chart_data_table[chart_data_table.columns[1]].apply(lambda x: round(float(x), 3) if isinstance(float(x), (int, float)) else float(x))
        cur_dataset_name = chart_data_table.columns[1]


        def highlight_first_element(x):
                # Create a DataFrame with the same shape as the input
                df_style = pd.DataFrame('', index=x.index, columns=x.columns)
                # Apply background color to the first element in row 0 (df[0][0])
                # df_style.iloc[0, 1] = 'background-color: #b0c1d7; color: white'
                df_style.iloc[0, 1] = 'background-color: #b0c1d7'
                
                return df_style

        if cur_dataset_name in [
                            'LibriSpeech-Clean',
                            'LibriSpeech-Other',
                            'CommonVoice-15-EN',
                            'Peoples-Speech',
                            'GigaSpeech-1',
                            'Earnings-21',
                            'Earnings-22',
                            'TED-LIUM-3',
                            'TED-LIUM-3-Long',
                            'Aishell-ASR-ZH',
                            'IMDA-Part1-ASR',
                            'IMDA-Part2-ASR',
                            'IMDA-Part3-30s-ASR',
                            'IMDA-Part4-30s-ASR',
                            'IMDA-Part5-30s-ASR',
                            'IMDA-Part6-30s-ASR',
                            ]:
            
            chart_data_table = chart_data_table.sort_values(
                    by=chart_data_table.columns[1],
                    ascending=True
                ).reset_index(drop=True)
        else:
            chart_data_table = chart_data_table.sort_values(
                    by=chart_data_table.columns[1],
                    ascending=False
                ).reset_index(drop=True)
            

        styled_df = chart_data_table.style.format(
            {chart_data_table.columns[1]: "{:.3f}"}
        ).apply(
            highlight_first_element, axis=None
        )


        st.dataframe(
                styled_df,
                column_config={
                    'model_show': 'Model',
                    chart_data_table.columns[1]: {'alignment': 'left'},
                    "model_link": st.column_config.LinkColumn(
                        "Model Link",
                    ),
                },
                hide_index=True,
                use_container_width=True
            )
        

    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Chart
    '''

    # Initialize a session state variable for toggling the chart visibility
    if "show_chart" not in st.session_state:
        st.session_state.show_chart = False

    # Create a button to toggle visibility
    if st.button("Show Chart"):
        st.session_state.show_chart = not st.session_state.show_chart

    if st.session_state.show_chart:

        with st.container():
            st.markdown('##### CHART')

            # Get Values
            data_values = chart_data.iloc[:, 1]
            
            # Calculate Q1 and Q3
            q1 = data_values.quantile(0.25)
            q3 = data_values.quantile(0.75)

            # Calculate IQR
            iqr = q3 - q1

            # Define lower and upper bounds (1.5*IQR is a common threshold)
            lower_bound = q1 - 1.5 * iqr
            upper_bound = q3 + 1.5 * iqr

            # Filter data within the bounds
            filtered_data = data_values[(data_values >= lower_bound) & (data_values <= upper_bound)]

            # Calculate min and max values after outlier handling
            min_value = round(filtered_data.min() - 0.1 * filtered_data.min(), 3)
            max_value = round(filtered_data.max() + 0.1 * filtered_data.max(), 3)

            options = {
                # "title": {"text": f"{dataset_name}"},
                "tooltip": {
                    "trigger": "axis",
                    "axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
                    "triggerOn": 'mousemove',
                },
                "legend": {"data": ['Overall Accuracy']},
                "toolbox": {"feature": {"saveAsImage": {}}},
                "grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
                "xAxis": [
                    {
                        "type": "category",
                        "boundaryGap": True,
                        "triggerEvent": True,
                        "data":  chart_data['model_show'].tolist(),
                    }
                ],
                "yAxis": [{"type": "value", 
                            "min": min_value,
                            "max": max_value, 
                            "boundaryGap": True
                            # "splitNumber": 10
                            }],
                "series": [{
                        "name": f"{dataset_name}",
                        "type": "bar",
                        "data": chart_data[f'{new_dataset_name}'].tolist(),
                    }],
            }
            
            events = {
                "click": "function(params) { return params.value }"
            }

            value = st_echarts(options=options, events=events, height="500px")
            



    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =

    '''
    Show Examples
    '''


    # Initialize a session state variable for toggling the chart visibility
    if "show_examples" not in st.session_state:
        st.session_state.show_examples = False

    # Create a button to toggle visibility
    if st.button("Show Examples"):
        st.session_state.show_examples = not st.session_state.show_examples

    if st.session_state.show_examples:
        
        st.markdown('To be implemented')

        # # if dataset_name in ['Earnings21-Test', 'Earnings22-Test', 'Tedlium3-Test', 'Tedlium3-Long-form-Test']:
        # if dataset_name in []:
        #     pass
        # else:
        #     show_examples(category_name, dataset_name, chart_data['Model'].tolist(), display_model_names)