cloome / app.py
Ana Sanchez
Add src
724c6a9
raw
history blame
16.1 kB
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
import sys
import io
import os
import glob
import json
import zipfile
from tqdm import tqdm
from itertools import chain
import torch
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.abspath("src/"))
from clip.clip import _transform
from training.datasets import CellPainting
from clip.model import convert_weights, CLIPGeneral
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
basepath = os.path.dirname(__file__)
MODEL_PATH = os.path.join(basepath, "epoch_55.pt")
CLOOME_PATH = "/home/ana/gitrepos/hti-cloob"
npzs = os.path.join(basepath, "npzs")
imgname = "I1"
molecule_features = "all_molecule_cellpainting_features.pkl"
image_features = "subset_image_cellpainting_features.pkl"
images_arr = "subset_npzs_dict_200.npz"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "RN50"
image_resolution = 520
######### CLOOME FUNCTIONS #########
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
def load(model_path, device, model, image_resolution):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = state_dict["state_dict"]
model_config_file = f"{model.replace('/', '-')}.json"
print('Loading model from', model_config_file)
assert os.path.exists(model_config_file)
with open(model_config_file, 'r') as f:
model_info = json.load(f)
model = CLIPGeneral(**model_info)
convert_weights(model)
convert_models_to_fp32(model)
if str(device) == "cpu":
model.float()
print(device)
new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
return model
def get_features(dataset, model, device):
all_image_features = []
all_text_features = []
all_ids = []
print(f"get_features {device}")
print(len(dataset))
with torch.no_grad():
for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)):
if type(batch) is dict:
imgs = batch
text_features = None
mols = None
elif type(batch) is torch.Tensor:
mols = batch
imgs = None
else:
imgs, mols = batch
if mols is not None:
text_features = model.encode_text(mols.to(device))
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
all_text_features.append(text_features)
molecules_exist = True
if imgs is not None:
images = imgs["input"]
ids = imgs["ID"]
img_features = model.encode_image(images.to(device))
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
all_image_features.append(img_features)
all_ids.append(ids)
all_ids = list(chain.from_iterable(all_ids))
if imgs is not None and mols is not None:
return torch.cat(all_image_features), torch.cat(all_text_features), all_ids
elif imgs is not None:
return torch.cat(all_image_features), all_ids
elif mols is not None:
return torch.cat(all_text_features), all_ids
return
def read_array(file):
t = torch.load(file)
features = t["mol_features"]
ids = t["mol_ids"]
return features, ids
def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None):
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.device_count())
model = load(model_path, device, model, image_resolution)
preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize")
# Load the dataset
val = CellPainting(df,
img_path,
mol_path,
transforms = preprocess_val)
# Calculate the image features
print("getting_features")
result = get_features(val, model, device)
if len(result) > 2:
val_img_features, val_text_features, val_ids = result
return val_img_features, val_text_features, val_ids
else:
val_img_features, val_ids = result
return val_img_features, val_ids
#val_img_features, val_ids = get_features(val, model, device)
#return val_img_features, val_text_features, val_ids
def img_to_numpy(file):
img = Image.open(file)
arr = np.array(img)
return arr
def illumination_threshold(arr, perc=0.0028):
""" Return threshold value to not display a percentage of highest pixels"""
perc = perc/100
h = arr.shape[0]
w = arr.shape[1]
# find n pixels to delete
total_pixels = h * w
n_pixels = total_pixels * perc
n_pixels = int(np.around(n_pixels))
# find indexes of highest pixels
flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:]
inds = np.array(np.unravel_index(flat_inds, arr.shape)).T
max_values = [arr[i, j] for i, j in inds]
threshold = min(max_values)
return threshold
def process_image(arr):
threshold = illumination_threshold(arr)
scaled_img = sixteen_to_eight_bit(arr, threshold)
return scaled_img
def sixteen_to_eight_bit(arr, display_max, display_min=0):
threshold_image = ((arr.astype(float) - display_min) * (arr > display_min))
scaled_image = (threshold_image * (256. / (display_max - display_min)))
scaled_image[scaled_image > 255] = 255
scaled_image = scaled_image.astype(np.uint8)
return scaled_image
def process_image(arr):
threshold = illumination_threshold(arr)
scaled_img = sixteen_to_eight_bit(arr, threshold)
return scaled_img
def process_sample(imglst, channels, filenames, outdir, outfile):
sample = np.zeros((520, 696, 5), dtype=np.uint8)
filenames_dict, channels_dict = {}, {}
for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)):
print(channel)
arr = img_to_numpy(img)
arr = process_image(arr)
sample[:,:,i] = arr
channels_dict[i] = channel
filenames_dict[channel] = fname
sample_dict = dict(sample=sample,
channels=channels_dict,
filenames=filenames_dict)
outfile = outfile + ".npz"
outpath = os.path.join(outdir, outfile)
np.savez(outpath, sample=sample, channels=channels, filenames=filenames)
return sample_dict, outpath
def display_cellpainting(sample):
arr = sample["sample"]
r = arr[:, :, 0].astype(np.float32)
g = arr[:, :, 3].astype(np.float32)
b = arr[:, :, 4].astype(np.float32)
rgb_arr = np.dstack((r, g, b))
im = Image.fromarray(rgb_arr.astype("uint8"))
im_rgb = im.convert("RGB")
return im_rgb
def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True):
mol = Chem.MolFromSmiles(smiles)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral)
arr = np.zeros((0,), dtype=np.int8)
DataStructs.ConvertToNumpyArray(fp,arr)
return arr
def save_hdf(fps, index, outfile_hdf):
ids = [i for i in range(len(fps))]
columns = [str(i) for i in range(fps[0].shape[0])]
df = pd.DataFrame(fps, index=ids, columns=columns)
df.to_hdf(outfile_hdf, key="df", mode="w")
return outfile_hdf
def create_index(outdir, ids, filename):
filepath = os.path.join(outdir, filename)
if type(ids) is str:
values = [ids]
else:
values = ids
data = {"SAMPLE_KEY": values}
print(data)
df = pd.DataFrame(data)
df.to_csv(filepath)
return filepath
def draw_molecules(smiles_lst):
mols = [Chem.MolFromSmiles(s) for s in smiles_lst]
mol_imgs = [Chem.Draw.MolToImage(m) for m in mols]
return mol_imgs
def reshape_image(arr):
c, h, w = arr.shape
reshaped_image = np.empty((h, w, c))
reshaped_image[:,:,0] = arr[0]
reshaped_image[:,:,1] = arr[1]
reshaped_image[:,:,2] = arr[2]
reshaped_pil = Image.fromarray(reshaped_image.astype("uint8"))
return reshaped_pil
# missing functions: save morgan to to_hdf, create index, load features, calculate similarities
#model = load(MODEL_PATH, device, model_type, image_resolution)
##### STREAMLIT FUNCTIONS ######
st.title('CLOOME: Contrastive Learning for Molecule Representation with Microscopy Images and Chemical Structures')
def main_page():
st.markdown(
"""
Contrastive learning for self-supervised representation learning has brought a
strong improvement to many application areas, such as computer vision and natural
language processing. With the availability of large collections of unlabeled data in
vision and language, contrastive learning of language and image representations
has shown impressive results. The contrastive learning methods CLIP and CLOOB
have demonstrated that the learned representations are highly transferable to a
large set of diverse tasks when trained on multi-modal data from two different
domains. In drug discovery, similar large, multi-modal datasets comprising both
cell-based microscopy images and chemical structures of molecules are available.
However, contrastive learning has not yet been used for this type of multi-modal data,
although transferable representations could be a remedy for the
time-consuming and cost-expensive label acquisition in this domain. In this work,
we present a contrastive learning method for image-based and structure-based
representations of small molecules for drug discovery.
Our method, Contrastive Leave One Out boost for Molecule Encoders (CLOOME), is based on CLOOB
and comprises an encoder for microscopy data, an encoder for chemical structures
and a contrastive learning objective. On the benchmark dataset ”Cell Painting”,
we demonstrate the ability of our method to learn transferable representations by
performing linear probing for activity prediction tasks. Additionally, we show that
the representations could also be useful for bioisosteric replacement tasks.
"""
)
def molecules_from_image():
## TODO: Check if expander can be automatically collapsed
exp = st.expander("Upload a microscopy image")
with exp:
channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst']
imglst, filenames = [], []
for c in channels:
file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif")
if file_obj is not None:
imglst.append(file_obj)
filenames.append(file_obj.name)
if imglst:
if not os.path.isdir(npzs):
os.mkdir(npzs)
sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname)
print(imglst)
i = display_cellpainting(sample_dict)
st.image(i)
uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)")
if imglst:
if uploaded_file is not None:
molecule_df = pd.read_csv(uploaded_file)
smiles = molecule_df["SMILES"].tolist()
morgan = [morgan_from_smiles(s) for s in smiles]
molnames = [f"M{i}" for i in range(len(morgan))]
mol_index_fname = "mol_index.csv"
mol_index = create_index(basepath, molnames, mol_index_fname)
molpath = os.path.join(basepath, "mols.hdf")
fps_fname = save_hdf(morgan, molnames, molpath)
mol_imgs = draw_molecules(smiles)
mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
predefined_features = False
else:
mol_index = pd.read_csv("cellpainting-unique-molecule.csv")
mol_features_torch = torch.load("all_molecule_cellpainting_features.pkl")
mol_features = mol_features_torch["mol_features"]
mol_ids = mol_features_torch["mol_ids"]
print(len(mol_ids))
predefined_features = True
img_index_fname = "img_index.csv"
img_index = create_index(basepath, imgname, img_index_fname)
img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution)
print(img_features.shape)
print(mol_features.shape)
logits = img_features @ mol_features.T
mol_probs = (30.0 * logits).softmax(dim=-1)
top_probs, top_labels = mol_probs.cpu().topk(5, dim=-1)
# Delete this if want to allow retrieval for multiple images
top_probs = torch.flatten(top_probs)
top_labels = torch.flatten(top_labels)
print(top_probs.shape)
print(top_labels.shape)
if predefined_features:
mol_index.set_index(["SAMPLE_KEY"], inplace=True)
top_ids = [mol_ids[i] for i in top_labels]
smiles = mol_index.loc[top_ids]["SMILES"].tolist()
mol_imgs = draw_molecules(smiles)
with st.container():
#st.write("Ranking of most similar molecules")
columns = st.columns(len(top_probs))
for i, col in enumerate(columns):
if predefined_features:
image_id = i
else:
image_id = top_labels[i]
index = i+1
col.image(mol_imgs[image_id], width=140, caption=index)
print(mol_probs.sum(dim=-1))
print((top_probs, top_labels))
def images_from_molecule():
smiles = st.text_input("Enter a SMILES string", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
if smiles:
smiles = [smiles]
morgan = [morgan_from_smiles(s) for s in smiles]
molnames = [f"M{i}" for i in range(len(morgan))]
mol_index_fname = "mol_index.csv"
mol_index = create_index(basepath, molnames, mol_index_fname)
molpath = os.path.join(basepath, "mols.hdf")
fps_fname = save_hdf(morgan, molnames, molpath)
mol_imgs = draw_molecules(smiles)
mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
col1, col2, col3 = st.columns(3)
with col1:
st.write("")
with col2:
st.image(mol_imgs, width = 140)
with col3:
st.write("")
img_features_torch = torch.load(image_features)
img_features = img_features_torch["img_features"]
img_ids = img_features_torch["img_ids"]
logits = mol_features @ img_features.T
img_probs = (30.0 * logits).softmax(dim=-1)
top_probs, top_labels = img_probs.cpu().topk(5, dim=-1)
top_probs = torch.flatten(top_probs)
top_labels = torch.flatten(top_labels)
img_index = pd.read_csv("cellpainting-all-imgpermol.csv")
img_index.set_index(["SAMPLE_KEY"], inplace=True)
top_ids = [img_ids[i] for i in top_labels]
images_dict = np.load(images_arr, allow_pickle = True)
with st.container():
columns = st.columns(len(top_probs))
for i, col in enumerate(columns):
id = top_ids[i]
id = f"{id}.npz"
image = images_dict[id]
## TODO: generalize and functionalize
im = reshape_image(image)
index = i+1
col.image(im, caption=index)
page_names_to_funcs = {
"-": main_page,
"Molecules from a microscopy image": molecules_from_image,
"Microscopy images from a molecule": images_from_molecule,
}
selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
page_names_to_funcs[selected_page]()
# print(img_features.shape)
# print(img_ids)