liuganghuggingface commited on
Commit
6cc3c63
·
verified ·
1 Parent(s): 4fd362d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +110 -24
app.py CHANGED
@@ -1,41 +1,57 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
4
  from rdkit import Chem
5
  from rdkit.Chem import Draw
6
- from graph_decoder.diffusion_model import GraphDiT
7
-
8
- # Load the model
9
- def load_graph_decoder(path='model_labeled'):
10
- model = GraphDiT(
11
- model_config_path=f"{path}/config.yaml",
12
- data_info_path=f"{path}/data.meta.json",
13
- model_dtype=torch.float32,
14
- )
15
- model.init_model(path)
16
- model.disable_grads()
 
 
 
 
 
 
17
  return model
18
 
19
- # model = load_graph_decoder()
20
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
21
 
22
  @spaces.GPU
23
  def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
24
- properties = [CH4, CO2, H2, N2, O2]
25
 
 
26
  try:
27
  model = load_graph_decoder()
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model.to(device)
30
- print('enter function')
31
- generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
32
 
33
- if generated_molecule is not None:
34
- mol = Chem.MolFromSmiles(generated_molecule)
35
- if mol is not None:
36
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
37
- img = Draw.MolToImage(mol)
38
- return standardized_smiles, img
 
 
 
 
 
 
39
  except Exception as e:
40
  print(f"Error in generation: {e}")
41
 
@@ -43,7 +59,7 @@ def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
43
 
44
  # Create the Gradio interface
45
  with gr.Blocks(title="Simplified Polymer Design") as iface:
46
- gr.Markdown("## Polymer Design with GraphDiT")
47
 
48
  with gr.Row():
49
  CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
@@ -66,4 +82,74 @@ with gr.Blocks(title="Simplified Polymer Design") as iface:
66
  )
67
 
68
  if __name__ == "__main__":
69
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import torch
4
+ import torch.nn as nn
5
+ import random
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
+
9
+ class RandomPolymerGenerator(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.fc1 = nn.Linear(5, 64)
13
+ self.fc2 = nn.Linear(64, 128)
14
+ self.fc3 = nn.Linear(128, 256)
15
+ self.fc4 = nn.Linear(256, 100) # Output size set to 100 for simplicity
16
+
17
+ def forward(self, x):
18
+ x = torch.relu(self.fc1(x))
19
+ x = torch.relu(self.fc2(x))
20
+ x = torch.relu(self.fc3(x))
21
+ return torch.sigmoid(self.fc4(x))
22
+
23
+ def load_graph_decoder():
24
+ model = RandomPolymerGenerator()
25
  return model
26
 
27
+ ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
28
+
29
+ def generate_random_smiles(length=10):
30
+ return ''.join(random.choices(ATOM_SYMBOLS, k=length))
31
 
32
  @spaces.GPU
33
  def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
34
+ properties = torch.tensor([CH4, CO2, H2, N2, O2], dtype=torch.float32).unsqueeze(0)
35
 
36
+ print('in generate_polymer')
37
  try:
38
  model = load_graph_decoder()
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  model.to(device)
41
+ properties = properties.to(device)
 
42
 
43
+ with torch.no_grad():
44
+ output = model(properties)
45
+ print('output', output)
46
+
47
+ # Generate a random SMILES string (this is a placeholder)
48
+ generated_molecule = generate_random_smiles()
49
+
50
+ mol = Chem.MolFromSmiles(generated_molecule)
51
+ if mol is not None:
52
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
53
+ img = Draw.MolToImage(mol)
54
+ return standardized_smiles, img
55
  except Exception as e:
56
  print(f"Error in generation: {e}")
57
 
 
59
 
60
  # Create the Gradio interface
61
  with gr.Blocks(title="Simplified Polymer Design") as iface:
62
+ gr.Markdown("## Polymer Design with Random Neural Network")
63
 
64
  with gr.Row():
65
  CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
 
82
  )
83
 
84
  if __name__ == "__main__":
85
+ iface.launch()
86
+
87
+ # import spaces
88
+ # import gradio as gr
89
+ # import torch
90
+ # from rdkit import Chem
91
+ # from rdkit.Chem import Draw
92
+ # # from graph_decoder.diffusion_model import GraphDiT
93
+
94
+ # # Load the model
95
+ # def load_graph_decoder(path='model_labeled'):
96
+ # model = GraphDiT(
97
+ # model_config_path=f"{path}/config.yaml",
98
+ # data_info_path=f"{path}/data.meta.json",
99
+ # model_dtype=torch.float32,
100
+ # )
101
+ # model.init_model(path)
102
+ # model.disable_grads()
103
+ # return model
104
+
105
+ # # model = load_graph_decoder()
106
+ # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+
108
+ # @spaces.GPU
109
+ # def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
110
+ # properties = [CH4, CO2, H2, N2, O2]
111
+
112
+ # try:
113
+ # model = load_graph_decoder()
114
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+ # model.to(device)
116
+ # print('enter function')
117
+ # generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
118
+
119
+ # if generated_molecule is not None:
120
+ # mol = Chem.MolFromSmiles(generated_molecule)
121
+ # if mol is not None:
122
+ # standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
123
+ # img = Draw.MolToImage(mol)
124
+ # return standardized_smiles, img
125
+ # except Exception as e:
126
+ # print(f"Error in generation: {e}")
127
+
128
+ # return "Generation failed", None
129
+
130
+ # # Create the Gradio interface
131
+ # with gr.Blocks(title="Simplified Polymer Design") as iface:
132
+ # gr.Markdown("## Polymer Design with GraphDiT")
133
+
134
+ # with gr.Row():
135
+ # CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
136
+ # CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
137
+ # H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
138
+ # N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
139
+ # O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
140
+ # guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
141
+
142
+ # generate_btn = gr.Button("Generate Polymer")
143
+
144
+ # with gr.Row():
145
+ # result_smiles = gr.Textbox(label="Generated SMILES")
146
+ # result_image = gr.Image(label="Molecule Visualization", type="pil")
147
+
148
+ # generate_btn.click(
149
+ # generate_polymer,
150
+ # inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
151
+ # outputs=[result_smiles, result_image]
152
+ # )
153
+
154
+ # if __name__ == "__main__":
155
+ # iface.launch()