Spaces:
Runtime error
Runtime error
Alexander Seifert
commited on
Commit
·
554bac5
1
Parent(s):
8778b89
improve docs
Browse files- src/data.py +47 -14
- src/load.py +2 -2
- src/subpages/attention.py +3 -14
- src/subpages/hidden_states.py +37 -1
- src/utils.py +41 -20
src/data.py
CHANGED
@@ -46,7 +46,16 @@ def get_collator(tokenizer) -> DataCollatorForTokenClassification:
|
|
46 |
return DataCollatorForTokenClassification(tokenizer)
|
47 |
|
48 |
|
49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
word_ids = []
|
51 |
wid = -1
|
52 |
tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
|
@@ -65,16 +74,27 @@ def create_word_ids_from_tokens(tokenizer, input_ids: list[int]):
|
|
65 |
return word_ids
|
66 |
|
67 |
|
68 |
-
def
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
labels = []
|
71 |
wids = []
|
72 |
|
73 |
-
for idx, label in enumerate(
|
74 |
try:
|
75 |
word_ids = tokenized_inputs.word_ids(batch_index=idx)
|
76 |
except ValueError:
|
77 |
-
word_ids =
|
|
|
|
|
78 |
previous_word_idx = None
|
79 |
label_ids = []
|
80 |
for word_idx in word_ids:
|
@@ -119,7 +139,7 @@ def encode_dataset(split: Dataset, tokenizer):
|
|
119 |
remove_columns = split.column_names
|
120 |
ids = split["id"]
|
121 |
split = split.map(
|
122 |
-
partial(
|
123 |
batched=True,
|
124 |
remove_columns=remove_columns,
|
125 |
)
|
@@ -128,6 +148,18 @@ def encode_dataset(split: Dataset, tokenizer):
|
|
128 |
|
129 |
|
130 |
def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
# Convert dict of lists to list of dicts suitable for data collator
|
132 |
features = [dict(zip(batch, t)) for t in zip(*batch.values())]
|
133 |
|
@@ -159,19 +191,20 @@ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
|
159 |
return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
|
160 |
|
161 |
|
162 |
-
def
|
163 |
-
"""
|
164 |
|
165 |
Args:
|
166 |
-
split_encoded (Dataset):
|
167 |
-
model
|
168 |
-
tokenizer
|
169 |
-
collator
|
170 |
-
tags
|
171 |
|
172 |
Returns:
|
173 |
-
pd.DataFrame:
|
174 |
"""
|
|
|
175 |
split_encoded = split_encoded.map(
|
176 |
partial(
|
177 |
forward_pass_with_label,
|
|
|
46 |
return DataCollatorForTokenClassification(tokenizer)
|
47 |
|
48 |
|
49 |
+
def create_word_ids_from_input_ids(tokenizer, input_ids: list[int]) -> list[int]:
|
50 |
+
"""Takes a list of input_ids and return corresponding word_ids
|
51 |
+
|
52 |
+
Args:
|
53 |
+
tokenizer: The tokenizer that was used to obtain the input ids.
|
54 |
+
input_ids (list[int]): List of token ids.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
list[int]: Word ids corresponding to the input ids.
|
58 |
+
"""
|
59 |
word_ids = []
|
60 |
wid = -1
|
61 |
tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
|
|
|
74 |
return word_ids
|
75 |
|
76 |
|
77 |
+
def tokenize(batch, tokenizer) -> dict:
|
78 |
+
"""Tokenizes a batch of examples.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
batch: The examples to tokenize
|
82 |
+
tokenizer: The tokenizer to use
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
dict: The tokenized batch
|
86 |
+
"""
|
87 |
+
tokenized_inputs = tokenizer(batch["tokens"], truncation=True, is_split_into_words=True)
|
88 |
labels = []
|
89 |
wids = []
|
90 |
|
91 |
+
for idx, label in enumerate(batch["ner_tags"]):
|
92 |
try:
|
93 |
word_ids = tokenized_inputs.word_ids(batch_index=idx)
|
94 |
except ValueError:
|
95 |
+
word_ids = create_word_ids_from_input_ids(
|
96 |
+
tokenizer, tokenized_inputs["input_ids"][idx]
|
97 |
+
)
|
98 |
previous_word_idx = None
|
99 |
label_ids = []
|
100 |
for word_idx in word_ids:
|
|
|
139 |
remove_columns = split.column_names
|
140 |
ids = split["id"]
|
141 |
split = split.map(
|
142 |
+
partial(tokenize, tokenizer=tokenizer),
|
143 |
batched=True,
|
144 |
remove_columns=remove_columns,
|
145 |
)
|
|
|
148 |
|
149 |
|
150 |
def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
151 |
+
"""Runs the forward pass for a batch of examples.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
batch: The batch to process
|
155 |
+
model: The model to process the batch with
|
156 |
+
collator: A data collator
|
157 |
+
num_classes (int): Number of classes
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
dict: a dictionary containing `losses`, `preds` and `hidden_states`
|
161 |
+
"""
|
162 |
+
|
163 |
# Convert dict of lists to list of dicts suitable for data collator
|
164 |
features = [dict(zip(batch, t)) for t in zip(*batch.values())]
|
165 |
|
|
|
191 |
return {"losses": loss, "preds": preds, "hidden_states": hidden_states}
|
192 |
|
193 |
|
194 |
+
def predict(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
|
195 |
+
"""Generates predictions for a given dataset split and returns the results as a dataframe.
|
196 |
|
197 |
Args:
|
198 |
+
split_encoded (Dataset): The dataset to process
|
199 |
+
model: The model to process the dataset with
|
200 |
+
tokenizer: The tokenizer to process the dataset with
|
201 |
+
collator: The data collator to use
|
202 |
+
tags: The tags used in the dataset
|
203 |
|
204 |
Returns:
|
205 |
+
pd.DataFrame: A dataframe containing token-level predictions.
|
206 |
"""
|
207 |
+
|
208 |
split_encoded = split_encoded.map(
|
209 |
partial(
|
210 |
forward_pass_with_label,
|
src/load.py
CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
|
|
4 |
import streamlit as st
|
5 |
from datasets import Dataset # type: ignore
|
6 |
|
7 |
-
from src.data import encode_dataset, get_collator, get_data,
|
8 |
from src.model import get_encoder, get_model, get_tokenizer
|
9 |
from src.subpages import Context
|
10 |
from src.utils import align_sample, device, explode_df
|
@@ -68,7 +68,7 @@ def load_context(
|
|
68 |
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
69 |
|
70 |
# transform into dataframe
|
71 |
-
df =
|
72 |
df["word_ids"] = word_ids
|
73 |
df["ids"] = ids
|
74 |
|
|
|
4 |
import streamlit as st
|
5 |
from datasets import Dataset # type: ignore
|
6 |
|
7 |
+
from src.data import encode_dataset, get_collator, get_data, predict
|
8 |
from src.model import get_encoder, get_model, get_tokenizer
|
9 |
from src.subpages import Context
|
10 |
from src.utils import align_sample, device, explode_df
|
|
|
68 |
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
69 |
|
70 |
# transform into dataframe
|
71 |
+
df = predict(split_encoded, model, tokenizer, collator, tags)
|
72 |
df["word_ids"] = word_ids
|
73 |
df["ids"] = ids
|
74 |
|
src/subpages/attention.py
CHANGED
@@ -7,7 +7,7 @@ from streamlit.components.v1 import html
|
|
7 |
|
8 |
from src.subpages.page import Context, Page # type: ignore
|
9 |
|
10 |
-
|
11 |
<script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
|
12 |
<script>
|
13 |
var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
|
@@ -70,17 +70,6 @@ SETUP_HTML = """
|
|
70 |
<div id="basic"></div>
|
71 |
"""
|
72 |
|
73 |
-
JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
|
74 |
-
const viz_id = basic.init()
|
75 |
-
|
76 |
-
ecco.interactiveTokensAndFactorSparklines(viz_id, {}, {{
|
77 |
-
'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}}
|
78 |
-
}}
|
79 |
-
}})
|
80 |
-
}}, function (err) {{
|
81 |
-
console.log(err);
|
82 |
-
}})"""
|
83 |
-
|
84 |
|
85 |
@st.cache(allow_output_mutation=True)
|
86 |
def _load_ecco_model():
|
@@ -160,10 +149,10 @@ class AttentionPage(Page):
|
|
160 |
output = lm(inputs)
|
161 |
nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
|
162 |
data = nmf.explore(returnData=True)
|
163 |
-
|
164 |
const viz_id = basic.init()
|
165 |
ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
|
166 |
}}, function (err) {{
|
167 |
console.log(err);
|
168 |
}})</script>"""
|
169 |
-
html(
|
|
|
7 |
|
8 |
from src.subpages.page import Context, Page # type: ignore
|
9 |
|
10 |
+
_SETUP_HTML = """
|
11 |
<script src="https://requirejs.org/docs/release/2.3.6/minified/require.js"></script>
|
12 |
<script>
|
13 |
var ecco_url = 'https://storage.googleapis.com/ml-intro/ecco/'
|
|
|
70 |
<div id="basic"></div>
|
71 |
"""
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
@st.cache(allow_output_mutation=True)
|
75 |
def _load_ecco_model():
|
|
|
149 |
output = lm(inputs)
|
150 |
nmf = output.run_nmf(n_components=n_components, from_layer=from_layer, to_layer=to_layer)
|
151 |
data = nmf.explore(returnData=True)
|
152 |
+
_JS_TEMPLATE = f"""<script>requirejs(['basic', 'ecco'], function(basic, ecco){{
|
153 |
const viz_id = basic.init()
|
154 |
ecco.interactiveTokensAndFactorSparklines(viz_id, {data}, {{ 'hltrCFG': {{'tokenization_config': {{'token_prefix': '', 'partial_token_prefix': '##'}} }} }})
|
155 |
}}, function (err) {{
|
156 |
console.log(err);
|
157 |
}})</script>"""
|
158 |
+
html(_SETUP_HTML + _JS_TEMPLATE, height=800, scrolling=True)
|
src/subpages/hidden_states.py
CHANGED
@@ -10,7 +10,19 @@ from src.subpages.page import Context, Page
|
|
10 |
|
11 |
|
12 |
@st.cache
|
13 |
-
def reduce_dim_svd(X, n_iter, random_state=42):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from sklearn.decomposition import TruncatedSVD
|
15 |
|
16 |
svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
|
@@ -19,6 +31,17 @@ def reduce_dim_svd(X, n_iter, random_state=42):
|
|
19 |
|
20 |
@st.cache
|
21 |
def reduce_dim_pca(X, random_state=42):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from sklearn.decomposition import PCA
|
23 |
|
24 |
return PCA(n_components=2, random_state=random_state).fit_transform(X)
|
@@ -26,6 +49,19 @@ def reduce_dim_pca(X, random_state=42):
|
|
26 |
|
27 |
@st.cache
|
28 |
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from umap import UMAP
|
30 |
|
31 |
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
|
|
10 |
|
11 |
|
12 |
@st.cache
|
13 |
+
def reduce_dim_svd(X, n_iter: int, random_state=42):
|
14 |
+
"""Dimensionality reduction using truncated SVD (aka LSA).
|
15 |
+
|
16 |
+
This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
X: Training data
|
20 |
+
n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
|
21 |
+
random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
25 |
+
"""
|
26 |
from sklearn.decomposition import TruncatedSVD
|
27 |
|
28 |
svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
|
|
|
31 |
|
32 |
@st.cache
|
33 |
def reduce_dim_pca(X, random_state=42):
|
34 |
+
"""Principal component analysis (PCA).
|
35 |
+
|
36 |
+
Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
X: Training data
|
40 |
+
random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
44 |
+
"""
|
45 |
from sklearn.decomposition import PCA
|
46 |
|
47 |
return PCA(n_components=2, random_state=random_state).fit_transform(X)
|
|
|
49 |
|
50 |
@st.cache
|
51 |
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
|
52 |
+
"""Uniform Manifold Approximation and Projection
|
53 |
+
|
54 |
+
Finds a low dimensional embedding of the data that approximates an underlying manifold.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
X: Training data
|
58 |
+
n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
|
59 |
+
min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
|
60 |
+
metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
|
64 |
+
"""
|
65 |
from umap import UMAP
|
66 |
|
67 |
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
|
src/utils.py
CHANGED
@@ -34,6 +34,7 @@ classmap = {
|
|
34 |
|
35 |
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
36 |
"""Creates an st-aggrid interactive table based on a dataframe.
|
|
|
37 |
Args:
|
38 |
df (pd.DataFrame]): Source dataframe
|
39 |
Returns:
|
@@ -60,6 +61,8 @@ def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
|
60 |
|
61 |
|
62 |
def explode_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
|
63 |
df_tokens = df.apply(pd.Series.explode)
|
64 |
if "losses" in df.columns:
|
65 |
df_tokens["losses"] = df_tokens["losses"].astype(float)
|
@@ -67,7 +70,7 @@ def explode_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
67 |
|
68 |
|
69 |
def align_sample(row: pd.Series):
|
70 |
-
"""
|
71 |
|
72 |
columns = row.axes[0].to_list()
|
73 |
indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
|
@@ -113,7 +116,17 @@ def align_sample(row: pd.Series):
|
|
113 |
hash_funcs=tokenizer_hash_funcs,
|
114 |
)
|
115 |
def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
|
116 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
tokens = tokenizer(text).tokens()
|
119 |
tokenized = tokenizer(text, return_tensors="pt")
|
@@ -137,21 +150,31 @@ def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
|
|
137 |
return explode_df(merged_df).reset_index().drop(columns=["index"])
|
138 |
|
139 |
|
140 |
-
def get_bg_color(label):
|
|
|
141 |
return st.session_state[f"color_{label}"]
|
142 |
|
143 |
|
144 |
-
def get_fg_color(
|
145 |
-
"""
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
|
150 |
return "black" if (yiq >= 128) else "white"
|
151 |
|
152 |
|
153 |
def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
154 |
-
"""
|
155 |
|
156 |
def colorize_row(row):
|
157 |
return [
|
@@ -175,6 +198,14 @@ def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
|
175 |
|
176 |
|
177 |
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
html = []
|
179 |
|
180 |
for _, row in example.iterrows():
|
@@ -215,18 +246,8 @@ def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
|
215 |
return " ".join(html)
|
216 |
|
217 |
|
218 |
-
def htmlify_example(example: pd.DataFrame) -> str:
|
219 |
-
corr_html = " ".join(
|
220 |
-
[
|
221 |
-
f", {row.tokens}" if row.labels == "B-COMMA" else row.tokens
|
222 |
-
for _, row in example.iterrows()
|
223 |
-
]
|
224 |
-
).strip()
|
225 |
-
return f"<em>{corr_html}</em>"
|
226 |
-
|
227 |
-
|
228 |
def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
|
229 |
-
"""
|
230 |
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
|
231 |
cmap = cm.get_cmap(cmap_name) # PiYG
|
232 |
rgba = cmap(norm(abs(value)))
|
|
|
34 |
|
35 |
def aggrid_interactive_table(df: pd.DataFrame) -> dict:
|
36 |
"""Creates an st-aggrid interactive table based on a dataframe.
|
37 |
+
|
38 |
Args:
|
39 |
df (pd.DataFrame]): Source dataframe
|
40 |
Returns:
|
|
|
61 |
|
62 |
|
63 |
def explode_df(df: pd.DataFrame) -> pd.DataFrame:
|
64 |
+
"""Takes a dataframe and explodes all the fields."""
|
65 |
+
|
66 |
df_tokens = df.apply(pd.Series.explode)
|
67 |
if "losses" in df.columns:
|
68 |
df_tokens["losses"] = df_tokens["losses"].astype(float)
|
|
|
70 |
|
71 |
|
72 |
def align_sample(row: pd.Series):
|
73 |
+
"""Uses word_ids to align all lists in a sample."""
|
74 |
|
75 |
columns = row.axes[0].to_list()
|
76 |
indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
|
|
|
116 |
hash_funcs=tokenizer_hash_funcs,
|
117 |
)
|
118 |
def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
|
119 |
+
"""Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
text (str): The text to be processed
|
123 |
+
tokenizer: Tokenizer to use
|
124 |
+
model (_type_): Model to use
|
125 |
+
device (torch.device): The device we want pytorch to use for its calcultaions.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
pd.DataFrame: A data frame holding the tagged text.
|
129 |
+
"""
|
130 |
|
131 |
tokens = tokenizer(text).tokens()
|
132 |
tokenized = tokenizer(text, return_tensors="pt")
|
|
|
150 |
return explode_df(merged_df).reset_index().drop(columns=["index"])
|
151 |
|
152 |
|
153 |
+
def get_bg_color(label: str):
|
154 |
+
"""Retrieves a label's color from the session state."""
|
155 |
return st.session_state[f"color_{label}"]
|
156 |
|
157 |
|
158 |
+
def get_fg_color(bg_color_hex: str) -> str:
|
159 |
+
"""Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
|
160 |
+
|
161 |
+
Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
|
162 |
+
|
163 |
+
Args:
|
164 |
+
bg_color_hex (str): The background color given as a HEX stirng.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
str: Either "black" or "white".
|
168 |
+
"""
|
169 |
+
r = int(bg_color_hex[1:3], 16)
|
170 |
+
g = int(bg_color_hex[3:5], 16)
|
171 |
+
b = int(bg_color_hex[5:7], 16)
|
172 |
yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
|
173 |
return "black" if (yiq >= 128) else "white"
|
174 |
|
175 |
|
176 |
def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
|
177 |
+
"""Colorizes the errors in the dataframe."""
|
178 |
|
179 |
def colorize_row(row):
|
180 |
return [
|
|
|
198 |
|
199 |
|
200 |
def htmlify_labeled_example(example: pd.DataFrame) -> str:
|
201 |
+
"""Builds an HTML (string) representation of a single example.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
example (pd.DataFrame): The example to process.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
str: An HTML string representation of a single example.
|
208 |
+
"""
|
209 |
html = []
|
210 |
|
211 |
for _, row in example.iterrows():
|
|
|
246 |
return " ".join(html)
|
247 |
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
|
250 |
+
"""Turns a value into a color using a color map."""
|
251 |
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
|
252 |
cmap = cm.get_cmap(cmap_name) # PiYG
|
253 |
rgba = cmap(norm(abs(value)))
|