jayksharma commited on
Commit
da9ac04
·
verified ·
1 Parent(s): 2d8a6d5

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +45 -0
train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from super_large_language_model import TransformerModel
8
+
9
+ class TextDataset(Dataset):
10
+ def __init__(self, texts, vocab):
11
+ self.texts = texts
12
+ self.vocab = vocab
13
+
14
+ def __len__(self):
15
+ return len(self.texts)
16
+
17
+ def __getitem__(self, idx):
18
+ text = self.texts[idx]
19
+ text_indices = [self.vocab[char] for char in text]
20
+ return torch.tensor(text_indices)
21
+
22
+ def train_model(model, dataset, num_epochs=10, batch_size=32, learning_rate=0.001):
23
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
24
+ criterion = nn.CrossEntropyLoss()
25
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
26
+
27
+ for epoch in range(num_epochs):
28
+ model.train()
29
+ for batch in dataloader:
30
+ optimizer.zero_grad()
31
+ output = model(batch[:-1], batch[1:])
32
+ loss = criterion(output.view(-1, output.size(-1)), batch[1:].view(-1))
33
+ loss.backward()
34
+ optimizer.step()
35
+ print(f'Epoch {epoch+1}, Loss: {loss.item()}')
36
+
37
+ if __name__ == "__main__":
38
+ # Example texts and vocabulary
39
+ texts = ["hello world", "pytorch is great"]
40
+ vocab = {char: idx for idx, char in enumerate(set("".join(texts)))}
41
+
42
+ dataset = TextDataset(texts, vocab)
43
+ model = TransformerModel(vocab_size=len(vocab), d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048)
44
+
45
+ train_model(model, dataset)