Spaces:
Sleeping
Sleeping
GUI v0.1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SynTool/__init__.py +3 -0
- SynTool/chem/__init__.py +0 -0
- SynTool/chem/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/chem/__pycache__/reaction.cpython-310.pyc +0 -0
- SynTool/chem/__pycache__/retron.cpython-310.pyc +0 -0
- SynTool/chem/__pycache__/utils.cpython-310.pyc +0 -0
- SynTool/chem/data/__init__.py +0 -0
- SynTool/chem/data/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc +0 -0
- SynTool/chem/data/__pycache__/filtering.cpython-310.pyc +0 -0
- SynTool/chem/data/__pycache__/mapping.cpython-310.pyc +0 -0
- SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc +0 -0
- SynTool/chem/data/cleaning.py +124 -0
- SynTool/chem/data/filtering.py +917 -0
- SynTool/chem/data/mapping.py +96 -0
- SynTool/chem/data/mapping.py.bk +90 -0
- SynTool/chem/data/standardizer.py +604 -0
- SynTool/chem/reaction.py +107 -0
- SynTool/chem/reaction_rules/__init__.py +0 -0
- SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc +0 -0
- SynTool/chem/reaction_rules/extraction.py +679 -0
- SynTool/chem/reaction_rules/manual/__init__.py +6 -0
- SynTool/chem/reaction_rules/manual/decompositions.py +415 -0
- SynTool/chem/reaction_rules/manual/transformations.py +535 -0
- SynTool/chem/retron.py +132 -0
- SynTool/chem/utils.py +227 -0
- SynTool/interfaces/__init__.py +0 -0
- SynTool/interfaces/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc +0 -0
- SynTool/interfaces/cli.py +530 -0
- SynTool/interfaces/cli.py.bk +241 -0
- SynTool/interfaces/visualisation.py +346 -0
- SynTool/mcts/__init__.py +7 -0
- SynTool/mcts/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/mcts/__pycache__/evaluation.cpython-310.pyc +0 -0
- SynTool/mcts/__pycache__/expansion.cpython-310.pyc +0 -0
- SynTool/mcts/__pycache__/node.cpython-310.pyc +0 -0
- SynTool/mcts/__pycache__/search.cpython-310.pyc +0 -0
- SynTool/mcts/__pycache__/tree.cpython-310.pyc +0 -0
- SynTool/mcts/evaluation.py +59 -0
- SynTool/mcts/expansion.py +83 -0
- SynTool/mcts/node.py +49 -0
- SynTool/mcts/search.py +135 -0
- SynTool/mcts/tree.py +659 -0
- SynTool/ml/__init__.py +0 -0
- SynTool/ml/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/ml/networks/__init__.py +0 -0
- SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc +0 -0
- SynTool/ml/networks/__pycache__/modules.cpython-310.pyc +0 -0
SynTool/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .mcts import *
|
2 |
+
|
3 |
+
__all__ = ["Tree"]
|
SynTool/chem/__init__.py
ADDED
File without changes
|
SynTool/chem/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
SynTool/chem/__pycache__/reaction.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
SynTool/chem/__pycache__/retron.cpython-310.pyc
ADDED
Binary file (4.88 kB). View file
|
|
SynTool/chem/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (8.25 kB). View file
|
|
SynTool/chem/data/__init__.py
ADDED
File without changes
|
SynTool/chem/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc
ADDED
Binary file (3.86 kB). View file
|
|
SynTool/chem/data/__pycache__/filtering.cpython-310.pyc
ADDED
Binary file (27.6 kB). View file
|
|
SynTool/chem/data/__pycache__/mapping.cpython-310.pyc
ADDED
Binary file (2.59 kB). View file
|
|
SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc
ADDED
Binary file (18 kB). View file
|
|
SynTool/chem/data/cleaning.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from multiprocessing import Queue, Process, Manager, Value
|
3 |
+
from logging import getLogger, Logger
|
4 |
+
from tqdm import tqdm
|
5 |
+
from CGRtools.containers import ReactionContainer
|
6 |
+
|
7 |
+
from .standardizer import Standardizer
|
8 |
+
from SynTool.utils.files import ReactionReader, ReactionWriter
|
9 |
+
from SynTool.utils.config import ReactionStandardizationConfig
|
10 |
+
|
11 |
+
|
12 |
+
def cleaner(reaction: ReactionContainer, logger: Logger, config: ReactionStandardizationConfig):
|
13 |
+
"""
|
14 |
+
Standardize a reaction according to external script
|
15 |
+
|
16 |
+
:param reaction: ReactionContainer to clean/standardize
|
17 |
+
:param logger: Logger - to avoid writing log
|
18 |
+
:param config: ReactionStandardizationConfig
|
19 |
+
:return: ReactionContainer or empty list
|
20 |
+
"""
|
21 |
+
standardizer = Standardizer(id_tag='Reaction_ID',
|
22 |
+
action_on_isotopes=2,
|
23 |
+
skip_tautomerize=True,
|
24 |
+
skip_errors=config.skip_errors,
|
25 |
+
keep_unbalanced_ions=config.keep_unbalanced_ions,
|
26 |
+
keep_reagents=config.keep_reagents,
|
27 |
+
ignore_mapping=config.ignore_mapping,
|
28 |
+
logger=logger)
|
29 |
+
return standardizer.standardize(reaction)
|
30 |
+
|
31 |
+
|
32 |
+
def worker_cleaner(to_clean: Queue, to_write: Queue, config: ReactionStandardizationConfig):
|
33 |
+
"""
|
34 |
+
Launches standardizations using the Queue to_clean. Fills the to_write Queue with results
|
35 |
+
|
36 |
+
:param to_clean: Queue of reactions to clean/standardize
|
37 |
+
:param to_write: Standardized outputs to write
|
38 |
+
:param config: ReactionStandardizationConfig
|
39 |
+
:return: None
|
40 |
+
"""
|
41 |
+
logger = getLogger()
|
42 |
+
logger.disabled = True
|
43 |
+
while True:
|
44 |
+
raw_reaction = to_clean.get()
|
45 |
+
if raw_reaction == "Quit":
|
46 |
+
break
|
47 |
+
res = cleaner(raw_reaction, logger, config)
|
48 |
+
to_write.put(res)
|
49 |
+
logger.disabled = False
|
50 |
+
|
51 |
+
|
52 |
+
def cleaner_writer(output_file: str, to_write: Queue, cleaned_nb: Value, remove_old=True):
|
53 |
+
"""
|
54 |
+
Writes in output file the standardized reactions
|
55 |
+
|
56 |
+
:param output_file: output file path
|
57 |
+
:param to_write: Standardized ReactionContainer to write
|
58 |
+
:param cleaned_nb: number of final reactions
|
59 |
+
:param remove_old: whenever to remove or not an already existing file
|
60 |
+
"""
|
61 |
+
|
62 |
+
if remove_old and os.path.isfile(output_file):
|
63 |
+
os.remove(output_file)
|
64 |
+
|
65 |
+
counter = 0
|
66 |
+
seen_reactions = []
|
67 |
+
with ReactionWriter(output_file) as out:
|
68 |
+
while True:
|
69 |
+
res = to_write.get()
|
70 |
+
if res:
|
71 |
+
if res == "Quit":
|
72 |
+
cleaned_nb.set(counter)
|
73 |
+
break
|
74 |
+
elif isinstance(res, ReactionContainer):
|
75 |
+
smi = format(res, "m")
|
76 |
+
if smi not in seen_reactions:
|
77 |
+
out.write(res)
|
78 |
+
counter += 1
|
79 |
+
seen_reactions.append(smi)
|
80 |
+
|
81 |
+
|
82 |
+
def reactions_cleaner(config: ReactionStandardizationConfig,
|
83 |
+
input_file: str, output_file: str, num_cpus: int, batch_prep_size: int = 100):
|
84 |
+
"""
|
85 |
+
Writes in output file the standardized reactions
|
86 |
+
|
87 |
+
:param config:
|
88 |
+
:param input_file: input RDF file path
|
89 |
+
:param output_file: output RDF file path
|
90 |
+
:param num_cpus: number of CPU to be parallelized
|
91 |
+
:param batch_prep_size: size of each batch per CPU
|
92 |
+
"""
|
93 |
+
with Manager() as m:
|
94 |
+
to_clean = m.Queue(maxsize=num_cpus * batch_prep_size)
|
95 |
+
to_write = m.Queue(maxsize=batch_prep_size)
|
96 |
+
cleaned_nb = m.Value(int, 0)
|
97 |
+
|
98 |
+
writer = Process(target=cleaner_writer, args=(output_file, to_write, cleaned_nb))
|
99 |
+
writer.start()
|
100 |
+
|
101 |
+
workers = []
|
102 |
+
for _ in range(num_cpus - 2):
|
103 |
+
w = Process(target=worker_cleaner, args=(to_clean, to_write, config))
|
104 |
+
w.start()
|
105 |
+
workers.append(w)
|
106 |
+
|
107 |
+
n = 0
|
108 |
+
with ReactionReader(input_file) as reactions:
|
109 |
+
for raw_reaction in tqdm(reactions):
|
110 |
+
if 'Reaction_ID' not in raw_reaction.meta:
|
111 |
+
raw_reaction.meta['Reaction_ID'] = n
|
112 |
+
to_clean.put(raw_reaction)
|
113 |
+
n += 1
|
114 |
+
|
115 |
+
for _ in workers:
|
116 |
+
to_clean.put("Quit")
|
117 |
+
for w in workers:
|
118 |
+
w.join()
|
119 |
+
|
120 |
+
to_write.put("Quit")
|
121 |
+
writer.join()
|
122 |
+
|
123 |
+
print(f'Initial number of reactions: {n}'),
|
124 |
+
print(f'Removed number of reactions: {n - cleaned_nb.get()}')
|
SynTool/chem/data/filtering.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Iterable, Tuple, Dict, Any, Optional
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import ray
|
9 |
+
import yaml
|
10 |
+
from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer
|
11 |
+
from StructureFingerprint import MorganFingerprint
|
12 |
+
|
13 |
+
from SynTool.utils.files import ReactionReader, ReactionWriter
|
14 |
+
from SynTool.chem.utils import remove_small_molecules, rebalance_reaction, remove_reagents
|
15 |
+
from SynTool.utils.config import ConfigABC, convert_config_to_dict
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class CompeteProductsConfig(ConfigABC):
|
20 |
+
fingerprint_tanimoto_threshold: float = 0.3
|
21 |
+
mcs_tanimoto_threshold: float = 0.6
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def from_dict(config_dict: Dict[str, Any]):
|
25 |
+
"""Create an instance of CompeteProductsConfig from a dictionary."""
|
26 |
+
return CompeteProductsConfig(**config_dict)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def from_yaml(file_path: str):
|
30 |
+
"""Deserialize a YAML file into a CompeteProductsConfig object."""
|
31 |
+
with open(file_path, "r") as file:
|
32 |
+
config_dict = yaml.safe_load(file)
|
33 |
+
return CompeteProductsConfig.from_dict(config_dict)
|
34 |
+
|
35 |
+
def _validate_params(self, params: Dict[str, Any]):
|
36 |
+
"""Validate configuration parameters."""
|
37 |
+
if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) \
|
38 |
+
or not (0 <= params["fingerprint_tanimoto_threshold"] <= 1):
|
39 |
+
raise ValueError("Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1")
|
40 |
+
|
41 |
+
if not isinstance(params.get("mcs_tanimoto_threshold"), float) \
|
42 |
+
or not (0 <= params["mcs_tanimoto_threshold"] <= 1):
|
43 |
+
raise ValueError("Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1")
|
44 |
+
|
45 |
+
|
46 |
+
class CompeteProductsChecker:
|
47 |
+
"""Checks if there are compete reactions."""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
fingerprint_tanimoto_threshold: float = 0.3,
|
52 |
+
mcs_tanimoto_threshold: float = 0.6,
|
53 |
+
):
|
54 |
+
self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold
|
55 |
+
self.mcs_tanimoto_threshold = mcs_tanimoto_threshold
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def from_config(config: CompeteProductsConfig):
|
59 |
+
"""Creates an instance of CompeteProductsChecker from a configuration object."""
|
60 |
+
return CompeteProductsChecker(
|
61 |
+
config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold
|
62 |
+
)
|
63 |
+
|
64 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
65 |
+
"""
|
66 |
+
Returns True if the reaction has competing products, else False
|
67 |
+
|
68 |
+
:param reaction: input reaction
|
69 |
+
:return: True or False
|
70 |
+
"""
|
71 |
+
mf = MorganFingerprint()
|
72 |
+
is_compete = False
|
73 |
+
|
74 |
+
# Check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity
|
75 |
+
for mol in reaction.reagents:
|
76 |
+
for other_mol in reaction.products:
|
77 |
+
if len(mol) > 6 and len(other_mol) > 6:
|
78 |
+
# Compute fingerprint similarity
|
79 |
+
molf = mf.transform([mol])
|
80 |
+
other_molf = mf.transform([other_mol])
|
81 |
+
fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0]
|
82 |
+
|
83 |
+
# If fingerprint similarity is high enough, check for MCS similarity
|
84 |
+
if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold:
|
85 |
+
try:
|
86 |
+
# Find the maximum common substructure (MCS) and compute its size
|
87 |
+
clique_size = len(next(mol.get_mcs_mapping(other_mol, limit=100)))
|
88 |
+
|
89 |
+
# Calculate MCS similarity based on MCS size
|
90 |
+
mcs_tanimoto = clique_size / (len(mol) + len(other_mol) - clique_size)
|
91 |
+
|
92 |
+
# If MCS similarity is also high enough, mark the reaction as having compete products
|
93 |
+
if mcs_tanimoto > self.mcs_tanimoto_threshold:
|
94 |
+
is_compete = True
|
95 |
+
break
|
96 |
+
except StopIteration:
|
97 |
+
continue
|
98 |
+
|
99 |
+
return is_compete
|
100 |
+
|
101 |
+
|
102 |
+
@dataclass
|
103 |
+
class DynamicBondsConfig(ConfigABC):
|
104 |
+
min_bonds_number: int = 1
|
105 |
+
max_bonds_number: int = 6
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def from_dict(config_dict: Dict[str, Any]):
|
109 |
+
"""Create an instance of DynamicBondsConfig from a dictionary."""
|
110 |
+
return DynamicBondsConfig(**config_dict)
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def from_yaml(file_path: str):
|
114 |
+
"""Deserialize a YAML file into a DynamicBondsConfig object."""
|
115 |
+
with open(file_path, "r") as file:
|
116 |
+
config_dict = yaml.safe_load(file)
|
117 |
+
return DynamicBondsConfig.from_dict(config_dict)
|
118 |
+
|
119 |
+
def _validate_params(self, params: Dict[str, Any]):
|
120 |
+
"""Validate configuration parameters."""
|
121 |
+
if not isinstance(params.get("min_bonds_number"), int) \
|
122 |
+
or params["min_bonds_number"] < 0:
|
123 |
+
raise ValueError(
|
124 |
+
"Invalid 'min_bonds_number'; expected a non-negative integer")
|
125 |
+
|
126 |
+
if not isinstance(params.get("max_bonds_number"), int) \
|
127 |
+
or params["max_bonds_number"] < 0:
|
128 |
+
raise ValueError("Invalid 'max_bonds_number'; expected a non-negative integer")
|
129 |
+
|
130 |
+
if params["min_bonds_number"] > params["max_bonds_number"]:
|
131 |
+
raise ValueError("'min_bonds_number' cannot be greater than 'max_bonds_number'")
|
132 |
+
|
133 |
+
|
134 |
+
class DynamicBondsChecker:
|
135 |
+
"""Checks if there is an unacceptable number of dynamic bonds in CGR."""
|
136 |
+
|
137 |
+
def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6):
|
138 |
+
self.min_bonds_number = min_bonds_number
|
139 |
+
self.max_bonds_number = max_bonds_number
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def from_config(config: DynamicBondsConfig):
|
143 |
+
"""Creates an instance of DynamicBondsChecker from a configuration object."""
|
144 |
+
return DynamicBondsChecker(config.min_bonds_number, config.max_bonds_number)
|
145 |
+
|
146 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
147 |
+
cgr = ~reaction
|
148 |
+
return not (self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number)
|
149 |
+
|
150 |
+
|
151 |
+
@dataclass
|
152 |
+
class SmallMoleculesConfig(ConfigABC):
|
153 |
+
limit: int = 6
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def from_dict(config_dict: Dict[str, Any]):
|
157 |
+
"""Create an instance of SmallMoleculesConfig from a dictionary."""
|
158 |
+
return SmallMoleculesConfig(**config_dict)
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def from_yaml(file_path: str):
|
162 |
+
"""Deserialize a YAML file into a SmallMoleculesConfig object."""
|
163 |
+
with open(file_path, "r") as file:
|
164 |
+
config_dict = yaml.safe_load(file)
|
165 |
+
return SmallMoleculesConfig.from_dict(config_dict)
|
166 |
+
|
167 |
+
def _validate_params(self, params: Dict[str, Any]):
|
168 |
+
"""Validate configuration parameters."""
|
169 |
+
if not isinstance(params.get("limit"), int) or params["limit"] < 1:
|
170 |
+
raise ValueError("Invalid 'limit'; expected a positive integer")
|
171 |
+
|
172 |
+
|
173 |
+
class SmallMoleculesChecker:
|
174 |
+
"""Checks if there are only small molecules in the reaction or if there is only one small reactant or product."""
|
175 |
+
|
176 |
+
def __init__(self, limit: int = 6):
|
177 |
+
self.limit = limit
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def from_config(config: SmallMoleculesConfig):
|
181 |
+
"""Creates an instance of SmallMoleculesChecker from a configuration object."""
|
182 |
+
return SmallMoleculesChecker(config.limit)
|
183 |
+
|
184 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
185 |
+
if (len(reaction.reactants) == 1 and self.are_only_small_molecules(reaction.reactants)) \
|
186 |
+
or (len(reaction.products) == 1 and self.are_only_small_molecules(reaction.products)) \
|
187 |
+
or (self.are_only_small_molecules(reaction.reactants) and self.are_only_small_molecules(reaction.products)):
|
188 |
+
return True
|
189 |
+
return False
|
190 |
+
|
191 |
+
def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool:
|
192 |
+
"""Checks if all molecules in the given iterable are small molecules."""
|
193 |
+
return all(len(molecule) <= self.limit for molecule in molecules)
|
194 |
+
|
195 |
+
|
196 |
+
@dataclass
|
197 |
+
class CGRConnectedComponentsConfig:
|
198 |
+
pass
|
199 |
+
|
200 |
+
|
201 |
+
class CGRConnectedComponentsChecker:
|
202 |
+
"""Allows to check if CGR contains unrelated components (without reagents)."""
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def from_config(config: CGRConnectedComponentsConfig): # TODO config class not used
|
206 |
+
"""Creates an instance of CGRConnectedComponentsChecker from a configuration object."""
|
207 |
+
return CGRConnectedComponentsChecker()
|
208 |
+
|
209 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
210 |
+
tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
|
211 |
+
cgr = ~tmp_reaction
|
212 |
+
return cgr.connected_components_count > 1
|
213 |
+
|
214 |
+
|
215 |
+
@dataclass
|
216 |
+
class RingsChangeConfig:
|
217 |
+
pass
|
218 |
+
|
219 |
+
|
220 |
+
class RingsChangeChecker:
|
221 |
+
"""Allows to check if there is changing rings number in the reaction."""
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def from_config(config: RingsChangeConfig): # TODO config class not used
|
225 |
+
"""Creates an instance of RingsChecker from a configuration object."""
|
226 |
+
return RingsChangeChecker()
|
227 |
+
|
228 |
+
def __call__(self, reaction: ReactionContainer):
|
229 |
+
"""
|
230 |
+
Returns True if there are valence mistakes in the reaction or there is a reaction with mismatch numbers of all
|
231 |
+
rings or aromatic rings in reactants and products (reaction in rings)
|
232 |
+
|
233 |
+
:param reaction: input reaction
|
234 |
+
:return: True or False
|
235 |
+
"""
|
236 |
+
|
237 |
+
reaction.kekule()
|
238 |
+
reaction.thiele()
|
239 |
+
r_rings, r_arom_rings = self._calc_rings(reaction.reactants)
|
240 |
+
p_rings, p_arom_rings = self._calc_rings(reaction.products)
|
241 |
+
if (r_arom_rings != p_arom_rings) or (r_rings != p_rings):
|
242 |
+
return True
|
243 |
+
else:
|
244 |
+
return False
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def _calc_rings(molecules: Iterable) -> Tuple[int, int]:
|
248 |
+
"""
|
249 |
+
Calculates number of all rings and number of aromatic rings in molecules
|
250 |
+
|
251 |
+
:param molecules: set of molecules
|
252 |
+
:return: number of all rings and number of aromatic rings in molecules
|
253 |
+
"""
|
254 |
+
rings, arom_rings = 0, 0
|
255 |
+
for mol in molecules:
|
256 |
+
rings += mol.rings_count
|
257 |
+
arom_rings += len(mol.aromatic_rings)
|
258 |
+
return rings, arom_rings
|
259 |
+
|
260 |
+
|
261 |
+
@dataclass
|
262 |
+
class StrangeCarbonsConfig:
|
263 |
+
# Currently empty, but can be extended in the future if needed
|
264 |
+
pass
|
265 |
+
|
266 |
+
|
267 |
+
class StrangeCarbonsChecker:
|
268 |
+
"""Checks if there are 'strange' carbons in the reaction."""
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def from_config(config: StrangeCarbonsConfig): # TODO config class not used
|
272 |
+
"""Creates an instance of StrangeCarbonsChecker from a configuration object."""
|
273 |
+
return StrangeCarbonsChecker()
|
274 |
+
|
275 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
276 |
+
for molecule in reaction.reactants + reaction.products:
|
277 |
+
atoms_types = {a.atomic_symbol for _, a in molecule.atoms()} # atoms types in molecule
|
278 |
+
if len(atoms_types) == 1 and atoms_types.pop() == "C":
|
279 |
+
if len(molecule) == 1: # methane
|
280 |
+
return True
|
281 |
+
bond_types = {int(b) for _, _, b in molecule.bonds()}
|
282 |
+
if len(bond_types) == 1 and bond_types.pop() != 4:
|
283 |
+
return True # C molecules with only one type of bond (not aromatic)
|
284 |
+
return False
|
285 |
+
|
286 |
+
|
287 |
+
@dataclass
|
288 |
+
class NoReactionConfig:
|
289 |
+
# Currently empty, but can be extended in the future if needed
|
290 |
+
pass
|
291 |
+
|
292 |
+
|
293 |
+
class NoReactionChecker:
|
294 |
+
"""Checks if there is no reaction in the provided reaction container."""
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def from_config(config: NoReactionConfig): # TODO config class not used
|
298 |
+
"""Creates an instance of NoReactionChecker from a configuration object."""
|
299 |
+
return NoReactionChecker()
|
300 |
+
|
301 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
302 |
+
cgr = ~reaction
|
303 |
+
return not cgr.center_atoms and not cgr.center_bonds
|
304 |
+
|
305 |
+
|
306 |
+
@dataclass
|
307 |
+
class MultiCenterConfig:
|
308 |
+
pass
|
309 |
+
|
310 |
+
|
311 |
+
class MultiCenterChecker:
|
312 |
+
"""Checks if there is a multicenter reaction."""
|
313 |
+
|
314 |
+
@staticmethod
|
315 |
+
def from_config(config: MultiCenterConfig): # TODO config class not used
|
316 |
+
return MultiCenterChecker()
|
317 |
+
|
318 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
319 |
+
cgr = ~reaction
|
320 |
+
return len(cgr.centers_list) > 1
|
321 |
+
|
322 |
+
|
323 |
+
@dataclass
|
324 |
+
class WrongCHBreakingConfig:
|
325 |
+
pass
|
326 |
+
|
327 |
+
|
328 |
+
class WrongCHBreakingChecker:
|
329 |
+
"""Checks for incorrect C-C bond formation from breaking a C-H bond."""
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def from_config(config: WrongCHBreakingConfig): # TODO config class not used
|
333 |
+
return WrongCHBreakingChecker()
|
334 |
+
|
335 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
336 |
+
"""
|
337 |
+
Determines if a reaction involves incorrect C-C bond formation from breaking a C-H bond.
|
338 |
+
|
339 |
+
:param reaction: The reaction to be checked.
|
340 |
+
:return: True if incorrect C-C bond formation is found, False otherwise.
|
341 |
+
"""
|
342 |
+
|
343 |
+
reaction.kekule()
|
344 |
+
if reaction.check_valence():
|
345 |
+
return False
|
346 |
+
reaction.thiele()
|
347 |
+
|
348 |
+
copy_reaction = reaction.copy()
|
349 |
+
copy_reaction.explicify_hydrogens()
|
350 |
+
cgr = ~copy_reaction
|
351 |
+
reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1)
|
352 |
+
|
353 |
+
return self.is_wrong_c_h_breaking(reduced_cgr)
|
354 |
+
|
355 |
+
@staticmethod
|
356 |
+
def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool:
|
357 |
+
"""
|
358 |
+
Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR.
|
359 |
+
:param cgr: The CGR with explicified hydrogens.
|
360 |
+
:return: True if incorrect C-C bond formation is found, False otherwise.
|
361 |
+
"""
|
362 |
+
for atom_id in cgr.center_atoms:
|
363 |
+
if cgr.atom(atom_id).atomic_symbol == "C":
|
364 |
+
is_c_h_breaking, is_c_c_formation = False, False
|
365 |
+
c_with_h_id, another_c_id = None, None
|
366 |
+
|
367 |
+
for neighbour_id, bond in cgr._bonds[atom_id].items():
|
368 |
+
neighbour = cgr.atom(neighbour_id)
|
369 |
+
|
370 |
+
if (
|
371 |
+
bond.order
|
372 |
+
and not bond.p_order
|
373 |
+
and neighbour.atomic_symbol == "H"
|
374 |
+
):
|
375 |
+
is_c_h_breaking = True
|
376 |
+
c_with_h_id = atom_id
|
377 |
+
|
378 |
+
elif (
|
379 |
+
not bond.order
|
380 |
+
and bond.p_order
|
381 |
+
and neighbour.atomic_symbol == "C"
|
382 |
+
):
|
383 |
+
is_c_c_formation = True
|
384 |
+
another_c_id = neighbour_id
|
385 |
+
|
386 |
+
if is_c_h_breaking and is_c_c_formation:
|
387 |
+
# Check for presence of heteroatoms in the first environment of 2 bonding carbons
|
388 |
+
if any(
|
389 |
+
cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
|
390 |
+
for neighbour_id in cgr._bonds[c_with_h_id]
|
391 |
+
) or any(
|
392 |
+
cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
|
393 |
+
for neighbour_id in cgr._bonds[another_c_id]
|
394 |
+
):
|
395 |
+
return False
|
396 |
+
return True
|
397 |
+
|
398 |
+
return False
|
399 |
+
|
400 |
+
|
401 |
+
@dataclass
|
402 |
+
class CCsp3BreakingConfig:
|
403 |
+
pass
|
404 |
+
|
405 |
+
|
406 |
+
class CCsp3BreakingChecker:
|
407 |
+
"""Checks if there is C(sp3)-C bond breaking."""
|
408 |
+
|
409 |
+
@staticmethod
|
410 |
+
def from_config(config: CCsp3BreakingConfig): # TODO config class not used
|
411 |
+
return CCsp3BreakingChecker()
|
412 |
+
|
413 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
414 |
+
"""
|
415 |
+
Returns True if there is C(sp3)-C bonds breaking, else False
|
416 |
+
|
417 |
+
:param reaction: input reaction
|
418 |
+
:return: True or False
|
419 |
+
"""
|
420 |
+
cgr = ~reaction
|
421 |
+
reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1)
|
422 |
+
for atom_id, neighbour_id, bond in reaction_center.bonds():
|
423 |
+
atom = reaction_center.atom(atom_id)
|
424 |
+
neighbour = reaction_center.atom(neighbour_id)
|
425 |
+
|
426 |
+
is_bond_broken = bond.order is not None and bond.p_order is None
|
427 |
+
are_atoms_carbons = (
|
428 |
+
atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
|
429 |
+
)
|
430 |
+
is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1
|
431 |
+
|
432 |
+
if is_bond_broken and are_atoms_carbons and is_atom_sp3:
|
433 |
+
return True
|
434 |
+
return False
|
435 |
+
|
436 |
+
|
437 |
+
@dataclass
|
438 |
+
class CCRingBreakingConfig:
|
439 |
+
pass
|
440 |
+
|
441 |
+
|
442 |
+
class CCRingBreakingChecker:
|
443 |
+
"""Checks if a reaction involves ring C-C bond breaking."""
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def from_config(config: CCRingBreakingConfig): # TODO config class not used
|
447 |
+
return CCRingBreakingChecker()
|
448 |
+
|
449 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
450 |
+
"""
|
451 |
+
Returns True if the reaction involves ring C-C bond breaking, else False
|
452 |
+
|
453 |
+
:param reaction: input reaction
|
454 |
+
:return: True or False
|
455 |
+
"""
|
456 |
+
cgr = ~reaction
|
457 |
+
|
458 |
+
# Extract reactants' center atoms and their rings
|
459 |
+
reactants_center_atoms = {}
|
460 |
+
reactants_rings = set()
|
461 |
+
for reactant in reaction.reactants:
|
462 |
+
reactants_rings.update(reactant.sssr)
|
463 |
+
for n, atom in reactant.atoms():
|
464 |
+
if n in cgr.center_atoms:
|
465 |
+
reactants_center_atoms[n] = atom
|
466 |
+
|
467 |
+
# Identify reaction center based on center atoms
|
468 |
+
reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0)
|
469 |
+
|
470 |
+
# Iterate over bonds in the reaction center and check for ring C-C bond breaking
|
471 |
+
for atom_id, neighbour_id, bond in reaction_center.bonds():
|
472 |
+
try:
|
473 |
+
# Retrieve corresponding atoms from reactants
|
474 |
+
atom = reactants_center_atoms[atom_id]
|
475 |
+
neighbour = reactants_center_atoms[neighbour_id]
|
476 |
+
except KeyError:
|
477 |
+
continue
|
478 |
+
else:
|
479 |
+
# Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7
|
480 |
+
is_bond_broken = (bond.order is not None) and (bond.p_order is None)
|
481 |
+
are_atoms_carbons = (
|
482 |
+
atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
|
483 |
+
)
|
484 |
+
are_atoms_in_ring = (
|
485 |
+
set(atom.ring_sizes).intersection({5, 6, 7})
|
486 |
+
and set(neighbour.ring_sizes).intersection({5, 6, 7})
|
487 |
+
and any(
|
488 |
+
atom_id in ring and neighbour_id in ring
|
489 |
+
for ring in reactants_rings
|
490 |
+
)
|
491 |
+
)
|
492 |
+
|
493 |
+
# If all conditions are met, indicate ring C-C bond breaking
|
494 |
+
if is_bond_broken and are_atoms_carbons and are_atoms_in_ring:
|
495 |
+
return True
|
496 |
+
|
497 |
+
return False
|
498 |
+
|
499 |
+
|
500 |
+
@dataclass
|
501 |
+
class ReactionCheckConfig(ConfigABC):
|
502 |
+
"""
|
503 |
+
Configuration class for reaction checks, inheriting from ConfigABC.
|
504 |
+
|
505 |
+
This class manages configuration settings for various reaction checkers, including paths, file formats,
|
506 |
+
and checker-specific parameters.
|
507 |
+
|
508 |
+
Attributes:
|
509 |
+
dynamic_bonds_config: Configuration for dynamic bonds checking.
|
510 |
+
small_molecules_config: Configuration for small molecules checking.
|
511 |
+
strange_carbons_config: Configuration for strange carbons checking.
|
512 |
+
compete_products_config: Configuration for competing products checking.
|
513 |
+
cgr_connected_components_config: Configuration for CGR connected components checking.
|
514 |
+
rings_change_config: Configuration for rings change checking.
|
515 |
+
no_reaction_config: Configuration for no reaction checking.
|
516 |
+
multi_center_config: Configuration for multi-center checking.
|
517 |
+
wrong_ch_breaking_config: Configuration for wrong C-H breaking checking.
|
518 |
+
cc_sp3_breaking_config: Configuration for CC sp3 breaking checking.
|
519 |
+
cc_ring_breaking_config: Configuration for CC ring breaking checking.
|
520 |
+
"""
|
521 |
+
|
522 |
+
# Configuration for reaction checkers
|
523 |
+
dynamic_bonds_config: Optional[DynamicBondsConfig] = None
|
524 |
+
small_molecules_config: Optional[SmallMoleculesConfig] = None
|
525 |
+
strange_carbons_config: Optional[StrangeCarbonsConfig] = None
|
526 |
+
compete_products_config: Optional[CompeteProductsConfig] = None
|
527 |
+
cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None
|
528 |
+
rings_change_config: Optional[RingsChangeConfig] = None
|
529 |
+
no_reaction_config: Optional[NoReactionConfig] = None
|
530 |
+
multi_center_config: Optional[MultiCenterConfig] = None
|
531 |
+
wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None
|
532 |
+
cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None
|
533 |
+
cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None
|
534 |
+
|
535 |
+
# Other configuration parameters
|
536 |
+
rebalance_reaction: bool = False
|
537 |
+
remove_reagents: bool = True
|
538 |
+
reagents_max_size: int = 7
|
539 |
+
remove_small_molecules: bool = False
|
540 |
+
small_molecules_max_size: int = 6
|
541 |
+
|
542 |
+
def to_dict(self):
|
543 |
+
"""
|
544 |
+
Converts the configuration into a dictionary.
|
545 |
+
"""
|
546 |
+
config_dict = {
|
547 |
+
"dynamic_bonds_config": convert_config_to_dict(
|
548 |
+
self.dynamic_bonds_config, DynamicBondsConfig
|
549 |
+
),
|
550 |
+
"small_molecules_config": convert_config_to_dict(
|
551 |
+
self.small_molecules_config, SmallMoleculesConfig
|
552 |
+
),
|
553 |
+
"compete_products_config": convert_config_to_dict(
|
554 |
+
self.compete_products_config, CompeteProductsConfig
|
555 |
+
),
|
556 |
+
"cgr_connected_components_config": {}
|
557 |
+
if self.cgr_connected_components_config is not None
|
558 |
+
else None,
|
559 |
+
"rings_change_config": {} if self.rings_change_config is not None else None,
|
560 |
+
"strange_carbons_config": {}
|
561 |
+
if self.strange_carbons_config is not None
|
562 |
+
else None,
|
563 |
+
"no_reaction_config": {} if self.no_reaction_config is not None else None,
|
564 |
+
"multi_center_config": {} if self.multi_center_config is not None else None,
|
565 |
+
"wrong_ch_breaking_config": {}
|
566 |
+
if self.wrong_ch_breaking_config is not None
|
567 |
+
else None,
|
568 |
+
"cc_sp3_breaking_config": {}
|
569 |
+
if self.cc_sp3_breaking_config is not None
|
570 |
+
else None,
|
571 |
+
"cc_ring_breaking_config": {}
|
572 |
+
if self.cc_ring_breaking_config is not None
|
573 |
+
else None,
|
574 |
+
"rebalance_reaction": self.rebalance_reaction,
|
575 |
+
"remove_reagents": self.remove_reagents,
|
576 |
+
"reagents_max_size": self.reagents_max_size,
|
577 |
+
"remove_small_molecules": self.remove_small_molecules,
|
578 |
+
"small_molecules_max_size": self.small_molecules_max_size,
|
579 |
+
}
|
580 |
+
|
581 |
+
filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
582 |
+
|
583 |
+
return filtered_config_dict
|
584 |
+
|
585 |
+
@staticmethod
|
586 |
+
def from_dict(config_dict: Dict[str, Any]):
|
587 |
+
"""
|
588 |
+
Create an instance of ReactionCheckConfig from a dictionary.
|
589 |
+
"""
|
590 |
+
# Instantiate configuration objects if their corresponding dictionary is present
|
591 |
+
dynamic_bonds_config = (
|
592 |
+
DynamicBondsConfig(**config_dict["dynamic_bonds_config"])
|
593 |
+
if "dynamic_bonds_config" in config_dict
|
594 |
+
else None
|
595 |
+
)
|
596 |
+
small_molecules_config = (
|
597 |
+
SmallMoleculesConfig(**config_dict["small_molecules_config"])
|
598 |
+
if "small_molecules_config" in config_dict
|
599 |
+
else None
|
600 |
+
)
|
601 |
+
compete_products_config = (
|
602 |
+
CompeteProductsConfig(**config_dict["compete_products_config"])
|
603 |
+
if "compete_products_config" in config_dict
|
604 |
+
else None
|
605 |
+
)
|
606 |
+
cgr_connected_components_config = (
|
607 |
+
CGRConnectedComponentsConfig()
|
608 |
+
if "cgr_connected_components_config" in config_dict
|
609 |
+
else None
|
610 |
+
)
|
611 |
+
rings_change_config = (
|
612 |
+
RingsChangeConfig()
|
613 |
+
if "rings_change_config" in config_dict
|
614 |
+
else None
|
615 |
+
)
|
616 |
+
strange_carbons_config = (
|
617 |
+
StrangeCarbonsConfig()
|
618 |
+
if "strange_carbons_config" in config_dict
|
619 |
+
else None
|
620 |
+
)
|
621 |
+
no_reaction_config = (
|
622 |
+
NoReactionConfig()
|
623 |
+
if "no_reaction_config" in config_dict
|
624 |
+
else None
|
625 |
+
)
|
626 |
+
multi_center_config = (
|
627 |
+
MultiCenterConfig()
|
628 |
+
if "multi_center_config" in config_dict
|
629 |
+
else None
|
630 |
+
)
|
631 |
+
wrong_ch_breaking_config = (
|
632 |
+
WrongCHBreakingConfig()
|
633 |
+
if "wrong_ch_breaking_config" in config_dict
|
634 |
+
else None
|
635 |
+
)
|
636 |
+
cc_sp3_breaking_config = (
|
637 |
+
CCsp3BreakingConfig()
|
638 |
+
if "cc_sp3_breaking_config" in config_dict
|
639 |
+
else None
|
640 |
+
)
|
641 |
+
cc_ring_breaking_config = (
|
642 |
+
CCRingBreakingConfig()
|
643 |
+
if "cc_ring_breaking_config" in config_dict
|
644 |
+
else None
|
645 |
+
)
|
646 |
+
|
647 |
+
# Extract other simple configuration parameters
|
648 |
+
rebalance_reaction = config_dict.get("rebalance_reaction", False)
|
649 |
+
remove_reagents = config_dict.get("remove_reagents", True)
|
650 |
+
reagents_max_size = config_dict.get("reagents_max_size", 7)
|
651 |
+
remove_small_molecules = config_dict.get("remove_small_molecules", False)
|
652 |
+
small_molecules_max_size = config_dict.get("small_molecules_max_size", 6)
|
653 |
+
|
654 |
+
return ReactionCheckConfig(
|
655 |
+
dynamic_bonds_config=dynamic_bonds_config,
|
656 |
+
small_molecules_config=small_molecules_config,
|
657 |
+
compete_products_config=compete_products_config,
|
658 |
+
cgr_connected_components_config=cgr_connected_components_config,
|
659 |
+
rings_change_config=rings_change_config,
|
660 |
+
strange_carbons_config=strange_carbons_config,
|
661 |
+
no_reaction_config=no_reaction_config,
|
662 |
+
multi_center_config=multi_center_config,
|
663 |
+
wrong_ch_breaking_config=wrong_ch_breaking_config,
|
664 |
+
cc_sp3_breaking_config=cc_sp3_breaking_config,
|
665 |
+
cc_ring_breaking_config=cc_ring_breaking_config,
|
666 |
+
rebalance_reaction=rebalance_reaction,
|
667 |
+
remove_reagents=remove_reagents,
|
668 |
+
reagents_max_size=reagents_max_size,
|
669 |
+
remove_small_molecules=remove_small_molecules,
|
670 |
+
small_molecules_max_size=small_molecules_max_size,
|
671 |
+
)
|
672 |
+
|
673 |
+
@staticmethod
|
674 |
+
def from_yaml(file_path):
|
675 |
+
"""
|
676 |
+
Deserializes a YAML file into a ReactionCheckConfig object.
|
677 |
+
"""
|
678 |
+
with open(file_path, "r") as file:
|
679 |
+
config_dict = yaml.safe_load(file)
|
680 |
+
return ReactionCheckConfig.from_dict(config_dict)
|
681 |
+
|
682 |
+
def _validate_params(self, params: Dict[str, Any]):
|
683 |
+
if not isinstance(params["rebalance_reaction"], bool):
|
684 |
+
raise ValueError("rebalance_reaction must be a boolean.")
|
685 |
+
|
686 |
+
if not isinstance(params["remove_reagents"], bool):
|
687 |
+
raise ValueError("remove_reagents must be a boolean.")
|
688 |
+
|
689 |
+
if not isinstance(params["reagents_max_size"], int):
|
690 |
+
raise ValueError("reagents_max_size must be an int.")
|
691 |
+
|
692 |
+
if not isinstance(params["remove_small_molecules"], bool):
|
693 |
+
raise ValueError("remove_small_molecules must be a boolean.")
|
694 |
+
|
695 |
+
if not isinstance(params["small_molecules_max_size"], int):
|
696 |
+
raise ValueError("small_molecules_max_size must be an int.")
|
697 |
+
|
698 |
+
def create_checkers(self):
|
699 |
+
checker_instances = []
|
700 |
+
|
701 |
+
if self.dynamic_bonds_config is not None:
|
702 |
+
checker_instances.append(
|
703 |
+
DynamicBondsChecker.from_config(self.dynamic_bonds_config)
|
704 |
+
)
|
705 |
+
|
706 |
+
if self.small_molecules_config is not None:
|
707 |
+
checker_instances.append(
|
708 |
+
SmallMoleculesChecker.from_config(self.small_molecules_config)
|
709 |
+
)
|
710 |
+
|
711 |
+
if self.strange_carbons_config is not None:
|
712 |
+
checker_instances.append(
|
713 |
+
StrangeCarbonsChecker.from_config(self.strange_carbons_config)
|
714 |
+
)
|
715 |
+
|
716 |
+
if self.compete_products_config is not None:
|
717 |
+
checker_instances.append(
|
718 |
+
CompeteProductsChecker.from_config(self.compete_products_config)
|
719 |
+
)
|
720 |
+
|
721 |
+
if self.cgr_connected_components_config is not None:
|
722 |
+
checker_instances.append(
|
723 |
+
CGRConnectedComponentsChecker.from_config(
|
724 |
+
self.cgr_connected_components_config
|
725 |
+
)
|
726 |
+
)
|
727 |
+
|
728 |
+
if self.rings_change_config is not None:
|
729 |
+
checker_instances.append(
|
730 |
+
RingsChangeChecker.from_config(self.rings_change_config)
|
731 |
+
)
|
732 |
+
|
733 |
+
if self.no_reaction_config is not None:
|
734 |
+
checker_instances.append(
|
735 |
+
NoReactionChecker.from_config(self.no_reaction_config)
|
736 |
+
)
|
737 |
+
|
738 |
+
if self.multi_center_config is not None:
|
739 |
+
checker_instances.append(
|
740 |
+
MultiCenterChecker.from_config(self.multi_center_config)
|
741 |
+
)
|
742 |
+
|
743 |
+
if self.wrong_ch_breaking_config is not None:
|
744 |
+
checker_instances.append(
|
745 |
+
WrongCHBreakingChecker.from_config(self.wrong_ch_breaking_config)
|
746 |
+
)
|
747 |
+
|
748 |
+
if self.cc_sp3_breaking_config is not None:
|
749 |
+
checker_instances.append(
|
750 |
+
CCsp3BreakingChecker.from_config(self.cc_sp3_breaking_config)
|
751 |
+
)
|
752 |
+
|
753 |
+
if self.cc_ring_breaking_config is not None:
|
754 |
+
checker_instances.append(
|
755 |
+
CCRingBreakingChecker.from_config(self.cc_ring_breaking_config)
|
756 |
+
)
|
757 |
+
|
758 |
+
return checker_instances
|
759 |
+
|
760 |
+
|
761 |
+
def tanimoto_kernel(x, y):
|
762 |
+
"""
|
763 |
+
Calculate the Tanimoto coefficient between each element of arrays x and y.
|
764 |
+
"""
|
765 |
+
x = x.astype(np.float64)
|
766 |
+
y = y.astype(np.float64)
|
767 |
+
x_dot = np.dot(x, y.T)
|
768 |
+
x2 = np.sum(x**2, axis=1)
|
769 |
+
y2 = np.sum(y**2, axis=1)
|
770 |
+
|
771 |
+
denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot
|
772 |
+
result = np.divide(x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0)
|
773 |
+
|
774 |
+
return result
|
775 |
+
|
776 |
+
|
777 |
+
def remove_file_if_exists(directory: Path, file_names): # TODO not used
|
778 |
+
for file_name in file_names:
|
779 |
+
file_path = directory / file_name
|
780 |
+
if file_path.is_file():
|
781 |
+
file_path.unlink()
|
782 |
+
logging.warning(f"Removed {file_path}")
|
783 |
+
|
784 |
+
|
785 |
+
def filter_reaction(reaction: ReactionContainer, config: ReactionCheckConfig, checkers: list):
|
786 |
+
|
787 |
+
is_filtered = False
|
788 |
+
if config.remove_small_molecules:
|
789 |
+
new_reaction = remove_small_molecules(reaction, number_of_atoms=config.small_molecules_max_size)
|
790 |
+
else:
|
791 |
+
new_reaction = reaction.copy()
|
792 |
+
|
793 |
+
if new_reaction is None:
|
794 |
+
is_filtered = True
|
795 |
+
|
796 |
+
if config.remove_reagents and not is_filtered:
|
797 |
+
new_reaction = remove_reagents(
|
798 |
+
new_reaction,
|
799 |
+
keep_reagents=True,
|
800 |
+
reagents_max_size=config.reagents_max_size,
|
801 |
+
)
|
802 |
+
|
803 |
+
if new_reaction is None:
|
804 |
+
is_filtered = True
|
805 |
+
new_reaction = reaction.copy()
|
806 |
+
# TODO you are specifying that if the reaction has only reagents, it is kept as it ?
|
807 |
+
|
808 |
+
if not is_filtered:
|
809 |
+
if config.rebalance_reaction:
|
810 |
+
new_reaction = rebalance_reaction(new_reaction)
|
811 |
+
for checker in checkers:
|
812 |
+
try: # TODO CGRTools: ValueError: mapping of graphs is not disjoint
|
813 |
+
if checker(new_reaction):
|
814 |
+
# If checker returns True it means the reaction doesn't pass the check
|
815 |
+
new_reaction.meta["filtration_log"] = checker.__class__.__name__
|
816 |
+
is_filtered = True
|
817 |
+
except:
|
818 |
+
is_filtered = True
|
819 |
+
|
820 |
+
|
821 |
+
|
822 |
+
return is_filtered, new_reaction
|
823 |
+
|
824 |
+
|
825 |
+
@ray.remote
|
826 |
+
def process_batch(batch, config: ReactionCheckConfig, checkers):
|
827 |
+
results = []
|
828 |
+
for index, reaction in batch:
|
829 |
+
try: # TODO CGRtools.exceptions.MappingError: atoms with number {52} not equal
|
830 |
+
is_filtered, processed_reaction = filter_reaction(reaction, config, checkers)
|
831 |
+
results.append((index, is_filtered, processed_reaction))
|
832 |
+
except:
|
833 |
+
results.append((index, True, reaction))
|
834 |
+
return results
|
835 |
+
|
836 |
+
|
837 |
+
def process_completed_batches(futures, result_file, pbar, treated: int = 0, passed_filters: int = 0):
|
838 |
+
done, _ = ray.wait(list(futures.keys()), num_returns=1)
|
839 |
+
completed_batch = ray.get(done[0])
|
840 |
+
|
841 |
+
# Write results of the completed batch to file
|
842 |
+
now_treated = 0
|
843 |
+
for index, is_filtered, reaction in completed_batch:
|
844 |
+
now_treated += 1
|
845 |
+
if not is_filtered:
|
846 |
+
result_file.write(reaction.meta['init_smiles'])
|
847 |
+
passed_filters += 1
|
848 |
+
|
849 |
+
# Remove completed future and update progress bar
|
850 |
+
del futures[done[0]]
|
851 |
+
pbar.update(now_treated)
|
852 |
+
treated += now_treated
|
853 |
+
|
854 |
+
return treated, passed_filters
|
855 |
+
|
856 |
+
|
857 |
+
def filter_reactions(
|
858 |
+
config: ReactionCheckConfig,
|
859 |
+
reaction_database_path: str,
|
860 |
+
result_reactions_file_name: str = "reaction_data_filtered.smi",
|
861 |
+
append_results: bool = False,
|
862 |
+
num_cpus: int = 1,
|
863 |
+
batch_size: int = 100,
|
864 |
+
) -> None:
|
865 |
+
"""
|
866 |
+
Processes a database of chemical reactions, applying checks based on the provided configuration,
|
867 |
+
and writes the results to specified files. All configurations are provided by the ReactionCheckConfig object.
|
868 |
+
|
869 |
+
:param config: ReactionCheckConfig object containing all configuration settings.
|
870 |
+
:param reaction_database_path: Path to the reaction database file.
|
871 |
+
:param result_reactions_file_name: Name for the file containing cleaned reactions.
|
872 |
+
:param append_results: Flag indicating whether to append results to existing files.
|
873 |
+
:param num_cpus: Number of CPUs to use for processing.
|
874 |
+
:param batch_size: Size of the batch for processing reactions.
|
875 |
+
:return: None. The function writes the processed reactions to specified RDF and pickle files.
|
876 |
+
Unique reactions are written if save_only_unique is True.
|
877 |
+
"""
|
878 |
+
|
879 |
+
checkers = config.create_checkers()
|
880 |
+
|
881 |
+
ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
|
882 |
+
max_concurrent_batches = num_cpus # Limit the number of concurrent batches
|
883 |
+
|
884 |
+
with ReactionReader(reaction_database_path) as reactions, \
|
885 |
+
ReactionWriter(result_reactions_file_name, append_results) as result_file:
|
886 |
+
|
887 |
+
pbar = tqdm(reactions, leave=True) # TODO fix progress bars
|
888 |
+
|
889 |
+
futures = {}
|
890 |
+
batch = []
|
891 |
+
treated = filtered = 0
|
892 |
+
for index, reaction in enumerate(reactions):
|
893 |
+
reaction.meta["reaction_index"] = index
|
894 |
+
batch.append((index, reaction))
|
895 |
+
if len(batch) == batch_size:
|
896 |
+
future = process_batch.remote(batch, config, checkers)
|
897 |
+
futures[future] = None
|
898 |
+
batch = []
|
899 |
+
|
900 |
+
# Check and process completed tasks if we've reached the concurrency limit
|
901 |
+
while len(futures) >= max_concurrent_batches:
|
902 |
+
treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered)
|
903 |
+
|
904 |
+
# Process the last batch if it's not empty
|
905 |
+
if batch:
|
906 |
+
future = process_batch.remote(batch, config, checkers)
|
907 |
+
futures[future] = None
|
908 |
+
|
909 |
+
# Process remaining batches
|
910 |
+
while futures:
|
911 |
+
treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered)
|
912 |
+
|
913 |
+
pbar.close()
|
914 |
+
|
915 |
+
ray.shutdown()
|
916 |
+
print(f'Initial number of reactions: {treated}'),
|
917 |
+
print(f'Removed number of reactions: {treated - filtered}')
|
SynTool/chem/data/mapping.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from os.path import splitext
|
3 |
+
from typing import Union
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from chython import smiles, RDFRead, RDFWrite, ReactionContainer
|
7 |
+
from chython.exceptions import MappingError, IncorrectSmiles
|
8 |
+
|
9 |
+
from SynTool.utils import path_type
|
10 |
+
|
11 |
+
|
12 |
+
def remove_reagents_and_map(rea: ReactionContainer, keep_reagent: bool = False) -> Union[ReactionContainer, None]:
|
13 |
+
"""
|
14 |
+
Maps atoms of the reaction using chytorch.
|
15 |
+
|
16 |
+
:param rea: reaction to map
|
17 |
+
:type rea: ReactionContainer
|
18 |
+
:param keep_reagent: whenever to remove reagent or not
|
19 |
+
:type keep_reagent: bool
|
20 |
+
|
21 |
+
:return: ReactionContainer or None
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
rea.reset_mapping()
|
25 |
+
except MappingError:
|
26 |
+
rea.reset_mapping() # Successive reset_mapping works
|
27 |
+
if not keep_reagent:
|
28 |
+
try:
|
29 |
+
rea.remove_reagents()
|
30 |
+
except:
|
31 |
+
return None
|
32 |
+
return rea
|
33 |
+
|
34 |
+
|
35 |
+
def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type, keep_reagent: bool = False) -> None:
|
36 |
+
"""
|
37 |
+
Reads a file of reactions and maps atoms of the reactions using chytorch.
|
38 |
+
|
39 |
+
:param input_file: the path and name of the input file
|
40 |
+
:type input_file: path_type
|
41 |
+
:param output_file: the path and name of the output file
|
42 |
+
:type output_file: path_type
|
43 |
+
:param keep_reagent: whenever to remove reagent or not
|
44 |
+
:type keep_reagent: bool
|
45 |
+
|
46 |
+
:return: None
|
47 |
+
"""
|
48 |
+
input_file = str(Path(input_file).resolve(strict=True))
|
49 |
+
_, input_ext = splitext(input_file)
|
50 |
+
if input_ext == ".smi":
|
51 |
+
input_file = open(input_file, "r")
|
52 |
+
elif input_ext == ".rdf":
|
53 |
+
input_file = RDFRead(input_file, indexable=True)
|
54 |
+
else:
|
55 |
+
raise ValueError("File extension not recognized. File:", input_file,
|
56 |
+
"- Please use smi or rdf file")
|
57 |
+
enumerator = input_file if input_ext == ".rdf" else input_file.readlines()
|
58 |
+
|
59 |
+
_, out_ext = splitext(output_file)
|
60 |
+
if out_ext == ".smi":
|
61 |
+
output_file = open(output_file, "w")
|
62 |
+
elif out_ext == ".rdf":
|
63 |
+
output_file = RDFWrite(output_file)
|
64 |
+
else:
|
65 |
+
raise ValueError("File extension not recognized. File:", output_file,
|
66 |
+
"- Please use smi or rdf file")
|
67 |
+
|
68 |
+
mapping_errors = 0
|
69 |
+
parsing_errors = 0
|
70 |
+
for rea_raw in tqdm(enumerator):
|
71 |
+
try:
|
72 |
+
rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw
|
73 |
+
except IncorrectSmiles:
|
74 |
+
parsing_errors += 1
|
75 |
+
continue
|
76 |
+
try:
|
77 |
+
rea_mapped = remove_reagents_and_map(rea, keep_reagent)
|
78 |
+
except MappingError:
|
79 |
+
try:
|
80 |
+
rea_mapped = remove_reagents_and_map(smiles(str(rea)), keep_reagent)
|
81 |
+
except MappingError:
|
82 |
+
mapping_errors += 1
|
83 |
+
continue
|
84 |
+
if rea_mapped:
|
85 |
+
rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea
|
86 |
+
output_file.write(rea_output)
|
87 |
+
else:
|
88 |
+
mapping_errors += 1
|
89 |
+
|
90 |
+
input_file.close()
|
91 |
+
output_file.close()
|
92 |
+
|
93 |
+
if parsing_errors:
|
94 |
+
print(parsing_errors, "reactions couldn't be parsed")
|
95 |
+
if mapping_errors:
|
96 |
+
print(mapping_errors, "reactions couldn't be mapped")
|
SynTool/chem/data/mapping.py.bk
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from os.path import splitext
|
3 |
+
from typing import Union
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from chython import smiles, RDFRead, RDFWrite, ReactionContainer
|
7 |
+
from chython.exceptions import MappingError
|
8 |
+
|
9 |
+
from Syntool.utils import path_type
|
10 |
+
|
11 |
+
|
12 |
+
def remove_reagents_and_map(rea: ReactionContainer) -> Union[ReactionContainer, None]:
|
13 |
+
"""
|
14 |
+
Maps atoms of the reaction using chytorch.
|
15 |
+
|
16 |
+
:param rea: reaction to map
|
17 |
+
:type rea: ReactionContainer
|
18 |
+
|
19 |
+
:return: ReactionContainer or None
|
20 |
+
"""
|
21 |
+
try:
|
22 |
+
rea.reset_mapping()
|
23 |
+
except MappingError:
|
24 |
+
rea.reset_mapping()
|
25 |
+
try:
|
26 |
+
rea.remove_reagents()
|
27 |
+
return rea
|
28 |
+
except:
|
29 |
+
# print("Error", str(rea))
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type) -> None:
|
34 |
+
"""
|
35 |
+
Reads a file of reactions and maps atoms of the reactions using chytorch.
|
36 |
+
|
37 |
+
:param input_file: the path and name of the input file
|
38 |
+
:type input_file: path_type
|
39 |
+
|
40 |
+
:param output_file: the path and name of the output file
|
41 |
+
:type output_file: path_type
|
42 |
+
|
43 |
+
:return: None
|
44 |
+
"""
|
45 |
+
input_file = str(Path(input_file).resolve(strict=True))
|
46 |
+
_, input_ext = splitext(input_file)
|
47 |
+
if input_ext == ".smi":
|
48 |
+
input_file = open(input_file, "r")
|
49 |
+
elif input_ext == ".rdf":
|
50 |
+
input_file = RDFRead(input_file, indexable=True)
|
51 |
+
else:
|
52 |
+
raise ValueError("File extension not recognized. File:", input_file,
|
53 |
+
"- Please use smi or rdf file")
|
54 |
+
enumerator = input_file if input_ext == ".rdf" else input_file.readlines()
|
55 |
+
|
56 |
+
_, out_ext = splitext(output_file)
|
57 |
+
if out_ext == ".smi":
|
58 |
+
output_file = open(output_file, "w")
|
59 |
+
elif out_ext == ".rdf":
|
60 |
+
output_file = RDFWrite(output_file)
|
61 |
+
else:
|
62 |
+
raise ValueError("File extension not recognized. File:", output_file,
|
63 |
+
"- Please use smi or rdf file")
|
64 |
+
|
65 |
+
mapping_errors = 0
|
66 |
+
parsing_errors = 0
|
67 |
+
for rea_raw in tqdm(enumerator):
|
68 |
+
try:
|
69 |
+
rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw
|
70 |
+
except:
|
71 |
+
parsing_errors += 1
|
72 |
+
print("Error", parsing_errors, rea_raw)
|
73 |
+
continue
|
74 |
+
try:
|
75 |
+
rea_mapped = remove_reagents_and_map(rea)
|
76 |
+
except:
|
77 |
+
parsing_errors += 1
|
78 |
+
print("Error for,", rea)
|
79 |
+
continue
|
80 |
+
if rea_mapped:
|
81 |
+
rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea
|
82 |
+
output_file.write(rea_output)
|
83 |
+
else:
|
84 |
+
mapping_errors += 1
|
85 |
+
|
86 |
+
input_file.close()
|
87 |
+
output_file.close()
|
88 |
+
|
89 |
+
if mapping_errors:
|
90 |
+
print(mapping_errors, "reactions couldn't be mapped")
|
SynTool/chem/data/standardizer.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################################################################
|
2 |
+
# Code issued from https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning
|
3 |
+
# Reaction_Data_Cleaning/scripts/standardizer.py
|
4 |
+
# version as it from commit 793475e54d8b2c7f714165a61e4eb439435d7d92
|
5 |
+
# DOI 10.1002/minf.202100119
|
6 |
+
#############################################################################
|
7 |
+
# Chemical reactions data curation best practices
|
8 |
+
# including optimized RDTool
|
9 |
+
#############################################################################
|
10 |
+
# GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html
|
11 |
+
#############################################################################
|
12 |
+
# Corresponding Authors: Timur Madzhidov and Alexandre Varnek
|
13 |
+
# Corresponding Authors' emails: [email protected] and [email protected]
|
14 |
+
# Main contributors: Arkadii Lin, Natalia Duybankova, Ramil Nugmanov, Rail Suleymanov and Timur Madzhidov
|
15 |
+
# Copyright: Copyright 2020,
|
16 |
+
# MaDeSmart, Machine Design of Small Molecules by AI
|
17 |
+
# VLAIO project HBC.2018.2287
|
18 |
+
# Credits: Kazan Federal University, Russia
|
19 |
+
# University of Strasbourg, France
|
20 |
+
# University of Linz, Austria
|
21 |
+
# University of Leuven, Belgium
|
22 |
+
# Janssen Pharmaceutica N.V., Beerse, Belgium
|
23 |
+
# Rail Suleymanov, Arcadia, St. Petersburg, Russia
|
24 |
+
# License: GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html
|
25 |
+
# Version: 00.02
|
26 |
+
#############################################################################
|
27 |
+
|
28 |
+
from CGRtools.files import RDFRead, RDFWrite, SDFWrite, SDFRead, SMILESRead
|
29 |
+
from CGRtools.containers import MoleculeContainer, ReactionContainer
|
30 |
+
import logging
|
31 |
+
from ordered_set import OrderedSet
|
32 |
+
import os
|
33 |
+
import io
|
34 |
+
import pathlib
|
35 |
+
from pathlib import PurePosixPath
|
36 |
+
|
37 |
+
|
38 |
+
class Standardizer:
|
39 |
+
def __init__(self, skip_errors=False, log_file=None, keep_unbalanced_ions=False, id_tag='Reaction_ID',
|
40 |
+
action_on_isotopes=0, keep_reagents=False, logger=None, ignore_mapping=False, jvm_path=None,
|
41 |
+
rdkit_dearomatization=False, remove_unchanged_parts=True, skip_tautomerize=True,
|
42 |
+
jchem_path=None, add_reagents_to_reactants=False) -> None:
|
43 |
+
if logger is None:
|
44 |
+
self.logger = self._config_log(log_file, logger_name='logger')
|
45 |
+
else:
|
46 |
+
self.logger = logger
|
47 |
+
self._skip_errors = skip_errors
|
48 |
+
self._keep_unbalanced_ions = keep_unbalanced_ions
|
49 |
+
self._id_tag = id_tag
|
50 |
+
self._action_on_isotopes = action_on_isotopes
|
51 |
+
self._keep_reagents = keep_reagents
|
52 |
+
self._ignore_mapping = ignore_mapping
|
53 |
+
self._remove_unchanged_parts_flag = remove_unchanged_parts
|
54 |
+
self._skip_tautomerize = skip_tautomerize
|
55 |
+
self._dearomatize_by_rdkit = rdkit_dearomatization
|
56 |
+
self._reagents_to_reactants = add_reagents_to_reactants
|
57 |
+
if not skip_tautomerize:
|
58 |
+
if jvm_path:
|
59 |
+
os.environ['JDK_HOME'] = jvm_path
|
60 |
+
os.environ['JAVA_HOME'] = jvm_path
|
61 |
+
os.environ['PATH'] += f';{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};' \
|
62 |
+
f'{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};'
|
63 |
+
if jchem_path:
|
64 |
+
import jnius_config
|
65 |
+
jnius_config.add_classpath(jchem_path)
|
66 |
+
from jnius import autoclass
|
67 |
+
Standardizer = autoclass('chemaxon.standardizer.Standardizer')
|
68 |
+
self._Molecule = autoclass('chemaxon.struc.Molecule')
|
69 |
+
self._MolHandler = autoclass('chemaxon.util.MolHandler')
|
70 |
+
self._standardizer = Standardizer('tautomerize')
|
71 |
+
|
72 |
+
def standardize_file(self, input_file=None) -> OrderedSet:
|
73 |
+
"""
|
74 |
+
Standardize a set of reactions in a file. Returns an ordered set of ReactionContainer objects passed the
|
75 |
+
standardization protocol.
|
76 |
+
:param input_file: str
|
77 |
+
:return: OrderedSet
|
78 |
+
"""
|
79 |
+
if pathlib.Path(input_file).suffix == '.rdf':
|
80 |
+
data = self._read_RDF(input_file)
|
81 |
+
elif pathlib.Path(input_file).suffix == '.smi' or pathlib.Path(input_file).suffix == '.smiles':
|
82 |
+
data = self._read_SMILES(input_file)
|
83 |
+
else:
|
84 |
+
raise ValueError('Data format is not recognized!')
|
85 |
+
|
86 |
+
print("{0} reactions passed..".format(len(data)))
|
87 |
+
return data
|
88 |
+
|
89 |
+
def _read_RDF(self, input_file) -> OrderedSet:
|
90 |
+
"""
|
91 |
+
Reads an RDF file. Returns an ordered set of ReactionContainer objects passed the standardization protocol.
|
92 |
+
:param input_file: str
|
93 |
+
:return: OrderedSet
|
94 |
+
"""
|
95 |
+
data = OrderedSet()
|
96 |
+
self.logger.info('Start..')
|
97 |
+
with RDFRead(input_file, ignore=self._ignore_mapping, store_log=True, remap=self._ignore_mapping) as ifile, \
|
98 |
+
open(input_file) as meta_searcher:
|
99 |
+
for reaction in ifile._data:
|
100 |
+
if isinstance(reaction, tuple):
|
101 |
+
meta_searcher.seek(reaction.position)
|
102 |
+
flag = False
|
103 |
+
for line in meta_searcher:
|
104 |
+
if flag and '$RFMT' in line:
|
105 |
+
self.logger.critical(f'Reaction id extraction problem rised for the reaction '
|
106 |
+
f'#{reaction.number + 1}: a reaction id was expected but $RFMT line '
|
107 |
+
f'was found!')
|
108 |
+
if flag:
|
109 |
+
self.logger.critical(f'Reaction {line.strip().split()[1]}: Parser has returned an error '
|
110 |
+
f'message\n{reaction.log}')
|
111 |
+
break
|
112 |
+
elif '$RFMT' in line:
|
113 |
+
self.logger.critical(f'Reaction #{reaction.number + 1} has no reaction id!')
|
114 |
+
elif f'$DTYPE {self._id_tag}' in line:
|
115 |
+
flag = True
|
116 |
+
continue
|
117 |
+
standardized_reaction = self.standardize(reaction)
|
118 |
+
if standardized_reaction:
|
119 |
+
if standardized_reaction not in data:
|
120 |
+
data.add(standardized_reaction)
|
121 |
+
else:
|
122 |
+
i = data.index(standardized_reaction)
|
123 |
+
if 'Extraction_IDs' not in data[i].meta:
|
124 |
+
data[i].meta['Extraction_IDs'] = ''
|
125 |
+
data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') +
|
126 |
+
[reaction.meta[self._id_tag]])
|
127 |
+
self.logger.info('Reaction {0} is a duplicate of the reaction {1}..'
|
128 |
+
.format(reaction.meta[self._id_tag], data[i].meta[self._id_tag]))
|
129 |
+
return data
|
130 |
+
|
131 |
+
def _read_SMILES(self, input_file) -> OrderedSet:
|
132 |
+
"""
|
133 |
+
Reads a SMILES file. Returns an ordered set of ReactionContainer objects passed the standardization protocol.
|
134 |
+
:param input_file: str
|
135 |
+
:return: OrderedSet
|
136 |
+
"""
|
137 |
+
data = OrderedSet()
|
138 |
+
self.logger.info('Start..')
|
139 |
+
with SMILESRead(input_file, ignore=True, store_log=True, remap=self._ignore_mapping, header=True) as ifile, \
|
140 |
+
open(input_file) as meta_searcher:
|
141 |
+
id_tag_position = meta_searcher.readline().strip().split().index(self._id_tag)
|
142 |
+
if id_tag_position is None or id_tag_position == 0:
|
143 |
+
self.logger.critical(f'No reaction ID tag was found in the header!')
|
144 |
+
raise ValueError(f'No reaction ID tag was found in the header!')
|
145 |
+
for reaction in ifile._data:
|
146 |
+
if isinstance(reaction, tuple):
|
147 |
+
meta_searcher.seek(reaction.position)
|
148 |
+
line = meta_searcher.readline().strip().split()
|
149 |
+
if len(line) <= id_tag_position:
|
150 |
+
self.logger.critical(f'No reaction ID tag was found in line {reaction.number}!')
|
151 |
+
raise ValueError(f'No reaction ID tag was found in line {reaction.number}!')
|
152 |
+
r_id = line[id_tag_position]
|
153 |
+
self.logger.critical(f'Reaction {r_id}: Parser has returned an error message\n{reaction.log}')
|
154 |
+
continue
|
155 |
+
|
156 |
+
standardized_reaction = self.standardize(reaction)
|
157 |
+
if standardized_reaction:
|
158 |
+
if standardized_reaction not in data:
|
159 |
+
data.add(standardized_reaction)
|
160 |
+
else:
|
161 |
+
i = data.index(standardized_reaction)
|
162 |
+
if 'Extraction_IDs' not in data[i].meta:
|
163 |
+
data[i].meta['Extraction_IDs'] = ''
|
164 |
+
data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') +
|
165 |
+
[reaction.meta[self._id_tag]])
|
166 |
+
self.logger.info('Reaction {0} is a duplicate of the reaction {1}..'
|
167 |
+
.format(reaction.meta[self._id_tag], data[i].meta[self._id_tag]))
|
168 |
+
return data
|
169 |
+
|
170 |
+
def standardize(self, reaction: ReactionContainer) -> ReactionContainer:
|
171 |
+
"""
|
172 |
+
Standardization protocol: transform functional groups, kekulize, remove explicit hydrogens,
|
173 |
+
check for radicals (remove if something was found), check for isotopes, regroup ions (if the total charge
|
174 |
+
of reactants and/or products is not zero, and the 'keep_unbalanced_ions' option is False which is by default,
|
175 |
+
such reactions are removed; if the 'keep_unbalanced_ions' option is set True, they are kept), check valences
|
176 |
+
(remove if something is wrong), aromatize (thiele method), fix mapping (for symmetric functional groups) if
|
177 |
+
such is in, remove unchanged parts.
|
178 |
+
:param reaction: ReactionContainer
|
179 |
+
:return: ReactionContainer
|
180 |
+
"""
|
181 |
+
self.logger.info('Reaction {0}..'.format(reaction.meta[self._id_tag]))
|
182 |
+
try:
|
183 |
+
reaction.standardize()
|
184 |
+
except:
|
185 |
+
self.logger.exception(
|
186 |
+
'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag]))
|
187 |
+
if not self._skip_errors:
|
188 |
+
raise Exception(
|
189 |
+
'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag]))
|
190 |
+
else:
|
191 |
+
return
|
192 |
+
try:
|
193 |
+
reaction.kekule()
|
194 |
+
except:
|
195 |
+
self.logger.exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag]))
|
196 |
+
if not self._skip_errors:
|
197 |
+
raise Exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag]))
|
198 |
+
else:
|
199 |
+
return
|
200 |
+
try:
|
201 |
+
if self._check_valence(reaction):
|
202 |
+
self.logger.info(
|
203 |
+
'Reaction {0}: Bad valence: {1}'.format(reaction.meta[self._id_tag], reaction.meta['mistake']))
|
204 |
+
return
|
205 |
+
except:
|
206 |
+
self.logger.exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag]))
|
207 |
+
if not self._skip_errors:
|
208 |
+
self.logger.critical('Stop the algorithm!')
|
209 |
+
raise Exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag]))
|
210 |
+
else:
|
211 |
+
return
|
212 |
+
try:
|
213 |
+
if not self._skip_tautomerize:
|
214 |
+
reaction = self._tautomerize(reaction)
|
215 |
+
except:
|
216 |
+
self.logger.exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag]))
|
217 |
+
if not self._skip_errors:
|
218 |
+
raise Exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag]))
|
219 |
+
else:
|
220 |
+
return
|
221 |
+
try:
|
222 |
+
reaction.implicify_hydrogens()
|
223 |
+
except:
|
224 |
+
self.logger.exception(
|
225 |
+
'Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag]))
|
226 |
+
if not self._skip_errors:
|
227 |
+
raise Exception('Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag]))
|
228 |
+
else:
|
229 |
+
return
|
230 |
+
try:
|
231 |
+
if self._check_radicals(reaction):
|
232 |
+
self.logger.info('Reaction {0}: Radicals were found..'.format(reaction.meta[self._id_tag]))
|
233 |
+
return
|
234 |
+
except:
|
235 |
+
self.logger.exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag]))
|
236 |
+
if not self._skip_errors:
|
237 |
+
raise Exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag]))
|
238 |
+
else:
|
239 |
+
return
|
240 |
+
try:
|
241 |
+
if self._action_on_isotopes == 1 and self._check_isotopes(reaction):
|
242 |
+
self.logger.info('Reaction {0}: Isotopes were found..'.format(reaction.meta[self._id_tag]))
|
243 |
+
return
|
244 |
+
elif self._action_on_isotopes == 2 and self._check_isotopes(reaction):
|
245 |
+
reaction.clean_isotopes()
|
246 |
+
self.logger.info('Reaction {0}: Isotopes were removed but the reaction was kept..'.format(
|
247 |
+
reaction.meta[self._id_tag]))
|
248 |
+
except:
|
249 |
+
self.logger.exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag]))
|
250 |
+
if not self._skip_errors:
|
251 |
+
raise Exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag]))
|
252 |
+
else:
|
253 |
+
return
|
254 |
+
try:
|
255 |
+
reaction, return_code = self._split_ions(reaction)
|
256 |
+
if return_code == 1:
|
257 |
+
self.logger.info('Reaction {0}: Ions were split..'.format(reaction.meta[self._id_tag]))
|
258 |
+
elif return_code == 2:
|
259 |
+
self.logger.info('Reaction {0}: Ions were split but the reaction is imbalanced..'.format(
|
260 |
+
reaction.meta[self._id_tag]))
|
261 |
+
if not self._keep_unbalanced_ions:
|
262 |
+
return
|
263 |
+
except:
|
264 |
+
self.logger.exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag]))
|
265 |
+
if not self._skip_errors:
|
266 |
+
raise Exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag]))
|
267 |
+
else:
|
268 |
+
return
|
269 |
+
try:
|
270 |
+
reaction.thiele()
|
271 |
+
except:
|
272 |
+
self.logger.exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag]))
|
273 |
+
if not self._skip_errors:
|
274 |
+
raise Exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag]))
|
275 |
+
else:
|
276 |
+
return
|
277 |
+
try:
|
278 |
+
reaction.fix_mapping()
|
279 |
+
except:
|
280 |
+
self.logger.exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag]))
|
281 |
+
if not self._skip_errors:
|
282 |
+
raise Exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag]))
|
283 |
+
else:
|
284 |
+
return
|
285 |
+
try:
|
286 |
+
if self._remove_unchanged_parts_flag:
|
287 |
+
reaction = self._remove_unchanged_parts(reaction)
|
288 |
+
if not reaction.reactants and reaction.products:
|
289 |
+
self.logger.info('Reaction {0}: Reactants are empty..'.format(reaction.meta[self._id_tag]))
|
290 |
+
return
|
291 |
+
if not reaction.products and reaction.reactants:
|
292 |
+
self.logger.info('Reaction {0}: Products are empty..'.format(reaction.meta[self._id_tag]))
|
293 |
+
return
|
294 |
+
if not reaction.reactants and not reaction.products:
|
295 |
+
self.logger.exception(
|
296 |
+
'Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
|
297 |
+
reaction.meta[self._id_tag]))
|
298 |
+
return
|
299 |
+
except:
|
300 |
+
self.logger.exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
|
301 |
+
reaction.meta[self._id_tag]))
|
302 |
+
if not self._skip_errors:
|
303 |
+
raise Exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format(
|
304 |
+
reaction.meta[self._id_tag]))
|
305 |
+
else:
|
306 |
+
return
|
307 |
+
self.logger.debug('Reaction {0} is done..'.format(reaction.meta[self._id_tag]))
|
308 |
+
return reaction
|
309 |
+
|
310 |
+
def write(self, output_file: str, data: OrderedSet) -> None:
|
311 |
+
"""
|
312 |
+
Dump a set of reactions.
|
313 |
+
:param data: OrderedSet
|
314 |
+
:param output_file: str
|
315 |
+
:return: None
|
316 |
+
"""
|
317 |
+
with RDFWrite(output_file) as out:
|
318 |
+
for r in data:
|
319 |
+
out.write(r)
|
320 |
+
|
321 |
+
def _check_valence(self, reaction: ReactionContainer) -> bool:
|
322 |
+
"""
|
323 |
+
Checks valences.
|
324 |
+
:param reaction: ReactionContainer
|
325 |
+
:return: bool
|
326 |
+
"""
|
327 |
+
mistakes = []
|
328 |
+
for molecule in (reaction.reactants + reaction.products + reaction.reagents):
|
329 |
+
valence_mistakes = molecule.check_valence()
|
330 |
+
if valence_mistakes:
|
331 |
+
mistakes.append(("|".join([str(num) for num in valence_mistakes]),
|
332 |
+
"|".join([str(molecule.atom(n)) for n in valence_mistakes]), str(molecule)))
|
333 |
+
if mistakes:
|
334 |
+
message = ",".join([f'{atom_nums} at {atoms} in {smiles}' for atom_nums, atoms, smiles in mistakes])
|
335 |
+
reaction.meta['mistake'] = f'Valence mistake: {message}'
|
336 |
+
return True
|
337 |
+
return False
|
338 |
+
|
339 |
+
def _config_log(self, log_file: str, logger_name: str):
|
340 |
+
logger = logging.getLogger(logger_name)
|
341 |
+
logger.setLevel(logging.DEBUG)
|
342 |
+
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
|
343 |
+
logger.handlers.clear()
|
344 |
+
fileHandler = logging.FileHandler(filename=log_file, mode='w')
|
345 |
+
fileHandler.setFormatter(formatter)
|
346 |
+
fileHandler.setLevel(logging.DEBUG)
|
347 |
+
logger.addHandler(fileHandler)
|
348 |
+
# logging.basicConfig(filename=log_file, level=logging.info, filemode='w', format='%(asctime)s: %(message)s',
|
349 |
+
# datefmt='%d/%m/%Y %H:%M:%S')
|
350 |
+
return logger
|
351 |
+
|
352 |
+
def _check_radicals(self, reaction: ReactionContainer) -> bool:
|
353 |
+
"""
|
354 |
+
Checks radicals.
|
355 |
+
:param reaction: ReactionContainer
|
356 |
+
:return: bool
|
357 |
+
"""
|
358 |
+
for molecule in (reaction.reactants + reaction.products + reaction.reagents):
|
359 |
+
for n, atom in molecule.atoms():
|
360 |
+
if atom.is_radical:
|
361 |
+
return True
|
362 |
+
return False
|
363 |
+
|
364 |
+
def _calc_charge(self, molecule: MoleculeContainer) -> int:
|
365 |
+
"""Computing charge of molecule.
|
366 |
+
:param: molecule: MoleculeContainer
|
367 |
+
:return: int
|
368 |
+
"""
|
369 |
+
return sum(molecule._charges.values())
|
370 |
+
|
371 |
+
def _group_ions(self, reaction: ReactionContainer):
|
372 |
+
"""
|
373 |
+
Ungroup molecules recorded as ions, regroup ions. Returns a tuple with the corresponding ReactionContainer and
|
374 |
+
return code as int (0 - nothing was changed, 1 - ions were regrouped, 2 - ions are unbalanced).
|
375 |
+
:param reaction: current reaction
|
376 |
+
:return: tuple[ReactionContainer, int]
|
377 |
+
"""
|
378 |
+
meta = reaction.meta
|
379 |
+
reaction_parts = []
|
380 |
+
return_codes = []
|
381 |
+
for molecules in (reaction.reactants, reaction.reagents, reaction.products):
|
382 |
+
divided_molecules = [x for m in molecules for x in m.split('.')]
|
383 |
+
|
384 |
+
if len(divided_molecules) == 0:
|
385 |
+
reaction_parts.append(())
|
386 |
+
continue
|
387 |
+
elif len(divided_molecules) == 1 and self._calc_charge(divided_molecules[0]) == 0:
|
388 |
+
return_codes.append(0)
|
389 |
+
reaction_parts.append(molecules)
|
390 |
+
continue
|
391 |
+
elif len(divided_molecules) == 1:
|
392 |
+
return_codes.append(2)
|
393 |
+
reaction_parts.append(molecules)
|
394 |
+
continue
|
395 |
+
|
396 |
+
new_molecules = []
|
397 |
+
cations, anions, ions = [], [], []
|
398 |
+
total_charge = 0
|
399 |
+
for molecule in divided_molecules:
|
400 |
+
mol_charge = self._calc_charge(molecule)
|
401 |
+
total_charge += mol_charge
|
402 |
+
if mol_charge == 0:
|
403 |
+
new_molecules.append(molecule)
|
404 |
+
elif mol_charge > 0:
|
405 |
+
cations.append((mol_charge, molecule))
|
406 |
+
ions.append((mol_charge, molecule))
|
407 |
+
else:
|
408 |
+
anions.append((mol_charge, molecule))
|
409 |
+
ions.append((mol_charge, molecule))
|
410 |
+
|
411 |
+
if len(cations) == 0 and len(anions) == 0:
|
412 |
+
return_codes.append(0)
|
413 |
+
reaction_parts.append(tuple(new_molecules))
|
414 |
+
continue
|
415 |
+
elif total_charge != 0:
|
416 |
+
return_codes.append(2)
|
417 |
+
reaction_parts.append(tuple(divided_molecules))
|
418 |
+
continue
|
419 |
+
else:
|
420 |
+
salt = MoleculeContainer()
|
421 |
+
for ion_charge, ion in ions:
|
422 |
+
salt = salt.union(ion)
|
423 |
+
total_charge += ion_charge
|
424 |
+
if total_charge == 0:
|
425 |
+
new_molecules.append(salt)
|
426 |
+
salt = MoleculeContainer()
|
427 |
+
if total_charge != 0:
|
428 |
+
new_molecules.append(salt)
|
429 |
+
return_codes.append(2)
|
430 |
+
reaction_parts.append(tuple(new_molecules))
|
431 |
+
else:
|
432 |
+
return_codes.append(1)
|
433 |
+
reaction_parts.append(tuple(new_molecules))
|
434 |
+
return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2],
|
435 |
+
meta=meta), max(return_codes)
|
436 |
+
|
437 |
+
def _split_ions(self, reaction: ReactionContainer):
|
438 |
+
"""
|
439 |
+
Split ions in a reaction. Returns a tuple with the corresponding ReactionContainer and
|
440 |
+
a return code as int (0 - nothing was changed, 1 - ions were split, 2 - ions were split but the reaction
|
441 |
+
is imbalanced).
|
442 |
+
:param reaction: current reaction
|
443 |
+
:return: tuple[ReactionContainer, int]
|
444 |
+
"""
|
445 |
+
meta = reaction.meta
|
446 |
+
reaction_parts = []
|
447 |
+
return_codes = []
|
448 |
+
for molecules in (reaction.reactants, reaction.reagents, reaction.products):
|
449 |
+
divided_molecules = [x for m in molecules for x in m.split('.')]
|
450 |
+
|
451 |
+
total_charge = 0
|
452 |
+
ions_present = False
|
453 |
+
for molecule in divided_molecules:
|
454 |
+
mol_charge = self._calc_charge(molecule)
|
455 |
+
total_charge += mol_charge
|
456 |
+
if mol_charge != 0:
|
457 |
+
ions_present = True
|
458 |
+
|
459 |
+
if ions_present and total_charge:
|
460 |
+
return_codes.append(2)
|
461 |
+
elif ions_present:
|
462 |
+
return_codes.append(1)
|
463 |
+
else:
|
464 |
+
return_codes.append(0)
|
465 |
+
|
466 |
+
reaction_parts.append(tuple(divided_molecules))
|
467 |
+
|
468 |
+
return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2],
|
469 |
+
meta=meta), max(return_codes)
|
470 |
+
|
471 |
+
def _remove_unchanged_parts(self, reaction: ReactionContainer) -> ReactionContainer:
|
472 |
+
"""
|
473 |
+
Ungroup molecules, remove unchanged parts from reactants and products.
|
474 |
+
:param reaction: current reaction
|
475 |
+
:return: ReactionContainer
|
476 |
+
"""
|
477 |
+
meta = reaction.meta
|
478 |
+
new_reactants = [m for m in reaction.reactants]
|
479 |
+
new_reagents = [m for m in reaction.reagents]
|
480 |
+
if self._reagents_to_reactants:
|
481 |
+
new_reactants.extend(new_reagents)
|
482 |
+
new_reagents = []
|
483 |
+
reactants = new_reactants.copy()
|
484 |
+
new_products = [m for m in reaction.products]
|
485 |
+
|
486 |
+
for reactant in reactants:
|
487 |
+
if reactant in new_products:
|
488 |
+
new_reagents.append(reactant)
|
489 |
+
new_reactants.remove(reactant)
|
490 |
+
new_products.remove(reactant)
|
491 |
+
if not self._keep_reagents:
|
492 |
+
new_reagents = []
|
493 |
+
return ReactionContainer(reactants=tuple(new_reactants), reagents=tuple(new_reagents),
|
494 |
+
products=tuple(new_products), meta=meta)
|
495 |
+
|
496 |
+
def _check_isotopes(self, reaction: ReactionContainer) -> bool:
|
497 |
+
for molecules in (reaction.reactants, reaction.products):
|
498 |
+
for molecule in molecules:
|
499 |
+
for _, atom in molecule.atoms():
|
500 |
+
if atom.isotope:
|
501 |
+
return True
|
502 |
+
return False
|
503 |
+
|
504 |
+
def _tautomerize(self, reaction: ReactionContainer) -> ReactionContainer:
|
505 |
+
"""
|
506 |
+
Perform ChemAxon tautomerization.
|
507 |
+
:param reaction: reaction that needs to be tautomerized
|
508 |
+
:return: ReactionContainer
|
509 |
+
"""
|
510 |
+
new_molecules = []
|
511 |
+
for part in [reaction.reactants, reaction.reagents, reaction.products]:
|
512 |
+
tmp = []
|
513 |
+
for mol in part:
|
514 |
+
with io.StringIO() as f, SDFWrite(f) as i:
|
515 |
+
i.write(mol)
|
516 |
+
sdf = f.getvalue()
|
517 |
+
mol_handler = self._MolHandler(sdf)
|
518 |
+
mol_handler.clean(True, '2')
|
519 |
+
molecule = mol_handler.getMolecule()
|
520 |
+
self._standardizer.standardize(molecule)
|
521 |
+
new_mol_handler = self._MolHandler(molecule)
|
522 |
+
new_sdf = new_mol_handler.toFormat('SDF')
|
523 |
+
with io.StringIO('\n ' + new_sdf.strip()) as f, SDFRead(f, remap=False) as i:
|
524 |
+
new_mol = next(i)
|
525 |
+
tmp.append(new_mol)
|
526 |
+
new_molecules.append(tmp)
|
527 |
+
return ReactionContainer(reactants=tuple(new_molecules[0]), reagents=tuple(new_molecules[1]),
|
528 |
+
products=tuple(new_molecules[2]), meta=reaction.meta)
|
529 |
+
|
530 |
+
# def _dearomatize_by_RDKit(self, reaction: ReactionContainer) -> ReactionContainer:
|
531 |
+
# """
|
532 |
+
# Dearomatizes by RDKit (needs in case of some mappers, such as RXNMapper).
|
533 |
+
# :param reaction: ReactionContainer
|
534 |
+
# :return: ReactionContainer
|
535 |
+
# """
|
536 |
+
# with io.StringIO() as f, RDFWrite(f) as i:
|
537 |
+
# i.write(reaction)
|
538 |
+
# s = '\n'.join(f.getvalue().split('\n')[3:])
|
539 |
+
# rxn = rdChemReactions.ReactionFromRxnBlock(s)
|
540 |
+
# reactants, reagents, products = [], [], []
|
541 |
+
# for mol in rxn.GetReactants():
|
542 |
+
# try:
|
543 |
+
# Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
|
544 |
+
# except Chem.rdchem.KekulizeException:
|
545 |
+
# return reaction
|
546 |
+
# with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
|
547 |
+
# reactants.append(next(sdf_i))
|
548 |
+
# for mol in rxn.GetAgents():
|
549 |
+
# try:
|
550 |
+
# Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
|
551 |
+
# except Chem.rdchem.KekulizeException:
|
552 |
+
# return reaction
|
553 |
+
# with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
|
554 |
+
# reagents.append(next(sdf_i))
|
555 |
+
# for mol in rxn.GetProducts():
|
556 |
+
# try:
|
557 |
+
# Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True)
|
558 |
+
# except Chem.rdchem.KekulizeException:
|
559 |
+
# return reaction
|
560 |
+
# with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i:
|
561 |
+
# products.append(next(sdf_i))
|
562 |
+
#
|
563 |
+
# new_reaction = ReactionContainer(reactants=tuple(reactants), reagents=tuple(reagents), products=tuple(products),
|
564 |
+
# meta=reaction.meta)
|
565 |
+
#
|
566 |
+
# return new_reaction
|
567 |
+
|
568 |
+
|
569 |
+
if __name__ == '__main__':
|
570 |
+
import argparse
|
571 |
+
|
572 |
+
parser = argparse.ArgumentParser(description="This is a tool for reaction standardization.",
|
573 |
+
epilog="Arkadii Lin, Strasbourg/Kazan 2020", prog="Standardizer")
|
574 |
+
parser.add_argument("-i", "--input", type=str, help="Input RDF file.")
|
575 |
+
parser.add_argument("-o", "--output", type=str, help="Output RDF file.")
|
576 |
+
parser.add_argument("-id", "--idTag", default='Reaction_ID', type=str, help="ID tag in the RDF file.")
|
577 |
+
parser.add_argument("--skipErrors", action="store_true", help="Skip errors.")
|
578 |
+
parser.add_argument("--keep_unbalanced_ions", action="store_true", help="Will keep reactions with unbalanced ions.")
|
579 |
+
parser.add_argument("--action_on_isotopes", type=int, default=0, help="Action performed if an isotope is "
|
580 |
+
"found: 0 - to ignore isotopes; "
|
581 |
+
"1 - to remove reactions with isotopes; "
|
582 |
+
"2 - to clear isotopes' labels.")
|
583 |
+
parser.add_argument("--keep_reagents", action="store_true", help="Will keep reagents from the reaction.")
|
584 |
+
parser.add_argument("--add_reagents", action="store_true", help="Will add the given reagents to reactants.")
|
585 |
+
parser.add_argument("--ignore_mapping", action="store_true", help="Will ignore the initial mapping in the file.")
|
586 |
+
parser.add_argument("--keep_unchanged_parts", action="store_true", help="Will keep unchanged parts in a reaction.")
|
587 |
+
parser.add_argument("--logFile", type=str, default='logFile.txt', help="Log file name.")
|
588 |
+
parser.add_argument("--skip_tautomerize", action="store_true", help="Will skip generation of the major tautomer.")
|
589 |
+
parser.add_argument("--rdkit_dearomatization", action="store_true", help="Will kekulize the reaction using RDKit "
|
590 |
+
"facilities.")
|
591 |
+
parser.add_argument("--jvm_path", type=str,
|
592 |
+
help="JVM path (e.g. C:\\Program Files\\Java\\jdk-13.0.2).")
|
593 |
+
parser.add_argument("--jchem_path", type=str, help="JChem path (e.g. C:\\Users\\user\\JChemSuite\\lib\\jchem.jar).")
|
594 |
+
args = parser.parse_args()
|
595 |
+
|
596 |
+
standardizer = Standardizer(skip_errors=args.skipErrors, log_file=args.logFile,
|
597 |
+
keep_unbalanced_ions=args.keep_unbalanced_ions, id_tag=args.idTag,
|
598 |
+
action_on_isotopes=args.action_on_isotopes, keep_reagents=args.keep_reagents,
|
599 |
+
ignore_mapping=args.ignore_mapping, skip_tautomerize=args.skip_tautomerize,
|
600 |
+
remove_unchanged_parts=(not args.keep_unchanged_parts), jvm_path=args.jvm_path,
|
601 |
+
jchem_path=args.jchem_path, rdkit_dearomatization=args.rdkit_dearomatization,
|
602 |
+
add_reagents_to_reactants=args.add_reagents)
|
603 |
+
data = standardizer.standardize_file(input_file=args.input)
|
604 |
+
standardizer.write(output_file=args.output, data=data)
|
SynTool/chem/reaction.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing classes and functions for manipulating reactions and reaction rules
|
3 |
+
"""
|
4 |
+
|
5 |
+
from CGRtools.reactor import Reactor
|
6 |
+
from CGRtools.containers import MoleculeContainer, ReactionContainer
|
7 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
8 |
+
|
9 |
+
|
10 |
+
class Reaction(ReactionContainer):
|
11 |
+
"""
|
12 |
+
Reaction class can be used for a general representation of reaction for different chemoinformatics Python packages
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
"""
|
17 |
+
Initializes the reaction object.
|
18 |
+
"""
|
19 |
+
super().__init__(*args, **kwargs)
|
20 |
+
|
21 |
+
|
22 |
+
def add_small_mols(big_mol, small_molecules=None):
|
23 |
+
"""
|
24 |
+
The function takes a molecule and returns a list of modified molecules where each small molecule has been added to
|
25 |
+
the big molecule.
|
26 |
+
|
27 |
+
:param big_mol: A molecule
|
28 |
+
:param small_molecules: A list of small molecules that need to be added to the molecule
|
29 |
+
:return: Returns a list of molecules.
|
30 |
+
"""
|
31 |
+
if small_molecules:
|
32 |
+
tmp_mol = big_mol.copy()
|
33 |
+
transition_mapping = {}
|
34 |
+
for small_mol in small_molecules:
|
35 |
+
|
36 |
+
for n, atom in small_mol.atoms():
|
37 |
+
new_number = tmp_mol.add_atom(atom.atomic_symbol)
|
38 |
+
transition_mapping[n] = new_number
|
39 |
+
|
40 |
+
for atom, neighbor, bond in small_mol.bonds():
|
41 |
+
tmp_mol.add_bond(transition_mapping[atom], transition_mapping[neighbor], bond)
|
42 |
+
|
43 |
+
transition_mapping = {}
|
44 |
+
return tmp_mol.split()
|
45 |
+
else:
|
46 |
+
return [big_mol]
|
47 |
+
|
48 |
+
|
49 |
+
def apply_reaction_rule(
|
50 |
+
molecule: MoleculeContainer,
|
51 |
+
reaction_rule: Reactor,
|
52 |
+
sort_reactions: bool = False,
|
53 |
+
top_reactions_num: int = 3,
|
54 |
+
validate_products: bool = True,
|
55 |
+
rebuild_with_cgr: bool = False,
|
56 |
+
) -> list[MoleculeContainer]:
|
57 |
+
"""
|
58 |
+
The function applies a reaction rule to a given molecule.
|
59 |
+
|
60 |
+
:param rebuild_with_cgr:
|
61 |
+
:param validate_products:
|
62 |
+
:param sort_reactions:
|
63 |
+
:param top_reactions_num:
|
64 |
+
:param molecule: A MoleculeContainer object representing the molecule on which the reaction rule will be applied
|
65 |
+
:type molecule: MoleculeContainer
|
66 |
+
:param reaction_rule: The reaction_rule is an instance of the Reactor class. It represents a reaction rule that
|
67 |
+
can be applied to a molecule
|
68 |
+
:type reaction_rule: Reactor
|
69 |
+
"""
|
70 |
+
|
71 |
+
reactants = add_small_mols(molecule, small_molecules=False)
|
72 |
+
|
73 |
+
try:
|
74 |
+
if sort_reactions:
|
75 |
+
unsorted_reactions = list(reaction_rule(reactants))
|
76 |
+
sorted_reactions = sorted(
|
77 |
+
unsorted_reactions,
|
78 |
+
key=lambda react: len(list(filter(lambda mol: len(mol) > 6, react.products))),
|
79 |
+
reverse=True
|
80 |
+
)
|
81 |
+
reactions = sorted_reactions[:top_reactions_num] # Take top-N reactions from reactor
|
82 |
+
else:
|
83 |
+
reactions = []
|
84 |
+
for reaction in reaction_rule(reactants):
|
85 |
+
reactions.append(reaction)
|
86 |
+
if len(reactions) == top_reactions_num:
|
87 |
+
break
|
88 |
+
except IndexError:
|
89 |
+
reactions = []
|
90 |
+
|
91 |
+
for reaction in reactions:
|
92 |
+
if rebuild_with_cgr:
|
93 |
+
cgr = reaction.compose()
|
94 |
+
products = cgr.decompose()[1].split()
|
95 |
+
else:
|
96 |
+
products = reaction.products
|
97 |
+
products = [mol for mol in products if len(mol) > 0]
|
98 |
+
if validate_products:
|
99 |
+
for molecule in products:
|
100 |
+
try:
|
101 |
+
molecule.kekule()
|
102 |
+
if molecule.check_valence():
|
103 |
+
yield None
|
104 |
+
molecule.thiele()
|
105 |
+
except InvalidAromaticRing:
|
106 |
+
yield None
|
107 |
+
yield products
|
SynTool/chem/reaction_rules/__init__.py
ADDED
File without changes
|
SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc
ADDED
Binary file (25.7 kB). View file
|
|
SynTool/chem/reaction_rules/extraction.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing functions with fixed protocol for reaction rules extraction
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import pickle
|
6 |
+
from collections import defaultdict
|
7 |
+
from itertools import islice
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Union, Tuple, IO, Dict, Set, Iterable, Any
|
10 |
+
from os.path import splitext
|
11 |
+
|
12 |
+
|
13 |
+
import ray
|
14 |
+
from CGRtools.containers import MoleculeContainer, QueryContainer, ReactionContainer
|
15 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
16 |
+
from CGRtools.reactor import Reactor
|
17 |
+
from tqdm.auto import tqdm
|
18 |
+
|
19 |
+
from SynTool.chem.utils import reverse_reaction
|
20 |
+
from SynTool.utils.config import RuleExtractionConfig
|
21 |
+
from SynTool.utils.files import ReactionReader
|
22 |
+
|
23 |
+
|
24 |
+
def extract_rules_from_reactions(
|
25 |
+
config: RuleExtractionConfig,
|
26 |
+
reaction_file: str,
|
27 |
+
rules_file_name: str = 'reaction_rules.pickle',
|
28 |
+
num_cpus: int = 1,
|
29 |
+
batch_size: int = 10,
|
30 |
+
) -> None:
|
31 |
+
"""
|
32 |
+
Extracts reaction rules from a set of reactions based on the given configuration.
|
33 |
+
|
34 |
+
This function initializes a Ray environment for distributed computing and processes each reaction
|
35 |
+
in the provided reaction database to extract reaction rules. It handles the reactions in batches,
|
36 |
+
parallelize the rule extraction process. Extracted rules are written to RDF files and their statistics
|
37 |
+
are recorded. The function also sorts the rules based on their popularity and saves the sorted rules.
|
38 |
+
|
39 |
+
:param config: Configuration settings for rule extraction, including file paths, batch size, and other parameters.
|
40 |
+
:param reaction_file: Path to the file containing reaction database.
|
41 |
+
:param rules_file_name: Name of the file to store the extracted rules.
|
42 |
+
:param num_cpus: Number of CPU cores to use for processing. Defaults to 1.
|
43 |
+
:param batch_size: Number of reactions to process in each batch. Defaults to 10.
|
44 |
+
|
45 |
+
:return: None
|
46 |
+
"""
|
47 |
+
|
48 |
+
# read files
|
49 |
+
reaction_file = Path(reaction_file).resolve(strict=True)
|
50 |
+
|
51 |
+
ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
|
52 |
+
|
53 |
+
rules_file_name, _ = splitext(rules_file_name)
|
54 |
+
with ReactionReader(reaction_file) as reactions:
|
55 |
+
pbar = tqdm(reactions, disable=False)
|
56 |
+
|
57 |
+
futures = {}
|
58 |
+
batch = []
|
59 |
+
max_concurrent_batches = num_cpus
|
60 |
+
|
61 |
+
extracted_rules_and_statistics = defaultdict(list)
|
62 |
+
for index, reaction in enumerate(reactions):
|
63 |
+
batch.append((index, reaction))
|
64 |
+
if len(batch) == batch_size:
|
65 |
+
future = process_reaction_batch.remote(batch, config)
|
66 |
+
futures[future] = None
|
67 |
+
batch = []
|
68 |
+
|
69 |
+
while len(futures) >= max_concurrent_batches:
|
70 |
+
process_completed_batches(futures, extracted_rules_and_statistics, pbar, batch_size)
|
71 |
+
|
72 |
+
if batch:
|
73 |
+
remaining_size = len(batch)
|
74 |
+
future = process_reaction_batch.remote(batch, config)
|
75 |
+
futures[future] = None
|
76 |
+
|
77 |
+
while futures:
|
78 |
+
process_completed_batches(futures, extracted_rules_and_statistics, pbar, remaining_size)
|
79 |
+
|
80 |
+
pbar.close()
|
81 |
+
|
82 |
+
sorted_rules = sort_rules(
|
83 |
+
extracted_rules_and_statistics,
|
84 |
+
min_popularity=config.min_popularity,
|
85 |
+
single_reactant_only=config.single_reactant_only,
|
86 |
+
)
|
87 |
+
|
88 |
+
with open(f"{rules_file_name}.pickle", "wb") as statistics_file:
|
89 |
+
pickle.dump(sorted_rules, statistics_file)
|
90 |
+
print(f'Number of extracted reaction rules: {len(sorted_rules)}')
|
91 |
+
|
92 |
+
ray.shutdown()
|
93 |
+
|
94 |
+
|
95 |
+
@ray.remote
|
96 |
+
def process_reaction_batch(
|
97 |
+
batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig
|
98 |
+
) -> list[tuple[int, list[ReactionContainer]]]:
|
99 |
+
"""
|
100 |
+
Processes a batch of reactions to extract reaction rules based on the given configuration.
|
101 |
+
|
102 |
+
This function operates as a remote task in a distributed system using Ray. It takes a batch of reactions,
|
103 |
+
where each reaction is paired with an index. For each reaction in the batch, it extracts reaction rules
|
104 |
+
as specified by the configuration object. The extracted rules for each reaction are then returned along
|
105 |
+
with the corresponding index.
|
106 |
+
|
107 |
+
:param batch: A list where each element is a tuple containing an index (int) and a ReactionContainer object.
|
108 |
+
The index is typically used to keep track of the reaction's position in a larger dataset.
|
109 |
+
:type batch: List[Tuple[int, ReactionContainer]]
|
110 |
+
|
111 |
+
:param config: An instance of ExtractRuleConfig that provides settings and parameters for the rule extraction process.
|
112 |
+
:type config: RuleExtractionConfig
|
113 |
+
|
114 |
+
:return: A list where each element is a tuple. The first element of the tuple is an index (int), and the second
|
115 |
+
is a list of ReactionContainer objects representing the extracted rules for the corresponding reaction.
|
116 |
+
:rtype: list[tuple[int, list[ReactionContainer]]]
|
117 |
+
|
118 |
+
This function is intended to be used in a distributed manner with Ray to parallelize the rule extraction
|
119 |
+
process across multiple reactions.
|
120 |
+
"""
|
121 |
+
processed_batch = []
|
122 |
+
for index, reaction in batch:
|
123 |
+
try:
|
124 |
+
extracted_rules = extract_rules(config, reaction)
|
125 |
+
processed_batch.append((index, extracted_rules))
|
126 |
+
except:
|
127 |
+
continue
|
128 |
+
return processed_batch
|
129 |
+
|
130 |
+
|
131 |
+
def process_completed_batches(
|
132 |
+
futures: dict,
|
133 |
+
rules_statistics: Dict[ReactionContainer, List[int]],
|
134 |
+
pbar: tqdm,
|
135 |
+
batch_size: int,
|
136 |
+
) -> None:
|
137 |
+
"""
|
138 |
+
Processes completed batches of reactions, updating the rules statistics and writing rules to a file.
|
139 |
+
|
140 |
+
This function waits for the completion of a batch of reactions processed in parallel (using Ray),
|
141 |
+
updates the statistics for each extracted rule, and writes the rules to a result file if they are new.
|
142 |
+
It also updates the progress bar with the size of the processed batch.
|
143 |
+
|
144 |
+
:param futures: A dictionary of futures representing ongoing batch processing tasks.
|
145 |
+
:type futures: dict
|
146 |
+
|
147 |
+
:param rules_statistics: A dictionary to keep track of statistics for each rule.
|
148 |
+
:type rules_statistics: Dict[ReactionContainer, List[int]]
|
149 |
+
|
150 |
+
:param pbar: A tqdm progress bar instance for updating the progress of batch processing.
|
151 |
+
:type pbar: tqdm
|
152 |
+
|
153 |
+
:param batch_size: The number of reactions processed in each batch.
|
154 |
+
:type batch_size: int
|
155 |
+
|
156 |
+
:return: None
|
157 |
+
"""
|
158 |
+
done, _ = ray.wait(list(futures.keys()), num_returns=1)
|
159 |
+
completed_batch = ray.get(done[0])
|
160 |
+
|
161 |
+
for index, extracted_rules in completed_batch:
|
162 |
+
for rule in extracted_rules:
|
163 |
+
prev_stats_len = len(rules_statistics)
|
164 |
+
rules_statistics[rule].append(index)
|
165 |
+
if len(rules_statistics) != prev_stats_len:
|
166 |
+
rule.meta["first_reaction_index"] = index
|
167 |
+
|
168 |
+
del futures[done[0]]
|
169 |
+
pbar.update(batch_size)
|
170 |
+
|
171 |
+
|
172 |
+
def extract_rules(
|
173 |
+
config: RuleExtractionConfig, reaction: ReactionContainer
|
174 |
+
) -> list[ReactionContainer]:
|
175 |
+
"""
|
176 |
+
Extracts reaction rules from a given reaction based on the specified configuration.
|
177 |
+
|
178 |
+
:param config: An instance of ExtractRuleConfig, which contains various configuration settings
|
179 |
+
for rule extraction, such as whether to include multicenter rules, functional groups,
|
180 |
+
ring structures, leaving and incoming groups, etc.
|
181 |
+
:param reaction: The reaction object (ReactionContainer) from which to extract rules. The reaction
|
182 |
+
object represents a chemical reaction with specified reactants, products, and possibly reagents.
|
183 |
+
:return: A list of ReactionContainer objects, each representing a distinct reaction rule. If
|
184 |
+
config.multicenter_rules is True, a single rule encompassing all reaction centers is returned.
|
185 |
+
Otherwise, separate rules for each reaction center are extracted, up to a maximum of 15 distinct centers.
|
186 |
+
"""
|
187 |
+
if config.multicenter_rules:
|
188 |
+
# Extract a single rule encompassing all reaction centers
|
189 |
+
return [create_rule(config, reaction)]
|
190 |
+
else:
|
191 |
+
# Extract separate rules for each distinct reaction center
|
192 |
+
distinct_rules = set()
|
193 |
+
for center_reaction in islice(reaction.enumerate_centers(), 15):
|
194 |
+
single_rule = create_rule(config, center_reaction)
|
195 |
+
distinct_rules.add(single_rule)
|
196 |
+
return list(distinct_rules)
|
197 |
+
|
198 |
+
|
199 |
+
def create_rule(
|
200 |
+
config: RuleExtractionConfig, reaction: ReactionContainer
|
201 |
+
) -> ReactionContainer:
|
202 |
+
"""
|
203 |
+
Creates a reaction rule from a given reaction based on the specified configuration.
|
204 |
+
|
205 |
+
:param config: An instance of ExtractRuleConfig, containing various settings that determine how
|
206 |
+
the rule is created, such as environmental atom count, inclusion of functional groups,
|
207 |
+
rings, leaving and incoming groups, and other parameters.
|
208 |
+
:param reaction: The reaction object (ReactionContainer) from which to create the rule. This object
|
209 |
+
represents a chemical reaction with specified reactants, products, and possibly reagents.
|
210 |
+
:return: A ReactionContainer object representing the extracted reaction rule. This rule includes
|
211 |
+
various elements of the reaction as specified by the configuration, such as reaction centers,
|
212 |
+
environmental atoms, functional groups, and others.
|
213 |
+
|
214 |
+
The function processes the reaction to create a rule that matches the configuration settings. It handles
|
215 |
+
the inclusion of environmental atoms, functional groups, ring structures, and leaving and incoming groups.
|
216 |
+
It also constructs substructures for reactants, products, and reagents, and cleans molecule representations
|
217 |
+
if required. Optionally, it validates the rule using a reactor.
|
218 |
+
"""
|
219 |
+
cgr = ~reaction
|
220 |
+
center_atoms = set(cgr.center_atoms)
|
221 |
+
|
222 |
+
# Add atoms of reaction environment based on config settings
|
223 |
+
center_atoms = add_environment_atoms(
|
224 |
+
cgr, center_atoms, config.environment_atom_count
|
225 |
+
)
|
226 |
+
|
227 |
+
# Include functional groups in the rule if specified in config
|
228 |
+
if config.include_func_groups:
|
229 |
+
rule_atoms = add_functional_groups(
|
230 |
+
reaction, center_atoms, config.func_groups_list
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
rule_atoms = center_atoms.copy()
|
234 |
+
|
235 |
+
# Include ring structures in the rule if specified in config
|
236 |
+
if config.include_rings:
|
237 |
+
rule_atoms = add_ring_structures(
|
238 |
+
cgr,
|
239 |
+
rule_atoms,
|
240 |
+
)
|
241 |
+
|
242 |
+
# Add leaving and incoming groups to the rule based on config settings
|
243 |
+
rule_atoms, meta_debug = add_leaving_incoming_groups(
|
244 |
+
reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups
|
245 |
+
)
|
246 |
+
|
247 |
+
# Create substructures for reactants, products, and reagents
|
248 |
+
(
|
249 |
+
reactant_substructures,
|
250 |
+
product_substructures,
|
251 |
+
reagents,
|
252 |
+
) = create_substructures_and_reagents(
|
253 |
+
reaction, rule_atoms, config.as_query_container, config.keep_reagents
|
254 |
+
)
|
255 |
+
|
256 |
+
# Clean atom marks in the molecules if they are being converted to query containers
|
257 |
+
if config.as_query_container:
|
258 |
+
reactant_substructures = clean_molecules(
|
259 |
+
reactant_substructures,
|
260 |
+
reaction.reactants,
|
261 |
+
center_atoms,
|
262 |
+
config.atom_info_retention,
|
263 |
+
)
|
264 |
+
product_substructures = clean_molecules(
|
265 |
+
product_substructures,
|
266 |
+
reaction.products,
|
267 |
+
center_atoms,
|
268 |
+
config.atom_info_retention,
|
269 |
+
)
|
270 |
+
|
271 |
+
# Assemble the final rule including metadata if specified
|
272 |
+
rule = assemble_final_rule(
|
273 |
+
reactant_substructures,
|
274 |
+
product_substructures,
|
275 |
+
reagents,
|
276 |
+
meta_debug,
|
277 |
+
config.keep_metadata,
|
278 |
+
reaction,
|
279 |
+
)
|
280 |
+
|
281 |
+
if config.reverse_rule:
|
282 |
+
rule = reverse_reaction(rule)
|
283 |
+
reaction = reverse_reaction(reaction)
|
284 |
+
|
285 |
+
# Validate the rule using a reactor if validation is enabled in config
|
286 |
+
if config.reactor_validation:
|
287 |
+
if validate_rule(rule, reaction):
|
288 |
+
rule.meta["reactor_validation"] = "passed"
|
289 |
+
else:
|
290 |
+
rule.meta["reactor_validation"] = "failed"
|
291 |
+
|
292 |
+
return rule
|
293 |
+
|
294 |
+
|
295 |
+
def add_environment_atoms(cgr, center_atoms, environment_atom_count):
|
296 |
+
"""
|
297 |
+
Adds environment atoms to the set of center atoms based on the specified depth.
|
298 |
+
|
299 |
+
:param cgr: A complete graph representation of a reaction (ReactionContainer object).
|
300 |
+
:param center_atoms: A set of atom identifiers representing the center atoms of the reaction.
|
301 |
+
:param environment_atom_count: An integer specifying the depth of the environment around
|
302 |
+
the reaction center to be included. If it's 0, only the
|
303 |
+
reaction center is included. If it's 1, the first layer of
|
304 |
+
surrounding atoms is included, and so on.
|
305 |
+
:return: A set of atom identifiers including the center atoms and their environment atoms
|
306 |
+
up to the specified depth. If environment_atom_count is 0, the original set of
|
307 |
+
center atoms is returned unchanged.
|
308 |
+
"""
|
309 |
+
if environment_atom_count:
|
310 |
+
env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count)
|
311 |
+
# Combine the original center atoms with the new environment atoms
|
312 |
+
return center_atoms | set(env_cgr)
|
313 |
+
|
314 |
+
# If no environment is to be included, return the original center atoms
|
315 |
+
return center_atoms
|
316 |
+
|
317 |
+
|
318 |
+
def add_functional_groups(reaction, center_atoms, func_groups_list):
|
319 |
+
"""
|
320 |
+
Augments the set of rule atoms with functional groups if specified.
|
321 |
+
|
322 |
+
:param reaction: The reaction object (ReactionContainer) from which molecules are extracted.
|
323 |
+
:param center_atoms: A set of atom identifiers representing the center atoms of the reaction.
|
324 |
+
:param func_groups_list: A list of functional group objects (MoleculeContainer or QueryContainer)
|
325 |
+
to be considered when including functional groups. These objects define
|
326 |
+
the structure of the functional groups to be included.
|
327 |
+
:return: A set of atom identifiers representing the rule atoms, including atoms from the
|
328 |
+
specified functional groups if include_func_groups is True. If include_func_groups
|
329 |
+
is False, the original set of center atoms is returned.
|
330 |
+
"""
|
331 |
+
rule_atoms = center_atoms.copy()
|
332 |
+
# Iterate over each molecule in the reaction
|
333 |
+
for molecule in reaction.molecules():
|
334 |
+
# For each functional group specified in the list
|
335 |
+
for func_group in func_groups_list:
|
336 |
+
# Find mappings of the functional group in the molecule
|
337 |
+
for mapping in func_group.get_mapping(molecule):
|
338 |
+
# Remap the functional group based on the found mapping
|
339 |
+
func_group.remap(mapping)
|
340 |
+
# If the functional group intersects with center atoms, include it
|
341 |
+
if set(func_group.atoms_numbers) & center_atoms:
|
342 |
+
rule_atoms |= set(func_group.atoms_numbers)
|
343 |
+
# Reset the mapping to its original state for the next iteration
|
344 |
+
func_group.remap({v: k for k, v in mapping.items()})
|
345 |
+
return rule_atoms
|
346 |
+
|
347 |
+
|
348 |
+
def add_ring_structures(cgr, rule_atoms):
|
349 |
+
"""
|
350 |
+
Appends ring structures to the set of rule atoms if they intersect with the reaction center atoms.
|
351 |
+
|
352 |
+
:param cgr: A condensed graph representation of a reaction (CGRContainer object).
|
353 |
+
:param rule_atoms: A set of atom identifiers representing the center atoms of the reaction.
|
354 |
+
:return: A set of atom identifiers including the original rule atoms and the included ring structures.
|
355 |
+
"""
|
356 |
+
for ring in cgr.sssr:
|
357 |
+
# Check if the current ring intersects with the set of rule atoms
|
358 |
+
if set(ring) & rule_atoms:
|
359 |
+
# If the intersection exists, include all atoms in the ring to the rule atoms
|
360 |
+
rule_atoms |= set(ring)
|
361 |
+
return rule_atoms
|
362 |
+
|
363 |
+
|
364 |
+
def add_leaving_incoming_groups(
|
365 |
+
reaction, rule_atoms, keep_leaving_groups, keep_incoming_groups
|
366 |
+
):
|
367 |
+
"""
|
368 |
+
Identifies and includes leaving and incoming groups to the rule atoms based on specified flags.
|
369 |
+
|
370 |
+
:param reaction: The reaction object (ReactionContainer) from which leaving and incoming groups are extracted.
|
371 |
+
:param rule_atoms: A set of atom identifiers representing the center atoms of the reaction.
|
372 |
+
:param keep_leaving_groups: A boolean flag indicating whether to include leaving groups in the rule.
|
373 |
+
:param keep_incoming_groups: A boolean flag indicating whether to include incoming groups in the rule.
|
374 |
+
:return: Updated set of rule atoms including leaving and incoming groups if specified, and metadata about added groups.
|
375 |
+
"""
|
376 |
+
meta_debug = {"leaving": set(), "incoming": set()}
|
377 |
+
|
378 |
+
# Extract atoms from reactants and products
|
379 |
+
reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant}
|
380 |
+
product_atoms = {atom for product in reaction.products for atom in product}
|
381 |
+
|
382 |
+
# Identify leaving groups (reactant atoms not in products)
|
383 |
+
if keep_leaving_groups:
|
384 |
+
leaving_atoms = reactant_atoms - product_atoms
|
385 |
+
new_leaving_atoms = leaving_atoms - rule_atoms
|
386 |
+
# Include leaving atoms in the rule atoms
|
387 |
+
rule_atoms |= leaving_atoms
|
388 |
+
# Add leaving atoms to metadata
|
389 |
+
meta_debug["leaving"] |= new_leaving_atoms
|
390 |
+
|
391 |
+
# Identify incoming groups (product atoms not in reactants)
|
392 |
+
if keep_incoming_groups:
|
393 |
+
incoming_atoms = product_atoms - reactant_atoms
|
394 |
+
new_incoming_atoms = incoming_atoms - rule_atoms
|
395 |
+
# Include incoming atoms in the rule atoms
|
396 |
+
rule_atoms |= incoming_atoms
|
397 |
+
# Add incoming atoms to metadata
|
398 |
+
meta_debug["incoming"] |= new_incoming_atoms
|
399 |
+
|
400 |
+
return rule_atoms, meta_debug
|
401 |
+
|
402 |
+
|
403 |
+
def clean_molecules(
|
404 |
+
rule_molecules: Iterable[QueryContainer],
|
405 |
+
reaction_molecules: Iterable[MoleculeContainer],
|
406 |
+
reaction_center_atoms: Set[int],
|
407 |
+
atom_retention_details: Dict[str, Dict[str, bool]],
|
408 |
+
) -> List[QueryContainer]:
|
409 |
+
"""
|
410 |
+
Cleans rule molecules by removing specified information about atoms based on retention details provided.
|
411 |
+
|
412 |
+
:param rule_molecules: A list of query container objects representing the rule molecules.
|
413 |
+
:param reaction_molecules: A list of molecule container objects involved in the reaction.
|
414 |
+
:param reaction_center_atoms: A set of integers representing atom numbers in the reaction center.
|
415 |
+
:param atom_retention_details: A dictionary specifying what atom information to retain or remove.
|
416 |
+
This dictionary should have two keys: "reaction_center" and "environment",
|
417 |
+
each mapping to another dictionary. The nested dictionaries should have
|
418 |
+
keys representing atom attributes (like "neighbors", "hybridization",
|
419 |
+
"implicit_hydrogens", "ring_sizes") and boolean values. A value of True
|
420 |
+
indicates that the corresponding attribute should be retained,
|
421 |
+
while False indicates it should be removed from the atom.
|
422 |
+
|
423 |
+
For example:
|
424 |
+
{
|
425 |
+
"reaction_center": {"neighbors": True, "hybridization": False, ...},
|
426 |
+
"environment": {"neighbors": True, "implicit_hydrogens": False, ...}
|
427 |
+
}
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
A list of QueryContainer objects representing the cleaned rule molecules.
|
431 |
+
"""
|
432 |
+
cleaned_rule_molecules = []
|
433 |
+
|
434 |
+
for rule_molecule in rule_molecules:
|
435 |
+
for reaction_molecule in reaction_molecules:
|
436 |
+
if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers):
|
437 |
+
query_reaction_molecule = reaction_molecule.substructure(
|
438 |
+
reaction_molecule, as_query=True
|
439 |
+
)
|
440 |
+
query_rule_molecule = query_reaction_molecule.substructure(
|
441 |
+
rule_molecule
|
442 |
+
)
|
443 |
+
|
444 |
+
# Clean environment atoms
|
445 |
+
if not all(
|
446 |
+
atom_retention_details["environment"].values()
|
447 |
+
): # if everything True, we keep all marks
|
448 |
+
local_environment_atoms = (
|
449 |
+
set(rule_molecule.atoms_numbers) - reaction_center_atoms
|
450 |
+
)
|
451 |
+
for atom_number in local_environment_atoms:
|
452 |
+
query_rule_molecule = clean_atom(
|
453 |
+
query_rule_molecule,
|
454 |
+
atom_retention_details["environment"],
|
455 |
+
atom_number,
|
456 |
+
)
|
457 |
+
|
458 |
+
# Clean reaction center atoms
|
459 |
+
if not all(
|
460 |
+
atom_retention_details["reaction_center"].values()
|
461 |
+
): # if everything True, we keep all marks
|
462 |
+
local_reaction_center_atoms = (
|
463 |
+
set(rule_molecule.atoms_numbers) & reaction_center_atoms
|
464 |
+
)
|
465 |
+
for atom_number in local_reaction_center_atoms:
|
466 |
+
query_rule_molecule = clean_atom(
|
467 |
+
query_rule_molecule,
|
468 |
+
atom_retention_details["reaction_center"],
|
469 |
+
atom_number,
|
470 |
+
)
|
471 |
+
|
472 |
+
cleaned_rule_molecules.append(query_rule_molecule)
|
473 |
+
break
|
474 |
+
|
475 |
+
return cleaned_rule_molecules
|
476 |
+
|
477 |
+
|
478 |
+
def clean_atom(
|
479 |
+
query_molecule: QueryContainer,
|
480 |
+
attributes_to_keep: Dict[str, bool],
|
481 |
+
atom_number: int,
|
482 |
+
) -> QueryContainer:
|
483 |
+
"""
|
484 |
+
Removes specified information from a given atom in a query molecule.
|
485 |
+
|
486 |
+
:param query_molecule: The QueryContainer of molecule.
|
487 |
+
:param attributes_to_keep: Dictionary indicating which attributes to keep in the atom.
|
488 |
+
The keys should be strings representing the attribute names, and
|
489 |
+
the values should be booleans indicating whether to retain (True)
|
490 |
+
or remove (False) that attribute. Expected keys are:
|
491 |
+
- "neighbors": Indicates if neighbors of the atom should be removed.
|
492 |
+
- "hybridization": Indicates if hybridization information of the atom should be removed.
|
493 |
+
- "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed.
|
494 |
+
- "ring_sizes": Indicates if ring size information of the atom should be removed.
|
495 |
+
:param atom_number: The number of the atom to be modified in the query molecule.
|
496 |
+
"""
|
497 |
+
target_atom = query_molecule.atom(atom_number)
|
498 |
+
|
499 |
+
if not attributes_to_keep["neighbors"]:
|
500 |
+
target_atom.neighbors = None
|
501 |
+
if not attributes_to_keep["hybridization"]:
|
502 |
+
target_atom.hybridization = None
|
503 |
+
if not attributes_to_keep["implicit_hydrogens"]:
|
504 |
+
target_atom.implicit_hydrogens = None
|
505 |
+
if not attributes_to_keep["ring_sizes"]:
|
506 |
+
target_atom.ring_sizes = None
|
507 |
+
|
508 |
+
return query_molecule
|
509 |
+
|
510 |
+
|
511 |
+
def create_substructures_and_reagents(
|
512 |
+
reaction, rule_atoms, as_query_container, keep_reagents
|
513 |
+
):
|
514 |
+
"""
|
515 |
+
Creates substructures for reactants and products, and optionally includes reagents, based on specified parameters.
|
516 |
+
|
517 |
+
:param reaction: The reaction object (ReactionContainer) from which to extract substructures. This object
|
518 |
+
represents a chemical reaction with specified reactants, products, and possibly reagents.
|
519 |
+
:param rule_atoms: A set of atom identifiers that define the rule atoms. These are used to identify relevant
|
520 |
+
substructures in reactants and products.
|
521 |
+
:param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers.
|
522 |
+
Query containers are used for pattern matching in chemical structures.
|
523 |
+
:param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures.
|
524 |
+
Reagents are additional substances that are present in the reaction but are not reactants or products.
|
525 |
+
|
526 |
+
:return: A tuple containing three elements:
|
527 |
+
- A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms.
|
528 |
+
- A list of product substructures, each corresponding to a part of the products that matches the rule atoms.
|
529 |
+
- A list of reagents, included as is or as substructures, depending on the as_query_container flag.
|
530 |
+
|
531 |
+
The function processes the reaction to create substructures for reactants and products based on the rule atoms.
|
532 |
+
It also handles the inclusion of reagents based on the keep_reagents flag and converts these structures to query
|
533 |
+
containers if required.
|
534 |
+
"""
|
535 |
+
reactant_substructures = [
|
536 |
+
reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers))
|
537 |
+
for reactant in reaction.reactants
|
538 |
+
if rule_atoms.intersection(reactant.atoms_numbers)
|
539 |
+
]
|
540 |
+
|
541 |
+
product_substructures = [
|
542 |
+
product.substructure(rule_atoms.intersection(product.atoms_numbers))
|
543 |
+
for product in reaction.products
|
544 |
+
if rule_atoms.intersection(product.atoms_numbers)
|
545 |
+
]
|
546 |
+
|
547 |
+
reagents = []
|
548 |
+
if keep_reagents:
|
549 |
+
if as_query_container:
|
550 |
+
reagents = [
|
551 |
+
reagent.substructure(reagent, as_query=True)
|
552 |
+
for reagent in reaction.reagents
|
553 |
+
]
|
554 |
+
else:
|
555 |
+
reagents = reaction.reagents
|
556 |
+
|
557 |
+
return reactant_substructures, product_substructures, reagents
|
558 |
+
|
559 |
+
|
560 |
+
def assemble_final_rule(
|
561 |
+
reactant_substructures,
|
562 |
+
product_substructures,
|
563 |
+
reagents,
|
564 |
+
meta_debug,
|
565 |
+
keep_metadata,
|
566 |
+
reaction,
|
567 |
+
):
|
568 |
+
"""
|
569 |
+
Assembles the final reaction rule from the provided substructures and metadata.
|
570 |
+
|
571 |
+
:param reactant_substructures: A list of substructures derived from the reactants of the reaction.
|
572 |
+
These substructures represent parts of reactants that are relevant to the rule.
|
573 |
+
:param product_substructures: A list of substructures derived from the products of the reaction.
|
574 |
+
These substructures represent parts of products that are relevant to the rule.
|
575 |
+
:param reagents: A list of reagents involved in the reaction. These may be included as-is or as substructures,
|
576 |
+
depending on earlier processing steps.
|
577 |
+
:param meta_debug: A dictionary containing additional metadata about the reaction, such as leaving and incoming groups.
|
578 |
+
:param keep_metadata: A boolean flag indicating whether to retain the metadata associated with the reaction in the rule.
|
579 |
+
:param reaction: The original reaction object (ReactionContainer) from which the rule is being created.
|
580 |
+
|
581 |
+
:return: A ReactionContainer object representing the assembled reaction rule. This container includes
|
582 |
+
the reactant and product substructures, reagents, and any additional metadata if keep_metadata is True.
|
583 |
+
|
584 |
+
This function brings together the various components of a reaction rule, including reactant and product substructures,
|
585 |
+
reagents, and metadata. It creates a comprehensive representation of the reaction rule, which can be used for further
|
586 |
+
processing or analysis.
|
587 |
+
"""
|
588 |
+
rule_metadata = meta_debug if keep_metadata else {}
|
589 |
+
rule_metadata.update(reaction.meta if keep_metadata else {})
|
590 |
+
|
591 |
+
rule = ReactionContainer(
|
592 |
+
reactant_substructures, product_substructures, reagents, rule_metadata
|
593 |
+
)
|
594 |
+
|
595 |
+
if keep_metadata:
|
596 |
+
rule.name = reaction.name
|
597 |
+
|
598 |
+
rule.flush_cache()
|
599 |
+
return rule
|
600 |
+
|
601 |
+
|
602 |
+
def validate_rule(rule: ReactionContainer, reaction: ReactionContainer):
|
603 |
+
"""
|
604 |
+
Validates a reaction rule by ensuring it can correctly generate the products from the reactants.
|
605 |
+
|
606 |
+
:param rule: The reaction rule to be validated. This is a ReactionContainer object representing a chemical reaction rule,
|
607 |
+
which includes the necessary information to perform a reaction.
|
608 |
+
:param reaction: The original reaction object (ReactionContainer) against which the rule is to be validated. This object
|
609 |
+
contains the actual reactants and products of the reaction.
|
610 |
+
|
611 |
+
:return: The validated rule if the rule correctly generates the products from the reactants.
|
612 |
+
|
613 |
+
:raises ValueError: If the rule does not correctly generate the products from the reactants, indicating
|
614 |
+
an incorrect or incomplete rule.
|
615 |
+
|
616 |
+
The function uses a chemical reactor to simulate the reaction based on the provided rule. It then compares
|
617 |
+
the products generated by the simulation with the actual products of the reaction. If they match, the rule
|
618 |
+
is considered valid. If not, a ValueError is raised, indicating an issue with the rule.
|
619 |
+
"""
|
620 |
+
# Create a reactor with the given rule
|
621 |
+
reactor = Reactor(rule)
|
622 |
+
try:
|
623 |
+
for result_reaction in reactor(reaction.reactants):
|
624 |
+
result_products = []
|
625 |
+
for result_product in result_reaction.products:
|
626 |
+
tmp = result_product.copy()
|
627 |
+
try:
|
628 |
+
tmp.kekule()
|
629 |
+
if tmp.check_valence():
|
630 |
+
continue
|
631 |
+
except InvalidAromaticRing:
|
632 |
+
continue
|
633 |
+
result_products.append(result_product)
|
634 |
+
if set(reaction.products) == set(result_products) and len(
|
635 |
+
reaction.products
|
636 |
+
) == len(result_products):
|
637 |
+
return True
|
638 |
+
except (KeyError, IndexError):
|
639 |
+
# KeyError - iteration over reactor is finished and products are different from the original reaction
|
640 |
+
# IndexError - mistake in __contract_ions, possibly problems with charges in rule?
|
641 |
+
return False
|
642 |
+
|
643 |
+
|
644 |
+
def sort_rules(
|
645 |
+
rules_stats: Dict[ReactionContainer, List[int]],
|
646 |
+
min_popularity: int = 3,
|
647 |
+
single_reactant_only: bool = True,
|
648 |
+
) -> List[Tuple[ReactionContainer, List[int]]]:
|
649 |
+
"""
|
650 |
+
Sorts reaction rules based on their popularity and validation status.
|
651 |
+
|
652 |
+
This function sorts the given rules according to their popularity (i.e., the number of times they have been
|
653 |
+
applied) and filters out rules that haven't passed reactor validation or are less popular than the specified
|
654 |
+
minimum popularity threshold.
|
655 |
+
|
656 |
+
:param rules_stats: A dictionary where each key is a reaction rule and the value is a list of integers.
|
657 |
+
Each integer represents an index where the rule was applied.
|
658 |
+
:type rules_stats: Dict[ReactionContainer, List[int]]
|
659 |
+
|
660 |
+
:param min_popularity: The minimum number of times a rule must be applied to be considered. Default is 3.
|
661 |
+
:type min_popularity: int
|
662 |
+
|
663 |
+
:param single_reactant_only: Whether to keep only reaction rules with a single molecule on the right side
|
664 |
+
of reaction arrow. Default is True.
|
665 |
+
|
666 |
+
:return: A list of tuples, where each tuple contains a reaction rule and a list of indices representing
|
667 |
+
the rule's applications. The list is sorted in descending order of the rule's popularity.
|
668 |
+
:rtype: List[Tuple[ReactionContainer, List[int]]]
|
669 |
+
"""
|
670 |
+
return sorted(
|
671 |
+
(
|
672 |
+
(rule, indices)
|
673 |
+
for rule, indices in rules_stats.items()
|
674 |
+
if len(indices) >= min_popularity
|
675 |
+
and rule.meta["reactor_validation"] == "passed"
|
676 |
+
and (not single_reactant_only or len(rule.reactants) == 1)
|
677 |
+
),
|
678 |
+
key=lambda x: -len(x[1]),
|
679 |
+
)
|
SynTool/chem/reaction_rules/manual/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .decompositions import rules as d_rules
|
2 |
+
from .transformations import rules as t_rules
|
3 |
+
|
4 |
+
hardcoded_rules = t_rules + d_rules
|
5 |
+
|
6 |
+
__all__ = ["hardcoded_rules"]
|
SynTool/chem/reaction_rules/manual/decompositions.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing hardcoded decomposition reaction rules
|
3 |
+
"""
|
4 |
+
|
5 |
+
from CGRtools import QueryContainer, ReactionContainer
|
6 |
+
from CGRtools.periodictable import ListElement
|
7 |
+
|
8 |
+
rules = []
|
9 |
+
|
10 |
+
|
11 |
+
def prepare():
|
12 |
+
"""
|
13 |
+
Creates and returns three query containers and appends a reaction container to the "rules" list
|
14 |
+
"""
|
15 |
+
q_ = QueryContainer()
|
16 |
+
p1_ = QueryContainer()
|
17 |
+
p2_ = QueryContainer()
|
18 |
+
rules.append(ReactionContainer((q_,), (p1_, p2_)))
|
19 |
+
return q_, p1_, p2_
|
20 |
+
|
21 |
+
|
22 |
+
# R-amide/ester formation
|
23 |
+
# [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O]
|
24 |
+
q, p1, p2 = prepare()
|
25 |
+
q.add_atom('C')
|
26 |
+
q.add_atom('C')
|
27 |
+
q.add_atom('O')
|
28 |
+
q.add_atom(ListElement(['N', 'O']), hybridization=1, neighbors=(2, 3))
|
29 |
+
q.add_bond(1, 2, 1)
|
30 |
+
q.add_bond(2, 3, 2)
|
31 |
+
q.add_bond(2, 4, 1)
|
32 |
+
|
33 |
+
p1.add_atom('C')
|
34 |
+
p1.add_atom('C')
|
35 |
+
p1.add_atom('O')
|
36 |
+
p1.add_atom('O', _map=5)
|
37 |
+
p1.add_bond(1, 2, 1)
|
38 |
+
p1.add_bond(2, 3, 2)
|
39 |
+
p1.add_bond(2, 5, 1)
|
40 |
+
|
41 |
+
p2.add_atom('A', _map=4)
|
42 |
+
|
43 |
+
# acyl group addition with aromatic carbon's case (Friedel-Crafts)
|
44 |
+
# [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O]
|
45 |
+
q, p1, p2 = prepare()
|
46 |
+
q.add_atom('C')
|
47 |
+
q.add_atom('C')
|
48 |
+
q.add_atom('O')
|
49 |
+
q.add_atom('C', hybridization=4)
|
50 |
+
q.add_bond(1, 2, 1)
|
51 |
+
q.add_bond(2, 3, 2)
|
52 |
+
q.add_bond(2, 4, 1)
|
53 |
+
|
54 |
+
p1.add_atom('C')
|
55 |
+
p1.add_atom('C')
|
56 |
+
p1.add_atom('O')
|
57 |
+
p1.add_atom('Cl', _map=5)
|
58 |
+
p1.add_bond(1, 2, 1)
|
59 |
+
p1.add_bond(2, 3, 2)
|
60 |
+
p1.add_bond(2, 5, 1)
|
61 |
+
|
62 |
+
p2.add_atom('C', _map=4)
|
63 |
+
|
64 |
+
# Williamson reaction
|
65 |
+
# [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O]
|
66 |
+
q, p1, p2 = prepare()
|
67 |
+
q.add_atom('C', hybridization=4)
|
68 |
+
q.add_atom('O')
|
69 |
+
q.add_atom('C', hybridization=1, heteroatoms=1)
|
70 |
+
q.add_bond(1, 2, 1)
|
71 |
+
q.add_bond(2, 3, 1)
|
72 |
+
|
73 |
+
p1.add_atom('C')
|
74 |
+
p1.add_atom('O')
|
75 |
+
p1.add_bond(1, 2, 1)
|
76 |
+
|
77 |
+
p2.add_atom('C', _map=3)
|
78 |
+
p2.add_atom('Br')
|
79 |
+
p2.add_bond(3, 4, 1)
|
80 |
+
|
81 |
+
# Buchwald-Hartwig amination
|
82 |
+
# [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N]
|
83 |
+
q, p1, p2 = prepare()
|
84 |
+
q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3))
|
85 |
+
q.add_atom('C', hybridization=4)
|
86 |
+
q.add_bond(1, 2, 1)
|
87 |
+
|
88 |
+
p1.add_atom('C', _map=2)
|
89 |
+
p1.add_atom('Br')
|
90 |
+
p1.add_bond(2, 3, 1)
|
91 |
+
|
92 |
+
p2.add_atom('N')
|
93 |
+
|
94 |
+
# imidazole imine atom's alkylation
|
95 |
+
# [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N]
|
96 |
+
q, p1, p2 = prepare()
|
97 |
+
q.add_atom('N', rings_sizes=5)
|
98 |
+
q.add_atom('C', rings_sizes=5)
|
99 |
+
q.add_atom('N', rings_sizes=5, neighbors=2)
|
100 |
+
q.add_atom('C', hybridization=1, heteroatoms=(1, 2))
|
101 |
+
q.add_bond(1, 2, 4)
|
102 |
+
q.add_bond(2, 3, 4)
|
103 |
+
q.add_bond(1, 4, 1)
|
104 |
+
|
105 |
+
p1.add_atom('N')
|
106 |
+
p1.add_atom('C')
|
107 |
+
p1.add_atom('N')
|
108 |
+
p1.add_bond(1, 2, 4)
|
109 |
+
p1.add_bond(2, 3, 4)
|
110 |
+
|
111 |
+
p2.add_atom('C', _map=4)
|
112 |
+
p2.add_atom('Br')
|
113 |
+
p2.add_bond(4, 5, 1)
|
114 |
+
|
115 |
+
# Knoevenagel condensation (nitryl and carboxyl case)
|
116 |
+
# [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O]
|
117 |
+
q, p1, p2 = prepare()
|
118 |
+
q.add_atom('C')
|
119 |
+
q.add_atom('C')
|
120 |
+
q.add_atom('C')
|
121 |
+
q.add_atom('N')
|
122 |
+
q.add_atom('C')
|
123 |
+
q.add_atom('O')
|
124 |
+
q.add_atom('O')
|
125 |
+
q.add_bond(1, 2, 2)
|
126 |
+
q.add_bond(2, 3, 1)
|
127 |
+
q.add_bond(3, 4, 3)
|
128 |
+
q.add_bond(2, 5, 1)
|
129 |
+
q.add_bond(5, 6, 2)
|
130 |
+
q.add_bond(5, 7, 1)
|
131 |
+
|
132 |
+
p1.add_atom('C', _map=2)
|
133 |
+
p1.add_atom('C')
|
134 |
+
p1.add_atom('N')
|
135 |
+
p1.add_atom('C')
|
136 |
+
p1.add_atom('O')
|
137 |
+
p1.add_atom('O')
|
138 |
+
p1.add_bond(2, 3, 1)
|
139 |
+
p1.add_bond(3, 4, 3)
|
140 |
+
p1.add_bond(2, 5, 1)
|
141 |
+
p1.add_bond(5, 6, 2)
|
142 |
+
p1.add_bond(5, 7, 1)
|
143 |
+
|
144 |
+
p2.add_atom('C', _map=1)
|
145 |
+
p2.add_atom('O', _map=8)
|
146 |
+
p2.add_bond(1, 8, 2)
|
147 |
+
|
148 |
+
# Knoevenagel condensation (double nitryl case)
|
149 |
+
# [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N]
|
150 |
+
q, p1, p2 = prepare()
|
151 |
+
q.add_atom('C')
|
152 |
+
q.add_atom('C')
|
153 |
+
q.add_atom('C')
|
154 |
+
q.add_atom('N')
|
155 |
+
q.add_atom('C')
|
156 |
+
q.add_atom('N')
|
157 |
+
q.add_bond(1, 2, 2)
|
158 |
+
q.add_bond(2, 3, 1)
|
159 |
+
q.add_bond(3, 4, 3)
|
160 |
+
q.add_bond(2, 5, 1)
|
161 |
+
q.add_bond(5, 6, 3)
|
162 |
+
|
163 |
+
p1.add_atom('C', _map=2)
|
164 |
+
p1.add_atom('C')
|
165 |
+
p1.add_atom('N')
|
166 |
+
p1.add_atom('C')
|
167 |
+
p1.add_atom('N')
|
168 |
+
p1.add_bond(2, 3, 1)
|
169 |
+
p1.add_bond(3, 4, 3)
|
170 |
+
p1.add_bond(2, 5, 1)
|
171 |
+
p1.add_bond(5, 6, 3)
|
172 |
+
|
173 |
+
p2.add_atom('C', _map=1)
|
174 |
+
p2.add_atom('O', _map=8)
|
175 |
+
p2.add_bond(1, 8, 2)
|
176 |
+
|
177 |
+
# Knoevenagel condensation (double carboxyl case)
|
178 |
+
# [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O]
|
179 |
+
q, p1, p2 = prepare()
|
180 |
+
q.add_atom('C')
|
181 |
+
q.add_atom('C')
|
182 |
+
q.add_atom('C')
|
183 |
+
q.add_atom('O')
|
184 |
+
q.add_atom('O')
|
185 |
+
q.add_atom('C')
|
186 |
+
q.add_atom('O')
|
187 |
+
q.add_atom('O')
|
188 |
+
q.add_bond(1, 2, 2)
|
189 |
+
q.add_bond(2, 3, 1)
|
190 |
+
q.add_bond(3, 4, 2)
|
191 |
+
q.add_bond(3, 5, 1)
|
192 |
+
q.add_bond(2, 6, 1)
|
193 |
+
q.add_bond(6, 7, 2)
|
194 |
+
q.add_bond(6, 8, 1)
|
195 |
+
|
196 |
+
p1.add_atom('C', _map=2)
|
197 |
+
p1.add_atom('C')
|
198 |
+
p1.add_atom('O')
|
199 |
+
p1.add_atom('O')
|
200 |
+
p1.add_atom('C')
|
201 |
+
p1.add_atom('O')
|
202 |
+
p1.add_atom('O')
|
203 |
+
p1.add_bond(2, 3, 1)
|
204 |
+
p1.add_bond(3, 4, 2)
|
205 |
+
p1.add_bond(3, 5, 1)
|
206 |
+
p1.add_bond(2, 6, 1)
|
207 |
+
p1.add_bond(6, 7, 2)
|
208 |
+
p1.add_bond(6, 8, 1)
|
209 |
+
|
210 |
+
p2.add_atom('C', _map=1)
|
211 |
+
p2.add_atom('O', _map=9)
|
212 |
+
p2.add_bond(1, 9, 2)
|
213 |
+
|
214 |
+
# heterocyclization with guanidine
|
215 |
+
# [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O]
|
216 |
+
q, p1, p2 = prepare()
|
217 |
+
q.add_atom('C')
|
218 |
+
q.add_atom('N', heteroatoms=0, hybridization=1)
|
219 |
+
q.add_atom('N')
|
220 |
+
q.add_atom('C')
|
221 |
+
q.add_atom('N', neighbors=1)
|
222 |
+
q.add_atom('C', heteroatoms=0)
|
223 |
+
q.add_atom('N')
|
224 |
+
q.add_atom('C')
|
225 |
+
q.add_atom('O', neighbors=1)
|
226 |
+
q.add_bond(1, 2, 1)
|
227 |
+
q.add_bond(1, 3, 4)
|
228 |
+
q.add_bond(3, 4, 4)
|
229 |
+
q.add_bond(4, 5, 1)
|
230 |
+
q.add_bond(4, 6, 4)
|
231 |
+
q.add_bond(1, 7, 4)
|
232 |
+
q.add_bond(7, 8, 4)
|
233 |
+
q.add_bond(8, 9, 1)
|
234 |
+
|
235 |
+
p1.add_atom('C')
|
236 |
+
p1.add_atom('N')
|
237 |
+
p1.add_atom('N')
|
238 |
+
p1.add_atom('N', _map=7)
|
239 |
+
p1.add_bond(1, 2, 1)
|
240 |
+
p1.add_bond(1, 3, 2)
|
241 |
+
p1.add_bond(1, 7, 1)
|
242 |
+
|
243 |
+
p2.add_atom('C', _map=4)
|
244 |
+
p2.add_atom('N')
|
245 |
+
p2.add_atom('C')
|
246 |
+
p2.add_atom('C', _map=8)
|
247 |
+
p2.add_atom('O', _map=9)
|
248 |
+
p2.add_atom('O')
|
249 |
+
p2.add_bond(4, 5, 3)
|
250 |
+
p2.add_bond(4, 6, 1)
|
251 |
+
p2.add_bond(6, 8, 1)
|
252 |
+
p2.add_bond(8, 9, 2)
|
253 |
+
p2.add_bond(8, 10, 1)
|
254 |
+
|
255 |
+
# alkylation of amine
|
256 |
+
# [C]-[N]-[C]>>[C]-[N].[C]-[Br]
|
257 |
+
q, p1, p2 = prepare()
|
258 |
+
q.add_atom('C')
|
259 |
+
q.add_atom('N')
|
260 |
+
q.add_atom('C')
|
261 |
+
q.add_atom('C')
|
262 |
+
q.add_bond(1, 2, 1)
|
263 |
+
q.add_bond(2, 3, 1)
|
264 |
+
q.add_bond(2, 4, 1)
|
265 |
+
|
266 |
+
p1.add_atom('C')
|
267 |
+
p1.add_atom('N')
|
268 |
+
p1.add_atom('C')
|
269 |
+
p1.add_bond(1, 2, 1)
|
270 |
+
p1.add_bond(2, 3, 1)
|
271 |
+
|
272 |
+
p2.add_atom('C', _map=4)
|
273 |
+
p2.add_atom('Cl')
|
274 |
+
p2.add_bond(4, 5, 1)
|
275 |
+
|
276 |
+
# Synthesis of guanidines
|
277 |
+
#
|
278 |
+
q, p1, p2 = prepare()
|
279 |
+
q.add_atom('N')
|
280 |
+
q.add_atom('C')
|
281 |
+
q.add_atom('N', hybridization=1)
|
282 |
+
q.add_atom('N', hybridization=1)
|
283 |
+
q.add_bond(1, 2, 2)
|
284 |
+
q.add_bond(2, 3, 1)
|
285 |
+
q.add_bond(2, 4, 1)
|
286 |
+
|
287 |
+
p1.add_atom('N')
|
288 |
+
p1.add_atom('C')
|
289 |
+
p1.add_atom('N')
|
290 |
+
p1.add_bond(1, 2, 3)
|
291 |
+
p1.add_bond(2, 3, 1)
|
292 |
+
|
293 |
+
p2.add_atom('N', _map=4)
|
294 |
+
|
295 |
+
# Grignard reaction with nitrile
|
296 |
+
#
|
297 |
+
q, p1, p2 = prepare()
|
298 |
+
q.add_atom('C')
|
299 |
+
q.add_atom('C')
|
300 |
+
q.add_atom('O')
|
301 |
+
q.add_atom('C')
|
302 |
+
q.add_bond(1, 2, 1)
|
303 |
+
q.add_bond(2, 3, 2)
|
304 |
+
q.add_bond(2, 4, 1)
|
305 |
+
|
306 |
+
p1.add_atom('C')
|
307 |
+
p1.add_atom('C')
|
308 |
+
p1.add_atom('N')
|
309 |
+
p1.add_bond(1, 2, 1)
|
310 |
+
p1.add_bond(2, 3, 3)
|
311 |
+
|
312 |
+
p2.add_atom('C', _map=4)
|
313 |
+
p2.add_atom('Br')
|
314 |
+
p2.add_bond(4, 5, 1)
|
315 |
+
|
316 |
+
# Alkylation of alpha-carbon atom of nitrile
|
317 |
+
#
|
318 |
+
q, p1, p2 = prepare()
|
319 |
+
q.add_atom('N')
|
320 |
+
q.add_atom('C')
|
321 |
+
q.add_atom('C', neighbors=(3, 4))
|
322 |
+
q.add_atom('C', hybridization=1)
|
323 |
+
q.add_bond(1, 2, 3)
|
324 |
+
q.add_bond(2, 3, 1)
|
325 |
+
q.add_bond(3, 4, 1)
|
326 |
+
|
327 |
+
p1.add_atom('N')
|
328 |
+
p1.add_atom('C')
|
329 |
+
p1.add_atom('C')
|
330 |
+
p1.add_bond(1, 2, 3)
|
331 |
+
p1.add_bond(2, 3, 1)
|
332 |
+
|
333 |
+
p2.add_atom('C', _map=4)
|
334 |
+
p2.add_atom('Cl')
|
335 |
+
p2.add_bond(4, 5, 1)
|
336 |
+
|
337 |
+
# Gomberg-Bachmann reaction
|
338 |
+
#
|
339 |
+
q, p1, p2 = prepare()
|
340 |
+
q.add_atom('C', hybridization=4, heteroatoms=0)
|
341 |
+
q.add_atom('C', hybridization=4, heteroatoms=0)
|
342 |
+
q.add_bond(1, 2, 1)
|
343 |
+
|
344 |
+
p1.add_atom('C')
|
345 |
+
p1.add_atom('N', _map=3)
|
346 |
+
p1.add_bond(1, 3, 1)
|
347 |
+
|
348 |
+
p2.add_atom('C', _map=2)
|
349 |
+
|
350 |
+
# Cyclocondensation
|
351 |
+
#
|
352 |
+
q, p1, p2 = prepare()
|
353 |
+
q.add_atom('N', neighbors=2)
|
354 |
+
q.add_atom('C')
|
355 |
+
q.add_atom('C')
|
356 |
+
q.add_atom('C')
|
357 |
+
q.add_atom('N')
|
358 |
+
q.add_atom('C')
|
359 |
+
q.add_atom('C')
|
360 |
+
q.add_atom('O', neighbors=1)
|
361 |
+
q.add_bond(1, 2, 1)
|
362 |
+
q.add_bond(2, 3, 1)
|
363 |
+
q.add_bond(3, 4, 1)
|
364 |
+
q.add_bond(4, 5, 2)
|
365 |
+
q.add_bond(5, 6, 1)
|
366 |
+
q.add_bond(6, 7, 1)
|
367 |
+
q.add_bond(7, 8, 2)
|
368 |
+
q.add_bond(1, 7, 1)
|
369 |
+
|
370 |
+
p1.add_atom('N')
|
371 |
+
p1.add_atom('C')
|
372 |
+
p1.add_atom('C')
|
373 |
+
p1.add_atom('C')
|
374 |
+
p1.add_atom('O', _map=9)
|
375 |
+
p1.add_bond(1, 2, 1)
|
376 |
+
p1.add_bond(2, 3, 1)
|
377 |
+
p1.add_bond(3, 4, 1)
|
378 |
+
p1.add_bond(4, 9, 2)
|
379 |
+
|
380 |
+
p2.add_atom('N', _map=5)
|
381 |
+
p2.add_atom('C')
|
382 |
+
p2.add_atom('C')
|
383 |
+
p2.add_atom('O')
|
384 |
+
p2.add_atom('O', _map=10)
|
385 |
+
p2.add_bond(5, 6, 1)
|
386 |
+
p2.add_bond(6, 7, 1)
|
387 |
+
p2.add_bond(7, 8, 2)
|
388 |
+
p2.add_bond(7, 10, 1)
|
389 |
+
|
390 |
+
# heterocyclization dicarboxylic acids
|
391 |
+
#
|
392 |
+
q, p1, p2 = prepare()
|
393 |
+
q.add_atom('C', rings_sizes=(5, 6))
|
394 |
+
q.add_atom('O')
|
395 |
+
q.add_atom(ListElement(['O', 'N']))
|
396 |
+
q.add_atom('C', rings_sizes=(5, 6))
|
397 |
+
q.add_atom('O')
|
398 |
+
q.add_bond(1, 2, 2)
|
399 |
+
q.add_bond(1, 3, 1)
|
400 |
+
q.add_bond(3, 4, 1)
|
401 |
+
q.add_bond(4, 5, 2)
|
402 |
+
|
403 |
+
p1.add_atom('C')
|
404 |
+
p1.add_atom('O')
|
405 |
+
p1.add_atom('O', _map=6)
|
406 |
+
p1.add_bond(1, 2, 2)
|
407 |
+
p1.add_bond(1, 6, 1)
|
408 |
+
|
409 |
+
p2.add_atom('C', _map=4)
|
410 |
+
p2.add_atom('O')
|
411 |
+
p2.add_atom('O', _map=7)
|
412 |
+
p2.add_bond(4, 5, 2)
|
413 |
+
p2.add_bond(4, 7, 1)
|
414 |
+
|
415 |
+
__all__ = ['rules']
|
SynTool/chem/reaction_rules/manual/transformations.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing hardcoded transformation reaction rules
|
3 |
+
"""
|
4 |
+
|
5 |
+
from CGRtools import QueryContainer, ReactionContainer
|
6 |
+
from CGRtools.periodictable import ListElement
|
7 |
+
|
8 |
+
rules = []
|
9 |
+
|
10 |
+
|
11 |
+
def prepare():
|
12 |
+
"""
|
13 |
+
Creates and returns three query containers and appends a reaction container to the "rules" list
|
14 |
+
"""
|
15 |
+
q_ = QueryContainer()
|
16 |
+
p_ = QueryContainer()
|
17 |
+
rules.append(ReactionContainer((q_,), (p_,)))
|
18 |
+
return q_, p_
|
19 |
+
|
20 |
+
|
21 |
+
# aryl nitro reduction
|
22 |
+
# [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O]
|
23 |
+
q, p = prepare()
|
24 |
+
q.add_atom('N', neighbors=1)
|
25 |
+
q.add_atom('C', hybridization=4, heteroatoms=1)
|
26 |
+
q.add_bond(1, 2, 1)
|
27 |
+
|
28 |
+
p.add_atom('N', charge=1)
|
29 |
+
p.add_atom('C')
|
30 |
+
p.add_atom('O', charge=-1)
|
31 |
+
p.add_atom('O')
|
32 |
+
p.add_bond(1, 2, 1)
|
33 |
+
p.add_bond(1, 3, 1)
|
34 |
+
p.add_bond(1, 4, 2)
|
35 |
+
|
36 |
+
# aryl nitration
|
37 |
+
# [O-]-[N+](=[O])-[C;Za;W12]>>[C]
|
38 |
+
q, p = prepare()
|
39 |
+
q.add_atom('N', charge=1)
|
40 |
+
q.add_atom('C', hybridization=4, heteroatoms=(1, 2))
|
41 |
+
q.add_atom('O', charge=-1)
|
42 |
+
q.add_atom('O')
|
43 |
+
q.add_bond(1, 2, 1)
|
44 |
+
q.add_bond(1, 3, 1)
|
45 |
+
q.add_bond(1, 4, 2)
|
46 |
+
|
47 |
+
p.add_atom('C', _map=2)
|
48 |
+
|
49 |
+
# Beckmann rearrangement (oxime -> amide)
|
50 |
+
# [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C]
|
51 |
+
q, p = prepare()
|
52 |
+
q.add_atom('C')
|
53 |
+
q.add_atom('N', neighbors=2)
|
54 |
+
q.add_atom('O')
|
55 |
+
q.add_atom('C')
|
56 |
+
q.add_bond(1, 2, 1)
|
57 |
+
q.add_bond(1, 3, 2)
|
58 |
+
q.add_bond(2, 4, 1)
|
59 |
+
|
60 |
+
p.add_atom('C')
|
61 |
+
p.add_atom('N')
|
62 |
+
p.add_atom('O')
|
63 |
+
p.add_atom('C')
|
64 |
+
p.add_bond(1, 2, 2)
|
65 |
+
p.add_bond(2, 3, 1)
|
66 |
+
p.add_bond(1, 4, 1)
|
67 |
+
|
68 |
+
# aldehydes or ketones into oxime/imine reaction
|
69 |
+
# [C;Zd;W1]=[N]>>[C]=[O]
|
70 |
+
q, p = prepare()
|
71 |
+
q.add_atom('C', hybridization=2, heteroatoms=1)
|
72 |
+
q.add_atom('N')
|
73 |
+
q.add_bond(1, 2, 2)
|
74 |
+
|
75 |
+
p.add_atom('C')
|
76 |
+
p.add_atom('O', _map=3)
|
77 |
+
p.add_bond(1, 3, 2)
|
78 |
+
|
79 |
+
# addition of halogen atom into phenol ring (orto)
|
80 |
+
# [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C]
|
81 |
+
q, p = prepare()
|
82 |
+
q.add_atom(ListElement(['O', 'N']), hybridization=1)
|
83 |
+
q.add_atom('C')
|
84 |
+
q.add_atom('C')
|
85 |
+
q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1)
|
86 |
+
q.add_bond(1, 2, 1)
|
87 |
+
q.add_bond(2, 3, 4)
|
88 |
+
q.add_bond(3, 4, 1)
|
89 |
+
|
90 |
+
p.add_atom('A')
|
91 |
+
p.add_atom('C')
|
92 |
+
p.add_atom('C')
|
93 |
+
p.add_bond(1, 2, 1)
|
94 |
+
p.add_bond(2, 3, 4)
|
95 |
+
|
96 |
+
# addition of halogen atom into phenol ring (para)
|
97 |
+
# [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C]
|
98 |
+
q, p = prepare()
|
99 |
+
q.add_atom(ListElement(['O', 'N']), hybridization=1)
|
100 |
+
q.add_atom('C')
|
101 |
+
q.add_atom('C')
|
102 |
+
q.add_atom('C')
|
103 |
+
q.add_atom('C')
|
104 |
+
q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1)
|
105 |
+
q.add_bond(1, 2, 1)
|
106 |
+
q.add_bond(2, 3, 4)
|
107 |
+
q.add_bond(3, 4, 4)
|
108 |
+
q.add_bond(4, 5, 4)
|
109 |
+
q.add_bond(5, 6, 1)
|
110 |
+
|
111 |
+
p.add_atom('A')
|
112 |
+
p.add_atom('C')
|
113 |
+
p.add_atom('C')
|
114 |
+
p.add_atom('C')
|
115 |
+
p.add_atom('C')
|
116 |
+
p.add_bond(1, 2, 1)
|
117 |
+
p.add_bond(2, 3, 4)
|
118 |
+
p.add_bond(3, 4, 4)
|
119 |
+
p.add_bond(4, 5, 4)
|
120 |
+
|
121 |
+
# hard reduction of Ar-ketones
|
122 |
+
# [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O]
|
123 |
+
q, p = prepare()
|
124 |
+
q.add_atom('C', hybridization=4)
|
125 |
+
q.add_atom('C', hybridization=1, neighbors=2, heteroatoms=0)
|
126 |
+
q.add_bond(1, 2, 1)
|
127 |
+
|
128 |
+
p.add_atom('C')
|
129 |
+
p.add_atom('C')
|
130 |
+
p.add_atom('O')
|
131 |
+
p.add_bond(1, 2, 1)
|
132 |
+
p.add_bond(2, 3, 2)
|
133 |
+
|
134 |
+
# reduction of alpha-hydroxy pyridine
|
135 |
+
# [C;W1]:[N;H0;r6]>>[C](:[N])-[O]
|
136 |
+
q, p = prepare()
|
137 |
+
q.add_atom('C', heteroatoms=1)
|
138 |
+
q.add_atom('N', rings_sizes=6, hydrogens=0)
|
139 |
+
q.add_bond(1, 2, 4)
|
140 |
+
|
141 |
+
p.add_atom('C')
|
142 |
+
p.add_atom('N')
|
143 |
+
p.add_atom('O')
|
144 |
+
p.add_bond(1, 2, 4)
|
145 |
+
p.add_bond(1, 3, 1)
|
146 |
+
|
147 |
+
# Reduction of alkene
|
148 |
+
# [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C]
|
149 |
+
q, p = prepare()
|
150 |
+
q.add_atom('C')
|
151 |
+
q.add_atom('C', heteroatoms=0, neighbors=(2, 3), hybridization=1)
|
152 |
+
q.add_atom('C', heteroatoms=0, neighbors=(1, 2, 3), hybridization=1)
|
153 |
+
q.add_bond(1, 2, 1)
|
154 |
+
q.add_bond(2, 3, 1)
|
155 |
+
|
156 |
+
p.add_atom('C')
|
157 |
+
p.add_atom('C')
|
158 |
+
p.add_atom('C')
|
159 |
+
p.add_bond(1, 2, 1)
|
160 |
+
p.add_bond(2, 3, 2)
|
161 |
+
|
162 |
+
# Kolbe-Schmitt reaction
|
163 |
+
# [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C]
|
164 |
+
q, p = prepare()
|
165 |
+
q.add_atom('O', neighbors=1)
|
166 |
+
q.add_atom('C')
|
167 |
+
q.add_atom('C')
|
168 |
+
q.add_atom('C')
|
169 |
+
q.add_atom('O', neighbors=1)
|
170 |
+
q.add_atom('O')
|
171 |
+
q.add_bond(1, 2, 1)
|
172 |
+
q.add_bond(2, 3, 4)
|
173 |
+
q.add_bond(3, 4, 1)
|
174 |
+
q.add_bond(4, 5, 1)
|
175 |
+
q.add_bond(4, 6, 2)
|
176 |
+
|
177 |
+
p.add_atom('O')
|
178 |
+
p.add_atom('C')
|
179 |
+
p.add_atom('C')
|
180 |
+
p.add_bond(1, 2, 1)
|
181 |
+
p.add_bond(2, 3, 4)
|
182 |
+
|
183 |
+
# reduction of carboxylic acid
|
184 |
+
# [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O]
|
185 |
+
q, p = prepare()
|
186 |
+
q.add_atom('C')
|
187 |
+
q.add_atom('C', neighbors=2)
|
188 |
+
q.add_atom('O', neighbors=1)
|
189 |
+
q.add_bond(1, 2, 1)
|
190 |
+
q.add_bond(2, 3, 1)
|
191 |
+
|
192 |
+
p.add_atom('C')
|
193 |
+
p.add_atom('C')
|
194 |
+
p.add_atom('O')
|
195 |
+
p.add_atom('O')
|
196 |
+
p.add_bond(1, 2, 1)
|
197 |
+
p.add_bond(2, 3, 1)
|
198 |
+
p.add_bond(2, 4, 2)
|
199 |
+
|
200 |
+
# halogenation of alcohols
|
201 |
+
# [C;Zs]-[Cl,Br;D1]>>[C]-[O]
|
202 |
+
q, p = prepare()
|
203 |
+
q.add_atom('C', hybridization=1, heteroatoms=1)
|
204 |
+
q.add_atom(ListElement(['Cl', 'Br']), neighbors=1)
|
205 |
+
q.add_bond(1, 2, 1)
|
206 |
+
|
207 |
+
p.add_atom('C')
|
208 |
+
p.add_atom('O', _map=3)
|
209 |
+
p.add_bond(1, 3, 1)
|
210 |
+
|
211 |
+
# Kolbe nitrilation
|
212 |
+
# [N]#[C]-[C;Zs;W0]>>[Br]-[C]
|
213 |
+
q, p = prepare()
|
214 |
+
q.add_atom('C', heteroatoms=0, hybridization=1)
|
215 |
+
q.add_atom('C')
|
216 |
+
q.add_atom('N')
|
217 |
+
q.add_bond(1, 2, 1)
|
218 |
+
q.add_bond(2, 3, 3)
|
219 |
+
|
220 |
+
p.add_atom('C')
|
221 |
+
p.add_atom('Br', _map=4)
|
222 |
+
p.add_bond(1, 4, 1)
|
223 |
+
|
224 |
+
# Nitrile hydrolysis
|
225 |
+
# [O;D1]-[C]=[O]>>[N]#[C]
|
226 |
+
q, p = prepare()
|
227 |
+
q.add_atom('C')
|
228 |
+
q.add_atom('O', neighbors=1)
|
229 |
+
q.add_atom('O')
|
230 |
+
q.add_bond(1, 2, 1)
|
231 |
+
q.add_bond(1, 3, 2)
|
232 |
+
|
233 |
+
p.add_atom('C')
|
234 |
+
p.add_atom('N', _map=4)
|
235 |
+
p.add_bond(1, 4, 3)
|
236 |
+
|
237 |
+
# sulfamidation
|
238 |
+
# [c]-[S](=[O])(=[O])-[N]>>[c]
|
239 |
+
q, p = prepare()
|
240 |
+
q.add_atom('C', hybridization=4)
|
241 |
+
q.add_atom('S')
|
242 |
+
q.add_atom('O')
|
243 |
+
q.add_atom('O')
|
244 |
+
q.add_atom('N', neighbors=1)
|
245 |
+
q.add_bond(1, 2, 1)
|
246 |
+
q.add_bond(2, 3, 2)
|
247 |
+
q.add_bond(2, 4, 2)
|
248 |
+
q.add_bond(2, 5, 1)
|
249 |
+
|
250 |
+
p.add_atom('C')
|
251 |
+
|
252 |
+
# Ring expansion rearrangement
|
253 |
+
#
|
254 |
+
q, p = prepare()
|
255 |
+
q.add_atom('C')
|
256 |
+
q.add_atom('N')
|
257 |
+
q.add_atom('C', rings_sizes=6)
|
258 |
+
q.add_atom('C')
|
259 |
+
q.add_atom('O')
|
260 |
+
q.add_atom('C')
|
261 |
+
q.add_atom('C')
|
262 |
+
q.add_bond(1, 2, 1)
|
263 |
+
q.add_bond(2, 3, 1)
|
264 |
+
q.add_bond(3, 4, 1)
|
265 |
+
q.add_bond(4, 5, 2)
|
266 |
+
q.add_bond(3, 6, 1)
|
267 |
+
q.add_bond(4, 7, 1)
|
268 |
+
|
269 |
+
p.add_atom('C')
|
270 |
+
p.add_atom('N')
|
271 |
+
p.add_atom('C')
|
272 |
+
p.add_atom('C')
|
273 |
+
p.add_atom('O')
|
274 |
+
p.add_atom('C')
|
275 |
+
p.add_atom('C')
|
276 |
+
p.add_bond(1, 2, 1)
|
277 |
+
p.add_bond(2, 3, 2)
|
278 |
+
p.add_bond(3, 4, 1)
|
279 |
+
p.add_bond(4, 5, 1)
|
280 |
+
p.add_bond(4, 6, 1)
|
281 |
+
p.add_bond(4, 7, 1)
|
282 |
+
|
283 |
+
# hydrolysis of bromide alkyl
|
284 |
+
#
|
285 |
+
q, p = prepare()
|
286 |
+
q.add_atom('C', hybridization=1)
|
287 |
+
q.add_atom('O', neighbors=1)
|
288 |
+
q.add_bond(1, 2, 1)
|
289 |
+
|
290 |
+
p.add_atom('C')
|
291 |
+
p.add_atom('Br')
|
292 |
+
p.add_bond(1, 2, 1)
|
293 |
+
|
294 |
+
# Condensation of ketones/aldehydes and amines into imines
|
295 |
+
#
|
296 |
+
q, p = prepare()
|
297 |
+
q.add_atom('N', neighbors=(1, 2))
|
298 |
+
q.add_atom('C', neighbors=(2, 3), heteroatoms=1)
|
299 |
+
q.add_bond(1, 2, 2)
|
300 |
+
|
301 |
+
p.add_atom('C', _map=2)
|
302 |
+
p.add_atom('O')
|
303 |
+
p.add_bond(2, 3, 2)
|
304 |
+
|
305 |
+
# Halogenation of alkanes
|
306 |
+
#
|
307 |
+
q, p = prepare()
|
308 |
+
q.add_atom('C', hybridization=1)
|
309 |
+
q.add_atom(ListElement(['F', 'Cl', 'Br']))
|
310 |
+
q.add_bond(1, 2, 1)
|
311 |
+
|
312 |
+
p.add_atom('C')
|
313 |
+
|
314 |
+
# heterocyclization
|
315 |
+
#
|
316 |
+
q, p = prepare()
|
317 |
+
q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3))
|
318 |
+
q.add_atom('C', heteroatoms=2)
|
319 |
+
q.add_atom('N', heteroatoms=0, neighbors=2)
|
320 |
+
q.add_bond(1, 2, 1)
|
321 |
+
q.add_bond(2, 3, 2)
|
322 |
+
|
323 |
+
p.add_atom('N')
|
324 |
+
p.add_atom('C')
|
325 |
+
p.add_atom('N')
|
326 |
+
p.add_atom('O')
|
327 |
+
p.add_bond(1, 2, 1)
|
328 |
+
p.add_bond(2, 4, 2)
|
329 |
+
|
330 |
+
# Reduction of nitrile
|
331 |
+
#
|
332 |
+
q, p = prepare()
|
333 |
+
q.add_atom('N', neighbors=1)
|
334 |
+
q.add_atom('C')
|
335 |
+
q.add_atom('C', hybridization=1)
|
336 |
+
q.add_bond(1, 2, 1)
|
337 |
+
q.add_bond(2, 3, 1)
|
338 |
+
|
339 |
+
p.add_atom('N')
|
340 |
+
p.add_atom('C')
|
341 |
+
p.add_atom('C')
|
342 |
+
p.add_bond(1, 2, 3)
|
343 |
+
p.add_bond(2, 3, 1)
|
344 |
+
|
345 |
+
# SPECIAL CASE
|
346 |
+
# Reduction of nitrile into methylamine
|
347 |
+
#
|
348 |
+
q, p = prepare()
|
349 |
+
q.add_atom('C', neighbors=1)
|
350 |
+
q.add_atom('N', neighbors=2)
|
351 |
+
q.add_atom('C')
|
352 |
+
q.add_atom('C', hybridization=1)
|
353 |
+
q.add_bond(1, 2, 1)
|
354 |
+
q.add_bond(2, 3, 1)
|
355 |
+
q.add_bond(3, 4, 1)
|
356 |
+
|
357 |
+
p.add_atom('N', _map=2)
|
358 |
+
p.add_atom('C')
|
359 |
+
p.add_atom('C')
|
360 |
+
p.add_bond(2, 3, 3)
|
361 |
+
p.add_bond(3, 4, 1)
|
362 |
+
|
363 |
+
# methylation of amides
|
364 |
+
#
|
365 |
+
q, p = prepare()
|
366 |
+
q.add_atom('O')
|
367 |
+
q.add_atom('C')
|
368 |
+
q.add_atom('N')
|
369 |
+
q.add_atom('C', neighbors=1)
|
370 |
+
q.add_bond(1, 2, 2)
|
371 |
+
q.add_bond(2, 3, 1)
|
372 |
+
q.add_bond(3, 4, 1)
|
373 |
+
|
374 |
+
p.add_atom('O')
|
375 |
+
p.add_atom('C')
|
376 |
+
p.add_atom('N')
|
377 |
+
p.add_bond(1, 2, 2)
|
378 |
+
p.add_bond(2, 3, 1)
|
379 |
+
|
380 |
+
# hydrocyanation of alkenes
|
381 |
+
#
|
382 |
+
q, p = prepare()
|
383 |
+
q.add_atom('C', hybridization=1)
|
384 |
+
q.add_atom('C')
|
385 |
+
q.add_atom('C')
|
386 |
+
q.add_atom('N')
|
387 |
+
q.add_bond(1, 2, 1)
|
388 |
+
q.add_bond(2, 3, 1)
|
389 |
+
q.add_bond(3, 4, 3)
|
390 |
+
|
391 |
+
p.add_atom('C')
|
392 |
+
p.add_atom('C')
|
393 |
+
p.add_bond(1, 2, 2)
|
394 |
+
|
395 |
+
# decarbocylation (alpha atom of nitrile)
|
396 |
+
#
|
397 |
+
q, p = prepare()
|
398 |
+
q.add_atom('N')
|
399 |
+
q.add_atom('C')
|
400 |
+
q.add_atom('C', neighbors=2)
|
401 |
+
q.add_bond(1, 2, 3)
|
402 |
+
q.add_bond(2, 3, 1)
|
403 |
+
|
404 |
+
p.add_atom('N')
|
405 |
+
p.add_atom('C')
|
406 |
+
p.add_atom('C')
|
407 |
+
p.add_atom('C')
|
408 |
+
p.add_atom('O')
|
409 |
+
p.add_atom('O')
|
410 |
+
p.add_bond(1, 2, 3)
|
411 |
+
p.add_bond(2, 3, 1)
|
412 |
+
p.add_bond(3, 4, 1)
|
413 |
+
p.add_bond(4, 5, 2)
|
414 |
+
p.add_bond(4, 6, 1)
|
415 |
+
|
416 |
+
# Bichler-Napieralski reaction
|
417 |
+
#
|
418 |
+
q, p = prepare()
|
419 |
+
q.add_atom('C', rings_sizes=(6,))
|
420 |
+
q.add_atom('C', rings_sizes=(6,))
|
421 |
+
q.add_atom('N', rings_sizes=(6,), neighbors=2)
|
422 |
+
q.add_atom('C')
|
423 |
+
q.add_atom('C')
|
424 |
+
q.add_atom('C')
|
425 |
+
q.add_atom('O')
|
426 |
+
q.add_atom('O')
|
427 |
+
q.add_atom('C')
|
428 |
+
q.add_atom('O', neighbors=1)
|
429 |
+
q.add_bond(1, 2, 4)
|
430 |
+
q.add_bond(2, 3, 1)
|
431 |
+
q.add_bond(3, 4, 1)
|
432 |
+
q.add_bond(4, 5, 2)
|
433 |
+
q.add_bond(5, 6, 1)
|
434 |
+
q.add_bond(6, 7, 2)
|
435 |
+
q.add_bond(6, 8, 1)
|
436 |
+
q.add_bond(5, 9, 4)
|
437 |
+
q.add_bond(9, 10, 1)
|
438 |
+
q.add_bond(1, 9, 1)
|
439 |
+
|
440 |
+
p.add_atom('C')
|
441 |
+
p.add_atom('C')
|
442 |
+
p.add_atom('N')
|
443 |
+
p.add_atom('C')
|
444 |
+
p.add_atom('C')
|
445 |
+
p.add_atom('C')
|
446 |
+
p.add_atom('O')
|
447 |
+
p.add_atom('O')
|
448 |
+
p.add_atom('C')
|
449 |
+
p.add_atom('O')
|
450 |
+
p.add_atom('O')
|
451 |
+
p.add_bond(1, 2, 4)
|
452 |
+
p.add_bond(2, 3, 1)
|
453 |
+
p.add_bond(3, 4, 1)
|
454 |
+
p.add_bond(4, 5, 2)
|
455 |
+
p.add_bond(5, 6, 1)
|
456 |
+
p.add_bond(6, 7, 2)
|
457 |
+
p.add_bond(6, 8, 1)
|
458 |
+
p.add_bond(5, 9, 1)
|
459 |
+
p.add_bond(9, 10, 2)
|
460 |
+
p.add_bond(9, 11, 1)
|
461 |
+
|
462 |
+
# heterocyclization in Prins reaction
|
463 |
+
#
|
464 |
+
q, p = prepare()
|
465 |
+
q.add_atom('C')
|
466 |
+
q.add_atom('O')
|
467 |
+
q.add_atom('C')
|
468 |
+
q.add_atom(ListElement(['N', 'O']), neighbors=2)
|
469 |
+
q.add_atom('C')
|
470 |
+
q.add_atom('C')
|
471 |
+
q.add_bond(1, 2, 1)
|
472 |
+
q.add_bond(2, 3, 1)
|
473 |
+
q.add_bond(3, 4, 1)
|
474 |
+
q.add_bond(4, 5, 1)
|
475 |
+
q.add_bond(5, 6, 1)
|
476 |
+
q.add_bond(1, 6, 1)
|
477 |
+
|
478 |
+
p.add_atom('C')
|
479 |
+
p.add_atom('C', _map=5)
|
480 |
+
p.add_bond(1, 5, 2)
|
481 |
+
|
482 |
+
# recyclization of tetrahydropyran through an opening the ring and dehydration
|
483 |
+
#
|
484 |
+
q, p = prepare()
|
485 |
+
q.add_atom('C')
|
486 |
+
q.add_atom('C')
|
487 |
+
q.add_atom('C')
|
488 |
+
q.add_atom(ListElement(['N', 'O']))
|
489 |
+
q.add_atom('C')
|
490 |
+
q.add_atom('C')
|
491 |
+
q.add_bond(1, 2, 1)
|
492 |
+
q.add_bond(2, 3, 1)
|
493 |
+
q.add_bond(3, 4, 1)
|
494 |
+
q.add_bond(4, 5, 1)
|
495 |
+
q.add_bond(5, 6, 1)
|
496 |
+
q.add_bond(1, 6, 2)
|
497 |
+
|
498 |
+
p.add_atom('C')
|
499 |
+
p.add_atom('C')
|
500 |
+
p.add_atom('C')
|
501 |
+
p.add_atom('A')
|
502 |
+
p.add_atom('C')
|
503 |
+
p.add_atom('C')
|
504 |
+
p.add_atom('O')
|
505 |
+
p.add_bond(1, 2, 1)
|
506 |
+
p.add_bond(1, 7, 1)
|
507 |
+
p.add_bond(3, 7, 1)
|
508 |
+
p.add_bond(3, 4, 1)
|
509 |
+
p.add_bond(4, 5, 1)
|
510 |
+
p.add_bond(5, 6, 1)
|
511 |
+
p.add_bond(1, 6, 1)
|
512 |
+
|
513 |
+
# alkenes + h2o/hHal
|
514 |
+
#
|
515 |
+
q, p = prepare()
|
516 |
+
q.add_atom('C', hybridization=1)
|
517 |
+
q.add_atom('C', hybridization=1)
|
518 |
+
q.add_atom(ListElement(['O', 'F', 'Cl', 'Br', 'I']), neighbors=1)
|
519 |
+
q.add_bond(1, 2, 1)
|
520 |
+
q.add_bond(2, 3, 1)
|
521 |
+
|
522 |
+
p.add_atom('C')
|
523 |
+
p.add_atom('C')
|
524 |
+
p.add_bond(1, 2, 2)
|
525 |
+
|
526 |
+
# methylation of dimethylamines
|
527 |
+
#
|
528 |
+
q, p = prepare()
|
529 |
+
q.add_atom('C', neighbors=1)
|
530 |
+
q.add_atom('N', neighbors=3)
|
531 |
+
q.add_bond(1, 2, 1)
|
532 |
+
|
533 |
+
p.add_atom('N', _map=2)
|
534 |
+
|
535 |
+
__all__ = ['rules']
|
SynTool/chem/retron.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing a class Retron that represents a retron (extend molecule object) in the search tree
|
3 |
+
"""
|
4 |
+
|
5 |
+
from CGRtools.containers import MoleculeContainer
|
6 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
7 |
+
|
8 |
+
from SynTool.chem.utils import safe_canonicalization
|
9 |
+
|
10 |
+
|
11 |
+
class Retron:
|
12 |
+
"""
|
13 |
+
Retron class is used to extend the molecule behavior needed for interaction with a tree in MCTS
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True):
|
17 |
+
"""
|
18 |
+
It initializes a Retron object with a molecule container as a parameter.
|
19 |
+
|
20 |
+
:param molecule: The `molecule` parameter is of type `MoleculeContainer`.
|
21 |
+
:type molecule: MoleculeContainer
|
22 |
+
"""
|
23 |
+
self._molecule = safe_canonicalization(molecule) if canonicalize else molecule
|
24 |
+
self._mapping = None
|
25 |
+
self.prev_retrons = []
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
"""
|
29 |
+
Return the number of atoms in Retron.
|
30 |
+
"""
|
31 |
+
return len(self._molecule)
|
32 |
+
|
33 |
+
def __hash__(self):
|
34 |
+
"""
|
35 |
+
Returns the hash value of Retron.
|
36 |
+
"""
|
37 |
+
return hash(self._molecule)
|
38 |
+
|
39 |
+
def __str__(self):
|
40 |
+
return str(self._molecule)
|
41 |
+
|
42 |
+
def __eq__(self, other: "Retron"):
|
43 |
+
"""
|
44 |
+
The function checks if the current Retron is equal to another Retron of the same class.
|
45 |
+
|
46 |
+
:param other: The "other" parameter is a reference to another object of the same class "Retron". It is used to
|
47 |
+
compare the current Retron with the other Retron to check if they are equal
|
48 |
+
:type other: "Retron"
|
49 |
+
"""
|
50 |
+
return self._molecule == other._molecule
|
51 |
+
|
52 |
+
def validate_molecule(self):
|
53 |
+
molecule = self._molecule.copy()
|
54 |
+
try:
|
55 |
+
molecule.kekule()
|
56 |
+
if molecule.check_valence():
|
57 |
+
return False
|
58 |
+
molecule.thiele()
|
59 |
+
except InvalidAromaticRing:
|
60 |
+
return False
|
61 |
+
return True
|
62 |
+
|
63 |
+
@property
|
64 |
+
def molecule(self) -> MoleculeContainer:
|
65 |
+
"""
|
66 |
+
Returns a remapped MoleculeContainer object if self._mapping=True.
|
67 |
+
"""
|
68 |
+
if self._mapping:
|
69 |
+
remapped = self._molecule.copy()
|
70 |
+
try:
|
71 |
+
remapped = self._molecule.remap(self._mapping, copy=True)
|
72 |
+
except ValueError:
|
73 |
+
pass
|
74 |
+
return remapped
|
75 |
+
return self._molecule.copy()
|
76 |
+
|
77 |
+
def __repr__(self):
|
78 |
+
"""
|
79 |
+
Returns a SMILES of the retron
|
80 |
+
"""
|
81 |
+
return str(self._molecule)
|
82 |
+
|
83 |
+
def is_building_block(self, stock, min_mol_size=6):
|
84 |
+
"""
|
85 |
+
The function checks if a Retron is a building block.
|
86 |
+
|
87 |
+
:param min_mol_size:
|
88 |
+
:param stock: The list of building blocks. Each building block is represented by a smiles.
|
89 |
+
"""
|
90 |
+
if len(self._molecule) <= min_mol_size:
|
91 |
+
return True
|
92 |
+
else:
|
93 |
+
return str(self._molecule) in stock
|
94 |
+
|
95 |
+
|
96 |
+
def compose_retrons(retrons: list = None, exclude_small=True, min_mol_size: int = 6
|
97 |
+
) -> MoleculeContainer:
|
98 |
+
"""
|
99 |
+
The function takes a list of retrons, excludes small retrons if specified, and composes them into a single molecule.
|
100 |
+
This molecule is used for the prediction of synthesisability of the characterizing the possible success of the path
|
101 |
+
including the nodes with the given retrons.
|
102 |
+
|
103 |
+
:param retrons: The list of retrons to be composed.
|
104 |
+
:type retrons: list
|
105 |
+
:param exclude_small: The parameter that determines whether small retrons should be
|
106 |
+
excluded from the composition process. If `exclude_small` is set to `True`, only retrons with a length greater than
|
107 |
+
min_mol_size will be considered for composition.
|
108 |
+
:param min_mol_size: parameter used with exclude_small
|
109 |
+
:return: A composed retrons as a MoleculeContainer object.
|
110 |
+
"""
|
111 |
+
|
112 |
+
if len(retrons) == 1:
|
113 |
+
return retrons[0].molecule
|
114 |
+
elif len(retrons) > 1:
|
115 |
+
if exclude_small:
|
116 |
+
big_retrons = [
|
117 |
+
retron for retron in retrons if len(retron.molecule) > min_mol_size
|
118 |
+
]
|
119 |
+
if big_retrons:
|
120 |
+
retrons = big_retrons
|
121 |
+
tmp_mol = retrons[0].molecule.copy()
|
122 |
+
transition_mapping = {}
|
123 |
+
for mol in retrons[1:]:
|
124 |
+
for n, atom in mol.molecule.atoms():
|
125 |
+
new_number = tmp_mol.add_atom(atom.atomic_symbol)
|
126 |
+
transition_mapping[n] = new_number
|
127 |
+
for atom, neighbor, bond in mol.molecule.bonds():
|
128 |
+
tmp_mol.add_bond(
|
129 |
+
transition_mapping[atom], transition_mapping[neighbor], bond
|
130 |
+
)
|
131 |
+
transition_mapping = {}
|
132 |
+
return tmp_mol
|
SynTool/chem/utils.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Iterable, Tuple, Union
|
2 |
+
|
3 |
+
from CGRtools.containers import MoleculeContainer, ReactionContainer, QueryContainer
|
4 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
5 |
+
|
6 |
+
|
7 |
+
def query_to_mol(query: QueryContainer) -> MoleculeContainer:
|
8 |
+
"""
|
9 |
+
Converts a QueryContainer object into a MoleculeContainer object.
|
10 |
+
|
11 |
+
:param query: A QueryContainer object representing the query structure.
|
12 |
+
:return: A MoleculeContainer object that replicates the structure of the query.
|
13 |
+
"""
|
14 |
+
new_mol = MoleculeContainer()
|
15 |
+
for n, atom in query.atoms():
|
16 |
+
new_mol.add_atom(atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical)
|
17 |
+
for i, j, bond in query.bonds():
|
18 |
+
new_mol.add_bond(i, j, int(bond))
|
19 |
+
return new_mol
|
20 |
+
|
21 |
+
|
22 |
+
def reaction_query_to_reaction(rule: ReactionContainer) -> ReactionContainer:
|
23 |
+
"""
|
24 |
+
Converts a ReactionContainer object with query structures into a ReactionContainer with molecular structures.
|
25 |
+
|
26 |
+
:param rule: A ReactionContainer object where reactants and products are QueryContainer objects.
|
27 |
+
:return: A new ReactionContainer
|
28 |
+
:return: A new ReactionContainer object where reactants and products are MoleculeContainer objects.
|
29 |
+
"""
|
30 |
+
reactants = [query_to_mol(q) for q in rule.reactants]
|
31 |
+
products = [query_to_mol(q) for q in rule.products]
|
32 |
+
reagents = [query_to_mol(q) for q in rule.reagents] # Assuming reagents are also part of the rule
|
33 |
+
reaction = ReactionContainer(reactants, products, reagents, rule.meta)
|
34 |
+
reaction.name = rule.name
|
35 |
+
return reaction
|
36 |
+
|
37 |
+
|
38 |
+
def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer:
|
39 |
+
"""
|
40 |
+
Unites a list of MoleculeContainer objects into a single MoleculeContainer.
|
41 |
+
|
42 |
+
This function takes multiple molecules and combines them into one larger molecule.
|
43 |
+
The first molecule in the list is taken as the base, and subsequent molecules are united with it sequentially.
|
44 |
+
|
45 |
+
:param molecules: A list of MoleculeContainer objects to be united.
|
46 |
+
:return: A single MoleculeContainer object representing the union of all input molecules.
|
47 |
+
"""
|
48 |
+
new_mol = MoleculeContainer()
|
49 |
+
for mol in molecules:
|
50 |
+
new_mol = new_mol.union(mol)
|
51 |
+
return new_mol
|
52 |
+
|
53 |
+
|
54 |
+
def safe_canonicalization(molecule: MoleculeContainer):
|
55 |
+
"""
|
56 |
+
Attempts to canonicalize a molecule, handling any exceptions.
|
57 |
+
|
58 |
+
This function tries to canonicalize the given molecule.
|
59 |
+
If the canonicalization process fails due to an InvalidAromaticRing exception,
|
60 |
+
it safely returns the original molecule.
|
61 |
+
|
62 |
+
:param molecule: The given molecule to be canonicalized.
|
63 |
+
:return: The canonicalized molecule if successful, otherwise the original molecule.
|
64 |
+
"""
|
65 |
+
molecule._atoms = dict(sorted(molecule._atoms.items()))
|
66 |
+
|
67 |
+
tmp = molecule.copy()
|
68 |
+
try:
|
69 |
+
tmp.canonicalize()
|
70 |
+
return tmp
|
71 |
+
except InvalidAromaticRing:
|
72 |
+
return molecule
|
73 |
+
|
74 |
+
|
75 |
+
def split_molecules(molecules: Iterable, number_of_atoms: int) -> Tuple[List, List]:
|
76 |
+
"""
|
77 |
+
Splits molecules according to the number of heavy atoms.
|
78 |
+
|
79 |
+
:param molecules: Iterable of molecules.
|
80 |
+
:param number_of_atoms: Threshold for splitting molecules.
|
81 |
+
:return: Tuple of lists containing "big" molecules and "small" molecules.
|
82 |
+
"""
|
83 |
+
big_molecules, small_molecules = [], []
|
84 |
+
for molecule in molecules:
|
85 |
+
if len(molecule) > number_of_atoms:
|
86 |
+
big_molecules.append(molecule)
|
87 |
+
else:
|
88 |
+
small_molecules.append(molecule)
|
89 |
+
|
90 |
+
return big_molecules, small_molecules
|
91 |
+
|
92 |
+
|
93 |
+
def remove_small_molecules(
|
94 |
+
reaction: ReactionContainer,
|
95 |
+
number_of_atoms: int = 6,
|
96 |
+
small_molecules_to_meta: bool = True
|
97 |
+
) -> Union[ReactionContainer, None]:
|
98 |
+
"""
|
99 |
+
Processes a reaction by removing small molecules.
|
100 |
+
|
101 |
+
:param reaction: ReactionContainer object.
|
102 |
+
:param number_of_atoms: Molecules with the number of atoms equal to or below this will be removed.
|
103 |
+
:param small_molecules_to_meta: If True, deleted molecules are saved to meta.
|
104 |
+
:return: Processed ReactionContainer without small molecules.
|
105 |
+
"""
|
106 |
+
new_reactants, small_reactants = split_molecules(reaction.reactants, number_of_atoms)
|
107 |
+
new_products, small_products = split_molecules(reaction.products, number_of_atoms)
|
108 |
+
|
109 |
+
if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0:
|
110 |
+
return None
|
111 |
+
|
112 |
+
new_reaction = ReactionContainer(new_reactants, new_products, reaction.reagents, reaction.meta)
|
113 |
+
new_reaction.name = reaction.name
|
114 |
+
|
115 |
+
if small_molecules_to_meta:
|
116 |
+
united_small_reactants = unite_molecules(small_reactants)
|
117 |
+
new_reaction.meta["small_reactants"] = str(united_small_reactants)
|
118 |
+
|
119 |
+
united_small_products = unite_molecules(small_products)
|
120 |
+
new_reaction.meta["small_products"] = str(united_small_products)
|
121 |
+
|
122 |
+
return new_reaction
|
123 |
+
|
124 |
+
|
125 |
+
def rebalance_reaction(reaction: ReactionContainer) -> ReactionContainer:
|
126 |
+
"""
|
127 |
+
Rebalances the reaction by assembling CGR and then decomposing it. Works for all reactions for which the correct
|
128 |
+
CGR can be assembled
|
129 |
+
|
130 |
+
:param reaction: a reaction object
|
131 |
+
:return: a rebalanced reaction
|
132 |
+
"""
|
133 |
+
tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
|
134 |
+
cgr = ~tmp_reaction
|
135 |
+
reactants, products = ~cgr
|
136 |
+
rebalanced_reaction = ReactionContainer(reactants.split(), products.split(), reaction.reagents, reaction.meta)
|
137 |
+
rebalanced_reaction.name = reaction.name
|
138 |
+
return rebalanced_reaction
|
139 |
+
|
140 |
+
|
141 |
+
def reverse_reaction(reaction: ReactionContainer) -> ReactionContainer:
|
142 |
+
"""
|
143 |
+
Reverses given reaction
|
144 |
+
|
145 |
+
:param reaction: a reaction object
|
146 |
+
:return: the reversed reaction
|
147 |
+
"""
|
148 |
+
reversed_reaction = ReactionContainer(reaction.products, reaction.reactants, reaction.reagents, reaction.meta)
|
149 |
+
reversed_reaction.name = reaction.name
|
150 |
+
|
151 |
+
return reversed_reaction
|
152 |
+
|
153 |
+
|
154 |
+
def remove_reagents(
|
155 |
+
reaction: ReactionContainer,
|
156 |
+
keep_reagents: bool = True,
|
157 |
+
reagents_max_size: int = 7
|
158 |
+
) -> Union[ReactionContainer, None]:
|
159 |
+
"""
|
160 |
+
Removes reagents (not changed molecules or molecules not involved in the reaction) from reactants and products
|
161 |
+
|
162 |
+
:param reaction: a reaction object
|
163 |
+
:param keep_reagents: if True, the reagents are written to ReactionContainer
|
164 |
+
:param reagents_max_size: max size of molecules that are called reagents, bigger are deleted
|
165 |
+
:return: cleaned reaction
|
166 |
+
"""
|
167 |
+
not_changed_molecules = set(reaction.reactants).intersection(reaction.products)
|
168 |
+
|
169 |
+
cgr = ~reaction
|
170 |
+
center_atoms = set(cgr.center_atoms)
|
171 |
+
|
172 |
+
new_reactants = []
|
173 |
+
new_products = []
|
174 |
+
new_reagents = []
|
175 |
+
|
176 |
+
for molecule in reaction.reactants:
|
177 |
+
if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
|
178 |
+
new_reagents.append(molecule)
|
179 |
+
else:
|
180 |
+
new_reactants.append(molecule)
|
181 |
+
|
182 |
+
for molecule in reaction.products:
|
183 |
+
if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
|
184 |
+
new_reagents.append(molecule)
|
185 |
+
else:
|
186 |
+
new_products.append(molecule)
|
187 |
+
|
188 |
+
if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0:
|
189 |
+
return None
|
190 |
+
|
191 |
+
if keep_reagents:
|
192 |
+
new_reagents = {molecule for molecule in new_reagents if len(molecule) <= reagents_max_size}
|
193 |
+
else:
|
194 |
+
new_reagents = []
|
195 |
+
|
196 |
+
new_reaction = ReactionContainer(new_reactants, new_products, new_reagents, reaction.meta)
|
197 |
+
new_reaction.name = reaction.name
|
198 |
+
|
199 |
+
return new_reaction
|
200 |
+
|
201 |
+
|
202 |
+
def to_reaction_smiles_record(reaction):
|
203 |
+
if isinstance(reaction, str):
|
204 |
+
return reaction
|
205 |
+
|
206 |
+
reaction_record = [format(reaction, "m")]
|
207 |
+
sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0])
|
208 |
+
for _, meta_info in sorted_meta:
|
209 |
+
# meta_info = str(meta_info)
|
210 |
+
meta_info = '' # TODO decide what to do with meta
|
211 |
+
meta_info = ";".join(meta_info.split("\n"))
|
212 |
+
reaction_record.append(str(meta_info))
|
213 |
+
# return "\t".join(reaction_record) + "\n"
|
214 |
+
return "".join(reaction_record)
|
215 |
+
|
216 |
+
|
217 |
+
def cgr_from_rule(rule: ReactionContainer):
|
218 |
+
reaction_rule = reaction_query_to_reaction(rule)
|
219 |
+
cgr_rule = ~reaction_rule
|
220 |
+
return cgr_rule
|
221 |
+
|
222 |
+
|
223 |
+
def hash_from_rule(reaction_rule: ReactionContainer):
|
224 |
+
reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants))
|
225 |
+
reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents))
|
226 |
+
products_hash = tuple(sorted(hash(r) for r in reaction_rule.products))
|
227 |
+
return hash((reactants_hash, reagents_hash, products_hash))
|
SynTool/interfaces/__init__.py
ADDED
File without changes
|
SynTool/interfaces/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (155 Bytes). View file
|
|
SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
SynTool/interfaces/cli.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing commands line scripts for training and planning mode
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import yaml
|
8 |
+
import warnings
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import click
|
12 |
+
import gdown
|
13 |
+
|
14 |
+
from SynTool.chem.data.cleaning import reactions_cleaner
|
15 |
+
from SynTool.chem.data.filtering import filter_reactions, ReactionCheckConfig
|
16 |
+
from SynTool.utils.loading import standardize_building_blocks
|
17 |
+
from SynTool.chem.reaction_rules.extraction import extract_rules_from_reactions
|
18 |
+
from SynTool.mcts.search import tree_search
|
19 |
+
from SynTool.ml.training.reinforcement import run_reinforcement_tuning
|
20 |
+
from SynTool.ml.training.supervised import create_policy_dataset, run_policy_training
|
21 |
+
from SynTool.utils.config import ReinforcementConfig, TreeConfig, PolicyNetworkConfig, ValueNetworkConfig
|
22 |
+
from SynTool.utils.config import ReactionStandardizationConfig, RuleExtractionConfig
|
23 |
+
from SynTool.chem.data.mapping import remove_reagents_and_map_from_file
|
24 |
+
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
|
27 |
+
|
28 |
+
@click.group(name="syntool")
|
29 |
+
def syntool():
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
@syntool.command(name="download_planning_data")
|
34 |
+
@click.option(
|
35 |
+
"--root_dir",
|
36 |
+
required=True,
|
37 |
+
type=click.Path(exists=True),
|
38 |
+
help="Path to the reaction database file that will be mapped.",
|
39 |
+
)
|
40 |
+
def download_planning_data_cli(root_dir='.'):
|
41 |
+
"""
|
42 |
+
Downloads data for retrosythesis planning
|
43 |
+
"""
|
44 |
+
remote_id = "1ygq9BvQgH2Tq_rL72BvSOdASSSbPFTsL"
|
45 |
+
output = os.path.join(root_dir, "syntool_planning_data.zip")
|
46 |
+
#
|
47 |
+
gdown.download(output=output, id=remote_id, quiet=False)
|
48 |
+
shutil.unpack_archive(output, root_dir)
|
49 |
+
#
|
50 |
+
os.remove(output)
|
51 |
+
|
52 |
+
|
53 |
+
@syntool.command(name='download_training_data')
|
54 |
+
@click.option(
|
55 |
+
"--root_dir",
|
56 |
+
required=True,
|
57 |
+
type=click.Path(exists=True),
|
58 |
+
help="Path to the reaction database file that will be mapped.",
|
59 |
+
)
|
60 |
+
def download_training_data_cli(root_dir='.'):
|
61 |
+
"""
|
62 |
+
Downloads data for retrosythetic models training
|
63 |
+
"""
|
64 |
+
remote_id = "1ckhO1l6xud0_bnC0rCDMkIlKRUMG_xs8"
|
65 |
+
output = os.path.join(root_dir, "syntool_training_data.zip")
|
66 |
+
#
|
67 |
+
gdown.download(output=output, id=remote_id, quiet=False)
|
68 |
+
shutil.unpack_archive(output, root_dir)
|
69 |
+
#
|
70 |
+
os.remove(output)
|
71 |
+
|
72 |
+
|
73 |
+
@syntool.command(name="building_blocks")
|
74 |
+
@click.option(
|
75 |
+
"--input",
|
76 |
+
"input_file",
|
77 |
+
required=True,
|
78 |
+
type=click.Path(exists=True),
|
79 |
+
help="Path to the reaction database file that will be mapped.",
|
80 |
+
)
|
81 |
+
@click.option(
|
82 |
+
"--output",
|
83 |
+
"output_file",
|
84 |
+
required=True,
|
85 |
+
type=click.Path(),
|
86 |
+
help="File where the results will be stored.",
|
87 |
+
)
|
88 |
+
def building_blocks_cli(input_file, output_file):
|
89 |
+
"""
|
90 |
+
Standardizes building blocks
|
91 |
+
"""
|
92 |
+
|
93 |
+
standardize_building_blocks(input_file=input_file, output_file=output_file)
|
94 |
+
|
95 |
+
|
96 |
+
@syntool.command(name="reaction_mapping")
|
97 |
+
@click.option(
|
98 |
+
"--config",
|
99 |
+
"config_path",
|
100 |
+
required=True,
|
101 |
+
type=click.Path(exists=True),
|
102 |
+
help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.",
|
103 |
+
)
|
104 |
+
@click.option(
|
105 |
+
"--input",
|
106 |
+
"input_file",
|
107 |
+
required=True,
|
108 |
+
type=click.Path(exists=True),
|
109 |
+
help="Path to the reaction database file that will be mapped.",
|
110 |
+
)
|
111 |
+
@click.option(
|
112 |
+
"--output",
|
113 |
+
"output_file",
|
114 |
+
default=Path("reaction_data_standardized.smi"),
|
115 |
+
type=click.Path(),
|
116 |
+
help="File where the results will be stored.",
|
117 |
+
)
|
118 |
+
def reaction_mapping_cli(config_path, input_file, output_file):
|
119 |
+
"""
|
120 |
+
Reaction data mapping
|
121 |
+
"""
|
122 |
+
|
123 |
+
stand_config = ReactionStandardizationConfig.from_yaml(config_path)
|
124 |
+
remove_reagents_and_map_from_file(input_file=input_file, output_file=output_file, keep_reagent=stand_config.keep_reagents)
|
125 |
+
|
126 |
+
|
127 |
+
@syntool.command(name="reaction_standardizing")
|
128 |
+
@click.option(
|
129 |
+
"--config",
|
130 |
+
"config_path",
|
131 |
+
required=True,
|
132 |
+
type=click.Path(exists=True),
|
133 |
+
help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.",
|
134 |
+
)
|
135 |
+
@click.option(
|
136 |
+
"--input",
|
137 |
+
"input_file",
|
138 |
+
required=True,
|
139 |
+
type=click.Path(exists=True),
|
140 |
+
help="Path to the reaction database file that will be mapped.",
|
141 |
+
)
|
142 |
+
@click.option(
|
143 |
+
"--output",
|
144 |
+
"output_file",
|
145 |
+
type=click.Path(),
|
146 |
+
help="File where the results will be stored.",
|
147 |
+
)
|
148 |
+
@click.option(
|
149 |
+
"--num_cpus",
|
150 |
+
default=8,
|
151 |
+
type=int,
|
152 |
+
help="Number of CPUs to use for processing. Defaults to 1.",
|
153 |
+
)
|
154 |
+
def reaction_standardizing_cli(config_path, input_file, output_file, num_cpus):
|
155 |
+
"""
|
156 |
+
Standardizes reactions and remove duplicates
|
157 |
+
"""
|
158 |
+
|
159 |
+
stand_config = ReactionStandardizationConfig.from_yaml(config_path)
|
160 |
+
reactions_cleaner(config=stand_config,
|
161 |
+
input_file=input_file,
|
162 |
+
output_file=output_file,
|
163 |
+
num_cpus=num_cpus)
|
164 |
+
|
165 |
+
|
166 |
+
@syntool.command(name="reaction_filtering")
|
167 |
+
@click.option(
|
168 |
+
"--config",
|
169 |
+
"config_path",
|
170 |
+
required=True,
|
171 |
+
type=click.Path(exists=True),
|
172 |
+
help="Path to the configuration file. This file contains settings for filtering reactions.",
|
173 |
+
)
|
174 |
+
@click.option(
|
175 |
+
"--input",
|
176 |
+
"input_file",
|
177 |
+
required=True,
|
178 |
+
type=click.Path(exists=True),
|
179 |
+
help="Path to the reaction database file that will be mapped.",
|
180 |
+
)
|
181 |
+
@click.option(
|
182 |
+
"--output",
|
183 |
+
"output_file",
|
184 |
+
default=Path("./"),
|
185 |
+
type=click.Path(),
|
186 |
+
help="File where the results will be stored.",
|
187 |
+
)
|
188 |
+
@click.option(
|
189 |
+
"--append_results",
|
190 |
+
is_flag=True,
|
191 |
+
default=False,
|
192 |
+
help="If set, results will be appended to existing files. By default, new files are created.",
|
193 |
+
)
|
194 |
+
@click.option(
|
195 |
+
"--batch_size",
|
196 |
+
default=100,
|
197 |
+
type=int,
|
198 |
+
help="Size of the batch for processing reactions. Defaults to 10.",
|
199 |
+
)
|
200 |
+
@click.option(
|
201 |
+
"--num_cpus",
|
202 |
+
default=8,
|
203 |
+
type=int,
|
204 |
+
help="Number of CPUs to use for processing. Defaults to 1.",
|
205 |
+
)
|
206 |
+
def reaction_filtering_cli(config_path,
|
207 |
+
input_file,
|
208 |
+
output_file,
|
209 |
+
append_results,
|
210 |
+
batch_size,
|
211 |
+
num_cpus):
|
212 |
+
"""
|
213 |
+
Filters erroneous reactions
|
214 |
+
"""
|
215 |
+
reaction_check_config = ReactionCheckConfig().from_yaml(config_path)
|
216 |
+
filter_reactions(
|
217 |
+
config=reaction_check_config,
|
218 |
+
reaction_database_path=input_file,
|
219 |
+
result_reactions_file_name=output_file,
|
220 |
+
append_results=append_results,
|
221 |
+
num_cpus=num_cpus,
|
222 |
+
batch_size=batch_size,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
@syntool.command(name="rule_extracting")
|
227 |
+
@click.option(
|
228 |
+
"--config",
|
229 |
+
"config_path",
|
230 |
+
required=True,
|
231 |
+
type=click.Path(exists=True),
|
232 |
+
help="Path to the configuration file. This file contains settings for reaction rules extraction.",
|
233 |
+
)
|
234 |
+
@click.option(
|
235 |
+
"--input",
|
236 |
+
"input_file",
|
237 |
+
required=True,
|
238 |
+
type=click.Path(exists=True),
|
239 |
+
help="Path to the reaction database file that will be mapped.",
|
240 |
+
)
|
241 |
+
@click.option(
|
242 |
+
"--output",
|
243 |
+
"output_file",
|
244 |
+
required=True,
|
245 |
+
type=click.Path(),
|
246 |
+
help="File where the results will be stored.",
|
247 |
+
)
|
248 |
+
@click.option(
|
249 |
+
"--batch_size",
|
250 |
+
default=100,
|
251 |
+
type=int,
|
252 |
+
help="Size of the batch for processing reactions. Defaults to 10.",
|
253 |
+
)
|
254 |
+
@click.option(
|
255 |
+
"--num_cpus",
|
256 |
+
default=4,
|
257 |
+
type=int,
|
258 |
+
help="Number of CPUs to use for processing. Defaults to 4.",
|
259 |
+
)
|
260 |
+
def rule_extracting_cli(
|
261 |
+
config_path,
|
262 |
+
input_file,
|
263 |
+
output_file,
|
264 |
+
num_cpus,
|
265 |
+
batch_size,
|
266 |
+
):
|
267 |
+
"""
|
268 |
+
Extracts reaction rules
|
269 |
+
"""
|
270 |
+
|
271 |
+
reaction_rule_config = RuleExtractionConfig.from_yaml(config_path)
|
272 |
+
extract_rules_from_reactions(config=reaction_rule_config,
|
273 |
+
reaction_file=input_file,
|
274 |
+
rules_file_name=output_file,
|
275 |
+
num_cpus=num_cpus,
|
276 |
+
batch_size=batch_size)
|
277 |
+
|
278 |
+
|
279 |
+
@syntool.command(name="supervised_ranking_policy_training")
|
280 |
+
@click.option(
|
281 |
+
"--config",
|
282 |
+
"config_path",
|
283 |
+
required=True,
|
284 |
+
type=click.Path(exists=True),
|
285 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
286 |
+
)
|
287 |
+
@click.option(
|
288 |
+
"--reaction_data",
|
289 |
+
required=True,
|
290 |
+
type=click.Path(exists=True),
|
291 |
+
help="Path to the reaction database file that will be mapped.",
|
292 |
+
)
|
293 |
+
@click.option(
|
294 |
+
"--reaction_rules",
|
295 |
+
required=True,
|
296 |
+
type=click.Path(exists=True),
|
297 |
+
help="Path to the reaction database file that will be mapped.",
|
298 |
+
)
|
299 |
+
@click.option(
|
300 |
+
"--results_dir",
|
301 |
+
default=Path("."),
|
302 |
+
type=click.Path(),
|
303 |
+
help="Root directory where the results will be stored.",
|
304 |
+
)
|
305 |
+
@click.option(
|
306 |
+
"--num_cpus",
|
307 |
+
default=4,
|
308 |
+
type=int,
|
309 |
+
help="Number of CPUs to use for processing. Defaults to 4.",
|
310 |
+
)
|
311 |
+
def supervised_ranking_policy_training_cli(config_path, reaction_data, reaction_rules, results_dir, num_cpus):
|
312 |
+
"""
|
313 |
+
Trains ranking policy network
|
314 |
+
"""
|
315 |
+
|
316 |
+
policy_config = PolicyNetworkConfig.from_yaml(config_path)
|
317 |
+
|
318 |
+
policy_dataset_file = os.path.join(results_dir, 'policy_dataset.dt')
|
319 |
+
|
320 |
+
datamodule = create_policy_dataset(reaction_rules_path=reaction_rules,
|
321 |
+
molecules_or_reactions_path=reaction_data,
|
322 |
+
output_path=policy_dataset_file,
|
323 |
+
dataset_type='ranking',
|
324 |
+
batch_size=policy_config.batch_size,
|
325 |
+
num_cpus=num_cpus)
|
326 |
+
|
327 |
+
run_policy_training(datamodule, config=policy_config, results_path=results_dir)
|
328 |
+
|
329 |
+
|
330 |
+
@syntool.command(name="supervised_filtering_policy_training")
|
331 |
+
@click.option(
|
332 |
+
"--config",
|
333 |
+
"config_path",
|
334 |
+
required=True,
|
335 |
+
type=click.Path(exists=True),
|
336 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
337 |
+
)
|
338 |
+
@click.option(
|
339 |
+
"--molecules_file",
|
340 |
+
required=True,
|
341 |
+
type=click.Path(exists=True),
|
342 |
+
help="Path to the molecules database file that will be mapped.",
|
343 |
+
)
|
344 |
+
@click.option(
|
345 |
+
"--reaction_rules",
|
346 |
+
required=True,
|
347 |
+
type=click.Path(exists=True),
|
348 |
+
help="Path to the reaction database file that will be mapped.",
|
349 |
+
)
|
350 |
+
@click.option(
|
351 |
+
"--results_dir",
|
352 |
+
default=Path("."),
|
353 |
+
type=click.Path(),
|
354 |
+
help="Root directory where the results will be stored.",
|
355 |
+
)
|
356 |
+
@click.option(
|
357 |
+
"--num_cpus",
|
358 |
+
default=8,
|
359 |
+
type=int,
|
360 |
+
help="Number of CPUs to use for processing. Defaults to 1.",
|
361 |
+
)
|
362 |
+
def supervised_filtering_policy_training_cli(config_path, molecules_file, reaction_rules, results_dir, num_cpus):
|
363 |
+
"""
|
364 |
+
Trains filtering policy network
|
365 |
+
"""
|
366 |
+
|
367 |
+
policy_config = PolicyNetworkConfig.from_yaml(config_path)
|
368 |
+
|
369 |
+
policy_dataset_file = os.path.join(results_dir, 'policy_dataset.ckpt')
|
370 |
+
datamodule = create_policy_dataset(reaction_rules_path=reaction_rules,
|
371 |
+
molecules_or_reactions_path=molecules_file,
|
372 |
+
output_path=policy_dataset_file,
|
373 |
+
dataset_type='filtering',
|
374 |
+
batch_size=policy_config.batch_size,
|
375 |
+
num_cpus=num_cpus)
|
376 |
+
|
377 |
+
run_policy_training(datamodule, config=policy_config, results_path=results_dir)
|
378 |
+
|
379 |
+
|
380 |
+
@syntool.command(name="reinforcement_value_network_training")
|
381 |
+
@click.option(
|
382 |
+
"--config",
|
383 |
+
required=True,
|
384 |
+
type=click.Path(exists=True),
|
385 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
386 |
+
)
|
387 |
+
@click.option(
|
388 |
+
"--targets",
|
389 |
+
required=True,
|
390 |
+
type=click.Path(exists=True),
|
391 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
392 |
+
)
|
393 |
+
@click.option(
|
394 |
+
"--reaction_rules",
|
395 |
+
required=True,
|
396 |
+
type=click.Path(exists=True),
|
397 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
398 |
+
)
|
399 |
+
@click.option(
|
400 |
+
"--building_blocks",
|
401 |
+
required=True,
|
402 |
+
type=click.Path(exists=True),
|
403 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
404 |
+
)
|
405 |
+
@click.option(
|
406 |
+
"--policy_network",
|
407 |
+
required=True,
|
408 |
+
type=click.Path(exists=True),
|
409 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
410 |
+
)
|
411 |
+
@click.option(
|
412 |
+
"--value_network",
|
413 |
+
default=None,
|
414 |
+
type=click.Path(exists=True),
|
415 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
416 |
+
)
|
417 |
+
@click.option(
|
418 |
+
"--results_dir",
|
419 |
+
default='.',
|
420 |
+
type=click.Path(exists=False),
|
421 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
422 |
+
)
|
423 |
+
def reinforcement_value_network_training_cli(config,
|
424 |
+
targets,
|
425 |
+
reaction_rules,
|
426 |
+
building_blocks,
|
427 |
+
policy_network,
|
428 |
+
value_network,
|
429 |
+
results_dir):
|
430 |
+
"""
|
431 |
+
Trains value network with reinforcement learning
|
432 |
+
"""
|
433 |
+
|
434 |
+
with open(config, "r") as file:
|
435 |
+
config = yaml.safe_load(file)
|
436 |
+
|
437 |
+
policy_config = PolicyNetworkConfig.from_dict(config['node_expansion'])
|
438 |
+
policy_config.weights_path = policy_network
|
439 |
+
|
440 |
+
value_config = ValueNetworkConfig.from_dict(config['value_network'])
|
441 |
+
if value_network is None:
|
442 |
+
value_config.weights_path = os.path.join(results_dir, 'weights', 'value_network.ckpt')
|
443 |
+
|
444 |
+
tree_config = TreeConfig.from_dict(config['tree'])
|
445 |
+
reinforce_config = ReinforcementConfig.from_dict(config['reinforcement'])
|
446 |
+
|
447 |
+
run_reinforcement_tuning(targets_path=targets,
|
448 |
+
tree_config=tree_config,
|
449 |
+
policy_config=policy_config,
|
450 |
+
value_config=value_config,
|
451 |
+
reinforce_config=reinforce_config,
|
452 |
+
reaction_rules_path=reaction_rules,
|
453 |
+
building_blocks_path=building_blocks,
|
454 |
+
results_root=results_dir)
|
455 |
+
|
456 |
+
|
457 |
+
@syntool.command(name="planning")
|
458 |
+
@click.option(
|
459 |
+
"--config",
|
460 |
+
"config_path",
|
461 |
+
required=True,
|
462 |
+
type=click.Path(exists=True),
|
463 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
464 |
+
)
|
465 |
+
@click.option(
|
466 |
+
"--targets",
|
467 |
+
required=True,
|
468 |
+
type=click.Path(exists=True),
|
469 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
470 |
+
)
|
471 |
+
@click.option(
|
472 |
+
"--reaction_rules",
|
473 |
+
required=True,
|
474 |
+
type=click.Path(exists=True),
|
475 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
476 |
+
)
|
477 |
+
@click.option(
|
478 |
+
"--building_blocks",
|
479 |
+
required=True,
|
480 |
+
type=click.Path(exists=True),
|
481 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
482 |
+
)
|
483 |
+
@click.option(
|
484 |
+
"--policy_network",
|
485 |
+
required=True,
|
486 |
+
type=click.Path(exists=True),
|
487 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
488 |
+
)
|
489 |
+
@click.option(
|
490 |
+
"--value_network",
|
491 |
+
default=None,
|
492 |
+
type=click.Path(exists=True),
|
493 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
494 |
+
)
|
495 |
+
@click.option(
|
496 |
+
"--results_dir",
|
497 |
+
default='.',
|
498 |
+
type=click.Path(exists=False),
|
499 |
+
help="Path to the configuration file. This file contains settings for policy training.",
|
500 |
+
)
|
501 |
+
def planning_cli(config_path,
|
502 |
+
targets,
|
503 |
+
reaction_rules,
|
504 |
+
building_blocks,
|
505 |
+
policy_network,
|
506 |
+
value_network,
|
507 |
+
results_dir):
|
508 |
+
"""
|
509 |
+
Runs retrosynthesis planning
|
510 |
+
"""
|
511 |
+
|
512 |
+
with open(config_path, "r") as file:
|
513 |
+
config = yaml.safe_load(file)
|
514 |
+
|
515 |
+
tree_config = TreeConfig.from_dict({**config['tree'], **config['node_evaluation']})
|
516 |
+
policy_config = PolicyNetworkConfig.from_dict({**config['node_expansion'], **{'weights_path': policy_network}})
|
517 |
+
|
518 |
+
tree_search(targets_path=targets,
|
519 |
+
tree_config=tree_config,
|
520 |
+
policy_config=policy_config,
|
521 |
+
reaction_rules_path=reaction_rules,
|
522 |
+
building_blocks_path=building_blocks,
|
523 |
+
value_weights_path=value_network,
|
524 |
+
results_root=results_dir)
|
525 |
+
|
526 |
+
|
527 |
+
if __name__ == '__main__':
|
528 |
+
syntool()
|
529 |
+
|
530 |
+
|
SynTool/interfaces/cli.py.bk
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing commands line scripts for training and planning mode
|
3 |
+
"""
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
from pathlib import Path
|
9 |
+
import click
|
10 |
+
import gdown
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
from Syntool.chem.reaction_rules.extraction import extract_rules_from_reactions
|
14 |
+
from Syntool.chem.data.cleaning import reactions_cleaner
|
15 |
+
from Syntool.chem.data.mapping import remove_reagents_and_map_from_file
|
16 |
+
from Syntool.chem.loading import standardize_building_blocks
|
17 |
+
from Syntool.ml.training import create_policy_dataset, run_policy_training
|
18 |
+
from Syntool.ml.training.reinforcement import run_self_tuning
|
19 |
+
from Syntool.ml.networks.policy import PolicyNetworkConfig
|
20 |
+
from Syntool.utils.config import read_planning_config, read_training_config, TreeConfig
|
21 |
+
from Syntool.mcts.search import tree_search
|
22 |
+
|
23 |
+
from Syntool.chem.data.filtering import (
|
24 |
+
filter_reactions,
|
25 |
+
ReactionCheckConfig,
|
26 |
+
CCRingBreakingConfig,
|
27 |
+
WrongCHBreakingConfig,
|
28 |
+
CCsp3BreakingConfig,
|
29 |
+
DynamicBondsConfig,
|
30 |
+
MultiCenterConfig,
|
31 |
+
NoReactionConfig,
|
32 |
+
SmallMoleculesConfig,
|
33 |
+
)
|
34 |
+
|
35 |
+
warnings.filterwarnings("ignore")
|
36 |
+
main = click.Group()
|
37 |
+
|
38 |
+
|
39 |
+
@main.command(name='planning_data')
|
40 |
+
def planning_data_cli():
|
41 |
+
"""
|
42 |
+
Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents,
|
43 |
+
and then deletes the zip file
|
44 |
+
"""
|
45 |
+
remote_id = '1c5YJDT-rP1ZvFA-ELmPNTUj0b8an4yFf'
|
46 |
+
output = 'synto_planning_data.zip'
|
47 |
+
#
|
48 |
+
gdown.download(output=output, id=remote_id, quiet=True)
|
49 |
+
shutil.unpack_archive(output, './')
|
50 |
+
#
|
51 |
+
os.remove(output)
|
52 |
+
|
53 |
+
|
54 |
+
@main.command(name='training_data')
|
55 |
+
def training_data_cli():
|
56 |
+
"""
|
57 |
+
Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents,
|
58 |
+
and then deletes the zip file
|
59 |
+
"""
|
60 |
+
remote_id = '1r4I7OskGvzg-zxYNJ7WVYpVR2HSYW10N'
|
61 |
+
output = 'synto_training_data.zip'
|
62 |
+
#
|
63 |
+
gdown.download(output=output, id=remote_id, quiet=True)
|
64 |
+
shutil.unpack_archive(output, './')
|
65 |
+
#
|
66 |
+
os.remove(output)
|
67 |
+
|
68 |
+
|
69 |
+
@main.command(name='syntool_planning')
|
70 |
+
@click.option("--config",
|
71 |
+
"config_path",
|
72 |
+
required=True,
|
73 |
+
help="Path to the config YAML molecules_path.",
|
74 |
+
type=click.Path(exists=True, path_type=Path),
|
75 |
+
)
|
76 |
+
def syntool_planning_cli(config_path):
|
77 |
+
"""
|
78 |
+
Launches tree search for the given target molecules and stores search statistics and found retrosynthetic paths
|
79 |
+
|
80 |
+
:param config_path: The path to the configuration file that contains the settings and parameters for the tree search
|
81 |
+
"""
|
82 |
+
config = read_planning_config(config_path)
|
83 |
+
config['Tree']['silent'] = True
|
84 |
+
|
85 |
+
# standardize building blocks
|
86 |
+
if config['InputData']['standardize_building_blocks']:
|
87 |
+
print('STANDARDIZE BUILDING BLOCKS ...')
|
88 |
+
standardize_building_blocks(config['InputData']['building_blocks_path'],
|
89 |
+
config['InputData']['building_blocks_path'])
|
90 |
+
# run planning
|
91 |
+
print('\nRUN PLANNING ...')
|
92 |
+
tree_config = TreeConfig.from_dict(config['Tree'])
|
93 |
+
tree_search(targets=config['General']['targets_path'],
|
94 |
+
tree_config=tree_config,
|
95 |
+
reaction_rules_path=config['InputData']['reaction_rules_path'],
|
96 |
+
building_blocks_path=config['InputData']['building_blocks_path'],
|
97 |
+
policy_weights_path=config['PolicyNetwork']['weights_path'],
|
98 |
+
value_weights_paths=config['ValueNetwork']['weights_path'],
|
99 |
+
results_root=config['General']['results_root'])
|
100 |
+
|
101 |
+
|
102 |
+
@main.command(name='syntool_training')
|
103 |
+
@click.option(
|
104 |
+
"--config",
|
105 |
+
"config_path",
|
106 |
+
required=True,
|
107 |
+
help="Path to the config YAML file.",
|
108 |
+
type=click.Path(exists=True, path_type=Path)
|
109 |
+
)
|
110 |
+
def syntool_training_cli(config_path):
|
111 |
+
|
112 |
+
# read training config
|
113 |
+
print('READ CONFIG ...')
|
114 |
+
config = read_training_config(config_path)
|
115 |
+
print('Config is read')
|
116 |
+
|
117 |
+
reaction_data_file = config['InputData']['reaction_data_path']
|
118 |
+
|
119 |
+
# reaction data mapping
|
120 |
+
startTime0 = datetime.now()
|
121 |
+
data_output_folder = os.path.join(config['General']['results_root'], 'reaction_data')
|
122 |
+
Path(data_output_folder).mkdir(parents=True, exist_ok=True)
|
123 |
+
mapped_data_file = os.path.join(data_output_folder, 'reaction_data_mapped.smi')
|
124 |
+
if config['DataCleaning']['map_reactions']:
|
125 |
+
print('\nMAP REACTION DATA ...')
|
126 |
+
|
127 |
+
remove_reagents_and_map_from_file(input_file=config['InputData']['reaction_data_path'],
|
128 |
+
output_file=mapped_data_file)
|
129 |
+
|
130 |
+
reaction_data_file = mapped_data_file
|
131 |
+
print("remove_reagents_and_map_from_file:", datetime.now() - startTime0)
|
132 |
+
|
133 |
+
# reaction data cleaning
|
134 |
+
startTime0 = datetime.now()
|
135 |
+
cleaned_data_file = os.path.join(data_output_folder, 'reaction_data_cleaned.rdf')
|
136 |
+
if config['DataCleaning']['clean_reactions']:
|
137 |
+
print('\nCLEAN REACTION DATA ...')
|
138 |
+
|
139 |
+
reactions_cleaner(input_file=reaction_data_file,
|
140 |
+
output_file=cleaned_data_file,
|
141 |
+
num_cpus=config['General']['num_cpus'])
|
142 |
+
|
143 |
+
reaction_data_file = cleaned_data_file
|
144 |
+
print("reactions_cleaner:", datetime.now() - startTime0)
|
145 |
+
|
146 |
+
# reactions data filtering
|
147 |
+
startTime0 = datetime.now()
|
148 |
+
if config['DataCleaning']['filter_reactions']:
|
149 |
+
print('\nFILTER REACTION DATA ...')
|
150 |
+
#
|
151 |
+
filtration_config = ReactionCheckConfig(
|
152 |
+
remove_small_molecules=False,
|
153 |
+
small_molecules_config=SmallMoleculesConfig(limit=6),
|
154 |
+
dynamic_bonds_config=DynamicBondsConfig(min_bonds_number=1, max_bonds_number=6),
|
155 |
+
no_reaction_config=NoReactionConfig(),
|
156 |
+
multi_center_config=MultiCenterConfig(),
|
157 |
+
wrong_ch_breaking_config=WrongCHBreakingConfig(),
|
158 |
+
cc_sp3_breaking_config=CCsp3BreakingConfig(),
|
159 |
+
cc_ring_breaking_config=CCRingBreakingConfig()
|
160 |
+
)
|
161 |
+
|
162 |
+
filtered_data_file = os.path.join(data_output_folder, 'reaction_data_filtered.rdf')
|
163 |
+
filter_reactions(config=filtration_config,
|
164 |
+
reaction_database_path=reaction_data_file,
|
165 |
+
result_directory_path=data_output_folder,
|
166 |
+
result_reactions_file_name='reaction_data_filtered',
|
167 |
+
num_cpus=config['General']['num_cpus'],
|
168 |
+
batch_size=100)
|
169 |
+
|
170 |
+
reaction_data_file = filtered_data_file
|
171 |
+
print("filter_reactions:", datetime.now() - startTime0)
|
172 |
+
|
173 |
+
# standardize building blocks
|
174 |
+
startTime0 = datetime.now()
|
175 |
+
if config['DataCleaning']['standardize_building_blocks']:
|
176 |
+
print('\nSTANDARDIZE BUILDING BLOCKS ...')
|
177 |
+
|
178 |
+
standardize_building_blocks(config['InputData']['building_blocks_path'],
|
179 |
+
config['InputData']['building_blocks_path'])
|
180 |
+
print("standardize_building_blocks:", datetime.now() - startTime0)
|
181 |
+
|
182 |
+
# reaction rules extraction
|
183 |
+
startTime0 = datetime.now()
|
184 |
+
print('\nEXTRACT REACTION RULES ...')
|
185 |
+
|
186 |
+
rules_output_folder = os.path.join(config['General']['results_root'], 'reaction_rules')
|
187 |
+
Path(rules_output_folder).mkdir(parents=True, exist_ok=True)
|
188 |
+
reaction_rules_path = os.path.join(rules_output_folder, 'reaction_rules_filtered.pickle')
|
189 |
+
config['InputData']['reaction_rules_path'] = reaction_rules_path
|
190 |
+
|
191 |
+
extract_rules_from_reactions(config=config,
|
192 |
+
reaction_file=reaction_data_file,
|
193 |
+
results_root=rules_output_folder,
|
194 |
+
num_cpus=config['General']['num_cpus'])
|
195 |
+
print("extract_rules_from_reactions:", datetime.now() - startTime0)
|
196 |
+
|
197 |
+
# create policy network dataset
|
198 |
+
startTime0 = datetime.now()
|
199 |
+
print('\nCREATE POLICY NETWORK DATASET ...')
|
200 |
+
policy_output_folder = os.path.join(config['General']['results_root'], 'policy_network')
|
201 |
+
Path(policy_output_folder).mkdir(parents=True, exist_ok=True)
|
202 |
+
policy_data_file = os.path.join(policy_output_folder, 'policy_dataset.pt')
|
203 |
+
|
204 |
+
if config['PolicyNetwork']['policy_type'] == 'ranking':
|
205 |
+
molecules_or_reactions_path = reaction_data_file
|
206 |
+
elif config['PolicyNetwork']['policy_type'] == 'filtering':
|
207 |
+
molecules_or_reactions_path = config['InputData']['policy_data_path']
|
208 |
+
else:
|
209 |
+
raise ValueError(
|
210 |
+
"Invalid policy_type. Allowed values are 'ranking', 'filtering'."
|
211 |
+
)
|
212 |
+
|
213 |
+
datamodule = create_policy_dataset(reaction_rules_path=reaction_rules_path,
|
214 |
+
molecules_or_reactions_path=molecules_or_reactions_path,
|
215 |
+
output_path=policy_data_file,
|
216 |
+
dataset_type=config['PolicyNetwork']['policy_type'],
|
217 |
+
batch_size=config['PolicyNetwork']['batch_size'],
|
218 |
+
num_cpus=config['General']['num_cpus'])
|
219 |
+
print("datamodule:", datetime.now() - startTime0)
|
220 |
+
|
221 |
+
# train policy network
|
222 |
+
startTime0 = datetime.now()
|
223 |
+
print('\nTRAIN POLICY NETWORK ...')
|
224 |
+
policy_config = PolicyNetworkConfig.from_dict(config['PolicyNetwork'])
|
225 |
+
run_policy_training(datamodule, config=policy_config, results_path=policy_output_folder)
|
226 |
+
config['PolicyNetwork']['weights_path'] = os.path.join(policy_output_folder, 'weights', 'policy_network.ckpt')
|
227 |
+
print("run_policy_training:", datetime.now() - startTime0)
|
228 |
+
|
229 |
+
# self-tuning value network training
|
230 |
+
startTime0 = datetime.now()
|
231 |
+
print('\nTRAIN VALUE NETWORK ...')
|
232 |
+
value_output_folder = os.path.join(config['General']['results_root'], 'value_network')
|
233 |
+
Path(value_output_folder).mkdir(parents=True, exist_ok=True)
|
234 |
+
config['ValueNetwork']['weights_path'] = os.path.join(value_output_folder, 'weights', 'value_network.ckpt')
|
235 |
+
|
236 |
+
run_self_tuning(config, results_root=value_output_folder)
|
237 |
+
print("run_self_tuning:", datetime.now() - startTime0)
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == '__main__':
|
241 |
+
main()
|
SynTool/interfaces/visualisation.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing functions for analysis and visualization of the built search tree
|
3 |
+
"""
|
4 |
+
|
5 |
+
from itertools import count, islice
|
6 |
+
|
7 |
+
from CGRtools.containers import MoleculeContainer
|
8 |
+
|
9 |
+
from SynTool import Tree
|
10 |
+
from SynTool.utils import path_type
|
11 |
+
|
12 |
+
|
13 |
+
def get_child_nodes(tree, molecule, graph):
|
14 |
+
nodes = []
|
15 |
+
try:
|
16 |
+
graph[molecule]
|
17 |
+
except KeyError:
|
18 |
+
return []
|
19 |
+
for retron in graph[molecule]:
|
20 |
+
temp_obj = {
|
21 |
+
"smiles": str(retron),
|
22 |
+
"type": "mol",
|
23 |
+
"in_stock": str(retron) in tree.building_blocks,
|
24 |
+
}
|
25 |
+
node = get_child_nodes(tree, retron, graph)
|
26 |
+
if node:
|
27 |
+
temp_obj["children"] = [node]
|
28 |
+
nodes.append(temp_obj)
|
29 |
+
return {"type": "reaction", "children": nodes}
|
30 |
+
|
31 |
+
|
32 |
+
def extract_routes(tree, extended=False):
|
33 |
+
"""
|
34 |
+
The function takes the target and the dictionary of
|
35 |
+
successors and predecessors and returns a list of dictionaries that contain the target
|
36 |
+
and the list of successors
|
37 |
+
:return: A list of dictionaries. Each dictionary contains a target, a list of children, and a
|
38 |
+
boolean indicating whether the target is in building_blocks.
|
39 |
+
"""
|
40 |
+
target = tree.nodes[1].retrons_to_expand[0].molecule
|
41 |
+
target_in_stock = tree.nodes[1].curr_retron.is_building_block(tree.building_blocks)
|
42 |
+
# Append encoded routes to list
|
43 |
+
paths_block = []
|
44 |
+
winning_nodes = []
|
45 |
+
if extended:
|
46 |
+
# Gather paths
|
47 |
+
for i, node in tree.nodes.items():
|
48 |
+
if node.is_solved():
|
49 |
+
winning_nodes.append(i)
|
50 |
+
else:
|
51 |
+
winning_nodes = tree.winning_nodes
|
52 |
+
if winning_nodes:
|
53 |
+
for winning_node in winning_nodes:
|
54 |
+
# Create graph for route
|
55 |
+
nodes = tree.path_to_node(winning_node)
|
56 |
+
graph, pred = {}, {}
|
57 |
+
for before, after in zip(nodes, nodes[1:]):
|
58 |
+
before = before.curr_retron.molecule
|
59 |
+
graph[before] = after = [x.molecule for x in after.new_retrons]
|
60 |
+
for x in after:
|
61 |
+
pred[x] = before
|
62 |
+
|
63 |
+
paths_block.append({"type": "mol", "smiles": str(target),
|
64 |
+
"in_stock": target_in_stock,
|
65 |
+
"children": [get_child_nodes(tree, target, graph)]})
|
66 |
+
else:
|
67 |
+
paths_block = [{"type": "mol", "smiles": str(target), "in_stock": target_in_stock, "children": []}]
|
68 |
+
return paths_block
|
69 |
+
|
70 |
+
|
71 |
+
def path_graph(tree, node: int) -> str:
|
72 |
+
"""
|
73 |
+
Visualizes reaction path
|
74 |
+
|
75 |
+
:param node: node id
|
76 |
+
:type node: int
|
77 |
+
:return: The SVG string.
|
78 |
+
"""
|
79 |
+
nodes = tree.path_to_node(node)
|
80 |
+
# Set up node_id types for different box colors
|
81 |
+
for node in nodes:
|
82 |
+
for retron in node.new_retrons:
|
83 |
+
retron._molecule.meta["status"] = "instock" if retron.is_building_block(
|
84 |
+
tree.building_blocks) else "mulecule"
|
85 |
+
nodes[0].curr_retron._molecule.meta["status"] = "target"
|
86 |
+
# Box colors
|
87 |
+
box_colors = {"target": "#98EEFF", # 152, 238, 255
|
88 |
+
"mulecule": "#F0AB90", # 240, 171, 144
|
89 |
+
"instock": "#9BFAB3", # 155, 250, 179
|
90 |
+
}
|
91 |
+
|
92 |
+
# first column is target
|
93 |
+
# second column are first new retrons_to_expand
|
94 |
+
columns = [[nodes[0].curr_retron.molecule], [x.molecule for x in nodes[1].new_retrons], ]
|
95 |
+
pred = {x: 0 for x in range(1, len(columns[1]) + 1)}
|
96 |
+
cx = [n for n, x in enumerate(nodes[1].new_retrons, 1) if not x.is_building_block(tree.building_blocks)]
|
97 |
+
size = len(cx)
|
98 |
+
nodes = iter(nodes[2:])
|
99 |
+
cy = count(len(columns[1]) + 1)
|
100 |
+
while size:
|
101 |
+
layer = []
|
102 |
+
for s in islice(nodes, size):
|
103 |
+
n = cx.pop(0)
|
104 |
+
for x in s.new_retrons:
|
105 |
+
layer.append(x)
|
106 |
+
m = next(cy)
|
107 |
+
if not x.is_building_block(tree.building_blocks):
|
108 |
+
cx.append(m)
|
109 |
+
pred[m] = n
|
110 |
+
size = len(cx)
|
111 |
+
columns.append([x.molecule for x in layer])
|
112 |
+
|
113 |
+
columns = [columns[::-1] for columns in columns[::-1]] # Reverse array to make retrosynthetic graph
|
114 |
+
pred = tuple( # Change dict to tuple to make multiple retrons_to_expand available
|
115 |
+
(abs(source - len(pred)), abs(target - len(pred))) for target, source in pred.items())
|
116 |
+
|
117 |
+
# now we have columns for visualizing
|
118 |
+
# lets start recalculate XY
|
119 |
+
x_shift = 0.0
|
120 |
+
c_max_x = 0.0
|
121 |
+
c_max_y = 0.0
|
122 |
+
render = []
|
123 |
+
cx = count()
|
124 |
+
cy = count()
|
125 |
+
arrow_points = {}
|
126 |
+
for ms in columns:
|
127 |
+
heights = []
|
128 |
+
for m in ms:
|
129 |
+
m.clean2d()
|
130 |
+
# X-shift for target
|
131 |
+
min_x = min(x for x, y in m._plane.values()) - x_shift
|
132 |
+
min_y = min(y for x, y in m._plane.values())
|
133 |
+
m._plane = {n: (x - min_x, y - min_y) for n, (x, y) in m._plane.items()}
|
134 |
+
max_x = max(x for x, y in m._plane.values())
|
135 |
+
if max_x > c_max_x:
|
136 |
+
c_max_x = max_x
|
137 |
+
arrow_points[next(cx)] = [x_shift, max_x]
|
138 |
+
heights.append(max(y for x, y in m._plane.values()))
|
139 |
+
|
140 |
+
x_shift = c_max_x + 5.0 # between columns gap
|
141 |
+
# calculate Y-shift
|
142 |
+
y_shift = sum(heights) + 3.0 * (len(heights) - 1)
|
143 |
+
if y_shift > c_max_y:
|
144 |
+
c_max_y = y_shift
|
145 |
+
y_shift /= 2.0
|
146 |
+
for m, h in zip(ms, heights):
|
147 |
+
m._plane = {n: (x, y - y_shift) for n, (x, y) in m._plane.items()}
|
148 |
+
|
149 |
+
# Calculate coordinates for boxes
|
150 |
+
max_x = max(x for x, y in m._plane.values()) + 0.9 # Max x
|
151 |
+
min_x = min(x for x, y in m._plane.values()) - 0.6 # Min x
|
152 |
+
max_y = -(max(y for x, y in m._plane.values()) + 0.45) # Max y
|
153 |
+
min_y = -(min(y for x, y in m._plane.values()) - 0.45) # Min y
|
154 |
+
x_delta = abs(max_x - min_x)
|
155 |
+
y_delta = abs(max_y - min_y)
|
156 |
+
box = (
|
157 |
+
f'<rect x="{min_x}" y="{max_y}" rx="{y_delta * 0.1}" ry="{y_delta * 0.1}" width="{x_delta}" height="{y_delta}"'
|
158 |
+
f' stroke="black" stroke-width=".0025" fill="{box_colors[m.meta["status"]]}" fill-opacity="0.30"/>')
|
159 |
+
arrow_points[next(cy)].append(y_shift - h / 2.0)
|
160 |
+
y_shift -= h + 3.0
|
161 |
+
depicted_molecule = list(m.depict(embedding=True))[:3]
|
162 |
+
depicted_molecule.append(box)
|
163 |
+
render.append(depicted_molecule)
|
164 |
+
|
165 |
+
# Calculate mid-X coordinate to draw square arrows
|
166 |
+
graph = {}
|
167 |
+
for s, p in pred:
|
168 |
+
try:
|
169 |
+
graph[s].append(p)
|
170 |
+
except KeyError:
|
171 |
+
graph[s] = [p]
|
172 |
+
for s, ps in graph.items():
|
173 |
+
mid_x = float("-inf")
|
174 |
+
for p in ps:
|
175 |
+
s_min_x, s_max, s_y = arrow_points[s][:3] # s
|
176 |
+
p_min_x, p_max, p_y = arrow_points[p][:3] # p
|
177 |
+
p_max += 1
|
178 |
+
mid = p_max + (s_min_x - p_max) / 3
|
179 |
+
if mid > mid_x:
|
180 |
+
mid_x = mid
|
181 |
+
for p in ps:
|
182 |
+
arrow_points[p].append(mid_x)
|
183 |
+
|
184 |
+
config = MoleculeContainer._render_config
|
185 |
+
font_size = config["font_size"]
|
186 |
+
font125 = 1.25 * font_size
|
187 |
+
width = c_max_x + 4.0 * font_size # 3.0 by default
|
188 |
+
height = c_max_y + 3.5 * font_size # 2.5 by default
|
189 |
+
box_y = height / 2.0
|
190 |
+
svg = [f'<svg width="{0.6 * width:.2f}cm" height="{0.6 * height:.2f}cm" '
|
191 |
+
f'viewBox="{-font125:.2f} {-box_y:.2f} {width:.2f} '
|
192 |
+
f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">',
|
193 |
+
' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
|
194 |
+
'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>', ]
|
195 |
+
|
196 |
+
for s, p in pred:
|
197 |
+
"""
|
198 |
+
(x1, y1) = (p_max, p_y)
|
199 |
+
(x2, y2) = (s_min_x, s_y)
|
200 |
+
polyline: (x1 y1, x2 y2, x3 y3, ..., xN yN)
|
201 |
+
"""
|
202 |
+
s_min_x, s_max, s_y = arrow_points[s][:3]
|
203 |
+
p_min_x, p_max, p_y = arrow_points[p][:3]
|
204 |
+
p_max += 1
|
205 |
+
mid_x = arrow_points[p][-1] # p_max + (s_min_x - p_max) / 3
|
206 |
+
"""print(f"s_min_x: {s_min_x}, s_max: {s_max}, s_y: {s_y}")
|
207 |
+
print(f"p_min_x: {p_min_x}, p_max: {p_max}, p_y: {p_y}")
|
208 |
+
print(f"mid_x: {mid_x}\n")"""
|
209 |
+
|
210 |
+
arrow = f""" <polyline points="{p_max:.2f} {p_y:.2f}, {mid_x:.2f} {p_y:.2f}, {mid_x:.2f} {s_y:.2f}, {s_min_x - 1.:.2f} {s_y:.2f}"
|
211 |
+
fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>"""
|
212 |
+
if p_y != s_y:
|
213 |
+
arrow += f' <circle cx="{mid_x}" cy="{p_y}" r="0.1"/>'
|
214 |
+
svg.append(arrow)
|
215 |
+
for atoms, bonds, masks, box in render:
|
216 |
+
molecule_svg = MoleculeContainer._graph_svg(atoms, bonds, masks, -font125, -box_y, width, height)
|
217 |
+
molecule_svg.insert(1, box)
|
218 |
+
svg.extend(molecule_svg)
|
219 |
+
svg.append("</svg>")
|
220 |
+
return "\n".join(svg)
|
221 |
+
|
222 |
+
|
223 |
+
def to_table(tree: Tree, html_path: path_type, aam: bool = False, extended=False, integration: bool = False):
|
224 |
+
"""
|
225 |
+
Write an HTML page with the synthesis paths in SVG format and corresponding reactions in SMILES format
|
226 |
+
|
227 |
+
:param tree: # TODO
|
228 |
+
:param extended: # TODO
|
229 |
+
:param html_path: Path to save the HTML molecules_path, if None returns the html without saving it
|
230 |
+
:type html_path: str (optional)
|
231 |
+
:param aam: depict atom-to-atom mapping
|
232 |
+
:type aam: bool (optional)
|
233 |
+
:param integration: Whenever to output the full html file (False) or only the body (True)
|
234 |
+
:type integration: bool
|
235 |
+
"""
|
236 |
+
if aam:
|
237 |
+
MoleculeContainer.depict_settings(aam=True)
|
238 |
+
else:
|
239 |
+
MoleculeContainer.depict_settings(aam=False)
|
240 |
+
|
241 |
+
paths = []
|
242 |
+
if extended:
|
243 |
+
# Gather paths
|
244 |
+
for idx, node in tree.nodes.items():
|
245 |
+
if node.is_solved():
|
246 |
+
paths.append(idx)
|
247 |
+
else:
|
248 |
+
paths = tree.winning_nodes
|
249 |
+
# HTML Tags
|
250 |
+
th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
|
251 |
+
td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
|
252 |
+
font_red = "<font color='red' style='font-weight: bold'>"
|
253 |
+
font_green = "<font color='light-green' style='font-weight: bold'>"
|
254 |
+
font_head = "<font style='font-weight: bold; font-size: 18px'>"
|
255 |
+
font_normal = "<font style='font-weight: normal; font-size: 18px'>"
|
256 |
+
font_close = "</font>"
|
257 |
+
|
258 |
+
template_begin = """
|
259 |
+
<!doctype html>
|
260 |
+
<html lang="en">
|
261 |
+
<head>
|
262 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css"
|
263 |
+
rel="stylesheet"
|
264 |
+
integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
|
265 |
+
crossorigin="anonymous">
|
266 |
+
<script
|
267 |
+
src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"
|
268 |
+
integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
|
269 |
+
crossorigin="anonymous">
|
270 |
+
</script>
|
271 |
+
<meta charset="utf-8">
|
272 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
273 |
+
<title>Predicted Paths Report</title>
|
274 |
+
<meta name="description" content="A simple HTML5 Template for new projects.">
|
275 |
+
<meta name="author" content="SitePoint">
|
276 |
+
</head>
|
277 |
+
<body>
|
278 |
+
"""
|
279 |
+
template_end = """
|
280 |
+
</body>
|
281 |
+
</html>
|
282 |
+
"""
|
283 |
+
# SVG Template
|
284 |
+
# box_mark = """
|
285 |
+
# <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg">
|
286 |
+
# <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
|
287 |
+
# </svg>
|
288 |
+
# """
|
289 |
+
# table = f"<table><thead><{th}>Retrosynthetic Routes</th></thead><tbody>"
|
290 |
+
table = """<table class="table table-striped table-hover caption-top">"""
|
291 |
+
if not integration:
|
292 |
+
table += "<caption><h3>Retrosynthetic Routes Report</h3></caption><tbody>"
|
293 |
+
else:
|
294 |
+
table += "<tbody>"
|
295 |
+
|
296 |
+
# Gather path data
|
297 |
+
table += f"<tr>{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_retron)}{font_close}</td></tr>"
|
298 |
+
table += (f"<tr>{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes</td></tr>")
|
299 |
+
table += f"<tr>{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}</td></tr>"
|
300 |
+
table += f"<tr>{td}{font_normal}Found paths: {len(paths)}{font_close}</td></tr>"
|
301 |
+
table += f"<tr>{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds</td></tr>"
|
302 |
+
table += f"""\
|
303 |
+
<tr>{td} \
|
304 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
|
305 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb(152, 238, 255)" fill-opacity="0.35" /></svg> \
|
306 |
+
Target Molecule \
|
307 |
+
\
|
308 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
|
309 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb(240, 171, 144)" fill-opacity="0.35" /></svg> \
|
310 |
+
Molecule Not In Stock \
|
311 |
+
\
|
312 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> \
|
313 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb(155, 250, 179)" fill-opacity="0.35" /></svg> \
|
314 |
+
Molecule In Stock \
|
315 |
+
\
|
316 |
+
</td></tr> \
|
317 |
+
"""
|
318 |
+
|
319 |
+
for path in paths:
|
320 |
+
svg = path_graph(tree, path) # Get SVG
|
321 |
+
full_path = tree.synthesis_path(path) # Get Path
|
322 |
+
# Write SMILES of all reactions in synthesis path
|
323 |
+
step = 1
|
324 |
+
reactions = ""
|
325 |
+
for synth_step in full_path:
|
326 |
+
reactions += f"<b>Step {step}:</b> {str(synth_step)}<br>"
|
327 |
+
step += 1
|
328 |
+
# Concatenate all content of path
|
329 |
+
path_score = round(tree.path_score(path), 3)
|
330 |
+
table += (f'<tr style="line-height: 250%">{td}{font_head}Path {path}; '
|
331 |
+
f"Steps: {len(full_path)}; "
|
332 |
+
f"Cumulated nodes' value: {path_score}{font_close}</td></tr>")
|
333 |
+
# f"Cumulated nodes' value: {node._probabilities[path]}{font_close}</td></tr>"
|
334 |
+
table += f"<tr>{td}{svg}</td></tr>"
|
335 |
+
table += f"<tr>{td}{reactions}</td></tr>"
|
336 |
+
table += "</tbody>"
|
337 |
+
|
338 |
+
# Save or display output
|
339 |
+
if not html_path:
|
340 |
+
return table if integration else template_begin + table + template_end
|
341 |
+
|
342 |
+
output = html_path
|
343 |
+
with open(output, "w") as html_file:
|
344 |
+
html_file.write(template_begin)
|
345 |
+
html_file.write(table)
|
346 |
+
html_file.write(template_end)
|
SynTool/mcts/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .node import *
|
2 |
+
from .tree import *
|
3 |
+
from CGRtools.containers import MoleculeContainer
|
4 |
+
|
5 |
+
MoleculeContainer.depict_settings(aam=False)
|
6 |
+
|
7 |
+
__all__ = ["Tree", "Node"]
|
SynTool/mcts/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (322 Bytes). View file
|
|
SynTool/mcts/__pycache__/evaluation.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
SynTool/mcts/__pycache__/expansion.cpython-310.pyc
ADDED
Binary file (2.86 kB). View file
|
|
SynTool/mcts/__pycache__/node.cpython-310.pyc
ADDED
Binary file (2.15 kB). View file
|
|
SynTool/mcts/__pycache__/search.cpython-310.pyc
ADDED
Binary file (4.68 kB). View file
|
|
SynTool/mcts/__pycache__/tree.cpython-310.pyc
ADDED
Binary file (18.7 kB). View file
|
|
SynTool/mcts/evaluation.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing a class that represents a value function for prediction of synthesisablity
|
3 |
+
of new nodes in the search tree
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from SynTool.chem.retron import compose_retrons
|
12 |
+
from SynTool.ml.networks.value import ValueNetwork
|
13 |
+
from SynTool.ml.training import mol_to_pyg
|
14 |
+
|
15 |
+
|
16 |
+
class ValueFunction:
|
17 |
+
"""
|
18 |
+
Value function based on value neural network for node evaluation (synthesisability prediction) in MCTS
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, weights_path: str) -> None:
|
22 |
+
"""
|
23 |
+
The value function predicts the probability to synthesize the target molecule with available building blocks
|
24 |
+
starting from a given retron.
|
25 |
+
|
26 |
+
:param weights_path: The value network weights location
|
27 |
+
:type weights_path: Path
|
28 |
+
"""
|
29 |
+
|
30 |
+
value_net = ValueNetwork.load_from_checkpoint(
|
31 |
+
weights_path,
|
32 |
+
map_location=torch.device("cpu")
|
33 |
+
)
|
34 |
+
|
35 |
+
self.value_network = value_net.eval()
|
36 |
+
|
37 |
+
def predict_value(self, retrons: list) -> float:
|
38 |
+
"""
|
39 |
+
The function predicts a value based on the given retrons. For prediction, retrons must be composed into a single
|
40 |
+
molecule (product)
|
41 |
+
|
42 |
+
:param retrons: The list of retrons
|
43 |
+
:type retrons: list
|
44 |
+
"""
|
45 |
+
|
46 |
+
molecule = compose_retrons(retrons=retrons, exclude_small=True)
|
47 |
+
pyg_graph = mol_to_pyg(molecule)
|
48 |
+
if pyg_graph:
|
49 |
+
with torch.no_grad():
|
50 |
+
value_pred = self.value_network.forward(pyg_graph)[0].item()
|
51 |
+
else:
|
52 |
+
try:
|
53 |
+
logging.debug(f"Molecule {str(molecule)} was not preprocessed. Giving value equal to -1e6.")
|
54 |
+
except:
|
55 |
+
logging.debug(f"There is a molecule for which SMILES cannot be generated")
|
56 |
+
|
57 |
+
value_pred = -1e6
|
58 |
+
|
59 |
+
return value_pred
|
SynTool/mcts/expansion.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing a class that represents a policy function for node expansion in the search tree
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch_geometric
|
7 |
+
from SynTool.chem.retron import Retron
|
8 |
+
from SynTool.ml.networks.policy import PolicyNetwork
|
9 |
+
from SynTool.ml.training import mol_to_pyg
|
10 |
+
from SynTool.utils.config import PolicyNetworkConfig
|
11 |
+
|
12 |
+
|
13 |
+
class PolicyFunction:
|
14 |
+
"""
|
15 |
+
Policy function based on policy neural network for node expansion in MCTS
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, policy_config: PolicyNetworkConfig, compile: bool = False):
|
19 |
+
"""
|
20 |
+
Initializes the expansion function (ranking or filter policy network).
|
21 |
+
|
22 |
+
:param policy_config: A configuration object settings for the expansion policy
|
23 |
+
:type policy_config: PolicyConfig
|
24 |
+
:param compile: XX # TODO what is compile # TODO2 compile is a bad variable name - is a builtin function name
|
25 |
+
:type compile: bool
|
26 |
+
"""
|
27 |
+
|
28 |
+
self.config = policy_config
|
29 |
+
|
30 |
+
policy_net = PolicyNetwork.load_from_checkpoint(
|
31 |
+
self.config.weights_path,
|
32 |
+
map_location=torch.device("cpu"),
|
33 |
+
batch_size=1,
|
34 |
+
dropout=0
|
35 |
+
)
|
36 |
+
|
37 |
+
policy_net = policy_net.eval()
|
38 |
+
if compile:
|
39 |
+
self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
|
40 |
+
else:
|
41 |
+
self.policy_net = policy_net
|
42 |
+
|
43 |
+
def predict_reaction_rules(self, retron: Retron, reaction_rules: list): # TODO what is output - finish annotation
|
44 |
+
"""
|
45 |
+
The policy function predicts the list of reaction rules given a retron.
|
46 |
+
|
47 |
+
:param retron: The current retron for which the reaction rules are predicted
|
48 |
+
:type retron: Retron
|
49 |
+
:param reaction_rules: The list of reaction rules from which applicable reaction rules are predicted and selected.
|
50 |
+
:type reaction_rules: list
|
51 |
+
"""
|
52 |
+
|
53 |
+
pyg_graph = mol_to_pyg(retron.molecule, canonicalize=False)
|
54 |
+
if pyg_graph:
|
55 |
+
with torch.no_grad():
|
56 |
+
if self.policy_net.policy_type == "filtering":
|
57 |
+
probs, priority = self.policy_net.forward(pyg_graph)
|
58 |
+
if self.policy_net.policy_type == "ranking":
|
59 |
+
probs = self.policy_net.forward(pyg_graph)
|
60 |
+
del pyg_graph
|
61 |
+
else:
|
62 |
+
return []
|
63 |
+
|
64 |
+
probs = probs[0].double()
|
65 |
+
if self.policy_net.policy_type == "filtering":
|
66 |
+
priority = priority[0].double()
|
67 |
+
priority_coef = self.config.priority_rules_fraction
|
68 |
+
probs = (1 - priority_coef) * probs + priority_coef * priority
|
69 |
+
|
70 |
+
sorted_probs, sorted_rules = torch.sort(probs, descending=True)
|
71 |
+
sorted_probs, sorted_rules = (
|
72 |
+
sorted_probs[: self.config.top_rules],
|
73 |
+
sorted_rules[: self.config.top_rules],
|
74 |
+
)
|
75 |
+
|
76 |
+
if self.policy_net.policy_type == "filtering":
|
77 |
+
sorted_probs = torch.softmax(sorted_probs, -1)
|
78 |
+
|
79 |
+
sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
|
80 |
+
|
81 |
+
for prob, rule_id in zip(sorted_probs, sorted_rules):
|
82 |
+
if prob > self.config.rule_prob_threshold: # TODO it will destroy all search if it is not correct (>0.5)
|
83 |
+
yield prob, reaction_rules[rule_id], rule_id
|
SynTool/mcts/node.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing a class Node that represents a node in the search tree
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
class Node:
|
7 |
+
"""
|
8 |
+
Node class represents a node in the search tree
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, retrons_to_expand: tuple = None, new_retrons: tuple = None) -> None:
|
12 |
+
"""
|
13 |
+
The function initializes the new Node object.
|
14 |
+
|
15 |
+
:param retrons_to_expand: The tuple of retrons to be expanded. The first retron in the tuple is the current
|
16 |
+
retron which will be expanded (for which new retrons will be generated by applying the predicted reaction
|
17 |
+
rules). When the first retron has been successfully expanded, the second retron becomes the current retron
|
18 |
+
to be expanded.
|
19 |
+
:param new_retrons: The tuple of new retrons generated by applying the reaction rule. New retrons have already
|
20 |
+
been added to the retrons_to_expand (see Tree._expand_node). Here they are stored for information.
|
21 |
+
"""
|
22 |
+
|
23 |
+
self.retrons_to_expand = retrons_to_expand
|
24 |
+
self.new_retrons = new_retrons
|
25 |
+
|
26 |
+
if len(self.retrons_to_expand) == 0:
|
27 |
+
self.curr_retron = tuple()
|
28 |
+
else:
|
29 |
+
self.curr_retron = self.retrons_to_expand[0]
|
30 |
+
self.next_retrons = self.retrons_to_expand[1:]
|
31 |
+
|
32 |
+
def __len__(self) -> int:
|
33 |
+
"""
|
34 |
+
The number of retrons in this node to expand.
|
35 |
+
"""
|
36 |
+
return len(self.retrons_to_expand)
|
37 |
+
|
38 |
+
def __repr__(self) -> str:
|
39 |
+
"""
|
40 |
+
String representation of the node. Returns the smiles of retrons_to_expand and new_retrons.
|
41 |
+
"""
|
42 |
+
return f"retrons_to_expand: {self.retrons_to_expand}\nnew_retrons: {self.new_retrons}"
|
43 |
+
|
44 |
+
def is_solved(self) -> bool:
|
45 |
+
"""
|
46 |
+
Is terminal node. There are not retrons for expansion.
|
47 |
+
"""
|
48 |
+
|
49 |
+
return len(self.retrons_to_expand) == 0
|
SynTool/mcts/search.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing functions for running tree search for the set of target molecules
|
3 |
+
"""
|
4 |
+
|
5 |
+
import csv
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from SynTool.interfaces.visualisation import to_table, extract_routes
|
12 |
+
from SynTool.mcts.tree import Tree, TreeConfig
|
13 |
+
from SynTool.mcts.evaluation import ValueFunction
|
14 |
+
from SynTool.mcts.expansion import PolicyFunction
|
15 |
+
from SynTool.utils import path_type
|
16 |
+
from SynTool.utils.files import MoleculeReader
|
17 |
+
from SynTool.utils.config import PolicyNetworkConfig
|
18 |
+
|
19 |
+
|
20 |
+
def extract_tree_stats(tree, target):
|
21 |
+
"""
|
22 |
+
Collects various statistics from a tree and returns them in a dictionary format
|
23 |
+
|
24 |
+
:param tree: The retro tree.
|
25 |
+
:param target: The target molecule or compound that you want to search for in the tree. It is
|
26 |
+
expected to be a string representing the SMILES notation of the target molecule
|
27 |
+
:return: A dictionary with the calculated statistics
|
28 |
+
"""
|
29 |
+
newick_tree, newick_meta = tree.newickify(visits_threshold=0)
|
30 |
+
newick_meta_line = ";".join([f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()])
|
31 |
+
return {
|
32 |
+
"target_smiles": str(target),
|
33 |
+
"tree_size": len(tree),
|
34 |
+
"search_time": round(tree.curr_time, 1),
|
35 |
+
"found_paths": len(tree.winning_nodes),
|
36 |
+
"newick_tree": newick_tree,
|
37 |
+
"newick_meta": newick_meta_line,
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def tree_search(
|
42 |
+
targets_path: path_type,
|
43 |
+
tree_config: TreeConfig,
|
44 |
+
policy_config: PolicyNetworkConfig,
|
45 |
+
reaction_rules_path: path_type,
|
46 |
+
building_blocks_path: path_type,
|
47 |
+
policy_weights_path: path_type = None, # TODO not used
|
48 |
+
value_weights_path: path_type = None,
|
49 |
+
results_root: path_type = "search_results"
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Performs a tree search on a set of target molecules using specified configuration and rules,
|
53 |
+
logging the results and statistics.
|
54 |
+
|
55 |
+
:param tree_config: The config object containing the configuration for the tree search.
|
56 |
+
:param policy_config: The config object containing the configuration for the policy.
|
57 |
+
:param reaction_rules_path: The path to the file containing reaction rules.
|
58 |
+
:param building_blocks_path: The path to the file containing building blocks.
|
59 |
+
:param targets_path: The path to the file containing the target molecules (in SDF or SMILES format).
|
60 |
+
:param value_weights_path: The path to the file containing value weights (optional).
|
61 |
+
:param results_root: The path to the directory where the results of the tree search will be saved. Defaults to 'search_results/'.
|
62 |
+
:param retropaths_files_name: The base name for the files that will be generated to store the retro paths. Defaults to 'retropath'. #TODO arg dont exist
|
63 |
+
|
64 |
+
This function configures and executes a tree search algorithm, leveraging reaction rules and building blocks
|
65 |
+
to find synthetic pathways for given target molecules. The results, including paths and statistics, are
|
66 |
+
saved in the specified directory. Logging is used to record the process and any issues encountered.
|
67 |
+
"""
|
68 |
+
|
69 |
+
targets_file = Path(targets_path)
|
70 |
+
|
71 |
+
# results folder
|
72 |
+
results_root = Path(results_root)
|
73 |
+
if not results_root.exists():
|
74 |
+
results_root.mkdir()
|
75 |
+
|
76 |
+
# output files
|
77 |
+
stats_file = results_root.joinpath("tree_search_stats.csv")
|
78 |
+
paths_file = results_root.joinpath("extracted_paths.json")
|
79 |
+
retropaths_folder = results_root.joinpath("retropaths")
|
80 |
+
retropaths_folder.mkdir(exist_ok=True)
|
81 |
+
|
82 |
+
# stats header
|
83 |
+
stats_header = ["target_smiles", "tree_size", "search_time",
|
84 |
+
"found_paths", "newick_tree", "newick_meta"]
|
85 |
+
|
86 |
+
# config
|
87 |
+
policy_function = PolicyFunction(policy_config=policy_config)
|
88 |
+
if tree_config.evaluation_type == 'gcn':
|
89 |
+
value_function = ValueFunction(weights_path=value_weights_path)
|
90 |
+
else:
|
91 |
+
value_function = None
|
92 |
+
|
93 |
+
# run search
|
94 |
+
n_solved = 0
|
95 |
+
extracted_paths = []
|
96 |
+
with MoleculeReader(targets_file) as targets_path, open(stats_file, "w", newline="\n") as csvfile:
|
97 |
+
statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header)
|
98 |
+
statswriter.writeheader()
|
99 |
+
|
100 |
+
for ti, target in tqdm(enumerate(targets_path), total=len(targets_path)):
|
101 |
+
|
102 |
+
try:
|
103 |
+
# run search
|
104 |
+
tree = Tree(
|
105 |
+
target=target,
|
106 |
+
tree_config=tree_config,
|
107 |
+
reaction_rules_path=reaction_rules_path,
|
108 |
+
building_blocks_path=building_blocks_path,
|
109 |
+
policy_function=policy_function,
|
110 |
+
value_function=value_function,
|
111 |
+
)
|
112 |
+
_ = list(tree)
|
113 |
+
|
114 |
+
except:
|
115 |
+
continue
|
116 |
+
|
117 |
+
n_solved += bool(tree.winning_nodes)
|
118 |
+
|
119 |
+
# extract routes
|
120 |
+
extracted_paths.append(extract_routes(tree))
|
121 |
+
|
122 |
+
# retropaths
|
123 |
+
retropaths_file = retropaths_folder.joinpath(f"retropaths_target_{ti}.html")
|
124 |
+
to_table(tree, retropaths_file, extended=True)
|
125 |
+
|
126 |
+
# stats
|
127 |
+
statswriter.writerow(extract_tree_stats(tree, target))
|
128 |
+
csvfile.flush()
|
129 |
+
|
130 |
+
#
|
131 |
+
with open(paths_file, 'w') as f:
|
132 |
+
json.dump(extracted_paths, f)
|
133 |
+
|
134 |
+
print(f"Solved number of target molecules: {n_solved}")
|
135 |
+
|
SynTool/mcts/tree.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing a class Tree that used for tree search of retrosynthetic paths
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from collections import deque, defaultdict
|
7 |
+
from math import sqrt
|
8 |
+
from random import choice, uniform
|
9 |
+
from time import time
|
10 |
+
from typing import Dict, Set, List, Tuple
|
11 |
+
|
12 |
+
from CGRtools.containers import MoleculeContainer
|
13 |
+
from CGRtools import smiles
|
14 |
+
from numpy.random import uniform
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
from SynTool.utils.loading import load_building_blocks, load_reaction_rules
|
17 |
+
from SynTool.chem.reaction import Reaction, apply_reaction_rule
|
18 |
+
from SynTool.chem.retron import Retron
|
19 |
+
from SynTool.mcts.evaluation import ValueFunction
|
20 |
+
from SynTool.mcts.expansion import PolicyFunction
|
21 |
+
from SynTool.mcts.node import Node
|
22 |
+
from SynTool.utils.config import TreeConfig
|
23 |
+
|
24 |
+
|
25 |
+
class Tree:
|
26 |
+
"""
|
27 |
+
Tree class with attributes and methods for Monte-Carlo tree search
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
target: MoleculeContainer,
|
33 |
+
tree_config: TreeConfig,
|
34 |
+
reaction_rules_path: str,
|
35 |
+
building_blocks_path: str,
|
36 |
+
policy_function: PolicyFunction,
|
37 |
+
value_function: ValueFunction = None,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
The function initializes a tree object with optional parameters for tree search for target molecule.
|
41 |
+
|
42 |
+
:param target: a target molecule for retrosynthesis paths search
|
43 |
+
:type target: MoleculeContainer
|
44 |
+
:param tree_config: a tree configuration file for retrosynthesis paths search
|
45 |
+
:type tree_config: TreeConfig
|
46 |
+
:param reaction_rules_path: a path for reaction rules file
|
47 |
+
:type reaction_rules_path: str
|
48 |
+
:param building_blocks_path: a path for building blocks file
|
49 |
+
:type building_blocks_path: str
|
50 |
+
:param policy_function: a policy function object
|
51 |
+
:type policy_function: PolicyFunction
|
52 |
+
:param value_function: a value function object
|
53 |
+
:type value_function: ValueFunction
|
54 |
+
"""
|
55 |
+
|
56 |
+
# config parameters
|
57 |
+
self.config = tree_config
|
58 |
+
|
59 |
+
# check target
|
60 |
+
if isinstance(target, str):
|
61 |
+
target = smiles(target)
|
62 |
+
assert (bool(target)), "Target is not defined, is not a MoleculeContainer or have no atoms"
|
63 |
+
if target:
|
64 |
+
target.canonicalize()
|
65 |
+
|
66 |
+
target_retron = Retron(target, canonicalize=True)
|
67 |
+
target_retron.prev_retrons.append(Retron(target, canonicalize=True))
|
68 |
+
target_node = Node(retrons_to_expand=(target_retron,), new_retrons=(target_retron,))
|
69 |
+
|
70 |
+
# tree structure init
|
71 |
+
self.nodes: Dict[int, Node] = {1: target_node}
|
72 |
+
self.parents: Dict[int, int] = {1: 0}
|
73 |
+
self.children: Dict[int, Set[int]] = {1: set()}
|
74 |
+
self.winning_nodes: List[int] = list()
|
75 |
+
self.visited_nodes: Set[int] = set()
|
76 |
+
self.expanded_nodes: Set[int] = set()
|
77 |
+
self.nodes_visit: Dict[int, int] = {1: 0}
|
78 |
+
self.nodes_depth: Dict[int, int] = {1: 0}
|
79 |
+
self.nodes_prob: Dict[int, float] = {1: 0.0}
|
80 |
+
self.nodes_init_value: Dict[int, float] = {1: 0.0}
|
81 |
+
self.nodes_total_value: Dict[int, float] = {1: 0.0}
|
82 |
+
|
83 |
+
# tree building limits
|
84 |
+
self.curr_iteration: int = 0
|
85 |
+
self.curr_tree_size: int = 2
|
86 |
+
self.curr_time: float = 2
|
87 |
+
|
88 |
+
# utils
|
89 |
+
self._tqdm = None
|
90 |
+
|
91 |
+
# policy and value functions
|
92 |
+
self.policy_function = policy_function
|
93 |
+
if self.config.evaluation_type == "gcn":
|
94 |
+
if value_function is None:
|
95 |
+
raise ValueError(
|
96 |
+
"Value function not specified while evaluation mode is 'gcn'"
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.value_function = value_function
|
100 |
+
|
101 |
+
# building blocks and reaction reaction_rules
|
102 |
+
self.reaction_rules = load_reaction_rules(reaction_rules_path)
|
103 |
+
self.building_blocks = load_building_blocks(building_blocks_path)
|
104 |
+
|
105 |
+
def __len__(self) -> int:
|
106 |
+
"""
|
107 |
+
Returns the current size (number of nodes) of a Tree.
|
108 |
+
"""
|
109 |
+
|
110 |
+
return self.curr_tree_size - 1
|
111 |
+
|
112 |
+
def __iter__(self) -> "Tree": # TODO what is annotation "Tree" -> Tree ?
|
113 |
+
"""
|
114 |
+
The function is defining an iterator for a Tree object. Also needed for the bar progress display.
|
115 |
+
"""
|
116 |
+
|
117 |
+
if not self._tqdm:
|
118 |
+
self._start_time = time()
|
119 |
+
self._tqdm = tqdm(
|
120 |
+
total=self.config.max_iterations, disable=self.config.silent
|
121 |
+
)
|
122 |
+
return self
|
123 |
+
|
124 |
+
def __repr__(self) -> str:
|
125 |
+
"""
|
126 |
+
Returns a string representation of a Tree object (target smiles, tree size, and the number of found paths).
|
127 |
+
"""
|
128 |
+
return self.report()
|
129 |
+
|
130 |
+
def __next__(self): # TODO what is return - function annotation ? tuple (bool, [node id])
|
131 |
+
"""
|
132 |
+
The __next__ function is used to do one iteration of the tree building.
|
133 |
+
"""
|
134 |
+
|
135 |
+
if self.nodes[1].curr_retron.is_building_block(self.building_blocks, self.config.min_mol_size):
|
136 |
+
raise StopIteration("Target is building block \n")
|
137 |
+
|
138 |
+
if self.curr_iteration >= self.config.max_iterations:
|
139 |
+
self._tqdm.close()
|
140 |
+
raise StopIteration("Iterations limit exceeded. \n")
|
141 |
+
elif self.curr_tree_size >= self.config.max_tree_size:
|
142 |
+
self._tqdm.close()
|
143 |
+
raise StopIteration("Max tree size exceeded or all possible paths found")
|
144 |
+
elif self.curr_time >= self.config.max_time:
|
145 |
+
self._tqdm.close()
|
146 |
+
raise StopIteration("Time limit exceeded. \n")
|
147 |
+
|
148 |
+
# start new iteration
|
149 |
+
self.curr_iteration += 1
|
150 |
+
self.curr_time = time() - self._start_time
|
151 |
+
self._tqdm.update()
|
152 |
+
|
153 |
+
curr_depth, node_id = 0, 1 # start from the root node_id
|
154 |
+
|
155 |
+
explore_path = True
|
156 |
+
while explore_path:
|
157 |
+
self.visited_nodes.add(node_id)
|
158 |
+
|
159 |
+
if self.nodes_visit[node_id]: # already visited
|
160 |
+
if not self.children[node_id]: # dead node
|
161 |
+
logging.debug(
|
162 |
+
f"Tree search: bumped into node {node_id} which is dead"
|
163 |
+
)
|
164 |
+
self._update_visits(node_id)
|
165 |
+
explore_path = False
|
166 |
+
else:
|
167 |
+
node_id = self._select_node(node_id) # select the child node
|
168 |
+
curr_depth += 1
|
169 |
+
else:
|
170 |
+
if self.nodes[node_id].is_solved(): # found path!
|
171 |
+
self._update_visits(node_id) # this prevents expanding of bb node_id
|
172 |
+
self.winning_nodes.append(node_id)
|
173 |
+
return True, [node_id]
|
174 |
+
|
175 |
+
elif (
|
176 |
+
curr_depth < self.config.max_depth
|
177 |
+
): # expand node if depth limit is not reached
|
178 |
+
self._expand_node(node_id)
|
179 |
+
if not self.children[node_id]: # node was not expanded
|
180 |
+
logging.debug(f"Tree search: node {node_id} was not expanded")
|
181 |
+
value_to_backprop = -1.0
|
182 |
+
else:
|
183 |
+
self.expanded_nodes.add(node_id)
|
184 |
+
|
185 |
+
if self.config.search_strategy == "evaluation_first":
|
186 |
+
# recalculate node value based on children synthesisability and backpropagation
|
187 |
+
child_values = [
|
188 |
+
self.nodes_init_value[child_id]
|
189 |
+
for child_id in self.children[node_id]
|
190 |
+
]
|
191 |
+
|
192 |
+
if self.config.evaluation_agg == "max":
|
193 |
+
value_to_backprop = max(child_values)
|
194 |
+
|
195 |
+
elif self.config.evaluation_agg == "average":
|
196 |
+
value_to_backprop = sum(child_values) / len(
|
197 |
+
self.children[node_id]
|
198 |
+
)
|
199 |
+
|
200 |
+
else:
|
201 |
+
raise ValueError(
|
202 |
+
f"Invalid evaluation aggregation mode: {self.config.evaluation_agg} "
|
203 |
+
f"Allowed values are 'max', 'average'"
|
204 |
+
)
|
205 |
+
elif self.config.search_strategy == "expansion_first":
|
206 |
+
value_to_backprop = self._get_node_value(node_id)
|
207 |
+
|
208 |
+
else:
|
209 |
+
raise ValueError(
|
210 |
+
f"Invalid search_strategy: {self.config.search_strategy}: "
|
211 |
+
f"Allowed values are 'expansion_first', 'evaluation_first'"
|
212 |
+
)
|
213 |
+
|
214 |
+
# backpropagation
|
215 |
+
self._backpropagate(node_id, value_to_backprop)
|
216 |
+
self._update_visits(node_id)
|
217 |
+
explore_path = False
|
218 |
+
|
219 |
+
if self.children[node_id]:
|
220 |
+
# found after expansion
|
221 |
+
found_after_expansion = set()
|
222 |
+
for child_id in iter(self.children[node_id]):
|
223 |
+
if self.nodes[child_id].is_solved():
|
224 |
+
found_after_expansion.add(child_id)
|
225 |
+
self.winning_nodes.append(child_id)
|
226 |
+
|
227 |
+
if found_after_expansion:
|
228 |
+
return True, list(found_after_expansion)
|
229 |
+
|
230 |
+
else:
|
231 |
+
self._backpropagate(node_id, self.nodes_total_value[node_id])
|
232 |
+
self._update_visits(node_id)
|
233 |
+
explore_path = False
|
234 |
+
|
235 |
+
return False, [node_id]
|
236 |
+
|
237 |
+
def _ucb(self, node_id: int) -> float:
|
238 |
+
"""
|
239 |
+
The function calculates the Upper Confidence Bound (UCB) for a given node.
|
240 |
+
|
241 |
+
:param node_id: The `node_id` parameter is an integer that represents the ID of a node in a tree
|
242 |
+
:type node_id: int
|
243 |
+
"""
|
244 |
+
|
245 |
+
prob = self.nodes_prob[node_id] # Predicted by policy network score
|
246 |
+
visit = self.nodes_visit[node_id]
|
247 |
+
|
248 |
+
if self.config.ucb_type == "puct":
|
249 |
+
u = (
|
250 |
+
self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]])
|
251 |
+
) / (visit + 1)
|
252 |
+
return self.nodes_total_value[node_id] + u
|
253 |
+
elif self.config.ucb_type == "uct":
|
254 |
+
u = (
|
255 |
+
self.config.c_ucb
|
256 |
+
* sqrt(self.nodes_visit[self.parents[node_id]])
|
257 |
+
/ (visit + 1)
|
258 |
+
)
|
259 |
+
return self.nodes_total_value[node_id] + u
|
260 |
+
elif self.config.ucb_type == "value":
|
261 |
+
return self.nodes_init_value[node_id] / (visit + 1)
|
262 |
+
else:
|
263 |
+
raise ValueError(f"I don't know this UCB type: {self.config.ucb_type}")
|
264 |
+
|
265 |
+
def _select_node(self, node_id: int) -> int:
|
266 |
+
"""
|
267 |
+
This function selects a node based on its UCB value and returns the ID of the node with the highest value of
|
268 |
+
the UCB function.
|
269 |
+
|
270 |
+
:param node_id: The `node_id` parameter is an integer that represents the ID of a node
|
271 |
+
:type node_id: int
|
272 |
+
"""
|
273 |
+
|
274 |
+
if self.config.epsilon > 0:
|
275 |
+
n = uniform(0, 1)
|
276 |
+
if n < self.config.epsilon:
|
277 |
+
return choice(list(self.children[node_id]))
|
278 |
+
|
279 |
+
best_score, best_children = None, []
|
280 |
+
for child_id in self.children[node_id]:
|
281 |
+
score = self._ucb(child_id)
|
282 |
+
if best_score is None or score > best_score:
|
283 |
+
best_score, best_children = score, [child_id]
|
284 |
+
elif score == best_score:
|
285 |
+
best_children.append(child_id)
|
286 |
+
return choice(best_children)
|
287 |
+
|
288 |
+
def _expand_node(self, node_id: int) -> None:
|
289 |
+
"""
|
290 |
+
The function expands a given node by generating new retrons with policy (expansion) policy.
|
291 |
+
|
292 |
+
:param node_id: The `node_id` parameter is an integer that represents the ID of the current node
|
293 |
+
:type node_id: int
|
294 |
+
"""
|
295 |
+
curr_node = self.nodes[node_id]
|
296 |
+
prev_retrons = curr_node.curr_retron.prev_retrons
|
297 |
+
|
298 |
+
tmp_retrons = set()
|
299 |
+
for prob, rule, rule_id in self.policy_function.predict_reaction_rules(
|
300 |
+
curr_node.curr_retron, self.reaction_rules
|
301 |
+
):
|
302 |
+
for products in apply_reaction_rule(curr_node.curr_retron.molecule, rule):
|
303 |
+
# check repeated products
|
304 |
+
if not products or not set(products) - tmp_retrons:
|
305 |
+
continue
|
306 |
+
tmp_retrons.update(products)
|
307 |
+
|
308 |
+
for molecule in products:
|
309 |
+
molecule.meta["reactor_id"] = rule_id
|
310 |
+
|
311 |
+
new_retrons = tuple(Retron(mol) for mol in products)
|
312 |
+
scaled_prob = prob * len(
|
313 |
+
list(filter(lambda x: len(x) > self.config.min_mol_size, products))
|
314 |
+
)
|
315 |
+
|
316 |
+
if set(prev_retrons).isdisjoint(new_retrons):
|
317 |
+
retrons_to_expand = (
|
318 |
+
*curr_node.next_retrons,
|
319 |
+
*(
|
320 |
+
x
|
321 |
+
for x in new_retrons
|
322 |
+
if not x.is_building_block(
|
323 |
+
self.building_blocks, self.config.min_mol_size
|
324 |
+
)
|
325 |
+
),
|
326 |
+
)
|
327 |
+
|
328 |
+
child_node = Node(
|
329 |
+
retrons_to_expand=retrons_to_expand, new_retrons=new_retrons
|
330 |
+
)
|
331 |
+
|
332 |
+
for new_retron in new_retrons:
|
333 |
+
new_retron.prev_retrons = [new_retron, *prev_retrons]
|
334 |
+
|
335 |
+
self._add_node(node_id, child_node, scaled_prob)
|
336 |
+
|
337 |
+
def _add_node(self, node_id: int, new_node: Node, policy_prob: float = None) -> None:
|
338 |
+
"""
|
339 |
+
This function adds a new node to a tree with its predicted policy probability.
|
340 |
+
|
341 |
+
:param node_id: ID of the parent node
|
342 |
+
:type node_id: int
|
343 |
+
:param new_node: The `new_node` is an instance of the`Node` class
|
344 |
+
:type new_node: Node
|
345 |
+
:param policy_prob: The `policy_prob` a float value that represents the probability associated with a new node.
|
346 |
+
:type policy_prob: float
|
347 |
+
"""
|
348 |
+
|
349 |
+
new_node_id = self.curr_tree_size
|
350 |
+
|
351 |
+
self.nodes[new_node_id] = new_node
|
352 |
+
self.parents[new_node_id] = node_id
|
353 |
+
self.children[node_id].add(new_node_id)
|
354 |
+
self.children[new_node_id] = set()
|
355 |
+
self.nodes_visit[new_node_id] = 0
|
356 |
+
self.nodes_prob[new_node_id] = policy_prob
|
357 |
+
self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1
|
358 |
+
self.curr_tree_size += 1
|
359 |
+
|
360 |
+
if self.config.search_strategy == "evaluation_first":
|
361 |
+
node_value = self._get_node_value(new_node_id)
|
362 |
+
elif self.config.search_strategy == "expansion_first":
|
363 |
+
node_value = self.config.init_node_value
|
364 |
+
else:
|
365 |
+
raise ValueError(
|
366 |
+
f"Invalid search_strategy: {self.config.search_strategy}: "
|
367 |
+
f"Allowed values are 'expansion_first', 'evaluation_first'"
|
368 |
+
)
|
369 |
+
|
370 |
+
self.nodes_init_value[new_node_id] = node_value
|
371 |
+
self.nodes_total_value[new_node_id] = node_value
|
372 |
+
|
373 |
+
def _get_node_value(self, node_id: int) -> float:
|
374 |
+
"""
|
375 |
+
This function calculates the value for the given node.
|
376 |
+
|
377 |
+
:param node_id: ID of the given node
|
378 |
+
:type node_id: int
|
379 |
+
"""
|
380 |
+
|
381 |
+
node = self.nodes[node_id]
|
382 |
+
|
383 |
+
if self.config.evaluation_type == "random":
|
384 |
+
node_value = uniform()
|
385 |
+
|
386 |
+
elif self.config.evaluation_type == "rollout":
|
387 |
+
node_value = min(
|
388 |
+
(
|
389 |
+
self._rollout_node(retron, current_depth=self.nodes_depth[node_id])
|
390 |
+
for retron in node.retrons_to_expand
|
391 |
+
),
|
392 |
+
default=1.0,
|
393 |
+
)
|
394 |
+
|
395 |
+
elif self.config.evaluation_type == "gcn":
|
396 |
+
node_value = self.value_function.predict_value(node.new_retrons)
|
397 |
+
|
398 |
+
else:
|
399 |
+
raise ValueError(
|
400 |
+
f"I don't know this evaluation mode: {self.config.evaluation_type}"
|
401 |
+
)
|
402 |
+
|
403 |
+
return node_value
|
404 |
+
|
405 |
+
def _update_visits(self, node_id: int) -> None:
|
406 |
+
"""
|
407 |
+
The function updates the number of visits from a given node to a root node.
|
408 |
+
|
409 |
+
:param node_id: The ID of a current node
|
410 |
+
:type node_id: int
|
411 |
+
"""
|
412 |
+
|
413 |
+
while node_id:
|
414 |
+
self.nodes_visit[node_id] += 1
|
415 |
+
node_id = self.parents[node_id]
|
416 |
+
|
417 |
+
def _backpropagate(self, node_id: int, value: float) -> None:
|
418 |
+
"""
|
419 |
+
The function backpropagates a value through a tree of a given node specified by node_id.
|
420 |
+
|
421 |
+
:param node_id: The ID of a given node from which to backpropagate value
|
422 |
+
:type node_id: int
|
423 |
+
:param value: The value to backpropagate
|
424 |
+
:type value: float
|
425 |
+
"""
|
426 |
+
while node_id:
|
427 |
+
if self.config.backprop_type == "muzero":
|
428 |
+
self.nodes_total_value[node_id] = (
|
429 |
+
self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value
|
430 |
+
) / (self.nodes_visit[node_id] + 1)
|
431 |
+
elif self.config.backprop_type == "cumulative":
|
432 |
+
self.nodes_total_value[node_id] += value
|
433 |
+
else:
|
434 |
+
raise ValueError(
|
435 |
+
f"I don't know this backpropagation type: {self.config.backprop_type}"
|
436 |
+
)
|
437 |
+
node_id = self.parents[node_id]
|
438 |
+
|
439 |
+
def _rollout_node(self, retron: Retron, current_depth: int = None) -> float:
|
440 |
+
"""
|
441 |
+
The function `_rollout_node` performs a rollout simulation from a given node in a tree.
|
442 |
+
Given the current retron, find the first successful reaction and return the new retrons.
|
443 |
+
|
444 |
+
If the retron is a building_block, return 1.0, else check the first successful reaction;
|
445 |
+
|
446 |
+
If the reaction is not successful, return -1.0;
|
447 |
+
|
448 |
+
If the reaction is successful, but the generated retrons are not the building_blocks and the retrons
|
449 |
+
cannot be generated without exceeding current_depth threshold, return -0.5;
|
450 |
+
|
451 |
+
If the reaction is successful, but the retrons are not the building_blocks and the retrons
|
452 |
+
cannot be generated, return -1.0;
|
453 |
+
|
454 |
+
:param retron: A Retron object
|
455 |
+
:type retron: Retron
|
456 |
+
:param current_depth: The current depth of the tree
|
457 |
+
:type current_depth: int
|
458 |
+
"""
|
459 |
+
|
460 |
+
max_depth = self.config.max_depth - current_depth
|
461 |
+
|
462 |
+
# retron checking
|
463 |
+
if retron.is_building_block(self.building_blocks, self.config.min_mol_size):
|
464 |
+
return 1.0
|
465 |
+
|
466 |
+
if max_depth == 0:
|
467 |
+
logging.debug("Rollout: tried to perform rollout on the leaf node")
|
468 |
+
return -0.5
|
469 |
+
|
470 |
+
# retron simulating
|
471 |
+
occurred_retrons = set()
|
472 |
+
retrons_to_expand = deque([retron])
|
473 |
+
history = defaultdict(dict)
|
474 |
+
rollout_depth = 0
|
475 |
+
while retrons_to_expand:
|
476 |
+
# Iterate through reactors and pick first successful reaction.
|
477 |
+
# Check products of the reaction if you can find them in in-building_blocks data
|
478 |
+
# If not, then add missed products to retrons_to_expand and try to decompose them
|
479 |
+
if len(history) >= max_depth:
|
480 |
+
logging.debug(
|
481 |
+
f"Rollout: max depth of rollout is reached with these "
|
482 |
+
f"retrons to expand: {retrons_to_expand} {history}",
|
483 |
+
)
|
484 |
+
reward = -0.5
|
485 |
+
return reward
|
486 |
+
|
487 |
+
current_retron = retrons_to_expand.popleft()
|
488 |
+
history[rollout_depth]["target"] = current_retron
|
489 |
+
occurred_retrons.add(current_retron)
|
490 |
+
|
491 |
+
# Pick the first successful reaction while iterating through reactors
|
492 |
+
reaction_rule_applied = False
|
493 |
+
for prob, rule, rule_id in self.policy_function.predict_reaction_rules(
|
494 |
+
current_retron, self.reaction_rules
|
495 |
+
):
|
496 |
+
for products in apply_reaction_rule(current_retron.molecule, rule):
|
497 |
+
if products:
|
498 |
+
reaction_rule_applied = True
|
499 |
+
break
|
500 |
+
|
501 |
+
if reaction_rule_applied:
|
502 |
+
history[rollout_depth]["rule_index"] = rule_id
|
503 |
+
break
|
504 |
+
|
505 |
+
if not reaction_rule_applied:
|
506 |
+
logging.debug(
|
507 |
+
f"Rollout: no reaction rule was applied for the "
|
508 |
+
f"molecule {current_retron} on rollout depth {rollout_depth}"
|
509 |
+
)
|
510 |
+
reward = -1.0
|
511 |
+
return reward
|
512 |
+
|
513 |
+
products = tuple(Retron(product) for product in products) # TODO /!\ Is it ok how products is defined above (line 496) ? Seems to
|
514 |
+
# TODO /!\ consider only last iterable of apply_reaction_rule
|
515 |
+
history[rollout_depth]["products"] = products
|
516 |
+
|
517 |
+
# check loops
|
518 |
+
if any(x in occurred_retrons for x in products) and products:
|
519 |
+
# Sometimes manual can create a loop, when
|
520 |
+
logging.debug("Rollout: rollout got in the loop: %s", history)
|
521 |
+
# print('occurred_retrons')
|
522 |
+
reward = -1.0
|
523 |
+
return reward
|
524 |
+
|
525 |
+
if occurred_retrons.isdisjoint(products):
|
526 |
+
# Added number of atoms check
|
527 |
+
retrons_to_expand.extend(
|
528 |
+
[
|
529 |
+
x
|
530 |
+
for x in products
|
531 |
+
if not x.is_building_block(
|
532 |
+
self.building_blocks, self.config.min_mol_size
|
533 |
+
)
|
534 |
+
]
|
535 |
+
)
|
536 |
+
rollout_depth += 1
|
537 |
+
|
538 |
+
reward = 1.0
|
539 |
+
return reward
|
540 |
+
|
541 |
+
def report(self) -> str:
|
542 |
+
"""
|
543 |
+
Returns the string representation of the tree.
|
544 |
+
"""
|
545 |
+
|
546 |
+
return (
|
547 |
+
f"Tree for: {str(self.nodes[1].retrons_to_expand[0])}\n"
|
548 |
+
f"Number of nodes: {len(self)}\nNumber of visited nodes: {len(self.visited_nodes)}\n"
|
549 |
+
f"Found paths: {len(self.winning_nodes)}\nTime: {round(self.curr_time, 1)} seconds"
|
550 |
+
)
|
551 |
+
|
552 |
+
def path_score(self, node_id: int) -> float:
|
553 |
+
"""
|
554 |
+
The function calculates the score of a given path from the node with node_id to the root node.
|
555 |
+
|
556 |
+
:param node_id: The ID of a given node
|
557 |
+
:type node_id: int
|
558 |
+
"""
|
559 |
+
|
560 |
+
cumulated_nodes_value, path_length = 0, 0
|
561 |
+
while node_id:
|
562 |
+
path_length += 1
|
563 |
+
|
564 |
+
cumulated_nodes_value += self.nodes_total_value[node_id]
|
565 |
+
node_id = self.parents[node_id]
|
566 |
+
|
567 |
+
return cumulated_nodes_value / (path_length ** 2)
|
568 |
+
|
569 |
+
def path_to_node(self, node_id: int) -> list:
|
570 |
+
"""
|
571 |
+
The function returns the path (list of IDs of nodes) to from a node specified by node_id to the root node.
|
572 |
+
|
573 |
+
:param node_id: The ID of a given node
|
574 |
+
:type node_id: int
|
575 |
+
"""
|
576 |
+
|
577 |
+
nodes = []
|
578 |
+
while node_id:
|
579 |
+
nodes.append(node_id)
|
580 |
+
node_id = self.parents[node_id]
|
581 |
+
return [self.nodes[node_id] for node_id in reversed(nodes)]
|
582 |
+
|
583 |
+
def synthesis_path(self, node_id: int) -> Tuple[Reaction, ...]:
|
584 |
+
"""
|
585 |
+
Given a node_id, return a tuple of Reactions that represent the synthesis path from the
|
586 |
+
node specified with node_id to the root node
|
587 |
+
|
588 |
+
:param node_id: The ID of a given node
|
589 |
+
:type node_id: int
|
590 |
+
"""
|
591 |
+
|
592 |
+
nodes = self.path_to_node(node_id)
|
593 |
+
|
594 |
+
tmp = [
|
595 |
+
Reaction(
|
596 |
+
[x.molecule for x in after.new_retrons],
|
597 |
+
[before.curr_retron.molecule],
|
598 |
+
)
|
599 |
+
for before, after in zip(nodes, nodes[1:])
|
600 |
+
]
|
601 |
+
|
602 |
+
for r in tmp:
|
603 |
+
r.clean2d()
|
604 |
+
return tuple(reversed(tmp))
|
605 |
+
|
606 |
+
def newickify(self, visits_threshold: int = 0, root_node_id: int = 1): # TODO what is return here ?
|
607 |
+
"""
|
608 |
+
Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format
|
609 |
+
:param visits_threshold: the minimum number of visits for the given node # TODO is this explanation correct ?
|
610 |
+
:type visits_threshold: int
|
611 |
+
:param root_node_id: The ID of a root node
|
612 |
+
:type root_node_id: int
|
613 |
+
"""
|
614 |
+
visited_nodes = set()
|
615 |
+
|
616 |
+
def newick_render_node(current_node_id: int) -> str:
|
617 |
+
"""
|
618 |
+
Recursively generates a Newick string representation of a tree
|
619 |
+
|
620 |
+
:param current_node_id: The identifier of the current node in the tree
|
621 |
+
:type current_node_id: The identifier of the current node in the tree
|
622 |
+
:return: A string representation of a node in a Newick format
|
623 |
+
"""
|
624 |
+
assert (
|
625 |
+
current_node_id not in visited_nodes
|
626 |
+
), "Error: The tree may not be circular!"
|
627 |
+
node_visit = self.nodes_visit[current_node_id]
|
628 |
+
|
629 |
+
visited_nodes.add(current_node_id)
|
630 |
+
if self.children[current_node_id]:
|
631 |
+
# Nodes
|
632 |
+
children = [
|
633 |
+
child
|
634 |
+
for child in list(self.children[current_node_id])
|
635 |
+
if self.nodes_visit[child] >= visits_threshold
|
636 |
+
]
|
637 |
+
children_strings = [newick_render_node(child) for child in children]
|
638 |
+
children_strings = ",".join(children_strings)
|
639 |
+
if children_strings:
|
640 |
+
return f"({children_strings}){current_node_id}:{node_visit}"
|
641 |
+
else:
|
642 |
+
# Leafs within threshold
|
643 |
+
return f"{current_node_id}:{node_visit}"
|
644 |
+
else:
|
645 |
+
# Leafs
|
646 |
+
return f"{current_node_id}:{node_visit}"
|
647 |
+
|
648 |
+
newick_string = newick_render_node(root_node_id) + ";"
|
649 |
+
|
650 |
+
meta = {}
|
651 |
+
for node_id in iter(visited_nodes):
|
652 |
+
node_value = round(self.nodes_total_value[node_id], 3)
|
653 |
+
|
654 |
+
node_synthesisability = round(self.nodes_init_value[node_id])
|
655 |
+
|
656 |
+
visit_in_node = self.nodes_visit[node_id]
|
657 |
+
meta[node_id] = (node_value, node_synthesisability, visit_in_node)
|
658 |
+
|
659 |
+
return newick_string, meta
|
SynTool/ml/__init__.py
ADDED
File without changes
|
SynTool/ml/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (147 Bytes). View file
|
|
SynTool/ml/networks/__init__.py
ADDED
File without changes
|
SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (156 Bytes). View file
|
|
SynTool/ml/networks/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (8.14 kB). View file
|
|