import random import pandas as pd import torch from torch.utils.data import Dataset, DataLoader from enumerator import SmilesEnumerator from utils import split PAD = 0 MAX_LEN = 220 class Randomizer(object): def __init__(self): self.sme = SmilesEnumerator() def __call__(self, sm): sm_r = self.sme.randomize_smiles(sm) # Random transoform if sm_r is None: sm_spaced = split(sm) # Spacing else: sm_spaced = split(sm_r) # Spacing sm_split = sm_spaced.split() if len(sm_split)<=MAX_LEN - 2: return sm_split # List else: return split(sm).split() def random_transform(self, sm): ''' function: Random transformation for SMILES. It may take some time. input: A SMILES output: A randomized SMILES ''' return self.sme.randomize_smiles(sm) class Seq2seqDataset(Dataset): def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()): self.smiles = smiles self.vocab = vocab self.seq_len = seq_len self.transform = transform def __len__(self): return len(self.smiles) def __getitem__(self, item): sm = self.smiles[item] sm = self.transform(sm) # List content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm] X = [self.vocab.sos_index] + content + [self.vocab.eos_index] padding = [self.vocab.pad_index]*(self.seq_len - len(X)) X.extend(padding) return torch.tensor(X)