SengTak's picture
added necesary files-1
55890ea verified
import json
import random
import os
from .dataset import DataSample, TrainSample, Dataset
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="INFO")
E5_EMBEDDING_PROMPTS = {
"allnli": [
"Given a premise, retrieve a hypothesis that is entailed by the premise",
"Retrieve semantically similar text",
],
"dureader": "Given a Chinese search query, retrieve web passages that answer the question",
"eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum",
"fever": "Given a claim, retrieve documents that support or refute the claim",
"hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question",
"miracl": "Given a question, retrieve Wikipedia passages that answer the question",
"mrtydi": "Given a question, retrieve Wikipedia passages that answer the question",
"msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query",
"msmarco_document": "Given a web search query, retrieve relevant documents that answer the query",
"nq": "Given a question, retrieve Wikipedia passages that answer the question",
"quora_duplicates": [
"Given a question, retrieve questions that are semantically equivalent to the given question",
"Find questions that have the same meaning as the input question",
],
"squad": "Retrieve Wikipedia passages that answer the question",
"t2ranking": "Given a Chinese search query, retrieve web passages that answer the question",
"trivia_qa": "Retrieve Wikipedia passages that answer the question",
}
class E5Data(Dataset):
def __init__(
self,
dataset_name: str = "E5",
split: str = "validation",
file_path: str = "cache/echo-data",
effective_batch_size: int = 32,
shuffle_individual_datasets: bool = True,
separator: str = "!@#$%^&*()",
):
self.dataset_name = dataset_name
self.split = split
self.effective_batch_size = effective_batch_size
self.shuffle_individual_datasets = shuffle_individual_datasets
self.separator = separator
self.data = []
self.load_data(file_path)
def __len__(self):
return len(self.data)
def load_data(self, file_path: str = None):
logger.info(f"Loading E5 data from {file_path}...")
# file path is actually a directory
data_map = {}
all_samples = []
id_ = 0
for dataset in E5_EMBEDDING_PROMPTS:
logger.info(f"Loading dataset {dataset}...")
if dataset not in data_map:
data_map[dataset] = []
with open(os.path.join(file_path, f"{dataset}.jsonl"), "r") as f:
dataset_samples = f.readlines()
dataset_samples = [json.loads(d) for d in dataset_samples]
for i, sample in enumerate(dataset_samples):
instruction = (
E5_EMBEDDING_PROMPTS[dataset]
if isinstance(E5_EMBEDDING_PROMPTS[dataset], str)
else E5_EMBEDDING_PROMPTS[dataset][i % 2]
)
query = f"{instruction}; " + self.separator + sample["query"]
if dataset in [
"allnli_split2",
"quora_duplicates_split1",
"quora_duplicates_split2",
]:
pos = (
f"{E5_EMBEDDING_PROMPTS[dataset]}; "
+ self.separator
+ sample["positive"]
)
neg = (
f"{E5_EMBEDDING_PROMPTS[dataset]}; "
+ self.separator
+ sample["negative"]
)
else:
pos = self.separator + sample["positive"]
neg = self.separator + sample["negative"]
data_map[dataset].append(id_)
all_samples.append(
DataSample(
id_=id_,
query=query,
positive=pos,
negative=neg,
task_name=dataset,
)
)
id_ += 1
# combine split1 and split2
new_data_map = {}
for dataset in data_map:
new_dataset = dataset.replace("_split1", "").replace("_split2", "")
if new_dataset not in new_data_map:
new_data_map[new_dataset] = []
new_data_map[new_dataset] += data_map[dataset]
data_map = new_data_map
if self.shuffle_individual_datasets:
for task, samples in data_map.items():
random.shuffle(samples)
datasets = list(data_map.keys())
logger.info(
f"Batching Echo data properly for effective batch size of {self.effective_batch_size}..."
)
all_batches = []
for dataset in datasets:
dataset_samples = data_map[dataset]
for i in range(0, len(dataset_samples), self.effective_batch_size):
batch = dataset_samples[i : i + self.effective_batch_size]
if len(batch) == self.effective_batch_size:
all_batches.append(batch)
else:
logger.info(f"Skip 1 batch for dataset {dataset}.")
random.shuffle(all_batches)
final_idx_order = []
for batch in all_batches:
for idx in batch:
final_idx_order.append(idx)
self.data = [all_samples[idx] for idx in final_idx_order]
logger.info(f"Loaded {len(self.data)} samples.")
def __getitem__(self, index):
sample = self.data[index]
if self.split == "train":
return TrainSample(
texts=[sample.query, sample.positive, sample.negative], label=1.0
)
elif self.split == "validation":
assert False, "E5Data does not have a validation split."