Spaces:
Running
Running
File size: 4,334 Bytes
ec53722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""Get/put submission results concerning attention from/on COS."""
import os
import json
import dill
import logging
import numpy as np
from typing import Iterable
from configuration import GENES
from cos import (
RESULTS_PREFIX,
bytes_from_key,
string_from_key,
bytes_to_key,
)
from utils import Drug
from plots import embed_barplot
from smiles import smiles_attention_to_svg
logger = logging.getLogger("openapi_server:attention")
def download_attention(workspace_id: str, task_id: str, sample_name: str) -> dict:
"""
Download attention figures and related data.
Args:
workspace_id (str): workspace identifier.
task_id (str): task identifier.
sample_name (str): name of the sample.
Returns:
dict: attention figures and related data.
"""
def _remote_to_bytes(basename: str) -> bytes:
object_name = os.path.join(workspace_id, task_id, sample_name, basename)
key = os.path.join(RESULTS_PREFIX, object_name)
return bytes_from_key(key)
drug_path = os.path.join(workspace_id, task_id, "drug.json")
key = os.path.join(RESULTS_PREFIX, drug_path)
drug = Drug(**json.loads(string_from_key(key)))
logger.debug(f"download attention results from COS for {drug.smiles}.")
# omic
logger.debug("gene attention.")
gene_attention = dill.loads(_remote_to_bytes("gene_attention.pkl"))
genes = np.array(GENES)
order = gene_attention.argsort()[::-1] # descending
gene_attention_js, gene_attention_html = embed_barplot(
genes[order], gene_attention[order]
)
logger.debug("gene attention plots created.")
# smiles
logger.debug("SMILES attention.")
smiles_attention = dill.loads(_remote_to_bytes("smiles_attention.pkl"))
drug_attention_svg, drug_color_bar_svg = smiles_attention_to_svg(
drug.smiles, smiles_attention
)
logger.debug("SMILES attention plots created.")
return {
"drug": drug,
"sample_name": sample_name,
"sample_drug_attention_svg": drug_attention_svg,
"sample_drug_color_bar_svg": drug_color_bar_svg,
"sample_gene_attention_js": gene_attention_js,
"sample_gene_attention_html": gene_attention_html,
}
def _upload_ndarray(sample_prefix: str, array: np.ndarray, filename: str) -> None:
bytes_to_key(dill.dumps(array), os.path.join(sample_prefix, f"{filename}.pkl"))
def upload_attention(
prefix: str,
sample_names: Iterable[str],
omic_attention: np.ndarray,
smiles_attention: np.ndarray,
) -> None:
"""
Upload attention profiles.
Args:
prefix (str): base prefix used as a root.
sample_names (Iterable[str]): name of the samples.
omic_attention (np.ndarray): attention values for genes.
smiles_attention (np.ndarray): attention values for SMILES.
Raises:
ValueError: mismatch in sample names and gene attention.
ValueError: mismatch in sample names and SMILES attention.
ValueError: mismatch in number of genes and gene attention.
"""
omic_entities = np.array(GENES)
# sanity checks
if len(sample_names) != omic_attention.shape[0]:
raise ValueError(
f"length of sample_names {len(sample_names)} does not "
f"match omic_attention {omic_attention.shape[0]}"
)
if len(sample_names) != len(smiles_attention):
raise ValueError(
f"length of sample_names {len(sample_names)} does not "
f"match smiles_attention {len(smiles_attention)}"
)
if len(omic_entities) != omic_attention.shape[1]:
raise ValueError(
f"length of omic_entities {len(omic_entities)} "
f"does not match omic_attention.shape[1] {omic_attention.shape[1]}"
)
# special case first
sample_name = "average"
# omic
res = {}
omic_alphas = omic_attention.mean(axis=0)
res["gene_attention"] = omic_alphas
# smiles
smiles_alphas = smiles_attention.mean(axis=0)
res["smiles_attention"] = smiles_alphas
# logging.debug('uploaded "average" attention figures.')
# for index, sample_name in enumerate(sample_names):
# res[f"gene_attention_{index}"] = omic_attention[index]
# res[f"smiles_attention_{index}"] = smiles_attention[index]
return res
|