Model Card for T5-LM-Large_Canard-HotpotQA-rephrase
This model is trained on three objectives:
- Generating answers for Canard dataset based on Wikipedia search results
- Generating answers for HotpotQA,
- Rephrasing questions by the conversation context.
Training
The model was trained using the following script, which can be copy-pasted and run as-is (with the installed requirements.txt
).
All details, including the request format, can be inferred without errors from the code.
The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE.
import datasets
canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train")
canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test")
canard_df = canard_train_augm.to_pandas()
canard_test_df = canard_train_augm.to_pandas()
### Curation of seq2seq input contexts and labels
import random
def input_context_from_sample(row: dict, max_length=5) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nSearch results:"
all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
random.shuffle(all_contexts)
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
def rephrasing_context_from_sample(row: dict) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nMore specific question: "
return context
def hotpotqa_context(row: dict) -> str:
context = "Current Question: "
context += row["question"]
context += "\nSearch results:"
all_contexts = [" ".join(context) for context in row["context"]["sentences"]]
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
# Conversational QA sequences
input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values
too_long_index = [len(t) > 20000 for t in input_texts]
input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
# print(too_long_index)
print("training on %s samples" % len(input_texts))
labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
labels = [l for i, l in enumerate(labels) if not too_long_index[i]]
val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
# Rephrasing sequences
rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_labels = canard_df.Rewrite.values
rephrasing_val_labels = canard_test_df.Rewrite.values
# HotpotQA sequences
hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]
hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
too_long_index = [len(t) > 20000 for t in hotpot_inputs]
hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]
# Training routine
# see Adaptor's homepage for details:
# https://github.com/gaussalgo/adaptor
# Base model
from adaptor.lang_module import LangModule
lang_module = LangModule("google/t5-large-lm-adapt")
from adaptor.evaluators.generative import ROUGE, BLEU
# Evaluations
evaluators = [BLEU(), ROUGE(decides_convergence=True)]
# Objectives
from adaptor.objectives.seq2seq import Sequence2Sequence
seq_qa = Sequence2Sequence(lang_module,
texts_or_path=input_texts,
labels_or_path=labels,
val_texts_or_path=input_val_texts,
val_labels_or_path=val_labels,
batch_size=4,
val_evaluators=evaluators,
objective_id="Canard")
seq_additional_qa = Sequence2Sequence(lang_module,
texts_or_path=hotpot_inputs,
labels_or_path=hotpot_answers,
val_texts_or_path=hotpot_val_inputs[:200],
val_labels_or_path=hotpot_val["answer"][:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="HotpotQA",
share_other_objective_head=seq_qa)
seq_rephrasing = Sequence2Sequence(lang_module,
texts_or_path=rephrasing_inputs,
labels_or_path=rephrasing_labels,
val_texts_or_path=rephrasing_val_inputs[:200],
val_labels_or_path=rephrasing_val_labels[:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="rephrasing",
share_other_objective_head=seq_qa)
# Training schedule & arguments
from adaptor.utils import AdaptationArguments, StoppingStrategy
training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
learning_rate=5e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
stopping_patience=8,
save_total_limit=8,
do_train=True,
do_eval=True,
bf16=True,
warmup_steps=1000,
gradient_accumulation_steps=8,
logging_steps=10,
eval_steps=200,
save_steps=1000,
num_train_epochs=10,
evaluation_strategy="steps")
from adaptor.schedules import ParallelSchedule
from adaptor.adapter import Adapter
schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train() # Training for 63k updates
Usage
See the prompting templates used in training to infer the optimal prompting format.
Contact
Feel free to ask questions here, or at stefanik{at} gaussalgo.com
- Downloads last month
- 14
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Datasets used to train gaussalgo/T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase
Space using gaussalgo/T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase 1
Evaluation results
- rouge on HotpotQAvalidation set self-reported0.477
- bleu on HotpotQAvalidation set self-reported29.110
- rouge on Wikipedia-augmented Conversational QA (Canard)validation set self-reported0.438
- bleu on Wikipedia-augmented Conversational QA (Canard)validation set self-reported19.340