Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from PIL import Image | |
import sys | |
import io | |
import os | |
import glob | |
import json | |
import zipfile | |
from tqdm import tqdm | |
from itertools import chain | |
import torch | |
from torch.utils.data import DataLoader | |
sys.path.insert(0, os.path.abspath("src/")) | |
from clip.clip import _transform | |
from training.datasets import CellPainting | |
from clip.model import convert_weights, CLIPGeneral | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem import AllChem | |
from rdkit.Chem import DataStructs | |
st.set_page_config(layout="wide") | |
basepath = os.path.dirname(__file__) | |
datapath = os.path.join(basepath, "data") | |
CLOOME_PATH = "/home/ana/gitrepos/hti-cloob" | |
MODEL_PATH = os.path.join(datapath, "epoch_55.pt") | |
npzs = os.path.join(datapath, "npzs") | |
molecule_features = os.path.join(datapath, "all_molecule_cellpainting_features.pkl") | |
mol_index_file = os.path.join(datapath, "cellpainting-unique-molecule.csv") | |
image_features = os.path.join(datapath, "subset_image_cellpainting_features.pkl") | |
images_arr = os.path.join(datapath, "subset_npzs_dict_.npz") | |
img_index_file = os.path.join(datapath, "cellpainting-all-imgpermol.csv") | |
imgname = "I1" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_type = "RN50" | |
######### CLOOME FUNCTIONS ######### | |
def convert_models_to_fp32(model): | |
for p in model.parameters(): | |
p.data = p.data.float() | |
if p.grad: | |
p.grad.data = p.grad.data.float() | |
def load(model_path, device, model, image_resolution): | |
state_dict = torch.load(model_path, map_location=device) | |
state_dict = state_dict["state_dict"] | |
model_config_file = f"{model.replace('/', '-')}.json" | |
print('Loading model from', model_config_file) | |
assert os.path.exists(model_config_file) | |
with open(model_config_file, 'r') as f: | |
model_info = json.load(f) | |
model = CLIPGeneral(**model_info) | |
convert_weights(model) | |
convert_models_to_fp32(model) | |
if str(device) == "cpu": | |
model.float() | |
print(device) | |
new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()} | |
model.load_state_dict(new_state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def get_features(dataset, model, device): | |
all_image_features = [] | |
all_text_features = [] | |
all_ids = [] | |
print(f"get_features {device}") | |
print(len(dataset)) | |
with torch.no_grad(): | |
for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)): | |
if type(batch) is dict: | |
imgs = batch | |
text_features = None | |
mols = None | |
elif type(batch) is torch.Tensor: | |
mols = batch | |
imgs = None | |
else: | |
imgs, mols = batch | |
if mols is not None: | |
text_features = model.encode_text(mols.to(device)) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
all_text_features.append(text_features) | |
molecules_exist = True | |
if imgs is not None: | |
images = imgs["input"] | |
ids = imgs["ID"] | |
img_features = model.encode_image(images.to(device)) | |
img_features = img_features / img_features.norm(dim=-1, keepdim=True) | |
all_image_features.append(img_features) | |
all_ids.append(ids) | |
all_ids = list(chain.from_iterable(all_ids)) | |
if imgs is not None and mols is not None: | |
return torch.cat(all_image_features), torch.cat(all_text_features), all_ids | |
elif imgs is not None: | |
return torch.cat(all_image_features), all_ids | |
elif mols is not None: | |
return torch.cat(all_text_features), all_ids | |
return | |
def read_array(file): | |
t = torch.load(file) | |
features = t["mol_features"] | |
ids = t["mol_ids"] | |
return features, ids | |
def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None): | |
# Load the model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(torch.cuda.device_count()) | |
model = load(model_path, device, model, image_resolution) | |
preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize") | |
# Load the dataset | |
val = CellPainting(df, | |
img_path, | |
mol_path, | |
transforms = preprocess_val) | |
# Calculate the image features | |
print("getting_features") | |
result = get_features(val, model, device) | |
if len(result) > 2: | |
val_img_features, val_text_features, val_ids = result | |
return val_img_features, val_text_features, val_ids | |
else: | |
val_img_features, val_ids = result | |
return val_img_features, val_ids | |
def img_to_numpy(file): | |
img = Image.open(file) | |
arr = np.array(img) | |
return arr | |
def illumination_threshold(arr, perc=0.0028): | |
""" Return threshold value to not display a percentage of highest pixels""" | |
perc = perc/100 | |
h = arr.shape[0] | |
w = arr.shape[1] | |
# find n pixels to delete | |
total_pixels = h * w | |
n_pixels = total_pixels * perc | |
n_pixels = int(np.around(n_pixels)) | |
# find indexes of highest pixels | |
flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:] | |
inds = np.array(np.unravel_index(flat_inds, arr.shape)).T | |
max_values = [arr[i, j] for i, j in inds] | |
threshold = min(max_values) | |
return threshold | |
def process_image(arr): | |
threshold = illumination_threshold(arr) | |
scaled_img = sixteen_to_eight_bit(arr, threshold) | |
return scaled_img | |
def sixteen_to_eight_bit(arr, display_max, display_min=0): | |
threshold_image = ((arr.astype(float) - display_min) * (arr > display_min)) | |
scaled_image = (threshold_image * (256. / (display_max - display_min))) | |
scaled_image[scaled_image > 255] = 255 | |
scaled_image = scaled_image.astype(np.uint8) | |
return scaled_image | |
def process_image(arr): | |
threshold = illumination_threshold(arr) | |
scaled_img = sixteen_to_eight_bit(arr, threshold) | |
return scaled_img | |
def process_sample(imglst, channels, filenames, outdir, outfile): | |
sample = np.zeros((520, 696, 5), dtype=np.uint8) | |
filenames_dict, channels_dict = {}, {} | |
for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)): | |
print(channel) | |
arr = img_to_numpy(img) | |
arr = process_image(arr) | |
sample[:,:,i] = arr | |
channels_dict[i] = channel | |
filenames_dict[channel] = fname | |
sample_dict = dict(sample=sample, | |
channels=channels_dict, | |
filenames=filenames_dict) | |
outfile = outfile + ".npz" | |
outpath = os.path.join(outdir, outfile) | |
np.savez(outpath, sample=sample, channels=channels, filenames=filenames) | |
return sample_dict, outpath | |
def display_cellpainting(sample): | |
arr = sample["sample"] | |
r = arr[:, :, 0].astype(np.float32) | |
g = arr[:, :, 3].astype(np.float32) | |
b = arr[:, :, 4].astype(np.float32) | |
rgb_arr = np.dstack((r, g, b)) | |
im = Image.fromarray(rgb_arr.astype("uint8")) | |
im_rgb = im.convert("RGB") | |
return im_rgb | |
def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True): | |
mol = Chem.MolFromSmiles(smiles) | |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral) | |
arr = np.zeros((0,), dtype=np.int8) | |
DataStructs.ConvertToNumpyArray(fp,arr) | |
return arr | |
def save_hdf(fps, index, outfile_hdf): | |
ids = [i for i in range(len(fps))] | |
columns = [str(i) for i in range(fps[0].shape[0])] | |
df = pd.DataFrame(fps, index=ids, columns=columns) | |
df.to_hdf(outfile_hdf, key="df", mode="w") | |
return outfile_hdf | |
def create_index(outdir, ids, filename): | |
filepath = os.path.join(outdir, filename) | |
if type(ids) is str: | |
values = [ids] | |
else: | |
values = ids | |
data = {"SAMPLE_KEY": values} | |
print(data) | |
df = pd.DataFrame(data) | |
df.to_csv(filepath) | |
return filepath | |
def draw_molecules(smiles_lst): | |
mols = [Chem.MolFromSmiles(s) for s in smiles_lst] | |
mol_imgs = [Chem.Draw.MolToImage(m) for m in mols] | |
return mol_imgs | |
def reshape_image(arr): | |
c, h, w = arr.shape | |
reshaped_image = np.empty((h, w, c)) | |
reshaped_image[:,:,0] = arr[0] | |
reshaped_image[:,:,1] = arr[1] | |
reshaped_image[:,:,2] = arr[2] | |
reshaped_pil = Image.fromarray(reshaped_image.astype("uint8")) | |
return reshaped_pil | |
# missing functions: save morgan to to_hdf, create index, load features, calculate similarities | |
##### STREAMLIT FUNCTIONS ###### | |
st.title('CLOOME. Bioimage database retrieval from chemical structures (and viceversa)') | |
def main_page(top_n, model_path): | |
st.markdown( | |
""" | |
Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing. | |
With the availability of large collections of unlabeled data in vision and language, contrastive learning of language and image representations has shown impressive results. | |
The contrastive learning methods CLIP and CLOOB have demonstrated that the learned representations are highly transferable to a large set of diverse tasks when trained on multi-modal data from two different domains. | |
In life sciences, similar large, multi-modal datasets comprising both cell-based microscopy images and chemical structures of molecules are available. | |
However, contrastive learning has not yet been used for this type of multi-modal data, although this would allow to design cross-modal retrieval systems for bioimaing and chemical databases. | |
In this work, we present a such a contrastive learning method, the retrieval systems, and the transferability of the learned representations. Our method, Contrastive Learning and leave-One-Out-boost for Molecule Encoders (CLOOME), is based on both CLOOB and CLIP and comprises an encoder for microscopy data, an encoder for chemical structures and a contrastive learning objective, which produce rich embeddings of bioimages and chemical structures. | |
On the benchmark dataset ”Cell Painting”, we demonstrate that the embddings can be used to form a retrieval system for bioimaging and chemical databases. We also show that CLOOME learns transferable representations by performing linear probing for activity prediction tasks. Furthermore, the image embeddings can identify new cell phenotypes, as we show in a zero-shot classification task. | |
""" | |
) | |
def molecules_from_image(top_n, model_path): | |
## TODO: Check if expander can be automatically collapsed | |
exp = st.expander("Upload a microscopy image") | |
with exp: | |
channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst'] | |
imglst, filenames = [], [] | |
for c in channels: | |
file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif") | |
if file_obj is not None: | |
imglst.append(file_obj) | |
filenames.append(file_obj.name) | |
if imglst: | |
if not os.path.isdir(npzs): | |
os.mkdir(npzs) | |
sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname) | |
print(imglst) | |
i = display_cellpainting(sample_dict) | |
st.image(i) | |
uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)") | |
if imglst: | |
if uploaded_file is not None: | |
molecule_df = pd.read_csv(uploaded_file) | |
smiles = molecule_df["SMILES"].tolist() | |
morgan = [morgan_from_smiles(s) for s in smiles] | |
molnames = [f"M{i}" for i in range(len(morgan))] | |
mol_index_fname = "mol_index.csv" | |
mol_index = create_index(datapath, molnames, mol_index_fname) | |
molpath = os.path.join(datapath, "mols.hdf") | |
fps_fname = save_hdf(morgan, molnames, molpath) | |
mol_imgs = draw_molecules(smiles) | |
mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution) | |
predefined_features = False | |
else: | |
mol_index = pd.read_csv(mol_index_file) | |
mol_features_torch = torch.load(molecule_features, map_location=device) | |
mol_features = mol_features_torch["mol_features"] | |
mol_ids = mol_features_torch["mol_ids"] | |
print(len(mol_ids)) | |
predefined_features = True | |
img_index_fname = "img_index.csv" | |
img_index = create_index(datapath, imgname, img_index_fname) | |
img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution) | |
print(img_features.shape) | |
print(mol_features.shape) | |
top_n = int(top_n) | |
logits = img_features @ mol_features.T | |
mol_probs = (30.0 * logits).softmax(dim=-1) | |
top_probs, top_labels = mol_probs.cpu().topk(top_n, dim=-1) | |
# Delete this if want to allow retrieval for multiple images | |
top_probs = torch.flatten(top_probs) | |
top_labels = torch.flatten(top_labels) | |
print(top_probs.shape) | |
print(top_labels.shape) | |
if predefined_features: | |
mol_index.set_index(["SAMPLE_KEY"], inplace=True) | |
top_ids = [mol_ids[i] for i in top_labels] | |
smiles = mol_index.loc[top_ids]["SMILES"].tolist() | |
mol_imgs = draw_molecules(smiles) | |
with st.container(): | |
if uploaded_file is not None: | |
st.write("Retrieved molecules") | |
else: | |
st.write("Retrieved molecules from the Cell Painting database") | |
columns = st.columns(len(top_probs)) | |
for i, col in enumerate(columns): | |
if predefined_features: | |
image_id = i | |
else: | |
image_id = top_labels[i] | |
index = i+1 | |
col.image(mol_imgs[image_id], width=140, caption=index) | |
print(mol_probs.sum(dim=-1)) | |
print((top_probs, top_labels)) | |
def images_from_molecule(top_n, model_path): | |
#st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",) | |
smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O") | |
if smiles: | |
smiles = [smiles] | |
morgan = [morgan_from_smiles(s) for s in smiles] | |
molnames = [f"M{i}" for i in range(len(morgan))] | |
mol_index_fname = "mol_index.csv" | |
mol_index = create_index(datapath, molnames, mol_index_fname) | |
molpath = os.path.join(datapath, "mols.hdf") | |
fps_fname = save_hdf(morgan, molnames, molpath) | |
mol_imgs = draw_molecules(smiles) | |
mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.write("") | |
with col2: | |
st.image(mol_imgs, width = 140) | |
with col3: | |
st.write("") | |
st.markdown('##') | |
top_n = int(top_n) | |
img_features_torch = torch.load(image_features, map_location=device) | |
img_features = img_features_torch["img_features"] | |
img_ids = img_features_torch["img_ids"] | |
logits = mol_features @ img_features.T | |
img_probs = (30.0 * logits).softmax(dim=-1) | |
top_probs, top_labels = img_probs.cpu().topk(top_n, dim=-1) | |
top_probs = torch.flatten(top_probs) | |
top_labels = torch.flatten(top_labels) | |
img_index = pd.read_csv(img_index_file) | |
img_index.set_index(["SAMPLE_KEY"], inplace=True) | |
top_ids = [img_ids[i] for i in top_labels] | |
images_dict = np.load(images_arr, allow_pickle = True) | |
st.write("Retrieved images from the Cell Painting database") | |
with st.container(): | |
n_columns = 5 | |
columns = st.columns(n_columns) | |
for i, col in enumerate(columns): | |
l = list(range(len(top_ids))) | |
col_ids = [m for m in l if m % 5 == i] | |
for n in col_ids: | |
id = top_ids[n] | |
id = f"{id}.npz" | |
image = images_dict[id] | |
## TODO: generalize and functionalize | |
im = reshape_image(image) | |
index = n+1 | |
score = float(top_probs[n]) * 100 | |
col.image(im, caption=f"Top {index}. Score: {round(score, 2)}", width=200) | |
page_names_to_funcs = { | |
"Microscopy images from a molecule": images_from_molecule, | |
"Molecules from a microscopy image": molecules_from_image, | |
"About CLOOME": main_page, | |
} | |
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys()) | |
st.sidebar.markdown('') | |
n_objects = st.sidebar.selectbox( | |
"How many objects would you like to retrieve?", | |
("5", "10", "20")) | |
selected_model = st.sidebar.selectbox( | |
"Select a CLOOME model to load", | |
("CLOOME (default)", "CLOOME (CLIP imgres 320)", "CLOOME (fullres)", "CLOOME (imgres 320)")) | |
model_dict = { | |
"CLOOME (default)" : "cloome_default.pt", | |
"CLOOME (CLIP imgres 320)" : "cloome_cli_imres320.pt", | |
"CLOOME (fullres)" : "cloome_fullres.pt", | |
"CLOOME (imgres 320)" : "cloome_imres320.pt" | |
} | |
model_file = model_dict[selected_model] | |
model_path = os.path.join(datapath, model_file) | |
if model_path.endswith("320).pt"): | |
image_resolution = 320 | |
else: | |
image_resolution = 520 | |
page_names_to_funcs[selected_page](n_objects, model_path) | |