Update PrateritumGPT.py
Browse files- PrateritumGPT.py +4 -2
PrateritumGPT.py
CHANGED
@@ -142,7 +142,7 @@ train_loader = DataLoader(MyDataset, batch_size=32, shuffle=True, collate_fn=col
|
|
142 |
#Dropout: 0
|
143 |
#Forward Dim: 1024
|
144 |
|
145 |
-
model = TransformerModel(vocab_size=len(tokens)+2, emb_dim=128, nhead=32, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=
|
146 |
loss_fn = nn.CrossEntropyLoss()
|
147 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
148 |
|
@@ -166,8 +166,10 @@ def Prompt():
|
|
166 |
tgt_=torch.Tensor(tgt)
|
167 |
out=model(torch.Tensor(src).to(device),tgt_.to(device)).tolist()[0]
|
168 |
Best=0
|
169 |
-
|
170 |
for k,f in enumerate(out):
|
|
|
|
|
171 |
if f>Best:
|
172 |
Best=f
|
173 |
Best_=k
|
|
|
142 |
#Dropout: 0
|
143 |
#Forward Dim: 1024
|
144 |
|
145 |
+
model = TransformerModel(vocab_size=len(tokens)+2, emb_dim=128, nhead=32, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=512,dropout=0)
|
146 |
loss_fn = nn.CrossEntropyLoss()
|
147 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
148 |
|
|
|
166 |
tgt_=torch.Tensor(tgt)
|
167 |
out=model(torch.Tensor(src).to(device),tgt_.to(device)).tolist()[0]
|
168 |
Best=0
|
169 |
+
warn=tokens.index(" ")
|
170 |
for k,f in enumerate(out):
|
171 |
+
if k==len(tokens):
|
172 |
+
f*=2
|
173 |
if f>Best:
|
174 |
Best=f
|
175 |
Best_=k
|