File size: 2,002 Bytes
ec53722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Inference utilities."""
import logging
import torch
import numpy as np
from paccmann_predictor.models.paccmann import MCA
from pytoda.transforms import Compose
from pytoda.smiles.transforms import ToTensor
from configuration import (
    MODEL_WEIGHTS_URI,
    MODEL_PARAMS,
    SMILES_LANGUAGE,
    SMILES_TRANSFORMS,
)

logger = logging.getLogger("openapi_server:inference")
# NOTE: to avoid segfaults
torch.set_num_threads(1)


def predict(
    smiles: str, gene_expression: np.ndarray, estimate_confidence: bool = False
) -> dict:
    """
    Run PaccMann prediction.

    Args:
        smiles (str): SMILES representing a compound.
        gene_expression (np.ndarray): gene expression data.
        estimate_confidence (bool, optional): estimate confidence of the
            prediction. Defaults to False.
    Returns:
        dict: the prediction dictionaty from the model.
    """
    logger.debug("running predict.")
    logger.debug("gene expression shape: {}.".format(gene_expression.shape))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.debug("device selected: {}.".format(device))
    logger.debug("loading model for prediction.")
    model = MCA(MODEL_PARAMS)
    model.load_state_dict(torch.load(MODEL_WEIGHTS_URI, map_location=device))
    model.eval()
    if estimate_confidence:
        logger.debug("associating SMILES language for confidence estimates.")
        model._associate_language(SMILES_LANGUAGE)
    logger.debug("model loaded.")
    logger.debug("set up the transformation.")
    smiles_transform_fn = Compose(SMILES_TRANSFORMS + [ToTensor(device=device)])
    logger.debug("starting the prediction.")
    with torch.no_grad():
        _, prediction_dict = model(
            smiles_transform_fn(smiles).view(1, -1).repeat(gene_expression.shape[0], 1),
            torch.tensor(gene_expression).float(),
            confidence=estimate_confidence,
        )
    logger.debug("successful prediction.")
    return prediction_dict