File size: 4,334 Bytes
ec53722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Get/put submission results concerning attention from/on COS."""
import os
import json
import dill
import logging
import numpy as np
from typing import Iterable
from configuration import GENES
from cos import (
    RESULTS_PREFIX,
    bytes_from_key,
    string_from_key,
    bytes_to_key,
)
from utils import Drug
from plots import embed_barplot
from smiles import smiles_attention_to_svg

logger = logging.getLogger("openapi_server:attention")


def download_attention(workspace_id: str, task_id: str, sample_name: str) -> dict:
    """
    Download attention figures and related data.

    Args:
        workspace_id (str): workspace identifier.
        task_id (str): task identifier.
        sample_name (str): name of the sample.

    Returns:
        dict: attention figures and related data.
    """

    def _remote_to_bytes(basename: str) -> bytes:
        object_name = os.path.join(workspace_id, task_id, sample_name, basename)
        key = os.path.join(RESULTS_PREFIX, object_name)
        return bytes_from_key(key)

    drug_path = os.path.join(workspace_id, task_id, "drug.json")
    key = os.path.join(RESULTS_PREFIX, drug_path)
    drug = Drug(**json.loads(string_from_key(key)))
    logger.debug(f"download attention results from COS for {drug.smiles}.")
    # omic
    logger.debug("gene attention.")
    gene_attention = dill.loads(_remote_to_bytes("gene_attention.pkl"))
    genes = np.array(GENES)
    order = gene_attention.argsort()[::-1]  # descending
    gene_attention_js, gene_attention_html = embed_barplot(
        genes[order], gene_attention[order]
    )
    logger.debug("gene attention plots created.")
    # smiles
    logger.debug("SMILES attention.")
    smiles_attention = dill.loads(_remote_to_bytes("smiles_attention.pkl"))
    drug_attention_svg, drug_color_bar_svg = smiles_attention_to_svg(
        drug.smiles, smiles_attention
    )
    logger.debug("SMILES attention plots created.")
    return {
        "drug": drug,
        "sample_name": sample_name,
        "sample_drug_attention_svg": drug_attention_svg,
        "sample_drug_color_bar_svg": drug_color_bar_svg,
        "sample_gene_attention_js": gene_attention_js,
        "sample_gene_attention_html": gene_attention_html,
    }


def _upload_ndarray(sample_prefix: str, array: np.ndarray, filename: str) -> None:
    bytes_to_key(dill.dumps(array), os.path.join(sample_prefix, f"{filename}.pkl"))


def upload_attention(
    prefix: str,
    sample_names: Iterable[str],
    omic_attention: np.ndarray,
    smiles_attention: np.ndarray,
) -> None:
    """
    Upload attention profiles.

    Args:
        prefix (str): base prefix used as a root.
        sample_names (Iterable[str]): name of the samples.
        omic_attention (np.ndarray): attention values for genes.
        smiles_attention (np.ndarray): attention values for SMILES.

    Raises:
        ValueError: mismatch in sample names and gene attention.
        ValueError: mismatch in sample names and SMILES attention.
        ValueError: mismatch in number of genes and gene attention.
    """
    omic_entities = np.array(GENES)
    # sanity checks
    if len(sample_names) != omic_attention.shape[0]:
        raise ValueError(
            f"length of sample_names {len(sample_names)} does not "
            f"match omic_attention {omic_attention.shape[0]}"
        )
    if len(sample_names) != len(smiles_attention):
        raise ValueError(
            f"length of sample_names {len(sample_names)} does not "
            f"match smiles_attention {len(smiles_attention)}"
        )
    if len(omic_entities) != omic_attention.shape[1]:
        raise ValueError(
            f"length of omic_entities {len(omic_entities)} "
            f"does not match omic_attention.shape[1] {omic_attention.shape[1]}"
        )
    # special case first
    sample_name = "average"
    # omic
    res = {}
    omic_alphas = omic_attention.mean(axis=0)
    res["gene_attention"] = omic_alphas

    # smiles
    smiles_alphas = smiles_attention.mean(axis=0)
    res["smiles_attention"] = smiles_alphas

    # logging.debug('uploaded "average" attention figures.')
    # for index, sample_name in enumerate(sample_names):
    #     res[f"gene_attention_{index}"] = omic_attention[index]
    #     res[f"smiles_attention_{index}"] = smiles_attention[index]
    return res