paccmann / forward.py
jannisborn's picture
update
ec53722 unverified
"""Inference utilities."""
import logging
import torch
import numpy as np
from paccmann_predictor.models.paccmann import MCA
from pytoda.transforms import Compose
from pytoda.smiles.transforms import ToTensor
from configuration import (
MODEL_WEIGHTS_URI,
MODEL_PARAMS,
SMILES_LANGUAGE,
SMILES_TRANSFORMS,
)
logger = logging.getLogger("openapi_server:inference")
# NOTE: to avoid segfaults
torch.set_num_threads(1)
def predict(
smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False
) -> dict:
"""
Run PaccMann prediction.
Args:
smiles (str): SMILES representing a compound.
gene_expression (np.ndarray): gene expression data.
estimate_confidence (bool, optional): estimate confidence of the
prediction. Defaults to False.
Returns:
dict: the prediction dictionaty from the model.
"""
logger.debug("running predict.")
logger.debug("gene expression shape: {}.".format(gene_expression.shape))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.debug("device selected: {}.".format(device))
logger.debug("loading model for prediction.")
model = MCA(MODEL_PARAMS)
model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device))
model.eval()
if estimate_confidence:
logger.debug("associating SMILES language for confidence estimates.")
model._associate_language(SMILES_LANGUAGE)
logger.debug("model loaded.")
logger.debug("set up the transformation.")
smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)])
logger.debug("starting the prediction.")
with torch.no_grad():
_, prediction_dict = model(
smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1),
torch.tensor(gene_expression).float(),
confidence=estimate_confidence,
)
logger.debug("successful prediction.")
return prediction_dict