Spaces:
Build error
Build error
File size: 11,092 Bytes
546a9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
from abc import abstractmethod
from pprint import pformat
from time import sleep
from typing import List, Tuple, Optional, Union, Generator
from datasets import (
Dataset,
DatasetDict,
DatasetInfo,
concatenate_datasets,
load_dataset,
)
# Defualt values for retrying dataset download
DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5
DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5
# Default value for creating missing val/test splits
TEST_OR_VAL_SPLIT_RATIO = 0.1
class SummInstance:
"""
Basic instance for summarization tasks
"""
def __init__(
self, source: Union[List[str], str], summary: str, query: Optional[str] = None
):
"""
Create a summarization instance
:rtype: object
:param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit
into specific models. For example, for the same document, it could be simply `str` or `List[str]` for
a list of sentences in the same document
:param summary: a string summary that serves as ground truth
:param query: Optional, applies when a string query is present
"""
self.source = source
self.summary = summary
self.query = query
def __repr__(self):
instance_dict = {"source": self.source, "summary": self.summary}
if self.query:
instance_dict["query"] = self.query
return str(instance_dict)
def __str__(self):
instance_dict = {"source": self.source, "summary": self.summary}
if self.query:
instance_dict["query"] = self.query
return pformat(instance_dict, indent=1)
class SummDataset:
"""
Dataset class for summarization, which takes into account of the following tasks:
* Single document summarization
* Multi-document/Dialogue summarization
* Query-based summarization
"""
def __init__(
self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None
):
"""Create dataset information from the huggingface Dataset class
:rtype: object
:param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method.
Only required for datasets loaded from the Huggingface library.
The arguments for each dataset are different and comprise of a string or multiple strings
:param splitseed: a number to instantiate the random generator used to generate val/test splits
for the datasets without them
"""
# Load dataset from huggingface, use default huggingface arguments
if self.huggingface_dataset:
dataset = self._load_dataset_safe(*dataset_args)
# Load non-huggingface dataset, use custom dataset builder
else:
dataset = self._load_dataset_safe(path=self.builder_script_path)
info_set = self._get_dataset_info(dataset)
# Ensure any dataset with a val or dev or validation split is standardised to validation split
if "val" in dataset:
dataset["validation"] = dataset["val"]
dataset.remove("val")
elif "dev" in dataset:
dataset["validation"] = dataset["dev"]
dataset.remove("dev")
# If no splits other other than training, generate them
assert (
"train" in dataset or "validation" in dataset or "test" in dataset
), "At least one of train/validation test needs to be not empty!"
if not ("validation" in dataset or "test" in dataset):
dataset = self._generate_missing_val_test_splits(dataset, splitseed)
self.description = info_set.description
self.citation = info_set.citation
self.homepage = info_set.homepage
# Extract the dataset entries from folders and load into dataset
self._train_set = self._process_data(dataset["train"])
self._validation_set = self._process_data(
dataset["validation"]
) # Some datasets have a validation split
self._test_set = self._process_data(dataset["test"])
@property
def train_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._train_set is not None:
return self._train_set
else:
print(
f"{self.dataset_name} does not contain a train set, empty list returned"
)
return list()
@property
def validation_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._validation_set is not None:
return self._validation_set
else:
print(
f"{self.dataset_name} does not contain a validation set, empty list returned"
)
return list()
@property
def test_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._test_set is not None:
return self._test_set
else:
print(
f"{self.dataset_name} does not contain a test set, empty list returned"
)
return list()
def _load_dataset_safe(self, *args, **kwargs) -> Dataset:
"""
This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function,
the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests.
This method tackles this problem by attempting the download multiple times with a wait time before each retry
The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration.
:rtype: Dataset
:param args: non-keyword arguments to passed on to the 'load_dataset' function
:param kwargs: keyword arguments to passed on to the 'load_dataset' function
"""
tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED
wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY
for i in range(tries):
try:
dataset = load_dataset(*args, **kwargs)
except ConnectionError:
if i < tries - 1: # i is zero indexed
sleep(wait_time)
continue
else:
raise RuntimeError(
"Wait for a minute and attempt downloading the dataset again. \
The server hosting the dataset occassionally times out."
)
break
return dataset
def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo:
"""
Get the information set from the dataset
The information set contains: dataset name, description, version, citation and licence
:param data_dict: DatasetDict
:rtype: DatasetInfo
"""
return data_dict["train"].info
@abstractmethod
def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]:
"""
Abstract class method to process the data contained within each dataset.
Each dataset class processes it's own information differently due to the diversity in domains
This method processes the data contained in the dataset
and puts each data instance into a SummInstance object,
the SummInstance has the following properties [source, summary, query[optional]]
:param dataset: a train/validation/test dataset
:rtype: a generator yielding SummInstance objects
"""
return
def _generate_missing_val_test_splits(
self, dataset_dict: DatasetDict, seed: int
) -> DatasetDict:
"""
Creating the train, val and test splits from a dataset
the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size
the splits are randomized for each object unless a seed is provided for the random generator
:param dataset: Arrow Dataset with containing, usually the train set
:param seed: seed for the random generator to shuffle the dataset
:rtype: Arrow DatasetDict containing the three splits
"""
# Return dataset if no train set available for splitting
if "train" not in dataset_dict:
if "validation" not in dataset_dict:
dataset_dict["validation"] = None
if "test" not in dataset_dict:
dataset_dict["test"] = None
return dataset_dict
# Create a 'test' split from 'train' if no 'test' set is available
if "test" not in dataset_dict:
dataset_traintest_split = dataset_dict["train"].train_test_split(
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
)
dataset_dict["train"] = dataset_traintest_split["train"]
dataset_dict["test"] = dataset_traintest_split["test"]
# Create a 'validation' split from the remaining 'train' set if no 'validation' set is available
if "validation" not in dataset_dict:
dataset_trainval_split = dataset_dict["train"].train_test_split(
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
)
dataset_dict["train"] = dataset_trainval_split["train"]
dataset_dict["validation"] = dataset_trainval_split["test"]
return dataset_dict
def _concatenate_dataset_dicts(
self, dataset_dicts: List[DatasetDict]
) -> DatasetDict:
"""
Concatenate two dataset dicts with similar splits and columns tinto one
:param dataset_dicts: A list of DatasetDicts
:rtype: DatasetDict containing the combined data
"""
# Ensure all dataset dicts have the same splits
setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts)
if len(setsofsplits) > 1:
raise ValueError("Splits must match for all datasets")
# Concatenate all datasets into one according to the splits
temp_dict = {}
for split in setsofsplits.pop():
split_set = [dataset_dict[split] for dataset_dict in dataset_dicts]
temp_dict[split] = concatenate_datasets(split_set)
return DatasetDict(temp_dict)
@classmethod
def generate_basic_description(cls) -> str:
"""
Automatically generate the basic description string based on the attributes
:rtype: string containing the description
:param cls: class object
"""
basic_description = (
f": {cls.dataset_name} is a "
f"{'query-based ' if cls.is_query_based else ''}"
f"{'dialogue ' if cls.is_dialogue_based else ''}"
f"{'multi-document' if cls.is_multi_document else 'single-document'} "
f"summarization dataset."
)
return basic_description
def show_description(self):
"""
Print the description of the dataset.
"""
print(self.dataset_name, ":\n", self.description)
|