cloome / app.py
Sanchez fernandez
Include similarity score
79d82ee
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
import sys
import io
import os
import glob
import json
import zipfile
from tqdm import tqdm
from itertools import chain
import torch
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.abspath("src/"))
from clip.clip import _transform
from training.datasets import CellPainting
from clip.model import convert_weights, CLIPGeneral
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
st.set_page_config(layout="wide")
basepath = os.path.dirname(__file__)
datapath = os.path.join(basepath, "data")
CLOOME_PATH = "/home/ana/gitrepos/hti-cloob"
MODEL_PATH = os.path.join(datapath, "epoch_55.pt")
npzs = os.path.join(datapath, "npzs")
molecule_features = os.path.join(datapath, "all_molecule_cellpainting_features.pkl")
mol_index_file = os.path.join(datapath, "cellpainting-unique-molecule.csv")
image_features = os.path.join(datapath, "subset_image_cellpainting_features.pkl")
images_arr = os.path.join(datapath, "subset_npzs_dict_.npz")
img_index_file = os.path.join(datapath, "cellpainting-all-imgpermol.csv")
imgname = "I1"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "RN50"
######### CLOOME FUNCTIONS #########
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
def load(model_path, device, model, image_resolution):
state_dict = torch.load(model_path, map_location=device)
state_dict = state_dict["state_dict"]
model_config_file = f"{model.replace('/', '-')}.json"
print('Loading model from', model_config_file)
assert os.path.exists(model_config_file)
with open(model_config_file, 'r') as f:
model_info = json.load(f)
model = CLIPGeneral(**model_info)
convert_weights(model)
convert_models_to_fp32(model)
if str(device) == "cpu":
model.float()
print(device)
new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
return model
def get_features(dataset, model, device):
all_image_features = []
all_text_features = []
all_ids = []
print(f"get_features {device}")
print(len(dataset))
with torch.no_grad():
for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)):
if type(batch) is dict:
imgs = batch
text_features = None
mols = None
elif type(batch) is torch.Tensor:
mols = batch
imgs = None
else:
imgs, mols = batch
if mols is not None:
text_features = model.encode_text(mols.to(device))
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
all_text_features.append(text_features)
molecules_exist = True
if imgs is not None:
images = imgs["input"]
ids = imgs["ID"]
img_features = model.encode_image(images.to(device))
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
all_image_features.append(img_features)
all_ids.append(ids)
all_ids = list(chain.from_iterable(all_ids))
if imgs is not None and mols is not None:
return torch.cat(all_image_features), torch.cat(all_text_features), all_ids
elif imgs is not None:
return torch.cat(all_image_features), all_ids
elif mols is not None:
return torch.cat(all_text_features), all_ids
return
def read_array(file):
t = torch.load(file)
features = t["mol_features"]
ids = t["mol_ids"]
return features, ids
def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None):
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.device_count())
model = load(model_path, device, model, image_resolution)
preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize")
# Load the dataset
val = CellPainting(df,
img_path,
mol_path,
transforms = preprocess_val)
# Calculate the image features
print("getting_features")
result = get_features(val, model, device)
if len(result) > 2:
val_img_features, val_text_features, val_ids = result
return val_img_features, val_text_features, val_ids
else:
val_img_features, val_ids = result
return val_img_features, val_ids
def img_to_numpy(file):
img = Image.open(file)
arr = np.array(img)
return arr
def illumination_threshold(arr, perc=0.0028):
""" Return threshold value to not display a percentage of highest pixels"""
perc = perc/100
h = arr.shape[0]
w = arr.shape[1]
# find n pixels to delete
total_pixels = h * w
n_pixels = total_pixels * perc
n_pixels = int(np.around(n_pixels))
# find indexes of highest pixels
flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:]
inds = np.array(np.unravel_index(flat_inds, arr.shape)).T
max_values = [arr[i, j] for i, j in inds]
threshold = min(max_values)
return threshold
def process_image(arr):
threshold = illumination_threshold(arr)
scaled_img = sixteen_to_eight_bit(arr, threshold)
return scaled_img
def sixteen_to_eight_bit(arr, display_max, display_min=0):
threshold_image = ((arr.astype(float) - display_min) * (arr > display_min))
scaled_image = (threshold_image * (256. / (display_max - display_min)))
scaled_image[scaled_image > 255] = 255
scaled_image = scaled_image.astype(np.uint8)
return scaled_image
def process_image(arr):
threshold = illumination_threshold(arr)
scaled_img = sixteen_to_eight_bit(arr, threshold)
return scaled_img
def process_sample(imglst, channels, filenames, outdir, outfile):
sample = np.zeros((520, 696, 5), dtype=np.uint8)
filenames_dict, channels_dict = {}, {}
for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)):
print(channel)
arr = img_to_numpy(img)
arr = process_image(arr)
sample[:,:,i] = arr
channels_dict[i] = channel
filenames_dict[channel] = fname
sample_dict = dict(sample=sample,
channels=channels_dict,
filenames=filenames_dict)
outfile = outfile + ".npz"
outpath = os.path.join(outdir, outfile)
np.savez(outpath, sample=sample, channels=channels, filenames=filenames)
return sample_dict, outpath
def display_cellpainting(sample):
arr = sample["sample"]
r = arr[:, :, 0].astype(np.float32)
g = arr[:, :, 3].astype(np.float32)
b = arr[:, :, 4].astype(np.float32)
rgb_arr = np.dstack((r, g, b))
im = Image.fromarray(rgb_arr.astype("uint8"))
im_rgb = im.convert("RGB")
return im_rgb
def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True):
mol = Chem.MolFromSmiles(smiles)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral)
arr = np.zeros((0,), dtype=np.int8)
DataStructs.ConvertToNumpyArray(fp,arr)
return arr
def save_hdf(fps, index, outfile_hdf):
ids = [i for i in range(len(fps))]
columns = [str(i) for i in range(fps[0].shape[0])]
df = pd.DataFrame(fps, index=ids, columns=columns)
df.to_hdf(outfile_hdf, key="df", mode="w")
return outfile_hdf
def create_index(outdir, ids, filename):
filepath = os.path.join(outdir, filename)
if type(ids) is str:
values = [ids]
else:
values = ids
data = {"SAMPLE_KEY": values}
print(data)
df = pd.DataFrame(data)
df.to_csv(filepath)
return filepath
def draw_molecules(smiles_lst):
mols = [Chem.MolFromSmiles(s) for s in smiles_lst]
mol_imgs = [Chem.Draw.MolToImage(m) for m in mols]
return mol_imgs
def reshape_image(arr):
c, h, w = arr.shape
reshaped_image = np.empty((h, w, c))
reshaped_image[:,:,0] = arr[0]
reshaped_image[:,:,1] = arr[1]
reshaped_image[:,:,2] = arr[2]
reshaped_pil = Image.fromarray(reshaped_image.astype("uint8"))
return reshaped_pil
# missing functions: save morgan to to_hdf, create index, load features, calculate similarities
##### STREAMLIT FUNCTIONS ######
st.title('CLOOME. Bioimage database retrieval from chemical structures (and viceversa)')
def main_page(top_n, model_path):
st.markdown(
"""
Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing.
With the availability of large collections of unlabeled data in vision and language, contrastive learning of language and image representations has shown impressive results.
The contrastive learning methods CLIP and CLOOB have demonstrated that the learned representations are highly transferable to a large set of diverse tasks when trained on multi-modal data from two different domains.
In life sciences, similar large, multi-modal datasets comprising both cell-based microscopy images and chemical structures of molecules are available.
However, contrastive learning has not yet been used for this type of multi-modal data, although this would allow to design cross-modal retrieval systems for bioimaing and chemical databases.
In this work, we present a such a contrastive learning method, the retrieval systems, and the transferability of the learned representations. Our method, Contrastive Learning and leave-One-Out-boost for Molecule Encoders (CLOOME), is based on both CLOOB and CLIP and comprises an encoder for microscopy data, an encoder for chemical structures and a contrastive learning objective, which produce rich embeddings of bioimages and chemical structures.
On the benchmark dataset ”Cell Painting”, we demonstrate that the embddings can be used to form a retrieval system for bioimaging and chemical databases. We also show that CLOOME learns transferable representations by performing linear probing for activity prediction tasks. Furthermore, the image embeddings can identify new cell phenotypes, as we show in a zero-shot classification task.
"""
)
def molecules_from_image(top_n, model_path):
## TODO: Check if expander can be automatically collapsed
exp = st.expander("Upload a microscopy image")
with exp:
channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst']
imglst, filenames = [], []
for c in channels:
file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif")
if file_obj is not None:
imglst.append(file_obj)
filenames.append(file_obj.name)
if imglst:
if not os.path.isdir(npzs):
os.mkdir(npzs)
sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname)
print(imglst)
i = display_cellpainting(sample_dict)
st.image(i)
uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)")
if imglst:
if uploaded_file is not None:
molecule_df = pd.read_csv(uploaded_file)
smiles = molecule_df["SMILES"].tolist()
morgan = [morgan_from_smiles(s) for s in smiles]
molnames = [f"M{i}" for i in range(len(morgan))]
mol_index_fname = "mol_index.csv"
mol_index = create_index(datapath, molnames, mol_index_fname)
molpath = os.path.join(datapath, "mols.hdf")
fps_fname = save_hdf(morgan, molnames, molpath)
mol_imgs = draw_molecules(smiles)
mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution)
predefined_features = False
else:
mol_index = pd.read_csv(mol_index_file)
mol_features_torch = torch.load(molecule_features, map_location=device)
mol_features = mol_features_torch["mol_features"]
mol_ids = mol_features_torch["mol_ids"]
print(len(mol_ids))
predefined_features = True
img_index_fname = "img_index.csv"
img_index = create_index(datapath, imgname, img_index_fname)
img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution)
print(img_features.shape)
print(mol_features.shape)
top_n = int(top_n)
logits = img_features @ mol_features.T
mol_probs = (30.0 * logits).softmax(dim=-1)
top_probs, top_labels = mol_probs.cpu().topk(top_n, dim=-1)
# Delete this if want to allow retrieval for multiple images
top_probs = torch.flatten(top_probs)
top_labels = torch.flatten(top_labels)
print(top_probs.shape)
print(top_labels.shape)
if predefined_features:
mol_index.set_index(["SAMPLE_KEY"], inplace=True)
top_ids = [mol_ids[i] for i in top_labels]
smiles = mol_index.loc[top_ids]["SMILES"].tolist()
mol_imgs = draw_molecules(smiles)
with st.container():
if uploaded_file is not None:
st.write("Retrieved molecules")
else:
st.write("Retrieved molecules from the Cell Painting database")
columns = st.columns(len(top_probs))
for i, col in enumerate(columns):
if predefined_features:
image_id = i
else:
image_id = top_labels[i]
index = i+1
col.image(mol_imgs[image_id], width=140, caption=index)
print(mol_probs.sum(dim=-1))
print((top_probs, top_labels))
def images_from_molecule(top_n, model_path):
#st.markdown("Enter a query molecule in [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) format",)
smiles = st.text_input("Enter a query molecule in SMILES format", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
if smiles:
smiles = [smiles]
morgan = [morgan_from_smiles(s) for s in smiles]
molnames = [f"M{i}" for i in range(len(morgan))]
mol_index_fname = "mol_index.csv"
mol_index = create_index(datapath, molnames, mol_index_fname)
molpath = os.path.join(datapath, "mols.hdf")
fps_fname = save_hdf(morgan, molnames, molpath)
mol_imgs = draw_molecules(smiles)
mol_features, mol_ids = main(mol_index, model_path, model_type, mol_path=molpath, image_resolution=image_resolution)
col1, col2, col3 = st.columns(3)
with col1:
st.write("")
with col2:
st.image(mol_imgs, width = 140)
with col3:
st.write("")
st.markdown('##')
top_n = int(top_n)
img_features_torch = torch.load(image_features, map_location=device)
img_features = img_features_torch["img_features"]
img_ids = img_features_torch["img_ids"]
logits = mol_features @ img_features.T
img_probs = (30.0 * logits).softmax(dim=-1)
top_probs, top_labels = img_probs.cpu().topk(top_n, dim=-1)
top_probs = torch.flatten(top_probs)
top_labels = torch.flatten(top_labels)
img_index = pd.read_csv(img_index_file)
img_index.set_index(["SAMPLE_KEY"], inplace=True)
top_ids = [img_ids[i] for i in top_labels]
images_dict = np.load(images_arr, allow_pickle = True)
st.write("Retrieved images from the Cell Painting database")
with st.container():
n_columns = 5
columns = st.columns(n_columns)
for i, col in enumerate(columns):
l = list(range(len(top_ids)))
col_ids = [m for m in l if m % 5 == i]
for n in col_ids:
id = top_ids[n]
id = f"{id}.npz"
image = images_dict[id]
## TODO: generalize and functionalize
im = reshape_image(image)
index = n+1
score = float(top_probs[n]) * 100
col.image(im, caption=f"Top {index}. Score: {round(score, 2)}", width=200)
page_names_to_funcs = {
"Microscopy images from a molecule": images_from_molecule,
"Molecules from a microscopy image": molecules_from_image,
"About CLOOME": main_page,
}
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
st.sidebar.markdown('')
n_objects = st.sidebar.selectbox(
"How many objects would you like to retrieve?",
("5", "10", "20"))
selected_model = st.sidebar.selectbox(
"Select a CLOOME model to load",
("CLOOME (default)", "CLOOME (CLIP imgres 320)", "CLOOME (fullres)", "CLOOME (imgres 320)"))
model_dict = {
"CLOOME (default)" : "cloome_default.pt",
"CLOOME (CLIP imgres 320)" : "cloome_cli_imres320.pt",
"CLOOME (fullres)" : "cloome_fullres.pt",
"CLOOME (imgres 320)" : "cloome_imres320.pt"
}
model_file = model_dict[selected_model]
model_path = os.path.join(datapath, model_file)
if model_path.endswith("320).pt"):
image_resolution = 320
else:
image_resolution = 520
page_names_to_funcs[selected_page](n_objects, model_path)