File size: 7,089 Bytes
ec53722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""SMILES utilities."""
import os
import regex
import logging
import matplotlib
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from io import StringIO
from operator import itemgetter
from typing import Callable, Iterable, Tuple
from matplotlib.ticker import FormatStrFormatter, ScalarFormatter
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from configuration import SMILES_LANGUAGE, SMILES_TOKENIZE_FN

logger = logging.getLogger(__name__)
# NOTE: avoid segfaults in matplotlib
matplotlib.use("Agg")

MOLECULE_TOKENS = set(SMILES_LANGUAGE.token_to_index.keys())
NON_ATOM_REGEX = regex.compile(r"^(\d|\%\d+|\p{P}+|\p{Math}+)$")
NON_ATOM_TOKENS = set(
    [token for token in MOLECULE_TOKENS if NON_ATOM_REGEX.match(token)]
)
CMAP = cm.Oranges
COLOR_NORMALIZERS = {"linear": mpl.colors.Normalize, "logarithmic": mpl.colors.LogNorm}
ATOM_RADII = float(os.environ.get("PACCMANN_ATOM_RADII", 0.5))
SVG_WIDTH = int(os.environ.get("PACCMANN_SVG_WIDTH", 400))
SVG_HEIGHT = int(os.environ.get("PACCMANN_SVG_HEIGHT", 200))
COLOR_NORMALIZATION = os.environ.get("PACCMANN_COLOR_NORMALIZATION", "logarithmic")


def validate_smiles(smiles: str) -> bool:
    """
    Validate a SMILES.

    Args:
        smiles (str): a SMILES string.

    Returns:
        bool: flag indicating whether the SMILES is a valid molecule.
    """
    molecule = Chem.MolFromSmiles(smiles)
    return not (molecule is None)


def canonicalize_smiles(smiles: str) -> str:
    """
    Canonicalize a SMILES.

    Args:
        smiles (str): a SMILES string.

    Returns:
        str: the canonicalized SMILES.
    """
    molecule = Chem.MolFromSmiles(smiles)
    return Chem.MolToSmiles(molecule)


def remove_housekeeping_from_tokens_and_smiles_attention(
    tokens: Iterable[str], smiles_attention: Iterable[float]
) -> Tuple[Iterable[str], Iterable[float]]:
    """
    Remove housekeeping tokens and corresponding attention weights.

    Args:
        tokens (Iterable[str]): tokens obtained from the SMILES.
        smiles_attention (Iterable[float]): SMILES attention.

    Returns:
        Tuple[Iterable[str], Iterable[float]]: a tuple containing the filtered
            tokens and attention values.
    """
    to_keep = [index for index, token in enumerate(tokens) if token in MOLECULE_TOKENS]
    return (
        list(itemgetter(*to_keep)(tokens)),
        list(itemgetter(*to_keep)(smiles_attention)),
    )


def _get_index_and_colors(
    values: Iterable[float],
    tokens: Iterable[str],
    predicate: Callable[[tuple], bool],
    color_mapper: cm.ScalarMappable,
) -> Tuple[Iterable[int], Iterable[tuple]]:
    """
    Get index and RGB colors from a color map using a rule.

    Args:
        values (Iterable[float]): values associated to tokens.
        tokens (Iterable[str]): tokens.
        predicate (Callable[[tuple], bool]): a predicate that acts on a tuple
            of (value, object).
        color_mapper (cm.ScalarMappable): a color mapper.

    Returns:
        Tuple[Iterable[int], Iterable[tuple]]: tuple with indexes and RGB
            colors associated to the given index.
    """
    indices = []
    colors = {}
    for index, value in enumerate(
        map(lambda t: t[0], filter(lambda t: predicate(t), zip(values, tokens)))
    ):
        indices.append(index)
        colors[index] = color_mapper.to_rgba(value)
    return indices, colors


def smiles_attention_to_svg(
    smiles: str, smiles_attention: Iterable[float]
) -> Tuple[str, str]:
    """
    Generate an svg of the molecule highlighiting SMILES attention.

    Args:
        smiles (str): SMILES representing a molecule.
        smiles_attention (Iterable[float]): SMILES attention.

    Returns:
        Tuple[str, str]: drawing, colorbar
            the svg of the molecule highlighiting SMILES attention
            and the svg displaying the colorbar
    """
    # remove padding
    logger.debug("SMILES attention:\n{}.".format(smiles_attention))
    logger.debug(
        "SMILES attention range: [{},{}].".format(
            min(smiles_attention), max(smiles_attention)
        )
    )
    # get the molecule
    molecule = Chem.MolFromSmiles(smiles)
    tokens = [
        SMILES_LANGUAGE.index_to_token[token_index]
        for token_index in SMILES_TOKENIZE_FN(smiles)
    ]
    logger.debug("SMILES tokens:{}.".format(tokens))
    tokens, smiles_attention = remove_housekeeping_from_tokens_and_smiles_attention(
        tokens, smiles_attention
    )  # yapf:disable
    logger.debug(
        "tokens and SMILES attention after removal:\n{}\n{}.".format(
            tokens, smiles_attention
        )
    )
    logger.debug(
        "SMILES attention range after padding removal: [{},{}].".format(
            min(smiles_attention), max(smiles_attention)
        )
    )
    # define a color map
    normalize = COLOR_NORMALIZERS.get(COLOR_NORMALIZATION, mpl.colors.LogNorm)(
        vmin=min(smiles_attention), vmax=min(1.0, 2 * max(smiles_attention))
    )
    color_mapper = cm.ScalarMappable(norm=normalize, cmap=CMAP)
    # get atom colors
    highlight_atoms, highlight_atom_colors = _get_index_and_colors(
        smiles_attention, tokens, lambda t: t[1] not in NON_ATOM_TOKENS, color_mapper
    )
    logger.debug("atom colors:\n{}.".format(highlight_atom_colors))
    # get bond colors
    highlight_bonds, highlight_bond_colors = _get_index_and_colors(
        smiles_attention, tokens, lambda t: t[1] in NON_ATOM_TOKENS, color_mapper
    )
    logger.debug("bond colors:\n{}.".format(highlight_bond_colors))
    # add coordinates
    logger.debug("compute 2D coordinates")
    Chem.rdDepictor.Compute2DCoords(molecule)
    # draw the molecule
    logger.debug("get a drawer")
    drawer = rdMolDraw2D.MolDraw2DSVG(SVG_WIDTH, SVG_HEIGHT)
    logger.debug("draw the molecule")
    drawer.DrawMolecule(
        molecule,
        highlightAtoms=highlight_atoms,
        highlightAtomColors=highlight_atom_colors,
        highlightBonds=highlight_bonds,
        highlightBondColors=highlight_bond_colors,
        highlightAtomRadii={index: ATOM_RADII for index in highlight_atoms},
    )
    logger.debug("finish drawing")
    drawer.FinishDrawing()
    # the drawn molecule as str
    logger.debug("drawing to string")
    drawing = drawer.GetDrawingText().replace("\n", " ")
    # the respective colorbar
    logger.debug("draw the colorbar")
    fig, ax = plt.subplots(figsize=(0.5, 6))
    mpl.colorbar.ColorbarBase(
        ax,
        cmap=CMAP,
        norm=normalize,
        orientation="vertical",
        extend="both",
        extendrect=True,
    )
    # instead of LogFormatterSciNotation
    logger.debug("format the colorbar")
    ax.yaxis.set_minor_formatter(ScalarFormatter())
    ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))  # fixes 0.1, 0.20
    # the colorbar svg as str
    logger.debug("colorbar to string")
    file_like = StringIO()
    plt.savefig(file_like, format="svg", bbox_inches="tight")
    colorbar = file_like.getvalue().replace("\n", " ")
    plt.close(fig)
    return drawing, colorbar