liuganghuggingface commited on
Commit
e387f21
·
verified ·
1 Parent(s): cacde71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -18,7 +18,7 @@ from rdkit import Chem
18
  from rdkit.Chem import Draw
19
 
20
  from evaluator import Evaluator
21
- from loader import load_graph_decoder
22
 
23
  # Load the CSV data
24
  known_labels = pd.read_csv('data/known_labels.csv')
@@ -55,7 +55,19 @@ def random_properties():
55
 
56
  def load_model(model_choice):
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- model = load_graph_decoder(path=model_choice)
 
 
 
 
 
 
 
 
 
 
 
 
59
  return (model, device)
60
 
61
  # Create a flagged folder if it doesn't exist
 
18
  from rdkit.Chem import Draw
19
 
20
  from evaluator import Evaluator
21
+ # from loader import load_graph_decoder
22
 
23
  # Load the CSV data
24
  known_labels = pd.read_csv('data/known_labels.csv')
 
55
 
56
  def load_model(model_choice):
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ # model = load_graph_decoder(path=model_choice)
59
+ #### test
60
+ from graph_decoder.diffusion_model import GraphDiT
61
+
62
+ model_config_path = f"model_labeled/config.yaml"
63
+ data_info_path = f"model_labeled/data.meta.json"
64
+ model = GraphDiT(
65
+ model_config_path=model_config_path,
66
+ data_info_path=data_info_path,
67
+ # model_dtype=torch.float16,
68
+ model_dtype=torch.float32,
69
+ )
70
+ ### test
71
  return (model, device)
72
 
73
  # Create a flagged folder if it doesn't exist