import numpy as np import pandas as pd import streamlit as st from PIL import Image import sys import io import os import glob import json import zipfile from tqdm import tqdm from itertools import chain import torch from torch.utils.data import DataLoader sys.path.insert(0, os.path.abspath("src/")) from clip.clip import _transform from training.datasets import CellPainting from clip.model import convert_weights, CLIPGeneral from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem import AllChem from rdkit.Chem import DataStructs st.set_page_config(layout="wide") basepath = os.path.dirname(__file__) datapath = os.path.join(basepath, "data") CLOOME_PATH = "/home/ana/gitrepos/hti-cloob" MODEL_PATH = os.path.join(datapath, "epoch_55.pt") npzs = os.path.join(datapath, "npzs") molecule_features = os.path.join(datapath, "all_molecule_cellpainting_features.pkl") mol_index_file = os.path.join(datapath, "cellpainting-unique-molecule.csv") image_features = os.path.join(datapath, "subset_image_cellpainting_features.pkl") images_arr = os.path.join(datapath, "subset_npzs_dict_.npz") img_index_file = os.path.join(datapath, "cellpainting-all-imgpermol.csv") imgname = "I1" device = "cuda" if torch.cuda.is_available() else "cpu" model_type = "RN50" image_resolution = 520 ######### CLOOME FUNCTIONS ######### def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() def load(model_path, device, model, image_resolution): state_dict = torch.load(model_path, map_location=device) state_dict = state_dict["state_dict"] model_config_file = f"{model.replace('/', '-')}.json" print('Loading model from', model_config_file) assert os.path.exists(model_config_file) with open(model_config_file, 'r') as f: model_info = json.load(f) model = CLIPGeneral(**model_info) convert_weights(model) convert_models_to_fp32(model) if str(device) == "cpu": model.float() print(device) new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()} model.load_state_dict(new_state_dict) model.to(device) model.eval() return model def get_features(dataset, model, device): all_image_features = [] all_text_features = [] all_ids = [] print(f"get_features {device}") print(len(dataset)) with torch.no_grad(): for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)): if type(batch) is dict: imgs = batch text_features = None mols = None elif type(batch) is torch.Tensor: mols = batch imgs = None else: imgs, mols = batch if mols is not None: text_features = model.encode_text(mols.to(device)) text_features = text_features / text_features.norm(dim=-1, keepdim=True) all_text_features.append(text_features) molecules_exist = True if imgs is not None: images = imgs["input"] ids = imgs["ID"] img_features = model.encode_image(images.to(device)) img_features = img_features / img_features.norm(dim=-1, keepdim=True) all_image_features.append(img_features) all_ids.append(ids) all_ids = list(chain.from_iterable(all_ids)) if imgs is not None and mols is not None: return torch.cat(all_image_features), torch.cat(all_text_features), all_ids elif imgs is not None: return torch.cat(all_image_features), all_ids elif mols is not None: return torch.cat(all_text_features), all_ids return def read_array(file): t = torch.load(file) features = t["mol_features"] ids = t["mol_ids"] return features, ids def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None): # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" print(torch.cuda.device_count()) model = load(model_path, device, model, image_resolution) preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize") # Load the dataset val = CellPainting(df, img_path, mol_path, transforms = preprocess_val) # Calculate the image features print("getting_features") result = get_features(val, model, device) if len(result) > 2: val_img_features, val_text_features, val_ids = result return val_img_features, val_text_features, val_ids else: val_img_features, val_ids = result return val_img_features, val_ids def img_to_numpy(file): img = Image.open(file) arr = np.array(img) return arr def illumination_threshold(arr, perc=0.0028): """ Return threshold value to not display a percentage of highest pixels""" perc = perc/100 h = arr.shape[0] w = arr.shape[1] # find n pixels to delete total_pixels = h * w n_pixels = total_pixels * perc n_pixels = int(np.around(n_pixels)) # find indexes of highest pixels flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:] inds = np.array(np.unravel_index(flat_inds, arr.shape)).T max_values = [arr[i, j] for i, j in inds] threshold = min(max_values) return threshold def process_image(arr): threshold = illumination_threshold(arr) scaled_img = sixteen_to_eight_bit(arr, threshold) return scaled_img def sixteen_to_eight_bit(arr, display_max, display_min=0): threshold_image = ((arr.astype(float) - display_min) * (arr > display_min)) scaled_image = (threshold_image * (256. / (display_max - display_min))) scaled_image[scaled_image > 255] = 255 scaled_image = scaled_image.astype(np.uint8) return scaled_image def process_image(arr): threshold = illumination_threshold(arr) scaled_img = sixteen_to_eight_bit(arr, threshold) return scaled_img def process_sample(imglst, channels, filenames, outdir, outfile): sample = np.zeros((520, 696, 5), dtype=np.uint8) filenames_dict, channels_dict = {}, {} for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)): print(channel) arr = img_to_numpy(img) arr = process_image(arr) sample[:,:,i] = arr channels_dict[i] = channel filenames_dict[channel] = fname sample_dict = dict(sample=sample, channels=channels_dict, filenames=filenames_dict) outfile = outfile + ".npz" outpath = os.path.join(outdir, outfile) np.savez(outpath, sample=sample, channels=channels, filenames=filenames) return sample_dict, outpath def display_cellpainting(sample): arr = sample["sample"] r = arr[:, :, 0].astype(np.float32) g = arr[:, :, 3].astype(np.float32) b = arr[:, :, 4].astype(np.float32) rgb_arr = np.dstack((r, g, b)) im = Image.fromarray(rgb_arr.astype("uint8")) im_rgb = im.convert("RGB") return im_rgb def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True): mol = Chem.MolFromSmiles(smiles) fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral) arr = np.zeros((0,), dtype=np.int8) DataStructs.ConvertToNumpyArray(fp,arr) return arr def save_hdf(fps, index, outfile_hdf): ids = [i for i in range(len(fps))] columns = [str(i) for i in range(fps[0].shape[0])] df = pd.DataFrame(fps, index=ids, columns=columns) df.to_hdf(outfile_hdf, key="df", mode="w") return outfile_hdf def create_index(outdir, ids, filename): filepath = os.path.join(outdir, filename) if type(ids) is str: values = [ids] else: values = ids data = {"SAMPLE_KEY": values} print(data) df = pd.DataFrame(data) df.to_csv(filepath) return filepath def draw_molecules(smiles_lst): mols = [Chem.MolFromSmiles(s) for s in smiles_lst] mol_imgs = [Chem.Draw.MolToImage(m) for m in mols] return mol_imgs def reshape_image(arr): c, h, w = arr.shape reshaped_image = np.empty((h, w, c)) reshaped_image[:,:,0] = arr[0] reshaped_image[:,:,1] = arr[1] reshaped_image[:,:,2] = arr[2] reshaped_pil = Image.fromarray(reshaped_image.astype("uint8")) return reshaped_pil # missing functions: save morgan to to_hdf, create index, load features, calculate similarities ##### STREAMLIT FUNCTIONS ###### st.title('CLOOME. Bioimage database retrieval from chemical structures (and viceversa)') def main_page(top_n): st.markdown( """ Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing. With the availability of large collections of unlabeled data in vision and language, contrastive learning of language and image representations has shown impressive results. The contrastive learning methods CLIP and CLOOB have demonstrated that the learned representations are highly transferable to a large set of diverse tasks when trained on multi-modal data from two different domains. In life sciences, similar large, multi-modal datasets comprising both cell-based microscopy images and chemical structures of molecules are available. However, contrastive learning has not yet been used for this type of multi-modal data, although this would allow to design cross-modal retrieval systems for bioimaing and chemical databases. In this work, we present a such a contrastive learning method, the retrieval systems, and the transferability of the learned representations. Our method, Contrastive Learning and leave-One-Out-boost for Molecule Encoders (CLOOME), is based on both CLOOB and CLIP and comprises an encoder for microscopy data, an encoder for chemical structures and a contrastive learning objective, which produce rich embeddings of bioimages and chemical structures. On the benchmark dataset ”Cell Painting”, we demonstrate that the embddings can be used to form a retrieval system for bioimaging and chemical databases. We also show that CLOOME learns transferable representations by performing linear probing for activity prediction tasks. Furthermore, the image embeddings can identify new cell phenotypes, as we show in a zero-shot classification task. """ ) def molecules_from_image(top_n): ## TODO: Check if expander can be automatically collapsed exp = st.expander("Upload a microscopy image") with exp: channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst'] imglst, filenames = [], [] for c in channels: file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif") if file_obj is not None: imglst.append(file_obj) filenames.append(file_obj.name) if imglst: if not os.path.isdir(npzs): os.mkdir(npzs) sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname) print(imglst) i = display_cellpainting(sample_dict) st.image(i) uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)") if imglst: if uploaded_file is not None: molecule_df = pd.read_csv(uploaded_file) smiles = molecule_df["SMILES"].tolist() morgan = [morgan_from_smiles(s) for s in smiles] molnames = [f"M{i}" for i in range(len(morgan))] mol_index_fname = "mol_index.csv" mol_index = create_index(datapath, molnames, mol_index_fname) molpath = os.path.join(datapath, "mols.hdf") fps_fname = save_hdf(morgan, molnames, molpath) mol_imgs = draw_molecules(smiles) mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution) predefined_features = False else: mol_index = pd.read_csv(mol_index_file) mol_features_torch = torch.load(molecule_features, map_location=device) mol_features = mol_features_torch["mol_features"] mol_ids = mol_features_torch["mol_ids"] print(len(mol_ids)) predefined_features = True img_index_fname = "img_index.csv" img_index = create_index(datapath, imgname, img_index_fname) img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution) print(img_features.shape) print(mol_features.shape) logits = img_features @ mol_features.T mol_probs = (30.0 * logits).softmax(dim=-1) top_probs, top_labels = mol_probs.cpu().topk(5, dim=-1) # Delete this if want to allow retrieval for multiple images top_probs = torch.flatten(top_probs) top_labels = torch.flatten(top_labels) print(top_probs.shape) print(top_labels.shape) if predefined_features: mol_index.set_index(["SAMPLE_KEY"], inplace=True) top_ids = [mol_ids[i] for i in top_labels] smiles = mol_index.loc[top_ids]["SMILES"].tolist() mol_imgs = draw_molecules(smiles) with st.container(): if uploaded_file is not None: st.write("Retrieved molecules") else: st.write("Retrieved molecules from the Cell Painting database") columns = st.columns(len(top_probs)) for i, col in enumerate(columns): if predefined_features: image_id = i else: image_id = top_labels[i] index = i+1 col.image(mol_imgs[image_id], width=140, caption=index) print(mol_probs.sum(dim=-1)) print((top_probs, top_labels)) def images_from_molecule(top_n): #st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",) smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O") if smiles: smiles = [smiles] morgan = [morgan_from_smiles(s) for s in smiles] molnames = [f"M{i}" for i in range(len(morgan))] mol_index_fname = "mol_index.csv" mol_index = create_index(datapath, molnames, mol_index_fname) molpath = os.path.join(datapath, "mols.hdf") fps_fname = save_hdf(morgan, molnames, molpath) mol_imgs = draw_molecules(smiles) mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution) col1, col2, col3 = st.columns(3) with col1: st.write("") with col2: st.image(mol_imgs, width = 140) with col3: st.write("") st.markdown('##') top_n = int(top_n) img_features_torch = torch.load(image_features, map_location=device) img_features = img_features_torch["img_features"] img_ids = img_features_torch["img_ids"] logits = mol_features @ img_features.T img_probs = (30.0 * logits).softmax(dim=-1) top_probs, top_labels = img_probs.cpu().topk(top_n, dim=-1) top_probs = torch.flatten(top_probs) top_labels = torch.flatten(top_labels) img_index = pd.read_csv(img_index_file) img_index.set_index(["SAMPLE_KEY"], inplace=True) top_ids = [img_ids[i] for i in top_labels] images_dict = np.load(images_arr, allow_pickle = True) st.write("Retrieved images from the Cell Painting database") with st.container(): n_columns = 5 columns = st.columns(n_columns) for i, col in enumerate(columns): l = list(range(len(top_ids))) col_ids = [m for m in l if m % 5 == i] for n in col_ids: id = top_ids[n] id = f"{id}.npz" image = images_dict[id] ## TODO: generalize and functionalize im = reshape_image(image) index = n+1 col.image(im, caption=index) page_names_to_funcs = { "Microscopy images from a molecule": images_from_molecule, "Molecules from a microscopy image": molecules_from_image, "About CLOOME": main_page, } selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys()) st.sidebar.markdown('') n_objects = st.sidebar.selectbox( "How many objects would you like to retrieve?", ("5", "10", "20")) page_names_to_funcs[selected_page](n_objects)