File size: 5,768 Bytes
fb8317d
 
 
 
 
 
 
 
b6413de
 
 
fb8317d
b6413de
 
 
 
 
 
 
 
 
6380c39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6413de
fb8317d
 
f93b91e
fb8317d
f8aa6e8
fb8317d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c0c29c
f8aa6e8
3c0c29c
f8aa6e8
 
fb8317d
 
 
f8aa6e8
 
 
 
 
 
fb8317d
 
 
278ad65
fb8317d
f93b91e
 
 
 
 
8a94c9b
 
f93b91e
 
 
8a94c9b
 
 
 
f93b91e
 
 
 
fb8317d
 
 
 
 
 
 
 
 
 
 
 
 
 
6582fd0
 
fb8317d
 
 
 
 
 
f93b91e
 
fb8317d
 
 
 
 
 
f93b91e
fb8317d
 
 
 
 
 
 
 
b6413de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
license: apache-2.0
base_model: google/mt5-small
tags:
- generated_from_trainer
metrics:
- rouge
- bleu
- meteor
datasets:
- natural_questions
model-index:
- name: mt5-small
  results:
  - task:
      type: Question answering from context             # Required. Example: automatic-speech-recognition
      name: Question answering             # Optional. Example: Speech Recognition
    dataset:
      type: natural-questions          # Required. Example: common_voice. Use dataset id from https://hf.co/datasets
      name: Adapted Natural Questions          # Required. A pretty name for the dataset. Example: Common Voice (French)
    metrics:
    - type: bleu
      value: 34.1596
      name: BLEU
      verified: true
    - type: rouge
      value: 44.4366
      name: ROUGE1
      verified: true
    - type: rouge
      value: 38.8202
      name: ROUGE2
      verified: true
    - type: rouge
      value: 43.113
      name: ROUGEl
      verified: true
    - type: rouge
      value: 43.1423
      name: ROUGElsum
      verified: true
    - type: meteor
      value: 0.4049
      name: METEOR
      verified: true

---

# mt5-small

This model is a fine-tuned version of [google/mt5-small](https://huggingface.co/google/mt5-small) on an enhanced version of the Natural Questions dataset.
It achieves the following results on the evaluation set:
- Loss: 0.7291
- Rouge1: 44.4366
- Rouge2: 38.8202
- Rougel: 43.113
- Rougelsum: 43.1423
- Bleu: 34.1596
- Gen Len: 12.6724
- Meteor: 0.4049
- True negatives: 69.7281
- False negatives: 10.4037
- Cosine Sim: 0.763

## Model description

This model is fine-tuned for long-form, closed-domain question answering - question-answering from context. It uses a heavily refined version of [Google's Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/).

Answers to the questions were rewritten using [OpenAI's GPT-3.5 Turbo model](https://platform.openai.com/docs/models).

Please see [the following repo](https://github.com/pointonjoel/MSc-Diss) for all code and adaptations.

## Intended uses & limitations

The model requires questions to be submitted using the following format using the input message:
\[CONTEXT\] <\s> \[QUESTION\]

It is trained to respond appropriately when a question cannot be answered using the provided context.

It can give false negatives and false positives on occasion (see Training Results), and all answers must be checked appropriately.

## Training and evaluation data

The model is trained using the Natural Questions dataset, with answers that have been refined using GPT-3.5 Turbo. It is evaluated using a number of metrics including BLEU, ROUGE, METEOR, and cosine similarity. 

## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model_name = "psxjp5/mt5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Generate text
context = "Once upon a time"
question = "What is time"
input_ids = tokenizer(context, question, return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_new_tokens=150)

print(tokenizer.decode(output[0], skip_special_tokens=True))
```

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 0.001
- train_batch_size: 16
- eval_batch_size: 16
- seed: 9
- gradient_accumulation_steps: 8
- total_train_batch_size: 128
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 20
- weight_decay: 0.007
- dropout: 0.4

### Training results

| Training Loss | Epoch | Step | Validation Loss | Rouge1  | Rouge2  | Rougel  | Rougelsum | Bleu    | Gen Len | Meteor | True negatives | False negatives | Cosine Sim |
|:-------------:|:-----:|:----:|:---------------:|:-------:|:-------:|:-------:|:---------:|:-------:|:-------:|:------:|:--------------:|:---------------:|:----------:|
| 2.5724        | 1.0   | 175  | 0.9876          | 18.7781 | 15.6002 | 18.22   | 18.2686   | 7.6676  | 7.7661  | 0.1628 | 72.8701        | 56.677          | 0.4003     |
| 1.1469        | 1.99  | 350  | 0.8580          | 36.8209 | 31.2514 | 35.5008 | 35.5462   | 25.7137 | 12.0014 | 0.3311 | 62.8399        | 20.3934         | 0.66
|
| 0.9468        | 2.99  | 525  | 0.7997          | 40.4128 | 34.716  | 39.0867 | 39.0972   | 29.3028 | 12.4287 | 0.3656 | 63.4441        | 15.295          | 0.7114     |
| 0.8129        | 3.98  | 700  | 0.7733          | 42.6764 | 36.7266 | 41.2465 | 41.2833   | 32.0644 | 12.9002 | 0.3871 | 62.1752        | 11.413          | 0.7425     |
| 0.7228        | 4.98  | 875  | 0.7483          | 42.9082 | 36.957  | 41.482  | 41.5233   | 32.4942 | 12.8866 | 0.3906 | 63.3233        | 11.5166         | 0.747      |
| 0.6493        | 5.97  | 1050 | 0.7293          | 40.3205 | 34.9632 | 39.1111 | 39.1168   | 28.8249 | 11.6867 | 0.3674 | 73.8973        | 17.9865         | 0.7068     |
| 0.5883        | 6.97  | 1225 | 0.7172          | 42.7342 | 37.0855 | 41.4069 | 41.424    | 32.1296 | 12.48   | 0.3887 | 70.0302        | 12.7847         | 0.7392     |
| 0.5409        | 7.96  | 1400 | 0.7387          | 44.6657 | 38.8426 | 43.3276 | 43.3496   | 34.4773 | 12.9395 | 0.4084 | 66.3444        | 9.5238          | 0.7658     |
| 0.5035        | 8.96  | 1575 | 0.7330          | 43.4925 | 38.0013 | 42.2697 | 42.2372   | 32.6131 | 12.2789 | 0.3979 | 72.6284        | 12.8364         | 0.7```1     |
| 0.4652        | 9.95  | 1750 | 0.7291          | 44.4366 | 38.8202 | 43.113  | 43.1423   | 34.1596 | 12.6724 | 0.4049 | 69.7281        | 10.4037         | 0.763      |


### Framework versions

- Transformers 4.31.0
- Pytorch 2.0.1+cu118
- Datasets 2.13.1
- Tokenizers 0.13.3