"""SMILES utilities.""" import os import regex import logging import matplotlib import matplotlib as mpl import matplotlib.cm as cm import matplotlib.pyplot as plt from io import StringIO from operator import itemgetter from typing import Callable, Iterable, Tuple from matplotlib.ticker import FormatStrFormatter, ScalarFormatter from rdkit import Chem from rdkit.Chem.Draw import rdMolDraw2D from configuration import SMILES_LANGUAGE, SMILES_TOKENIZE_FN logger = logging.getLogger(__name__) # NOTE: avoid segfaults in matplotlib matplotlib.use("Agg") MOLECULE_TOKENS = set(SMILES_LANGUAGE.token_to_index.keys()) NON_ATOM_REGEX = regex.compile(r"^(\d|\%\d+|\p{P}+|\p{Math}+)$") NON_ATOM_TOKENS = set( [token for token in MOLECULE_TOKENS if NON_ATOM_REGEX.match(token)] ) CMAP = cm.Oranges COLOR_NORMALIZERS = {"linear": mpl.colors.Normalize, "logarithmic": mpl.colors.LogNorm} ATOM_RADII = float(os.environ.get("PACCMANN_ATOM_RADII", 0.5)) SVG_WIDTH = int(os.environ.get("PACCMANN_SVG_WIDTH", 400)) SVG_HEIGHT = int(os.environ.get("PACCMANN_SVG_HEIGHT", 200)) COLOR_NORMALIZATION = os.environ.get("PACCMANN_COLOR_NORMALIZATION", "logarithmic") def validate_smiles(smiles: str) -> bool: """ Validate a SMILES. Args: smiles (str): a SMILES string. Returns: bool: flag indicating whether the SMILES is a valid molecule. """ molecule = Chem.MolFromSmiles(smiles) return not (molecule is None) def canonicalize_smiles(smiles: str) -> str: """ Canonicalize a SMILES. Args: smiles (str): a SMILES string. Returns: str: the canonicalized SMILES. """ molecule = Chem.MolFromSmiles(smiles) return Chem.MolToSmiles(molecule) def remove_housekeeping_from_tokens_and_smiles_attention( tokens: Iterable[str], smiles_attention: Iterable[float] ) -> Tuple[Iterable[str], Iterable[float]]: """ Remove housekeeping tokens and corresponding attention weights. Args: tokens (Iterable[str]): tokens obtained from the SMILES. smiles_attention (Iterable[float]): SMILES attention. Returns: Tuple[Iterable[str], Iterable[float]]: a tuple containing the filtered tokens and attention values. """ to_keep = [index for index, token in enumerate(tokens) if token in MOLECULE_TOKENS] return ( list(itemgetter(*to_keep)(tokens)), list(itemgetter(*to_keep)(smiles_attention)), ) def _get_index_and_colors( values: Iterable[float], tokens: Iterable[str], predicate: Callable[[tuple], bool], color_mapper: cm.ScalarMappable, ) -> Tuple[Iterable[int], Iterable[tuple]]: """ Get index and RGB colors from a color map using a rule. Args: values (Iterable[float]): values associated to tokens. tokens (Iterable[str]): tokens. predicate (Callable[[tuple], bool]): a predicate that acts on a tuple of (value, object). color_mapper (cm.ScalarMappable): a color mapper. Returns: Tuple[Iterable[int], Iterable[tuple]]: tuple with indexes and RGB colors associated to the given index. """ indices = [] colors = {} for index, value in enumerate( map(lambda t: t[0], filter(lambda t: predicate(t), zip(values, tokens))) ): indices.append(index) colors[index] = color_mapper.to_rgba(value) return indices, colors def smiles_attention_to_svg( smiles: str, smiles_attention: Iterable[float] ) -> Tuple[str, str]: """ Generate an svg of the molecule highlighiting SMILES attention. Args: smiles (str): SMILES representing a molecule. smiles_attention (Iterable[float]): SMILES attention. Returns: Tuple[str, str]: drawing, colorbar the svg of the molecule highlighiting SMILES attention and the svg displaying the colorbar """ # remove padding logger.debug("SMILES attention:\n{}.".format(smiles_attention)) logger.debug( "SMILES attention range: [{},{}].".format( min(smiles_attention), max(smiles_attention) ) ) # get the molecule molecule = Chem.MolFromSmiles(smiles) tokens = [ SMILES_LANGUAGE.index_to_token[token_index] for token_index in SMILES_TOKENIZE_FN(smiles) ] logger.debug("SMILES tokens:{}.".format(tokens)) tokens, smiles_attention = remove_housekeeping_from_tokens_and_smiles_attention( tokens, smiles_attention ) # yapf:disable logger.debug( "tokens and SMILES attention after removal:\n{}\n{}.".format( tokens, smiles_attention ) ) logger.debug( "SMILES attention range after padding removal: [{},{}].".format( min(smiles_attention), max(smiles_attention) ) ) # define a color map normalize = COLOR_NORMALIZERS.get(COLOR_NORMALIZATION, mpl.colors.LogNorm)( vmin=min(smiles_attention), vmax=min(1.0, 2 * max(smiles_attention)) ) color_mapper = cm.ScalarMappable(norm=normalize, cmap=CMAP) # get atom colors highlight_atoms, highlight_atom_colors = _get_index_and_colors( smiles_attention, tokens, lambda t: t[1] not in NON_ATOM_TOKENS, color_mapper ) logger.debug("atom colors:\n{}.".format(highlight_atom_colors)) # get bond colors highlight_bonds, highlight_bond_colors = _get_index_and_colors( smiles_attention, tokens, lambda t: t[1] in NON_ATOM_TOKENS, color_mapper ) logger.debug("bond colors:\n{}.".format(highlight_bond_colors)) # add coordinates logger.debug("compute 2D coordinates") Chem.rdDepictor.Compute2DCoords(molecule) # draw the molecule logger.debug("get a drawer") drawer = rdMolDraw2D.MolDraw2DSVG(SVG_WIDTH, SVG_HEIGHT) logger.debug("draw the molecule") drawer.DrawMolecule( molecule, highlightAtoms=highlight_atoms, highlightAtomColors=highlight_atom_colors, highlightBonds=highlight_bonds, highlightBondColors=highlight_bond_colors, highlightAtomRadii={index: ATOM_RADII for index in highlight_atoms}, ) logger.debug("finish drawing") drawer.FinishDrawing() # the drawn molecule as str logger.debug("drawing to string") drawing = drawer.GetDrawingText().replace("\n", " ") # the respective colorbar logger.debug("draw the colorbar") fig, ax = plt.subplots(figsize=(0.5, 6)) mpl.colorbar.ColorbarBase( ax, cmap=CMAP, norm=normalize, orientation="vertical", extend="both", extendrect=True, ) # instead of LogFormatterSciNotation logger.debug("format the colorbar") ax.yaxis.set_minor_formatter(ScalarFormatter()) ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) # fixes 0.1, 0.20 # the colorbar svg as str logger.debug("colorbar to string") file_like = StringIO() plt.savefig(file_like, format="svg", bbox_inches="tight") colorbar = file_like.getvalue().replace("\n", " ") plt.close(fig) return drawing, colorbar