Mt5-base for Czech+English Generative Question Answering
This is the mt5-base model with an LM head for a generation of extractive answers. In contrary to our mt5-base-priming, this is a traditional sequence2sequence model without priming, though can also be used on other Text extraction tasks, such as Named Entity Recognition in zero-shot settings (with a significant decay in quality, compared to priming).
Intended uses & limitations
This model is purposed to generate a segment of a given context that contains an answer to a given question (Extractive Question Answering) in English and Czech. Given the fine-tuning on two languages and a good reported zero-shot cross-lingual applicability of other fine-tuned multilingual large language models, the model will likely also work on other languages as well, with a specific decay in quality.
Note that despite its size, English SQuAD has a variety of reported biases, conditioned by the relative position or type of the answer in the context that can affect the model's performance on new data (see, e.g. L. Mikula (2022), Chap. 4.1).
Usage
Here is how to use this model to answer the question on a given context using 🤗 Transformers in PyTorch:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("gaussalgo/mt5-base-generative-QA_en-cs")
model = AutoModelForSeq2SeqLM.from_pretrained("gaussalgo/mt5-base-generative-QA_en-cs")
context = """
Podle slovenského lidového podání byl Juro Jánošík obdařen magickými předměty (kouzelná valaška, čarovný opasek),
které mu dodávaly nadpřirozené schopnosti. Okrádal především šlechtice,
trestal panské dráby a ze svého lupu vyděloval část pro chudé, tedy bohatým bral a chudým dával.
"""
question = "Jaké schopnosti daly magické předměty Juro Jánošíkovi?"
inputs = tokenizer(question, context, return_tensors="pt")
outputs = model.generate(**inputs)
print("Answer:")
print(tokenizer.decode(outputs))
Training
The model has been trained using Adaptor library v0.1.5, in parallel on both Czech and English data, with the following parameters:
training_arguments = AdaptationArguments(output_dir="train_dir",
learning_rate=5e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
do_train=True,
do_eval=True,
warmup_steps=1000,
max_steps=100000,
gradient_accumulation_steps=4,
eval_steps=100,
logging_steps=10,
save_steps=1000,
num_train_epochs=50,
evaluation_strategy="steps",
remove_unused_columns=False)
You can find the full training script in train_mt5_qa_en+cs.py, reproducible after a specific data preprocessing for Czech SQAD in parse_czech_squad.py
- Downloads last month
- 12