NMT-LaVi / utils /misc.py
hieungo1410's picture
'add'
8cb4f3b
raw
history blame contribute delete
979 Bytes
import numpy as np
import torch
from torch.autograd import Variable
def no_peeking_mask(size, device):
"""
Creating a mask for decoder
that future words cannot be seen at prediction during training.
"""
np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8')
np_mask = Variable(torch.from_numpy(np_mask) == 0)
np_mask = np_mask.to(device)
return np_mask
def create_masks(src, trg, src_pad, trg_pad, device):
"""
Creating a mask for Encoder
That the model does not ignore the information of the PAD characters we added
"""
src_mask = (src != src_pad).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != trg_pad).unsqueeze(-2)
size = trg.size(1) # get seq_len for matrix
np_mask = no_peeking_mask(size, device)
if trg.is_cuda:
np_mask.cuda()
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask