Spaces:
Running
Running
"""Submission-related utilities.""" | |
import os | |
import json | |
import logging | |
import numpy as np | |
import pandas as pd | |
from io import StringIO | |
from typing import Optional | |
from sklearn.preprocessing import StandardScaler | |
from configuration import ( | |
GENE_EXPRESSION_DATA, | |
GENE_EXPRESSION_METADATA, | |
GENES, | |
GENE_STANDARDIZATION_PARAMETERS, | |
) | |
from cos import RESULTS_PREFIX, string_to_key | |
from forward import predict | |
# from attention import upload_attention | |
logger = logging.getLogger("openapi_server:submission") | |
def submission( | |
drug: dict, | |
workspace_id: str, | |
task_id: str, | |
estimate_confidence: bool = False, | |
omics_file: Optional[str] = None, | |
) -> None: | |
""" | |
Submit PaccMann prediction | |
Args: | |
drug (dict): drug to analyse in dictionary format. | |
workspace_id (str): workspace identifier for the submission. | |
task_id (str): task identifier. | |
estimate_confidence (bool, optional): estimate confidence of the | |
prediction. Defaults to False. | |
omics_file (Optional[str], optional): binary string containing | |
expression data. Defaults to None. | |
""" | |
prefix = os.path.join(RESULTS_PREFIX, workspace_id, task_id) | |
logger.debug("processing omic data.") | |
# NOTE: this trick is used in case a single example is passed | |
single_example = False | |
result = {} | |
if omics_file is None: | |
gene_expression, gene_expression_metadata = ( | |
GENE_EXPRESSION_DATA, | |
GENE_EXPRESSION_METADATA, | |
) | |
else: | |
logger.debug("parsing uploaded omic data.") | |
logger.debug(omics_file) | |
gene_expression_df = pd.read_csv(omics_file, low_memory=False) | |
logger.debug(gene_expression_df.columns) | |
to_drop = list(set(GENES) & set(gene_expression_df.columns)) | |
gene_expression_data, gene_expression_metadata = ( | |
gene_expression_df.T.reindex(GENES).fillna(0.0).T, | |
gene_expression_df.drop(to_drop, axis=1), | |
) | |
logger.debug("peek parsed expression and metadata.") | |
logger.debug("gene_expression_data:\n{}".format(gene_expression_data.head())) | |
logger.debug( | |
"gene_expression_metadata:\n{}".format(gene_expression_metadata.head()) | |
) | |
if gene_expression_data.shape[0] < 2: | |
logger.debug( | |
"single example, standardizing with default parameters:\n{}".format( | |
GENE_STANDARDIZATION_PARAMETERS | |
) | |
) | |
single_example = True | |
gene_expression = ( | |
gene_expression_data.values - GENE_STANDARDIZATION_PARAMETERS[0] | |
) / GENE_STANDARDIZATION_PARAMETERS[1] | |
gene_expression = np.vstack(2 * [gene_expression]) | |
logger.debug(gene_expression.shape) | |
else: | |
gene_expression = StandardScaler().fit_transform( | |
gene_expression_data.values | |
) | |
logger.debug("gene_expression:\n{}".format(gene_expression[:10])) | |
logger.debug("omic data prepared if present.") | |
prediction_dict = predict( | |
smiles=drug["smiles"], | |
gene_expression=gene_expression, | |
estimate_confidence=estimate_confidence, | |
) | |
# from tensors | |
for key, value in prediction_dict.items(): | |
prediction_dict[key] = value.numpy()[:1] if single_example else value.numpy() | |
result.update(prediction_dict) | |
# merge for single table, index is unique identifier for samples. | |
gene_expression_metadata["IC50 (min/max scaled)"] = prediction_dict["IC50"] | |
gene_expression_metadata["IC50 (log(μmol))"] = prediction_dict[ | |
"log_micromolar_IC50" | |
] | |
if estimate_confidence: | |
gene_expression_metadata["epistemic_confidence"] = prediction_dict[ | |
"epistemic_confidence" | |
] | |
gene_expression_metadata["aleatoric_confidence"] = prediction_dict[ | |
"aleatoric_confidence" | |
] | |
logger.debug("uploaded predicted sensitivity table including metadata.") | |
# attention | |
# result.update( | |
# upload_attention( | |
# prefix, | |
# sample_names=list(map(str, gene_expression_metadata.index)), | |
# omic_attention=prediction_dict["gene_attention"], | |
# smiles_attention=prediction_dict["smiles_attention"], | |
# ) | |
# ) | |
logger.debug("uploaded attention for each sample.") | |
logger.debug("uploading drug information and sensitivity.") | |
# prediction (is sensitivity_json in API) | |
logger.debug("uploaded drug information and sensitivity.") | |
# NOTE: Ordering corresponds to IDs in GEP metadata! | |
return result | |