|
--- |
|
license: mit |
|
language: |
|
- en |
|
tags: |
|
- gpu |
|
--- |
|
# Text Summarization Model with Seq2Seq and LSTM |
|
|
|
This model is a sequence-to-sequence (seq2seq) model for text summarization. It uses a bidirectional LSTM encoder and an LSTM decoder to generate summaries from input articles. The model was trained on a dataset with sequences of length up to 800 tokens. |
|
|
|
## Model Architecture |
|
|
|
### Encoder |
|
|
|
- **Input Layer:** Takes input sequences of length `max_len_article`. |
|
- **Embedding Layer:** Converts input sequences into dense vectors of size 100. |
|
- **Bidirectional LSTM Layer:** Processes the embedded input, capturing dependencies in both forward and backward directions. Outputs hidden and cell states from both directions. |
|
- **State Concatenation:** Combines forward and backward hidden and cell states to form the final encoder states. |
|
|
|
### Decoder |
|
|
|
- **Input Layer:** Takes target sequences of variable length. |
|
- **Embedding Layer:** Converts target sequences into dense vectors of size 100. |
|
- **LSTM Layer:** Processes the embedded target sequences using an LSTM with the initial states set to the encoder states. |
|
- **Dense Layer:** Applies a Dense layer with softmax activation to generate the probabilities for each word in the vocabulary. |
|
|
|
### Model Summary |
|
|
|
| Layer (type) | Output Shape | Param # | Connected to | |
|
|-----------------------|---------------------|-------------|-----------------------------| |
|
| input_1 (InputLayer) | [(None, 800)] | 0 | - | |
|
| embedding (Embedding) | (None, 800, 100) | 47,619,900 | input_1[0][0] | |
|
| bidirectional | [(None, 200), | 160,800 | embedding[0][0] | |
|
| (Bidirectional) | (None, 100), | | | |
|
| | (None, 100), | | | |
|
| | (None, 100), | | | |
|
| | (None, 100)] | | | |
|
| input_2 (InputLayer) | [(None, None)] | 0 | - | |
|
| embedding_1 | (None, None, 100) | 15,515,800 | input_2[0][0] | |
|
| (Embedding) | | | | |
|
| concatenate | (None, 200) | 0 | bidirectional[0][1] | |
|
| (Concatenate) | | | bidirectional[0][3] | |
|
| concatenate_1 | (None, 200) | 0 | bidirectional[0][2] | |
|
| (Concatenate) | | | bidirectional[0][4] | |
|
| lstm | [(None, None, 200), | 240,800 | embedding_1[0][0] | |
|
| (LSTM) | (None, 200), | | concatenate[0][0] | |
|
| | (None, 200)] | | concatenate_1[0][0] | |
|
| dense (Dense) | (None, None, 155158)| 31,186,758 | lstm[0][0] | |
|
| | | | | |
|
|
|
Total params: 94,724,060 |
|
|
|
Trainable params: 94,724,058 |
|
|
|
Non-trainable params: 0 |
|
|
|
## Training |
|
|
|
The model was trained on a dataset with sequences of length up to 800 tokens using the following configuration: |
|
|
|
- **Optimizer:** Adam |
|
- **Loss Function:** Categorical Crossentropy |
|
- **Metrics:** Accuracy |
|
|
|
### Training Loss and Validation Loss |
|
|
|
| Epoch | Training Loss | Validation Loss | Time per Epoch (s) | |
|
|-------|---------------|-----------------|--------------------| |
|
| 1 | 3.9044 | 0.4543 | 3087 | |
|
| 2 | 0.3429 | 0.0976 | 3091 | |
|
| 3 | 0.1054 | 0.0427 | 3096 | |
|
| 4 | 0.0490 | 0.0231 | 3099 | |
|
| 5 | 0.0203 | 0.0148 | 3098 | |
|
|
|
### Test Loss |
|
|
|
| Test Loss | |
|
|----------------------| |
|
| 0.014802712015807629 | |
|
|
|
## Usage -- I will update this soon |
|
|
|
To use this model, you can load it using the Hugging Face Transformers library: |
|
|
|
```python |
|
from transformers import TFAutoModel |
|
|
|
model = TFAutoModel.from_pretrained('your-model-name') |
|
|
|
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('your-model-name') |
|
model = TFAutoModelForSeq2SeqLM.from_pretrained('your-model-name') |
|
|
|
article = "Your input text here." |
|
inputs = tokenizer.encode("summarize: " + article, return_tensors="tf", max_length=800, truncation=True) |
|
summary_ids = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
print(summary) |
|
|