Spaces:
Runtime error
Runtime error
Alexander Seifert
commited on
Commit
·
bb162b6
1
Parent(s):
408486e
add randomize_sample option
Browse files- README.md +1 -1
- src/data.py +6 -2
- src/load.py +4 -1
- src/subpages/home.py +9 -1
README.md
CHANGED
@@ -19,7 +19,7 @@ Error Analysis is an important but often overlooked part of the data science pro
|
|
19 |
|
20 |
### Activations
|
21 |
|
22 |
-
A group of neurons
|
23 |
|
24 |
|
25 |
### Embeddings
|
|
|
19 |
|
20 |
### Activations
|
21 |
|
22 |
+
A group of neurons tends to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model.
|
23 |
|
24 |
|
25 |
### Embeddings
|
src/data.py
CHANGED
@@ -11,7 +11,9 @@ from src.utils import device, tokenizer_hash_funcs
|
|
11 |
|
12 |
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
-
def get_data(
|
|
|
|
|
15 |
"""Loads a Dataset from the HuggingFace hub (if not already loaded).
|
16 |
|
17 |
Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
|
@@ -25,7 +27,9 @@ def get_data(ds_name: str, config_name: str, split_name: str, split_sample_size:
|
|
25 |
Returns:
|
26 |
Dataset: A Dataset object.
|
27 |
"""
|
28 |
-
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
|
|
|
|
|
29 |
split = ds[split_name].select(range(split_sample_size))
|
30 |
return split
|
31 |
|
|
|
11 |
|
12 |
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
+
def get_data(
|
15 |
+
ds_name: str, config_name: str, split_name: str, split_sample_size: int, randomize_sample: bool
|
16 |
+
) -> Dataset:
|
17 |
"""Loads a Dataset from the HuggingFace hub (if not already loaded).
|
18 |
|
19 |
Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).
|
|
|
27 |
Returns:
|
28 |
Dataset: A Dataset object.
|
29 |
"""
|
30 |
+
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
|
31 |
+
seed=0 if randomize_sample else None
|
32 |
+
) # type: ignore
|
33 |
split = ds[split_name].select(range(split_sample_size))
|
34 |
return split
|
35 |
|
src/load.py
CHANGED
@@ -37,6 +37,7 @@ def load_context(
|
|
37 |
ds_config_name: str,
|
38 |
ds_split_name: str,
|
39 |
split_sample_size: int,
|
|
|
40 |
**kw_args,
|
41 |
) -> Context:
|
42 |
"""Utility method loading (almost) everything we need for the application.
|
@@ -63,7 +64,9 @@ def load_context(
|
|
63 |
collator = get_collator(tokenizer)
|
64 |
|
65 |
# load data related stuff
|
66 |
-
split: Dataset = get_data(
|
|
|
|
|
67 |
tags = split.features["ner_tags"].feature
|
68 |
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
69 |
|
|
|
37 |
ds_config_name: str,
|
38 |
ds_split_name: str,
|
39 |
split_sample_size: int,
|
40 |
+
randomize_sample: bool,
|
41 |
**kw_args,
|
42 |
) -> Context:
|
43 |
"""Utility method loading (almost) everything we need for the application.
|
|
|
64 |
collator = get_collator(tokenizer)
|
65 |
|
66 |
# load data related stuff
|
67 |
+
split: Dataset = get_data(
|
68 |
+
ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample
|
69 |
+
)
|
70 |
tags = split.features["ner_tags"].feature
|
71 |
split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
|
72 |
|
src/subpages/home.py
CHANGED
@@ -45,6 +45,7 @@ class HomePage(Page):
|
|
45 |
"ds_split_name": "validation",
|
46 |
"ds_config_name": _CONFIG_NAME,
|
47 |
"split_sample_size": 512,
|
|
|
48 |
}
|
49 |
|
50 |
def render(self, context: Optional[Context] = None):
|
@@ -118,11 +119,18 @@ class HomePage(Page):
|
|
118 |
key="split_sample_size",
|
119 |
help="Sample size for the split, speeds up processing inside streamlit",
|
120 |
)
|
|
|
|
|
|
|
|
|
|
|
121 |
# breakpoint()
|
122 |
# st.form_submit_button("Submit")
|
123 |
st.form_submit_button("Load Model & Data")
|
124 |
|
125 |
-
split = get_data(
|
|
|
|
|
126 |
labels = list(
|
127 |
set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
|
128 |
)
|
|
|
45 |
"ds_split_name": "validation",
|
46 |
"ds_config_name": _CONFIG_NAME,
|
47 |
"split_sample_size": 512,
|
48 |
+
"randomize_sample": True,
|
49 |
}
|
50 |
|
51 |
def render(self, context: Optional[Context] = None):
|
|
|
119 |
key="split_sample_size",
|
120 |
help="Sample size for the split, speeds up processing inside streamlit",
|
121 |
)
|
122 |
+
randomize_sample = st.checkbox(
|
123 |
+
"Randomize sample",
|
124 |
+
key="randomize_sample",
|
125 |
+
help="Whether to randomize the sample",
|
126 |
+
)
|
127 |
# breakpoint()
|
128 |
# st.form_submit_button("Submit")
|
129 |
st.form_submit_button("Load Model & Data")
|
130 |
|
131 |
+
split = get_data(
|
132 |
+
ds_name, ds_config_name, ds_split_name, split_sample_size, randomize_sample # type: ignore
|
133 |
+
)
|
134 |
labels = list(
|
135 |
set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
|
136 |
)
|