"""Inference utilities.""" import logging import torch import numpy as np from paccmann_predictor.models.paccmann import MCA from pytoda.transforms import Compose from pytoda.smiles.transforms import ToTensor from configuration import ( MODEL_WEIGHTS_URI, MODEL_PARAMS, SMILES_LANGUAGE, SMILES_TRANSFORMS, ) logger = logging.getLogger("openapi_server:inference") # NOTE: to avoid segfaults torch.set_num_threads(1) def predict( smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False ) -> dict: """ Run PaccMann prediction. Args: smiles (str): SMILES representing a compound. gene_expression (np.ndarray): gene expression data. estimate_confidence (bool, optional): estimate confidence of the prediction. Defaults to False. Returns: dict: the prediction dictionaty from the model. """ logger.debug("running predict.") logger.debug("gene expression shape: {}.".format(gene_expression.shape)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.debug("device selected: {}.".format(device)) logger.debug("loading model for prediction.") model = MCA(MODEL_PARAMS) model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device)) model.eval() if estimate_confidence: logger.debug("associating SMILES language for confidence estimates.") model._associate_language(SMILES_LANGUAGE) logger.debug("model loaded.") logger.debug("set up the transformation.") smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)]) logger.debug("starting the prediction.") with torch.no_grad(): _, prediction_dict = model( smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1), torch.tensor(gene_expression).float(), confidence=estimate_confidence, ) logger.debug("successful prediction.") return prediction_dict