Spaces:
Running
Running
File size: 7,089 Bytes
ec53722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""SMILES utilities."""
import os
import regex
import logging
import matplotlib
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from io import StringIO
from operator import itemgetter
from typing import Callable, Iterable, Tuple
from matplotlib.ticker import FormatStrFormatter, ScalarFormatter
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from configuration import SMILES_LANGUAGE, SMILES_TOKENIZE_FN
logger = logging.getLogger(__name__)
# NOTE: avoid segfaults in matplotlib
matplotlib.use("Agg")
MOLECULE_TOKENS = set(SMILES_LANGUAGE.token_to_index.keys())
NON_ATOM_REGEX = regex.compile(r"^(\d|\%\d+|\p{P}+|\p{Math}+)$")
NON_ATOM_TOKENS = set(
[token for token in MOLECULE_TOKENS if NON_ATOM_REGEX.match(token)]
)
CMAP = cm.Oranges
COLOR_NORMALIZERS = {"linear": mpl.colors.Normalize, "logarithmic": mpl.colors.LogNorm}
ATOM_RADII = float(os.environ.get("PACCMANN_ATOM_RADII", 0.5))
SVG_WIDTH = int(os.environ.get("PACCMANN_SVG_WIDTH", 400))
SVG_HEIGHT = int(os.environ.get("PACCMANN_SVG_HEIGHT", 200))
COLOR_NORMALIZATION = os.environ.get("PACCMANN_COLOR_NORMALIZATION", "logarithmic")
def validate_smiles(smiles: str) -> bool:
"""
Validate a SMILES.
Args:
smiles (str): a SMILES string.
Returns:
bool: flag indicating whether the SMILES is a valid molecule.
"""
molecule = Chem.MolFromSmiles(smiles)
return not (molecule is None)
def canonicalize_smiles(smiles: str) -> str:
"""
Canonicalize a SMILES.
Args:
smiles (str): a SMILES string.
Returns:
str: the canonicalized SMILES.
"""
molecule = Chem.MolFromSmiles(smiles)
return Chem.MolToSmiles(molecule)
def remove_housekeeping_from_tokens_and_smiles_attention(
tokens: Iterable[str], smiles_attention: Iterable[float]
) -> Tuple[Iterable[str], Iterable[float]]:
"""
Remove housekeeping tokens and corresponding attention weights.
Args:
tokens (Iterable[str]): tokens obtained from the SMILES.
smiles_attention (Iterable[float]): SMILES attention.
Returns:
Tuple[Iterable[str], Iterable[float]]: a tuple containing the filtered
tokens and attention values.
"""
to_keep = [index for index, token in enumerate(tokens) if token in MOLECULE_TOKENS]
return (
list(itemgetter(*to_keep)(tokens)),
list(itemgetter(*to_keep)(smiles_attention)),
)
def _get_index_and_colors(
values: Iterable[float],
tokens: Iterable[str],
predicate: Callable[[tuple], bool],
color_mapper: cm.ScalarMappable,
) -> Tuple[Iterable[int], Iterable[tuple]]:
"""
Get index and RGB colors from a color map using a rule.
Args:
values (Iterable[float]): values associated to tokens.
tokens (Iterable[str]): tokens.
predicate (Callable[[tuple], bool]): a predicate that acts on a tuple
of (value, object).
color_mapper (cm.ScalarMappable): a color mapper.
Returns:
Tuple[Iterable[int], Iterable[tuple]]: tuple with indexes and RGB
colors associated to the given index.
"""
indices = []
colors = {}
for index, value in enumerate(
map(lambda t: t[0], filter(lambda t: predicate(t), zip(values, tokens)))
):
indices.append(index)
colors[index] = color_mapper.to_rgba(value)
return indices, colors
def smiles_attention_to_svg(
smiles: str, smiles_attention: Iterable[float]
) -> Tuple[str, str]:
"""
Generate an svg of the molecule highlighiting SMILES attention.
Args:
smiles (str): SMILES representing a molecule.
smiles_attention (Iterable[float]): SMILES attention.
Returns:
Tuple[str, str]: drawing, colorbar
the svg of the molecule highlighiting SMILES attention
and the svg displaying the colorbar
"""
# remove padding
logger.debug("SMILES attention:\n{}.".format(smiles_attention))
logger.debug(
"SMILES attention range: [{},{}].".format(
min(smiles_attention), max(smiles_attention)
)
)
# get the molecule
molecule = Chem.MolFromSmiles(smiles)
tokens = [
SMILES_LANGUAGE.index_to_token[token_index]
for token_index in SMILES_TOKENIZE_FN(smiles)
]
logger.debug("SMILES tokens:{}.".format(tokens))
tokens, smiles_attention = remove_housekeeping_from_tokens_and_smiles_attention(
tokens, smiles_attention
) # yapf:disable
logger.debug(
"tokens and SMILES attention after removal:\n{}\n{}.".format(
tokens, smiles_attention
)
)
logger.debug(
"SMILES attention range after padding removal: [{},{}].".format(
min(smiles_attention), max(smiles_attention)
)
)
# define a color map
normalize = COLOR_NORMALIZERS.get(COLOR_NORMALIZATION, mpl.colors.LogNorm)(
vmin=min(smiles_attention), vmax=min(1.0, 2 * max(smiles_attention))
)
color_mapper = cm.ScalarMappable(norm=normalize, cmap=CMAP)
# get atom colors
highlight_atoms, highlight_atom_colors = _get_index_and_colors(
smiles_attention, tokens, lambda t: t[1] not in NON_ATOM_TOKENS, color_mapper
)
logger.debug("atom colors:\n{}.".format(highlight_atom_colors))
# get bond colors
highlight_bonds, highlight_bond_colors = _get_index_and_colors(
smiles_attention, tokens, lambda t: t[1] in NON_ATOM_TOKENS, color_mapper
)
logger.debug("bond colors:\n{}.".format(highlight_bond_colors))
# add coordinates
logger.debug("compute 2D coordinates")
Chem.rdDepictor.Compute2DCoords(molecule)
# draw the molecule
logger.debug("get a drawer")
drawer = rdMolDraw2D.MolDraw2DSVG(SVG_WIDTH, SVG_HEIGHT)
logger.debug("draw the molecule")
drawer.DrawMolecule(
molecule,
highlightAtoms=highlight_atoms,
highlightAtomColors=highlight_atom_colors,
highlightBonds=highlight_bonds,
highlightBondColors=highlight_bond_colors,
highlightAtomRadii={index: ATOM_RADII for index in highlight_atoms},
)
logger.debug("finish drawing")
drawer.FinishDrawing()
# the drawn molecule as str
logger.debug("drawing to string")
drawing = drawer.GetDrawingText().replace("\n", " ")
# the respective colorbar
logger.debug("draw the colorbar")
fig, ax = plt.subplots(figsize=(0.5, 6))
mpl.colorbar.ColorbarBase(
ax,
cmap=CMAP,
norm=normalize,
orientation="vertical",
extend="both",
extendrect=True,
)
# instead of LogFormatterSciNotation
logger.debug("format the colorbar")
ax.yaxis.set_minor_formatter(ScalarFormatter())
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) # fixes 0.1, 0.20
# the colorbar svg as str
logger.debug("colorbar to string")
file_like = StringIO()
plt.savefig(file_like, format="svg", bbox_inches="tight")
colorbar = file_like.getvalue().replace("\n", " ")
plt.close(fig)
return drawing, colorbar
|