Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import pandas as pd | |
from pathlib import Path | |
from scipy.io import mmread | |
from torchvision.transforms import Compose | |
from torch.utils.data import Dataset | |
class CellPainting(Dataset): | |
def __init__(self, sample_index_file: str, image_directory_path: str = None, molecule_file: str = None, label_matrix_file: str = None, | |
label_row_index_file: str = None, label_col_index_file: str = None, auxiliary_labels=None, | |
transforms=None, group_views: bool = False, | |
subset: float = 1., num_classes: int = None, verbose: bool = False): | |
""" Read samples from cellpainting dataset.""" | |
self.verbose = verbose | |
self.molecules = False | |
self.images = False | |
assert (os.path.exists(sample_index_file)) | |
print(image_directory_path) | |
print(molecule_file) | |
# Read sample index | |
sample_index = pd.read_csv(sample_index_file, sep=",", header=0) | |
sample_index.set_index(["SAMPLE_KEY"]) | |
# read auxiliary labels if provided | |
if auxiliary_labels is not None: | |
pddata = pd.read_csv(auxiliary_labels, sep=",", header=0) | |
self.auxiliary_data = pddata.as_matrix()[:, 2:].astype(np.float32) | |
# threshold | |
self.auxiliary_data[self.auxiliary_data < 0.75] = -1 | |
self.auxiliary_data[self.auxiliary_data >= 0.75] = 1 | |
self.auxiliary_assays = list(pddata)[2:] | |
self.n_auxiliary_classes = len(self.auxiliary_assays) | |
self.auxiliary_smiles = pddata["SMILES"].tolist() | |
else: | |
self.n_auxiliary_classes = 0 | |
if image_directory_path: | |
self.images = True | |
assert (os.path.exists(image_directory_path)) | |
if group_views: | |
sample_groups = sample_index.groupby(['PLATE_ID', 'WELL_POSITION']) | |
sample_keys = list(sample_groups.groups.keys()) | |
sample_index = sample_groups | |
self.sample_to_smiles = None # TODO | |
else: | |
sample_keys = sample_index['SAMPLE_KEY'].tolist() | |
if auxiliary_labels is not None: | |
self.sample_to_smiles = dict(zip(sample_index.SAMPLE_KEY, [self.auxiliary_smiles.index(s) for s in sample_index.SMILES])) | |
else: | |
self.sample_to_smiles = None | |
if molecule_file: | |
self.molecules = True | |
assert (os.path.exists(molecule_file)) | |
molecule_df = pd.read_hdf(molecule_file, key="df") | |
#molecule_objs = {index: row.values for index, row in molecule_df.iterrows()} | |
#keys = list(set(sample_keys) & set(list(molecule_df.index.values))) | |
mol_keys = list(molecule_df.index.values) | |
if self.images and self.molecules: | |
keys = list(set(sample_keys) & set(list(molecule_df.index.values))) | |
elif self.images: | |
keys = sample_keys | |
elif self.molecules: | |
keys = mol_keys | |
if len(keys) == 0: | |
raise Exception("Empty dataset!") | |
else: | |
self.log("Found {} samples".format(len(keys))) | |
if subset != 1.: | |
sample_keys = sample_keys[:int(len(sample_keys) * subset)] | |
# Read Label Matrix if specified | |
if label_matrix_file is not None: | |
assert (os.path.exists(label_matrix_file)) | |
assert (os.path.exists(label_row_index_file)) | |
assert (os.path.exists(label_col_index_file)) | |
if label_row_index_file is not None and label_col_index_file is not None: | |
col_index = pd.read_csv(label_col_index_file, sep=",", header=0) | |
row_index = pd.read_csv(label_row_index_file, sep=",", header=0) | |
label_matrix = mmread(label_matrix_file).tocsr() | |
# -- | |
self.label_matrix = label_matrix | |
self.row_index = row_index | |
self.col_index = col_index | |
if group_views: | |
self.label_dict = dict( | |
(key, sample_groups.get_group(key).iloc[0].ROW_NR_LABEL_MAT) for key in sample_keys) | |
else: | |
self.label_dict = dict(zip(sample_index.SAMPLE_KEY, sample_index.ROW_NR_LABEL_MAT)) | |
self.n_classes = label_matrix.shape[1] | |
else: | |
raise Exception("If label is specified index files must be passed!") | |
else: | |
self.label_matrix = None | |
self.row_index = None | |
self.col_index = None | |
self.label_dict = None | |
self.n_classes = num_classes | |
if auxiliary_labels is not None: | |
self.n_classes += self.n_auxiliary_classes | |
# expose everything important | |
self.data_directory = image_directory_path | |
self.sample_index = sample_index | |
if self.molecules: | |
self.molecule_objs = molecule_df | |
self.keys = keys | |
self.n_samples = len(keys) | |
self.sample_keys = list(keys) | |
self.group_views = group_views | |
self.transforms = transforms | |
# load first sample and check shape | |
i = 0 | |
sample = self[i][0] if self.molecules else self[i] #getitem returns tuple of img and fp | |
# while sample["input"] is np.nan and i < len(self): | |
# sample = self[i][0] if self.molecules else self[i] | |
# i += 1 | |
# | |
# if sample["input"] is not None and not np.nan: | |
# self.data_shape = sample["input"].shape | |
# else: | |
# self.data_shape = "Unknown" | |
# self.log("Discovered {} samples (subset={}) with shape {}".format(self.n_samples, subset, self.data_shape)) | |
def __len__(self): | |
return len(self.keys) | |
## TODO: Clean! | |
def __getitem__(self, idx): | |
sample_key = self.keys[idx] | |
if self.molecules and self.images: | |
mol = self.molecule_objs.loc[sample_key].values | |
img = self.read_img(sample_key) | |
# mol = list(self.molecule_objs.loc[sample_key].values) | |
return img, mol | |
elif self.images: | |
img = self.read_img(sample_key) | |
return img | |
elif self.molecules: | |
mol = self.molecule_objs.loc[sample_key].values | |
return mol | |
def shape(self): | |
return self.data_shape | |
def num_classes(self): | |
return self.n_classes | |
def log(self, message): | |
if self.verbose: | |
print(message) | |
def read_img(self, key): | |
if self.group_views: | |
X = self.load_view_group(key) | |
else: | |
filepath = os.path.join(self.data_directory, "{}.npz".format(key)) | |
if os.path.exists(filepath): | |
X = self.load_view(filepath=filepath) | |
index = int(np.where(self.sample_index["SAMPLE_KEY"]==key)[0]) | |
#cpd = str(self.sample_index["CPD_NAME"]) | |
else: | |
#print("ERROR: Missing sample '{}'".format(key)) | |
return dict(input=np.nan, ID=key) | |
if self.transforms: | |
X = self.transforms(X) | |
# get label | |
if self.label_dict is not None: | |
label_idx = self.label_dict[key] | |
y = self.label_matrix[label_idx].toarray()[0].astype(np.float32) | |
if self.sample_to_smiles is not None and key in self.sample_to_smiles: | |
y = np.concatenate([y, self.auxiliary_data[self.sample_to_smiles[key], :]]) | |
return dict(input=X, target=y, ID=key) | |
else: | |
return dict(input=X, row_id=index, ID=key) | |
def get_sample_keys(self): | |
return self.sample_keys.copy() | |
def load_view(self, filepath): | |
"""Load all channels for one sample""" | |
npz = np.load(filepath, allow_pickle=True) | |
if "sample" in npz: | |
image = npz["sample"].astype(np.float32) | |
#image_reshaped = np.transpose(image, (2, 0, 1)) | |
# for c in range(image.shape[-1]): | |
# image[:, :, c] = (image[:, :, c] - image[:, :, c].mean()) / image[:, :, c].std() | |
# image[:, :, c] = ((image[:, :, c] - image[:, :, c].mean()) / image[:, :, c].std() * 255).astype(np.uint8) | |
# image = (image - image.mean()) / image.std() | |
return image | |
return None | |
def load_view_group(self, groupkey): | |
result = np.empty((1040, 2088 - 12, 5), dtype=np.uint8) | |
viewgroup = self.sample_index.get_group(groupkey) | |
for i, view in enumerate(viewgroup.sort_values("SITE", ascending=True).iterrows()): | |
corner = (0 if int(i / 3) == 0 else 520, i % 3 * 692) | |
filepath = os.path.join(self.data_directory, "{}.npz".format(view[1].SAMPLE_KEY)) | |
v = self.load_view(filepath=filepath)[:, 4:, :] | |
# for j in range(v.shape[-1]): | |
# plt.imshow(v[:, :, j]) | |
# plt.savefig("{}-{}-{}-{}.png".format(groupkey[0], groupkey[1], i, j)) | |
result[corner[0]:corner[0] + 520, corner[1]:corner[1] + 692, :] = v | |
return result | |