Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitignore +163 -0
- README.md +148 -0
- az_tokenizer.json +0 -0
- az_wiki_data.json +0 -0
- collect_data.py +127 -0
- generate.py +68 -0
- prepare_data.py +124 -0
- push_to_hf.py +17 -0
- requirements.txt +42 -0
- train.py +274 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
.vscode
|
6 |
+
/wandb
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
best_model.pt
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
111 |
+
.pdm.toml
|
112 |
+
.pdm-python
|
113 |
+
.pdm-build/
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
README.md
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Azerbaijani Language GPT Model
|
2 |
+
|
3 |
+
This repository contains an implementation of a GPT (Generative Pre-trained Transformer) model trained on Azerbaijani Wikipedia data. The model is designed to understand and generate Azerbaijani text.
|
4 |
+
|
5 |
+
## Project Structure
|
6 |
+
```
|
7 |
+
.
|
8 |
+
├── README.md
|
9 |
+
├── az_tokenizer.json # Trained tokenizer for Azerbaijani text
|
10 |
+
├── az_wiki_data.json # Collected Wikipedia data
|
11 |
+
├── best_model.pt # Saved state of the best trained model
|
12 |
+
├── collect_data.py # Script for collecting Wikipedia articles
|
13 |
+
├── generate.py # Text generation script using the trained model
|
14 |
+
├── prepare_data.py # Data preprocessing and tokenizer training
|
15 |
+
├── requirements.txt # Project dependencies
|
16 |
+
└── train.py # GPT model training script
|
17 |
+
```
|
18 |
+
|
19 |
+
## Setup
|
20 |
+
|
21 |
+
1. Create and activate virtual environment:
|
22 |
+
```bash
|
23 |
+
python -m venv .venv
|
24 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
25 |
+
```
|
26 |
+
|
27 |
+
2. Install dependencies based on your system:
|
28 |
+
|
29 |
+
For Mac with Apple Silicon (M1/M2):
|
30 |
+
```bash
|
31 |
+
# Install PyTorch for Apple Silicon
|
32 |
+
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
33 |
+
|
34 |
+
# Install other required packages
|
35 |
+
pip install transformers wikipedia-api beautifulsoup4 requests
|
36 |
+
```
|
37 |
+
|
38 |
+
For other systems:
|
39 |
+
```bash
|
40 |
+
pip install -r requirements.txt
|
41 |
+
```
|
42 |
+
|
43 |
+
## Platform-Specific Notes
|
44 |
+
|
45 |
+
### Apple Silicon (M1/M2) Macs
|
46 |
+
- Uses MPS (Metal Performance Shaders) for acceleration
|
47 |
+
- Optimized memory management for Apple Silicon
|
48 |
+
- May require specific PyTorch nightly builds
|
49 |
+
|
50 |
+
### CUDA-enabled GPUs
|
51 |
+
- Automatically utilizes CUDA if available
|
52 |
+
- Implements mixed precision training
|
53 |
+
- Memory optimization through gradient accumulation
|
54 |
+
|
55 |
+
## Data Collection
|
56 |
+
|
57 |
+
1. Collect Azerbaijani Wikipedia articles:
|
58 |
+
```bash
|
59 |
+
python collect_data.py
|
60 |
+
```
|
61 |
+
This will save articles to `az_wiki_data.json`
|
62 |
+
|
63 |
+
2. Prepare data and train tokenizer:
|
64 |
+
```bash
|
65 |
+
python prepare_data.py
|
66 |
+
```
|
67 |
+
This will create `az_tokenizer.json`
|
68 |
+
|
69 |
+
## Training
|
70 |
+
|
71 |
+
Train the GPT model:
|
72 |
+
```bash
|
73 |
+
python train.py
|
74 |
+
```
|
75 |
+
|
76 |
+
The training script:
|
77 |
+
- Uses mixed precision training
|
78 |
+
- Implements gradient accumulation
|
79 |
+
- Saves model checkpoints every 5 epochs
|
80 |
+
- Saves the best model based on validation loss
|
81 |
+
|
82 |
+
## Model Architecture
|
83 |
+
|
84 |
+
- Transformer-based architecture
|
85 |
+
- Configuration adjustable in `train.py`:
|
86 |
+
- Embedding dimension: 512
|
87 |
+
- Attention heads: 8
|
88 |
+
- Layers: 6
|
89 |
+
- Block size: 128
|
90 |
+
- Batch size: 4
|
91 |
+
|
92 |
+
## Text Generation
|
93 |
+
|
94 |
+
Generate text using the trained model:
|
95 |
+
```bash
|
96 |
+
python generate.py
|
97 |
+
```
|
98 |
+
The `generate.py` script:
|
99 |
+
- Loads the trained model and tokenizer
|
100 |
+
- Generates text based on a user-provided prompt
|
101 |
+
- Implements sampling strategies such as nucleus sampling and temperature scaling
|
102 |
+
|
103 |
+
## Files Description
|
104 |
+
|
105 |
+
- `collect_data.py`: Collects articles from Azerbaijani Wikipedia using categories like history, culture, literature, and geography
|
106 |
+
- `prepare_data.py`: Preprocesses text and trains a BPE tokenizer
|
107 |
+
- `train.py`: Contains GPT model implementation and training loop
|
108 |
+
- `generate.py`: Generates text using the trained model and sampling strategies
|
109 |
+
- `az_wiki_data.json`: Collected and preprocessed Wikipedia articles
|
110 |
+
- `az_tokenizer.json`: Trained BPE tokenizer for Azerbaijani text
|
111 |
+
- `best_model.pt`: Saved state of the best model during training
|
112 |
+
|
113 |
+
## Training Output
|
114 |
+
|
115 |
+
The model saves:
|
116 |
+
- Best model state as `best_model.pt`
|
117 |
+
- Regular checkpoints as `checkpoint_epoch_N.pt`
|
118 |
+
- Interrupted training state as `interrupt_checkpoint.pt`
|
119 |
+
|
120 |
+
## Memory Requirements
|
121 |
+
|
122 |
+
- Recommended: GPU with at least 8GB memory
|
123 |
+
- For larger models: Use gradient accumulation steps
|
124 |
+
- Adjustable batch size and model size based on available memory
|
125 |
+
|
126 |
+
## Troubleshooting
|
127 |
+
|
128 |
+
Common Issues:
|
129 |
+
1. Memory Errors:
|
130 |
+
- Reduce batch size
|
131 |
+
- Enable gradient accumulation
|
132 |
+
- Reduce model size
|
133 |
+
- Clear GPU cache regularly
|
134 |
+
|
135 |
+
2. PyTorch Installation:
|
136 |
+
- For Apple Silicon: Use the nightly build command
|
137 |
+
- For CUDA: Install appropriate CUDA version
|
138 |
+
|
139 |
+
3. Data Loading:
|
140 |
+
- Reduce number of workers if getting process errors
|
141 |
+
- Enable pin memory for faster data transfer
|
142 |
+
|
143 |
+
## Future Improvements
|
144 |
+
|
145 |
+
- [ ] Implement model evaluation metrics
|
146 |
+
- [ ] Add data augmentation techniques
|
147 |
+
- [ ] Implement distributed training
|
148 |
+
- [ ] Add model compression techniques
|
az_tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
az_wiki_data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
collect_data.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipediaapi
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import time
|
5 |
+
|
6 |
+
def get_wiki_pages(categories=["Azərbaycan tarixi", "Azərbaycan mədəniyyəti",
|
7 |
+
"Azərbaycan ədəbiyyatı", "Azərbaycan coğrafiyası"],
|
8 |
+
min_length=500, max_pages=1000):
|
9 |
+
"""
|
10 |
+
Recursively collect substantial Azerbaijani Wikipedia pages from multiple categories
|
11 |
+
"""
|
12 |
+
wiki = wikipediaapi.Wikipedia(
|
13 |
+
language='az',
|
14 |
+
extract_format=wikipediaapi.ExtractFormat.WIKI,
|
15 |
+
user_agent='AzGPTDataCollector/1.0'
|
16 |
+
)
|
17 |
+
|
18 |
+
collected_pages = {}
|
19 |
+
visited_pages = set()
|
20 |
+
|
21 |
+
def collect_pages(category_title):
|
22 |
+
if len(collected_pages) >= max_pages:
|
23 |
+
return
|
24 |
+
|
25 |
+
category = wiki.page(f"Kateqoriya:{category_title}")
|
26 |
+
if not category.exists():
|
27 |
+
print(f"Category not found: {category_title}")
|
28 |
+
return
|
29 |
+
|
30 |
+
# First, process all articles in this category
|
31 |
+
for member in category.categorymembers.values():
|
32 |
+
if len(collected_pages) >= max_pages:
|
33 |
+
return
|
34 |
+
|
35 |
+
if member.title in visited_pages:
|
36 |
+
continue
|
37 |
+
|
38 |
+
visited_pages.add(member.title)
|
39 |
+
|
40 |
+
# Skip if it's a category or template page
|
41 |
+
if member.title.startswith('Kateqoriya:') or member.title.startswith('Şablon:'):
|
42 |
+
continue
|
43 |
+
|
44 |
+
# Skip if content is too short
|
45 |
+
if len(member.text) < min_length:
|
46 |
+
continue
|
47 |
+
|
48 |
+
collected_pages[member.title] = {
|
49 |
+
'title': member.title,
|
50 |
+
'text': member.text,
|
51 |
+
'url': member.fullurl,
|
52 |
+
'length': len(member.text)
|
53 |
+
}
|
54 |
+
print(f"Collected: {member.title} ({len(member.text)} chars)")
|
55 |
+
|
56 |
+
# Delay to avoid hitting API limits
|
57 |
+
time.sleep(0.1)
|
58 |
+
|
59 |
+
# Then process subcategories
|
60 |
+
for subcategory in category.categorymembers.values():
|
61 |
+
if subcategory.title.startswith('Kateqoriya:'):
|
62 |
+
collect_pages(subcategory.title.replace('Kateqoriya:', ''))
|
63 |
+
|
64 |
+
# Start collection from each category
|
65 |
+
for category in categories:
|
66 |
+
print(f"\nStarting collection from category: {category}")
|
67 |
+
collect_pages(category)
|
68 |
+
|
69 |
+
return collected_pages
|
70 |
+
|
71 |
+
def preprocess_text(text):
|
72 |
+
"""
|
73 |
+
Enhanced text preprocessing for Azerbaijani text
|
74 |
+
"""
|
75 |
+
# Remove extra whitespace
|
76 |
+
text = ' '.join(text.split())
|
77 |
+
|
78 |
+
# Add space after punctuation if missing
|
79 |
+
for punct in '.!?،؛:()[]{}«»':
|
80 |
+
text = text.replace(punct, punct + ' ')
|
81 |
+
|
82 |
+
# Fix common OCR errors in Azerbaijani text
|
83 |
+
replacements = {
|
84 |
+
'i': 'ı', # Replace dotted i with dotless ı where appropriate
|
85 |
+
'І': 'I',
|
86 |
+
'...': '…',
|
87 |
+
}
|
88 |
+
for old, new in replacements.items():
|
89 |
+
text = text.replace(old, new)
|
90 |
+
|
91 |
+
return text
|
92 |
+
|
93 |
+
def save_dataset(pages, output_file='az_wiki_data.json'):
|
94 |
+
"""
|
95 |
+
Save collected pages to a JSON file
|
96 |
+
"""
|
97 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
98 |
+
json.dump(pages, f, ensure_ascii=False, indent=2)
|
99 |
+
print(f"Saved {len(pages)} pages to {output_file}")
|
100 |
+
|
101 |
+
def main():
|
102 |
+
# Collect pages with minimum length requirement
|
103 |
+
print("Starting data collection...")
|
104 |
+
pages = get_wiki_pages(min_length=500, max_pages=100) # 500 chars minimum length
|
105 |
+
|
106 |
+
# Preprocess and save
|
107 |
+
print("\nPreprocessing and saving data...")
|
108 |
+
for title in pages:
|
109 |
+
pages[title]['text'] = preprocess_text(pages[title]['text'])
|
110 |
+
|
111 |
+
save_dataset(pages)
|
112 |
+
|
113 |
+
# Print statistics
|
114 |
+
total_chars = sum(page['length'] for page in pages.values())
|
115 |
+
if pages:
|
116 |
+
print(f"\nCollection complete!")
|
117 |
+
print(f"Total pages: {len(pages)}")
|
118 |
+
print(f"Total characters: {total_chars}")
|
119 |
+
print(f"Average page length: {total_chars / len(pages):.2f} characters")
|
120 |
+
|
121 |
+
# Print some titles as examples
|
122 |
+
print("\nSample of collected articles:")
|
123 |
+
for title in list(pages.keys())[:5]:
|
124 |
+
print(f"- {title} ({pages[title]['length']} chars)")
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|
generate.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tokenizers import Tokenizer
|
3 |
+
from train import GPT, GPTConfig # Assuming your model definition is in train.py
|
4 |
+
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def nucleus_sampling(logits, p=0.9):
|
8 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
9 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
10 |
+
sorted_indices_to_remove = cumulative_probs > p
|
11 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
12 |
+
sorted_indices_to_remove[..., 0] = 0
|
13 |
+
logits[sorted_indices[sorted_indices_to_remove]] = -float('Inf')
|
14 |
+
probabilities = F.softmax(logits, dim=-1)
|
15 |
+
next_token_id = torch.multinomial(probabilities, num_samples=1).item()
|
16 |
+
return next_token_id
|
17 |
+
|
18 |
+
def load_model_and_tokenizer():
|
19 |
+
# Load the model configuration and tokenizer
|
20 |
+
config = GPTConfig()
|
21 |
+
model = GPT(config)
|
22 |
+
model.load_state_dict(torch.load('best_model.pt', map_location=torch.device('cpu')))
|
23 |
+
model.eval() # Set model to evaluation mode
|
24 |
+
tokenizer = Tokenizer.from_file("az_tokenizer.json") # Load tokenizer
|
25 |
+
return model, tokenizer
|
26 |
+
|
27 |
+
def apply_repetition_penalty(logits, input_ids, penalty=1.2):
|
28 |
+
# Penalize the logits for tokens that have already been generated
|
29 |
+
for token_id in set(input_ids):
|
30 |
+
logits[0, token_id] /= penalty
|
31 |
+
return logits
|
32 |
+
|
33 |
+
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.001, p=0.95, repetition_penalty=1.5, device='cpu'):
|
34 |
+
model = model.to(device)
|
35 |
+
input_ids = tokenizer.encode(prompt).ids
|
36 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
|
37 |
+
|
38 |
+
for _ in range(max_new_tokens):
|
39 |
+
with torch.no_grad():
|
40 |
+
output_logits, _ = model(input_tensor)
|
41 |
+
|
42 |
+
# Apply temperature scaling
|
43 |
+
logits = output_logits[:, -1, :] / temperature
|
44 |
+
|
45 |
+
# Apply repetition penalty
|
46 |
+
logits = apply_repetition_penalty(logits.clone(), input_ids, penalty=repetition_penalty)
|
47 |
+
|
48 |
+
# Use nucleus sampling
|
49 |
+
next_token_id = nucleus_sampling(logits[0], p=p)
|
50 |
+
|
51 |
+
input_ids.append(next_token_id)
|
52 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
|
53 |
+
|
54 |
+
if next_token_id == tokenizer.token_to_id('[END]'): # Replace with actual end token if applicable
|
55 |
+
break
|
56 |
+
|
57 |
+
generated_text = tokenizer.decode(input_ids)
|
58 |
+
return generated_text.replace(' i ', ' ') # Example: minor post-processing to clean up spaces
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
model, tokenizer = load_model_and_tokenizer()
|
63 |
+
prompt = "Azərbaycanın tarixi" # Your input prompt
|
64 |
+
generated_text = generate_text(model, tokenizer, prompt, p=0.9) # Adjust p as needed
|
65 |
+
print("Generated Text:", generated_text)
|
66 |
+
|
67 |
+
if __name__ == '__main__':
|
68 |
+
main()
|
prepare_data.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
5 |
+
from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, trainers, processors
|
6 |
+
from tokenizers.models import BPE
|
7 |
+
from tokenizers.trainers import BpeTrainer
|
8 |
+
from tokenizers.pre_tokenizers import Whitespace
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
class AzerbaijaniTokenizer:
|
13 |
+
def __init__(self, vocab_size=50000):
|
14 |
+
self.tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
15 |
+
self.tokenizer.normalizer = normalizers.Sequence([
|
16 |
+
normalizers.NFD(),
|
17 |
+
normalizers.Lowercase(),
|
18 |
+
normalizers.StripAccents(),
|
19 |
+
])
|
20 |
+
self.tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
|
21 |
+
pre_tokenizers.WhitespaceSplit(),
|
22 |
+
pre_tokenizers.Punctuation(),
|
23 |
+
])
|
24 |
+
|
25 |
+
self.trainer = BpeTrainer(
|
26 |
+
vocab_size=vocab_size,
|
27 |
+
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
|
28 |
+
min_frequency=2
|
29 |
+
)
|
30 |
+
|
31 |
+
def train(self, texts):
|
32 |
+
"""Train the tokenizer on the given texts"""
|
33 |
+
print("Training tokenizer...")
|
34 |
+
self.tokenizer.train_from_iterator(texts, trainer=self.trainer)
|
35 |
+
|
36 |
+
def save(self, path):
|
37 |
+
"""Save the tokenizer to a file"""
|
38 |
+
self.tokenizer.save(path)
|
39 |
+
|
40 |
+
def load(self, path):
|
41 |
+
"""Load the tokenizer from a file"""
|
42 |
+
self.tokenizer = Tokenizer.from_file(path)
|
43 |
+
|
44 |
+
def get_vocab_size(self):
|
45 |
+
return self.tokenizer.get_vocab_size()
|
46 |
+
|
47 |
+
class WikiTextDataset(Dataset):
|
48 |
+
def __init__(self, texts, tokenizer, max_length=512):
|
49 |
+
self.tokenizer = tokenizer
|
50 |
+
self.max_length = max_length
|
51 |
+
|
52 |
+
print("Tokenizing texts...")
|
53 |
+
self.examples = []
|
54 |
+
|
55 |
+
for text in tqdm(texts):
|
56 |
+
# Tokenize the text
|
57 |
+
tokens = self.tokenizer.encode(text).ids
|
58 |
+
|
59 |
+
# Create sequences of max_length tokens
|
60 |
+
for i in range(0, len(tokens) - max_length, max_length // 2):
|
61 |
+
chunk = tokens[i:i + max_length]
|
62 |
+
if len(chunk) < max_length:
|
63 |
+
# Pad if necessary
|
64 |
+
chunk = chunk + [0] * (max_length - len(chunk))
|
65 |
+
self.examples.append(chunk)
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return len(self.examples)
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
# Return input and target sequences (for next token prediction)
|
72 |
+
tokens = self.examples[idx]
|
73 |
+
return torch.tensor(tokens[:-1]), torch.tensor(tokens[1:])
|
74 |
+
|
75 |
+
def prepare_data_and_tokenizer():
|
76 |
+
# Load the collected Wikipedia data
|
77 |
+
print("Loading Wikipedia data...")
|
78 |
+
with open('az_wiki_data.json', 'r', encoding='utf-8') as f:
|
79 |
+
wiki_data = json.load(f)
|
80 |
+
|
81 |
+
# Extract texts
|
82 |
+
texts = [page['text'] for page in wiki_data.values()]
|
83 |
+
|
84 |
+
# Create and train tokenizer
|
85 |
+
tokenizer = AzerbaijaniTokenizer(vocab_size=50000)
|
86 |
+
tokenizer.train(texts)
|
87 |
+
|
88 |
+
# Save the tokenizer
|
89 |
+
tokenizer.save("az_tokenizer.json")
|
90 |
+
print(f"Tokenizer vocabulary size: {tokenizer.get_vocab_size()}")
|
91 |
+
|
92 |
+
# Create dataset
|
93 |
+
dataset = WikiTextDataset(texts, tokenizer.tokenizer)
|
94 |
+
|
95 |
+
# Create data loaders
|
96 |
+
train_size = int(0.9 * len(dataset))
|
97 |
+
val_size = len(dataset) - train_size
|
98 |
+
|
99 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
100 |
+
dataset, [train_size, val_size]
|
101 |
+
)
|
102 |
+
|
103 |
+
train_loader = DataLoader(
|
104 |
+
train_dataset,
|
105 |
+
batch_size=16,
|
106 |
+
shuffle=True,
|
107 |
+
num_workers=4
|
108 |
+
)
|
109 |
+
|
110 |
+
val_loader = DataLoader(
|
111 |
+
val_dataset,
|
112 |
+
batch_size=16,
|
113 |
+
shuffle=False,
|
114 |
+
num_workers=4
|
115 |
+
)
|
116 |
+
|
117 |
+
print(f"Total sequences: {len(dataset)}")
|
118 |
+
print(f"Training sequences: {len(train_dataset)}")
|
119 |
+
print(f"Validation sequences: {len(val_dataset)}")
|
120 |
+
|
121 |
+
return tokenizer, train_loader, val_loader
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
tokenizer, train_loader, val_loader = prepare_data_and_tokenizer()
|
push_to_hf.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from huggingface_hub import login, HfApi
|
4 |
+
|
5 |
+
# Load the Hugging Face token from .env
|
6 |
+
load_dotenv()
|
7 |
+
hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
8 |
+
|
9 |
+
# Log in to Hugging Face
|
10 |
+
login(token=hf_token)
|
11 |
+
|
12 |
+
# Define your repository ID
|
13 |
+
repo_id = "IsmatS/gpt-wiki-az"
|
14 |
+
|
15 |
+
# Initialize HfApi and upload the model folder
|
16 |
+
api = HfApi()
|
17 |
+
api.upload_folder(folder_path="./", path_in_repo="", repo_id=repo_id)
|
requirements.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beautifulsoup4==4.12.3
|
2 |
+
certifi==2024.8.30
|
3 |
+
charset-normalizer==3.4.0
|
4 |
+
click==8.1.7
|
5 |
+
docker-pycreds==0.4.0
|
6 |
+
filelock==3.16.1
|
7 |
+
fsspec==2024.10.0
|
8 |
+
gitdb==4.0.11
|
9 |
+
GitPython==3.1.43
|
10 |
+
huggingface-hub==0.26.2
|
11 |
+
idna==3.10
|
12 |
+
Jinja2==3.1.4
|
13 |
+
MarkupSafe==3.0.2
|
14 |
+
mpmath==1.3.0
|
15 |
+
networkx==3.4.2
|
16 |
+
numpy==2.1.3
|
17 |
+
packaging==24.2
|
18 |
+
pillow==11.0.0
|
19 |
+
platformdirs==4.3.6
|
20 |
+
protobuf==5.28.3
|
21 |
+
psutil==6.1.0
|
22 |
+
PyYAML==6.0.2
|
23 |
+
regex==2024.11.6
|
24 |
+
requests==2.32.3
|
25 |
+
safetensors==0.4.5
|
26 |
+
sentry-sdk==2.18.0
|
27 |
+
setproctitle==1.3.3
|
28 |
+
setuptools==75.5.0
|
29 |
+
six==1.16.0
|
30 |
+
smmap==5.0.1
|
31 |
+
soupsieve==2.6
|
32 |
+
sympy==1.13.1
|
33 |
+
tokenizers==0.20.3
|
34 |
+
torch==2.6.0.dev20241113
|
35 |
+
torchaudio==2.5.0.dev20241113
|
36 |
+
torchvision==0.20.0.dev20241113
|
37 |
+
tqdm==4.67.0
|
38 |
+
transformers==4.46.2
|
39 |
+
typing_extensions==4.12.2
|
40 |
+
urllib3==2.2.3
|
41 |
+
wandb==0.18.6
|
42 |
+
Wikipedia-API==0.7.1
|
train.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
6 |
+
import math
|
7 |
+
from tqdm import tqdm
|
8 |
+
import json
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
from datetime import datetime
|
11 |
+
import gc
|
12 |
+
|
13 |
+
class GPTConfig:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
vocab_size=22588,
|
17 |
+
n_embd=768, # Reduced from 2048
|
18 |
+
n_head=12, # Reduced from 16
|
19 |
+
n_layer=8, # Reduced from 12
|
20 |
+
dropout=0.1,
|
21 |
+
block_size=256, # Reduced from 512
|
22 |
+
learning_rate=3e-4,
|
23 |
+
max_epochs=50,
|
24 |
+
batch_size=8, # Reduced from 64
|
25 |
+
grad_clip=1.0,
|
26 |
+
):
|
27 |
+
self.vocab_size = vocab_size
|
28 |
+
self.n_embd = n_embd
|
29 |
+
self.n_head = n_head
|
30 |
+
self.n_layer = n_layer
|
31 |
+
self.dropout = dropout
|
32 |
+
self.block_size = block_size
|
33 |
+
self.learning_rate = learning_rate
|
34 |
+
self.max_epochs = max_epochs
|
35 |
+
self.batch_size = batch_size
|
36 |
+
self.grad_clip = grad_clip
|
37 |
+
|
38 |
+
# Model Architecture
|
39 |
+
class SelfAttention(nn.Module):
|
40 |
+
def __init__(self, config):
|
41 |
+
super().__init__()
|
42 |
+
assert config.n_embd % config.n_head == 0
|
43 |
+
self.w_k = nn.Linear(config.n_embd, config.n_embd)
|
44 |
+
self.w_q = nn.Linear(config.n_embd, config.n_embd)
|
45 |
+
self.w_v = nn.Linear(config.n_embd, config.n_embd)
|
46 |
+
self.attn_drop = nn.Dropout(config.dropout)
|
47 |
+
self.resid_drop = nn.Dropout(config.dropout)
|
48 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
49 |
+
self.n_head = config.n_head
|
50 |
+
self.n_embd = config.n_embd
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
B, T, C = x.size()
|
54 |
+
k = self.w_k(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
55 |
+
q = self.w_q(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
56 |
+
v = self.w_v(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
57 |
+
|
58 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
59 |
+
att = F.softmax(att, dim=-1)
|
60 |
+
att = self.attn_drop(att)
|
61 |
+
y = att @ v
|
62 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
63 |
+
y = self.resid_drop(self.proj(y))
|
64 |
+
return y
|
65 |
+
|
66 |
+
class Block(nn.Module):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__()
|
69 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
70 |
+
self.attn = SelfAttention(config)
|
71 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
72 |
+
self.mlp = nn.Sequential(
|
73 |
+
nn.Linear(config.n_embd, 4 * config.n_embd),
|
74 |
+
nn.GELU(),
|
75 |
+
nn.Linear(4 * config.n_embd, config.n_embd),
|
76 |
+
nn.Dropout(config.dropout),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = x + self.attn(self.ln1(x))
|
81 |
+
x = x + self.mlp(self.ln2(x))
|
82 |
+
return x
|
83 |
+
|
84 |
+
class GPT(nn.Module):
|
85 |
+
def __init__(self, config):
|
86 |
+
super().__init__()
|
87 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
88 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
89 |
+
self.drop = nn.Dropout(config.dropout)
|
90 |
+
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
91 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
92 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
93 |
+
|
94 |
+
self.block_size = config.block_size
|
95 |
+
self.apply(self._init_weights)
|
96 |
+
|
97 |
+
def _init_weights(self, module):
|
98 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
99 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
100 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
101 |
+
module.bias.data.zero_()
|
102 |
+
elif isinstance(module, nn.LayerNorm):
|
103 |
+
module.bias.data.zero_()
|
104 |
+
module.weight.data.fill_(1.0)
|
105 |
+
|
106 |
+
def forward(self, idx, targets=None):
|
107 |
+
b, t = idx.size()
|
108 |
+
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
|
109 |
+
|
110 |
+
token_embeddings = self.tok_emb(idx)
|
111 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
112 |
+
x = self.drop(token_embeddings + position_embeddings)
|
113 |
+
for block in self.blocks:
|
114 |
+
x = block(x)
|
115 |
+
x = self.ln_f(x)
|
116 |
+
logits = self.head(x)
|
117 |
+
|
118 |
+
loss = None
|
119 |
+
if targets is not None:
|
120 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
121 |
+
|
122 |
+
return logits, loss
|
123 |
+
|
124 |
+
|
125 |
+
class WikiTextDataset(Dataset):
|
126 |
+
def __init__(self, texts, tokenizer, max_length=256): # Reduced max_length
|
127 |
+
self.tokenizer = tokenizer
|
128 |
+
self.max_length = max_length
|
129 |
+
|
130 |
+
print("Tokenizing texts...")
|
131 |
+
self.examples = []
|
132 |
+
|
133 |
+
for text in tqdm(texts):
|
134 |
+
tokens = self.tokenizer.encode(text).ids
|
135 |
+
for i in range(0, len(tokens) - max_length, max_length // 2):
|
136 |
+
chunk = tokens[i:i + max_length]
|
137 |
+
if len(chunk) < max_length:
|
138 |
+
chunk = chunk + [0] * (max_length - len(chunk))
|
139 |
+
self.examples.append(chunk)
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.examples)
|
143 |
+
|
144 |
+
def __getitem__(self, idx):
|
145 |
+
tokens = self.examples[idx]
|
146 |
+
return torch.tensor(tokens[:-1]), torch.tensor(tokens[1:])
|
147 |
+
|
148 |
+
def train():
|
149 |
+
# Clear GPU memory
|
150 |
+
torch.cuda.empty_cache()
|
151 |
+
gc.collect()
|
152 |
+
|
153 |
+
print("Loading Wikipedia data...")
|
154 |
+
with open('az_wiki_data.json', 'r', encoding='utf-8') as f:
|
155 |
+
wiki_data = json.load(f)
|
156 |
+
|
157 |
+
texts = [page['text'] for page in wiki_data.values()]
|
158 |
+
tokenizer = Tokenizer.from_file("az_tokenizer.json")
|
159 |
+
|
160 |
+
dataset = WikiTextDataset(texts, tokenizer)
|
161 |
+
train_size = int(0.9 * len(dataset))
|
162 |
+
val_size = len(dataset) - train_size
|
163 |
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
164 |
+
|
165 |
+
config = GPTConfig()
|
166 |
+
|
167 |
+
train_loader = DataLoader(
|
168 |
+
train_dataset,
|
169 |
+
batch_size=config.batch_size,
|
170 |
+
shuffle=True,
|
171 |
+
num_workers=2, # Reduced from 4
|
172 |
+
pin_memory=True
|
173 |
+
)
|
174 |
+
|
175 |
+
val_loader = DataLoader(
|
176 |
+
val_dataset,
|
177 |
+
batch_size=config.batch_size,
|
178 |
+
shuffle=False,
|
179 |
+
num_workers=2, # Reduced from 4
|
180 |
+
pin_memory=True
|
181 |
+
)
|
182 |
+
|
183 |
+
model = GPT(config)
|
184 |
+
model = model.to('cuda')
|
185 |
+
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
|
186 |
+
|
187 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
188 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=config.max_epochs)
|
189 |
+
scaler = torch.amp.GradScaler() # Updated deprecation warning
|
190 |
+
|
191 |
+
def run_epoch(split, epoch_num=0):
|
192 |
+
is_train = split == 'train'
|
193 |
+
model.train(is_train)
|
194 |
+
if not is_train:
|
195 |
+
model.eval()
|
196 |
+
|
197 |
+
loader = train_loader if is_train else val_loader
|
198 |
+
losses = []
|
199 |
+
|
200 |
+
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
|
201 |
+
|
202 |
+
for it, (x, y) in pbar:
|
203 |
+
# Clear memory
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
|
206 |
+
x = x.to('cuda', non_blocking=True)
|
207 |
+
y = y.to('cuda', non_blocking=True)
|
208 |
+
|
209 |
+
with torch.amp.autocast(device_type='cuda'): # Updated deprecation warning
|
210 |
+
logits, loss = model(x, y)
|
211 |
+
|
212 |
+
losses.append(loss.item())
|
213 |
+
|
214 |
+
if is_train:
|
215 |
+
scaler.scale(loss).backward()
|
216 |
+
scaler.unscale_(optimizer)
|
217 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
218 |
+
scaler.step(optimizer)
|
219 |
+
scaler.update()
|
220 |
+
optimizer.zero_grad(set_to_none=True)
|
221 |
+
|
222 |
+
pbar.set_description(f"epoch {epoch_num+1} iter {it}: train loss {loss.item():.5f}")
|
223 |
+
|
224 |
+
# Delete unnecessary tensors
|
225 |
+
del x, y, logits
|
226 |
+
if is_train:
|
227 |
+
del loss
|
228 |
+
|
229 |
+
mean_loss = torch.tensor(losses).mean().item()
|
230 |
+
return mean_loss
|
231 |
+
|
232 |
+
best_val_loss = float('inf')
|
233 |
+
|
234 |
+
try:
|
235 |
+
for epoch in range(config.max_epochs):
|
236 |
+
print(f"\nEpoch {epoch+1}/{config.max_epochs}")
|
237 |
+
|
238 |
+
train_loss = run_epoch('train', epoch_num=epoch)
|
239 |
+
|
240 |
+
with torch.no_grad():
|
241 |
+
val_loss = run_epoch('val')
|
242 |
+
|
243 |
+
scheduler.step()
|
244 |
+
|
245 |
+
if val_loss < best_val_loss:
|
246 |
+
best_val_loss = val_loss
|
247 |
+
print(f"Saving best model with val_loss: {val_loss:.4f}")
|
248 |
+
torch.save(model.state_dict(), 'best_model.pt')
|
249 |
+
|
250 |
+
print(f"Epoch {epoch+1}: train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}")
|
251 |
+
|
252 |
+
if (epoch + 1) % 5 == 0:
|
253 |
+
torch.save({
|
254 |
+
'epoch': epoch,
|
255 |
+
'model_state_dict': model.state_dict(),
|
256 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
257 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
258 |
+
'train_loss': train_loss,
|
259 |
+
'val_loss': val_loss,
|
260 |
+
}, f'checkpoint_epoch_{epoch+1}.pt')
|
261 |
+
|
262 |
+
except KeyboardInterrupt:
|
263 |
+
print('Training interrupted, saving checkpoint...')
|
264 |
+
torch.save({
|
265 |
+
'epoch': epoch,
|
266 |
+
'model_state_dict': model.state_dict(),
|
267 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
268 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
269 |
+
'train_loss': train_loss,
|
270 |
+
'val_loss': val_loss,
|
271 |
+
}, 'interrupt_checkpoint.pt')
|
272 |
+
|
273 |
+
if __name__ == '__main__':
|
274 |
+
train()
|