Mol-MoE: Training Preference-Guided Routers for Molecule Generation

Diego Calanzone (1, 2), Pierluca D'Oro (2), Pierre-Luc Bacon (1, 2)
(1) Universite de Montreal, (2) Mila Quebec AI Institute

Abstract: Recent advances in language models have enabled framing molecule generation as sequence modeling. However, existing approaches often rely on single-objective reinforcement learning, limiting their applicability to real-world drug design, where multiple competing properties must be optimized. Traditional multi-objective reinforcement learning (MORL) methods require costly retraining for each new objective combination, making rapid exploration of trade-offs impractical. To overcome these limitations, we introduce Mol-MoE, a mixture-of-experts (MoE) architecture that enables efficient test-time steering of molecule generation without retraining. Central to our approach is a preference-based router training objective that incentivizes the router to combine experts in a way that aligns with user-specified trade-offs. This provides improved flexibility in exploring the chemical property space at test time, facilitating rapid trade-off exploration. Benchmarking against state-of-the-art methods, we show that Mol-MoE achieves superior sample quality and steerability.

How to use this model

This LM is fine-tuned to generate molecules in the SMILES format wrt. desired properties. For unconditioned SMILES generation, use the BOS token <s>.
For conditioned generation, you can target the following properties: JNK3, DRD2, GSK3B, CYP2D6, CYP2C19.

prompt: <JNK3=0.3><DRD2=0.7><GSK3B=0.2><CYP2D6=0.8><CYP2C19=0.8><s> 

An example of the generation pipeline:

from transformers import AutoTokenizer, AutoModelForCausalLM
import re

# Setup
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("ddidacus/RiC-mol-llama-1b")
model = AutoModelForCausalLM.from_pretrained("ddidacus/RiC-mol-llama-1b")
generation_kwargs = {
    "max_new_tokens": 128,
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 0.9,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "temperature": 1.0
}

# Inference
query = "<JNK3=0.3><DRD2=0.7><GSK3B=0.2><CYP2D6=0.8><CYP2C19=0.8><s>"
toks = tokenizer([query], return_tensors="pt")["input_ids"].to(device)
output = model.generate(toks, **generation_kwargs)
output = tokenizer.batch_decode(output)

# Parsing
filter = r'<s>(.*?)</s>'
molecule = re.findall(filter, output[0], re.DOTALL)

Model Description

This model is a fine-tuned version of LLaMa 3.2 1B through two stages:

  1. Fine-tuning on ~3.5M molecules extracted from: ZINC 250K, MOSES, CHEMBL
  2. RLHF-tuning using instruction fine-tuning on 5 distinct reward signals.

The detailed pipeline we followed is reported in the original paper:
"Rewards-in-Context: Multi-objective Alignment of Foundation Models with Dynamic Preference Adjustment" Yang et al. 2024 [1]

  • Developed by: Diego Calanzone ([email protected])
  • Model type: Decoder Only Transformer
  • Finetuned from model [optional]: LLaMA 3.2 1B

Read the paper for further details.

Sources

[1] https://arxiv.org/abs/2402.10207

Downloads last month
11
Safetensors
Model size
1.24B params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Collection including ddidacus/RiC-mol-llama-1b