rita443 commited on
Commit
44158b3
1 Parent(s): a5131f2

Delete srl_pipeline.py

Browse files
Files changed (1) hide show
  1. srl_pipeline.py +0 -242
srl_pipeline.py DELETED
@@ -1,242 +0,0 @@
1
- import logging
2
- from typing import Any, Dict, List, Tuple
3
-
4
- import spacy
5
- import torch
6
- from transformers import Pipeline
7
-
8
- from decoder import Decoder
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class SrlPipeline(Pipeline):
14
- """
15
- A pipeline for Semantic Role Labeling (SRL) using transformers and spaCy models.
16
-
17
- This pipeline tokenizes input sentences, finds verbs using POS tagging, and postprocesses
18
- the model outputs using Viterbi decoding to provide human-readable results.
19
-
20
- Attributes:
21
- model ``str``: The name or identifier of the underlying transformer model.
22
- tokenizer ``str``: The name or identifier of the tokenizer associated with the model.
23
- framework ``str``: The framework used for the pipeline (e.g., PyTorch, TensorFlow).
24
- task ``str``: The specific task of the pipeline.
25
- verb_predictor: An instance of spaCy model used for predicting verbs in the input sentences.
26
- Usage:
27
- # Register the SrlPipeline in the pipeline registry
28
- PIPELINE_REGISTRY.register_pipeline(
29
- "srl",
30
- pipeline_class=SrlPipeline,
31
- model=SRLModel, # Assuming SRLModel is the model class used
32
- default={"lang": "en"},
33
- type="text",
34
- )
35
-
36
- # Load the model and tokenizer
37
- model = AutoModel.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
38
- tokenizer = AutoTokenizer.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
39
-
40
- # Load the SRL pipeline
41
- srl_pipeline = pipeline(
42
- "srl",
43
- model=model,
44
- tokenizer=tokenizer,
45
- framework="PyTorch", # Replace with actual framework used
46
- task="semantic_role_labeling", # Replace with actual task name
47
- lang="en" # Language specification
48
- )
49
-
50
- # Example text input
51
- text = ["The cat jumps over the fence.", "She quickly eats the delicious cake."]
52
-
53
- # Perform semantic role labeling
54
- results = srl_pipeline(text)
55
- """
56
-
57
- def __init__(self, model: str, tokenizer: str, framework: str, task: str, **kwargs):
58
- """
59
- Initializes the Semantic Role Labeling pipeline.
60
-
61
- Parameters:
62
- - model ``str``: The model name or identifier.
63
- - tokenizer ``str``: The tokenizer name or identifier.
64
- - framework ``str``: The framework used.
65
- - task ``str``: The specific task of the pipeline.
66
- - **kwargs: Additional keyword arguments.
67
- - lang ``str``, optional: Language specification ('en' for English or 'pt' for Portuguese, which is default).
68
- """
69
- super().__init__(model, tokenizer=tokenizer)
70
- if "lang" in kwargs and kwargs["lang"] == "en":
71
- logger.info("Loading English verb predictor model...")
72
- self.verb_predictor = spacy.load("en_core_web_trf")
73
- else:
74
- logger.info("Loading Portuguese verb predictor model...")
75
- self.verb_predictor = spacy.load("pt_core_news_lg")
76
- logger.info("Got verb prediction model\n")
77
-
78
- def _sanitize_parameters(
79
- self, **kwargs: Dict[str, Any]
80
- ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
81
- """
82
- Sanitizes and organizes additional parameters.
83
-
84
- Parameters:
85
- - **kwargs: Additional keyword arguments.
86
-
87
- Returns:
88
- - ``Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]``: Three dictionaries of sanitized parameters for preprocess, _forward, and postprocess.
89
- """
90
- return {}, {}, {}
91
-
92
- def preprocess(self, sentence: str) -> List[Dict[str, Any]]:
93
- """
94
- Preprocesses a sentence for semantic role labeling.
95
-
96
- Parameters:
97
- - sentence ``str``: The input sentence to be processed.
98
-
99
- Returns:
100
- - ``List[Dict[str, Any]]``: A list of dictionaries containing model inputs for each verb in the sentence.
101
- """
102
- # Extract sentence verbs
103
- doc = self.verb_predictor(sentence)
104
-
105
- verbs = {token.text for token in doc if token.pos_ == "VERB"}
106
- # If the sentence only contains auxiliary verbs, consider those as the
107
- # main verbs
108
- if not verbs:
109
- verbs = {token.text for token in doc if token.pos_ == "AUX"}
110
-
111
- # Tokenize sentence
112
- tokens = self.tokenizer.encode_plus(
113
- sentence,
114
- truncation=True,
115
- return_token_type_ids=False,
116
- return_offsets_mapping=True,
117
- )
118
- tokens_lst = tokens.tokens()
119
- offsets = tokens["offset_mapping"]
120
-
121
- input_ids = torch.tensor([tokens["input_ids"]], dtype=torch.long)
122
- attention_mask = torch.tensor([tokens["attention_mask"]], dtype=torch.long)
123
-
124
- model_input = {
125
- "input_ids": input_ids,
126
- "attention_mask": attention_mask,
127
- "token_type_ids": [],
128
- "tokens": tokens_lst,
129
- "verb": "",
130
- }
131
-
132
- model_inputs = [
133
- {**model_input} for _ in verbs
134
- ] # Create a new dictionary for each verb
135
-
136
- for i, verb in enumerate(verbs):
137
- model_inputs[i]["verb"] = verb
138
- token_type_ids = model_inputs[i]["token_type_ids"]
139
- token_type_ids.append([])
140
- curr_word_offsets: tuple[int, int] = None
141
-
142
- for j in range(len(tokens_lst)):
143
- curr_offsets = offsets[j]
144
- curr_slice = sentence[curr_offsets[0] : curr_offsets[1]]
145
- if not curr_slice:
146
- token_type_ids[-1].append(0)
147
- # Check if new token still belongs to same word
148
- elif (
149
- curr_word_offsets
150
- and curr_offsets[0] >= curr_word_offsets[0]
151
- and curr_offsets[1] <= curr_word_offsets[1]
152
- ):
153
- # Extend previous token type
154
- token_type_ids[-1].append(token_type_ids[-1][-1])
155
- else:
156
- curr_word_offsets = self._find_word(sentence, start=curr_offsets[0])
157
- curr_word = sentence[curr_word_offsets[0] : curr_word_offsets[1]]
158
-
159
- token_type_ids[-1].append(
160
- int(curr_word != "" and curr_word == verb)
161
- )
162
-
163
- model_inputs[i]["token_type_ids"] = torch.tensor(
164
- token_type_ids, dtype=torch.long
165
- )
166
-
167
- return model_inputs
168
-
169
- def _forward(self, model_inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
170
- """
171
- Internal method to forward model inputs for prediction.
172
-
173
- Parameters:
174
- - model_inputs ``List[Dict[str, Any]]``: List of dictionaries containing model inputs.
175
-
176
- Returns:
177
- - ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
178
- """
179
- outputs = []
180
- for model_input in model_inputs:
181
- output = self.model(
182
- input_ids=model_input["input_ids"],
183
- attention_mask=model_input["attention_mask"],
184
- token_type_ids=model_input["token_type_ids"],
185
- )
186
- output["verb"] = model_input["verb"]
187
- output["tokens"] = model_input["tokens"]
188
- outputs.append(output)
189
- return outputs
190
-
191
- def postprocess(self, model_outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
192
- """
193
- Postprocesses model outputs to human-readable format.
194
-
195
- Parameters:
196
- - model_outputs ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
197
-
198
- Returns:
199
- - ``List[Dict[str, Any]]``: List of dictionaries containing processed results.
200
- Each dictionary entry represents a verb with its associated labels and token-label pairs.
201
- Example format: {verb: (labels, List[(token, label)])}
202
- """
203
- result = []
204
- id2label = {int(k): str(v) for k, v in self.model.config.id2label.items()}
205
- evaluator = Decoder(id2label)
206
-
207
- for model_output in model_outputs:
208
- class_probabilities = model_output["class_probabilities"]
209
- attention_mask = model_output["attention_mask"]
210
- output_dict = evaluator.make_output_human_readable(
211
- class_probabilities, attention_mask
212
- )
213
- # Here we always fetch the first list because in a pipeline every
214
- # sentence is processed one at a time
215
- wordpiece_label_ids = output_dict["wordpiece_label_ids"][0]
216
- labels = list(map(lambda idx: id2label[idx], wordpiece_label_ids))
217
- result.append(
218
- {
219
- model_output["verb"]: (
220
- labels,
221
- list(zip(model_output["tokens"], labels)),
222
- )
223
- }
224
- )
225
- return result
226
-
227
- def _find_word(self, s: str, start: int = 0) -> Tuple[int, int]:
228
- """
229
- Helper method to find the boundaries of a word in a string.
230
- Assumes a non alphanumeric char represents the end of a word.
231
-
232
- Parameters:
233
- - s ``str``: The input string.
234
- - start ``int``, optional: Starting index to start looking for the word. Defaults to 0.
235
-
236
- Returns:
237
- - ``Tuple[int, int]``: Start and end indices of the word.
238
- """
239
- for i, char in enumerate(s[start:], start):
240
- if not char.isalpha():
241
- return start, i
242
- return start, len(s)