|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from transformers import pipeline |
|
import gradio as gr |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem.Draw import SimilarityMaps |
|
import io |
|
from PIL import Image |
|
import numpy as np |
|
import rdkit |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers_interpret import SequenceClassificationExplainer |
|
|
|
model_name = "FartLabs/FART_Augmented" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
cls_explainer = SequenceClassificationExplainer(model, tokenizer) |
|
|
|
def save_high_quality_png(smiles, title, bw=True, padding=0.05): |
|
""" |
|
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule. |
|
Parameters: |
|
- smiles (str): The SMILES string of the molecule to visualize. |
|
- token_importance (list): List of importance scores for each atom. |
|
- bw (bool): If True, renders the molecule in black and white. |
|
- padding (float): Padding for molecule drawing. |
|
- output_file (str): Path to save the high-quality PNG file. |
|
Returns: |
|
- None |
|
""" |
|
|
|
|
|
molecule = Chem.MolFromSmiles(smiles) |
|
Chem.rdDepictor.Compute2DCoords(molecule) |
|
|
|
|
|
token_importance = cls_explainer(smiles) |
|
atom_importance = [c[1] for c in token_importance if c[0].isalpha()] |
|
num_atoms = molecule.GetNumAtoms() |
|
atom_importance = atom_importance[:num_atoms] |
|
|
|
|
|
d = Draw.MolDraw2DCairo(1500, 1500) |
|
|
|
dopts = d.drawOptions() |
|
dopts.padding = padding |
|
dopts.maxFontSize = 2000 |
|
dopts.bondLineWidth = 5 |
|
|
|
|
|
if bw: |
|
d.drawOptions().useBWAtomPalette() |
|
|
|
|
|
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d) |
|
|
|
|
|
d.FinishDrawing() |
|
|
|
|
|
with open(f"{title}.png", "wb") as png_file: |
|
png_file.write(d.GetDrawingText()) |
|
|
|
return None |
|
|
|
model_checkpoint = "FartLabs/FART_Augmented" |
|
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None) |
|
|
|
def process_smiles(smiles, compute_explanation): |
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
return "Invalid SMILES", None, "Invalid SMILES" |
|
canonical_smiles = Chem.MolToSmiles(mol) |
|
|
|
|
|
predictions = classifier(canonical_smiles) |
|
|
|
|
|
if compute_explanation: |
|
img_path = "molecule" |
|
filepath= "molecule.png" |
|
save_high_quality_png(smiles, img_path) |
|
else: |
|
filepath = "molecule.png" |
|
img = Draw.MolToImage(mol) |
|
img.save(filepath) |
|
|
|
|
|
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]} |
|
|
|
return prediction_dict, filepath, canonical_smiles |
|
|
|
iface = gr.Interface( |
|
fn=process_smiles, |
|
inputs=[ |
|
gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"), |
|
gr.Checkbox(label="Display explanation (can take some time)", value=False), |
|
], |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label="Classification Probabilities"), |
|
gr.Image(type="filepath", label="Molecule Image"), |
|
gr.Textbox(label="Canonical SMILES") |
|
], |
|
description=""" |
|
<section id="molecular-taste-description"> |
|
<h2>Discover Molecular Taste with FART</h2> |
|
<p> |
|
At Kvant AI Labs, we just revolutionized taste chemistry with FART (Flavor Analysis and Recognition Transformer), an AI-powered tool designed to predict molecular taste from chemical structure alone. FART delivers predictions for <strong>sweet</strong>, <strong>bitter</strong>, <strong>sour</strong>, and <strong>umami</strong> with over 91% accuracy. |
|
</p> |
|
<p> |
|
Beyond predictions, FART identifies the molecular features driving taste characteristics, enabling actionable insights for flavor innovation. Powered by the ChemBERTa foundation model and trained on the largest molecular taste dataset to date, FART sets a new standard in food science. |
|
</p> |
|
<p> |
|
Learn more about the science behind FART in our <a href="https://chemrxiv.org/engage/chemrxiv/article-details/673a2a3af9980725cf80503c" target="_blank">Pre-print</a>. To generate SMILES, one possible option is this <a href="https://www.cheminfo.org/flavor/malaria/Utilities/SMILES_generator___checker/index.html" target="_blank">tool</a>. |
|
</p> |
|
</section> |
|
""", |
|
) |
|
|
|
iface.launch() |
|
|