|
import spaces |
|
import gradio as gr |
|
import torch |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
|
|
|
|
|
|
def load_graph_decoder(path='model_labeled'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
return model |
|
|
|
model = load_graph_decoder() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@spaces.GPU |
|
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale): |
|
properties = [CH4, CO2, H2, N2, O2] |
|
|
|
try: |
|
print('enter generate polymer') |
|
model.to(device) |
|
generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale) |
|
|
|
if generated_molecule is not None: |
|
mol = Chem.MolFromSmiles(generated_molecule) |
|
if mol is not None: |
|
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True) |
|
img = Draw.MolToImage(mol) |
|
return standardized_smiles, img |
|
except Exception as e: |
|
print(f"Error in generation: {e}") |
|
|
|
return "Generation failed", None |
|
|
|
|
|
with gr.Blocks(title="Simplified Polymer Design") as iface: |
|
gr.Markdown("## Polymer Design with GraphDiT") |
|
|
|
with gr.Row(): |
|
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)") |
|
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)") |
|
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)") |
|
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)") |
|
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)") |
|
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale") |
|
|
|
generate_btn = gr.Button("Generate Polymer") |
|
|
|
with gr.Row(): |
|
result_smiles = gr.Textbox(label="Generated SMILES") |
|
result_image = gr.Image(label="Molecule Visualization", type="pil") |
|
|
|
generate_btn.click( |
|
generate_polymer, |
|
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale], |
|
outputs=[result_smiles, result_image] |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |