Spaces:
Sleeping
Sleeping
import json | |
import random | |
import os | |
from .dataset import DataSample, TrainSample, Dataset | |
from accelerate.logging import get_logger | |
logger = get_logger(__name__, log_level="INFO") | |
E5_EMBEDDING_PROMPTS = { | |
"allnli": [ | |
"Given a premise, retrieve a hypothesis that is entailed by the premise", | |
"Retrieve semantically similar text", | |
], | |
"dureader": "Given a Chinese search query, retrieve web passages that answer the question", | |
"eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum", | |
"fever": "Given a claim, retrieve documents that support or refute the claim", | |
"hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question", | |
"miracl": "Given a question, retrieve Wikipedia passages that answer the question", | |
"mrtydi": "Given a question, retrieve Wikipedia passages that answer the question", | |
"msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query", | |
"msmarco_document": "Given a web search query, retrieve relevant documents that answer the query", | |
"nq": "Given a question, retrieve Wikipedia passages that answer the question", | |
"quora_duplicates": [ | |
"Given a question, retrieve questions that are semantically equivalent to the given question", | |
"Find questions that have the same meaning as the input question", | |
], | |
"squad": "Retrieve Wikipedia passages that answer the question", | |
"t2ranking": "Given a Chinese search query, retrieve web passages that answer the question", | |
"trivia_qa": "Retrieve Wikipedia passages that answer the question", | |
} | |
class E5Data(Dataset): | |
def __init__( | |
self, | |
dataset_name: str = "E5", | |
split: str = "validation", | |
file_path: str = "cache/echo-data", | |
effective_batch_size: int = 32, | |
shuffle_individual_datasets: bool = True, | |
separator: str = "!@#$%^&*()", | |
): | |
self.dataset_name = dataset_name | |
self.split = split | |
self.effective_batch_size = effective_batch_size | |
self.shuffle_individual_datasets = shuffle_individual_datasets | |
self.separator = separator | |
self.data = [] | |
self.load_data(file_path) | |
def __len__(self): | |
return len(self.data) | |
def load_data(self, file_path: str = None): | |
logger.info(f"Loading E5 data from {file_path}...") | |
# file path is actually a directory | |
data_map = {} | |
all_samples = [] | |
id_ = 0 | |
for dataset in E5_EMBEDDING_PROMPTS: | |
logger.info(f"Loading dataset {dataset}...") | |
if dataset not in data_map: | |
data_map[dataset] = [] | |
with open(os.path.join(file_path, f"{dataset}.jsonl"), "r") as f: | |
dataset_samples = f.readlines() | |
dataset_samples = [json.loads(d) for d in dataset_samples] | |
for i, sample in enumerate(dataset_samples): | |
instruction = ( | |
E5_EMBEDDING_PROMPTS[dataset] | |
if isinstance(E5_EMBEDDING_PROMPTS[dataset], str) | |
else E5_EMBEDDING_PROMPTS[dataset][i % 2] | |
) | |
query = f"{instruction}; " + self.separator + sample["query"] | |
if dataset in [ | |
"allnli_split2", | |
"quora_duplicates_split1", | |
"quora_duplicates_split2", | |
]: | |
pos = ( | |
f"{E5_EMBEDDING_PROMPTS[dataset]}; " | |
+ self.separator | |
+ sample["positive"] | |
) | |
neg = ( | |
f"{E5_EMBEDDING_PROMPTS[dataset]}; " | |
+ self.separator | |
+ sample["negative"] | |
) | |
else: | |
pos = self.separator + sample["positive"] | |
neg = self.separator + sample["negative"] | |
data_map[dataset].append(id_) | |
all_samples.append( | |
DataSample( | |
id_=id_, | |
query=query, | |
positive=pos, | |
negative=neg, | |
task_name=dataset, | |
) | |
) | |
id_ += 1 | |
# combine split1 and split2 | |
new_data_map = {} | |
for dataset in data_map: | |
new_dataset = dataset.replace("_split1", "").replace("_split2", "") | |
if new_dataset not in new_data_map: | |
new_data_map[new_dataset] = [] | |
new_data_map[new_dataset] += data_map[dataset] | |
data_map = new_data_map | |
if self.shuffle_individual_datasets: | |
for task, samples in data_map.items(): | |
random.shuffle(samples) | |
datasets = list(data_map.keys()) | |
logger.info( | |
f"Batching Echo data properly for effective batch size of {self.effective_batch_size}..." | |
) | |
all_batches = [] | |
for dataset in datasets: | |
dataset_samples = data_map[dataset] | |
for i in range(0, len(dataset_samples), self.effective_batch_size): | |
batch = dataset_samples[i : i + self.effective_batch_size] | |
if len(batch) == self.effective_batch_size: | |
all_batches.append(batch) | |
else: | |
logger.info(f"Skip 1 batch for dataset {dataset}.") | |
random.shuffle(all_batches) | |
final_idx_order = [] | |
for batch in all_batches: | |
for idx in batch: | |
final_idx_order.append(idx) | |
self.data = [all_samples[idx] for idx in final_idx_order] | |
logger.info(f"Loaded {len(self.data)} samples.") | |
def __getitem__(self, index): | |
sample = self.data[index] | |
if self.split == "train": | |
return TrainSample( | |
texts=[sample.query, sample.positive, sample.negative], label=1.0 | |
) | |
elif self.split == "validation": | |
assert False, "E5Data does not have a validation split." | |