pgantzer commited on
Commit
1760662
·
1 Parent(s): c08b1f5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SynTool/__init__.py +3 -0
  2. SynTool/chem/__init__.py +0 -0
  3. SynTool/chem/__pycache__/__init__.cpython-310.pyc +0 -0
  4. SynTool/chem/__pycache__/reaction.cpython-310.pyc +0 -0
  5. SynTool/chem/__pycache__/retron.cpython-310.pyc +0 -0
  6. SynTool/chem/__pycache__/utils.cpython-310.pyc +0 -0
  7. SynTool/chem/data/__init__.py +0 -0
  8. SynTool/chem/data/__pycache__/__init__.cpython-310.pyc +0 -0
  9. SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc +0 -0
  10. SynTool/chem/data/__pycache__/filtering.cpython-310.pyc +0 -0
  11. SynTool/chem/data/__pycache__/mapping.cpython-310.pyc +0 -0
  12. SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc +0 -0
  13. SynTool/chem/data/cleaning.py +124 -0
  14. SynTool/chem/data/filtering.py +917 -0
  15. SynTool/chem/data/mapping.py +96 -0
  16. SynTool/chem/data/mapping.py.bk +90 -0
  17. SynTool/chem/data/standardizer.py +604 -0
  18. SynTool/chem/reaction.py +107 -0
  19. SynTool/chem/reaction_rules/__init__.py +0 -0
  20. SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc +0 -0
  21. SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc +0 -0
  22. SynTool/chem/reaction_rules/extraction.py +679 -0
  23. SynTool/chem/reaction_rules/manual/__init__.py +6 -0
  24. SynTool/chem/reaction_rules/manual/decompositions.py +415 -0
  25. SynTool/chem/reaction_rules/manual/transformations.py +535 -0
  26. SynTool/chem/retron.py +132 -0
  27. SynTool/chem/utils.py +227 -0
  28. SynTool/interfaces/__init__.py +0 -0
  29. SynTool/interfaces/__pycache__/__init__.cpython-310.pyc +0 -0
  30. SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc +0 -0
  31. SynTool/interfaces/cli.py +530 -0
  32. SynTool/interfaces/cli.py.bk +241 -0
  33. SynTool/interfaces/visualisation.py +346 -0
  34. SynTool/mcts/__init__.py +7 -0
  35. SynTool/mcts/__pycache__/__init__.cpython-310.pyc +0 -0
  36. SynTool/mcts/__pycache__/evaluation.cpython-310.pyc +0 -0
  37. SynTool/mcts/__pycache__/expansion.cpython-310.pyc +0 -0
  38. SynTool/mcts/__pycache__/node.cpython-310.pyc +0 -0
  39. SynTool/mcts/__pycache__/search.cpython-310.pyc +0 -0
  40. SynTool/mcts/__pycache__/tree.cpython-310.pyc +0 -0
  41. SynTool/mcts/evaluation.py +59 -0
  42. SynTool/mcts/expansion.py +83 -0
  43. SynTool/mcts/node.py +49 -0
  44. SynTool/mcts/search.py +135 -0
  45. SynTool/mcts/tree.py +659 -0
  46. SynTool/ml/__init__.py +0 -0
  47. SynTool/ml/__pycache__/__init__.cpython-310.pyc +0 -0
  48. SynTool/ml/networks/__init__.py +0 -0
  49. SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc +0 -0
  50. SynTool/ml/networks/__pycache__/modules.cpython-310.pyc +0 -0
SynTool/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mcts import *
2
+
3
+ __all__ = ["Tree"]
SynTool/chem/__init__.py ADDED
File without changes
SynTool/chem/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
SynTool/chem/__pycache__/reaction.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
SynTool/chem/__pycache__/retron.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
SynTool/chem/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.25 kB). View file
 
SynTool/chem/data/__init__.py ADDED
File without changes
SynTool/chem/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
SynTool/chem/data/__pycache__/filtering.cpython-310.pyc ADDED
Binary file (27.6 kB). View file
 
SynTool/chem/data/__pycache__/mapping.cpython-310.pyc ADDED
Binary file (2.59 kB). View file
 
SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
SynTool/chem/data/cleaning.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from multiprocessing import Queue, Process, Manager, Value
3
+ from logging import getLogger, Logger
4
+ from tqdm import tqdm
5
+ from CGRtools.containers import ReactionContainer
6
+
7
+ from .standardizer import Standardizer
8
+ from SynTool.utils.files import ReactionReader, ReactionWriter
9
+ from SynTool.utils.config import ReactionStandardizationConfig
10
+
11
+
12
+ def cleaner(reaction: ReactionContainer, logger: Logger, config: ReactionStandardizationConfig):
13
+ """
14
+ Standardize a reaction according to external script
15
+
16
+ :param reaction: ReactionContainer to clean/standardize
17
+ :param logger: Logger - to avoid writing log
18
+ :param config: ReactionStandardizationConfig
19
+ :return: ReactionContainer or empty list
20
+ """
21
+ standardizer = Standardizer(id_tag='Reaction_ID',
22
+ action_on_isotopes=2,
23
+ skip_tautomerize=True,
24
+ skip_errors=config.skip_errors,
25
+ keep_unbalanced_ions=config.keep_unbalanced_ions,
26
+ keep_reagents=config.keep_reagents,
27
+ ignore_mapping=config.ignore_mapping,
28
+ logger=logger)
29
+ return standardizer.standardize(reaction)
30
+
31
+
32
+ def worker_cleaner(to_clean: Queue, to_write: Queue, config: ReactionStandardizationConfig):
33
+ """
34
+ Launches standardizations using the Queue to_clean. Fills the to_write Queue with results
35
+
36
+ :param to_clean: Queue of reactions to clean/standardize
37
+ :param to_write: Standardized outputs to write
38
+ :param config: ReactionStandardizationConfig
39
+ :return: None
40
+ """
41
+ logger = getLogger()
42
+ logger.disabled = True
43
+ while True:
44
+ raw_reaction = to_clean.get()
45
+ if raw_reaction == "Quit":
46
+ break
47
+ res = cleaner(raw_reaction, logger, config)
48
+ to_write.put(res)
49
+ logger.disabled = False
50
+
51
+
52
+ def cleaner_writer(output_file: str, to_write: Queue, cleaned_nb: Value, remove_old=True):
53
+ """
54
+ Writes in output file the standardized reactions
55
+
56
+ :param output_file: output file path
57
+ :param to_write: Standardized ReactionContainer to write
58
+ :param cleaned_nb: number of final reactions
59
+ :param remove_old: whenever to remove or not an already existing file
60
+ """
61
+
62
+ if remove_old and os.path.isfile(output_file):
63
+ os.remove(output_file)
64
+
65
+ counter = 0
66
+ seen_reactions = []
67
+ with ReactionWriter(output_file) as out:
68
+ while True:
69
+ res = to_write.get()
70
+ if res:
71
+ if res == "Quit":
72
+ cleaned_nb.set(counter)
73
+ break
74
+ elif isinstance(res, ReactionContainer):
75
+ smi = format(res, "m")
76
+ if smi not in seen_reactions:
77
+ out.write(res)
78
+ counter += 1
79
+ seen_reactions.append(smi)
80
+
81
+
82
+ def reactions_cleaner(config: ReactionStandardizationConfig,
83
+ input_file: str, output_file: str, num_cpus: int, batch_prep_size: int = 100):
84
+ """
85
+ Writes in output file the standardized reactions
86
+
87
+ :param config:
88
+ :param input_file: input RDF file path
89
+ :param output_file: output RDF file path
90
+ :param num_cpus: number of CPU to be parallelized
91
+ :param batch_prep_size: size of each batch per CPU
92
+ """
93
+ with Manager() as m:
94
+ to_clean = m.Queue(maxsize=num_cpus * batch_prep_size)
95
+ to_write = m.Queue(maxsize=batch_prep_size)
96
+ cleaned_nb = m.Value(int, 0)
97
+
98
+ writer = Process(target=cleaner_writer, args=(output_file, to_write, cleaned_nb))
99
+ writer.start()
100
+
101
+ workers = []
102
+ for _ in range(num_cpus - 2):
103
+ w = Process(target=worker_cleaner, args=(to_clean, to_write, config))
104
+ w.start()
105
+ workers.append(w)
106
+
107
+ n = 0
108
+ with ReactionReader(input_file) as reactions:
109
+ for raw_reaction in tqdm(reactions):
110
+ if 'Reaction_ID' not in raw_reaction.meta:
111
+ raw_reaction.meta['Reaction_ID'] = n
112
+ to_clean.put(raw_reaction)
113
+ n += 1
114
+
115
+ for _ in workers:
116
+ to_clean.put("Quit")
117
+ for w in workers:
118
+ w.join()
119
+
120
+ to_write.put("Quit")
121
+ writer.join()
122
+
123
+ print(f'Initial number of reactions: {n}'),
124
+ print(f'Removed number of reactions: {n - cleaned_nb.get()}')
SynTool/chem/data/filtering.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Iterable, Tuple, Dict, Any, Optional
5
+ from tqdm.auto import tqdm
6
+
7
+ import numpy as np
8
+ import ray
9
+ import yaml
10
+ from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer
11
+ from StructureFingerprint import MorganFingerprint
12
+
13
+ from SynTool.utils.files import ReactionReader, ReactionWriter
14
+ from SynTool.chem.utils import remove_small_molecules, rebalance_reaction, remove_reagents
15
+ from SynTool.utils.config import ConfigABC, convert_config_to_dict
16
+
17
+
18
+ @dataclass
19
+ class CompeteProductsConfig(ConfigABC):
20
+ fingerprint_tanimoto_threshold: float = 0.3
21
+ mcs_tanimoto_threshold: float = 0.6
22
+
23
+ @staticmethod
24
+ def from_dict(config_dict: Dict[str, Any]):
25
+ """Create an instance of CompeteProductsConfig from a dictionary."""
26
+ return CompeteProductsConfig(**config_dict)
27
+
28
+ @staticmethod
29
+ def from_yaml(file_path: str):
30
+ """Deserialize a YAML file into a CompeteProductsConfig object."""
31
+ with open(file_path, "r") as file:
32
+ config_dict = yaml.safe_load(file)
33
+ return CompeteProductsConfig.from_dict(config_dict)
34
+
35
+ def _validate_params(self, params: Dict[str, Any]):
36
+ """Validate configuration parameters."""
37
+ if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) \
38
+ or not (0 <= params["fingerprint_tanimoto_threshold"] <= 1):
39
+ raise ValueError("Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1")
40
+
41
+ if not isinstance(params.get("mcs_tanimoto_threshold"), float) \
42
+ or not (0 <= params["mcs_tanimoto_threshold"] <= 1):
43
+ raise ValueError("Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1")
44
+
45
+
46
+ class CompeteProductsChecker:
47
+ """Checks if there are compete reactions."""
48
+
49
+ def __init__(
50
+ self,
51
+ fingerprint_tanimoto_threshold: float = 0.3,
52
+ mcs_tanimoto_threshold: float = 0.6,
53
+ ):
54
+ self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold
55
+ self.mcs_tanimoto_threshold = mcs_tanimoto_threshold
56
+
57
+ @staticmethod
58
+ def from_config(config: CompeteProductsConfig):
59
+ """Creates an instance of CompeteProductsChecker from a configuration object."""
60
+ return CompeteProductsChecker(
61
+ config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold
62
+ )
63
+
64
+ def __call__(self, reaction: ReactionContainer) -> bool:
65
+ """
66
+ Returns True if the reaction has competing products, else False
67
+
68
+ :param reaction: input reaction
69
+ :return: True or False
70
+ """
71
+ mf = MorganFingerprint()
72
+ is_compete = False
73
+
74
+ # Check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity
75
+ for mol in reaction.reagents:
76
+ for other_mol in reaction.products:
77
+ if len(mol) > 6 and len(other_mol) > 6:
78
+ # Compute fingerprint similarity
79
+ molf = mf.transform([mol])
80
+ other_molf = mf.transform([other_mol])
81
+ fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0]
82
+
83
+ # If fingerprint similarity is high enough, check for MCS similarity
84
+ if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold:
85
+ try:
86
+ # Find the maximum common substructure (MCS) and compute its size
87
+ clique_size = len(next(mol.get_mcs_mapping(other_mol, limit=100)))
88
+
89
+ # Calculate MCS similarity based on MCS size
90
+ mcs_tanimoto = clique_size / (len(mol) + len(other_mol) - clique_size)
91
+
92
+ # If MCS similarity is also high enough, mark the reaction as having compete products
93
+ if mcs_tanimoto > self.mcs_tanimoto_threshold:
94
+ is_compete = True
95
+ break
96
+ except StopIteration:
97
+ continue
98
+
99
+ return is_compete
100
+
101
+
102
+ @dataclass
103
+ class DynamicBondsConfig(ConfigABC):
104
+ min_bonds_number: int = 1
105
+ max_bonds_number: int = 6
106
+
107
+ @staticmethod
108
+ def from_dict(config_dict: Dict[str, Any]):
109
+ """Create an instance of DynamicBondsConfig from a dictionary."""
110
+ return DynamicBondsConfig(**config_dict)
111
+
112
+ @staticmethod
113
+ def from_yaml(file_path: str):
114
+ """Deserialize a YAML file into a DynamicBondsConfig object."""
115
+ with open(file_path, "r") as file:
116
+ config_dict = yaml.safe_load(file)
117
+ return DynamicBondsConfig.from_dict(config_dict)
118
+
119
+ def _validate_params(self, params: Dict[str, Any]):
120
+ """Validate configuration parameters."""
121
+ if not isinstance(params.get("min_bonds_number"), int) \
122
+ or params["min_bonds_number"] < 0:
123
+ raise ValueError(
124
+ "Invalid 'min_bonds_number'; expected a non-negative integer")
125
+
126
+ if not isinstance(params.get("max_bonds_number"), int) \
127
+ or params["max_bonds_number"] < 0:
128
+ raise ValueError("Invalid 'max_bonds_number'; expected a non-negative integer")
129
+
130
+ if params["min_bonds_number"] > params["max_bonds_number"]:
131
+ raise ValueError("'min_bonds_number' cannot be greater than 'max_bonds_number'")
132
+
133
+
134
+ class DynamicBondsChecker:
135
+ """Checks if there is an unacceptable number of dynamic bonds in CGR."""
136
+
137
+ def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6):
138
+ self.min_bonds_number = min_bonds_number
139
+ self.max_bonds_number = max_bonds_number
140
+
141
+ @staticmethod
142
+ def from_config(config: DynamicBondsConfig):
143
+ """Creates an instance of DynamicBondsChecker from a configuration object."""
144
+ return DynamicBondsChecker(config.min_bonds_number, config.max_bonds_number)
145
+
146
+ def __call__(self, reaction: ReactionContainer) -> bool:
147
+ cgr = ~reaction
148
+ return not (self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number)
149
+
150
+
151
+ @dataclass
152
+ class SmallMoleculesConfig(ConfigABC):
153
+ limit: int = 6
154
+
155
+ @staticmethod
156
+ def from_dict(config_dict: Dict[str, Any]):
157
+ """Create an instance of SmallMoleculesConfig from a dictionary."""
158
+ return SmallMoleculesConfig(**config_dict)
159
+
160
+ @staticmethod
161
+ def from_yaml(file_path: str):
162
+ """Deserialize a YAML file into a SmallMoleculesConfig object."""
163
+ with open(file_path, "r") as file:
164
+ config_dict = yaml.safe_load(file)
165
+ return SmallMoleculesConfig.from_dict(config_dict)
166
+
167
+ def _validate_params(self, params: Dict[str, Any]):
168
+ """Validate configuration parameters."""
169
+ if not isinstance(params.get("limit"), int) or params["limit"] < 1:
170
+ raise ValueError("Invalid 'limit'; expected a positive integer")
171
+
172
+
173
+ class SmallMoleculesChecker:
174
+ """Checks if there are only small molecules in the reaction or if there is only one small reactant or product."""
175
+
176
+ def __init__(self, limit: int = 6):
177
+ self.limit = limit
178
+
179
+ @staticmethod
180
+ def from_config(config: SmallMoleculesConfig):
181
+ """Creates an instance of SmallMoleculesChecker from a configuration object."""
182
+ return SmallMoleculesChecker(config.limit)
183
+
184
+ def __call__(self, reaction: ReactionContainer) -> bool:
185
+ if (len(reaction.reactants) == 1 and self.are_only_small_molecules(reaction.reactants)) \
186
+ or (len(reaction.products) == 1 and self.are_only_small_molecules(reaction.products)) \
187
+ or (self.are_only_small_molecules(reaction.reactants) and self.are_only_small_molecules(reaction.products)):
188
+ return True
189
+ return False
190
+
191
+ def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool:
192
+ """Checks if all molecules in the given iterable are small molecules."""
193
+ return all(len(molecule) <= self.limit for molecule in molecules)
194
+
195
+
196
+ @dataclass
197
+ class CGRConnectedComponentsConfig:
198
+ pass
199
+
200
+
201
+ class CGRConnectedComponentsChecker:
202
+ """Allows to check if CGR contains unrelated components (without reagents)."""
203
+
204
+ @staticmethod
205
+ def from_config(config: CGRConnectedComponentsConfig): # TODO config class not used
206
+ """Creates an instance of CGRConnectedComponentsChecker from a configuration object."""
207
+ return CGRConnectedComponentsChecker()
208
+
209
+ def __call__(self, reaction: ReactionContainer) -> bool:
210
+ tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
211
+ cgr = ~tmp_reaction
212
+ return cgr.connected_components_count > 1
213
+
214
+
215
+ @dataclass
216
+ class RingsChangeConfig:
217
+ pass
218
+
219
+
220
+ class RingsChangeChecker:
221
+ """Allows to check if there is changing rings number in the reaction."""
222
+
223
+ @staticmethod
224
+ def from_config(config: RingsChangeConfig): # TODO config class not used
225
+ """Creates an instance of RingsChecker from a configuration object."""
226
+ return RingsChangeChecker()
227
+
228
+ def __call__(self, reaction: ReactionContainer):
229
+ """
230
+ Returns True if there are valence mistakes in the reaction or there is a reaction with mismatch numbers of all
231
+ rings or aromatic rings in reactants and products (reaction in rings)
232
+
233
+ :param reaction: input reaction
234
+ :return: True or False
235
+ """
236
+
237
+ reaction.kekule()
238
+ reaction.thiele()
239
+ r_rings, r_arom_rings = self._calc_rings(reaction.reactants)
240
+ p_rings, p_arom_rings = self._calc_rings(reaction.products)
241
+ if (r_arom_rings != p_arom_rings) or (r_rings != p_rings):
242
+ return True
243
+ else:
244
+ return False
245
+
246
+ @staticmethod
247
+ def _calc_rings(molecules: Iterable) -> Tuple[int, int]:
248
+ """
249
+ Calculates number of all rings and number of aromatic rings in molecules
250
+
251
+ :param molecules: set of molecules
252
+ :return: number of all rings and number of aromatic rings in molecules
253
+ """
254
+ rings, arom_rings = 0, 0
255
+ for mol in molecules:
256
+ rings += mol.rings_count
257
+ arom_rings += len(mol.aromatic_rings)
258
+ return rings, arom_rings
259
+
260
+
261
+ @dataclass
262
+ class StrangeCarbonsConfig:
263
+ # Currently empty, but can be extended in the future if needed
264
+ pass
265
+
266
+
267
+ class StrangeCarbonsChecker:
268
+ """Checks if there are 'strange' carbons in the reaction."""
269
+
270
+ @staticmethod
271
+ def from_config(config: StrangeCarbonsConfig): # TODO config class not used
272
+ """Creates an instance of StrangeCarbonsChecker from a configuration object."""
273
+ return StrangeCarbonsChecker()
274
+
275
+ def __call__(self, reaction: ReactionContainer) -> bool:
276
+ for molecule in reaction.reactants + reaction.products:
277
+ atoms_types = {a.atomic_symbol for _, a in molecule.atoms()} # atoms types in molecule
278
+ if len(atoms_types) == 1 and atoms_types.pop() == "C":
279
+ if len(molecule) == 1: # methane
280
+ return True
281
+ bond_types = {int(b) for _, _, b in molecule.bonds()}
282
+ if len(bond_types) == 1 and bond_types.pop() != 4:
283
+ return True # C molecules with only one type of bond (not aromatic)
284
+ return False
285
+
286
+
287
+ @dataclass
288
+ class NoReactionConfig:
289
+ # Currently empty, but can be extended in the future if needed
290
+ pass
291
+
292
+
293
+ class NoReactionChecker:
294
+ """Checks if there is no reaction in the provided reaction container."""
295
+
296
+ @staticmethod
297
+ def from_config(config: NoReactionConfig): # TODO config class not used
298
+ """Creates an instance of NoReactionChecker from a configuration object."""
299
+ return NoReactionChecker()
300
+
301
+ def __call__(self, reaction: ReactionContainer) -> bool:
302
+ cgr = ~reaction
303
+ return not cgr.center_atoms and not cgr.center_bonds
304
+
305
+
306
+ @dataclass
307
+ class MultiCenterConfig:
308
+ pass
309
+
310
+
311
+ class MultiCenterChecker:
312
+ """Checks if there is a multicenter reaction."""
313
+
314
+ @staticmethod
315
+ def from_config(config: MultiCenterConfig): # TODO config class not used
316
+ return MultiCenterChecker()
317
+
318
+ def __call__(self, reaction: ReactionContainer) -> bool:
319
+ cgr = ~reaction
320
+ return len(cgr.centers_list) > 1
321
+
322
+
323
+ @dataclass
324
+ class WrongCHBreakingConfig:
325
+ pass
326
+
327
+
328
+ class WrongCHBreakingChecker:
329
+ """Checks for incorrect C-C bond formation from breaking a C-H bond."""
330
+
331
+ @staticmethod
332
+ def from_config(config: WrongCHBreakingConfig): # TODO config class not used
333
+ return WrongCHBreakingChecker()
334
+
335
+ def __call__(self, reaction: ReactionContainer) -> bool:
336
+ """
337
+ Determines if a reaction involves incorrect C-C bond formation from breaking a C-H bond.
338
+
339
+ :param reaction: The reaction to be checked.
340
+ :return: True if incorrect C-C bond formation is found, False otherwise.
341
+ """
342
+
343
+ reaction.kekule()
344
+ if reaction.check_valence():
345
+ return False
346
+ reaction.thiele()
347
+
348
+ copy_reaction = reaction.copy()
349
+ copy_reaction.explicify_hydrogens()
350
+ cgr = ~copy_reaction
351
+ reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1)
352
+
353
+ return self.is_wrong_c_h_breaking(reduced_cgr)
354
+
355
+ @staticmethod
356
+ def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool:
357
+ """
358
+ Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR.
359
+ :param cgr: The CGR with explicified hydrogens.
360
+ :return: True if incorrect C-C bond formation is found, False otherwise.
361
+ """
362
+ for atom_id in cgr.center_atoms:
363
+ if cgr.atom(atom_id).atomic_symbol == "C":
364
+ is_c_h_breaking, is_c_c_formation = False, False
365
+ c_with_h_id, another_c_id = None, None
366
+
367
+ for neighbour_id, bond in cgr._bonds[atom_id].items():
368
+ neighbour = cgr.atom(neighbour_id)
369
+
370
+ if (
371
+ bond.order
372
+ and not bond.p_order
373
+ and neighbour.atomic_symbol == "H"
374
+ ):
375
+ is_c_h_breaking = True
376
+ c_with_h_id = atom_id
377
+
378
+ elif (
379
+ not bond.order
380
+ and bond.p_order
381
+ and neighbour.atomic_symbol == "C"
382
+ ):
383
+ is_c_c_formation = True
384
+ another_c_id = neighbour_id
385
+
386
+ if is_c_h_breaking and is_c_c_formation:
387
+ # Check for presence of heteroatoms in the first environment of 2 bonding carbons
388
+ if any(
389
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
390
+ for neighbour_id in cgr._bonds[c_with_h_id]
391
+ ) or any(
392
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
393
+ for neighbour_id in cgr._bonds[another_c_id]
394
+ ):
395
+ return False
396
+ return True
397
+
398
+ return False
399
+
400
+
401
+ @dataclass
402
+ class CCsp3BreakingConfig:
403
+ pass
404
+
405
+
406
+ class CCsp3BreakingChecker:
407
+ """Checks if there is C(sp3)-C bond breaking."""
408
+
409
+ @staticmethod
410
+ def from_config(config: CCsp3BreakingConfig): # TODO config class not used
411
+ return CCsp3BreakingChecker()
412
+
413
+ def __call__(self, reaction: ReactionContainer) -> bool:
414
+ """
415
+ Returns True if there is C(sp3)-C bonds breaking, else False
416
+
417
+ :param reaction: input reaction
418
+ :return: True or False
419
+ """
420
+ cgr = ~reaction
421
+ reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1)
422
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
423
+ atom = reaction_center.atom(atom_id)
424
+ neighbour = reaction_center.atom(neighbour_id)
425
+
426
+ is_bond_broken = bond.order is not None and bond.p_order is None
427
+ are_atoms_carbons = (
428
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
429
+ )
430
+ is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1
431
+
432
+ if is_bond_broken and are_atoms_carbons and is_atom_sp3:
433
+ return True
434
+ return False
435
+
436
+
437
+ @dataclass
438
+ class CCRingBreakingConfig:
439
+ pass
440
+
441
+
442
+ class CCRingBreakingChecker:
443
+ """Checks if a reaction involves ring C-C bond breaking."""
444
+
445
+ @staticmethod
446
+ def from_config(config: CCRingBreakingConfig): # TODO config class not used
447
+ return CCRingBreakingChecker()
448
+
449
+ def __call__(self, reaction: ReactionContainer) -> bool:
450
+ """
451
+ Returns True if the reaction involves ring C-C bond breaking, else False
452
+
453
+ :param reaction: input reaction
454
+ :return: True or False
455
+ """
456
+ cgr = ~reaction
457
+
458
+ # Extract reactants' center atoms and their rings
459
+ reactants_center_atoms = {}
460
+ reactants_rings = set()
461
+ for reactant in reaction.reactants:
462
+ reactants_rings.update(reactant.sssr)
463
+ for n, atom in reactant.atoms():
464
+ if n in cgr.center_atoms:
465
+ reactants_center_atoms[n] = atom
466
+
467
+ # Identify reaction center based on center atoms
468
+ reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0)
469
+
470
+ # Iterate over bonds in the reaction center and check for ring C-C bond breaking
471
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
472
+ try:
473
+ # Retrieve corresponding atoms from reactants
474
+ atom = reactants_center_atoms[atom_id]
475
+ neighbour = reactants_center_atoms[neighbour_id]
476
+ except KeyError:
477
+ continue
478
+ else:
479
+ # Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7
480
+ is_bond_broken = (bond.order is not None) and (bond.p_order is None)
481
+ are_atoms_carbons = (
482
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
483
+ )
484
+ are_atoms_in_ring = (
485
+ set(atom.ring_sizes).intersection({5, 6, 7})
486
+ and set(neighbour.ring_sizes).intersection({5, 6, 7})
487
+ and any(
488
+ atom_id in ring and neighbour_id in ring
489
+ for ring in reactants_rings
490
+ )
491
+ )
492
+
493
+ # If all conditions are met, indicate ring C-C bond breaking
494
+ if is_bond_broken and are_atoms_carbons and are_atoms_in_ring:
495
+ return True
496
+
497
+ return False
498
+
499
+
500
+ @dataclass
501
+ class ReactionCheckConfig(ConfigABC):
502
+ """
503
+ Configuration class for reaction checks, inheriting from ConfigABC.
504
+
505
+ This class manages configuration settings for various reaction checkers, including paths, file formats,
506
+ and checker-specific parameters.
507
+
508
+ Attributes:
509
+ dynamic_bonds_config: Configuration for dynamic bonds checking.
510
+ small_molecules_config: Configuration for small molecules checking.
511
+ strange_carbons_config: Configuration for strange carbons checking.
512
+ compete_products_config: Configuration for competing products checking.
513
+ cgr_connected_components_config: Configuration for CGR connected components checking.
514
+ rings_change_config: Configuration for rings change checking.
515
+ no_reaction_config: Configuration for no reaction checking.
516
+ multi_center_config: Configuration for multi-center checking.
517
+ wrong_ch_breaking_config: Configuration for wrong C-H breaking checking.
518
+ cc_sp3_breaking_config: Configuration for CC sp3 breaking checking.
519
+ cc_ring_breaking_config: Configuration for CC ring breaking checking.
520
+ """
521
+
522
+ # Configuration for reaction checkers
523
+ dynamic_bonds_config: Optional[DynamicBondsConfig] = None
524
+ small_molecules_config: Optional[SmallMoleculesConfig] = None
525
+ strange_carbons_config: Optional[StrangeCarbonsConfig] = None
526
+ compete_products_config: Optional[CompeteProductsConfig] = None
527
+ cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None
528
+ rings_change_config: Optional[RingsChangeConfig] = None
529
+ no_reaction_config: Optional[NoReactionConfig] = None
530
+ multi_center_config: Optional[MultiCenterConfig] = None
531
+ wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None
532
+ cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None
533
+ cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None
534
+
535
+ # Other configuration parameters
536
+ rebalance_reaction: bool = False
537
+ remove_reagents: bool = True
538
+ reagents_max_size: int = 7
539
+ remove_small_molecules: bool = False
540
+ small_molecules_max_size: int = 6
541
+
542
+ def to_dict(self):
543
+ """
544
+ Converts the configuration into a dictionary.
545
+ """
546
+ config_dict = {
547
+ "dynamic_bonds_config": convert_config_to_dict(
548
+ self.dynamic_bonds_config, DynamicBondsConfig
549
+ ),
550
+ "small_molecules_config": convert_config_to_dict(
551
+ self.small_molecules_config, SmallMoleculesConfig
552
+ ),
553
+ "compete_products_config": convert_config_to_dict(
554
+ self.compete_products_config, CompeteProductsConfig
555
+ ),
556
+ "cgr_connected_components_config": {}
557
+ if self.cgr_connected_components_config is not None
558
+ else None,
559
+ "rings_change_config": {} if self.rings_change_config is not None else None,
560
+ "strange_carbons_config": {}
561
+ if self.strange_carbons_config is not None
562
+ else None,
563
+ "no_reaction_config": {} if self.no_reaction_config is not None else None,
564
+ "multi_center_config": {} if self.multi_center_config is not None else None,
565
+ "wrong_ch_breaking_config": {}
566
+ if self.wrong_ch_breaking_config is not None
567
+ else None,
568
+ "cc_sp3_breaking_config": {}
569
+ if self.cc_sp3_breaking_config is not None
570
+ else None,
571
+ "cc_ring_breaking_config": {}
572
+ if self.cc_ring_breaking_config is not None
573
+ else None,
574
+ "rebalance_reaction": self.rebalance_reaction,
575
+ "remove_reagents": self.remove_reagents,
576
+ "reagents_max_size": self.reagents_max_size,
577
+ "remove_small_molecules": self.remove_small_molecules,
578
+ "small_molecules_max_size": self.small_molecules_max_size,
579
+ }
580
+
581
+ filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None}
582
+
583
+ return filtered_config_dict
584
+
585
+ @staticmethod
586
+ def from_dict(config_dict: Dict[str, Any]):
587
+ """
588
+ Create an instance of ReactionCheckConfig from a dictionary.
589
+ """
590
+ # Instantiate configuration objects if their corresponding dictionary is present
591
+ dynamic_bonds_config = (
592
+ DynamicBondsConfig(**config_dict["dynamic_bonds_config"])
593
+ if "dynamic_bonds_config" in config_dict
594
+ else None
595
+ )
596
+ small_molecules_config = (
597
+ SmallMoleculesConfig(**config_dict["small_molecules_config"])
598
+ if "small_molecules_config" in config_dict
599
+ else None
600
+ )
601
+ compete_products_config = (
602
+ CompeteProductsConfig(**config_dict["compete_products_config"])
603
+ if "compete_products_config" in config_dict
604
+ else None
605
+ )
606
+ cgr_connected_components_config = (
607
+ CGRConnectedComponentsConfig()
608
+ if "cgr_connected_components_config" in config_dict
609
+ else None
610
+ )
611
+ rings_change_config = (
612
+ RingsChangeConfig()
613
+ if "rings_change_config" in config_dict
614
+ else None
615
+ )
616
+ strange_carbons_config = (
617
+ StrangeCarbonsConfig()
618
+ if "strange_carbons_config" in config_dict
619
+ else None
620
+ )
621
+ no_reaction_config = (
622
+ NoReactionConfig()
623
+ if "no_reaction_config" in config_dict
624
+ else None
625
+ )
626
+ multi_center_config = (
627
+ MultiCenterConfig()
628
+ if "multi_center_config" in config_dict
629
+ else None
630
+ )
631
+ wrong_ch_breaking_config = (
632
+ WrongCHBreakingConfig()
633
+ if "wrong_ch_breaking_config" in config_dict
634
+ else None
635
+ )
636
+ cc_sp3_breaking_config = (
637
+ CCsp3BreakingConfig()
638
+ if "cc_sp3_breaking_config" in config_dict
639
+ else None
640
+ )
641
+ cc_ring_breaking_config = (
642
+ CCRingBreakingConfig()
643
+ if "cc_ring_breaking_config" in config_dict
644
+ else None
645
+ )
646
+
647
+ # Extract other simple configuration parameters
648
+ rebalance_reaction = config_dict.get("rebalance_reaction", False)
649
+ remove_reagents = config_dict.get("remove_reagents", True)
650
+ reagents_max_size = config_dict.get("reagents_max_size", 7)
651
+ remove_small_molecules = config_dict.get("remove_small_molecules", False)
652
+ small_molecules_max_size = config_dict.get("small_molecules_max_size", 6)
653
+
654
+ return ReactionCheckConfig(
655
+ dynamic_bonds_config=dynamic_bonds_config,
656
+ small_molecules_config=small_molecules_config,
657
+ compete_products_config=compete_products_config,
658
+ cgr_connected_components_config=cgr_connected_components_config,
659
+ rings_change_config=rings_change_config,
660
+ strange_carbons_config=strange_carbons_config,
661
+ no_reaction_config=no_reaction_config,
662
+ multi_center_config=multi_center_config,
663
+ wrong_ch_breaking_config=wrong_ch_breaking_config,
664
+ cc_sp3_breaking_config=cc_sp3_breaking_config,
665
+ cc_ring_breaking_config=cc_ring_breaking_config,
666
+ rebalance_reaction=rebalance_reaction,
667
+ remove_reagents=remove_reagents,
668
+ reagents_max_size=reagents_max_size,
669
+ remove_small_molecules=remove_small_molecules,
670
+ small_molecules_max_size=small_molecules_max_size,
671
+ )
672
+
673
+ @staticmethod
674
+ def from_yaml(file_path):
675
+ """
676
+ Deserializes a YAML file into a ReactionCheckConfig object.
677
+ """
678
+ with open(file_path, "r") as file:
679
+ config_dict = yaml.safe_load(file)
680
+ return ReactionCheckConfig.from_dict(config_dict)
681
+
682
+ def _validate_params(self, params: Dict[str, Any]):
683
+ if not isinstance(params["rebalance_reaction"], bool):
684
+ raise ValueError("rebalance_reaction must be a boolean.")
685
+
686
+ if not isinstance(params["remove_reagents"], bool):
687
+ raise ValueError("remove_reagents must be a boolean.")
688
+
689
+ if not isinstance(params["reagents_max_size"], int):
690
+ raise ValueError("reagents_max_size must be an int.")
691
+
692
+ if not isinstance(params["remove_small_molecules"], bool):
693
+ raise ValueError("remove_small_molecules must be a boolean.")
694
+
695
+ if not isinstance(params["small_molecules_max_size"], int):
696
+ raise ValueError("small_molecules_max_size must be an int.")
697
+
698
+ def create_checkers(self):
699
+ checker_instances = []
700
+
701
+ if self.dynamic_bonds_config is not None:
702
+ checker_instances.append(
703
+ DynamicBondsChecker.from_config(self.dynamic_bonds_config)
704
+ )
705
+
706
+ if self.small_molecules_config is not None:
707
+ checker_instances.append(
708
+ SmallMoleculesChecker.from_config(self.small_molecules_config)
709
+ )
710
+
711
+ if self.strange_carbons_config is not None:
712
+ checker_instances.append(
713
+ StrangeCarbonsChecker.from_config(self.strange_carbons_config)
714
+ )
715
+
716
+ if self.compete_products_config is not None:
717
+ checker_instances.append(
718
+ CompeteProductsChecker.from_config(self.compete_products_config)
719
+ )
720
+
721
+ if self.cgr_connected_components_config is not None:
722
+ checker_instances.append(
723
+ CGRConnectedComponentsChecker.from_config(
724
+ self.cgr_connected_components_config
725
+ )
726
+ )
727
+
728
+ if self.rings_change_config is not None:
729
+ checker_instances.append(
730
+ RingsChangeChecker.from_config(self.rings_change_config)
731
+ )
732
+
733
+ if self.no_reaction_config is not None:
734
+ checker_instances.append(
735
+ NoReactionChecker.from_config(self.no_reaction_config)
736
+ )
737
+
738
+ if self.multi_center_config is not None:
739
+ checker_instances.append(
740
+ MultiCenterChecker.from_config(self.multi_center_config)
741
+ )
742
+
743
+ if self.wrong_ch_breaking_config is not None:
744
+ checker_instances.append(
745
+ WrongCHBreakingChecker.from_config(self.wrong_ch_breaking_config)
746
+ )
747
+
748
+ if self.cc_sp3_breaking_config is not None:
749
+ checker_instances.append(
750
+ CCsp3BreakingChecker.from_config(self.cc_sp3_breaking_config)
751
+ )
752
+
753
+ if self.cc_ring_breaking_config is not None:
754
+ checker_instances.append(
755
+ CCRingBreakingChecker.from_config(self.cc_ring_breaking_config)
756
+ )
757
+
758
+ return checker_instances
759
+
760
+
761
+ def tanimoto_kernel(x, y):
762
+ """
763
+ Calculate the Tanimoto coefficient between each element of arrays x and y.
764
+ """
765
+ x = x.astype(np.float64)
766
+ y = y.astype(np.float64)
767
+ x_dot = np.dot(x, y.T)
768
+ x2 = np.sum(x**2, axis=1)
769
+ y2 = np.sum(y**2, axis=1)
770
+
771
+ denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot
772
+ result = np.divide(x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0)
773
+
774
+ return result
775
+
776
+
777
+ def remove_file_if_exists(directory: Path, file_names): # TODO not used
778
+ for file_name in file_names:
779
+ file_path = directory / file_name
780
+ if file_path.is_file():
781
+ file_path.unlink()
782
+ logging.warning(f"Removed {file_path}")
783
+
784
+
785
+ def filter_reaction(reaction: ReactionContainer, config: ReactionCheckConfig, checkers: list):
786
+
787
+ is_filtered = False
788
+ if config.remove_small_molecules:
789
+ new_reaction = remove_small_molecules(reaction, number_of_atoms=config.small_molecules_max_size)
790
+ else:
791
+ new_reaction = reaction.copy()
792
+
793
+ if new_reaction is None:
794
+ is_filtered = True
795
+
796
+ if config.remove_reagents and not is_filtered:
797
+ new_reaction = remove_reagents(
798
+ new_reaction,
799
+ keep_reagents=True,
800
+ reagents_max_size=config.reagents_max_size,
801
+ )
802
+
803
+ if new_reaction is None:
804
+ is_filtered = True
805
+ new_reaction = reaction.copy()
806
+ # TODO you are specifying that if the reaction has only reagents, it is kept as it ?
807
+
808
+ if not is_filtered:
809
+ if config.rebalance_reaction:
810
+ new_reaction = rebalance_reaction(new_reaction)
811
+ for checker in checkers:
812
+ try: # TODO CGRTools: ValueError: mapping of graphs is not disjoint
813
+ if checker(new_reaction):
814
+ # If checker returns True it means the reaction doesn't pass the check
815
+ new_reaction.meta["filtration_log"] = checker.__class__.__name__
816
+ is_filtered = True
817
+ except:
818
+ is_filtered = True
819
+
820
+
821
+
822
+ return is_filtered, new_reaction
823
+
824
+
825
+ @ray.remote
826
+ def process_batch(batch, config: ReactionCheckConfig, checkers):
827
+ results = []
828
+ for index, reaction in batch:
829
+ try: # TODO CGRtools.exceptions.MappingError: atoms with number {52} not equal
830
+ is_filtered, processed_reaction = filter_reaction(reaction, config, checkers)
831
+ results.append((index, is_filtered, processed_reaction))
832
+ except:
833
+ results.append((index, True, reaction))
834
+ return results
835
+
836
+
837
+ def process_completed_batches(futures, result_file, pbar, treated: int = 0, passed_filters: int = 0):
838
+ done, _ = ray.wait(list(futures.keys()), num_returns=1)
839
+ completed_batch = ray.get(done[0])
840
+
841
+ # Write results of the completed batch to file
842
+ now_treated = 0
843
+ for index, is_filtered, reaction in completed_batch:
844
+ now_treated += 1
845
+ if not is_filtered:
846
+ result_file.write(reaction.meta['init_smiles'])
847
+ passed_filters += 1
848
+
849
+ # Remove completed future and update progress bar
850
+ del futures[done[0]]
851
+ pbar.update(now_treated)
852
+ treated += now_treated
853
+
854
+ return treated, passed_filters
855
+
856
+
857
+ def filter_reactions(
858
+ config: ReactionCheckConfig,
859
+ reaction_database_path: str,
860
+ result_reactions_file_name: str = "reaction_data_filtered.smi",
861
+ append_results: bool = False,
862
+ num_cpus: int = 1,
863
+ batch_size: int = 100,
864
+ ) -> None:
865
+ """
866
+ Processes a database of chemical reactions, applying checks based on the provided configuration,
867
+ and writes the results to specified files. All configurations are provided by the ReactionCheckConfig object.
868
+
869
+ :param config: ReactionCheckConfig object containing all configuration settings.
870
+ :param reaction_database_path: Path to the reaction database file.
871
+ :param result_reactions_file_name: Name for the file containing cleaned reactions.
872
+ :param append_results: Flag indicating whether to append results to existing files.
873
+ :param num_cpus: Number of CPUs to use for processing.
874
+ :param batch_size: Size of the batch for processing reactions.
875
+ :return: None. The function writes the processed reactions to specified RDF and pickle files.
876
+ Unique reactions are written if save_only_unique is True.
877
+ """
878
+
879
+ checkers = config.create_checkers()
880
+
881
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
882
+ max_concurrent_batches = num_cpus # Limit the number of concurrent batches
883
+
884
+ with ReactionReader(reaction_database_path) as reactions, \
885
+ ReactionWriter(result_reactions_file_name, append_results) as result_file:
886
+
887
+ pbar = tqdm(reactions, leave=True) # TODO fix progress bars
888
+
889
+ futures = {}
890
+ batch = []
891
+ treated = filtered = 0
892
+ for index, reaction in enumerate(reactions):
893
+ reaction.meta["reaction_index"] = index
894
+ batch.append((index, reaction))
895
+ if len(batch) == batch_size:
896
+ future = process_batch.remote(batch, config, checkers)
897
+ futures[future] = None
898
+ batch = []
899
+
900
+ # Check and process completed tasks if we've reached the concurrency limit
901
+ while len(futures) >= max_concurrent_batches:
902
+ treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered)
903
+
904
+ # Process the last batch if it's not empty
905
+ if batch:
906
+ future = process_batch.remote(batch, config, checkers)
907
+ futures[future] = None
908
+
909
+ # Process remaining batches
910
+ while futures:
911
+ treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered)
912
+
913
+ pbar.close()
914
+
915
+ ray.shutdown()
916
+ print(f'Initial number of reactions: {treated}'),
917
+ print(f'Removed number of reactions: {treated - filtered}')
SynTool/chem/data/mapping.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from os.path import splitext
3
+ from typing import Union
4
+ from tqdm import tqdm
5
+
6
+ from chython import smiles, RDFRead, RDFWrite, ReactionContainer
7
+ from chython.exceptions import MappingError, IncorrectSmiles
8
+
9
+ from SynTool.utils import path_type
10
+
11
+
12
+ def remove_reagents_and_map(rea: ReactionContainer, keep_reagent: bool = False) -> Union[ReactionContainer, None]:
13
+ """
14
+ Maps atoms of the reaction using chytorch.
15
+
16
+ :param rea: reaction to map
17
+ :type rea: ReactionContainer
18
+ :param keep_reagent: whenever to remove reagent or not
19
+ :type keep_reagent: bool
20
+
21
+ :return: ReactionContainer or None
22
+ """
23
+ try:
24
+ rea.reset_mapping()
25
+ except MappingError:
26
+ rea.reset_mapping() # Successive reset_mapping works
27
+ if not keep_reagent:
28
+ try:
29
+ rea.remove_reagents()
30
+ except:
31
+ return None
32
+ return rea
33
+
34
+
35
+ def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type, keep_reagent: bool = False) -> None:
36
+ """
37
+ Reads a file of reactions and maps atoms of the reactions using chytorch.
38
+
39
+ :param input_file: the path and name of the input file
40
+ :type input_file: path_type
41
+ :param output_file: the path and name of the output file
42
+ :type output_file: path_type
43
+ :param keep_reagent: whenever to remove reagent or not
44
+ :type keep_reagent: bool
45
+
46
+ :return: None
47
+ """
48
+ input_file = str(Path(input_file).resolve(strict=True))
49
+ _, input_ext = splitext(input_file)
50
+ if input_ext == ".smi":
51
+ input_file = open(input_file, "r")
52
+ elif input_ext == ".rdf":
53
+ input_file = RDFRead(input_file, indexable=True)
54
+ else:
55
+ raise ValueError("File extension not recognized. File:", input_file,
56
+ "- Please use smi or rdf file")
57
+ enumerator = input_file if input_ext == ".rdf" else input_file.readlines()
58
+
59
+ _, out_ext = splitext(output_file)
60
+ if out_ext == ".smi":
61
+ output_file = open(output_file, "w")
62
+ elif out_ext == ".rdf":
63
+ output_file = RDFWrite(output_file)
64
+ else:
65
+ raise ValueError("File extension not recognized. File:", output_file,
66
+ "- Please use smi or rdf file")
67
+
68
+ mapping_errors = 0
69
+ parsing_errors = 0
70
+ for rea_raw in tqdm(enumerator):
71
+ try:
72
+ rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw
73
+ except IncorrectSmiles:
74
+ parsing_errors += 1
75
+ continue
76
+ try:
77
+ rea_mapped = remove_reagents_and_map(rea, keep_reagent)
78
+ except MappingError:
79
+ try:
80
+ rea_mapped = remove_reagents_and_map(smiles(str(rea)), keep_reagent)
81
+ except MappingError:
82
+ mapping_errors += 1
83
+ continue
84
+ if rea_mapped:
85
+ rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea
86
+ output_file.write(rea_output)
87
+ else:
88
+ mapping_errors += 1
89
+
90
+ input_file.close()
91
+ output_file.close()
92
+
93
+ if parsing_errors:
94
+ print(parsing_errors, "reactions couldn't be parsed")
95
+ if mapping_errors:
96
+ print(mapping_errors, "reactions couldn't be mapped")
SynTool/chem/data/mapping.py.bk ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from os.path import splitext
3
+ from typing import Union
4
+ from tqdm import tqdm
5
+
6
+ from chython import smiles, RDFRead, RDFWrite, ReactionContainer
7
+ from chython.exceptions import MappingError
8
+
9
+ from Syntool.utils import path_type
10
+
11
+
12
+ def remove_reagents_and_map(rea: ReactionContainer) -> Union[ReactionContainer, None]:
13
+ """
14
+ Maps atoms of the reaction using chytorch.
15
+
16
+ :param rea: reaction to map
17
+ :type rea: ReactionContainer
18
+
19
+ :return: ReactionContainer or None
20
+ """
21
+ try:
22
+ rea.reset_mapping()
23
+ except MappingError:
24
+ rea.reset_mapping()
25
+ try:
26
+ rea.remove_reagents()
27
+ return rea
28
+ except:
29
+ # print("Error", str(rea))
30
+ return None
31
+
32
+
33
+ def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type) -> None:
34
+ """
35
+ Reads a file of reactions and maps atoms of the reactions using chytorch.
36
+
37
+ :param input_file: the path and name of the input file
38
+ :type input_file: path_type
39
+
40
+ :param output_file: the path and name of the output file
41
+ :type output_file: path_type
42
+
43
+ :return: None
44
+ """
45
+ input_file = str(Path(input_file).resolve(strict=True))
46
+ _, input_ext = splitext(input_file)
47
+ if input_ext == ".smi":
48
+ input_file = open(input_file, "r")
49
+ elif input_ext == ".rdf":
50
+ input_file = RDFRead(input_file, indexable=True)
51
+ else:
52
+ raise ValueError("File extension not recognized. File:", input_file,
53
+ "- Please use smi or rdf file")
54
+ enumerator = input_file if input_ext == ".rdf" else input_file.readlines()
55
+
56
+ _, out_ext = splitext(output_file)
57
+ if out_ext == ".smi":
58
+ output_file = open(output_file, "w")
59
+ elif out_ext == ".rdf":
60
+ output_file = RDFWrite(output_file)
61
+ else:
62
+ raise ValueError("File extension not recognized. File:", output_file,
63
+ "- Please use smi or rdf file")
64
+
65
+ mapping_errors = 0
66
+ parsing_errors = 0
67
+ for rea_raw in tqdm(enumerator):
68
+ try:
69
+ rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw
70
+ except:
71
+ parsing_errors += 1
72
+ print("Error", parsing_errors, rea_raw)
73
+ continue
74
+ try:
75
+ rea_mapped = remove_reagents_and_map(rea)
76
+ except:
77
+ parsing_errors += 1
78
+ print("Error for,", rea)
79
+ continue
80
+ if rea_mapped:
81
+ rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea
82
+ output_file.write(rea_output)
83
+ else:
84
+ mapping_errors += 1
85
+
86
+ input_file.close()
87
+ output_file.close()
88
+
89
+ if mapping_errors:
90
+ print(mapping_errors, "reactions couldn't be mapped")
SynTool/chem/data/standardizer.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################################################################
2
+ # Code issued from https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning
3
+ # Reaction_Data_Cleaning/scripts/standardizer.py
4
+ # version as it from commit 793475e54d8b2c7f714165a61e4eb439435d7d92
5
+ # DOI 10.1002/minf.202100119
6
+ #############################################################################
7
+ # Chemical reactions data curation best practices
8
+ # including optimized RDTool
9
+ #############################################################################
10
+ # GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html
11
+ #############################################################################
12
+ # Corresponding Authors: Timur Madzhidov and Alexandre Varnek
13
+ # Corresponding Authors' emails: [email protected] and [email protected]
14
+ # Main contributors: Arkadii Lin, Natalia Duybankova, Ramil Nugmanov, Rail Suleymanov and Timur Madzhidov
15
+ # Copyright: Copyright 2020,
16
+ # MaDeSmart, Machine Design of Small Molecules by AI
17
+ # VLAIO project HBC.2018.2287
18
+ # Credits: Kazan Federal University, Russia
19
+ # University of Strasbourg, France
20
+ # University of Linz, Austria
21
+ # University of Leuven, Belgium
22
+ # Janssen Pharmaceutica N.V., Beerse, Belgium
23
+ # Rail Suleymanov, Arcadia, St. Petersburg, Russia
24
+ # License: GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html
25
+ # Version: 00.02
26
+ #############################################################################
27
+
28
+ from CGRtools.files import RDFRead, RDFWrite, SDFWrite, SDFRead, SMILESRead
29
+ from CGRtools.containers import MoleculeContainer, ReactionContainer
30
+ import logging
31
+ from ordered_set import OrderedSet
32
+ import os
33
+ import io
34
+ import pathlib
35
+ from pathlib import PurePosixPath
36
+
37
+
38
+ class Standardizer:
39
+ def __init__(self, skip_errors=False, log_file=None, keep_unbalanced_ions=False, id_tag='Reaction_ID',
40
+ action_on_isotopes=0, keep_reagents=False, logger=None, ignore_mapping=False, jvm_path=None,
41
+ rdkit_dearomatization=False, remove_unchanged_parts=True, skip_tautomerize=True,
42
+ jchem_path=None, add_reagents_to_reactants=False) -> None:
43
+ if logger is None:
44
+ self.logger = self._config_log(log_file, logger_name='logger')
45
+ else:
46
+ self.logger = logger
47
+ self._skip_errors = skip_errors
48
+ self._keep_unbalanced_ions = keep_unbalanced_ions
49
+ self._id_tag = id_tag
50
+ self._action_on_isotopes = action_on_isotopes
51
+ self._keep_reagents = keep_reagents
52
+ self._ignore_mapping = ignore_mapping
53
+ self._remove_unchanged_parts_flag = remove_unchanged_parts
54
+ self._skip_tautomerize = skip_tautomerize
55
+ self._dearomatize_by_rdkit = rdkit_dearomatization
56
+ self._reagents_to_reactants = add_reagents_to_reactants
57
+ if not skip_tautomerize:
58
+ if jvm_path:
59
+ os.environ['JDK_HOME'] = jvm_path
60
+ os.environ['JAVA_HOME'] = jvm_path
61
+ os.environ['PATH'] += f';{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};' \
62
+ f'{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};'
63
+ if jchem_path:
64
+ import jnius_config
65
+ jnius_config.add_classpath(jchem_path)
66
+ from jnius import autoclass
67
+ Standardizer = autoclass('chemaxon.standardizer.Standardizer')
68
+ self._Molecule = autoclass('chemaxon.struc.Molecule')
69
+ self._MolHandler = autoclass('chemaxon.util.MolHandler')
70
+ self._standardizer = Standardizer('tautomerize')
71
+
72
+ def standardize_file(self, input_file=None) -> OrderedSet:
73
+ """
74
+ Standardize a set of reactions in a file. Returns an ordered set of ReactionContainer objects passed the
75
+ standardization protocol.
76
+ :param input_file: str
77
+ :return: OrderedSet
78
+ """
79
+ if pathlib.Path(input_file).suffix == '.rdf':
80
+ data = self._read_RDF(input_file)
81
+ elif pathlib.Path(input_file).suffix == '.smi' or pathlib.Path(input_file).suffix == '.smiles':
82
+ data = self._read_SMILES(input_file)
83
+ else:
84
+ raise ValueError('Data format is not recognized!')
85
+
86
+ print("{0} reactions passed..".format(len(data)))
87
+ return data
88
+
89
+ def _read_RDF(self, input_file) -> OrderedSet:
90
+ """
91
+ Reads an RDF file. Returns an ordered set of ReactionContainer objects passed the standardization protocol.
92
+ :param input_file: str
93
+ :return: OrderedSet
94
+ """
95
+ data = OrderedSet()
96
+ self.logger.info('Start..')
97
+ with RDFRead(input_file, ignore=self._ignore_mapping, store_log=True, remap=self._ignore_mapping) as ifile, \
98
+ open(input_file) as meta_searcher:
99
+ for reaction in ifile._data:
100
+ if isinstance(reaction, tuple):
101
+ meta_searcher.seek(reaction.position)
102
+ flag = False
103
+ for line in meta_searcher:
104
+ if flag and '$RFMT' in line:
105
+ self.logger.critical(f'Reaction id extraction problem rised for the reaction '
106
+ f'#{reaction.number + 1}: a reaction id was expected but $RFMT line '
107
+ f'was found!')
108
+ if flag:
109
+ self.logger.critical(f'Reaction {line.strip().split()[1]}: Parser has returned an error '
110
+ f'message\n{reaction.log}')
111
+ break
112
+ elif '$RFMT' in line:
113
+ self.logger.critical(f'Reaction #{reaction.number + 1} has no reaction id!')
114
+ elif f'$DTYPE {self._id_tag}' in line:
115
+ flag = True
116
+ continue
117
+ standardized_reaction = self.standardize(reaction)
118
+ if standardized_reaction:
119
+ if standardized_reaction not in data:
120
+ data.add(standardized_reaction)
121
+ else:
122
+ i = data.index(standardized_reaction)
123
+ if 'Extraction_IDs' not in data[i].meta:
124
+ data[i].meta['Extraction_IDs'] = ''
125
+ data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') +
126
+ [reaction.meta[self._id_tag]])
127
+ self.logger.info('Reaction {0} is a duplicate of the reaction {1}..'
128
+ .format(reaction.meta[self._id_tag], data[i].meta[self._id_tag]))
129
+ return data
130
+
131
+ def _read_SMILES(self, input_file) -> OrderedSet:
132
+ """
133
+ Reads a SMILES file. Returns an ordered set of ReactionContainer objects passed the standardization protocol.
134
+ :param input_file: str
135
+ :return: OrderedSet
136
+ """
137
+ data = OrderedSet()
138
+ self.logger.info('Start..')
139
+ with SMILESRead(input_file, ignore=True, store_log=True, remap=self._ignore_mapping, header=True) as ifile, \
140
+ open(input_file) as meta_searcher:
141
+ id_tag_position = meta_searcher.readline().strip().split().index(self._id_tag)
142
+ if id_tag_position is None or id_tag_position == 0:
143
+ self.logger.critical(f'No reaction ID tag was found in the header!')
144
+ raise ValueError(f'No reaction ID tag was found in the header!')
145
+ for reaction in ifile._data:
146
+ if isinstance(reaction, tuple):
147
+ meta_searcher.seek(reaction.position)
148
+ line = meta_searcher.readline().strip().split()
149
+ if len(line) <= id_tag_position:
150
+ self.logger.critical(f'No reaction ID tag was found in line {reaction.number}!')
151
+ raise ValueError(f'No reaction ID tag was found in line {reaction.number}!')
152
+ r_id = line[id_tag_position]
153
+ self.logger.critical(f'Reaction {r_id}: Parser has returned an error message\n{reaction.log}')
154
+ continue
155
+
156
+ standardized_reaction = self.standardize(reaction)
157
+ if standardized_reaction:
158
+ if standardized_reaction not in data:
159
+ data.add(standardized_reaction)
160
+ else:
161
+ i = data.index(standardized_reaction)
162
+ if 'Extraction_IDs' not in data[i].meta:
163
+ data[i].meta['Extraction_IDs'] = ''
164
+ data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') +
165
+ [reaction.meta[self._id_tag]])
166
+ self.logger.info('Reaction {0} is a duplicate of the reaction {1}..'
167
+ .format(reaction.meta[self._id_tag], data[i].meta[self._id_tag]))
168
+ return data
169
+
170
+ def standardize(self, reaction: ReactionContainer) -> ReactionContainer:
171
+ """
172
+ Standardization protocol: transform functional groups, kekulize, remove explicit hydrogens,
173
+ check for radicals (remove if something was found), check for isotopes, regroup ions (if the total charge
174
+ of reactants and/or products is not zero, and the 'keep_unbalanced_ions' option is False which is by default,
175
+ such reactions are removed; if the 'keep_unbalanced_ions' option is set True, they are kept), check valences
176
+ (remove if something is wrong), aromatize (thiele method), fix mapping (for symmetric functional groups) if
177
+ such is in, remove unchanged parts.
178
+ :param reaction: ReactionContainer
179
+ :return: ReactionContainer
180
+ """
181
+ self.logger.info('Reaction {0}..'.format(reaction.meta[self._id_tag]))
182
+ try:
183
+ reaction.standardize()
184
+ except:
185
+ self.logger.exception(
186
+ 'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag]))
187
+ if not self._skip_errors:
188
+ raise Exception(
189
+ 'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag]))
190
+ else:
191
+ return
192
+ try:
193
+ reaction.kekule()
194
+ except:
195
+ self.logger.exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag]))
196
+ if not self._skip_errors:
197
+ raise Exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag]))
198
+ else:
199
+ return
200
+ try:
201
+ if self._check_valence(reaction):
202
+ self.logger.info(
203
+ 'Reaction {0}: Bad valence: {1}'.format(reaction.meta[self._id_tag], reaction.meta['mistake']))
204
+ return
205
+ except:
206
+ self.logger.exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag]))
207
+ if not self._skip_errors:
208
+ self.logger.critical('Stop the algorithm!')
209
+ raise Exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag]))
210
+ else:
211
+ return
212
+ try:
213
+ if not self._skip_tautomerize:
214
+ reaction = self._tautomerize(reaction)
215
+ except:
216
+ self.logger.exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag]))
217
+ if not self._skip_errors:
218
+ raise Exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag]))
219
+ else:
220
+ return
221
+ try:
222
+ reaction.implicify_hydrogens()
223
+ except:
224
+ self.logger.exception(
225
+ 'Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag]))
226
+ if not self._skip_errors:
227
+ raise Exception('Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag]))
228
+ else:
229
+ return
230
+ try:
231
+ if self._check_radicals(reaction):
232
+ self.logger.info('Reaction {0}: Radicals were found..'.format(reaction.meta[self._id_tag]))
233
+ return
234
+ except:
235
+ self.logger.exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag]))
236
+ if not self._skip_errors:
237
+ raise Exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag]))
238
+ else:
239
+ return
240
+ try:
241
+ if self._action_on_isotopes == 1 and self._check_isotopes(reaction):
242
+ self.logger.info('Reaction {0}: Isotopes were found..'.format(reaction.meta[self._id_tag]))
243
+ return
244
+ elif self._action_on_isotopes == 2 and self._check_isotopes(reaction):
245
+ reaction.clean_isotopes()
246
+ self.logger.info('Reaction {0}: Isotopes were removed but the reaction was kept..'.format(
247
+ reaction.meta[self._id_tag]))
248
+ except:
249
+ self.logger.exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag]))
250
+ if not self._skip_errors:
251
+ raise Exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag]))
252
+ else:
253
+ return
254
+ try:
255
+ reaction, return_code = self._split_ions(reaction)
256
+ if return_code == 1:
257
+ self.logger.info('Reaction {0}: Ions were split..'.format(reaction.meta[self._id_tag]))
258
+ elif return_code == 2:
259
+ self.logger.info('Reaction {0}: Ions were split but the reaction is imbalanced..'.format(
260
+ reaction.meta[self._id_tag]))
261
+ if not self._keep_unbalanced_ions:
262
+ return
263
+ except:
264
+ self.logger.exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag]))
265
+ if not self._skip_errors:
266
+ raise Exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag]))
267
+ else:
268
+ return
269
+ try:
270
+ reaction.thiele()
271
+ except:
272
+ self.logger.exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag]))
273
+ if not self._skip_errors:
274
+ raise Exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag]))
275
+ else:
276
+ return
277
+ try:
278
+ reaction.fix_mapping()
279
+ except:
280
+ self.logger.exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag]))
281
+ if not self._skip_errors:
282
+ raise Exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag]))
283
+ else:
284
+ return
285
+ try:
286
+ if self._remove_unchanged_parts_flag:
287
+ reaction = self._remove_unchanged_parts(reaction)
288
+ if not reaction.reactants and reaction.products:
289
+ self.logger.info('Reaction {0}: Reactants are empty..'.format(reaction.meta[self._id_tag]))
290
+ return
291
+ if not reaction.products and reaction.reactants:
292
+ self.logger.info('Reaction {0}: Products are empty..'.format(reaction.meta[self._id_tag]))
293
+ return
294
+ if not reaction.reactants and not reaction.products:
295
+ self.logger.exception(
296
+ 'Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
297
+ reaction.meta[self._id_tag]))
298
+ return
299
+ except:
300
+ self.logger.exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
301
+ reaction.meta[self._id_tag]))
302
+ if not self._skip_errors:
303
+ raise Exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
304
+ reaction.meta[self._id_tag]))
305
+ else:
306
+ return
307
+ self.logger.debug('Reaction {0} is done..'.format(reaction.meta[self._id_tag]))
308
+ return reaction
309
+
310
+ def write(self, output_file: str, data: OrderedSet) -> None:
311
+ """
312
+ Dump a set of reactions.
313
+ :param data: OrderedSet
314
+ :param output_file: str
315
+ :return: None
316
+ """
317
+ with RDFWrite(output_file) as out:
318
+ for r in data:
319
+ out.write(r)
320
+
321
+ def _check_valence(self, reaction: ReactionContainer) -> bool:
322
+ """
323
+ Checks valences.
324
+ :param reaction: ReactionContainer
325
+ :return: bool
326
+ """
327
+ mistakes = []
328
+ for molecule in (reaction.reactants + reaction.products + reaction.reagents):
329
+ valence_mistakes = molecule.check_valence()
330
+ if valence_mistakes:
331
+ mistakes.append(("|".join([str(num) for num in valence_mistakes]),
332
+ "|".join([str(molecule.atom(n)) for n in valence_mistakes]), str(molecule)))
333
+ if mistakes:
334
+ message = ",".join([f'{atom_nums} at {atoms} in {smiles}' for atom_nums, atoms, smiles in mistakes])
335
+ reaction.meta['mistake'] = f'Valence mistake: {message}'
336
+ return True
337
+ return False
338
+
339
+ def _config_log(self, log_file: str, logger_name: str):
340
+ logger = logging.getLogger(logger_name)
341
+ logger.setLevel(logging.DEBUG)
342
+ formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
343
+ logger.handlers.clear()
344
+ fileHandler = logging.FileHandler(filename=log_file, mode='w')
345
+ fileHandler.setFormatter(formatter)
346
+ fileHandler.setLevel(logging.DEBUG)
347
+ logger.addHandler(fileHandler)
348
+ # logging.basicConfig(filename=log_file, level=logging.info, filemode='w', format='%(asctime)s: %(message)s',
349
+ # datefmt='%d/%m/%Y %H:%M:%S')
350
+ return logger
351
+
352
+ def _check_radicals(self, reaction: ReactionContainer) -> bool:
353
+ """
354
+ Checks radicals.
355
+ :param reaction: ReactionContainer
356
+ :return: bool
357
+ """
358
+ for molecule in (reaction.reactants + reaction.products + reaction.reagents):
359
+ for n, atom in molecule.atoms():
360
+ if atom.is_radical:
361
+ return True
362
+ return False
363
+
364
+ def _calc_charge(self, molecule: MoleculeContainer) -> int:
365
+ """Computing charge of molecule.
366
+ :param: molecule: MoleculeContainer
367
+ :return: int
368
+ """
369
+ return sum(molecule._charges.values())
370
+
371
+ def _group_ions(self, reaction: ReactionContainer):
372
+ """
373
+ Ungroup molecules recorded as ions, regroup ions. Returns a tuple with the corresponding ReactionContainer and
374
+ return code as int (0 - nothing was changed, 1 - ions were regrouped, 2 - ions are unbalanced).
375
+ :param reaction: current reaction
376
+ :return: tuple[ReactionContainer, int]
377
+ """
378
+ meta = reaction.meta
379
+ reaction_parts = []
380
+ return_codes = []
381
+ for molecules in (reaction.reactants, reaction.reagents, reaction.products):
382
+ divided_molecules = [x for m in molecules for x in m.split('.')]
383
+
384
+ if len(divided_molecules) == 0:
385
+ reaction_parts.append(())
386
+ continue
387
+ elif len(divided_molecules) == 1 and self._calc_charge(divided_molecules[0]) == 0:
388
+ return_codes.append(0)
389
+ reaction_parts.append(molecules)
390
+ continue
391
+ elif len(divided_molecules) == 1:
392
+ return_codes.append(2)
393
+ reaction_parts.append(molecules)
394
+ continue
395
+
396
+ new_molecules = []
397
+ cations, anions, ions = [], [], []
398
+ total_charge = 0
399
+ for molecule in divided_molecules:
400
+ mol_charge = self._calc_charge(molecule)
401
+ total_charge += mol_charge
402
+ if mol_charge == 0:
403
+ new_molecules.append(molecule)
404
+ elif mol_charge > 0:
405
+ cations.append((mol_charge, molecule))
406
+ ions.append((mol_charge, molecule))
407
+ else:
408
+ anions.append((mol_charge, molecule))
409
+ ions.append((mol_charge, molecule))
410
+
411
+ if len(cations) == 0 and len(anions) == 0:
412
+ return_codes.append(0)
413
+ reaction_parts.append(tuple(new_molecules))
414
+ continue
415
+ elif total_charge != 0:
416
+ return_codes.append(2)
417
+ reaction_parts.append(tuple(divided_molecules))
418
+ continue
419
+ else:
420
+ salt = MoleculeContainer()
421
+ for ion_charge, ion in ions:
422
+ salt = salt.union(ion)
423
+ total_charge += ion_charge
424
+ if total_charge == 0:
425
+ new_molecules.append(salt)
426
+ salt = MoleculeContainer()
427
+ if total_charge != 0:
428
+ new_molecules.append(salt)
429
+ return_codes.append(2)
430
+ reaction_parts.append(tuple(new_molecules))
431
+ else:
432
+ return_codes.append(1)
433
+ reaction_parts.append(tuple(new_molecules))
434
+ return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2],
435
+ meta=meta), max(return_codes)
436
+
437
+ def _split_ions(self, reaction: ReactionContainer):
438
+ """
439
+ Split ions in a reaction. Returns a tuple with the corresponding ReactionContainer and
440
+ a return code as int (0 - nothing was changed, 1 - ions were split, 2 - ions were split but the reaction
441
+ is imbalanced).
442
+ :param reaction: current reaction
443
+ :return: tuple[ReactionContainer, int]
444
+ """
445
+ meta = reaction.meta
446
+ reaction_parts = []
447
+ return_codes = []
448
+ for molecules in (reaction.reactants, reaction.reagents, reaction.products):
449
+ divided_molecules = [x for m in molecules for x in m.split('.')]
450
+
451
+ total_charge = 0
452
+ ions_present = False
453
+ for molecule in divided_molecules:
454
+ mol_charge = self._calc_charge(molecule)
455
+ total_charge += mol_charge
456
+ if mol_charge != 0:
457
+ ions_present = True
458
+
459
+ if ions_present and total_charge:
460
+ return_codes.append(2)
461
+ elif ions_present:
462
+ return_codes.append(1)
463
+ else:
464
+ return_codes.append(0)
465
+
466
+ reaction_parts.append(tuple(divided_molecules))
467
+
468
+ return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2],
469
+ meta=meta), max(return_codes)
470
+
471
+ def _remove_unchanged_parts(self, reaction: ReactionContainer) -> ReactionContainer:
472
+ """
473
+ Ungroup molecules, remove unchanged parts from reactants and products.
474
+ :param reaction: current reaction
475
+ :return: ReactionContainer
476
+ """
477
+ meta = reaction.meta
478
+ new_reactants = [m for m in reaction.reactants]
479
+ new_reagents = [m for m in reaction.reagents]
480
+ if self._reagents_to_reactants:
481
+ new_reactants.extend(new_reagents)
482
+ new_reagents = []
483
+ reactants = new_reactants.copy()
484
+ new_products = [m for m in reaction.products]
485
+
486
+ for reactant in reactants:
487
+ if reactant in new_products:
488
+ new_reagents.append(reactant)
489
+ new_reactants.remove(reactant)
490
+ new_products.remove(reactant)
491
+ if not self._keep_reagents:
492
+ new_reagents = []
493
+ return ReactionContainer(reactants=tuple(new_reactants), reagents=tuple(new_reagents),
494
+ products=tuple(new_products), meta=meta)
495
+
496
+ def _check_isotopes(self, reaction: ReactionContainer) -> bool:
497
+ for molecules in (reaction.reactants, reaction.products):
498
+ for molecule in molecules:
499
+ for _, atom in molecule.atoms():
500
+ if atom.isotope:
501
+ return True
502
+ return False
503
+
504
+ def _tautomerize(self, reaction: ReactionContainer) -> ReactionContainer:
505
+ """
506
+ Perform ChemAxon tautomerization.
507
+ :param reaction: reaction that needs to be tautomerized
508
+ :return: ReactionContainer
509
+ """
510
+ new_molecules = []
511
+ for part in [reaction.reactants, reaction.reagents, reaction.products]:
512
+ tmp = []
513
+ for mol in part:
514
+ with io.StringIO() as f, SDFWrite(f) as i:
515
+ i.write(mol)
516
+ sdf = f.getvalue()
517
+ mol_handler = self._MolHandler(sdf)
518
+ mol_handler.clean(True, '2')
519
+ molecule = mol_handler.getMolecule()
520
+ self._standardizer.standardize(molecule)
521
+ new_mol_handler = self._MolHandler(molecule)
522
+ new_sdf = new_mol_handler.toFormat('SDF')
523
+ with io.StringIO('\n ' + new_sdf.strip()) as f, SDFRead(f, remap=False) as i:
524
+ new_mol = next(i)
525
+ tmp.append(new_mol)
526
+ new_molecules.append(tmp)
527
+ return ReactionContainer(reactants=tuple(new_molecules[0]), reagents=tuple(new_molecules[1]),
528
+ products=tuple(new_molecules[2]), meta=reaction.meta)
529
+
530
+ # def _dearomatize_by_RDKit(self, reaction: ReactionContainer) -> ReactionContainer:
531
+ # """
532
+ # Dearomatizes by RDKit (needs in case of some mappers, such as RXNMapper).
533
+ # :param reaction: ReactionContainer
534
+ # :return: ReactionContainer
535
+ # """
536
+ # with io.StringIO() as f, RDFWrite(f) as i:
537
+ # i.write(reaction)
538
+ # s = '\n'.join(f.getvalue().split('\n')[3:])
539
+ # rxn = rdChemReactions.ReactionFromRxnBlock(s)
540
+ # reactants, reagents, products = [], [], []
541
+ # for mol in rxn.GetReactants():
542
+ # try:
543
+ # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
544
+ # except Chem.rdchem.KekulizeException:
545
+ # return reaction
546
+ # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
547
+ # reactants.append(next(sdf_i))
548
+ # for mol in rxn.GetAgents():
549
+ # try:
550
+ # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
551
+ # except Chem.rdchem.KekulizeException:
552
+ # return reaction
553
+ # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
554
+ # reagents.append(next(sdf_i))
555
+ # for mol in rxn.GetProducts():
556
+ # try:
557
+ # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
558
+ # except Chem.rdchem.KekulizeException:
559
+ # return reaction
560
+ # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
561
+ # products.append(next(sdf_i))
562
+ #
563
+ # new_reaction = ReactionContainer(reactants=tuple(reactants), reagents=tuple(reagents), products=tuple(products),
564
+ # meta=reaction.meta)
565
+ #
566
+ # return new_reaction
567
+
568
+
569
+ if __name__ == '__main__':
570
+ import argparse
571
+
572
+ parser = argparse.ArgumentParser(description="This is a tool for reaction standardization.",
573
+ epilog="Arkadii Lin, Strasbourg/Kazan 2020", prog="Standardizer")
574
+ parser.add_argument("-i", "--input", type=str, help="Input RDF file.")
575
+ parser.add_argument("-o", "--output", type=str, help="Output RDF file.")
576
+ parser.add_argument("-id", "--idTag", default='Reaction_ID', type=str, help="ID tag in the RDF file.")
577
+ parser.add_argument("--skipErrors", action="store_true", help="Skip errors.")
578
+ parser.add_argument("--keep_unbalanced_ions", action="store_true", help="Will keep reactions with unbalanced ions.")
579
+ parser.add_argument("--action_on_isotopes", type=int, default=0, help="Action performed if an isotope is "
580
+ "found: 0 - to ignore isotopes; "
581
+ "1 - to remove reactions with isotopes; "
582
+ "2 - to clear isotopes' labels.")
583
+ parser.add_argument("--keep_reagents", action="store_true", help="Will keep reagents from the reaction.")
584
+ parser.add_argument("--add_reagents", action="store_true", help="Will add the given reagents to reactants.")
585
+ parser.add_argument("--ignore_mapping", action="store_true", help="Will ignore the initial mapping in the file.")
586
+ parser.add_argument("--keep_unchanged_parts", action="store_true", help="Will keep unchanged parts in a reaction.")
587
+ parser.add_argument("--logFile", type=str, default='logFile.txt', help="Log file name.")
588
+ parser.add_argument("--skip_tautomerize", action="store_true", help="Will skip generation of the major tautomer.")
589
+ parser.add_argument("--rdkit_dearomatization", action="store_true", help="Will kekulize the reaction using RDKit "
590
+ "facilities.")
591
+ parser.add_argument("--jvm_path", type=str,
592
+ help="JVM path (e.g. C:\\Program Files\\Java\\jdk-13.0.2).")
593
+ parser.add_argument("--jchem_path", type=str, help="JChem path (e.g. C:\\Users\\user\\JChemSuite\\lib\\jchem.jar).")
594
+ args = parser.parse_args()
595
+
596
+ standardizer = Standardizer(skip_errors=args.skipErrors, log_file=args.logFile,
597
+ keep_unbalanced_ions=args.keep_unbalanced_ions, id_tag=args.idTag,
598
+ action_on_isotopes=args.action_on_isotopes, keep_reagents=args.keep_reagents,
599
+ ignore_mapping=args.ignore_mapping, skip_tautomerize=args.skip_tautomerize,
600
+ remove_unchanged_parts=(not args.keep_unchanged_parts), jvm_path=args.jvm_path,
601
+ jchem_path=args.jchem_path, rdkit_dearomatization=args.rdkit_dearomatization,
602
+ add_reagents_to_reactants=args.add_reagents)
603
+ data = standardizer.standardize_file(input_file=args.input)
604
+ standardizer.write(output_file=args.output, data=data)
SynTool/chem/reaction.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing classes and functions for manipulating reactions and reaction rules
3
+ """
4
+
5
+ from CGRtools.reactor import Reactor
6
+ from CGRtools.containers import MoleculeContainer, ReactionContainer
7
+ from CGRtools.exceptions import InvalidAromaticRing
8
+
9
+
10
+ class Reaction(ReactionContainer):
11
+ """
12
+ Reaction class can be used for a general representation of reaction for different chemoinformatics Python packages
13
+ """
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ """
17
+ Initializes the reaction object.
18
+ """
19
+ super().__init__(*args, **kwargs)
20
+
21
+
22
+ def add_small_mols(big_mol, small_molecules=None):
23
+ """
24
+ The function takes a molecule and returns a list of modified molecules where each small molecule has been added to
25
+ the big molecule.
26
+
27
+ :param big_mol: A molecule
28
+ :param small_molecules: A list of small molecules that need to be added to the molecule
29
+ :return: Returns a list of molecules.
30
+ """
31
+ if small_molecules:
32
+ tmp_mol = big_mol.copy()
33
+ transition_mapping = {}
34
+ for small_mol in small_molecules:
35
+
36
+ for n, atom in small_mol.atoms():
37
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
38
+ transition_mapping[n] = new_number
39
+
40
+ for atom, neighbor, bond in small_mol.bonds():
41
+ tmp_mol.add_bond(transition_mapping[atom], transition_mapping[neighbor], bond)
42
+
43
+ transition_mapping = {}
44
+ return tmp_mol.split()
45
+ else:
46
+ return [big_mol]
47
+
48
+
49
+ def apply_reaction_rule(
50
+ molecule: MoleculeContainer,
51
+ reaction_rule: Reactor,
52
+ sort_reactions: bool = False,
53
+ top_reactions_num: int = 3,
54
+ validate_products: bool = True,
55
+ rebuild_with_cgr: bool = False,
56
+ ) -> list[MoleculeContainer]:
57
+ """
58
+ The function applies a reaction rule to a given molecule.
59
+
60
+ :param rebuild_with_cgr:
61
+ :param validate_products:
62
+ :param sort_reactions:
63
+ :param top_reactions_num:
64
+ :param molecule: A MoleculeContainer object representing the molecule on which the reaction rule will be applied
65
+ :type molecule: MoleculeContainer
66
+ :param reaction_rule: The reaction_rule is an instance of the Reactor class. It represents a reaction rule that
67
+ can be applied to a molecule
68
+ :type reaction_rule: Reactor
69
+ """
70
+
71
+ reactants = add_small_mols(molecule, small_molecules=False)
72
+
73
+ try:
74
+ if sort_reactions:
75
+ unsorted_reactions = list(reaction_rule(reactants))
76
+ sorted_reactions = sorted(
77
+ unsorted_reactions,
78
+ key=lambda react: len(list(filter(lambda mol: len(mol) > 6, react.products))),
79
+ reverse=True
80
+ )
81
+ reactions = sorted_reactions[:top_reactions_num] # Take top-N reactions from reactor
82
+ else:
83
+ reactions = []
84
+ for reaction in reaction_rule(reactants):
85
+ reactions.append(reaction)
86
+ if len(reactions) == top_reactions_num:
87
+ break
88
+ except IndexError:
89
+ reactions = []
90
+
91
+ for reaction in reactions:
92
+ if rebuild_with_cgr:
93
+ cgr = reaction.compose()
94
+ products = cgr.decompose()[1].split()
95
+ else:
96
+ products = reaction.products
97
+ products = [mol for mol in products if len(mol) > 0]
98
+ if validate_products:
99
+ for molecule in products:
100
+ try:
101
+ molecule.kekule()
102
+ if molecule.check_valence():
103
+ yield None
104
+ molecule.thiele()
105
+ except InvalidAromaticRing:
106
+ yield None
107
+ yield products
SynTool/chem/reaction_rules/__init__.py ADDED
File without changes
SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc ADDED
Binary file (25.7 kB). View file
 
SynTool/chem/reaction_rules/extraction.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing functions with fixed protocol for reaction rules extraction
3
+ """
4
+ import logging
5
+ import pickle
6
+ from collections import defaultdict
7
+ from itertools import islice
8
+ from pathlib import Path
9
+ from typing import List, Union, Tuple, IO, Dict, Set, Iterable, Any
10
+ from os.path import splitext
11
+
12
+
13
+ import ray
14
+ from CGRtools.containers import MoleculeContainer, QueryContainer, ReactionContainer
15
+ from CGRtools.exceptions import InvalidAromaticRing
16
+ from CGRtools.reactor import Reactor
17
+ from tqdm.auto import tqdm
18
+
19
+ from SynTool.chem.utils import reverse_reaction
20
+ from SynTool.utils.config import RuleExtractionConfig
21
+ from SynTool.utils.files import ReactionReader
22
+
23
+
24
+ def extract_rules_from_reactions(
25
+ config: RuleExtractionConfig,
26
+ reaction_file: str,
27
+ rules_file_name: str = 'reaction_rules.pickle',
28
+ num_cpus: int = 1,
29
+ batch_size: int = 10,
30
+ ) -> None:
31
+ """
32
+ Extracts reaction rules from a set of reactions based on the given configuration.
33
+
34
+ This function initializes a Ray environment for distributed computing and processes each reaction
35
+ in the provided reaction database to extract reaction rules. It handles the reactions in batches,
36
+ parallelize the rule extraction process. Extracted rules are written to RDF files and their statistics
37
+ are recorded. The function also sorts the rules based on their popularity and saves the sorted rules.
38
+
39
+ :param config: Configuration settings for rule extraction, including file paths, batch size, and other parameters.
40
+ :param reaction_file: Path to the file containing reaction database.
41
+ :param rules_file_name: Name of the file to store the extracted rules.
42
+ :param num_cpus: Number of CPU cores to use for processing. Defaults to 1.
43
+ :param batch_size: Number of reactions to process in each batch. Defaults to 10.
44
+
45
+ :return: None
46
+ """
47
+
48
+ # read files
49
+ reaction_file = Path(reaction_file).resolve(strict=True)
50
+
51
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
52
+
53
+ rules_file_name, _ = splitext(rules_file_name)
54
+ with ReactionReader(reaction_file) as reactions:
55
+ pbar = tqdm(reactions, disable=False)
56
+
57
+ futures = {}
58
+ batch = []
59
+ max_concurrent_batches = num_cpus
60
+
61
+ extracted_rules_and_statistics = defaultdict(list)
62
+ for index, reaction in enumerate(reactions):
63
+ batch.append((index, reaction))
64
+ if len(batch) == batch_size:
65
+ future = process_reaction_batch.remote(batch, config)
66
+ futures[future] = None
67
+ batch = []
68
+
69
+ while len(futures) >= max_concurrent_batches:
70
+ process_completed_batches(futures, extracted_rules_and_statistics, pbar, batch_size)
71
+
72
+ if batch:
73
+ remaining_size = len(batch)
74
+ future = process_reaction_batch.remote(batch, config)
75
+ futures[future] = None
76
+
77
+ while futures:
78
+ process_completed_batches(futures, extracted_rules_and_statistics, pbar, remaining_size)
79
+
80
+ pbar.close()
81
+
82
+ sorted_rules = sort_rules(
83
+ extracted_rules_and_statistics,
84
+ min_popularity=config.min_popularity,
85
+ single_reactant_only=config.single_reactant_only,
86
+ )
87
+
88
+ with open(f"{rules_file_name}.pickle", "wb") as statistics_file:
89
+ pickle.dump(sorted_rules, statistics_file)
90
+ print(f'Number of extracted reaction rules: {len(sorted_rules)}')
91
+
92
+ ray.shutdown()
93
+
94
+
95
+ @ray.remote
96
+ def process_reaction_batch(
97
+ batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig
98
+ ) -> list[tuple[int, list[ReactionContainer]]]:
99
+ """
100
+ Processes a batch of reactions to extract reaction rules based on the given configuration.
101
+
102
+ This function operates as a remote task in a distributed system using Ray. It takes a batch of reactions,
103
+ where each reaction is paired with an index. For each reaction in the batch, it extracts reaction rules
104
+ as specified by the configuration object. The extracted rules for each reaction are then returned along
105
+ with the corresponding index.
106
+
107
+ :param batch: A list where each element is a tuple containing an index (int) and a ReactionContainer object.
108
+ The index is typically used to keep track of the reaction's position in a larger dataset.
109
+ :type batch: List[Tuple[int, ReactionContainer]]
110
+
111
+ :param config: An instance of ExtractRuleConfig that provides settings and parameters for the rule extraction process.
112
+ :type config: RuleExtractionConfig
113
+
114
+ :return: A list where each element is a tuple. The first element of the tuple is an index (int), and the second
115
+ is a list of ReactionContainer objects representing the extracted rules for the corresponding reaction.
116
+ :rtype: list[tuple[int, list[ReactionContainer]]]
117
+
118
+ This function is intended to be used in a distributed manner with Ray to parallelize the rule extraction
119
+ process across multiple reactions.
120
+ """
121
+ processed_batch = []
122
+ for index, reaction in batch:
123
+ try:
124
+ extracted_rules = extract_rules(config, reaction)
125
+ processed_batch.append((index, extracted_rules))
126
+ except:
127
+ continue
128
+ return processed_batch
129
+
130
+
131
+ def process_completed_batches(
132
+ futures: dict,
133
+ rules_statistics: Dict[ReactionContainer, List[int]],
134
+ pbar: tqdm,
135
+ batch_size: int,
136
+ ) -> None:
137
+ """
138
+ Processes completed batches of reactions, updating the rules statistics and writing rules to a file.
139
+
140
+ This function waits for the completion of a batch of reactions processed in parallel (using Ray),
141
+ updates the statistics for each extracted rule, and writes the rules to a result file if they are new.
142
+ It also updates the progress bar with the size of the processed batch.
143
+
144
+ :param futures: A dictionary of futures representing ongoing batch processing tasks.
145
+ :type futures: dict
146
+
147
+ :param rules_statistics: A dictionary to keep track of statistics for each rule.
148
+ :type rules_statistics: Dict[ReactionContainer, List[int]]
149
+
150
+ :param pbar: A tqdm progress bar instance for updating the progress of batch processing.
151
+ :type pbar: tqdm
152
+
153
+ :param batch_size: The number of reactions processed in each batch.
154
+ :type batch_size: int
155
+
156
+ :return: None
157
+ """
158
+ done, _ = ray.wait(list(futures.keys()), num_returns=1)
159
+ completed_batch = ray.get(done[0])
160
+
161
+ for index, extracted_rules in completed_batch:
162
+ for rule in extracted_rules:
163
+ prev_stats_len = len(rules_statistics)
164
+ rules_statistics[rule].append(index)
165
+ if len(rules_statistics) != prev_stats_len:
166
+ rule.meta["first_reaction_index"] = index
167
+
168
+ del futures[done[0]]
169
+ pbar.update(batch_size)
170
+
171
+
172
+ def extract_rules(
173
+ config: RuleExtractionConfig, reaction: ReactionContainer
174
+ ) -> list[ReactionContainer]:
175
+ """
176
+ Extracts reaction rules from a given reaction based on the specified configuration.
177
+
178
+ :param config: An instance of ExtractRuleConfig, which contains various configuration settings
179
+ for rule extraction, such as whether to include multicenter rules, functional groups,
180
+ ring structures, leaving and incoming groups, etc.
181
+ :param reaction: The reaction object (ReactionContainer) from which to extract rules. The reaction
182
+ object represents a chemical reaction with specified reactants, products, and possibly reagents.
183
+ :return: A list of ReactionContainer objects, each representing a distinct reaction rule. If
184
+ config.multicenter_rules is True, a single rule encompassing all reaction centers is returned.
185
+ Otherwise, separate rules for each reaction center are extracted, up to a maximum of 15 distinct centers.
186
+ """
187
+ if config.multicenter_rules:
188
+ # Extract a single rule encompassing all reaction centers
189
+ return [create_rule(config, reaction)]
190
+ else:
191
+ # Extract separate rules for each distinct reaction center
192
+ distinct_rules = set()
193
+ for center_reaction in islice(reaction.enumerate_centers(), 15):
194
+ single_rule = create_rule(config, center_reaction)
195
+ distinct_rules.add(single_rule)
196
+ return list(distinct_rules)
197
+
198
+
199
+ def create_rule(
200
+ config: RuleExtractionConfig, reaction: ReactionContainer
201
+ ) -> ReactionContainer:
202
+ """
203
+ Creates a reaction rule from a given reaction based on the specified configuration.
204
+
205
+ :param config: An instance of ExtractRuleConfig, containing various settings that determine how
206
+ the rule is created, such as environmental atom count, inclusion of functional groups,
207
+ rings, leaving and incoming groups, and other parameters.
208
+ :param reaction: The reaction object (ReactionContainer) from which to create the rule. This object
209
+ represents a chemical reaction with specified reactants, products, and possibly reagents.
210
+ :return: A ReactionContainer object representing the extracted reaction rule. This rule includes
211
+ various elements of the reaction as specified by the configuration, such as reaction centers,
212
+ environmental atoms, functional groups, and others.
213
+
214
+ The function processes the reaction to create a rule that matches the configuration settings. It handles
215
+ the inclusion of environmental atoms, functional groups, ring structures, and leaving and incoming groups.
216
+ It also constructs substructures for reactants, products, and reagents, and cleans molecule representations
217
+ if required. Optionally, it validates the rule using a reactor.
218
+ """
219
+ cgr = ~reaction
220
+ center_atoms = set(cgr.center_atoms)
221
+
222
+ # Add atoms of reaction environment based on config settings
223
+ center_atoms = add_environment_atoms(
224
+ cgr, center_atoms, config.environment_atom_count
225
+ )
226
+
227
+ # Include functional groups in the rule if specified in config
228
+ if config.include_func_groups:
229
+ rule_atoms = add_functional_groups(
230
+ reaction, center_atoms, config.func_groups_list
231
+ )
232
+ else:
233
+ rule_atoms = center_atoms.copy()
234
+
235
+ # Include ring structures in the rule if specified in config
236
+ if config.include_rings:
237
+ rule_atoms = add_ring_structures(
238
+ cgr,
239
+ rule_atoms,
240
+ )
241
+
242
+ # Add leaving and incoming groups to the rule based on config settings
243
+ rule_atoms, meta_debug = add_leaving_incoming_groups(
244
+ reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups
245
+ )
246
+
247
+ # Create substructures for reactants, products, and reagents
248
+ (
249
+ reactant_substructures,
250
+ product_substructures,
251
+ reagents,
252
+ ) = create_substructures_and_reagents(
253
+ reaction, rule_atoms, config.as_query_container, config.keep_reagents
254
+ )
255
+
256
+ # Clean atom marks in the molecules if they are being converted to query containers
257
+ if config.as_query_container:
258
+ reactant_substructures = clean_molecules(
259
+ reactant_substructures,
260
+ reaction.reactants,
261
+ center_atoms,
262
+ config.atom_info_retention,
263
+ )
264
+ product_substructures = clean_molecules(
265
+ product_substructures,
266
+ reaction.products,
267
+ center_atoms,
268
+ config.atom_info_retention,
269
+ )
270
+
271
+ # Assemble the final rule including metadata if specified
272
+ rule = assemble_final_rule(
273
+ reactant_substructures,
274
+ product_substructures,
275
+ reagents,
276
+ meta_debug,
277
+ config.keep_metadata,
278
+ reaction,
279
+ )
280
+
281
+ if config.reverse_rule:
282
+ rule = reverse_reaction(rule)
283
+ reaction = reverse_reaction(reaction)
284
+
285
+ # Validate the rule using a reactor if validation is enabled in config
286
+ if config.reactor_validation:
287
+ if validate_rule(rule, reaction):
288
+ rule.meta["reactor_validation"] = "passed"
289
+ else:
290
+ rule.meta["reactor_validation"] = "failed"
291
+
292
+ return rule
293
+
294
+
295
+ def add_environment_atoms(cgr, center_atoms, environment_atom_count):
296
+ """
297
+ Adds environment atoms to the set of center atoms based on the specified depth.
298
+
299
+ :param cgr: A complete graph representation of a reaction (ReactionContainer object).
300
+ :param center_atoms: A set of atom identifiers representing the center atoms of the reaction.
301
+ :param environment_atom_count: An integer specifying the depth of the environment around
302
+ the reaction center to be included. If it's 0, only the
303
+ reaction center is included. If it's 1, the first layer of
304
+ surrounding atoms is included, and so on.
305
+ :return: A set of atom identifiers including the center atoms and their environment atoms
306
+ up to the specified depth. If environment_atom_count is 0, the original set of
307
+ center atoms is returned unchanged.
308
+ """
309
+ if environment_atom_count:
310
+ env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count)
311
+ # Combine the original center atoms with the new environment atoms
312
+ return center_atoms | set(env_cgr)
313
+
314
+ # If no environment is to be included, return the original center atoms
315
+ return center_atoms
316
+
317
+
318
+ def add_functional_groups(reaction, center_atoms, func_groups_list):
319
+ """
320
+ Augments the set of rule atoms with functional groups if specified.
321
+
322
+ :param reaction: The reaction object (ReactionContainer) from which molecules are extracted.
323
+ :param center_atoms: A set of atom identifiers representing the center atoms of the reaction.
324
+ :param func_groups_list: A list of functional group objects (MoleculeContainer or QueryContainer)
325
+ to be considered when including functional groups. These objects define
326
+ the structure of the functional groups to be included.
327
+ :return: A set of atom identifiers representing the rule atoms, including atoms from the
328
+ specified functional groups if include_func_groups is True. If include_func_groups
329
+ is False, the original set of center atoms is returned.
330
+ """
331
+ rule_atoms = center_atoms.copy()
332
+ # Iterate over each molecule in the reaction
333
+ for molecule in reaction.molecules():
334
+ # For each functional group specified in the list
335
+ for func_group in func_groups_list:
336
+ # Find mappings of the functional group in the molecule
337
+ for mapping in func_group.get_mapping(molecule):
338
+ # Remap the functional group based on the found mapping
339
+ func_group.remap(mapping)
340
+ # If the functional group intersects with center atoms, include it
341
+ if set(func_group.atoms_numbers) & center_atoms:
342
+ rule_atoms |= set(func_group.atoms_numbers)
343
+ # Reset the mapping to its original state for the next iteration
344
+ func_group.remap({v: k for k, v in mapping.items()})
345
+ return rule_atoms
346
+
347
+
348
+ def add_ring_structures(cgr, rule_atoms):
349
+ """
350
+ Appends ring structures to the set of rule atoms if they intersect with the reaction center atoms.
351
+
352
+ :param cgr: A condensed graph representation of a reaction (CGRContainer object).
353
+ :param rule_atoms: A set of atom identifiers representing the center atoms of the reaction.
354
+ :return: A set of atom identifiers including the original rule atoms and the included ring structures.
355
+ """
356
+ for ring in cgr.sssr:
357
+ # Check if the current ring intersects with the set of rule atoms
358
+ if set(ring) & rule_atoms:
359
+ # If the intersection exists, include all atoms in the ring to the rule atoms
360
+ rule_atoms |= set(ring)
361
+ return rule_atoms
362
+
363
+
364
+ def add_leaving_incoming_groups(
365
+ reaction, rule_atoms, keep_leaving_groups, keep_incoming_groups
366
+ ):
367
+ """
368
+ Identifies and includes leaving and incoming groups to the rule atoms based on specified flags.
369
+
370
+ :param reaction: The reaction object (ReactionContainer) from which leaving and incoming groups are extracted.
371
+ :param rule_atoms: A set of atom identifiers representing the center atoms of the reaction.
372
+ :param keep_leaving_groups: A boolean flag indicating whether to include leaving groups in the rule.
373
+ :param keep_incoming_groups: A boolean flag indicating whether to include incoming groups in the rule.
374
+ :return: Updated set of rule atoms including leaving and incoming groups if specified, and metadata about added groups.
375
+ """
376
+ meta_debug = {"leaving": set(), "incoming": set()}
377
+
378
+ # Extract atoms from reactants and products
379
+ reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant}
380
+ product_atoms = {atom for product in reaction.products for atom in product}
381
+
382
+ # Identify leaving groups (reactant atoms not in products)
383
+ if keep_leaving_groups:
384
+ leaving_atoms = reactant_atoms - product_atoms
385
+ new_leaving_atoms = leaving_atoms - rule_atoms
386
+ # Include leaving atoms in the rule atoms
387
+ rule_atoms |= leaving_atoms
388
+ # Add leaving atoms to metadata
389
+ meta_debug["leaving"] |= new_leaving_atoms
390
+
391
+ # Identify incoming groups (product atoms not in reactants)
392
+ if keep_incoming_groups:
393
+ incoming_atoms = product_atoms - reactant_atoms
394
+ new_incoming_atoms = incoming_atoms - rule_atoms
395
+ # Include incoming atoms in the rule atoms
396
+ rule_atoms |= incoming_atoms
397
+ # Add incoming atoms to metadata
398
+ meta_debug["incoming"] |= new_incoming_atoms
399
+
400
+ return rule_atoms, meta_debug
401
+
402
+
403
+ def clean_molecules(
404
+ rule_molecules: Iterable[QueryContainer],
405
+ reaction_molecules: Iterable[MoleculeContainer],
406
+ reaction_center_atoms: Set[int],
407
+ atom_retention_details: Dict[str, Dict[str, bool]],
408
+ ) -> List[QueryContainer]:
409
+ """
410
+ Cleans rule molecules by removing specified information about atoms based on retention details provided.
411
+
412
+ :param rule_molecules: A list of query container objects representing the rule molecules.
413
+ :param reaction_molecules: A list of molecule container objects involved in the reaction.
414
+ :param reaction_center_atoms: A set of integers representing atom numbers in the reaction center.
415
+ :param atom_retention_details: A dictionary specifying what atom information to retain or remove.
416
+ This dictionary should have two keys: "reaction_center" and "environment",
417
+ each mapping to another dictionary. The nested dictionaries should have
418
+ keys representing atom attributes (like "neighbors", "hybridization",
419
+ "implicit_hydrogens", "ring_sizes") and boolean values. A value of True
420
+ indicates that the corresponding attribute should be retained,
421
+ while False indicates it should be removed from the atom.
422
+
423
+ For example:
424
+ {
425
+ "reaction_center": {"neighbors": True, "hybridization": False, ...},
426
+ "environment": {"neighbors": True, "implicit_hydrogens": False, ...}
427
+ }
428
+
429
+ Returns:
430
+ A list of QueryContainer objects representing the cleaned rule molecules.
431
+ """
432
+ cleaned_rule_molecules = []
433
+
434
+ for rule_molecule in rule_molecules:
435
+ for reaction_molecule in reaction_molecules:
436
+ if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers):
437
+ query_reaction_molecule = reaction_molecule.substructure(
438
+ reaction_molecule, as_query=True
439
+ )
440
+ query_rule_molecule = query_reaction_molecule.substructure(
441
+ rule_molecule
442
+ )
443
+
444
+ # Clean environment atoms
445
+ if not all(
446
+ atom_retention_details["environment"].values()
447
+ ): # if everything True, we keep all marks
448
+ local_environment_atoms = (
449
+ set(rule_molecule.atoms_numbers) - reaction_center_atoms
450
+ )
451
+ for atom_number in local_environment_atoms:
452
+ query_rule_molecule = clean_atom(
453
+ query_rule_molecule,
454
+ atom_retention_details["environment"],
455
+ atom_number,
456
+ )
457
+
458
+ # Clean reaction center atoms
459
+ if not all(
460
+ atom_retention_details["reaction_center"].values()
461
+ ): # if everything True, we keep all marks
462
+ local_reaction_center_atoms = (
463
+ set(rule_molecule.atoms_numbers) & reaction_center_atoms
464
+ )
465
+ for atom_number in local_reaction_center_atoms:
466
+ query_rule_molecule = clean_atom(
467
+ query_rule_molecule,
468
+ atom_retention_details["reaction_center"],
469
+ atom_number,
470
+ )
471
+
472
+ cleaned_rule_molecules.append(query_rule_molecule)
473
+ break
474
+
475
+ return cleaned_rule_molecules
476
+
477
+
478
+ def clean_atom(
479
+ query_molecule: QueryContainer,
480
+ attributes_to_keep: Dict[str, bool],
481
+ atom_number: int,
482
+ ) -> QueryContainer:
483
+ """
484
+ Removes specified information from a given atom in a query molecule.
485
+
486
+ :param query_molecule: The QueryContainer of molecule.
487
+ :param attributes_to_keep: Dictionary indicating which attributes to keep in the atom.
488
+ The keys should be strings representing the attribute names, and
489
+ the values should be booleans indicating whether to retain (True)
490
+ or remove (False) that attribute. Expected keys are:
491
+ - "neighbors": Indicates if neighbors of the atom should be removed.
492
+ - "hybridization": Indicates if hybridization information of the atom should be removed.
493
+ - "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed.
494
+ - "ring_sizes": Indicates if ring size information of the atom should be removed.
495
+ :param atom_number: The number of the atom to be modified in the query molecule.
496
+ """
497
+ target_atom = query_molecule.atom(atom_number)
498
+
499
+ if not attributes_to_keep["neighbors"]:
500
+ target_atom.neighbors = None
501
+ if not attributes_to_keep["hybridization"]:
502
+ target_atom.hybridization = None
503
+ if not attributes_to_keep["implicit_hydrogens"]:
504
+ target_atom.implicit_hydrogens = None
505
+ if not attributes_to_keep["ring_sizes"]:
506
+ target_atom.ring_sizes = None
507
+
508
+ return query_molecule
509
+
510
+
511
+ def create_substructures_and_reagents(
512
+ reaction, rule_atoms, as_query_container, keep_reagents
513
+ ):
514
+ """
515
+ Creates substructures for reactants and products, and optionally includes reagents, based on specified parameters.
516
+
517
+ :param reaction: The reaction object (ReactionContainer) from which to extract substructures. This object
518
+ represents a chemical reaction with specified reactants, products, and possibly reagents.
519
+ :param rule_atoms: A set of atom identifiers that define the rule atoms. These are used to identify relevant
520
+ substructures in reactants and products.
521
+ :param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers.
522
+ Query containers are used for pattern matching in chemical structures.
523
+ :param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures.
524
+ Reagents are additional substances that are present in the reaction but are not reactants or products.
525
+
526
+ :return: A tuple containing three elements:
527
+ - A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms.
528
+ - A list of product substructures, each corresponding to a part of the products that matches the rule atoms.
529
+ - A list of reagents, included as is or as substructures, depending on the as_query_container flag.
530
+
531
+ The function processes the reaction to create substructures for reactants and products based on the rule atoms.
532
+ It also handles the inclusion of reagents based on the keep_reagents flag and converts these structures to query
533
+ containers if required.
534
+ """
535
+ reactant_substructures = [
536
+ reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers))
537
+ for reactant in reaction.reactants
538
+ if rule_atoms.intersection(reactant.atoms_numbers)
539
+ ]
540
+
541
+ product_substructures = [
542
+ product.substructure(rule_atoms.intersection(product.atoms_numbers))
543
+ for product in reaction.products
544
+ if rule_atoms.intersection(product.atoms_numbers)
545
+ ]
546
+
547
+ reagents = []
548
+ if keep_reagents:
549
+ if as_query_container:
550
+ reagents = [
551
+ reagent.substructure(reagent, as_query=True)
552
+ for reagent in reaction.reagents
553
+ ]
554
+ else:
555
+ reagents = reaction.reagents
556
+
557
+ return reactant_substructures, product_substructures, reagents
558
+
559
+
560
+ def assemble_final_rule(
561
+ reactant_substructures,
562
+ product_substructures,
563
+ reagents,
564
+ meta_debug,
565
+ keep_metadata,
566
+ reaction,
567
+ ):
568
+ """
569
+ Assembles the final reaction rule from the provided substructures and metadata.
570
+
571
+ :param reactant_substructures: A list of substructures derived from the reactants of the reaction.
572
+ These substructures represent parts of reactants that are relevant to the rule.
573
+ :param product_substructures: A list of substructures derived from the products of the reaction.
574
+ These substructures represent parts of products that are relevant to the rule.
575
+ :param reagents: A list of reagents involved in the reaction. These may be included as-is or as substructures,
576
+ depending on earlier processing steps.
577
+ :param meta_debug: A dictionary containing additional metadata about the reaction, such as leaving and incoming groups.
578
+ :param keep_metadata: A boolean flag indicating whether to retain the metadata associated with the reaction in the rule.
579
+ :param reaction: The original reaction object (ReactionContainer) from which the rule is being created.
580
+
581
+ :return: A ReactionContainer object representing the assembled reaction rule. This container includes
582
+ the reactant and product substructures, reagents, and any additional metadata if keep_metadata is True.
583
+
584
+ This function brings together the various components of a reaction rule, including reactant and product substructures,
585
+ reagents, and metadata. It creates a comprehensive representation of the reaction rule, which can be used for further
586
+ processing or analysis.
587
+ """
588
+ rule_metadata = meta_debug if keep_metadata else {}
589
+ rule_metadata.update(reaction.meta if keep_metadata else {})
590
+
591
+ rule = ReactionContainer(
592
+ reactant_substructures, product_substructures, reagents, rule_metadata
593
+ )
594
+
595
+ if keep_metadata:
596
+ rule.name = reaction.name
597
+
598
+ rule.flush_cache()
599
+ return rule
600
+
601
+
602
+ def validate_rule(rule: ReactionContainer, reaction: ReactionContainer):
603
+ """
604
+ Validates a reaction rule by ensuring it can correctly generate the products from the reactants.
605
+
606
+ :param rule: The reaction rule to be validated. This is a ReactionContainer object representing a chemical reaction rule,
607
+ which includes the necessary information to perform a reaction.
608
+ :param reaction: The original reaction object (ReactionContainer) against which the rule is to be validated. This object
609
+ contains the actual reactants and products of the reaction.
610
+
611
+ :return: The validated rule if the rule correctly generates the products from the reactants.
612
+
613
+ :raises ValueError: If the rule does not correctly generate the products from the reactants, indicating
614
+ an incorrect or incomplete rule.
615
+
616
+ The function uses a chemical reactor to simulate the reaction based on the provided rule. It then compares
617
+ the products generated by the simulation with the actual products of the reaction. If they match, the rule
618
+ is considered valid. If not, a ValueError is raised, indicating an issue with the rule.
619
+ """
620
+ # Create a reactor with the given rule
621
+ reactor = Reactor(rule)
622
+ try:
623
+ for result_reaction in reactor(reaction.reactants):
624
+ result_products = []
625
+ for result_product in result_reaction.products:
626
+ tmp = result_product.copy()
627
+ try:
628
+ tmp.kekule()
629
+ if tmp.check_valence():
630
+ continue
631
+ except InvalidAromaticRing:
632
+ continue
633
+ result_products.append(result_product)
634
+ if set(reaction.products) == set(result_products) and len(
635
+ reaction.products
636
+ ) == len(result_products):
637
+ return True
638
+ except (KeyError, IndexError):
639
+ # KeyError - iteration over reactor is finished and products are different from the original reaction
640
+ # IndexError - mistake in __contract_ions, possibly problems with charges in rule?
641
+ return False
642
+
643
+
644
+ def sort_rules(
645
+ rules_stats: Dict[ReactionContainer, List[int]],
646
+ min_popularity: int = 3,
647
+ single_reactant_only: bool = True,
648
+ ) -> List[Tuple[ReactionContainer, List[int]]]:
649
+ """
650
+ Sorts reaction rules based on their popularity and validation status.
651
+
652
+ This function sorts the given rules according to their popularity (i.e., the number of times they have been
653
+ applied) and filters out rules that haven't passed reactor validation or are less popular than the specified
654
+ minimum popularity threshold.
655
+
656
+ :param rules_stats: A dictionary where each key is a reaction rule and the value is a list of integers.
657
+ Each integer represents an index where the rule was applied.
658
+ :type rules_stats: Dict[ReactionContainer, List[int]]
659
+
660
+ :param min_popularity: The minimum number of times a rule must be applied to be considered. Default is 3.
661
+ :type min_popularity: int
662
+
663
+ :param single_reactant_only: Whether to keep only reaction rules with a single molecule on the right side
664
+ of reaction arrow. Default is True.
665
+
666
+ :return: A list of tuples, where each tuple contains a reaction rule and a list of indices representing
667
+ the rule's applications. The list is sorted in descending order of the rule's popularity.
668
+ :rtype: List[Tuple[ReactionContainer, List[int]]]
669
+ """
670
+ return sorted(
671
+ (
672
+ (rule, indices)
673
+ for rule, indices in rules_stats.items()
674
+ if len(indices) >= min_popularity
675
+ and rule.meta["reactor_validation"] == "passed"
676
+ and (not single_reactant_only or len(rule.reactants) == 1)
677
+ ),
678
+ key=lambda x: -len(x[1]),
679
+ )
SynTool/chem/reaction_rules/manual/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .decompositions import rules as d_rules
2
+ from .transformations import rules as t_rules
3
+
4
+ hardcoded_rules = t_rules + d_rules
5
+
6
+ __all__ = ["hardcoded_rules"]
SynTool/chem/reaction_rules/manual/decompositions.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing hardcoded decomposition reaction rules
3
+ """
4
+
5
+ from CGRtools import QueryContainer, ReactionContainer
6
+ from CGRtools.periodictable import ListElement
7
+
8
+ rules = []
9
+
10
+
11
+ def prepare():
12
+ """
13
+ Creates and returns three query containers and appends a reaction container to the "rules" list
14
+ """
15
+ q_ = QueryContainer()
16
+ p1_ = QueryContainer()
17
+ p2_ = QueryContainer()
18
+ rules.append(ReactionContainer((q_,), (p1_, p2_)))
19
+ return q_, p1_, p2_
20
+
21
+
22
+ # R-amide/ester formation
23
+ # [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O]
24
+ q, p1, p2 = prepare()
25
+ q.add_atom('C')
26
+ q.add_atom('C')
27
+ q.add_atom('O')
28
+ q.add_atom(ListElement(['N', 'O']), hybridization=1, neighbors=(2, 3))
29
+ q.add_bond(1, 2, 1)
30
+ q.add_bond(2, 3, 2)
31
+ q.add_bond(2, 4, 1)
32
+
33
+ p1.add_atom('C')
34
+ p1.add_atom('C')
35
+ p1.add_atom('O')
36
+ p1.add_atom('O', _map=5)
37
+ p1.add_bond(1, 2, 1)
38
+ p1.add_bond(2, 3, 2)
39
+ p1.add_bond(2, 5, 1)
40
+
41
+ p2.add_atom('A', _map=4)
42
+
43
+ # acyl group addition with aromatic carbon's case (Friedel-Crafts)
44
+ # [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O]
45
+ q, p1, p2 = prepare()
46
+ q.add_atom('C')
47
+ q.add_atom('C')
48
+ q.add_atom('O')
49
+ q.add_atom('C', hybridization=4)
50
+ q.add_bond(1, 2, 1)
51
+ q.add_bond(2, 3, 2)
52
+ q.add_bond(2, 4, 1)
53
+
54
+ p1.add_atom('C')
55
+ p1.add_atom('C')
56
+ p1.add_atom('O')
57
+ p1.add_atom('Cl', _map=5)
58
+ p1.add_bond(1, 2, 1)
59
+ p1.add_bond(2, 3, 2)
60
+ p1.add_bond(2, 5, 1)
61
+
62
+ p2.add_atom('C', _map=4)
63
+
64
+ # Williamson reaction
65
+ # [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O]
66
+ q, p1, p2 = prepare()
67
+ q.add_atom('C', hybridization=4)
68
+ q.add_atom('O')
69
+ q.add_atom('C', hybridization=1, heteroatoms=1)
70
+ q.add_bond(1, 2, 1)
71
+ q.add_bond(2, 3, 1)
72
+
73
+ p1.add_atom('C')
74
+ p1.add_atom('O')
75
+ p1.add_bond(1, 2, 1)
76
+
77
+ p2.add_atom('C', _map=3)
78
+ p2.add_atom('Br')
79
+ p2.add_bond(3, 4, 1)
80
+
81
+ # Buchwald-Hartwig amination
82
+ # [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N]
83
+ q, p1, p2 = prepare()
84
+ q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3))
85
+ q.add_atom('C', hybridization=4)
86
+ q.add_bond(1, 2, 1)
87
+
88
+ p1.add_atom('C', _map=2)
89
+ p1.add_atom('Br')
90
+ p1.add_bond(2, 3, 1)
91
+
92
+ p2.add_atom('N')
93
+
94
+ # imidazole imine atom's alkylation
95
+ # [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N]
96
+ q, p1, p2 = prepare()
97
+ q.add_atom('N', rings_sizes=5)
98
+ q.add_atom('C', rings_sizes=5)
99
+ q.add_atom('N', rings_sizes=5, neighbors=2)
100
+ q.add_atom('C', hybridization=1, heteroatoms=(1, 2))
101
+ q.add_bond(1, 2, 4)
102
+ q.add_bond(2, 3, 4)
103
+ q.add_bond(1, 4, 1)
104
+
105
+ p1.add_atom('N')
106
+ p1.add_atom('C')
107
+ p1.add_atom('N')
108
+ p1.add_bond(1, 2, 4)
109
+ p1.add_bond(2, 3, 4)
110
+
111
+ p2.add_atom('C', _map=4)
112
+ p2.add_atom('Br')
113
+ p2.add_bond(4, 5, 1)
114
+
115
+ # Knoevenagel condensation (nitryl and carboxyl case)
116
+ # [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O]
117
+ q, p1, p2 = prepare()
118
+ q.add_atom('C')
119
+ q.add_atom('C')
120
+ q.add_atom('C')
121
+ q.add_atom('N')
122
+ q.add_atom('C')
123
+ q.add_atom('O')
124
+ q.add_atom('O')
125
+ q.add_bond(1, 2, 2)
126
+ q.add_bond(2, 3, 1)
127
+ q.add_bond(3, 4, 3)
128
+ q.add_bond(2, 5, 1)
129
+ q.add_bond(5, 6, 2)
130
+ q.add_bond(5, 7, 1)
131
+
132
+ p1.add_atom('C', _map=2)
133
+ p1.add_atom('C')
134
+ p1.add_atom('N')
135
+ p1.add_atom('C')
136
+ p1.add_atom('O')
137
+ p1.add_atom('O')
138
+ p1.add_bond(2, 3, 1)
139
+ p1.add_bond(3, 4, 3)
140
+ p1.add_bond(2, 5, 1)
141
+ p1.add_bond(5, 6, 2)
142
+ p1.add_bond(5, 7, 1)
143
+
144
+ p2.add_atom('C', _map=1)
145
+ p2.add_atom('O', _map=8)
146
+ p2.add_bond(1, 8, 2)
147
+
148
+ # Knoevenagel condensation (double nitryl case)
149
+ # [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N]
150
+ q, p1, p2 = prepare()
151
+ q.add_atom('C')
152
+ q.add_atom('C')
153
+ q.add_atom('C')
154
+ q.add_atom('N')
155
+ q.add_atom('C')
156
+ q.add_atom('N')
157
+ q.add_bond(1, 2, 2)
158
+ q.add_bond(2, 3, 1)
159
+ q.add_bond(3, 4, 3)
160
+ q.add_bond(2, 5, 1)
161
+ q.add_bond(5, 6, 3)
162
+
163
+ p1.add_atom('C', _map=2)
164
+ p1.add_atom('C')
165
+ p1.add_atom('N')
166
+ p1.add_atom('C')
167
+ p1.add_atom('N')
168
+ p1.add_bond(2, 3, 1)
169
+ p1.add_bond(3, 4, 3)
170
+ p1.add_bond(2, 5, 1)
171
+ p1.add_bond(5, 6, 3)
172
+
173
+ p2.add_atom('C', _map=1)
174
+ p2.add_atom('O', _map=8)
175
+ p2.add_bond(1, 8, 2)
176
+
177
+ # Knoevenagel condensation (double carboxyl case)
178
+ # [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O]
179
+ q, p1, p2 = prepare()
180
+ q.add_atom('C')
181
+ q.add_atom('C')
182
+ q.add_atom('C')
183
+ q.add_atom('O')
184
+ q.add_atom('O')
185
+ q.add_atom('C')
186
+ q.add_atom('O')
187
+ q.add_atom('O')
188
+ q.add_bond(1, 2, 2)
189
+ q.add_bond(2, 3, 1)
190
+ q.add_bond(3, 4, 2)
191
+ q.add_bond(3, 5, 1)
192
+ q.add_bond(2, 6, 1)
193
+ q.add_bond(6, 7, 2)
194
+ q.add_bond(6, 8, 1)
195
+
196
+ p1.add_atom('C', _map=2)
197
+ p1.add_atom('C')
198
+ p1.add_atom('O')
199
+ p1.add_atom('O')
200
+ p1.add_atom('C')
201
+ p1.add_atom('O')
202
+ p1.add_atom('O')
203
+ p1.add_bond(2, 3, 1)
204
+ p1.add_bond(3, 4, 2)
205
+ p1.add_bond(3, 5, 1)
206
+ p1.add_bond(2, 6, 1)
207
+ p1.add_bond(6, 7, 2)
208
+ p1.add_bond(6, 8, 1)
209
+
210
+ p2.add_atom('C', _map=1)
211
+ p2.add_atom('O', _map=9)
212
+ p2.add_bond(1, 9, 2)
213
+
214
+ # heterocyclization with guanidine
215
+ # [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O]
216
+ q, p1, p2 = prepare()
217
+ q.add_atom('C')
218
+ q.add_atom('N', heteroatoms=0, hybridization=1)
219
+ q.add_atom('N')
220
+ q.add_atom('C')
221
+ q.add_atom('N', neighbors=1)
222
+ q.add_atom('C', heteroatoms=0)
223
+ q.add_atom('N')
224
+ q.add_atom('C')
225
+ q.add_atom('O', neighbors=1)
226
+ q.add_bond(1, 2, 1)
227
+ q.add_bond(1, 3, 4)
228
+ q.add_bond(3, 4, 4)
229
+ q.add_bond(4, 5, 1)
230
+ q.add_bond(4, 6, 4)
231
+ q.add_bond(1, 7, 4)
232
+ q.add_bond(7, 8, 4)
233
+ q.add_bond(8, 9, 1)
234
+
235
+ p1.add_atom('C')
236
+ p1.add_atom('N')
237
+ p1.add_atom('N')
238
+ p1.add_atom('N', _map=7)
239
+ p1.add_bond(1, 2, 1)
240
+ p1.add_bond(1, 3, 2)
241
+ p1.add_bond(1, 7, 1)
242
+
243
+ p2.add_atom('C', _map=4)
244
+ p2.add_atom('N')
245
+ p2.add_atom('C')
246
+ p2.add_atom('C', _map=8)
247
+ p2.add_atom('O', _map=9)
248
+ p2.add_atom('O')
249
+ p2.add_bond(4, 5, 3)
250
+ p2.add_bond(4, 6, 1)
251
+ p2.add_bond(6, 8, 1)
252
+ p2.add_bond(8, 9, 2)
253
+ p2.add_bond(8, 10, 1)
254
+
255
+ # alkylation of amine
256
+ # [C]-[N]-[C]>>[C]-[N].[C]-[Br]
257
+ q, p1, p2 = prepare()
258
+ q.add_atom('C')
259
+ q.add_atom('N')
260
+ q.add_atom('C')
261
+ q.add_atom('C')
262
+ q.add_bond(1, 2, 1)
263
+ q.add_bond(2, 3, 1)
264
+ q.add_bond(2, 4, 1)
265
+
266
+ p1.add_atom('C')
267
+ p1.add_atom('N')
268
+ p1.add_atom('C')
269
+ p1.add_bond(1, 2, 1)
270
+ p1.add_bond(2, 3, 1)
271
+
272
+ p2.add_atom('C', _map=4)
273
+ p2.add_atom('Cl')
274
+ p2.add_bond(4, 5, 1)
275
+
276
+ # Synthesis of guanidines
277
+ #
278
+ q, p1, p2 = prepare()
279
+ q.add_atom('N')
280
+ q.add_atom('C')
281
+ q.add_atom('N', hybridization=1)
282
+ q.add_atom('N', hybridization=1)
283
+ q.add_bond(1, 2, 2)
284
+ q.add_bond(2, 3, 1)
285
+ q.add_bond(2, 4, 1)
286
+
287
+ p1.add_atom('N')
288
+ p1.add_atom('C')
289
+ p1.add_atom('N')
290
+ p1.add_bond(1, 2, 3)
291
+ p1.add_bond(2, 3, 1)
292
+
293
+ p2.add_atom('N', _map=4)
294
+
295
+ # Grignard reaction with nitrile
296
+ #
297
+ q, p1, p2 = prepare()
298
+ q.add_atom('C')
299
+ q.add_atom('C')
300
+ q.add_atom('O')
301
+ q.add_atom('C')
302
+ q.add_bond(1, 2, 1)
303
+ q.add_bond(2, 3, 2)
304
+ q.add_bond(2, 4, 1)
305
+
306
+ p1.add_atom('C')
307
+ p1.add_atom('C')
308
+ p1.add_atom('N')
309
+ p1.add_bond(1, 2, 1)
310
+ p1.add_bond(2, 3, 3)
311
+
312
+ p2.add_atom('C', _map=4)
313
+ p2.add_atom('Br')
314
+ p2.add_bond(4, 5, 1)
315
+
316
+ # Alkylation of alpha-carbon atom of nitrile
317
+ #
318
+ q, p1, p2 = prepare()
319
+ q.add_atom('N')
320
+ q.add_atom('C')
321
+ q.add_atom('C', neighbors=(3, 4))
322
+ q.add_atom('C', hybridization=1)
323
+ q.add_bond(1, 2, 3)
324
+ q.add_bond(2, 3, 1)
325
+ q.add_bond(3, 4, 1)
326
+
327
+ p1.add_atom('N')
328
+ p1.add_atom('C')
329
+ p1.add_atom('C')
330
+ p1.add_bond(1, 2, 3)
331
+ p1.add_bond(2, 3, 1)
332
+
333
+ p2.add_atom('C', _map=4)
334
+ p2.add_atom('Cl')
335
+ p2.add_bond(4, 5, 1)
336
+
337
+ # Gomberg-Bachmann reaction
338
+ #
339
+ q, p1, p2 = prepare()
340
+ q.add_atom('C', hybridization=4, heteroatoms=0)
341
+ q.add_atom('C', hybridization=4, heteroatoms=0)
342
+ q.add_bond(1, 2, 1)
343
+
344
+ p1.add_atom('C')
345
+ p1.add_atom('N', _map=3)
346
+ p1.add_bond(1, 3, 1)
347
+
348
+ p2.add_atom('C', _map=2)
349
+
350
+ # Cyclocondensation
351
+ #
352
+ q, p1, p2 = prepare()
353
+ q.add_atom('N', neighbors=2)
354
+ q.add_atom('C')
355
+ q.add_atom('C')
356
+ q.add_atom('C')
357
+ q.add_atom('N')
358
+ q.add_atom('C')
359
+ q.add_atom('C')
360
+ q.add_atom('O', neighbors=1)
361
+ q.add_bond(1, 2, 1)
362
+ q.add_bond(2, 3, 1)
363
+ q.add_bond(3, 4, 1)
364
+ q.add_bond(4, 5, 2)
365
+ q.add_bond(5, 6, 1)
366
+ q.add_bond(6, 7, 1)
367
+ q.add_bond(7, 8, 2)
368
+ q.add_bond(1, 7, 1)
369
+
370
+ p1.add_atom('N')
371
+ p1.add_atom('C')
372
+ p1.add_atom('C')
373
+ p1.add_atom('C')
374
+ p1.add_atom('O', _map=9)
375
+ p1.add_bond(1, 2, 1)
376
+ p1.add_bond(2, 3, 1)
377
+ p1.add_bond(3, 4, 1)
378
+ p1.add_bond(4, 9, 2)
379
+
380
+ p2.add_atom('N', _map=5)
381
+ p2.add_atom('C')
382
+ p2.add_atom('C')
383
+ p2.add_atom('O')
384
+ p2.add_atom('O', _map=10)
385
+ p2.add_bond(5, 6, 1)
386
+ p2.add_bond(6, 7, 1)
387
+ p2.add_bond(7, 8, 2)
388
+ p2.add_bond(7, 10, 1)
389
+
390
+ # heterocyclization dicarboxylic acids
391
+ #
392
+ q, p1, p2 = prepare()
393
+ q.add_atom('C', rings_sizes=(5, 6))
394
+ q.add_atom('O')
395
+ q.add_atom(ListElement(['O', 'N']))
396
+ q.add_atom('C', rings_sizes=(5, 6))
397
+ q.add_atom('O')
398
+ q.add_bond(1, 2, 2)
399
+ q.add_bond(1, 3, 1)
400
+ q.add_bond(3, 4, 1)
401
+ q.add_bond(4, 5, 2)
402
+
403
+ p1.add_atom('C')
404
+ p1.add_atom('O')
405
+ p1.add_atom('O', _map=6)
406
+ p1.add_bond(1, 2, 2)
407
+ p1.add_bond(1, 6, 1)
408
+
409
+ p2.add_atom('C', _map=4)
410
+ p2.add_atom('O')
411
+ p2.add_atom('O', _map=7)
412
+ p2.add_bond(4, 5, 2)
413
+ p2.add_bond(4, 7, 1)
414
+
415
+ __all__ = ['rules']
SynTool/chem/reaction_rules/manual/transformations.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing hardcoded transformation reaction rules
3
+ """
4
+
5
+ from CGRtools import QueryContainer, ReactionContainer
6
+ from CGRtools.periodictable import ListElement
7
+
8
+ rules = []
9
+
10
+
11
+ def prepare():
12
+ """
13
+ Creates and returns three query containers and appends a reaction container to the "rules" list
14
+ """
15
+ q_ = QueryContainer()
16
+ p_ = QueryContainer()
17
+ rules.append(ReactionContainer((q_,), (p_,)))
18
+ return q_, p_
19
+
20
+
21
+ # aryl nitro reduction
22
+ # [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O]
23
+ q, p = prepare()
24
+ q.add_atom('N', neighbors=1)
25
+ q.add_atom('C', hybridization=4, heteroatoms=1)
26
+ q.add_bond(1, 2, 1)
27
+
28
+ p.add_atom('N', charge=1)
29
+ p.add_atom('C')
30
+ p.add_atom('O', charge=-1)
31
+ p.add_atom('O')
32
+ p.add_bond(1, 2, 1)
33
+ p.add_bond(1, 3, 1)
34
+ p.add_bond(1, 4, 2)
35
+
36
+ # aryl nitration
37
+ # [O-]-[N+](=[O])-[C;Za;W12]>>[C]
38
+ q, p = prepare()
39
+ q.add_atom('N', charge=1)
40
+ q.add_atom('C', hybridization=4, heteroatoms=(1, 2))
41
+ q.add_atom('O', charge=-1)
42
+ q.add_atom('O')
43
+ q.add_bond(1, 2, 1)
44
+ q.add_bond(1, 3, 1)
45
+ q.add_bond(1, 4, 2)
46
+
47
+ p.add_atom('C', _map=2)
48
+
49
+ # Beckmann rearrangement (oxime -> amide)
50
+ # [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C]
51
+ q, p = prepare()
52
+ q.add_atom('C')
53
+ q.add_atom('N', neighbors=2)
54
+ q.add_atom('O')
55
+ q.add_atom('C')
56
+ q.add_bond(1, 2, 1)
57
+ q.add_bond(1, 3, 2)
58
+ q.add_bond(2, 4, 1)
59
+
60
+ p.add_atom('C')
61
+ p.add_atom('N')
62
+ p.add_atom('O')
63
+ p.add_atom('C')
64
+ p.add_bond(1, 2, 2)
65
+ p.add_bond(2, 3, 1)
66
+ p.add_bond(1, 4, 1)
67
+
68
+ # aldehydes or ketones into oxime/imine reaction
69
+ # [C;Zd;W1]=[N]>>[C]=[O]
70
+ q, p = prepare()
71
+ q.add_atom('C', hybridization=2, heteroatoms=1)
72
+ q.add_atom('N')
73
+ q.add_bond(1, 2, 2)
74
+
75
+ p.add_atom('C')
76
+ p.add_atom('O', _map=3)
77
+ p.add_bond(1, 3, 2)
78
+
79
+ # addition of halogen atom into phenol ring (orto)
80
+ # [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C]
81
+ q, p = prepare()
82
+ q.add_atom(ListElement(['O', 'N']), hybridization=1)
83
+ q.add_atom('C')
84
+ q.add_atom('C')
85
+ q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1)
86
+ q.add_bond(1, 2, 1)
87
+ q.add_bond(2, 3, 4)
88
+ q.add_bond(3, 4, 1)
89
+
90
+ p.add_atom('A')
91
+ p.add_atom('C')
92
+ p.add_atom('C')
93
+ p.add_bond(1, 2, 1)
94
+ p.add_bond(2, 3, 4)
95
+
96
+ # addition of halogen atom into phenol ring (para)
97
+ # [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C]
98
+ q, p = prepare()
99
+ q.add_atom(ListElement(['O', 'N']), hybridization=1)
100
+ q.add_atom('C')
101
+ q.add_atom('C')
102
+ q.add_atom('C')
103
+ q.add_atom('C')
104
+ q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1)
105
+ q.add_bond(1, 2, 1)
106
+ q.add_bond(2, 3, 4)
107
+ q.add_bond(3, 4, 4)
108
+ q.add_bond(4, 5, 4)
109
+ q.add_bond(5, 6, 1)
110
+
111
+ p.add_atom('A')
112
+ p.add_atom('C')
113
+ p.add_atom('C')
114
+ p.add_atom('C')
115
+ p.add_atom('C')
116
+ p.add_bond(1, 2, 1)
117
+ p.add_bond(2, 3, 4)
118
+ p.add_bond(3, 4, 4)
119
+ p.add_bond(4, 5, 4)
120
+
121
+ # hard reduction of Ar-ketones
122
+ # [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O]
123
+ q, p = prepare()
124
+ q.add_atom('C', hybridization=4)
125
+ q.add_atom('C', hybridization=1, neighbors=2, heteroatoms=0)
126
+ q.add_bond(1, 2, 1)
127
+
128
+ p.add_atom('C')
129
+ p.add_atom('C')
130
+ p.add_atom('O')
131
+ p.add_bond(1, 2, 1)
132
+ p.add_bond(2, 3, 2)
133
+
134
+ # reduction of alpha-hydroxy pyridine
135
+ # [C;W1]:[N;H0;r6]>>[C](:[N])-[O]
136
+ q, p = prepare()
137
+ q.add_atom('C', heteroatoms=1)
138
+ q.add_atom('N', rings_sizes=6, hydrogens=0)
139
+ q.add_bond(1, 2, 4)
140
+
141
+ p.add_atom('C')
142
+ p.add_atom('N')
143
+ p.add_atom('O')
144
+ p.add_bond(1, 2, 4)
145
+ p.add_bond(1, 3, 1)
146
+
147
+ # Reduction of alkene
148
+ # [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C]
149
+ q, p = prepare()
150
+ q.add_atom('C')
151
+ q.add_atom('C', heteroatoms=0, neighbors=(2, 3), hybridization=1)
152
+ q.add_atom('C', heteroatoms=0, neighbors=(1, 2, 3), hybridization=1)
153
+ q.add_bond(1, 2, 1)
154
+ q.add_bond(2, 3, 1)
155
+
156
+ p.add_atom('C')
157
+ p.add_atom('C')
158
+ p.add_atom('C')
159
+ p.add_bond(1, 2, 1)
160
+ p.add_bond(2, 3, 2)
161
+
162
+ # Kolbe-Schmitt reaction
163
+ # [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C]
164
+ q, p = prepare()
165
+ q.add_atom('O', neighbors=1)
166
+ q.add_atom('C')
167
+ q.add_atom('C')
168
+ q.add_atom('C')
169
+ q.add_atom('O', neighbors=1)
170
+ q.add_atom('O')
171
+ q.add_bond(1, 2, 1)
172
+ q.add_bond(2, 3, 4)
173
+ q.add_bond(3, 4, 1)
174
+ q.add_bond(4, 5, 1)
175
+ q.add_bond(4, 6, 2)
176
+
177
+ p.add_atom('O')
178
+ p.add_atom('C')
179
+ p.add_atom('C')
180
+ p.add_bond(1, 2, 1)
181
+ p.add_bond(2, 3, 4)
182
+
183
+ # reduction of carboxylic acid
184
+ # [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O]
185
+ q, p = prepare()
186
+ q.add_atom('C')
187
+ q.add_atom('C', neighbors=2)
188
+ q.add_atom('O', neighbors=1)
189
+ q.add_bond(1, 2, 1)
190
+ q.add_bond(2, 3, 1)
191
+
192
+ p.add_atom('C')
193
+ p.add_atom('C')
194
+ p.add_atom('O')
195
+ p.add_atom('O')
196
+ p.add_bond(1, 2, 1)
197
+ p.add_bond(2, 3, 1)
198
+ p.add_bond(2, 4, 2)
199
+
200
+ # halogenation of alcohols
201
+ # [C;Zs]-[Cl,Br;D1]>>[C]-[O]
202
+ q, p = prepare()
203
+ q.add_atom('C', hybridization=1, heteroatoms=1)
204
+ q.add_atom(ListElement(['Cl', 'Br']), neighbors=1)
205
+ q.add_bond(1, 2, 1)
206
+
207
+ p.add_atom('C')
208
+ p.add_atom('O', _map=3)
209
+ p.add_bond(1, 3, 1)
210
+
211
+ # Kolbe nitrilation
212
+ # [N]#[C]-[C;Zs;W0]>>[Br]-[C]
213
+ q, p = prepare()
214
+ q.add_atom('C', heteroatoms=0, hybridization=1)
215
+ q.add_atom('C')
216
+ q.add_atom('N')
217
+ q.add_bond(1, 2, 1)
218
+ q.add_bond(2, 3, 3)
219
+
220
+ p.add_atom('C')
221
+ p.add_atom('Br', _map=4)
222
+ p.add_bond(1, 4, 1)
223
+
224
+ # Nitrile hydrolysis
225
+ # [O;D1]-[C]=[O]>>[N]#[C]
226
+ q, p = prepare()
227
+ q.add_atom('C')
228
+ q.add_atom('O', neighbors=1)
229
+ q.add_atom('O')
230
+ q.add_bond(1, 2, 1)
231
+ q.add_bond(1, 3, 2)
232
+
233
+ p.add_atom('C')
234
+ p.add_atom('N', _map=4)
235
+ p.add_bond(1, 4, 3)
236
+
237
+ # sulfamidation
238
+ # [c]-[S](=[O])(=[O])-[N]>>[c]
239
+ q, p = prepare()
240
+ q.add_atom('C', hybridization=4)
241
+ q.add_atom('S')
242
+ q.add_atom('O')
243
+ q.add_atom('O')
244
+ q.add_atom('N', neighbors=1)
245
+ q.add_bond(1, 2, 1)
246
+ q.add_bond(2, 3, 2)
247
+ q.add_bond(2, 4, 2)
248
+ q.add_bond(2, 5, 1)
249
+
250
+ p.add_atom('C')
251
+
252
+ # Ring expansion rearrangement
253
+ #
254
+ q, p = prepare()
255
+ q.add_atom('C')
256
+ q.add_atom('N')
257
+ q.add_atom('C', rings_sizes=6)
258
+ q.add_atom('C')
259
+ q.add_atom('O')
260
+ q.add_atom('C')
261
+ q.add_atom('C')
262
+ q.add_bond(1, 2, 1)
263
+ q.add_bond(2, 3, 1)
264
+ q.add_bond(3, 4, 1)
265
+ q.add_bond(4, 5, 2)
266
+ q.add_bond(3, 6, 1)
267
+ q.add_bond(4, 7, 1)
268
+
269
+ p.add_atom('C')
270
+ p.add_atom('N')
271
+ p.add_atom('C')
272
+ p.add_atom('C')
273
+ p.add_atom('O')
274
+ p.add_atom('C')
275
+ p.add_atom('C')
276
+ p.add_bond(1, 2, 1)
277
+ p.add_bond(2, 3, 2)
278
+ p.add_bond(3, 4, 1)
279
+ p.add_bond(4, 5, 1)
280
+ p.add_bond(4, 6, 1)
281
+ p.add_bond(4, 7, 1)
282
+
283
+ # hydrolysis of bromide alkyl
284
+ #
285
+ q, p = prepare()
286
+ q.add_atom('C', hybridization=1)
287
+ q.add_atom('O', neighbors=1)
288
+ q.add_bond(1, 2, 1)
289
+
290
+ p.add_atom('C')
291
+ p.add_atom('Br')
292
+ p.add_bond(1, 2, 1)
293
+
294
+ # Condensation of ketones/aldehydes and amines into imines
295
+ #
296
+ q, p = prepare()
297
+ q.add_atom('N', neighbors=(1, 2))
298
+ q.add_atom('C', neighbors=(2, 3), heteroatoms=1)
299
+ q.add_bond(1, 2, 2)
300
+
301
+ p.add_atom('C', _map=2)
302
+ p.add_atom('O')
303
+ p.add_bond(2, 3, 2)
304
+
305
+ # Halogenation of alkanes
306
+ #
307
+ q, p = prepare()
308
+ q.add_atom('C', hybridization=1)
309
+ q.add_atom(ListElement(['F', 'Cl', 'Br']))
310
+ q.add_bond(1, 2, 1)
311
+
312
+ p.add_atom('C')
313
+
314
+ # heterocyclization
315
+ #
316
+ q, p = prepare()
317
+ q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3))
318
+ q.add_atom('C', heteroatoms=2)
319
+ q.add_atom('N', heteroatoms=0, neighbors=2)
320
+ q.add_bond(1, 2, 1)
321
+ q.add_bond(2, 3, 2)
322
+
323
+ p.add_atom('N')
324
+ p.add_atom('C')
325
+ p.add_atom('N')
326
+ p.add_atom('O')
327
+ p.add_bond(1, 2, 1)
328
+ p.add_bond(2, 4, 2)
329
+
330
+ # Reduction of nitrile
331
+ #
332
+ q, p = prepare()
333
+ q.add_atom('N', neighbors=1)
334
+ q.add_atom('C')
335
+ q.add_atom('C', hybridization=1)
336
+ q.add_bond(1, 2, 1)
337
+ q.add_bond(2, 3, 1)
338
+
339
+ p.add_atom('N')
340
+ p.add_atom('C')
341
+ p.add_atom('C')
342
+ p.add_bond(1, 2, 3)
343
+ p.add_bond(2, 3, 1)
344
+
345
+ # SPECIAL CASE
346
+ # Reduction of nitrile into methylamine
347
+ #
348
+ q, p = prepare()
349
+ q.add_atom('C', neighbors=1)
350
+ q.add_atom('N', neighbors=2)
351
+ q.add_atom('C')
352
+ q.add_atom('C', hybridization=1)
353
+ q.add_bond(1, 2, 1)
354
+ q.add_bond(2, 3, 1)
355
+ q.add_bond(3, 4, 1)
356
+
357
+ p.add_atom('N', _map=2)
358
+ p.add_atom('C')
359
+ p.add_atom('C')
360
+ p.add_bond(2, 3, 3)
361
+ p.add_bond(3, 4, 1)
362
+
363
+ # methylation of amides
364
+ #
365
+ q, p = prepare()
366
+ q.add_atom('O')
367
+ q.add_atom('C')
368
+ q.add_atom('N')
369
+ q.add_atom('C', neighbors=1)
370
+ q.add_bond(1, 2, 2)
371
+ q.add_bond(2, 3, 1)
372
+ q.add_bond(3, 4, 1)
373
+
374
+ p.add_atom('O')
375
+ p.add_atom('C')
376
+ p.add_atom('N')
377
+ p.add_bond(1, 2, 2)
378
+ p.add_bond(2, 3, 1)
379
+
380
+ # hydrocyanation of alkenes
381
+ #
382
+ q, p = prepare()
383
+ q.add_atom('C', hybridization=1)
384
+ q.add_atom('C')
385
+ q.add_atom('C')
386
+ q.add_atom('N')
387
+ q.add_bond(1, 2, 1)
388
+ q.add_bond(2, 3, 1)
389
+ q.add_bond(3, 4, 3)
390
+
391
+ p.add_atom('C')
392
+ p.add_atom('C')
393
+ p.add_bond(1, 2, 2)
394
+
395
+ # decarbocylation (alpha atom of nitrile)
396
+ #
397
+ q, p = prepare()
398
+ q.add_atom('N')
399
+ q.add_atom('C')
400
+ q.add_atom('C', neighbors=2)
401
+ q.add_bond(1, 2, 3)
402
+ q.add_bond(2, 3, 1)
403
+
404
+ p.add_atom('N')
405
+ p.add_atom('C')
406
+ p.add_atom('C')
407
+ p.add_atom('C')
408
+ p.add_atom('O')
409
+ p.add_atom('O')
410
+ p.add_bond(1, 2, 3)
411
+ p.add_bond(2, 3, 1)
412
+ p.add_bond(3, 4, 1)
413
+ p.add_bond(4, 5, 2)
414
+ p.add_bond(4, 6, 1)
415
+
416
+ # Bichler-Napieralski reaction
417
+ #
418
+ q, p = prepare()
419
+ q.add_atom('C', rings_sizes=(6,))
420
+ q.add_atom('C', rings_sizes=(6,))
421
+ q.add_atom('N', rings_sizes=(6,), neighbors=2)
422
+ q.add_atom('C')
423
+ q.add_atom('C')
424
+ q.add_atom('C')
425
+ q.add_atom('O')
426
+ q.add_atom('O')
427
+ q.add_atom('C')
428
+ q.add_atom('O', neighbors=1)
429
+ q.add_bond(1, 2, 4)
430
+ q.add_bond(2, 3, 1)
431
+ q.add_bond(3, 4, 1)
432
+ q.add_bond(4, 5, 2)
433
+ q.add_bond(5, 6, 1)
434
+ q.add_bond(6, 7, 2)
435
+ q.add_bond(6, 8, 1)
436
+ q.add_bond(5, 9, 4)
437
+ q.add_bond(9, 10, 1)
438
+ q.add_bond(1, 9, 1)
439
+
440
+ p.add_atom('C')
441
+ p.add_atom('C')
442
+ p.add_atom('N')
443
+ p.add_atom('C')
444
+ p.add_atom('C')
445
+ p.add_atom('C')
446
+ p.add_atom('O')
447
+ p.add_atom('O')
448
+ p.add_atom('C')
449
+ p.add_atom('O')
450
+ p.add_atom('O')
451
+ p.add_bond(1, 2, 4)
452
+ p.add_bond(2, 3, 1)
453
+ p.add_bond(3, 4, 1)
454
+ p.add_bond(4, 5, 2)
455
+ p.add_bond(5, 6, 1)
456
+ p.add_bond(6, 7, 2)
457
+ p.add_bond(6, 8, 1)
458
+ p.add_bond(5, 9, 1)
459
+ p.add_bond(9, 10, 2)
460
+ p.add_bond(9, 11, 1)
461
+
462
+ # heterocyclization in Prins reaction
463
+ #
464
+ q, p = prepare()
465
+ q.add_atom('C')
466
+ q.add_atom('O')
467
+ q.add_atom('C')
468
+ q.add_atom(ListElement(['N', 'O']), neighbors=2)
469
+ q.add_atom('C')
470
+ q.add_atom('C')
471
+ q.add_bond(1, 2, 1)
472
+ q.add_bond(2, 3, 1)
473
+ q.add_bond(3, 4, 1)
474
+ q.add_bond(4, 5, 1)
475
+ q.add_bond(5, 6, 1)
476
+ q.add_bond(1, 6, 1)
477
+
478
+ p.add_atom('C')
479
+ p.add_atom('C', _map=5)
480
+ p.add_bond(1, 5, 2)
481
+
482
+ # recyclization of tetrahydropyran through an opening the ring and dehydration
483
+ #
484
+ q, p = prepare()
485
+ q.add_atom('C')
486
+ q.add_atom('C')
487
+ q.add_atom('C')
488
+ q.add_atom(ListElement(['N', 'O']))
489
+ q.add_atom('C')
490
+ q.add_atom('C')
491
+ q.add_bond(1, 2, 1)
492
+ q.add_bond(2, 3, 1)
493
+ q.add_bond(3, 4, 1)
494
+ q.add_bond(4, 5, 1)
495
+ q.add_bond(5, 6, 1)
496
+ q.add_bond(1, 6, 2)
497
+
498
+ p.add_atom('C')
499
+ p.add_atom('C')
500
+ p.add_atom('C')
501
+ p.add_atom('A')
502
+ p.add_atom('C')
503
+ p.add_atom('C')
504
+ p.add_atom('O')
505
+ p.add_bond(1, 2, 1)
506
+ p.add_bond(1, 7, 1)
507
+ p.add_bond(3, 7, 1)
508
+ p.add_bond(3, 4, 1)
509
+ p.add_bond(4, 5, 1)
510
+ p.add_bond(5, 6, 1)
511
+ p.add_bond(1, 6, 1)
512
+
513
+ # alkenes + h2o/hHal
514
+ #
515
+ q, p = prepare()
516
+ q.add_atom('C', hybridization=1)
517
+ q.add_atom('C', hybridization=1)
518
+ q.add_atom(ListElement(['O', 'F', 'Cl', 'Br', 'I']), neighbors=1)
519
+ q.add_bond(1, 2, 1)
520
+ q.add_bond(2, 3, 1)
521
+
522
+ p.add_atom('C')
523
+ p.add_atom('C')
524
+ p.add_bond(1, 2, 2)
525
+
526
+ # methylation of dimethylamines
527
+ #
528
+ q, p = prepare()
529
+ q.add_atom('C', neighbors=1)
530
+ q.add_atom('N', neighbors=3)
531
+ q.add_bond(1, 2, 1)
532
+
533
+ p.add_atom('N', _map=2)
534
+
535
+ __all__ = ['rules']
SynTool/chem/retron.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing a class Retron that represents a retron (extend molecule object) in the search tree
3
+ """
4
+
5
+ from CGRtools.containers import MoleculeContainer
6
+ from CGRtools.exceptions import InvalidAromaticRing
7
+
8
+ from SynTool.chem.utils import safe_canonicalization
9
+
10
+
11
+ class Retron:
12
+ """
13
+ Retron class is used to extend the molecule behavior needed for interaction with a tree in MCTS
14
+ """
15
+
16
+ def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True):
17
+ """
18
+ It initializes a Retron object with a molecule container as a parameter.
19
+
20
+ :param molecule: The `molecule` parameter is of type `MoleculeContainer`.
21
+ :type molecule: MoleculeContainer
22
+ """
23
+ self._molecule = safe_canonicalization(molecule) if canonicalize else molecule
24
+ self._mapping = None
25
+ self.prev_retrons = []
26
+
27
+ def __len__(self):
28
+ """
29
+ Return the number of atoms in Retron.
30
+ """
31
+ return len(self._molecule)
32
+
33
+ def __hash__(self):
34
+ """
35
+ Returns the hash value of Retron.
36
+ """
37
+ return hash(self._molecule)
38
+
39
+ def __str__(self):
40
+ return str(self._molecule)
41
+
42
+ def __eq__(self, other: "Retron"):
43
+ """
44
+ The function checks if the current Retron is equal to another Retron of the same class.
45
+
46
+ :param other: The "other" parameter is a reference to another object of the same class "Retron". It is used to
47
+ compare the current Retron with the other Retron to check if they are equal
48
+ :type other: "Retron"
49
+ """
50
+ return self._molecule == other._molecule
51
+
52
+ def validate_molecule(self):
53
+ molecule = self._molecule.copy()
54
+ try:
55
+ molecule.kekule()
56
+ if molecule.check_valence():
57
+ return False
58
+ molecule.thiele()
59
+ except InvalidAromaticRing:
60
+ return False
61
+ return True
62
+
63
+ @property
64
+ def molecule(self) -> MoleculeContainer:
65
+ """
66
+ Returns a remapped MoleculeContainer object if self._mapping=True.
67
+ """
68
+ if self._mapping:
69
+ remapped = self._molecule.copy()
70
+ try:
71
+ remapped = self._molecule.remap(self._mapping, copy=True)
72
+ except ValueError:
73
+ pass
74
+ return remapped
75
+ return self._molecule.copy()
76
+
77
+ def __repr__(self):
78
+ """
79
+ Returns a SMILES of the retron
80
+ """
81
+ return str(self._molecule)
82
+
83
+ def is_building_block(self, stock, min_mol_size=6):
84
+ """
85
+ The function checks if a Retron is a building block.
86
+
87
+ :param min_mol_size:
88
+ :param stock: The list of building blocks. Each building block is represented by a smiles.
89
+ """
90
+ if len(self._molecule) <= min_mol_size:
91
+ return True
92
+ else:
93
+ return str(self._molecule) in stock
94
+
95
+
96
+ def compose_retrons(retrons: list = None, exclude_small=True, min_mol_size: int = 6
97
+ ) -> MoleculeContainer:
98
+ """
99
+ The function takes a list of retrons, excludes small retrons if specified, and composes them into a single molecule.
100
+ This molecule is used for the prediction of synthesisability of the characterizing the possible success of the path
101
+ including the nodes with the given retrons.
102
+
103
+ :param retrons: The list of retrons to be composed.
104
+ :type retrons: list
105
+ :param exclude_small: The parameter that determines whether small retrons should be
106
+ excluded from the composition process. If `exclude_small` is set to `True`, only retrons with a length greater than
107
+ min_mol_size will be considered for composition.
108
+ :param min_mol_size: parameter used with exclude_small
109
+ :return: A composed retrons as a MoleculeContainer object.
110
+ """
111
+
112
+ if len(retrons) == 1:
113
+ return retrons[0].molecule
114
+ elif len(retrons) > 1:
115
+ if exclude_small:
116
+ big_retrons = [
117
+ retron for retron in retrons if len(retron.molecule) > min_mol_size
118
+ ]
119
+ if big_retrons:
120
+ retrons = big_retrons
121
+ tmp_mol = retrons[0].molecule.copy()
122
+ transition_mapping = {}
123
+ for mol in retrons[1:]:
124
+ for n, atom in mol.molecule.atoms():
125
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
126
+ transition_mapping[n] = new_number
127
+ for atom, neighbor, bond in mol.molecule.bonds():
128
+ tmp_mol.add_bond(
129
+ transition_mapping[atom], transition_mapping[neighbor], bond
130
+ )
131
+ transition_mapping = {}
132
+ return tmp_mol
SynTool/chem/utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable, Tuple, Union
2
+
3
+ from CGRtools.containers import MoleculeContainer, ReactionContainer, QueryContainer
4
+ from CGRtools.exceptions import InvalidAromaticRing
5
+
6
+
7
+ def query_to_mol(query: QueryContainer) -> MoleculeContainer:
8
+ """
9
+ Converts a QueryContainer object into a MoleculeContainer object.
10
+
11
+ :param query: A QueryContainer object representing the query structure.
12
+ :return: A MoleculeContainer object that replicates the structure of the query.
13
+ """
14
+ new_mol = MoleculeContainer()
15
+ for n, atom in query.atoms():
16
+ new_mol.add_atom(atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical)
17
+ for i, j, bond in query.bonds():
18
+ new_mol.add_bond(i, j, int(bond))
19
+ return new_mol
20
+
21
+
22
+ def reaction_query_to_reaction(rule: ReactionContainer) -> ReactionContainer:
23
+ """
24
+ Converts a ReactionContainer object with query structures into a ReactionContainer with molecular structures.
25
+
26
+ :param rule: A ReactionContainer object where reactants and products are QueryContainer objects.
27
+ :return: A new ReactionContainer
28
+ :return: A new ReactionContainer object where reactants and products are MoleculeContainer objects.
29
+ """
30
+ reactants = [query_to_mol(q) for q in rule.reactants]
31
+ products = [query_to_mol(q) for q in rule.products]
32
+ reagents = [query_to_mol(q) for q in rule.reagents] # Assuming reagents are also part of the rule
33
+ reaction = ReactionContainer(reactants, products, reagents, rule.meta)
34
+ reaction.name = rule.name
35
+ return reaction
36
+
37
+
38
+ def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer:
39
+ """
40
+ Unites a list of MoleculeContainer objects into a single MoleculeContainer.
41
+
42
+ This function takes multiple molecules and combines them into one larger molecule.
43
+ The first molecule in the list is taken as the base, and subsequent molecules are united with it sequentially.
44
+
45
+ :param molecules: A list of MoleculeContainer objects to be united.
46
+ :return: A single MoleculeContainer object representing the union of all input molecules.
47
+ """
48
+ new_mol = MoleculeContainer()
49
+ for mol in molecules:
50
+ new_mol = new_mol.union(mol)
51
+ return new_mol
52
+
53
+
54
+ def safe_canonicalization(molecule: MoleculeContainer):
55
+ """
56
+ Attempts to canonicalize a molecule, handling any exceptions.
57
+
58
+ This function tries to canonicalize the given molecule.
59
+ If the canonicalization process fails due to an InvalidAromaticRing exception,
60
+ it safely returns the original molecule.
61
+
62
+ :param molecule: The given molecule to be canonicalized.
63
+ :return: The canonicalized molecule if successful, otherwise the original molecule.
64
+ """
65
+ molecule._atoms = dict(sorted(molecule._atoms.items()))
66
+
67
+ tmp = molecule.copy()
68
+ try:
69
+ tmp.canonicalize()
70
+ return tmp
71
+ except InvalidAromaticRing:
72
+ return molecule
73
+
74
+
75
+ def split_molecules(molecules: Iterable, number_of_atoms: int) -> Tuple[List, List]:
76
+ """
77
+ Splits molecules according to the number of heavy atoms.
78
+
79
+ :param molecules: Iterable of molecules.
80
+ :param number_of_atoms: Threshold for splitting molecules.
81
+ :return: Tuple of lists containing "big" molecules and "small" molecules.
82
+ """
83
+ big_molecules, small_molecules = [], []
84
+ for molecule in molecules:
85
+ if len(molecule) > number_of_atoms:
86
+ big_molecules.append(molecule)
87
+ else:
88
+ small_molecules.append(molecule)
89
+
90
+ return big_molecules, small_molecules
91
+
92
+
93
+ def remove_small_molecules(
94
+ reaction: ReactionContainer,
95
+ number_of_atoms: int = 6,
96
+ small_molecules_to_meta: bool = True
97
+ ) -> Union[ReactionContainer, None]:
98
+ """
99
+ Processes a reaction by removing small molecules.
100
+
101
+ :param reaction: ReactionContainer object.
102
+ :param number_of_atoms: Molecules with the number of atoms equal to or below this will be removed.
103
+ :param small_molecules_to_meta: If True, deleted molecules are saved to meta.
104
+ :return: Processed ReactionContainer without small molecules.
105
+ """
106
+ new_reactants, small_reactants = split_molecules(reaction.reactants, number_of_atoms)
107
+ new_products, small_products = split_molecules(reaction.products, number_of_atoms)
108
+
109
+ if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0:
110
+ return None
111
+
112
+ new_reaction = ReactionContainer(new_reactants, new_products, reaction.reagents, reaction.meta)
113
+ new_reaction.name = reaction.name
114
+
115
+ if small_molecules_to_meta:
116
+ united_small_reactants = unite_molecules(small_reactants)
117
+ new_reaction.meta["small_reactants"] = str(united_small_reactants)
118
+
119
+ united_small_products = unite_molecules(small_products)
120
+ new_reaction.meta["small_products"] = str(united_small_products)
121
+
122
+ return new_reaction
123
+
124
+
125
+ def rebalance_reaction(reaction: ReactionContainer) -> ReactionContainer:
126
+ """
127
+ Rebalances the reaction by assembling CGR and then decomposing it. Works for all reactions for which the correct
128
+ CGR can be assembled
129
+
130
+ :param reaction: a reaction object
131
+ :return: a rebalanced reaction
132
+ """
133
+ tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
134
+ cgr = ~tmp_reaction
135
+ reactants, products = ~cgr
136
+ rebalanced_reaction = ReactionContainer(reactants.split(), products.split(), reaction.reagents, reaction.meta)
137
+ rebalanced_reaction.name = reaction.name
138
+ return rebalanced_reaction
139
+
140
+
141
+ def reverse_reaction(reaction: ReactionContainer) -> ReactionContainer:
142
+ """
143
+ Reverses given reaction
144
+
145
+ :param reaction: a reaction object
146
+ :return: the reversed reaction
147
+ """
148
+ reversed_reaction = ReactionContainer(reaction.products, reaction.reactants, reaction.reagents, reaction.meta)
149
+ reversed_reaction.name = reaction.name
150
+
151
+ return reversed_reaction
152
+
153
+
154
+ def remove_reagents(
155
+ reaction: ReactionContainer,
156
+ keep_reagents: bool = True,
157
+ reagents_max_size: int = 7
158
+ ) -> Union[ReactionContainer, None]:
159
+ """
160
+ Removes reagents (not changed molecules or molecules not involved in the reaction) from reactants and products
161
+
162
+ :param reaction: a reaction object
163
+ :param keep_reagents: if True, the reagents are written to ReactionContainer
164
+ :param reagents_max_size: max size of molecules that are called reagents, bigger are deleted
165
+ :return: cleaned reaction
166
+ """
167
+ not_changed_molecules = set(reaction.reactants).intersection(reaction.products)
168
+
169
+ cgr = ~reaction
170
+ center_atoms = set(cgr.center_atoms)
171
+
172
+ new_reactants = []
173
+ new_products = []
174
+ new_reagents = []
175
+
176
+ for molecule in reaction.reactants:
177
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
178
+ new_reagents.append(molecule)
179
+ else:
180
+ new_reactants.append(molecule)
181
+
182
+ for molecule in reaction.products:
183
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
184
+ new_reagents.append(molecule)
185
+ else:
186
+ new_products.append(molecule)
187
+
188
+ if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0:
189
+ return None
190
+
191
+ if keep_reagents:
192
+ new_reagents = {molecule for molecule in new_reagents if len(molecule) <= reagents_max_size}
193
+ else:
194
+ new_reagents = []
195
+
196
+ new_reaction = ReactionContainer(new_reactants, new_products, new_reagents, reaction.meta)
197
+ new_reaction.name = reaction.name
198
+
199
+ return new_reaction
200
+
201
+
202
+ def to_reaction_smiles_record(reaction):
203
+ if isinstance(reaction, str):
204
+ return reaction
205
+
206
+ reaction_record = [format(reaction, "m")]
207
+ sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0])
208
+ for _, meta_info in sorted_meta:
209
+ # meta_info = str(meta_info)
210
+ meta_info = '' # TODO decide what to do with meta
211
+ meta_info = ";".join(meta_info.split("\n"))
212
+ reaction_record.append(str(meta_info))
213
+ # return "\t".join(reaction_record) + "\n"
214
+ return "".join(reaction_record)
215
+
216
+
217
+ def cgr_from_rule(rule: ReactionContainer):
218
+ reaction_rule = reaction_query_to_reaction(rule)
219
+ cgr_rule = ~reaction_rule
220
+ return cgr_rule
221
+
222
+
223
+ def hash_from_rule(reaction_rule: ReactionContainer):
224
+ reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants))
225
+ reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents))
226
+ products_hash = tuple(sorted(hash(r) for r in reaction_rule.products))
227
+ return hash((reactants_hash, reagents_hash, products_hash))
SynTool/interfaces/__init__.py ADDED
File without changes
SynTool/interfaces/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
SynTool/interfaces/cli.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing commands line scripts for training and planning mode
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import yaml
8
+ import warnings
9
+ from pathlib import Path
10
+
11
+ import click
12
+ import gdown
13
+
14
+ from SynTool.chem.data.cleaning import reactions_cleaner
15
+ from SynTool.chem.data.filtering import filter_reactions, ReactionCheckConfig
16
+ from SynTool.utils.loading import standardize_building_blocks
17
+ from SynTool.chem.reaction_rules.extraction import extract_rules_from_reactions
18
+ from SynTool.mcts.search import tree_search
19
+ from SynTool.ml.training.reinforcement import run_reinforcement_tuning
20
+ from SynTool.ml.training.supervised import create_policy_dataset, run_policy_training
21
+ from SynTool.utils.config import ReinforcementConfig, TreeConfig, PolicyNetworkConfig, ValueNetworkConfig
22
+ from SynTool.utils.config import ReactionStandardizationConfig, RuleExtractionConfig
23
+ from SynTool.chem.data.mapping import remove_reagents_and_map_from_file
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+
28
+ @click.group(name="syntool")
29
+ def syntool():
30
+ pass
31
+
32
+
33
+ @syntool.command(name="download_planning_data")
34
+ @click.option(
35
+ "--root_dir",
36
+ required=True,
37
+ type=click.Path(exists=True),
38
+ help="Path to the reaction database file that will be mapped.",
39
+ )
40
+ def download_planning_data_cli(root_dir='.'):
41
+ """
42
+ Downloads data for retrosythesis planning
43
+ """
44
+ remote_id = "1ygq9BvQgH2Tq_rL72BvSOdASSSbPFTsL"
45
+ output = os.path.join(root_dir, "syntool_planning_data.zip")
46
+ #
47
+ gdown.download(output=output, id=remote_id, quiet=False)
48
+ shutil.unpack_archive(output, root_dir)
49
+ #
50
+ os.remove(output)
51
+
52
+
53
+ @syntool.command(name='download_training_data')
54
+ @click.option(
55
+ "--root_dir",
56
+ required=True,
57
+ type=click.Path(exists=True),
58
+ help="Path to the reaction database file that will be mapped.",
59
+ )
60
+ def download_training_data_cli(root_dir='.'):
61
+ """
62
+ Downloads data for retrosythetic models training
63
+ """
64
+ remote_id = "1ckhO1l6xud0_bnC0rCDMkIlKRUMG_xs8"
65
+ output = os.path.join(root_dir, "syntool_training_data.zip")
66
+ #
67
+ gdown.download(output=output, id=remote_id, quiet=False)
68
+ shutil.unpack_archive(output, root_dir)
69
+ #
70
+ os.remove(output)
71
+
72
+
73
+ @syntool.command(name="building_blocks")
74
+ @click.option(
75
+ "--input",
76
+ "input_file",
77
+ required=True,
78
+ type=click.Path(exists=True),
79
+ help="Path to the reaction database file that will be mapped.",
80
+ )
81
+ @click.option(
82
+ "--output",
83
+ "output_file",
84
+ required=True,
85
+ type=click.Path(),
86
+ help="File where the results will be stored.",
87
+ )
88
+ def building_blocks_cli(input_file, output_file):
89
+ """
90
+ Standardizes building blocks
91
+ """
92
+
93
+ standardize_building_blocks(input_file=input_file, output_file=output_file)
94
+
95
+
96
+ @syntool.command(name="reaction_mapping")
97
+ @click.option(
98
+ "--config",
99
+ "config_path",
100
+ required=True,
101
+ type=click.Path(exists=True),
102
+ help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.",
103
+ )
104
+ @click.option(
105
+ "--input",
106
+ "input_file",
107
+ required=True,
108
+ type=click.Path(exists=True),
109
+ help="Path to the reaction database file that will be mapped.",
110
+ )
111
+ @click.option(
112
+ "--output",
113
+ "output_file",
114
+ default=Path("reaction_data_standardized.smi"),
115
+ type=click.Path(),
116
+ help="File where the results will be stored.",
117
+ )
118
+ def reaction_mapping_cli(config_path, input_file, output_file):
119
+ """
120
+ Reaction data mapping
121
+ """
122
+
123
+ stand_config = ReactionStandardizationConfig.from_yaml(config_path)
124
+ remove_reagents_and_map_from_file(input_file=input_file, output_file=output_file, keep_reagent=stand_config.keep_reagents)
125
+
126
+
127
+ @syntool.command(name="reaction_standardizing")
128
+ @click.option(
129
+ "--config",
130
+ "config_path",
131
+ required=True,
132
+ type=click.Path(exists=True),
133
+ help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.",
134
+ )
135
+ @click.option(
136
+ "--input",
137
+ "input_file",
138
+ required=True,
139
+ type=click.Path(exists=True),
140
+ help="Path to the reaction database file that will be mapped.",
141
+ )
142
+ @click.option(
143
+ "--output",
144
+ "output_file",
145
+ type=click.Path(),
146
+ help="File where the results will be stored.",
147
+ )
148
+ @click.option(
149
+ "--num_cpus",
150
+ default=8,
151
+ type=int,
152
+ help="Number of CPUs to use for processing. Defaults to 1.",
153
+ )
154
+ def reaction_standardizing_cli(config_path, input_file, output_file, num_cpus):
155
+ """
156
+ Standardizes reactions and remove duplicates
157
+ """
158
+
159
+ stand_config = ReactionStandardizationConfig.from_yaml(config_path)
160
+ reactions_cleaner(config=stand_config,
161
+ input_file=input_file,
162
+ output_file=output_file,
163
+ num_cpus=num_cpus)
164
+
165
+
166
+ @syntool.command(name="reaction_filtering")
167
+ @click.option(
168
+ "--config",
169
+ "config_path",
170
+ required=True,
171
+ type=click.Path(exists=True),
172
+ help="Path to the configuration file. This file contains settings for filtering reactions.",
173
+ )
174
+ @click.option(
175
+ "--input",
176
+ "input_file",
177
+ required=True,
178
+ type=click.Path(exists=True),
179
+ help="Path to the reaction database file that will be mapped.",
180
+ )
181
+ @click.option(
182
+ "--output",
183
+ "output_file",
184
+ default=Path("./"),
185
+ type=click.Path(),
186
+ help="File where the results will be stored.",
187
+ )
188
+ @click.option(
189
+ "--append_results",
190
+ is_flag=True,
191
+ default=False,
192
+ help="If set, results will be appended to existing files. By default, new files are created.",
193
+ )
194
+ @click.option(
195
+ "--batch_size",
196
+ default=100,
197
+ type=int,
198
+ help="Size of the batch for processing reactions. Defaults to 10.",
199
+ )
200
+ @click.option(
201
+ "--num_cpus",
202
+ default=8,
203
+ type=int,
204
+ help="Number of CPUs to use for processing. Defaults to 1.",
205
+ )
206
+ def reaction_filtering_cli(config_path,
207
+ input_file,
208
+ output_file,
209
+ append_results,
210
+ batch_size,
211
+ num_cpus):
212
+ """
213
+ Filters erroneous reactions
214
+ """
215
+ reaction_check_config = ReactionCheckConfig().from_yaml(config_path)
216
+ filter_reactions(
217
+ config=reaction_check_config,
218
+ reaction_database_path=input_file,
219
+ result_reactions_file_name=output_file,
220
+ append_results=append_results,
221
+ num_cpus=num_cpus,
222
+ batch_size=batch_size,
223
+ )
224
+
225
+
226
+ @syntool.command(name="rule_extracting")
227
+ @click.option(
228
+ "--config",
229
+ "config_path",
230
+ required=True,
231
+ type=click.Path(exists=True),
232
+ help="Path to the configuration file. This file contains settings for reaction rules extraction.",
233
+ )
234
+ @click.option(
235
+ "--input",
236
+ "input_file",
237
+ required=True,
238
+ type=click.Path(exists=True),
239
+ help="Path to the reaction database file that will be mapped.",
240
+ )
241
+ @click.option(
242
+ "--output",
243
+ "output_file",
244
+ required=True,
245
+ type=click.Path(),
246
+ help="File where the results will be stored.",
247
+ )
248
+ @click.option(
249
+ "--batch_size",
250
+ default=100,
251
+ type=int,
252
+ help="Size of the batch for processing reactions. Defaults to 10.",
253
+ )
254
+ @click.option(
255
+ "--num_cpus",
256
+ default=4,
257
+ type=int,
258
+ help="Number of CPUs to use for processing. Defaults to 4.",
259
+ )
260
+ def rule_extracting_cli(
261
+ config_path,
262
+ input_file,
263
+ output_file,
264
+ num_cpus,
265
+ batch_size,
266
+ ):
267
+ """
268
+ Extracts reaction rules
269
+ """
270
+
271
+ reaction_rule_config = RuleExtractionConfig.from_yaml(config_path)
272
+ extract_rules_from_reactions(config=reaction_rule_config,
273
+ reaction_file=input_file,
274
+ rules_file_name=output_file,
275
+ num_cpus=num_cpus,
276
+ batch_size=batch_size)
277
+
278
+
279
+ @syntool.command(name="supervised_ranking_policy_training")
280
+ @click.option(
281
+ "--config",
282
+ "config_path",
283
+ required=True,
284
+ type=click.Path(exists=True),
285
+ help="Path to the configuration file. This file contains settings for policy training.",
286
+ )
287
+ @click.option(
288
+ "--reaction_data",
289
+ required=True,
290
+ type=click.Path(exists=True),
291
+ help="Path to the reaction database file that will be mapped.",
292
+ )
293
+ @click.option(
294
+ "--reaction_rules",
295
+ required=True,
296
+ type=click.Path(exists=True),
297
+ help="Path to the reaction database file that will be mapped.",
298
+ )
299
+ @click.option(
300
+ "--results_dir",
301
+ default=Path("."),
302
+ type=click.Path(),
303
+ help="Root directory where the results will be stored.",
304
+ )
305
+ @click.option(
306
+ "--num_cpus",
307
+ default=4,
308
+ type=int,
309
+ help="Number of CPUs to use for processing. Defaults to 4.",
310
+ )
311
+ def supervised_ranking_policy_training_cli(config_path, reaction_data, reaction_rules, results_dir, num_cpus):
312
+ """
313
+ Trains ranking policy network
314
+ """
315
+
316
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
317
+
318
+ policy_dataset_file = os.path.join(results_dir, 'policy_dataset.dt')
319
+
320
+ datamodule = create_policy_dataset(reaction_rules_path=reaction_rules,
321
+ molecules_or_reactions_path=reaction_data,
322
+ output_path=policy_dataset_file,
323
+ dataset_type='ranking',
324
+ batch_size=policy_config.batch_size,
325
+ num_cpus=num_cpus)
326
+
327
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
328
+
329
+
330
+ @syntool.command(name="supervised_filtering_policy_training")
331
+ @click.option(
332
+ "--config",
333
+ "config_path",
334
+ required=True,
335
+ type=click.Path(exists=True),
336
+ help="Path to the configuration file. This file contains settings for policy training.",
337
+ )
338
+ @click.option(
339
+ "--molecules_file",
340
+ required=True,
341
+ type=click.Path(exists=True),
342
+ help="Path to the molecules database file that will be mapped.",
343
+ )
344
+ @click.option(
345
+ "--reaction_rules",
346
+ required=True,
347
+ type=click.Path(exists=True),
348
+ help="Path to the reaction database file that will be mapped.",
349
+ )
350
+ @click.option(
351
+ "--results_dir",
352
+ default=Path("."),
353
+ type=click.Path(),
354
+ help="Root directory where the results will be stored.",
355
+ )
356
+ @click.option(
357
+ "--num_cpus",
358
+ default=8,
359
+ type=int,
360
+ help="Number of CPUs to use for processing. Defaults to 1.",
361
+ )
362
+ def supervised_filtering_policy_training_cli(config_path, molecules_file, reaction_rules, results_dir, num_cpus):
363
+ """
364
+ Trains filtering policy network
365
+ """
366
+
367
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
368
+
369
+ policy_dataset_file = os.path.join(results_dir, 'policy_dataset.ckpt')
370
+ datamodule = create_policy_dataset(reaction_rules_path=reaction_rules,
371
+ molecules_or_reactions_path=molecules_file,
372
+ output_path=policy_dataset_file,
373
+ dataset_type='filtering',
374
+ batch_size=policy_config.batch_size,
375
+ num_cpus=num_cpus)
376
+
377
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
378
+
379
+
380
+ @syntool.command(name="reinforcement_value_network_training")
381
+ @click.option(
382
+ "--config",
383
+ required=True,
384
+ type=click.Path(exists=True),
385
+ help="Path to the configuration file. This file contains settings for policy training.",
386
+ )
387
+ @click.option(
388
+ "--targets",
389
+ required=True,
390
+ type=click.Path(exists=True),
391
+ help="Path to the configuration file. This file contains settings for policy training.",
392
+ )
393
+ @click.option(
394
+ "--reaction_rules",
395
+ required=True,
396
+ type=click.Path(exists=True),
397
+ help="Path to the configuration file. This file contains settings for policy training.",
398
+ )
399
+ @click.option(
400
+ "--building_blocks",
401
+ required=True,
402
+ type=click.Path(exists=True),
403
+ help="Path to the configuration file. This file contains settings for policy training.",
404
+ )
405
+ @click.option(
406
+ "--policy_network",
407
+ required=True,
408
+ type=click.Path(exists=True),
409
+ help="Path to the configuration file. This file contains settings for policy training.",
410
+ )
411
+ @click.option(
412
+ "--value_network",
413
+ default=None,
414
+ type=click.Path(exists=True),
415
+ help="Path to the configuration file. This file contains settings for policy training.",
416
+ )
417
+ @click.option(
418
+ "--results_dir",
419
+ default='.',
420
+ type=click.Path(exists=False),
421
+ help="Path to the configuration file. This file contains settings for policy training.",
422
+ )
423
+ def reinforcement_value_network_training_cli(config,
424
+ targets,
425
+ reaction_rules,
426
+ building_blocks,
427
+ policy_network,
428
+ value_network,
429
+ results_dir):
430
+ """
431
+ Trains value network with reinforcement learning
432
+ """
433
+
434
+ with open(config, "r") as file:
435
+ config = yaml.safe_load(file)
436
+
437
+ policy_config = PolicyNetworkConfig.from_dict(config['node_expansion'])
438
+ policy_config.weights_path = policy_network
439
+
440
+ value_config = ValueNetworkConfig.from_dict(config['value_network'])
441
+ if value_network is None:
442
+ value_config.weights_path = os.path.join(results_dir, 'weights', 'value_network.ckpt')
443
+
444
+ tree_config = TreeConfig.from_dict(config['tree'])
445
+ reinforce_config = ReinforcementConfig.from_dict(config['reinforcement'])
446
+
447
+ run_reinforcement_tuning(targets_path=targets,
448
+ tree_config=tree_config,
449
+ policy_config=policy_config,
450
+ value_config=value_config,
451
+ reinforce_config=reinforce_config,
452
+ reaction_rules_path=reaction_rules,
453
+ building_blocks_path=building_blocks,
454
+ results_root=results_dir)
455
+
456
+
457
+ @syntool.command(name="planning")
458
+ @click.option(
459
+ "--config",
460
+ "config_path",
461
+ required=True,
462
+ type=click.Path(exists=True),
463
+ help="Path to the configuration file. This file contains settings for policy training.",
464
+ )
465
+ @click.option(
466
+ "--targets",
467
+ required=True,
468
+ type=click.Path(exists=True),
469
+ help="Path to the configuration file. This file contains settings for policy training.",
470
+ )
471
+ @click.option(
472
+ "--reaction_rules",
473
+ required=True,
474
+ type=click.Path(exists=True),
475
+ help="Path to the configuration file. This file contains settings for policy training.",
476
+ )
477
+ @click.option(
478
+ "--building_blocks",
479
+ required=True,
480
+ type=click.Path(exists=True),
481
+ help="Path to the configuration file. This file contains settings for policy training.",
482
+ )
483
+ @click.option(
484
+ "--policy_network",
485
+ required=True,
486
+ type=click.Path(exists=True),
487
+ help="Path to the configuration file. This file contains settings for policy training.",
488
+ )
489
+ @click.option(
490
+ "--value_network",
491
+ default=None,
492
+ type=click.Path(exists=True),
493
+ help="Path to the configuration file. This file contains settings for policy training.",
494
+ )
495
+ @click.option(
496
+ "--results_dir",
497
+ default='.',
498
+ type=click.Path(exists=False),
499
+ help="Path to the configuration file. This file contains settings for policy training.",
500
+ )
501
+ def planning_cli(config_path,
502
+ targets,
503
+ reaction_rules,
504
+ building_blocks,
505
+ policy_network,
506
+ value_network,
507
+ results_dir):
508
+ """
509
+ Runs retrosynthesis planning
510
+ """
511
+
512
+ with open(config_path, "r") as file:
513
+ config = yaml.safe_load(file)
514
+
515
+ tree_config = TreeConfig.from_dict({**config['tree'], **config['node_evaluation']})
516
+ policy_config = PolicyNetworkConfig.from_dict({**config['node_expansion'], **{'weights_path': policy_network}})
517
+
518
+ tree_search(targets_path=targets,
519
+ tree_config=tree_config,
520
+ policy_config=policy_config,
521
+ reaction_rules_path=reaction_rules,
522
+ building_blocks_path=building_blocks,
523
+ value_weights_path=value_network,
524
+ results_root=results_dir)
525
+
526
+
527
+ if __name__ == '__main__':
528
+ syntool()
529
+
530
+
SynTool/interfaces/cli.py.bk ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing commands line scripts for training and planning mode
3
+ """
4
+
5
+ import warnings
6
+ import os
7
+ import shutil
8
+ from pathlib import Path
9
+ import click
10
+ import gdown
11
+ from datetime import datetime
12
+
13
+ from Syntool.chem.reaction_rules.extraction import extract_rules_from_reactions
14
+ from Syntool.chem.data.cleaning import reactions_cleaner
15
+ from Syntool.chem.data.mapping import remove_reagents_and_map_from_file
16
+ from Syntool.chem.loading import standardize_building_blocks
17
+ from Syntool.ml.training import create_policy_dataset, run_policy_training
18
+ from Syntool.ml.training.reinforcement import run_self_tuning
19
+ from Syntool.ml.networks.policy import PolicyNetworkConfig
20
+ from Syntool.utils.config import read_planning_config, read_training_config, TreeConfig
21
+ from Syntool.mcts.search import tree_search
22
+
23
+ from Syntool.chem.data.filtering import (
24
+ filter_reactions,
25
+ ReactionCheckConfig,
26
+ CCRingBreakingConfig,
27
+ WrongCHBreakingConfig,
28
+ CCsp3BreakingConfig,
29
+ DynamicBondsConfig,
30
+ MultiCenterConfig,
31
+ NoReactionConfig,
32
+ SmallMoleculesConfig,
33
+ )
34
+
35
+ warnings.filterwarnings("ignore")
36
+ main = click.Group()
37
+
38
+
39
+ @main.command(name='planning_data')
40
+ def planning_data_cli():
41
+ """
42
+ Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents,
43
+ and then deletes the zip file
44
+ """
45
+ remote_id = '1c5YJDT-rP1ZvFA-ELmPNTUj0b8an4yFf'
46
+ output = 'synto_planning_data.zip'
47
+ #
48
+ gdown.download(output=output, id=remote_id, quiet=True)
49
+ shutil.unpack_archive(output, './')
50
+ #
51
+ os.remove(output)
52
+
53
+
54
+ @main.command(name='training_data')
55
+ def training_data_cli():
56
+ """
57
+ Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents,
58
+ and then deletes the zip file
59
+ """
60
+ remote_id = '1r4I7OskGvzg-zxYNJ7WVYpVR2HSYW10N'
61
+ output = 'synto_training_data.zip'
62
+ #
63
+ gdown.download(output=output, id=remote_id, quiet=True)
64
+ shutil.unpack_archive(output, './')
65
+ #
66
+ os.remove(output)
67
+
68
+
69
+ @main.command(name='syntool_planning')
70
+ @click.option("--config",
71
+ "config_path",
72
+ required=True,
73
+ help="Path to the config YAML molecules_path.",
74
+ type=click.Path(exists=True, path_type=Path),
75
+ )
76
+ def syntool_planning_cli(config_path):
77
+ """
78
+ Launches tree search for the given target molecules and stores search statistics and found retrosynthetic paths
79
+
80
+ :param config_path: The path to the configuration file that contains the settings and parameters for the tree search
81
+ """
82
+ config = read_planning_config(config_path)
83
+ config['Tree']['silent'] = True
84
+
85
+ # standardize building blocks
86
+ if config['InputData']['standardize_building_blocks']:
87
+ print('STANDARDIZE BUILDING BLOCKS ...')
88
+ standardize_building_blocks(config['InputData']['building_blocks_path'],
89
+ config['InputData']['building_blocks_path'])
90
+ # run planning
91
+ print('\nRUN PLANNING ...')
92
+ tree_config = TreeConfig.from_dict(config['Tree'])
93
+ tree_search(targets=config['General']['targets_path'],
94
+ tree_config=tree_config,
95
+ reaction_rules_path=config['InputData']['reaction_rules_path'],
96
+ building_blocks_path=config['InputData']['building_blocks_path'],
97
+ policy_weights_path=config['PolicyNetwork']['weights_path'],
98
+ value_weights_paths=config['ValueNetwork']['weights_path'],
99
+ results_root=config['General']['results_root'])
100
+
101
+
102
+ @main.command(name='syntool_training')
103
+ @click.option(
104
+ "--config",
105
+ "config_path",
106
+ required=True,
107
+ help="Path to the config YAML file.",
108
+ type=click.Path(exists=True, path_type=Path)
109
+ )
110
+ def syntool_training_cli(config_path):
111
+
112
+ # read training config
113
+ print('READ CONFIG ...')
114
+ config = read_training_config(config_path)
115
+ print('Config is read')
116
+
117
+ reaction_data_file = config['InputData']['reaction_data_path']
118
+
119
+ # reaction data mapping
120
+ startTime0 = datetime.now()
121
+ data_output_folder = os.path.join(config['General']['results_root'], 'reaction_data')
122
+ Path(data_output_folder).mkdir(parents=True, exist_ok=True)
123
+ mapped_data_file = os.path.join(data_output_folder, 'reaction_data_mapped.smi')
124
+ if config['DataCleaning']['map_reactions']:
125
+ print('\nMAP REACTION DATA ...')
126
+
127
+ remove_reagents_and_map_from_file(input_file=config['InputData']['reaction_data_path'],
128
+ output_file=mapped_data_file)
129
+
130
+ reaction_data_file = mapped_data_file
131
+ print("remove_reagents_and_map_from_file:", datetime.now() - startTime0)
132
+
133
+ # reaction data cleaning
134
+ startTime0 = datetime.now()
135
+ cleaned_data_file = os.path.join(data_output_folder, 'reaction_data_cleaned.rdf')
136
+ if config['DataCleaning']['clean_reactions']:
137
+ print('\nCLEAN REACTION DATA ...')
138
+
139
+ reactions_cleaner(input_file=reaction_data_file,
140
+ output_file=cleaned_data_file,
141
+ num_cpus=config['General']['num_cpus'])
142
+
143
+ reaction_data_file = cleaned_data_file
144
+ print("reactions_cleaner:", datetime.now() - startTime0)
145
+
146
+ # reactions data filtering
147
+ startTime0 = datetime.now()
148
+ if config['DataCleaning']['filter_reactions']:
149
+ print('\nFILTER REACTION DATA ...')
150
+ #
151
+ filtration_config = ReactionCheckConfig(
152
+ remove_small_molecules=False,
153
+ small_molecules_config=SmallMoleculesConfig(limit=6),
154
+ dynamic_bonds_config=DynamicBondsConfig(min_bonds_number=1, max_bonds_number=6),
155
+ no_reaction_config=NoReactionConfig(),
156
+ multi_center_config=MultiCenterConfig(),
157
+ wrong_ch_breaking_config=WrongCHBreakingConfig(),
158
+ cc_sp3_breaking_config=CCsp3BreakingConfig(),
159
+ cc_ring_breaking_config=CCRingBreakingConfig()
160
+ )
161
+
162
+ filtered_data_file = os.path.join(data_output_folder, 'reaction_data_filtered.rdf')
163
+ filter_reactions(config=filtration_config,
164
+ reaction_database_path=reaction_data_file,
165
+ result_directory_path=data_output_folder,
166
+ result_reactions_file_name='reaction_data_filtered',
167
+ num_cpus=config['General']['num_cpus'],
168
+ batch_size=100)
169
+
170
+ reaction_data_file = filtered_data_file
171
+ print("filter_reactions:", datetime.now() - startTime0)
172
+
173
+ # standardize building blocks
174
+ startTime0 = datetime.now()
175
+ if config['DataCleaning']['standardize_building_blocks']:
176
+ print('\nSTANDARDIZE BUILDING BLOCKS ...')
177
+
178
+ standardize_building_blocks(config['InputData']['building_blocks_path'],
179
+ config['InputData']['building_blocks_path'])
180
+ print("standardize_building_blocks:", datetime.now() - startTime0)
181
+
182
+ # reaction rules extraction
183
+ startTime0 = datetime.now()
184
+ print('\nEXTRACT REACTION RULES ...')
185
+
186
+ rules_output_folder = os.path.join(config['General']['results_root'], 'reaction_rules')
187
+ Path(rules_output_folder).mkdir(parents=True, exist_ok=True)
188
+ reaction_rules_path = os.path.join(rules_output_folder, 'reaction_rules_filtered.pickle')
189
+ config['InputData']['reaction_rules_path'] = reaction_rules_path
190
+
191
+ extract_rules_from_reactions(config=config,
192
+ reaction_file=reaction_data_file,
193
+ results_root=rules_output_folder,
194
+ num_cpus=config['General']['num_cpus'])
195
+ print("extract_rules_from_reactions:", datetime.now() - startTime0)
196
+
197
+ # create policy network dataset
198
+ startTime0 = datetime.now()
199
+ print('\nCREATE POLICY NETWORK DATASET ...')
200
+ policy_output_folder = os.path.join(config['General']['results_root'], 'policy_network')
201
+ Path(policy_output_folder).mkdir(parents=True, exist_ok=True)
202
+ policy_data_file = os.path.join(policy_output_folder, 'policy_dataset.pt')
203
+
204
+ if config['PolicyNetwork']['policy_type'] == 'ranking':
205
+ molecules_or_reactions_path = reaction_data_file
206
+ elif config['PolicyNetwork']['policy_type'] == 'filtering':
207
+ molecules_or_reactions_path = config['InputData']['policy_data_path']
208
+ else:
209
+ raise ValueError(
210
+ "Invalid policy_type. Allowed values are 'ranking', 'filtering'."
211
+ )
212
+
213
+ datamodule = create_policy_dataset(reaction_rules_path=reaction_rules_path,
214
+ molecules_or_reactions_path=molecules_or_reactions_path,
215
+ output_path=policy_data_file,
216
+ dataset_type=config['PolicyNetwork']['policy_type'],
217
+ batch_size=config['PolicyNetwork']['batch_size'],
218
+ num_cpus=config['General']['num_cpus'])
219
+ print("datamodule:", datetime.now() - startTime0)
220
+
221
+ # train policy network
222
+ startTime0 = datetime.now()
223
+ print('\nTRAIN POLICY NETWORK ...')
224
+ policy_config = PolicyNetworkConfig.from_dict(config['PolicyNetwork'])
225
+ run_policy_training(datamodule, config=policy_config, results_path=policy_output_folder)
226
+ config['PolicyNetwork']['weights_path'] = os.path.join(policy_output_folder, 'weights', 'policy_network.ckpt')
227
+ print("run_policy_training:", datetime.now() - startTime0)
228
+
229
+ # self-tuning value network training
230
+ startTime0 = datetime.now()
231
+ print('\nTRAIN VALUE NETWORK ...')
232
+ value_output_folder = os.path.join(config['General']['results_root'], 'value_network')
233
+ Path(value_output_folder).mkdir(parents=True, exist_ok=True)
234
+ config['ValueNetwork']['weights_path'] = os.path.join(value_output_folder, 'weights', 'value_network.ckpt')
235
+
236
+ run_self_tuning(config, results_root=value_output_folder)
237
+ print("run_self_tuning:", datetime.now() - startTime0)
238
+
239
+
240
+ if __name__ == '__main__':
241
+ main()
SynTool/interfaces/visualisation.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing functions for analysis and visualization of the built search tree
3
+ """
4
+
5
+ from itertools import count, islice
6
+
7
+ from CGRtools.containers import MoleculeContainer
8
+
9
+ from SynTool import Tree
10
+ from SynTool.utils import path_type
11
+
12
+
13
+ def get_child_nodes(tree, molecule, graph):
14
+ nodes = []
15
+ try:
16
+ graph[molecule]
17
+ except KeyError:
18
+ return []
19
+ for retron in graph[molecule]:
20
+ temp_obj = {
21
+ "smiles": str(retron),
22
+ "type": "mol",
23
+ "in_stock": str(retron) in tree.building_blocks,
24
+ }
25
+ node = get_child_nodes(tree, retron, graph)
26
+ if node:
27
+ temp_obj["children"] = [node]
28
+ nodes.append(temp_obj)
29
+ return {"type": "reaction", "children": nodes}
30
+
31
+
32
+ def extract_routes(tree, extended=False):
33
+ """
34
+ The function takes the target and the dictionary of
35
+ successors and predecessors and returns a list of dictionaries that contain the target
36
+ and the list of successors
37
+ :return: A list of dictionaries. Each dictionary contains a target, a list of children, and a
38
+ boolean indicating whether the target is in building_blocks.
39
+ """
40
+ target = tree.nodes[1].retrons_to_expand[0].molecule
41
+ target_in_stock = tree.nodes[1].curr_retron.is_building_block(tree.building_blocks)
42
+ # Append encoded routes to list
43
+ paths_block = []
44
+ winning_nodes = []
45
+ if extended:
46
+ # Gather paths
47
+ for i, node in tree.nodes.items():
48
+ if node.is_solved():
49
+ winning_nodes.append(i)
50
+ else:
51
+ winning_nodes = tree.winning_nodes
52
+ if winning_nodes:
53
+ for winning_node in winning_nodes:
54
+ # Create graph for route
55
+ nodes = tree.path_to_node(winning_node)
56
+ graph, pred = {}, {}
57
+ for before, after in zip(nodes, nodes[1:]):
58
+ before = before.curr_retron.molecule
59
+ graph[before] = after = [x.molecule for x in after.new_retrons]
60
+ for x in after:
61
+ pred[x] = before
62
+
63
+ paths_block.append({"type": "mol", "smiles": str(target),
64
+ "in_stock": target_in_stock,
65
+ "children": [get_child_nodes(tree, target, graph)]})
66
+ else:
67
+ paths_block = [{"type": "mol", "smiles": str(target), "in_stock": target_in_stock, "children": []}]
68
+ return paths_block
69
+
70
+
71
+ def path_graph(tree, node: int) -> str:
72
+ """
73
+ Visualizes reaction path
74
+
75
+ :param node: node id
76
+ :type node: int
77
+ :return: The SVG string.
78
+ """
79
+ nodes = tree.path_to_node(node)
80
+ # Set up node_id types for different box colors
81
+ for node in nodes:
82
+ for retron in node.new_retrons:
83
+ retron._molecule.meta["status"] = "instock" if retron.is_building_block(
84
+ tree.building_blocks) else "mulecule"
85
+ nodes[0].curr_retron._molecule.meta["status"] = "target"
86
+ # Box colors
87
+ box_colors = {"target": "#98EEFF", # 152, 238, 255
88
+ "mulecule": "#F0AB90", # 240, 171, 144
89
+ "instock": "#9BFAB3", # 155, 250, 179
90
+ }
91
+
92
+ # first column is target
93
+ # second column are first new retrons_to_expand
94
+ columns = [[nodes[0].curr_retron.molecule], [x.molecule for x in nodes[1].new_retrons], ]
95
+ pred = {x: 0 for x in range(1, len(columns[1]) + 1)}
96
+ cx = [n for n, x in enumerate(nodes[1].new_retrons, 1) if not x.is_building_block(tree.building_blocks)]
97
+ size = len(cx)
98
+ nodes = iter(nodes[2:])
99
+ cy = count(len(columns[1]) + 1)
100
+ while size:
101
+ layer = []
102
+ for s in islice(nodes, size):
103
+ n = cx.pop(0)
104
+ for x in s.new_retrons:
105
+ layer.append(x)
106
+ m = next(cy)
107
+ if not x.is_building_block(tree.building_blocks):
108
+ cx.append(m)
109
+ pred[m] = n
110
+ size = len(cx)
111
+ columns.append([x.molecule for x in layer])
112
+
113
+ columns = [columns[::-1] for columns in columns[::-1]] # Reverse array to make retrosynthetic graph
114
+ pred = tuple( # Change dict to tuple to make multiple retrons_to_expand available
115
+ (abs(source - len(pred)), abs(target - len(pred))) for target, source in pred.items())
116
+
117
+ # now we have columns for visualizing
118
+ # lets start recalculate XY
119
+ x_shift = 0.0
120
+ c_max_x = 0.0
121
+ c_max_y = 0.0
122
+ render = []
123
+ cx = count()
124
+ cy = count()
125
+ arrow_points = {}
126
+ for ms in columns:
127
+ heights = []
128
+ for m in ms:
129
+ m.clean2d()
130
+ # X-shift for target
131
+ min_x = min(x for x, y in m._plane.values()) - x_shift
132
+ min_y = min(y for x, y in m._plane.values())
133
+ m._plane = {n: (x - min_x, y - min_y) for n, (x, y) in m._plane.items()}
134
+ max_x = max(x for x, y in m._plane.values())
135
+ if max_x > c_max_x:
136
+ c_max_x = max_x
137
+ arrow_points[next(cx)] = [x_shift, max_x]
138
+ heights.append(max(y for x, y in m._plane.values()))
139
+
140
+ x_shift = c_max_x + 5.0 # between columns gap
141
+ # calculate Y-shift
142
+ y_shift = sum(heights) + 3.0 * (len(heights) - 1)
143
+ if y_shift > c_max_y:
144
+ c_max_y = y_shift
145
+ y_shift /= 2.0
146
+ for m, h in zip(ms, heights):
147
+ m._plane = {n: (x, y - y_shift) for n, (x, y) in m._plane.items()}
148
+
149
+ # Calculate coordinates for boxes
150
+ max_x = max(x for x, y in m._plane.values()) + 0.9 # Max x
151
+ min_x = min(x for x, y in m._plane.values()) - 0.6 # Min x
152
+ max_y = -(max(y for x, y in m._plane.values()) + 0.45) # Max y
153
+ min_y = -(min(y for x, y in m._plane.values()) - 0.45) # Min y
154
+ x_delta = abs(max_x - min_x)
155
+ y_delta = abs(max_y - min_y)
156
+ box = (
157
+ f'<rect x="{min_x}" y="{max_y}" rx="{y_delta * 0.1}" ry="{y_delta * 0.1}" width="{x_delta}" height="{y_delta}"'
158
+ f' stroke="black" stroke-width=".0025" fill="{box_colors[m.meta["status"]]}" fill-opacity="0.30"/>')
159
+ arrow_points[next(cy)].append(y_shift - h / 2.0)
160
+ y_shift -= h + 3.0
161
+ depicted_molecule = list(m.depict(embedding=True))[:3]
162
+ depicted_molecule.append(box)
163
+ render.append(depicted_molecule)
164
+
165
+ # Calculate mid-X coordinate to draw square arrows
166
+ graph = {}
167
+ for s, p in pred:
168
+ try:
169
+ graph[s].append(p)
170
+ except KeyError:
171
+ graph[s] = [p]
172
+ for s, ps in graph.items():
173
+ mid_x = float("-inf")
174
+ for p in ps:
175
+ s_min_x, s_max, s_y = arrow_points[s][:3] # s
176
+ p_min_x, p_max, p_y = arrow_points[p][:3] # p
177
+ p_max += 1
178
+ mid = p_max + (s_min_x - p_max) / 3
179
+ if mid > mid_x:
180
+ mid_x = mid
181
+ for p in ps:
182
+ arrow_points[p].append(mid_x)
183
+
184
+ config = MoleculeContainer._render_config
185
+ font_size = config["font_size"]
186
+ font125 = 1.25 * font_size
187
+ width = c_max_x + 4.0 * font_size # 3.0 by default
188
+ height = c_max_y + 3.5 * font_size # 2.5 by default
189
+ box_y = height / 2.0
190
+ svg = [f'<svg width="{0.6 * width:.2f}cm" height="{0.6 * height:.2f}cm" '
191
+ f'viewBox="{-font125:.2f} {-box_y:.2f} {width:.2f} '
192
+ f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">',
193
+ ' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
194
+ 'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>', ]
195
+
196
+ for s, p in pred:
197
+ """
198
+ (x1, y1) = (p_max, p_y)
199
+ (x2, y2) = (s_min_x, s_y)
200
+ polyline: (x1 y1, x2 y2, x3 y3, ..., xN yN)
201
+ """
202
+ s_min_x, s_max, s_y = arrow_points[s][:3]
203
+ p_min_x, p_max, p_y = arrow_points[p][:3]
204
+ p_max += 1
205
+ mid_x = arrow_points[p][-1] # p_max + (s_min_x - p_max) / 3
206
+ """print(f"s_min_x: {s_min_x}, s_max: {s_max}, s_y: {s_y}")
207
+ print(f"p_min_x: {p_min_x}, p_max: {p_max}, p_y: {p_y}")
208
+ print(f"mid_x: {mid_x}\n")"""
209
+
210
+ arrow = f""" <polyline points="{p_max:.2f} {p_y:.2f}, {mid_x:.2f} {p_y:.2f}, {mid_x:.2f} {s_y:.2f}, {s_min_x - 1.:.2f} {s_y:.2f}"
211
+ fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>"""
212
+ if p_y != s_y:
213
+ arrow += f' <circle cx="{mid_x}" cy="{p_y}" r="0.1"/>'
214
+ svg.append(arrow)
215
+ for atoms, bonds, masks, box in render:
216
+ molecule_svg = MoleculeContainer._graph_svg(atoms, bonds, masks, -font125, -box_y, width, height)
217
+ molecule_svg.insert(1, box)
218
+ svg.extend(molecule_svg)
219
+ svg.append("</svg>")
220
+ return "\n".join(svg)
221
+
222
+
223
+ def to_table(tree: Tree, html_path: path_type, aam: bool = False, extended=False, integration: bool = False):
224
+ """
225
+ Write an HTML page with the synthesis paths in SVG format and corresponding reactions in SMILES format
226
+
227
+ :param tree: # TODO
228
+ :param extended: # TODO
229
+ :param html_path: Path to save the HTML molecules_path, if None returns the html without saving it
230
+ :type html_path: str (optional)
231
+ :param aam: depict atom-to-atom mapping
232
+ :type aam: bool (optional)
233
+ :param integration: Whenever to output the full html file (False) or only the body (True)
234
+ :type integration: bool
235
+ """
236
+ if aam:
237
+ MoleculeContainer.depict_settings(aam=True)
238
+ else:
239
+ MoleculeContainer.depict_settings(aam=False)
240
+
241
+ paths = []
242
+ if extended:
243
+ # Gather paths
244
+ for idx, node in tree.nodes.items():
245
+ if node.is_solved():
246
+ paths.append(idx)
247
+ else:
248
+ paths = tree.winning_nodes
249
+ # HTML Tags
250
+ th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
251
+ td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
252
+ font_red = "<font color='red' style='font-weight: bold'>"
253
+ font_green = "<font color='light-green' style='font-weight: bold'>"
254
+ font_head = "<font style='font-weight: bold; font-size: 18px'>"
255
+ font_normal = "<font style='font-weight: normal; font-size: 18px'>"
256
+ font_close = "</font>"
257
+
258
+ template_begin = """
259
+ <!doctype html>
260
+ <html lang="en">
261
+ <head>
262
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css"
263
+ rel="stylesheet"
264
+ integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
265
+ crossorigin="anonymous">
266
+ <script
267
+ src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
268
+ integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
269
+ crossorigin="anonymous">
270
+ </script>
271
+ <meta charset="utf-8">
272
+ <meta name="viewport" content="width=device-width, initial-scale=1">
273
+ <title>Predicted Paths Report</title>
274
+ <meta name="description" content="A simple HTML5 Template for new projects.">
275
+ <meta name="author" content="SitePoint">
276
+ </head>
277
+ <body>
278
+ """
279
+ template_end = """
280
+ </body>
281
+ </html>
282
+ """
283
+ # SVG Template
284
+ # box_mark = """
285
+ # <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg">
286
+ # <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
287
+ # </svg>
288
+ # """
289
+ # table = f"<table><thead><{th}>Retrosynthetic Routes</th></thead><tbody>"
290
+ table = """<table class="table table-striped table-hover caption-top">"""
291
+ if not integration:
292
+ table += "<caption><h3>Retrosynthetic Routes Report</h3></caption><tbody>"
293
+ else:
294
+ table += "<tbody>"
295
+
296
+ # Gather path data
297
+ table += f"<tr>{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_retron)}{font_close}</td></tr>"
298
+ table += (f"<tr>{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes</td></tr>")
299
+ table += f"<tr>{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}</td></tr>"
300
+ table += f"<tr>{td}{font_normal}Found paths: {len(paths)}{font_close}</td></tr>"
301
+ table += f"<tr>{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds</td></tr>"
302
+ table += f"""\
303
+ <tr>{td} \
304
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
305
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb(152, 238, 255)" fill-opacity="0.35" /></svg> \
306
+ Target Molecule \
307
+ \
308
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
309
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb(240, 171, 144)" fill-opacity="0.35" /></svg> \
310
+ Molecule Not In Stock \
311
+ \
312
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
313
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb(155, 250, 179)" fill-opacity="0.35" /></svg> \
314
+ Molecule In Stock \
315
+ \
316
+ </td></tr> \
317
+ """
318
+
319
+ for path in paths:
320
+ svg = path_graph(tree, path) # Get SVG
321
+ full_path = tree.synthesis_path(path) # Get Path
322
+ # Write SMILES of all reactions in synthesis path
323
+ step = 1
324
+ reactions = ""
325
+ for synth_step in full_path:
326
+ reactions += f"<b>Step {step}:</b> {str(synth_step)}<br>"
327
+ step += 1
328
+ # Concatenate all content of path
329
+ path_score = round(tree.path_score(path), 3)
330
+ table += (f'<tr style="line-height: 250%">{td}{font_head}Path {path}; '
331
+ f"Steps: {len(full_path)}; "
332
+ f"Cumulated nodes' value: {path_score}{font_close}</td></tr>")
333
+ # f"Cumulated nodes' value: {node._probabilities[path]}{font_close}</td></tr>"
334
+ table += f"<tr>{td}{svg}</td></tr>"
335
+ table += f"<tr>{td}{reactions}</td></tr>"
336
+ table += "</tbody>"
337
+
338
+ # Save or display output
339
+ if not html_path:
340
+ return table if integration else template_begin + table + template_end
341
+
342
+ output = html_path
343
+ with open(output, "w") as html_file:
344
+ html_file.write(template_begin)
345
+ html_file.write(table)
346
+ html_file.write(template_end)
SynTool/mcts/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .node import *
2
+ from .tree import *
3
+ from CGRtools.containers import MoleculeContainer
4
+
5
+ MoleculeContainer.depict_settings(aam=False)
6
+
7
+ __all__ = ["Tree", "Node"]
SynTool/mcts/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (322 Bytes). View file
 
SynTool/mcts/__pycache__/evaluation.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
SynTool/mcts/__pycache__/expansion.cpython-310.pyc ADDED
Binary file (2.86 kB). View file
 
SynTool/mcts/__pycache__/node.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
SynTool/mcts/__pycache__/search.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
SynTool/mcts/__pycache__/tree.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
SynTool/mcts/evaluation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing a class that represents a value function for prediction of synthesisablity
3
+ of new nodes in the search tree
4
+ """
5
+
6
+ import logging
7
+ import torch
8
+
9
+ from pathlib import Path
10
+
11
+ from SynTool.chem.retron import compose_retrons
12
+ from SynTool.ml.networks.value import ValueNetwork
13
+ from SynTool.ml.training import mol_to_pyg
14
+
15
+
16
+ class ValueFunction:
17
+ """
18
+ Value function based on value neural network for node evaluation (synthesisability prediction) in MCTS
19
+ """
20
+
21
+ def __init__(self, weights_path: str) -> None:
22
+ """
23
+ The value function predicts the probability to synthesize the target molecule with available building blocks
24
+ starting from a given retron.
25
+
26
+ :param weights_path: The value network weights location
27
+ :type weights_path: Path
28
+ """
29
+
30
+ value_net = ValueNetwork.load_from_checkpoint(
31
+ weights_path,
32
+ map_location=torch.device("cpu")
33
+ )
34
+
35
+ self.value_network = value_net.eval()
36
+
37
+ def predict_value(self, retrons: list) -> float:
38
+ """
39
+ The function predicts a value based on the given retrons. For prediction, retrons must be composed into a single
40
+ molecule (product)
41
+
42
+ :param retrons: The list of retrons
43
+ :type retrons: list
44
+ """
45
+
46
+ molecule = compose_retrons(retrons=retrons, exclude_small=True)
47
+ pyg_graph = mol_to_pyg(molecule)
48
+ if pyg_graph:
49
+ with torch.no_grad():
50
+ value_pred = self.value_network.forward(pyg_graph)[0].item()
51
+ else:
52
+ try:
53
+ logging.debug(f"Molecule {str(molecule)} was not preprocessed. Giving value equal to -1e6.")
54
+ except:
55
+ logging.debug(f"There is a molecule for which SMILES cannot be generated")
56
+
57
+ value_pred = -1e6
58
+
59
+ return value_pred
SynTool/mcts/expansion.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing a class that represents a policy function for node expansion in the search tree
3
+ """
4
+
5
+ import torch
6
+ import torch_geometric
7
+ from SynTool.chem.retron import Retron
8
+ from SynTool.ml.networks.policy import PolicyNetwork
9
+ from SynTool.ml.training import mol_to_pyg
10
+ from SynTool.utils.config import PolicyNetworkConfig
11
+
12
+
13
+ class PolicyFunction:
14
+ """
15
+ Policy function based on policy neural network for node expansion in MCTS
16
+ """
17
+
18
+ def __init__(self, policy_config: PolicyNetworkConfig, compile: bool = False):
19
+ """
20
+ Initializes the expansion function (ranking or filter policy network).
21
+
22
+ :param policy_config: A configuration object settings for the expansion policy
23
+ :type policy_config: PolicyConfig
24
+ :param compile: XX # TODO what is compile # TODO2 compile is a bad variable name - is a builtin function name
25
+ :type compile: bool
26
+ """
27
+
28
+ self.config = policy_config
29
+
30
+ policy_net = PolicyNetwork.load_from_checkpoint(
31
+ self.config.weights_path,
32
+ map_location=torch.device("cpu"),
33
+ batch_size=1,
34
+ dropout=0
35
+ )
36
+
37
+ policy_net = policy_net.eval()
38
+ if compile:
39
+ self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
40
+ else:
41
+ self.policy_net = policy_net
42
+
43
+ def predict_reaction_rules(self, retron: Retron, reaction_rules: list): # TODO what is output - finish annotation
44
+ """
45
+ The policy function predicts the list of reaction rules given a retron.
46
+
47
+ :param retron: The current retron for which the reaction rules are predicted
48
+ :type retron: Retron
49
+ :param reaction_rules: The list of reaction rules from which applicable reaction rules are predicted and selected.
50
+ :type reaction_rules: list
51
+ """
52
+
53
+ pyg_graph = mol_to_pyg(retron.molecule, canonicalize=False)
54
+ if pyg_graph:
55
+ with torch.no_grad():
56
+ if self.policy_net.policy_type == "filtering":
57
+ probs, priority = self.policy_net.forward(pyg_graph)
58
+ if self.policy_net.policy_type == "ranking":
59
+ probs = self.policy_net.forward(pyg_graph)
60
+ del pyg_graph
61
+ else:
62
+ return []
63
+
64
+ probs = probs[0].double()
65
+ if self.policy_net.policy_type == "filtering":
66
+ priority = priority[0].double()
67
+ priority_coef = self.config.priority_rules_fraction
68
+ probs = (1 - priority_coef) * probs + priority_coef * priority
69
+
70
+ sorted_probs, sorted_rules = torch.sort(probs, descending=True)
71
+ sorted_probs, sorted_rules = (
72
+ sorted_probs[: self.config.top_rules],
73
+ sorted_rules[: self.config.top_rules],
74
+ )
75
+
76
+ if self.policy_net.policy_type == "filtering":
77
+ sorted_probs = torch.softmax(sorted_probs, -1)
78
+
79
+ sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
80
+
81
+ for prob, rule_id in zip(sorted_probs, sorted_rules):
82
+ if prob > self.config.rule_prob_threshold: # TODO it will destroy all search if it is not correct (>0.5)
83
+ yield prob, reaction_rules[rule_id], rule_id
SynTool/mcts/node.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing a class Node that represents a node in the search tree
3
+ """
4
+
5
+
6
+ class Node:
7
+ """
8
+ Node class represents a node in the search tree
9
+ """
10
+
11
+ def __init__(self, retrons_to_expand: tuple = None, new_retrons: tuple = None) -> None:
12
+ """
13
+ The function initializes the new Node object.
14
+
15
+ :param retrons_to_expand: The tuple of retrons to be expanded. The first retron in the tuple is the current
16
+ retron which will be expanded (for which new retrons will be generated by applying the predicted reaction
17
+ rules). When the first retron has been successfully expanded, the second retron becomes the current retron
18
+ to be expanded.
19
+ :param new_retrons: The tuple of new retrons generated by applying the reaction rule. New retrons have already
20
+ been added to the retrons_to_expand (see Tree._expand_node). Here they are stored for information.
21
+ """
22
+
23
+ self.retrons_to_expand = retrons_to_expand
24
+ self.new_retrons = new_retrons
25
+
26
+ if len(self.retrons_to_expand) == 0:
27
+ self.curr_retron = tuple()
28
+ else:
29
+ self.curr_retron = self.retrons_to_expand[0]
30
+ self.next_retrons = self.retrons_to_expand[1:]
31
+
32
+ def __len__(self) -> int:
33
+ """
34
+ The number of retrons in this node to expand.
35
+ """
36
+ return len(self.retrons_to_expand)
37
+
38
+ def __repr__(self) -> str:
39
+ """
40
+ String representation of the node. Returns the smiles of retrons_to_expand and new_retrons.
41
+ """
42
+ return f"retrons_to_expand: {self.retrons_to_expand}\nnew_retrons: {self.new_retrons}"
43
+
44
+ def is_solved(self) -> bool:
45
+ """
46
+ Is terminal node. There are not retrons for expansion.
47
+ """
48
+
49
+ return len(self.retrons_to_expand) == 0
SynTool/mcts/search.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing functions for running tree search for the set of target molecules
3
+ """
4
+
5
+ import csv
6
+ import json
7
+ from pathlib import Path
8
+
9
+ from tqdm import tqdm
10
+
11
+ from SynTool.interfaces.visualisation import to_table, extract_routes
12
+ from SynTool.mcts.tree import Tree, TreeConfig
13
+ from SynTool.mcts.evaluation import ValueFunction
14
+ from SynTool.mcts.expansion import PolicyFunction
15
+ from SynTool.utils import path_type
16
+ from SynTool.utils.files import MoleculeReader
17
+ from SynTool.utils.config import PolicyNetworkConfig
18
+
19
+
20
+ def extract_tree_stats(tree, target):
21
+ """
22
+ Collects various statistics from a tree and returns them in a dictionary format
23
+
24
+ :param tree: The retro tree.
25
+ :param target: The target molecule or compound that you want to search for in the tree. It is
26
+ expected to be a string representing the SMILES notation of the target molecule
27
+ :return: A dictionary with the calculated statistics
28
+ """
29
+ newick_tree, newick_meta = tree.newickify(visits_threshold=0)
30
+ newick_meta_line = ";".join([f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()])
31
+ return {
32
+ "target_smiles": str(target),
33
+ "tree_size": len(tree),
34
+ "search_time": round(tree.curr_time, 1),
35
+ "found_paths": len(tree.winning_nodes),
36
+ "newick_tree": newick_tree,
37
+ "newick_meta": newick_meta_line,
38
+ }
39
+
40
+
41
+ def tree_search(
42
+ targets_path: path_type,
43
+ tree_config: TreeConfig,
44
+ policy_config: PolicyNetworkConfig,
45
+ reaction_rules_path: path_type,
46
+ building_blocks_path: path_type,
47
+ policy_weights_path: path_type = None, # TODO not used
48
+ value_weights_path: path_type = None,
49
+ results_root: path_type = "search_results"
50
+ ):
51
+ """
52
+ Performs a tree search on a set of target molecules using specified configuration and rules,
53
+ logging the results and statistics.
54
+
55
+ :param tree_config: The config object containing the configuration for the tree search.
56
+ :param policy_config: The config object containing the configuration for the policy.
57
+ :param reaction_rules_path: The path to the file containing reaction rules.
58
+ :param building_blocks_path: The path to the file containing building blocks.
59
+ :param targets_path: The path to the file containing the target molecules (in SDF or SMILES format).
60
+ :param value_weights_path: The path to the file containing value weights (optional).
61
+ :param results_root: The path to the directory where the results of the tree search will be saved. Defaults to 'search_results/'.
62
+ :param retropaths_files_name: The base name for the files that will be generated to store the retro paths. Defaults to 'retropath'. #TODO arg dont exist
63
+
64
+ This function configures and executes a tree search algorithm, leveraging reaction rules and building blocks
65
+ to find synthetic pathways for given target molecules. The results, including paths and statistics, are
66
+ saved in the specified directory. Logging is used to record the process and any issues encountered.
67
+ """
68
+
69
+ targets_file = Path(targets_path)
70
+
71
+ # results folder
72
+ results_root = Path(results_root)
73
+ if not results_root.exists():
74
+ results_root.mkdir()
75
+
76
+ # output files
77
+ stats_file = results_root.joinpath("tree_search_stats.csv")
78
+ paths_file = results_root.joinpath("extracted_paths.json")
79
+ retropaths_folder = results_root.joinpath("retropaths")
80
+ retropaths_folder.mkdir(exist_ok=True)
81
+
82
+ # stats header
83
+ stats_header = ["target_smiles", "tree_size", "search_time",
84
+ "found_paths", "newick_tree", "newick_meta"]
85
+
86
+ # config
87
+ policy_function = PolicyFunction(policy_config=policy_config)
88
+ if tree_config.evaluation_type == 'gcn':
89
+ value_function = ValueFunction(weights_path=value_weights_path)
90
+ else:
91
+ value_function = None
92
+
93
+ # run search
94
+ n_solved = 0
95
+ extracted_paths = []
96
+ with MoleculeReader(targets_file) as targets_path, open(stats_file, "w", newline="\n") as csvfile:
97
+ statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header)
98
+ statswriter.writeheader()
99
+
100
+ for ti, target in tqdm(enumerate(targets_path), total=len(targets_path)):
101
+
102
+ try:
103
+ # run search
104
+ tree = Tree(
105
+ target=target,
106
+ tree_config=tree_config,
107
+ reaction_rules_path=reaction_rules_path,
108
+ building_blocks_path=building_blocks_path,
109
+ policy_function=policy_function,
110
+ value_function=value_function,
111
+ )
112
+ _ = list(tree)
113
+
114
+ except:
115
+ continue
116
+
117
+ n_solved += bool(tree.winning_nodes)
118
+
119
+ # extract routes
120
+ extracted_paths.append(extract_routes(tree))
121
+
122
+ # retropaths
123
+ retropaths_file = retropaths_folder.joinpath(f"retropaths_target_{ti}.html")
124
+ to_table(tree, retropaths_file, extended=True)
125
+
126
+ # stats
127
+ statswriter.writerow(extract_tree_stats(tree, target))
128
+ csvfile.flush()
129
+
130
+ #
131
+ with open(paths_file, 'w') as f:
132
+ json.dump(extracted_paths, f)
133
+
134
+ print(f"Solved number of target molecules: {n_solved}")
135
+
SynTool/mcts/tree.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing a class Tree that used for tree search of retrosynthetic paths
3
+ """
4
+
5
+ import logging
6
+ from collections import deque, defaultdict
7
+ from math import sqrt
8
+ from random import choice, uniform
9
+ from time import time
10
+ from typing import Dict, Set, List, Tuple
11
+
12
+ from CGRtools.containers import MoleculeContainer
13
+ from CGRtools import smiles
14
+ from numpy.random import uniform
15
+ from tqdm.auto import tqdm
16
+ from SynTool.utils.loading import load_building_blocks, load_reaction_rules
17
+ from SynTool.chem.reaction import Reaction, apply_reaction_rule
18
+ from SynTool.chem.retron import Retron
19
+ from SynTool.mcts.evaluation import ValueFunction
20
+ from SynTool.mcts.expansion import PolicyFunction
21
+ from SynTool.mcts.node import Node
22
+ from SynTool.utils.config import TreeConfig
23
+
24
+
25
+ class Tree:
26
+ """
27
+ Tree class with attributes and methods for Monte-Carlo tree search
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ target: MoleculeContainer,
33
+ tree_config: TreeConfig,
34
+ reaction_rules_path: str,
35
+ building_blocks_path: str,
36
+ policy_function: PolicyFunction,
37
+ value_function: ValueFunction = None,
38
+ ):
39
+ """
40
+ The function initializes a tree object with optional parameters for tree search for target molecule.
41
+
42
+ :param target: a target molecule for retrosynthesis paths search
43
+ :type target: MoleculeContainer
44
+ :param tree_config: a tree configuration file for retrosynthesis paths search
45
+ :type tree_config: TreeConfig
46
+ :param reaction_rules_path: a path for reaction rules file
47
+ :type reaction_rules_path: str
48
+ :param building_blocks_path: a path for building blocks file
49
+ :type building_blocks_path: str
50
+ :param policy_function: a policy function object
51
+ :type policy_function: PolicyFunction
52
+ :param value_function: a value function object
53
+ :type value_function: ValueFunction
54
+ """
55
+
56
+ # config parameters
57
+ self.config = tree_config
58
+
59
+ # check target
60
+ if isinstance(target, str):
61
+ target = smiles(target)
62
+ assert (bool(target)), "Target is not defined, is not a MoleculeContainer or have no atoms"
63
+ if target:
64
+ target.canonicalize()
65
+
66
+ target_retron = Retron(target, canonicalize=True)
67
+ target_retron.prev_retrons.append(Retron(target, canonicalize=True))
68
+ target_node = Node(retrons_to_expand=(target_retron,), new_retrons=(target_retron,))
69
+
70
+ # tree structure init
71
+ self.nodes: Dict[int, Node] = {1: target_node}
72
+ self.parents: Dict[int, int] = {1: 0}
73
+ self.children: Dict[int, Set[int]] = {1: set()}
74
+ self.winning_nodes: List[int] = list()
75
+ self.visited_nodes: Set[int] = set()
76
+ self.expanded_nodes: Set[int] = set()
77
+ self.nodes_visit: Dict[int, int] = {1: 0}
78
+ self.nodes_depth: Dict[int, int] = {1: 0}
79
+ self.nodes_prob: Dict[int, float] = {1: 0.0}
80
+ self.nodes_init_value: Dict[int, float] = {1: 0.0}
81
+ self.nodes_total_value: Dict[int, float] = {1: 0.0}
82
+
83
+ # tree building limits
84
+ self.curr_iteration: int = 0
85
+ self.curr_tree_size: int = 2
86
+ self.curr_time: float = 2
87
+
88
+ # utils
89
+ self._tqdm = None
90
+
91
+ # policy and value functions
92
+ self.policy_function = policy_function
93
+ if self.config.evaluation_type == "gcn":
94
+ if value_function is None:
95
+ raise ValueError(
96
+ "Value function not specified while evaluation mode is 'gcn'"
97
+ )
98
+ else:
99
+ self.value_function = value_function
100
+
101
+ # building blocks and reaction reaction_rules
102
+ self.reaction_rules = load_reaction_rules(reaction_rules_path)
103
+ self.building_blocks = load_building_blocks(building_blocks_path)
104
+
105
+ def __len__(self) -> int:
106
+ """
107
+ Returns the current size (number of nodes) of a Tree.
108
+ """
109
+
110
+ return self.curr_tree_size - 1
111
+
112
+ def __iter__(self) -> "Tree": # TODO what is annotation "Tree" -> Tree ?
113
+ """
114
+ The function is defining an iterator for a Tree object. Also needed for the bar progress display.
115
+ """
116
+
117
+ if not self._tqdm:
118
+ self._start_time = time()
119
+ self._tqdm = tqdm(
120
+ total=self.config.max_iterations, disable=self.config.silent
121
+ )
122
+ return self
123
+
124
+ def __repr__(self) -> str:
125
+ """
126
+ Returns a string representation of a Tree object (target smiles, tree size, and the number of found paths).
127
+ """
128
+ return self.report()
129
+
130
+ def __next__(self): # TODO what is return - function annotation ? tuple (bool, [node id])
131
+ """
132
+ The __next__ function is used to do one iteration of the tree building.
133
+ """
134
+
135
+ if self.nodes[1].curr_retron.is_building_block(self.building_blocks, self.config.min_mol_size):
136
+ raise StopIteration("Target is building block \n")
137
+
138
+ if self.curr_iteration >= self.config.max_iterations:
139
+ self._tqdm.close()
140
+ raise StopIteration("Iterations limit exceeded. \n")
141
+ elif self.curr_tree_size >= self.config.max_tree_size:
142
+ self._tqdm.close()
143
+ raise StopIteration("Max tree size exceeded or all possible paths found")
144
+ elif self.curr_time >= self.config.max_time:
145
+ self._tqdm.close()
146
+ raise StopIteration("Time limit exceeded. \n")
147
+
148
+ # start new iteration
149
+ self.curr_iteration += 1
150
+ self.curr_time = time() - self._start_time
151
+ self._tqdm.update()
152
+
153
+ curr_depth, node_id = 0, 1 # start from the root node_id
154
+
155
+ explore_path = True
156
+ while explore_path:
157
+ self.visited_nodes.add(node_id)
158
+
159
+ if self.nodes_visit[node_id]: # already visited
160
+ if not self.children[node_id]: # dead node
161
+ logging.debug(
162
+ f"Tree search: bumped into node {node_id} which is dead"
163
+ )
164
+ self._update_visits(node_id)
165
+ explore_path = False
166
+ else:
167
+ node_id = self._select_node(node_id) # select the child node
168
+ curr_depth += 1
169
+ else:
170
+ if self.nodes[node_id].is_solved(): # found path!
171
+ self._update_visits(node_id) # this prevents expanding of bb node_id
172
+ self.winning_nodes.append(node_id)
173
+ return True, [node_id]
174
+
175
+ elif (
176
+ curr_depth < self.config.max_depth
177
+ ): # expand node if depth limit is not reached
178
+ self._expand_node(node_id)
179
+ if not self.children[node_id]: # node was not expanded
180
+ logging.debug(f"Tree search: node {node_id} was not expanded")
181
+ value_to_backprop = -1.0
182
+ else:
183
+ self.expanded_nodes.add(node_id)
184
+
185
+ if self.config.search_strategy == "evaluation_first":
186
+ # recalculate node value based on children synthesisability and backpropagation
187
+ child_values = [
188
+ self.nodes_init_value[child_id]
189
+ for child_id in self.children[node_id]
190
+ ]
191
+
192
+ if self.config.evaluation_agg == "max":
193
+ value_to_backprop = max(child_values)
194
+
195
+ elif self.config.evaluation_agg == "average":
196
+ value_to_backprop = sum(child_values) / len(
197
+ self.children[node_id]
198
+ )
199
+
200
+ else:
201
+ raise ValueError(
202
+ f"Invalid evaluation aggregation mode: {self.config.evaluation_agg} "
203
+ f"Allowed values are 'max', 'average'"
204
+ )
205
+ elif self.config.search_strategy == "expansion_first":
206
+ value_to_backprop = self._get_node_value(node_id)
207
+
208
+ else:
209
+ raise ValueError(
210
+ f"Invalid search_strategy: {self.config.search_strategy}: "
211
+ f"Allowed values are 'expansion_first', 'evaluation_first'"
212
+ )
213
+
214
+ # backpropagation
215
+ self._backpropagate(node_id, value_to_backprop)
216
+ self._update_visits(node_id)
217
+ explore_path = False
218
+
219
+ if self.children[node_id]:
220
+ # found after expansion
221
+ found_after_expansion = set()
222
+ for child_id in iter(self.children[node_id]):
223
+ if self.nodes[child_id].is_solved():
224
+ found_after_expansion.add(child_id)
225
+ self.winning_nodes.append(child_id)
226
+
227
+ if found_after_expansion:
228
+ return True, list(found_after_expansion)
229
+
230
+ else:
231
+ self._backpropagate(node_id, self.nodes_total_value[node_id])
232
+ self._update_visits(node_id)
233
+ explore_path = False
234
+
235
+ return False, [node_id]
236
+
237
+ def _ucb(self, node_id: int) -> float:
238
+ """
239
+ The function calculates the Upper Confidence Bound (UCB) for a given node.
240
+
241
+ :param node_id: The `node_id` parameter is an integer that represents the ID of a node in a tree
242
+ :type node_id: int
243
+ """
244
+
245
+ prob = self.nodes_prob[node_id] # Predicted by policy network score
246
+ visit = self.nodes_visit[node_id]
247
+
248
+ if self.config.ucb_type == "puct":
249
+ u = (
250
+ self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]])
251
+ ) / (visit + 1)
252
+ return self.nodes_total_value[node_id] + u
253
+ elif self.config.ucb_type == "uct":
254
+ u = (
255
+ self.config.c_ucb
256
+ * sqrt(self.nodes_visit[self.parents[node_id]])
257
+ / (visit + 1)
258
+ )
259
+ return self.nodes_total_value[node_id] + u
260
+ elif self.config.ucb_type == "value":
261
+ return self.nodes_init_value[node_id] / (visit + 1)
262
+ else:
263
+ raise ValueError(f"I don't know this UCB type: {self.config.ucb_type}")
264
+
265
+ def _select_node(self, node_id: int) -> int:
266
+ """
267
+ This function selects a node based on its UCB value and returns the ID of the node with the highest value of
268
+ the UCB function.
269
+
270
+ :param node_id: The `node_id` parameter is an integer that represents the ID of a node
271
+ :type node_id: int
272
+ """
273
+
274
+ if self.config.epsilon > 0:
275
+ n = uniform(0, 1)
276
+ if n < self.config.epsilon:
277
+ return choice(list(self.children[node_id]))
278
+
279
+ best_score, best_children = None, []
280
+ for child_id in self.children[node_id]:
281
+ score = self._ucb(child_id)
282
+ if best_score is None or score > best_score:
283
+ best_score, best_children = score, [child_id]
284
+ elif score == best_score:
285
+ best_children.append(child_id)
286
+ return choice(best_children)
287
+
288
+ def _expand_node(self, node_id: int) -> None:
289
+ """
290
+ The function expands a given node by generating new retrons with policy (expansion) policy.
291
+
292
+ :param node_id: The `node_id` parameter is an integer that represents the ID of the current node
293
+ :type node_id: int
294
+ """
295
+ curr_node = self.nodes[node_id]
296
+ prev_retrons = curr_node.curr_retron.prev_retrons
297
+
298
+ tmp_retrons = set()
299
+ for prob, rule, rule_id in self.policy_function.predict_reaction_rules(
300
+ curr_node.curr_retron, self.reaction_rules
301
+ ):
302
+ for products in apply_reaction_rule(curr_node.curr_retron.molecule, rule):
303
+ # check repeated products
304
+ if not products or not set(products) - tmp_retrons:
305
+ continue
306
+ tmp_retrons.update(products)
307
+
308
+ for molecule in products:
309
+ molecule.meta["reactor_id"] = rule_id
310
+
311
+ new_retrons = tuple(Retron(mol) for mol in products)
312
+ scaled_prob = prob * len(
313
+ list(filter(lambda x: len(x) > self.config.min_mol_size, products))
314
+ )
315
+
316
+ if set(prev_retrons).isdisjoint(new_retrons):
317
+ retrons_to_expand = (
318
+ *curr_node.next_retrons,
319
+ *(
320
+ x
321
+ for x in new_retrons
322
+ if not x.is_building_block(
323
+ self.building_blocks, self.config.min_mol_size
324
+ )
325
+ ),
326
+ )
327
+
328
+ child_node = Node(
329
+ retrons_to_expand=retrons_to_expand, new_retrons=new_retrons
330
+ )
331
+
332
+ for new_retron in new_retrons:
333
+ new_retron.prev_retrons = [new_retron, *prev_retrons]
334
+
335
+ self._add_node(node_id, child_node, scaled_prob)
336
+
337
+ def _add_node(self, node_id: int, new_node: Node, policy_prob: float = None) -> None:
338
+ """
339
+ This function adds a new node to a tree with its predicted policy probability.
340
+
341
+ :param node_id: ID of the parent node
342
+ :type node_id: int
343
+ :param new_node: The `new_node` is an instance of the`Node` class
344
+ :type new_node: Node
345
+ :param policy_prob: The `policy_prob` a float value that represents the probability associated with a new node.
346
+ :type policy_prob: float
347
+ """
348
+
349
+ new_node_id = self.curr_tree_size
350
+
351
+ self.nodes[new_node_id] = new_node
352
+ self.parents[new_node_id] = node_id
353
+ self.children[node_id].add(new_node_id)
354
+ self.children[new_node_id] = set()
355
+ self.nodes_visit[new_node_id] = 0
356
+ self.nodes_prob[new_node_id] = policy_prob
357
+ self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1
358
+ self.curr_tree_size += 1
359
+
360
+ if self.config.search_strategy == "evaluation_first":
361
+ node_value = self._get_node_value(new_node_id)
362
+ elif self.config.search_strategy == "expansion_first":
363
+ node_value = self.config.init_node_value
364
+ else:
365
+ raise ValueError(
366
+ f"Invalid search_strategy: {self.config.search_strategy}: "
367
+ f"Allowed values are 'expansion_first', 'evaluation_first'"
368
+ )
369
+
370
+ self.nodes_init_value[new_node_id] = node_value
371
+ self.nodes_total_value[new_node_id] = node_value
372
+
373
+ def _get_node_value(self, node_id: int) -> float:
374
+ """
375
+ This function calculates the value for the given node.
376
+
377
+ :param node_id: ID of the given node
378
+ :type node_id: int
379
+ """
380
+
381
+ node = self.nodes[node_id]
382
+
383
+ if self.config.evaluation_type == "random":
384
+ node_value = uniform()
385
+
386
+ elif self.config.evaluation_type == "rollout":
387
+ node_value = min(
388
+ (
389
+ self._rollout_node(retron, current_depth=self.nodes_depth[node_id])
390
+ for retron in node.retrons_to_expand
391
+ ),
392
+ default=1.0,
393
+ )
394
+
395
+ elif self.config.evaluation_type == "gcn":
396
+ node_value = self.value_function.predict_value(node.new_retrons)
397
+
398
+ else:
399
+ raise ValueError(
400
+ f"I don't know this evaluation mode: {self.config.evaluation_type}"
401
+ )
402
+
403
+ return node_value
404
+
405
+ def _update_visits(self, node_id: int) -> None:
406
+ """
407
+ The function updates the number of visits from a given node to a root node.
408
+
409
+ :param node_id: The ID of a current node
410
+ :type node_id: int
411
+ """
412
+
413
+ while node_id:
414
+ self.nodes_visit[node_id] += 1
415
+ node_id = self.parents[node_id]
416
+
417
+ def _backpropagate(self, node_id: int, value: float) -> None:
418
+ """
419
+ The function backpropagates a value through a tree of a given node specified by node_id.
420
+
421
+ :param node_id: The ID of a given node from which to backpropagate value
422
+ :type node_id: int
423
+ :param value: The value to backpropagate
424
+ :type value: float
425
+ """
426
+ while node_id:
427
+ if self.config.backprop_type == "muzero":
428
+ self.nodes_total_value[node_id] = (
429
+ self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value
430
+ ) / (self.nodes_visit[node_id] + 1)
431
+ elif self.config.backprop_type == "cumulative":
432
+ self.nodes_total_value[node_id] += value
433
+ else:
434
+ raise ValueError(
435
+ f"I don't know this backpropagation type: {self.config.backprop_type}"
436
+ )
437
+ node_id = self.parents[node_id]
438
+
439
+ def _rollout_node(self, retron: Retron, current_depth: int = None) -> float:
440
+ """
441
+ The function `_rollout_node` performs a rollout simulation from a given node in a tree.
442
+ Given the current retron, find the first successful reaction and return the new retrons.
443
+
444
+ If the retron is a building_block, return 1.0, else check the first successful reaction;
445
+
446
+ If the reaction is not successful, return -1.0;
447
+
448
+ If the reaction is successful, but the generated retrons are not the building_blocks and the retrons
449
+ cannot be generated without exceeding current_depth threshold, return -0.5;
450
+
451
+ If the reaction is successful, but the retrons are not the building_blocks and the retrons
452
+ cannot be generated, return -1.0;
453
+
454
+ :param retron: A Retron object
455
+ :type retron: Retron
456
+ :param current_depth: The current depth of the tree
457
+ :type current_depth: int
458
+ """
459
+
460
+ max_depth = self.config.max_depth - current_depth
461
+
462
+ # retron checking
463
+ if retron.is_building_block(self.building_blocks, self.config.min_mol_size):
464
+ return 1.0
465
+
466
+ if max_depth == 0:
467
+ logging.debug("Rollout: tried to perform rollout on the leaf node")
468
+ return -0.5
469
+
470
+ # retron simulating
471
+ occurred_retrons = set()
472
+ retrons_to_expand = deque([retron])
473
+ history = defaultdict(dict)
474
+ rollout_depth = 0
475
+ while retrons_to_expand:
476
+ # Iterate through reactors and pick first successful reaction.
477
+ # Check products of the reaction if you can find them in in-building_blocks data
478
+ # If not, then add missed products to retrons_to_expand and try to decompose them
479
+ if len(history) >= max_depth:
480
+ logging.debug(
481
+ f"Rollout: max depth of rollout is reached with these "
482
+ f"retrons to expand: {retrons_to_expand} {history}",
483
+ )
484
+ reward = -0.5
485
+ return reward
486
+
487
+ current_retron = retrons_to_expand.popleft()
488
+ history[rollout_depth]["target"] = current_retron
489
+ occurred_retrons.add(current_retron)
490
+
491
+ # Pick the first successful reaction while iterating through reactors
492
+ reaction_rule_applied = False
493
+ for prob, rule, rule_id in self.policy_function.predict_reaction_rules(
494
+ current_retron, self.reaction_rules
495
+ ):
496
+ for products in apply_reaction_rule(current_retron.molecule, rule):
497
+ if products:
498
+ reaction_rule_applied = True
499
+ break
500
+
501
+ if reaction_rule_applied:
502
+ history[rollout_depth]["rule_index"] = rule_id
503
+ break
504
+
505
+ if not reaction_rule_applied:
506
+ logging.debug(
507
+ f"Rollout: no reaction rule was applied for the "
508
+ f"molecule {current_retron} on rollout depth {rollout_depth}"
509
+ )
510
+ reward = -1.0
511
+ return reward
512
+
513
+ products = tuple(Retron(product) for product in products) # TODO /!\ Is it ok how products is defined above (line 496) ? Seems to
514
+ # TODO /!\ consider only last iterable of apply_reaction_rule
515
+ history[rollout_depth]["products"] = products
516
+
517
+ # check loops
518
+ if any(x in occurred_retrons for x in products) and products:
519
+ # Sometimes manual can create a loop, when
520
+ logging.debug("Rollout: rollout got in the loop: %s", history)
521
+ # print('occurred_retrons')
522
+ reward = -1.0
523
+ return reward
524
+
525
+ if occurred_retrons.isdisjoint(products):
526
+ # Added number of atoms check
527
+ retrons_to_expand.extend(
528
+ [
529
+ x
530
+ for x in products
531
+ if not x.is_building_block(
532
+ self.building_blocks, self.config.min_mol_size
533
+ )
534
+ ]
535
+ )
536
+ rollout_depth += 1
537
+
538
+ reward = 1.0
539
+ return reward
540
+
541
+ def report(self) -> str:
542
+ """
543
+ Returns the string representation of the tree.
544
+ """
545
+
546
+ return (
547
+ f"Tree for: {str(self.nodes[1].retrons_to_expand[0])}\n"
548
+ f"Number of nodes: {len(self)}\nNumber of visited nodes: {len(self.visited_nodes)}\n"
549
+ f"Found paths: {len(self.winning_nodes)}\nTime: {round(self.curr_time, 1)} seconds"
550
+ )
551
+
552
+ def path_score(self, node_id: int) -> float:
553
+ """
554
+ The function calculates the score of a given path from the node with node_id to the root node.
555
+
556
+ :param node_id: The ID of a given node
557
+ :type node_id: int
558
+ """
559
+
560
+ cumulated_nodes_value, path_length = 0, 0
561
+ while node_id:
562
+ path_length += 1
563
+
564
+ cumulated_nodes_value += self.nodes_total_value[node_id]
565
+ node_id = self.parents[node_id]
566
+
567
+ return cumulated_nodes_value / (path_length ** 2)
568
+
569
+ def path_to_node(self, node_id: int) -> list:
570
+ """
571
+ The function returns the path (list of IDs of nodes) to from a node specified by node_id to the root node.
572
+
573
+ :param node_id: The ID of a given node
574
+ :type node_id: int
575
+ """
576
+
577
+ nodes = []
578
+ while node_id:
579
+ nodes.append(node_id)
580
+ node_id = self.parents[node_id]
581
+ return [self.nodes[node_id] for node_id in reversed(nodes)]
582
+
583
+ def synthesis_path(self, node_id: int) -> Tuple[Reaction, ...]:
584
+ """
585
+ Given a node_id, return a tuple of Reactions that represent the synthesis path from the
586
+ node specified with node_id to the root node
587
+
588
+ :param node_id: The ID of a given node
589
+ :type node_id: int
590
+ """
591
+
592
+ nodes = self.path_to_node(node_id)
593
+
594
+ tmp = [
595
+ Reaction(
596
+ [x.molecule for x in after.new_retrons],
597
+ [before.curr_retron.molecule],
598
+ )
599
+ for before, after in zip(nodes, nodes[1:])
600
+ ]
601
+
602
+ for r in tmp:
603
+ r.clean2d()
604
+ return tuple(reversed(tmp))
605
+
606
+ def newickify(self, visits_threshold: int = 0, root_node_id: int = 1): # TODO what is return here ?
607
+ """
608
+ Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format
609
+ :param visits_threshold: the minimum number of visits for the given node # TODO is this explanation correct ?
610
+ :type visits_threshold: int
611
+ :param root_node_id: The ID of a root node
612
+ :type root_node_id: int
613
+ """
614
+ visited_nodes = set()
615
+
616
+ def newick_render_node(current_node_id: int) -> str:
617
+ """
618
+ Recursively generates a Newick string representation of a tree
619
+
620
+ :param current_node_id: The identifier of the current node in the tree
621
+ :type current_node_id: The identifier of the current node in the tree
622
+ :return: A string representation of a node in a Newick format
623
+ """
624
+ assert (
625
+ current_node_id not in visited_nodes
626
+ ), "Error: The tree may not be circular!"
627
+ node_visit = self.nodes_visit[current_node_id]
628
+
629
+ visited_nodes.add(current_node_id)
630
+ if self.children[current_node_id]:
631
+ # Nodes
632
+ children = [
633
+ child
634
+ for child in list(self.children[current_node_id])
635
+ if self.nodes_visit[child] >= visits_threshold
636
+ ]
637
+ children_strings = [newick_render_node(child) for child in children]
638
+ children_strings = ",".join(children_strings)
639
+ if children_strings:
640
+ return f"({children_strings}){current_node_id}:{node_visit}"
641
+ else:
642
+ # Leafs within threshold
643
+ return f"{current_node_id}:{node_visit}"
644
+ else:
645
+ # Leafs
646
+ return f"{current_node_id}:{node_visit}"
647
+
648
+ newick_string = newick_render_node(root_node_id) + ";"
649
+
650
+ meta = {}
651
+ for node_id in iter(visited_nodes):
652
+ node_value = round(self.nodes_total_value[node_id], 3)
653
+
654
+ node_synthesisability = round(self.nodes_init_value[node_id])
655
+
656
+ visit_in_node = self.nodes_visit[node_id]
657
+ meta[node_id] = (node_value, node_synthesisability, visit_in_node)
658
+
659
+ return newick_string, meta
SynTool/ml/__init__.py ADDED
File without changes
SynTool/ml/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (147 Bytes). View file
 
SynTool/ml/networks/__init__.py ADDED
File without changes
SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
SynTool/ml/networks/__pycache__/modules.cpython-310.pyc ADDED
Binary file (8.14 kB). View file