Spaces:
Running
Running
"""Inference utilities.""" | |
import logging | |
import torch | |
import numpy as np | |
from paccmann_predictor.models.paccmann import MCA | |
from pytoda.transforms import Compose | |
from pytoda.smiles.transforms import ToTensor | |
from configuration import ( | |
MODEL_WEIGHTS_URI, | |
MODEL_PARAMS, | |
SMILES_LANGUAGE, | |
SMILES_TRANSFORMS, | |
) | |
logger = logging.getLogger("openapi_server:inference") | |
# NOTE: to avoid segfaults | |
torch.set_num_threads(1) | |
def predict( | |
smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False | |
) -> dict: | |
""" | |
Run PaccMann prediction. | |
Args: | |
smiles (str): SMILES representing a compound. | |
gene_expression (np.ndarray): gene expression data. | |
estimate_confidence (bool, optional): estimate confidence of the | |
prediction. Defaults to False. | |
Returns: | |
dict: the prediction dictionaty from the model. | |
""" | |
logger.debug("running predict.") | |
logger.debug("gene expression shape: {}.".format(gene_expression.shape)) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.debug("device selected: {}.".format(device)) | |
logger.debug("loading model for prediction.") | |
model = MCA(MODEL_PARAMS) | |
model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device)) | |
model.eval() | |
if estimate_confidence: | |
logger.debug("associating SMILES language for confidence estimates.") | |
model._associate_language(SMILES_LANGUAGE) | |
logger.debug("model loaded.") | |
logger.debug("set up the transformation.") | |
smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)]) | |
logger.debug("starting the prediction.") | |
with torch.no_grad(): | |
_, prediction_dict = model( | |
smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1), | |
torch.tensor(gene_expression).float(), | |
confidence=estimate_confidence, | |
) | |
logger.debug("successful prediction.") | |
return prediction_dict | |