ML6-UniKP / utils.py
Topallaj Denis
copied the unikp model into this endpoint
c7272f2
import torch
import math
import torch.nn as nn
from rdkit import Chem
from rdkit import rdBase
rdBase.DisableLog('rdApp.*')
# Split SMILES into words
def split(sm):
'''
function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc.
input: A SMILES
output: A string with space between words
'''
arr = []
i = 0
while i < len(sm)-1:
if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \
'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']:
arr.append(sm[i])
i += 1
elif sm[i]=='%':
arr.append(sm[i:i+3])
i += 3
elif sm[i]=='C' and sm[i+1]=='l':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='C' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='C' and sm[i+1]=='u':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='N' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='N' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='R' and sm[i+1]=='b':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='R' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='X' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='L' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='l':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='s':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='g':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='u':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='M' and sm[i+1]=='g':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='M' and sm[i+1]=='n':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='T' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='Z' and sm[i+1]=='n':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='s' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='s' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='t' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='H' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='2':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='3':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='4':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='2':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='3':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='4':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='K' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='F' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
else:
arr.append(sm[i])
i += 1
if i == len(sm)-1:
arr.append(sm[i])
return ' '.join(arr)
# 活性化関数
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
# 位置情報を考慮したFFN
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = GELU()
def forward(self, x):
return self.w_2(self.dropout(self.activation(self.w_1(x))))
# 正規化層
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
# Sample SMILES from probablistic distribution
def sample(msms):
ret = []
for msm in msms:
ret.append(torch.multinomial(msm.exp(), 1).squeeze())
return torch.stack(ret)
def validity(smiles):
loss = 0
for sm in smiles:
mol = Chem.MolFromSmiles(sm)
if mol is None:
loss += 1
return 1-loss/len(smiles)