|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from collections import Counter |
|
import json |
|
|
|
|
|
|
|
|
|
|
|
def load_text_data(file_path): |
|
with open(file_path, 'r') as file: |
|
data = file.readlines() |
|
return [line.strip() for line in data] |
|
|
|
|
|
file_path = 'data.txt' |
|
sentences = load_text_data(file_path) |
|
print(f"Loaded sentences: {sentences[:0]}") |
|
|
|
|
|
def tokenize(text): |
|
return word_tokenize(text.lower()) |
|
|
|
|
|
def build_vocab(sentences): |
|
tokens = [token for sentence in sentences for token in tokenize(sentence)] |
|
vocab = {word: i for i, (word, _) in enumerate(Counter(tokens).items())} |
|
vocab['<unk>'] = len(vocab) |
|
vocab['<pad>'] = len(vocab) |
|
return vocab |
|
|
|
vocab = build_vocab(sentences) |
|
vocab_size = len(vocab) |
|
print(f"Vocabulary size: {vocab_size}") |
|
|
|
|
|
with open('vocab.json', 'w') as f: |
|
json.dump(vocab, f) |
|
print('Vocabulary saved to vocab.json') |
|
|
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, sentences, vocab, seq_length=10): |
|
self.data = [] |
|
self.vocab = vocab |
|
self.seq_length = seq_length |
|
for sentence in sentences: |
|
tokens = tokenize(sentence) |
|
indices = [vocab.get(token, vocab['<unk>']) for token in tokens] |
|
if len(indices) >= seq_length + 1: |
|
for i in range(len(indices) - seq_length): |
|
self.data.append((indices[i:i+seq_length], indices[i+1:i+seq_length+1])) |
|
print(f"Dataset size: {len(self.data)}") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
inputs, targets = self.data[idx] |
|
return torch.tensor(inputs, dtype=torch.long), torch.tensor(targets, dtype=torch.long) |
|
|
|
dataset = TextDataset(sentences, vocab, seq_length=10) |
|
print(f"Number of samples in dataset: {len(dataset)}") |
|
|
|
def collate_fn(batch): |
|
inputs, targets = zip(*batch) |
|
inputs_pad = pad_sequence(inputs, batch_first=True, padding_value=vocab['<pad>']) |
|
targets_pad = pad_sequence(targets, batch_first=True, padding_value=vocab['<pad>']) |
|
return inputs_pad, targets_pad |
|
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn) |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, vocab_size, embed_size, num_heads, hidden_size, num_layers): |
|
super(Transformer, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_size) |
|
self.transformer = nn.Transformer( |
|
d_model=embed_size, |
|
nhead=num_heads, |
|
num_encoder_layers=num_layers, |
|
num_decoder_layers=num_layers, |
|
dim_feedforward=hidden_size, |
|
batch_first=True |
|
) |
|
self.fc = nn.Linear(embed_size, vocab_size) |
|
|
|
def forward(self, src, tgt, src_mask=None, tgt_mask=None): |
|
src = self.embedding(src) |
|
tgt = self.embedding(tgt) |
|
output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) |
|
output = self.fc(output) |
|
return output |
|
|
|
|
|
embed_size = 10 |
|
num_heads = 5 |
|
hidden_size = 100 |
|
num_layers = 2 |
|
model = Transformer(vocab_size, embed_size, num_heads, hidden_size, num_layers) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>']) |
|
optimizer = optim.AdamW(model.parameters(), lr=0.01) |
|
|
|
|
|
num_epochs = 5 |
|
for epoch in range(num_epochs): |
|
total_loss = 0 |
|
for src_batch, tgt_batch in dataloader: |
|
src_batch = src_batch.squeeze(0).to(device) |
|
tgt_batch = tgt_batch.squeeze(0).to(device) |
|
|
|
src_mask = tgt_mask = None |
|
|
|
optimizer.zero_grad() |
|
outputs = model(src_batch, tgt_batch, src_mask, tgt_mask) |
|
loss = criterion(outputs.view(-1, vocab_size), tgt_batch.view(-1)) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') |
|
|
|
|
|
model_path = 'transformer_model.pth' |
|
torch.save(model.state_dict(), model_path) |
|
print(f'Model saved to {model_path}') |
|
|