File size: 4,906 Bytes
93b6e70 09ab574 93b6e70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
---
license: mit
language:
- en
tags:
- gpu
---
# Text Summarization Model with Seq2Seq and LSTM
This model is a sequence-to-sequence (seq2seq) model for text summarization. It uses a bidirectional LSTM encoder and an LSTM decoder to generate summaries from input articles. The model was trained on a dataset with sequences of length up to 800 tokens.
## Dataset
CNN-DailyMail News Text Summarization from kaggle
## Model Architecture
### Encoder
- **Input Layer:** Takes input sequences of length `max_len_article`.
- **Embedding Layer:** Converts input sequences into dense vectors of size 100.
- **Bidirectional LSTM Layer:** Processes the embedded input, capturing dependencies in both forward and backward directions. Outputs hidden and cell states from both directions.
- **State Concatenation:** Combines forward and backward hidden and cell states to form the final encoder states.
### Decoder
- **Input Layer:** Takes target sequences of variable length.
- **Embedding Layer:** Converts target sequences into dense vectors of size 100.
- **LSTM Layer:** Processes the embedded target sequences using an LSTM with the initial states set to the encoder states.
- **Dense Layer:** Applies a Dense layer with softmax activation to generate the probabilities for each word in the vocabulary.
### Model Summary
| Layer (type) | Output Shape | Param # | Connected to |
|-----------------------|---------------------|-------------|-----------------------------|
| input_1 (InputLayer) | [(None, 800)] | 0 | - |
| embedding (Embedding) | (None, 800, 100) | 47,619,900 | input_1[0][0] |
| bidirectional | [(None, 200), | 160,800 | embedding[0][0] |
| (Bidirectional) | (None, 100), | | |
| | (None, 100), | | |
| | (None, 100), | | |
| | (None, 100)] | | |
| input_2 (InputLayer) | [(None, None)] | 0 | - |
| embedding_1 | (None, None, 100) | 15,515,800 | input_2[0][0] |
| (Embedding) | | | |
| concatenate | (None, 200) | 0 | bidirectional[0][1] |
| (Concatenate) | | | bidirectional[0][3] |
| concatenate_1 | (None, 200) | 0 | bidirectional[0][2] |
| (Concatenate) | | | bidirectional[0][4] |
| lstm | [(None, None, 200), | 240,800 | embedding_1[0][0] |
| (LSTM) | (None, 200), | | concatenate[0][0] |
| | (None, 200)] | | concatenate_1[0][0] |
| dense (Dense) | (None, None, 155158)| 31,186,758 | lstm[0][0] |
| | | | |
Total params: 94,724,060
Trainable params: 94,724,058
Non-trainable params: 0
## Training
The model was trained on a dataset with sequences of length up to 800 tokens using the following configuration:
- **Optimizer:** Adam
- **Loss Function:** Categorical Crossentropy
- **Metrics:** Accuracy
### Training Loss and Validation Loss
| Epoch | Training Loss | Validation Loss | Time per Epoch (s) |
|-------|---------------|-----------------|--------------------|
| 1 | 3.9044 | 0.4543 | 3087 |
| 2 | 0.3429 | 0.0976 | 3091 |
| 3 | 0.1054 | 0.0427 | 3096 |
| 4 | 0.0490 | 0.0231 | 3099 |
| 5 | 0.0203 | 0.0148 | 3098 |
### Test Loss
| Test Loss |
|----------------------|
| 0.014802712015807629 |
## Usage -- I will update this soon
To use this model, you can load it using the Hugging Face Transformers library:
```python
from transformers import TFAutoModel
model = TFAutoModel.from_pretrained('your-model-name')
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained('your-model-name')
model = TFAutoModelForSeq2SeqLM.from_pretrained('your-model-name')
article = "Your input text here."
inputs = tokenizer.encode("summarize: " + article, return_tensors="tf", max_length=800, truncation=True)
summary_ids = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(summary)
|