counterfactuals / datasets.py
fabio-deep
added links
146a6ea
import os
import gzip
import struct
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as TF
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset
from typing import Tuple
from PIL import Image
from skimage.io import imread
def log_standardize(x):
log_x = torch.log(x.clamp(min=1e-12))
return (log_x - log_x.mean()) / log_x.std().clamp(min=1e-12) # mean=0, std=1
def normalize(x, x_min=None, x_max=None, zero_one=False):
if x_min is None:
x_min = x.min()
if x_max is None:
x_max = x.max()
print(f"max: {x_max}, min: {x_min}")
x = (x - x_min) / (x_max - x_min) # [0,1]
return x if zero_one else 2 * x - 1 # else [-1,1]
class UKBBDataset(Dataset):
def __init__(
self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True
):
super().__init__()
self.root = root
self.transform = transform
self.concat_pa = concat_pa # return concatenated parents
print(f"\nLoading csv data: {csv_file}")
self.df = pd.read_csv(csv_file)
self.columns = columns
if self.columns is None:
# ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq']
self.columns = list(self.df.columns) # return all
self.columns.pop(0) # remove redundant 'index' column
print(f"columns: {self.columns}")
self.samples = {i: torch.as_tensor(self.df[i]).float() for i in self.columns}
for k in ["age", "brain_volume", "ventricle_volume"]:
print(f"{k} normalization: {norm}")
if k in self.columns:
if norm == "[-1,1]":
self.samples[k] = normalize(self.samples[k])
elif norm == "[0,1]":
self.samples[k] = normalize(self.samples[k], zero_one=True)
elif norm == "log_standard":
self.samples[k] = log_standardize(self.samples[k])
elif norm == None:
pass
else:
NotImplementedError(f"{norm} not implemented.")
print(f"#samples: {len(self.df)}")
self.return_x = True if "eid" in self.columns else False
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
sample = {k: v[idx] for k, v in self.samples.items()}
if self.return_x:
mri_seq = "T1" if sample["mri_seq"] == 0.0 else "T2_FLAIR"
# Load scan
filename = (
f'{int(sample["eid"])}_' + mri_seq + "_unbiased_brain_rigid_to_mni.png"
)
x = Image.open(os.path.join(self.root, "thumbs_192x192", filename))
if self.transform is not None:
sample["x"] = self.transform(x)
sample.pop("eid", None)
if self.concat_pa:
sample["pa"] = torch.cat(
[torch.tensor([sample[k]]) for k in self.columns if k != "eid"], dim=0
)
return sample
def get_attr_max_min(attr):
# some ukbb dataset (max, min) stats
if attr == "age":
return 73, 44
elif attr == "brain_volume":
return 1629520, 841919
elif attr == "ventricle_volume":
return 157075, 7613.27001953125
else:
NotImplementedError
def ukbb(args):
csv_dir = args.data_dir
augmentation = {
"train": TF.Compose(
[
TF.Resize((args.input_res, args.input_res), antialias=None),
TF.RandomCrop(
size=(args.input_res, args.input_res),
padding=[2 * args.pad, args.pad],
),
TF.RandomHorizontalFlip(p=args.hflip),
TF.PILToTensor(),
]
),
"eval": TF.Compose(
[
TF.Resize((args.input_res, args.input_res), antialias=None),
TF.PILToTensor(),
]
),
}
datasets = {}
# for split in ['train', 'valid', 'test']:
for split in ["test"]:
datasets[split] = UKBBDataset(
root=args.data_dir,
csv_file=os.path.join(csv_dir, split + ".csv"),
transform=augmentation[("eval" if split != "train" else split)],
columns=(None if not args.parents_x else ["eid"] + args.parents_x),
norm=(None if not hasattr(args, "context_norm") else args.context_norm),
concat_pa=False,
)
return datasets
def _load_uint8(f):
idx_dtype, ndim = struct.unpack("BBBB", f.read(4))[2:]
shape = struct.unpack(">" + "I" * ndim, f.read(4 * ndim))
buffer_length = int(np.prod(shape))
data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)
return data
def load_idx(path: str) -> np.ndarray:
"""Reads an array in IDX format from disk.
Parameters
----------
path : str
Path of the input file. Will uncompress with `gzip` if path ends in '.gz'.
Returns
-------
np.ndarray
Output array of dtype ``uint8``.
References
----------
http://yann.lecun.com/exdb/mnist/
"""
open_fcn = gzip.open if path.endswith(".gz") else open
with open_fcn(path, "rb") as f:
return _load_uint8(f)
def _get_paths(root_dir, train):
prefix = "train" if train else "t10k"
images_filename = prefix + "-images-idx3-ubyte.gz"
labels_filename = prefix + "-labels-idx1-ubyte.gz"
metrics_filename = prefix + "-morpho.csv"
images_path = os.path.join(root_dir, images_filename)
labels_path = os.path.join(root_dir, labels_filename)
metrics_path = os.path.join(root_dir, metrics_filename)
return images_path, labels_path, metrics_path
def load_morphomnist_like(
root_dir, train: bool = True, columns=None
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
"""
Args:
root_dir: path to data directory
train: whether to load the training subset (``True``, ``'train-*'`` files) or the test
subset (``False``, ``'t10k-*'`` files)
columns: list of morphometrics to load; by default (``None``) loads the image index and
all available metrics: area, length, thickness, slant, width, and height
Returns:
images, labels, metrics
"""
images_path, labels_path, metrics_path = _get_paths(root_dir, train)
images = load_idx(images_path)
labels = load_idx(labels_path)
if columns is not None and "index" not in columns:
usecols = ["index"] + list(columns)
else:
usecols = columns
metrics = pd.read_csv(metrics_path, usecols=usecols, index_col="index")
return images, labels, metrics
class MorphoMNIST(Dataset):
def __init__(
self,
root_dir,
train=True,
transform=None,
columns=None,
norm=None,
concat_pa=True,
):
self.train = train
self.transform = transform
self.columns = columns
self.concat_pa = concat_pa
self.norm = norm
cols_not_digit = [c for c in self.columns if c != "digit"]
images, labels, metrics_df = load_morphomnist_like(
root_dir, train, cols_not_digit
)
self.images = torch.from_numpy(np.array(images)).unsqueeze(1)
self.labels = F.one_hot(
torch.from_numpy(np.array(labels)).long(), num_classes=10
)
if self.columns is None:
self.columns = metrics_df.columns
self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit}
self.min_max = {
"thickness": [0.87598526, 6.255515],
"intensity": [66.601204, 254.90317],
}
for k, v in self.samples.items(): # optional preprocessing
print(f"{k} normalization: {norm}")
if norm == "[-1,1]":
self.samples[k] = normalize(
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1]
)
elif norm == "[0,1]":
self.samples[k] = normalize(
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True
)
elif norm == None:
pass
else:
NotImplementedError(f"{norm} not implemented.")
print(f"#samples: {len(metrics_df)}\n")
self.samples.update({"digit": self.labels})
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
sample = {}
sample["x"] = self.images[idx]
if self.transform is not None:
sample["x"] = self.transform(sample["x"])
if self.concat_pa:
sample["pa"] = torch.cat(
[
v[idx] if k == "digit" else torch.tensor([v[idx]])
for k, v in self.samples.items()
],
dim=0,
)
else:
sample.update({k: v[idx] for k, v in self.samples.items()})
return sample
def morphomnist(args):
# Load data
augmentation = {
"train": TF.Compose(
[
TF.RandomCrop((args.input_res, args.input_res), padding=args.pad),
]
),
"eval": TF.Compose(
[
TF.Pad(padding=2), # (32, 32)
]
),
}
datasets = {}
# for split in ['train', 'valid', 'test']:
for split in ["test"]:
datasets[split] = MorphoMNIST(
root_dir=args.data_dir,
train=(split == "train"), # test set is valid set
transform=augmentation[("eval" if split != "train" else split)],
columns=args.parents_x,
norm=args.context_norm,
concat_pa=False,
)
return datasets
def preproc_mimic(batch):
for k, v in batch.items():
if k == "x":
batch["x"] = (batch["x"].float() - 127.5) / 127.5 # [-1,1]
elif k in ["age"]:
batch[k] = batch[k].float().unsqueeze(-1)
batch[k] = batch[k] / 100.0
batch[k] = batch[k] * 2 - 1 # [-1,1]
elif k in ["race"]:
batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float()
elif k in ["finding"]:
batch[k] = batch[k].unsqueeze(-1).float()
else:
batch[k] = batch[k].float().unsqueeze(-1)
return batch
class MIMICDataset(Dataset):
def __init__(
self,
root,
csv_file,
transform=None,
columns=None,
concat_pa=True,
only_pleural_eff=True,
):
self.data = pd.read_csv(csv_file)
self.transform = transform
self.disease_labels = [
"No Finding",
"Other",
"Pleural Effusion",
# "Lung Opacity",
]
self.samples = {
"age": [],
"sex": [],
"finding": [],
"x": [],
"race": [],
# "lung_opacity": [],
# "pleural_effusion": [],
}
for idx, _ in enumerate(tqdm(range(len(self.data)), desc="Loading MIMIC Data")):
if only_pleural_eff and self.data.loc[idx, "disease"] == "Other":
continue
img_path = os.path.join(root, self.data.loc[idx, "path_preproc"])
# lung_opacity = self.data.loc[idx, "Lung Opacity"]
# self.samples["lung_opacity"].append(lung_opacity)
# pleural_effusion = self.data.loc[idx, "Pleural Effusion"]
# self.samples["pleural_effusion"].append(pleural_effusion)
disease = self.data.loc[idx, "disease"]
finding = 0 if disease == "No Finding" else 1
self.samples["x"].append(img_path)
self.samples["finding"].append(finding)
self.samples["age"].append(self.data.loc[idx, "age"])
self.samples["race"].append(self.data.loc[idx, "race_label"])
self.samples["sex"].append(self.data.loc[idx, "sex_label"])
self.columns = columns
if self.columns is None:
# ['age', 'race', 'sex']
self.columns = list(self.data.columns) # return all
self.columns.pop(0) # remove redundant 'index' column
self.concat_pa = concat_pa
def __len__(self):
return len(self.samples["x"])
def __getitem__(self, idx):
sample = {k: v[idx] for k, v in self.samples.items()}
sample["x"] = imread(sample["x"]).astype(np.float32)[None, ...]
for k, v in sample.items():
sample[k] = torch.tensor(v)
if self.transform:
sample["x"] = self.transform(sample["x"])
sample = preproc_mimic(sample)
if self.concat_pa:
sample["pa"] = torch.cat([sample[k] for k in self.columns], dim=0)
return sample
def mimic(args):
args.csv_dir = args.data_dir
datasets = {}
datasets["test"] = MIMICDataset(
root=args.data_dir,
csv_file=os.path.join(args.csv_dir, "mimic.sample.test.csv"),
columns=args.parents_x,
transform=TF.Compose(
[
TF.Resize((args.input_res, args.input_res), antialias=None),
]
),
concat_pa=False,
)
return datasets