Spaces:
Runtime error
Runtime error
Alexander Seifert
commited on
Commit
·
9556889
1
Parent(s):
2918df9
improve docs
Browse files- src/data.py +53 -2
- src/load.py +14 -0
- src/subpages/attention.py +2 -2
- src/subpages/page.py +10 -0
src/data.py
CHANGED
@@ -11,7 +11,19 @@ from utils import device, tokenizer_hash_funcs
|
|
11 |
|
12 |
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
-
def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
|
16 |
split = ds[split_name].select(range(split_sample_size))
|
17 |
return split
|
@@ -22,6 +34,14 @@ def get_data(ds_name, config_name, split_name, split_sample_size) -> Dataset:
|
|
22 |
hash_funcs=tokenizer_hash_funcs,
|
23 |
)
|
24 |
def get_collator(tokenizer) -> DataCollatorForTokenClassification:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return DataCollatorForTokenClassification(tokenizer)
|
26 |
|
27 |
|
@@ -70,10 +90,29 @@ def tokenize_and_align_labels(examples, tokenizer):
|
|
70 |
|
71 |
|
72 |
def stringify_ner_tags(batch, tags):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
|
74 |
|
75 |
|
76 |
-
def encode_dataset(split, tokenizer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
tags = split.features["ner_tags"].feature
|
78 |
split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
|
79 |
remove_columns = split.column_names
|
@@ -120,6 +159,18 @@ def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
|
|
120 |
|
121 |
|
122 |
def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
split_encoded = split_encoded.map(
|
124 |
partial(
|
125 |
forward_pass_with_label,
|
|
|
11 |
|
12 |
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
+
def get_data(ds_name: str, config_name: str, split_name: str, split_sample_size: int) -> Dataset:
|
15 |
+
"""Loads dataset from the HF hub (if not already loaded) and returns a Dataset object.
|
16 |
+
Uses datasets.load_dataset to load the dataset (see its documentation for additional details).
|
17 |
+
|
18 |
+
Args:
|
19 |
+
ds_name (str): Path or name of the dataset.
|
20 |
+
config_name (str): Name of the dataset configuration.
|
21 |
+
split_name (str): Which split of the data to load.
|
22 |
+
split_sample_size (int): The number of examples to load from the split.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Dataset: A Dataset object.
|
26 |
+
"""
|
27 |
ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(seed=0) # type: ignore
|
28 |
split = ds[split_name].select(range(split_sample_size))
|
29 |
return split
|
|
|
34 |
hash_funcs=tokenizer_hash_funcs,
|
35 |
)
|
36 |
def get_collator(tokenizer) -> DataCollatorForTokenClassification:
|
37 |
+
"""Data collator that will dynamically pad the inputs received, as well as the labels.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
tokenizer ([PreTrainedTokenizer] or [PreTrainedTokenizerFast]): The tokenizer used for encoding the data.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
DataCollatorForTokenClassification: The DataCollatorForTokenClassification object.
|
44 |
+
"""
|
45 |
return DataCollatorForTokenClassification(tokenizer)
|
46 |
|
47 |
|
|
|
90 |
|
91 |
|
92 |
def stringify_ner_tags(batch, tags):
|
93 |
+
"""Stringifies a dataset batch's NER tags.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
batch (_type_): _description_
|
97 |
+
tags (_type_): _description_
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
_type_: _description_
|
101 |
+
"""
|
102 |
return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
|
103 |
|
104 |
|
105 |
+
def encode_dataset(split: Dataset, tokenizer):
|
106 |
+
"""Encodes a dataset split.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
split (Dataset): A Dataset object.
|
110 |
+
tokenizer: A PreTrainedTokenizer object.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Dataset: A Dataset object with the encoded inputs.
|
114 |
+
"""
|
115 |
+
|
116 |
tags = split.features["ner_tags"].feature
|
117 |
split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
|
118 |
remove_columns = split.column_names
|
|
|
159 |
|
160 |
|
161 |
def get_split_df(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
|
162 |
+
"""Turns a Dataset into a pandas dataframe.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
split_encoded (Dataset): _description_
|
166 |
+
model (_type_): _description_
|
167 |
+
tokenizer (_type_): _description_
|
168 |
+
collator (_type_): _description_
|
169 |
+
tags (_type_): _description_
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
pd.DataFrame: _description_
|
173 |
+
"""
|
174 |
split_encoded = split_encoded.map(
|
175 |
partial(
|
176 |
forward_pass_with_label,
|
src/load.py
CHANGED
@@ -39,6 +39,20 @@ def load_context(
|
|
39 |
split_sample_size: int,
|
40 |
**kw_args,
|
41 |
) -> Context:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
|
44 |
encoder_model_name=encoder_model_name,
|
|
|
39 |
split_sample_size: int,
|
40 |
**kw_args,
|
41 |
) -> Context:
|
42 |
+
"""Utility method loading (almost) everything we need for the application.
|
43 |
+
This exists just because we want to cache the results of this function.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
encoder_model_name (str): Name of the sentence encoder to load.
|
47 |
+
model_name (str): Name of the NER model to load.
|
48 |
+
ds_name (str): Dataset name or path.
|
49 |
+
ds_config_name (str): Dataset config name.
|
50 |
+
ds_split_name (str): Dataset split name.
|
51 |
+
split_sample_size (int): Number of examples to load from the split.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Context: An object containing everything we need for the application.
|
55 |
+
"""
|
56 |
|
57 |
sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
|
58 |
encoder_model_name=encoder_model_name,
|
src/subpages/attention.py
CHANGED
@@ -80,7 +80,7 @@ JS_TEMPLATE = """requirejs(['basic', 'ecco'], function(basic, ecco){{
|
|
80 |
|
81 |
|
82 |
@st.cache(allow_output_mutation=True)
|
83 |
-
def
|
84 |
model_config = {
|
85 |
"embedding": "embeddings.word_embeddings",
|
86 |
"type": "mlm",
|
@@ -115,7 +115,7 @@ class AttentionPage(Page):
|
|
115 |
"A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
|
116 |
)
|
117 |
|
118 |
-
lm =
|
119 |
|
120 |
col1, _, col2 = st.columns([1.5, 0.5, 4])
|
121 |
with col1:
|
|
|
80 |
|
81 |
|
82 |
@st.cache(allow_output_mutation=True)
|
83 |
+
def _load_ecco_model():
|
84 |
model_config = {
|
85 |
"embedding": "embeddings.word_embeddings",
|
86 |
"type": "mlm",
|
|
|
115 |
"A group of neurons tend to fire in response to commas and other punctuation. Other groups of neurons tend to fire in response to pronouns. Use this visualization to factorize neuron activity in individual FFNN layers or in the entire model."
|
116 |
)
|
117 |
|
118 |
+
lm = _load_ecco_model()
|
119 |
|
120 |
col1, _, col2 = st.columns([1.5, 0.5, 4])
|
121 |
with col1:
|
src/subpages/page.py
CHANGED
@@ -10,6 +10,8 @@ from transformers import AutoTokenizer # type: ignore
|
|
10 |
|
11 |
@dataclass
|
12 |
class Context:
|
|
|
|
|
13 |
model: AutoModelForSequenceClassification
|
14 |
tokenizer: AutoTokenizer
|
15 |
sentence_encoder: SentenceTransformer
|
@@ -27,11 +29,19 @@ class Context:
|
|
27 |
|
28 |
|
29 |
class Page:
|
|
|
|
|
30 |
name: str
|
31 |
icon: str
|
32 |
|
33 |
def get_widget_defaults(self):
|
|
|
|
|
|
|
|
|
|
|
34 |
return {}
|
35 |
|
36 |
def render(self, context):
|
|
|
37 |
...
|
|
|
10 |
|
11 |
@dataclass
|
12 |
class Context:
|
13 |
+
"""This object facilitates passing around the applications state between different pages."""
|
14 |
+
|
15 |
model: AutoModelForSequenceClassification
|
16 |
tokenizer: AutoTokenizer
|
17 |
sentence_encoder: SentenceTransformer
|
|
|
29 |
|
30 |
|
31 |
class Page:
|
32 |
+
"""Base class for all pages."""
|
33 |
+
|
34 |
name: str
|
35 |
icon: str
|
36 |
|
37 |
def get_widget_defaults(self):
|
38 |
+
"""This function holds the default settings for all the page's widgets.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
|
42 |
+
"""
|
43 |
return {}
|
44 |
|
45 |
def render(self, context):
|
46 |
+
"""This function renders the page."""
|
47 |
...
|