jordyvl commited on
Commit
212a8e9
·
1 Parent(s): efa98d8

adding local app - to be integrated with public app

Browse files
Files changed (2) hide show
  1. app.py +3 -146
  2. local_app.py +150 -0
app.py CHANGED
@@ -1,150 +1,7 @@
1
  import evaluate
2
  import numpy as np
3
- import pandas as pd
4
- import ast
5
- import json
6
- import gradio as gr
7
- from evaluate.utils import launch_gradio_widget
8
- from ece import ECE
9
-
10
- import matplotlib.pyplot as plt
11
- import seaborn as sns
12
- sns.set_style('white')
13
- sns.set_context("paper", font_scale=1) # 2
14
- # plt.rcParams['figure.figsize'] = [10, 7]
15
- plt.rcParams['figure.dpi'] = 300
16
- plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop
17
-
18
- sliders = [
19
- gr.Slider(0, 100, value=10, label="n_bins"),
20
- gr.Slider(0, 100, value=None, label="bin_range", visible=False), #DEV: need to have a double slider
21
- gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"),
22
- gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"),
23
- gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"),
24
- ]
25
-
26
- slider_defaults = [slider.value for slider in sliders]
27
-
28
- # example data
29
- df = dict()
30
- df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]]
31
- df["references"] = [0, 1, 2]
32
-
33
- component = gr.inputs.Dataframe(
34
- headers=["predictions", "references"], col_count=2, datatype="number", type="pandas"
35
- )
36
-
37
- component.value = [
38
- [[0.6, 0.2, 0.2], 0],
39
- [[0.7, 0.1, 0.2], 2],
40
- [[0, 0.95, 0.05], 1],
41
- ]
42
- sample_data = [[component] + slider_defaults] ##json.dumps(df)
43
-
44
-
45
- metric = ECE()
46
- # module = evaluate.load("jordyvl/ece")
47
- # launch_gradio_widget(module)
48
-
49
- """
50
- Switch inputs and compute_fn
51
- """
52
-
53
- def reliability_plot(results):
54
- fig = plt.figure()
55
- ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
56
- ax2 = plt.subplot2grid((3, 1), (2, 0))
57
-
58
- n_bins = len(results["y_bar"])
59
- bin_range = [
60
- results["y_bar"][0] - results["y_bar"][0],
61
- results["y_bar"][-1],
62
- ] # np.linspace(0, 1, n_bins)
63
- # if upper edge then minus binsize; same for center [but half]
64
-
65
- ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
66
- ax1.plot(
67
- ranged,
68
- ranged,
69
- color="darkgreen",
70
- ls="dotted",
71
- label="Perfect",
72
- )
73
- # ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect")
74
-
75
- anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0]
76
- bin_freqs = np.zeros(n_bins)
77
- bin_freqs[anindices] = results["bin_freq"]
78
- ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
79
 
80
- #widths = np.diff(results["y_bar"])
81
- for j, bin in enumerate(results["y_bar"]):
82
- perfect = results["y_bar"][j]
83
- empirical = results["p_bar"][j]
84
 
85
- if np.isnan(empirical):
86
- continue
87
-
88
- ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue")
89
-
90
- if perfect == empirical:
91
- continue
92
-
93
- acc_plt = ax2.axvline(
94
- x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy"
95
- )
96
- conf_plt = ax2.axvline(
97
- x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence"
98
- )
99
- ax2.legend(handles=[acc_plt, conf_plt])
100
-
101
- #Bin differences
102
- ax1.set_ylabel("Conditional Expectation")
103
- ax1.set_ylim([-0.05, 1.05]) #respective to bin range
104
- ax1.legend(loc="lower right")
105
- ax1.set_title("Reliability Diagram")
106
-
107
- #Bin frequencies
108
- ax2.set_xlabel("Confidence")
109
- ax2.set_ylabel("Count")
110
- ax2.legend(loc="upper left")#, ncol=2
111
- plt.tight_layout()
112
- return fig
113
-
114
- def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
115
- # DEV: check on invalid datatypes with better warnings
116
-
117
- if isinstance(data, pd.DataFrame):
118
- data.dropna(inplace=True)
119
-
120
- predictions = [
121
- ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction
122
- for prediction in data["predictions"]
123
- ]
124
- references = [reference for reference in data["references"]]
125
-
126
- results = metric._compute(
127
- predictions,
128
- references,
129
- n_bins=n_bins,
130
- # bin_range=None,#not needed
131
- scheme=scheme,
132
- proxy=proxy,
133
- p=p,
134
- detail=True,
135
- )
136
-
137
- plot = reliability_plot(results)
138
- return results["ECE"], plot #plt.gcf()
139
-
140
-
141
- outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
142
-
143
- iface = gr.Interface(
144
- fn=compute_and_plot,
145
- inputs=[component] + sliders,
146
- outputs=outputs,
147
- description=metric.info.description,
148
- article=metric.info.citation,
149
- # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
150
- ).launch()
 
1
  import evaluate
2
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
4
 
5
+ from evaluate.utils import launch_gradio_widget
6
+ module = evaluate.load("jordyvl/ece")
7
+ launch_gradio_widget(module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
local_app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ import pandas as pd
4
+ import ast
5
+ import json
6
+ import gradio as gr
7
+ from evaluate.utils import launch_gradio_widget
8
+ from ece import ECE
9
+
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ sns.set_style('white')
13
+ sns.set_context("paper", font_scale=1) # 2
14
+ # plt.rcParams['figure.figsize'] = [10, 7]
15
+ plt.rcParams['figure.dpi'] = 300
16
+ plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop
17
+
18
+ sliders = [
19
+ gr.Slider(0, 100, value=10, label="n_bins"),
20
+ gr.Slider(0, 100, value=None, label="bin_range", visible=False), #DEV: need to have a double slider
21
+ gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"),
22
+ gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"),
23
+ gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"),
24
+ ]
25
+
26
+ slider_defaults = [slider.value for slider in sliders]
27
+
28
+ # example data
29
+ df = dict()
30
+ df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]]
31
+ df["references"] = [0, 1, 2]
32
+
33
+ component = gr.inputs.Dataframe(
34
+ headers=["predictions", "references"], col_count=2, datatype="number", type="pandas"
35
+ )
36
+
37
+ component.value = [
38
+ [[0.6, 0.2, 0.2], 0],
39
+ [[0.7, 0.1, 0.2], 2],
40
+ [[0, 0.95, 0.05], 1],
41
+ ]
42
+ sample_data = [[component] + slider_defaults] ##json.dumps(df)
43
+
44
+
45
+ metric = ECE()
46
+ # module = evaluate.load("jordyvl/ece")
47
+ # launch_gradio_widget(module)
48
+
49
+ """
50
+ Switch inputs and compute_fn
51
+ """
52
+
53
+ def reliability_plot(results):
54
+ fig = plt.figure()
55
+ ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
56
+ ax2 = plt.subplot2grid((3, 1), (2, 0))
57
+
58
+ n_bins = len(results["y_bar"])
59
+ bin_range = [
60
+ results["y_bar"][0] - results["y_bar"][0],
61
+ results["y_bar"][-1],
62
+ ] # np.linspace(0, 1, n_bins)
63
+ # if upper edge then minus binsize; same for center [but half]
64
+
65
+ ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
66
+ ax1.plot(
67
+ ranged,
68
+ ranged,
69
+ color="darkgreen",
70
+ ls="dotted",
71
+ label="Perfect",
72
+ )
73
+ # ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect")
74
+
75
+ anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0]
76
+ bin_freqs = np.zeros(n_bins)
77
+ bin_freqs[anindices] = results["bin_freq"]
78
+ ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
79
+
80
+ #widths = np.diff(results["y_bar"])
81
+ for j, bin in enumerate(results["y_bar"]):
82
+ perfect = results["y_bar"][j]
83
+ empirical = results["p_bar"][j]
84
+
85
+ if np.isnan(empirical):
86
+ continue
87
+
88
+ ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue")
89
+
90
+ if perfect == empirical:
91
+ continue
92
+
93
+ acc_plt = ax2.axvline(
94
+ x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy"
95
+ )
96
+ conf_plt = ax2.axvline(
97
+ x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence"
98
+ )
99
+ ax2.legend(handles=[acc_plt, conf_plt])
100
+
101
+ #Bin differences
102
+ ax1.set_ylabel("Conditional Expectation")
103
+ ax1.set_ylim([-0.05, 1.05]) #respective to bin range
104
+ ax1.legend(loc="lower right")
105
+ ax1.set_title("Reliability Diagram")
106
+
107
+ #Bin frequencies
108
+ ax2.set_xlabel("Confidence")
109
+ ax2.set_ylabel("Count")
110
+ ax2.legend(loc="upper left")#, ncol=2
111
+ plt.tight_layout()
112
+ return fig
113
+
114
+ def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
115
+ # DEV: check on invalid datatypes with better warnings
116
+
117
+ if isinstance(data, pd.DataFrame):
118
+ data.dropna(inplace=True)
119
+
120
+ predictions = [
121
+ ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction
122
+ for prediction in data["predictions"]
123
+ ]
124
+ references = [reference for reference in data["references"]]
125
+
126
+ results = metric._compute(
127
+ predictions,
128
+ references,
129
+ n_bins=n_bins,
130
+ # bin_range=None,#not needed
131
+ scheme=scheme,
132
+ proxy=proxy,
133
+ p=p,
134
+ detail=True,
135
+ )
136
+
137
+ plot = reliability_plot(results)
138
+ return results["ECE"], plot #plt.gcf()
139
+
140
+
141
+ outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
142
+
143
+ iface = gr.Interface(
144
+ fn=compute_and_plot,
145
+ inputs=[component] + sliders,
146
+ outputs=outputs,
147
+ description=metric.info.description,
148
+ article=metric.info.citation,
149
+ # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
150
+ ).launch()