Upload 7 files
Browse files- attention_dynamic_model.py +283 -0
- attention_graph_encoder.py +93 -0
- enviroment.py +128 -0
- layers.py +110 -0
- reinforce_baseline.py +223 -0
- utils.py +221 -0
- utils_demo.py +73 -0
attention_dynamic_model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from attention_graph_encoder import GraphAttentionEncoder
|
5 |
+
from enviroment import AgentVRP
|
6 |
+
|
7 |
+
|
8 |
+
def set_decode_type(model, decode_type):
|
9 |
+
model.set_decode_type(decode_type)
|
10 |
+
|
11 |
+
class AttentionDynamicModel(tf.keras.Model):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
embedding_dim,
|
15 |
+
n_encode_layers=2,
|
16 |
+
n_heads=8,
|
17 |
+
tanh_clipping=10.
|
18 |
+
):
|
19 |
+
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
# attributes for MHA
|
23 |
+
self.embedding_dim = embedding_dim
|
24 |
+
self.n_encode_layers = n_encode_layers
|
25 |
+
self.decode_type = None
|
26 |
+
|
27 |
+
# attributes for VRP problem
|
28 |
+
self.problem = AgentVRP
|
29 |
+
self.n_heads = n_heads
|
30 |
+
|
31 |
+
# Encoder part
|
32 |
+
self.embedder = GraphAttentionEncoder(input_dim=self.embedding_dim,
|
33 |
+
num_heads=self.n_heads,
|
34 |
+
num_layers=self.n_encode_layers
|
35 |
+
)
|
36 |
+
|
37 |
+
# Decoder part
|
38 |
+
|
39 |
+
self.output_dim = self.embedding_dim
|
40 |
+
self.num_heads = n_heads
|
41 |
+
|
42 |
+
self.head_depth = self.output_dim // self.num_heads
|
43 |
+
self.dk_mha_decoder = tf.cast(self.head_depth, tf.float32) # for decoding in mha_decoder
|
44 |
+
self.dk_get_loc_p = tf.cast(self.output_dim, tf.float32) # for decoding in mha_decoder
|
45 |
+
|
46 |
+
if self.output_dim % self.num_heads != 0:
|
47 |
+
raise ValueError("number of heads must divide d_model=output_dim")
|
48 |
+
|
49 |
+
self.tanh_clipping = tanh_clipping
|
50 |
+
|
51 |
+
# we split projection matrix Wq into 2 matrices: Wq*[h_c, h_N, D] = Wq_context*h_c + Wq_step_context[h_N, D]
|
52 |
+
self.wq_context = tf.keras.layers.Dense(self.output_dim, use_bias=False,
|
53 |
+
name='wq_context') # (d_q_context, output_dim)
|
54 |
+
self.wq_step_context = tf.keras.layers.Dense(self.output_dim, use_bias=False,
|
55 |
+
name='wq_step_context') # (d_q_step_context, output_dim)
|
56 |
+
|
57 |
+
# we need two Wk projections since there is MHA followed by 1-head attention - they have different keys K
|
58 |
+
self.wk = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk') # (d_k, output_dim)
|
59 |
+
self.wk_tanh = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk_tanh') # (d_k_tanh, output_dim)
|
60 |
+
|
61 |
+
# we dont need Wv projection for 1-head attention: only need attention weights as outputs
|
62 |
+
self.wv = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wv') # (d_v, output_dim)
|
63 |
+
|
64 |
+
# we dont need wq for 1-head tanh attention, since we can absorb it into w_out
|
65 |
+
self.w_out = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='w_out') # (d_model, d_model)
|
66 |
+
|
67 |
+
def set_decode_type(self, decode_type):
|
68 |
+
self.decode_type = decode_type
|
69 |
+
|
70 |
+
def split_heads(self, tensor, batch_size):
|
71 |
+
"""Function for computing attention on several heads simultaneously
|
72 |
+
Splits last dimension of a tensor into (num_heads, head_depth).
|
73 |
+
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast
|
74 |
+
"""
|
75 |
+
tensor = tf.reshape(tensor, (batch_size, -1, self.num_heads, self.head_depth))
|
76 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
77 |
+
|
78 |
+
def _select_node(self, logits):
|
79 |
+
"""Select next node based on decoding type.
|
80 |
+
"""
|
81 |
+
|
82 |
+
# assert tf.reduce_all(logits == logits), "Probs should not contain any nans"
|
83 |
+
|
84 |
+
if self.decode_type == "greedy":
|
85 |
+
selected = tf.math.argmax(logits, axis=-1) # (batch_size, 1)
|
86 |
+
|
87 |
+
elif self.decode_type == "sampling":
|
88 |
+
# logits has a shape of (batch_size, 1, n_nodes), we have to squeeze it
|
89 |
+
# to (batch_size, n_nodes) since tf.random.categorical requires matrix
|
90 |
+
selected = tf.random.categorical(logits[:, 0, :], 1) # (bach_size,1)
|
91 |
+
else:
|
92 |
+
assert False, "Unknown decode type"
|
93 |
+
|
94 |
+
return tf.squeeze(selected, axis=-1) # (bach_size,)
|
95 |
+
|
96 |
+
def get_step_context(self, state, embeddings):
|
97 |
+
"""Takes a state and graph embeddings,
|
98 |
+
Returns a part [h_N, D] of context vector [h_c, h_N, D],
|
99 |
+
that is related to RL Agent last step.
|
100 |
+
"""
|
101 |
+
# index of previous node
|
102 |
+
prev_node = state.prev_a # (batch_size, 1)
|
103 |
+
|
104 |
+
# from embeddings=(batch_size, n_nodes, input_dim) select embeddings of previous nodes
|
105 |
+
cur_embedded_node = tf.gather(embeddings, tf.cast(prev_node, tf.int32), batch_dims=1) # (batch_size, 1, input_dim)
|
106 |
+
|
107 |
+
# add remaining capacity
|
108 |
+
step_context = tf.concat([cur_embedded_node, self.problem.VEHICLE_CAPACITY - state.used_capacity[:, :, None]], axis=-1)
|
109 |
+
|
110 |
+
return step_context # (batch_size, 1, input_dim + 1)
|
111 |
+
|
112 |
+
def decoder_mha(self, Q, K, V, mask=None):
|
113 |
+
""" Computes Multi-Head Attention part of decoder
|
114 |
+
Basically, its a part of MHA sublayer, but we cant construct a layer since Q changes in a decoding loop.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
mask: a mask for visited nodes,
|
118 |
+
has shape (batch_size, seq_len_q, seq_len_k), seq_len_q = 1 for context vector attention in decoder
|
119 |
+
Q: query (context vector for decoder)
|
120 |
+
has shape (..., seq_len_q, head_depth) with seq_len_q = 1 for context_vector attention in decoder
|
121 |
+
K, V: key, value (projections of nodes embeddings)
|
122 |
+
have shape (..., seq_len_k, head_depth), (..., seq_len_v, head_depth),
|
123 |
+
with seq_len_k = seq_len_v = n_nodes for decoder
|
124 |
+
"""
|
125 |
+
|
126 |
+
compatibility = tf.matmul(Q, K, transpose_b=True)/tf.math.sqrt(self.dk_mha_decoder) # (batch_size, num_heads, seq_len_q, seq_len_k)
|
127 |
+
|
128 |
+
if mask is not None:
|
129 |
+
|
130 |
+
# we need to reshape mask:
|
131 |
+
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k)
|
132 |
+
# so that we will be able to do a broadcast:
|
133 |
+
# (batch_size, num_heads, seq_len_q, seq_len_k) + (batch_size, 1, seq_len_q, seq_len_k)
|
134 |
+
mask = mask[:, tf.newaxis, :, :]
|
135 |
+
|
136 |
+
# we use tf.where since 0*-np.inf returns nan, but not -np.inf
|
137 |
+
# compatibility = tf.where(
|
138 |
+
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
|
139 |
+
# compatibility
|
140 |
+
# )
|
141 |
+
|
142 |
+
compatibility = tf.where(mask,
|
143 |
+
tf.ones_like(compatibility) * (-np.inf),
|
144 |
+
compatibility
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
compatibility = tf.nn.softmax(compatibility, axis=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
|
149 |
+
attention = tf.matmul(compatibility, V) # (batch_size, num_heads, seq_len_q, head_depth)
|
150 |
+
|
151 |
+
# transpose back to (batch_size, seq_len_q, num_heads, depth)
|
152 |
+
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
|
153 |
+
|
154 |
+
# concatenate heads (last 2 dimensions)
|
155 |
+
attention = tf.reshape(attention, (self.batch_size, -1, self.output_dim)) # (batch_size, seq_len_q, output_dim)
|
156 |
+
|
157 |
+
output = self.w_out(attention) # (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context att in decoder
|
158 |
+
|
159 |
+
return output
|
160 |
+
|
161 |
+
def get_log_p(self, Q, K, mask=None):
|
162 |
+
"""Single-Head attention sublayer in decoder,
|
163 |
+
computes log-probabilities for node selection.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
mask: mask for nodes
|
167 |
+
Q: query (output of mha layer)
|
168 |
+
has shape (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context attention in decoder
|
169 |
+
K: key (projection of node embeddings)
|
170 |
+
has shape (batch_size, seq_len_k, output_dim), seq_len_k = n_nodes for decoder
|
171 |
+
"""
|
172 |
+
|
173 |
+
compatibility = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(self.dk_get_loc_p)
|
174 |
+
compatibility = tf.math.tanh(compatibility) * self.tanh_clipping
|
175 |
+
|
176 |
+
if mask is not None:
|
177 |
+
|
178 |
+
# we dont need to reshape mask like we did in multi-head version:
|
179 |
+
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, num_heads, seq_len_q, seq_len_k)
|
180 |
+
# since we dont have multiple heads
|
181 |
+
|
182 |
+
# compatibility = tf.where(
|
183 |
+
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
|
184 |
+
# compatibility
|
185 |
+
# )
|
186 |
+
|
187 |
+
compatibility = tf.where(mask,
|
188 |
+
tf.ones_like(compatibility) * (-np.inf),
|
189 |
+
compatibility
|
190 |
+
)
|
191 |
+
|
192 |
+
log_p = tf.nn.log_softmax(compatibility, axis=-1) # (batch_size, seq_len_q, seq_len_k)
|
193 |
+
|
194 |
+
return log_p
|
195 |
+
|
196 |
+
def get_log_likelihood(self, _log_p, a):
|
197 |
+
|
198 |
+
# Get log_p corresponding to selected actions
|
199 |
+
log_p = tf.gather_nd(_log_p, tf.cast(tf.expand_dims(a, axis=-1), tf.int32), batch_dims=2)
|
200 |
+
|
201 |
+
# Calculate log_likelihood
|
202 |
+
return tf.reduce_sum(log_p,1)
|
203 |
+
|
204 |
+
def get_projections(self, embeddings, context_vectors):
|
205 |
+
|
206 |
+
# we compute some projections (common for each policy step) before decoding loop for efficiency
|
207 |
+
K = self.wk(embeddings) # (batch_size, n_nodes, output_dim)
|
208 |
+
K_tanh = self.wk_tanh(embeddings) # (batch_size, n_nodes, output_dim)
|
209 |
+
V = self.wv(embeddings) # (batch_size, n_nodes, output_dim)
|
210 |
+
Q_context = self.wq_context(context_vectors[:, tf.newaxis, :]) # (batch_size, 1, output_dim)
|
211 |
+
|
212 |
+
# we dont need to split K_tanh since there is only 1 head; Q will be split in decoding loop
|
213 |
+
K = self.split_heads(K, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth)
|
214 |
+
V = self.split_heads(V, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth)
|
215 |
+
|
216 |
+
return K_tanh, Q_context, K, V
|
217 |
+
|
218 |
+
def call(self, inputs, return_pi=False):
|
219 |
+
|
220 |
+
embeddings, mean_graph_emb = self.embedder(inputs)
|
221 |
+
|
222 |
+
self.batch_size = tf.shape(embeddings)[0]
|
223 |
+
|
224 |
+
outputs = []
|
225 |
+
sequences = []
|
226 |
+
|
227 |
+
state = self.problem(inputs)
|
228 |
+
|
229 |
+
K_tanh, Q_context, K, V = self.get_projections(embeddings, mean_graph_emb)
|
230 |
+
|
231 |
+
# Perform decoding steps
|
232 |
+
i = 0
|
233 |
+
inner_i = 0
|
234 |
+
|
235 |
+
while not state.all_finished():
|
236 |
+
|
237 |
+
if i > 0:
|
238 |
+
state.i = tf.zeros(1, dtype=tf.int64)
|
239 |
+
att_mask, cur_num_nodes = state.get_att_mask()
|
240 |
+
embeddings, context_vectors = self.embedder(inputs, att_mask, cur_num_nodes)
|
241 |
+
K_tanh, Q_context, K, V = self.get_projections(embeddings, context_vectors)
|
242 |
+
|
243 |
+
inner_i = 0
|
244 |
+
while not state.partial_finished():
|
245 |
+
|
246 |
+
step_context = self.get_step_context(state, embeddings) # (batch_size, 1), (batch_size, 1, input_dim + 1)
|
247 |
+
Q_step_context = self.wq_step_context(step_context) # (batch_size, 1, output_dim)
|
248 |
+
Q = Q_context + Q_step_context
|
249 |
+
|
250 |
+
# split heads for Q
|
251 |
+
Q = self.split_heads(Q, self.batch_size) # (batch_size, num_heads, 1, head_depth)
|
252 |
+
|
253 |
+
# get current mask
|
254 |
+
mask = state.get_mask() # (batch_size, 1, n_nodes) with True/False indicating where agent can go
|
255 |
+
|
256 |
+
# compute MHA decoder vectors for current mask
|
257 |
+
mha = self.decoder_mha(Q, K, V, mask) # (batch_size, 1, output_dim)
|
258 |
+
|
259 |
+
# compute probabilities
|
260 |
+
log_p = self.get_log_p(mha, K_tanh, mask) # (batch_size, 1, n_nodes)
|
261 |
+
|
262 |
+
# next step is to select node
|
263 |
+
selected = self._select_node(log_p)
|
264 |
+
|
265 |
+
state.step(selected)
|
266 |
+
|
267 |
+
outputs.append(log_p[:, 0, :])
|
268 |
+
sequences.append(selected)
|
269 |
+
|
270 |
+
inner_i += 1
|
271 |
+
|
272 |
+
i += 1
|
273 |
+
|
274 |
+
_log_p, pi = tf.stack(outputs, 1), tf.cast(tf.stack(sequences, 1), tf.float32)
|
275 |
+
|
276 |
+
cost = self.problem.get_costs(inputs, pi)
|
277 |
+
|
278 |
+
ll = self.get_log_likelihood(_log_p, pi)
|
279 |
+
|
280 |
+
if return_pi:
|
281 |
+
return cost, ll, pi
|
282 |
+
|
283 |
+
return cost, ll
|
attention_graph_encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from layers import MultiHeadAttention
|
3 |
+
|
4 |
+
|
5 |
+
class MultiHeadAttentionLayer(tf.keras.layers.Layer):
|
6 |
+
"""Feed-Forward Sublayer: fully-connected Feed-Forward network,
|
7 |
+
built based on MHA vectors from MultiHeadAttention layer with skip-connections
|
8 |
+
|
9 |
+
Args:
|
10 |
+
num_heads: number of attention heads in MHA layers.
|
11 |
+
input_dim: embedding size that will be used as d_model in MHA layers.
|
12 |
+
feed_forward_hidden: number of neuron units in each FF layer.
|
13 |
+
|
14 |
+
Call arguments:
|
15 |
+
x: batch of shape (batch_size, n_nodes, node_embedding_size).
|
16 |
+
mask: mask for MHA layer
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
outputs of shape (batch_size, n_nodes, input_dim)
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, input_dim, num_heads, feed_forward_hidden=512, **kwargs):
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
self.mha = MultiHeadAttention(n_heads=num_heads, d_model=input_dim, name='MHA')
|
26 |
+
self.ff1 = tf.keras.layers.Dense(feed_forward_hidden, name='ff1')
|
27 |
+
self.ff2 = tf.keras.layers.Dense(input_dim, name='ff2')
|
28 |
+
|
29 |
+
def call(self, x, mask=None):
|
30 |
+
mha_out = self.mha(x, x, x, mask)
|
31 |
+
sc1_out = tf.keras.layers.Add()([x, mha_out])
|
32 |
+
tanh1_out = tf.keras.activations.tanh(sc1_out)
|
33 |
+
|
34 |
+
ff1_out = self.ff1(tanh1_out)
|
35 |
+
relu1_out = tf.keras.activations.relu(ff1_out)
|
36 |
+
ff2_out = self.ff2(relu1_out)
|
37 |
+
sc2_out = tf.keras.layers.Add()([tanh1_out, ff2_out])
|
38 |
+
tanh2_out = tf.keras.activations.tanh(sc2_out)
|
39 |
+
|
40 |
+
return tanh2_out
|
41 |
+
|
42 |
+
class GraphAttentionEncoder(tf.keras.layers.Layer):
|
43 |
+
"""Graph Encoder, which uses MultiHeadAttentionLayer sublayer.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
input_dim: embedding size that will be used as d_model in MHA layers.
|
47 |
+
num_heads: number of attention heads in MHA layers.
|
48 |
+
num_layers: number of attention layers that will be used in encoder.
|
49 |
+
feed_forward_hidden: number of neuron units in each FF layer.
|
50 |
+
|
51 |
+
Call arguments:
|
52 |
+
x: tuples of 3 tensors: (batch_size, 2), (batch_size, n_nodes-1, 2), (batch_size, n_nodes-1)
|
53 |
+
First tensor contains coordinates for depot, second one is for coordinates of other nodes,
|
54 |
+
Last tensor is for normalized demands for nodes except depot
|
55 |
+
|
56 |
+
mask: mask for MHA layer
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Embedding for all nodes + mean embedding for graph.
|
60 |
+
Tuples ((batch_size, n_nodes, input_dim), (batch_size, input_dim))
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, input_dim, num_heads, num_layers, feed_forward_hidden=512):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.input_dim = input_dim
|
67 |
+
self.num_layers = num_layers
|
68 |
+
self.num_heads = num_heads
|
69 |
+
self.feed_forward_hidden = feed_forward_hidden
|
70 |
+
|
71 |
+
# initial embeddings (batch_size, n_nodes-1, 2) --> (batch-size, input_dim), separate for depot and other nodes
|
72 |
+
self.init_embed_depot = tf.keras.layers.Dense(self.input_dim, name='init_embed_depot') # nn.Linear(2, embedding_dim)
|
73 |
+
self.init_embed = tf.keras.layers.Dense(self.input_dim, name='init_embed')
|
74 |
+
|
75 |
+
self.mha_layers = [MultiHeadAttentionLayer(self.input_dim, self.num_heads, self.feed_forward_hidden)
|
76 |
+
for _ in range(self.num_layers)]
|
77 |
+
|
78 |
+
def call(self, x, mask=None, cur_num_nodes=None):
|
79 |
+
|
80 |
+
x = tf.concat((self.init_embed_depot(x[0])[:, None, :], # (batch_size, 2) --> (batch_size, 1, 2)
|
81 |
+
self.init_embed(tf.concat((x[1], x[2][:, :, None]), axis=-1)) # (batch_size, n_nodes-1, 2) + (batch_size, n_nodes-1)
|
82 |
+
), axis=1) # (batch_size, n_nodes, input_dim)
|
83 |
+
|
84 |
+
# stack attention layers
|
85 |
+
for i in range(self.num_layers):
|
86 |
+
x = self.mha_layers[i](x, mask)
|
87 |
+
|
88 |
+
if mask is not None:
|
89 |
+
output = (x, tf.reduce_sum(x, axis=1) / cur_num_nodes)
|
90 |
+
else:
|
91 |
+
output = (x, tf.reduce_mean(x, axis=1))
|
92 |
+
|
93 |
+
return output # (embeds of nodes, avg graph embed)=((batch_size, n_nodes, input), (batch_size, input_dim))
|
enviroment.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
class AgentVRP():
|
4 |
+
|
5 |
+
VEHICLE_CAPACITY = 1.0
|
6 |
+
|
7 |
+
def __init__(self, input):
|
8 |
+
|
9 |
+
depot = input[0]
|
10 |
+
loc = input[1]
|
11 |
+
|
12 |
+
self.batch_size, self.n_loc, _ = loc.shape # (batch_size, n_nodes, 2)
|
13 |
+
|
14 |
+
# Coordinates of depot + other nodes
|
15 |
+
self.coords = tf.concat((depot[:, None, :], loc), -2)
|
16 |
+
self.demand = tf.cast(input[2], tf.float32)
|
17 |
+
|
18 |
+
# Indices of graphs in batch
|
19 |
+
self.ids = tf.range(self.batch_size, dtype=tf.int64)[:, None]
|
20 |
+
|
21 |
+
# State
|
22 |
+
self.prev_a = tf.zeros((self.batch_size, 1), dtype=tf.float32)
|
23 |
+
self.from_depot = self.prev_a == 0
|
24 |
+
self.used_capacity = tf.zeros((self.batch_size, 1), dtype=tf.float32)
|
25 |
+
|
26 |
+
# Nodes that have been visited will be marked with 1
|
27 |
+
self.visited = tf.zeros((self.batch_size, 1, self.n_loc + 1), dtype=tf.uint8)
|
28 |
+
|
29 |
+
# Step counter
|
30 |
+
self.i = tf.zeros(1, dtype=tf.int64)
|
31 |
+
|
32 |
+
# Constant tensors for scatter update (in step method)
|
33 |
+
self.step_updates = tf.ones((self.batch_size, 1), dtype=tf.uint8) # (batch_size, 1)
|
34 |
+
self.scatter_zeros = tf.zeros((self.batch_size, 1), dtype=tf.int64) # (batch_size, 1)
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def outer_pr(a, b):
|
38 |
+
"""Outer product of matrices
|
39 |
+
"""
|
40 |
+
return tf.einsum('ki,kj->kij', a, b)
|
41 |
+
|
42 |
+
def get_att_mask(self):
|
43 |
+
""" Mask (batch_size, n_nodes, n_nodes) for attention encoder.
|
44 |
+
We mask already visited nodes except depot
|
45 |
+
"""
|
46 |
+
|
47 |
+
# We dont want to mask depot
|
48 |
+
att_mask = tf.squeeze(tf.cast(self.visited, tf.float32), axis=-2)[:, 1:] # [batch_size, 1, n_nodes] --> [batch_size, n_nodes-1]
|
49 |
+
|
50 |
+
# Number of nodes in new instance after masking
|
51 |
+
cur_num_nodes = self.n_loc + 1 - tf.reshape(tf.reduce_sum(att_mask, -1), (-1,1)) # [batch_size, 1]
|
52 |
+
|
53 |
+
att_mask = tf.concat((tf.zeros(shape=(att_mask.shape[0],1),dtype=tf.float32),att_mask), axis=-1)
|
54 |
+
|
55 |
+
ones_mask = tf.ones_like(att_mask)
|
56 |
+
|
57 |
+
# Create square attention mask from row-like mask
|
58 |
+
att_mask = AgentVRP.outer_pr(att_mask, ones_mask) \
|
59 |
+
+ AgentVRP.outer_pr(ones_mask, att_mask)\
|
60 |
+
- AgentVRP.outer_pr(att_mask, att_mask)
|
61 |
+
|
62 |
+
return tf.cast(att_mask, dtype=tf.bool), cur_num_nodes
|
63 |
+
|
64 |
+
def all_finished(self):
|
65 |
+
"""Checks if all games are finished
|
66 |
+
"""
|
67 |
+
return tf.reduce_all(tf.cast(self.visited, tf.bool))
|
68 |
+
|
69 |
+
def partial_finished(self):
|
70 |
+
"""Checks if partial solution for all graphs has been built, i.e. all agents came back to depot
|
71 |
+
"""
|
72 |
+
return tf.reduce_all(self.from_depot) and self.i != 0
|
73 |
+
|
74 |
+
def get_mask(self):
|
75 |
+
""" Returns a mask (batch_size, 1, n_nodes) with available actions.
|
76 |
+
Impossible nodes are masked.
|
77 |
+
"""
|
78 |
+
|
79 |
+
# Exclude depot
|
80 |
+
visited_loc = self.visited[:, :, 1:]
|
81 |
+
|
82 |
+
# Mark nodes which exceed vehicle capacity
|
83 |
+
exceeds_cap = self.demand + self.used_capacity > self.VEHICLE_CAPACITY
|
84 |
+
|
85 |
+
# We mask nodes that are already visited or have too much demand
|
86 |
+
# Also for dynamical model we stop agent at depot when it arrives there (for partial solution)
|
87 |
+
mask_loc = tf.cast(visited_loc, tf.bool) | exceeds_cap[:, None, :] | ((self.i > 0) & self.from_depot[:, None, :])
|
88 |
+
|
89 |
+
# We can choose depot if 1) we are not in depot OR 2) all nodes are visited
|
90 |
+
mask_depot = self.from_depot & (tf.reduce_sum(tf.cast(mask_loc == False, tf.int32), axis=-1) > 0)
|
91 |
+
|
92 |
+
return tf.concat([mask_depot[:, :, None], mask_loc], axis=-1)
|
93 |
+
|
94 |
+
def step(self, action):
|
95 |
+
|
96 |
+
# Update current state
|
97 |
+
selected = action[:, None]
|
98 |
+
|
99 |
+
self.prev_a = selected
|
100 |
+
self.from_depot = self.prev_a == 0
|
101 |
+
|
102 |
+
# We have to shift indices by 1 since demand doesn't include depot
|
103 |
+
# 0-index in demand corresponds to the FIRST node
|
104 |
+
selected_demand = tf.gather_nd(self.demand,
|
105 |
+
tf.concat([self.ids, tf.clip_by_value(self.prev_a - 1, 0, self.n_loc - 1)], axis=1)
|
106 |
+
)[:, None] # (batch_size, 1)
|
107 |
+
|
108 |
+
# We add current node capacity to used capacity and set it to zero if we return to the depot
|
109 |
+
self.used_capacity = (self.used_capacity + selected_demand) * (1.0 - tf.cast(self.from_depot, tf.float32))
|
110 |
+
|
111 |
+
# Update visited nodes (set 1 to visited nodes)
|
112 |
+
idx = tf.cast(tf.concat((self.ids, self.scatter_zeros, self.prev_a), axis=-1), tf.int32)[:, None, :] # (batch_size, 1, 3)
|
113 |
+
self.visited = tf.tensor_scatter_nd_update(self.visited, idx, self.step_updates) # (batch_size, 1, n_nodes)
|
114 |
+
|
115 |
+
self.i = self.i + 1
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def get_costs(dataset, pi):
|
119 |
+
|
120 |
+
# Place nodes with coordinates in order of decoder tour
|
121 |
+
loc_with_depot = tf.concat([dataset[0][:, None, :], dataset[1]], axis=1) # (batch_size, n_nodes, 2)
|
122 |
+
d = tf.gather(loc_with_depot, tf.cast(pi, tf.int32), batch_dims=1)
|
123 |
+
|
124 |
+
# Calculation of total distance
|
125 |
+
# Note: first element of pi is not depot, but the first selected node in the path
|
126 |
+
return (tf.reduce_sum(tf.norm(d[:, 1:] - d[:, :-1], ord=2, axis=2), axis=1)
|
127 |
+
+ tf.norm(d[:, 0] - dataset[0], ord=2, axis=1) # Distance from depot to first selected node
|
128 |
+
+ tf.norm(d[:, -1] - dataset[0], ord=2, axis=1)) # Distance from last selected node (!=0 for graph with longest path) to depot
|
layers.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class MultiHeadAttention(tf.keras.layers.Layer):
|
7 |
+
""" Attention Layer - multi-head scaled dot product attention (for encoder and decoder)
|
8 |
+
|
9 |
+
Args:
|
10 |
+
num_heads: number of attention heads which will be computed in parallel
|
11 |
+
d_model: embedding size of output features
|
12 |
+
|
13 |
+
Call arguments:
|
14 |
+
q: query, shape (..., seq_len_q, depth_q)
|
15 |
+
k: key, shape == (..., seq_len_k, depth_k)
|
16 |
+
v: value, shape == (..., seq_len_v, depth_v)
|
17 |
+
mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k) or None.
|
18 |
+
|
19 |
+
Since we use scaled-product attention, we assume seq_len_k = seq_len_v
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
attention outputs of shape (batch_size, seq_len_q, d_model)
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, n_heads, d_model, **kwargs):
|
26 |
+
super().__init__(**kwargs)
|
27 |
+
self.n_heads = n_heads
|
28 |
+
self.d_model = d_model
|
29 |
+
self.head_depth = self.d_model // self.n_heads
|
30 |
+
|
31 |
+
if self.d_model % self.n_heads != 0:
|
32 |
+
raise ValueError("number of heads must divide d_model")
|
33 |
+
|
34 |
+
# define weight matrices
|
35 |
+
self.wq = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_q, d_model)
|
36 |
+
self.wk = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_k, d_model)
|
37 |
+
self.wv = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_v, d_model)
|
38 |
+
|
39 |
+
self.w_out = tf.keras.layers.Dense(self.d_model, use_bias=False) # (d_model, d_model)
|
40 |
+
|
41 |
+
def split_heads(self, tensor, batch_size):
|
42 |
+
"""Function for computing attention on several heads simultaneously
|
43 |
+
Splits last dimension of a tensor into (num_heads, head_depth).
|
44 |
+
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast
|
45 |
+
"""
|
46 |
+
tensor = tf.reshape(tensor, (batch_size, -1, self.n_heads, self.head_depth))
|
47 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
48 |
+
|
49 |
+
# treats first parameter q as input, and k, v as parameters, so input_shape=q.shape
|
50 |
+
def call(self, q, k, v, mask=None):
|
51 |
+
# shape of q: (batch_size, seq_len_q, d_q)
|
52 |
+
batch_size = tf.shape(q)[0]
|
53 |
+
|
54 |
+
# compute Q = q * w_q, ...
|
55 |
+
Q = self.wq(q) # (batch_size, seq_len_q, d_q) x (d_q, d_model) --> (batch_size, seq_len_q, d_model)
|
56 |
+
K = self.wk(k) # ... --> (batch_size, seq_len_k, d_model)
|
57 |
+
V = self.wv(v) # ... --> (batch_size, seq_len_v, d_model)
|
58 |
+
|
59 |
+
# split heads: d_model = num_heads * head_depth + reshape
|
60 |
+
Q = self.split_heads(Q, batch_size) # (batch_size, num_heads, seq_len_q, head_depth)
|
61 |
+
K = self.split_heads(K, batch_size) # (batch_size, num_heads, seq_len_k, head_depth)
|
62 |
+
V = self.split_heads(V, batch_size) # (batch_size, num_heads, seq_len_v, head_depth)
|
63 |
+
|
64 |
+
# similarity between context vector Q and key K // self-similarity in case of self-attention
|
65 |
+
compatibility = tf.matmul(Q, K, transpose_b=True) # (batch_size, num_heads, seq_len_q, seq_len_k)
|
66 |
+
# seq_len_q = n_nodes for encoder self-attention
|
67 |
+
# seq_len_q = 1 for decoder context-vector attention
|
68 |
+
# seq_len_k = n_nodes for both encoder & decoder
|
69 |
+
# rescaling
|
70 |
+
dk = tf.cast(tf.shape(K)[-1], tf.float32)
|
71 |
+
compatibility = compatibility / tf.math.sqrt(dk)
|
72 |
+
|
73 |
+
if mask is not None:
|
74 |
+
# we need to reshape mask:
|
75 |
+
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k)
|
76 |
+
# so that we will be able to do a broadcast:
|
77 |
+
# (batch_size, num_heads, seq_len_q, seq_len_k) + (batch_size, 1, seq_len_q, seq_len_k)
|
78 |
+
mask = mask[:, tf.newaxis, :, :]
|
79 |
+
|
80 |
+
# we use tf.where since 0*-np.inf returns nan, but not -np.inf
|
81 |
+
# compatibility = tf.where(
|
82 |
+
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
|
83 |
+
# compatibility
|
84 |
+
# )
|
85 |
+
|
86 |
+
compatibility = tf.where(mask,
|
87 |
+
tf.ones_like(compatibility) * (-np.inf),
|
88 |
+
compatibility)
|
89 |
+
|
90 |
+
compatibility = tf.nn.softmax(compatibility, axis=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
|
91 |
+
|
92 |
+
# Replace NaN by zeros (tf.nn.softmax returns NaNs for masked rows)
|
93 |
+
compatibility = tf.where(tf.math.is_nan(compatibility), tf.zeros_like(compatibility), compatibility)
|
94 |
+
|
95 |
+
# seq_len_k = seq_len_v
|
96 |
+
attention = tf.matmul(compatibility, V) # (batch_size, num_heads, seq_len_q, head_depth)
|
97 |
+
|
98 |
+
# transpose back to (batch_size, seq_len_q, num_heads, head_depth)
|
99 |
+
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
|
100 |
+
|
101 |
+
# concatenate heads (last 2 dimensions)
|
102 |
+
attention = tf.reshape(attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
|
103 |
+
|
104 |
+
# project output to the same dimension
|
105 |
+
# this is equiv. to sum in the article (project heads with W_o and sum), beacuse of block-matrix multiplication
|
106 |
+
#e.g. https://math.stackexchange.com/questions/2961550/matrix-block-multiplication-definition-properties-and-applications
|
107 |
+
|
108 |
+
output = self.w_out(attention) # (batch_size, seq_len_q, d_model)
|
109 |
+
|
110 |
+
return output
|
reinforce_baseline.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from scipy.stats import ttest_rel
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from attention_dynamic_model import AttentionDynamicModel
|
7 |
+
from attention_dynamic_model import set_decode_type
|
8 |
+
from utils import generate_data_onfly
|
9 |
+
|
10 |
+
|
11 |
+
def copy_of_tf_model(model, embedding_dim=128, graph_size=20):
|
12 |
+
"""Copy model weights to new model
|
13 |
+
"""
|
14 |
+
# https://stackoverflow.com/questions/56841736/how-to-copy-a-network-in-tensorflow-2-0
|
15 |
+
CAPACITIES = {10: 20.,
|
16 |
+
20: 30.,
|
17 |
+
50: 40.,
|
18 |
+
100: 50.
|
19 |
+
}
|
20 |
+
|
21 |
+
data_random = [tf.random.uniform((2, 2,), minval=0, maxval=1, dtype=tf.dtypes.float32),
|
22 |
+
tf.random.uniform((2, graph_size, 2), minval=0, maxval=1, dtype=tf.dtypes.float32),
|
23 |
+
tf.cast(tf.random.uniform(minval=1, maxval=10, shape=(2, graph_size),
|
24 |
+
dtype=tf.int32), tf.float32) / tf.cast(CAPACITIES[graph_size], tf.float32)]
|
25 |
+
|
26 |
+
new_model = AttentionDynamicModel(embedding_dim)
|
27 |
+
set_decode_type(new_model, "sampling")
|
28 |
+
_, _ = new_model(data_random)
|
29 |
+
|
30 |
+
for a, b in zip(new_model.variables, model.variables):
|
31 |
+
a.assign(b)
|
32 |
+
|
33 |
+
return new_model
|
34 |
+
|
35 |
+
def rollout(model, dataset, batch_size = 1000, disable_tqdm = False):
|
36 |
+
# Evaluate model in greedy mode
|
37 |
+
set_decode_type(model, "greedy")
|
38 |
+
costs_list = []
|
39 |
+
|
40 |
+
for batch in tqdm(dataset.batch(batch_size), disable=disable_tqdm, desc="Rollout greedy execution"):
|
41 |
+
cost, _ = model(batch)
|
42 |
+
costs_list.append(cost)
|
43 |
+
|
44 |
+
return tf.concat(costs_list, axis=0)
|
45 |
+
|
46 |
+
|
47 |
+
def validate(dataset, model, batch_size=1000):
|
48 |
+
"""Validates model on given dataset in greedy mode
|
49 |
+
"""
|
50 |
+
val_costs = rollout(model, dataset, batch_size=batch_size)
|
51 |
+
set_decode_type(model, "sampling")
|
52 |
+
mean_cost = tf.reduce_mean(val_costs)
|
53 |
+
print(f"Validation score: {np.round(mean_cost, 4)}")
|
54 |
+
return mean_cost
|
55 |
+
|
56 |
+
|
57 |
+
class RolloutBaseline:
|
58 |
+
|
59 |
+
def __init__(self, model, filename,
|
60 |
+
from_checkpoint=False,
|
61 |
+
path_to_checkpoint=None,
|
62 |
+
wp_n_epochs=1,
|
63 |
+
epoch=0,
|
64 |
+
num_samples=10000,
|
65 |
+
warmup_exp_beta=0.8,
|
66 |
+
embedding_dim=128,
|
67 |
+
graph_size=20
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
model: current model
|
72 |
+
filename: suffix for baseline checkpoint filename
|
73 |
+
from_checkpoint: start from checkpoint flag
|
74 |
+
path_to_checkpoint: path to baseline model weights
|
75 |
+
wp_n_epochs: number of warm-up epochs
|
76 |
+
epoch: current epoch number
|
77 |
+
num_samples: number of samples to be generated for baseline dataset
|
78 |
+
warmup_exp_beta: warmup mixing parameter (exp. moving average parameter)
|
79 |
+
|
80 |
+
"""
|
81 |
+
|
82 |
+
self.num_samples = num_samples
|
83 |
+
self.cur_epoch = epoch
|
84 |
+
self.wp_n_epochs = wp_n_epochs
|
85 |
+
self.beta = warmup_exp_beta
|
86 |
+
|
87 |
+
# controls the amount of warmup
|
88 |
+
self.alpha = 0.0
|
89 |
+
|
90 |
+
self.running_average_cost = None
|
91 |
+
|
92 |
+
# Checkpoint params
|
93 |
+
self.filename = filename
|
94 |
+
self.from_checkpoint = from_checkpoint
|
95 |
+
self.path_to_checkpoint = path_to_checkpoint
|
96 |
+
|
97 |
+
# Problem params
|
98 |
+
self.embedding_dim = embedding_dim
|
99 |
+
self.graph_size = graph_size
|
100 |
+
|
101 |
+
# create and evaluate initial baseline
|
102 |
+
self._update_baseline(model, epoch)
|
103 |
+
|
104 |
+
|
105 |
+
def _update_baseline(self, model, epoch):
|
106 |
+
|
107 |
+
# Load or copy baseline model based on self.from_checkpoint condition
|
108 |
+
if self.from_checkpoint and self.alpha == 0:
|
109 |
+
print('Baseline model loaded')
|
110 |
+
self.model = load_tf_model(self.path_to_checkpoint,
|
111 |
+
embedding_dim=self.embedding_dim,
|
112 |
+
graph_size=self.graph_size)
|
113 |
+
else:
|
114 |
+
self.model = copy_of_tf_model(model,
|
115 |
+
embedding_dim=self.embedding_dim,
|
116 |
+
graph_size=self.graph_size)
|
117 |
+
|
118 |
+
# For checkpoint
|
119 |
+
self.model.save_weights('baseline_checkpoint_epoch_{}_{}.h5'.format(epoch, self.filename), save_format='h5')
|
120 |
+
|
121 |
+
# We generate a new dataset for baseline model on each baseline update to prevent possible overfitting
|
122 |
+
self.dataset = generate_data_onfly(num_samples=self.num_samples, graph_size=self.graph_size)
|
123 |
+
|
124 |
+
print(f"Evaluating baseline model on baseline dataset (epoch = {epoch})")
|
125 |
+
self.bl_vals = rollout(self.model, self.dataset)
|
126 |
+
self.mean = tf.reduce_mean(self.bl_vals)
|
127 |
+
self.cur_epoch = epoch
|
128 |
+
|
129 |
+
def ema_eval(self, cost):
|
130 |
+
"""This is running average of cost through previous batches (only for warm-up epochs)
|
131 |
+
"""
|
132 |
+
|
133 |
+
if self.running_average_cost is None:
|
134 |
+
self.running_average_cost = tf.reduce_mean(cost)
|
135 |
+
else:
|
136 |
+
self.running_average_cost = self.beta * self.running_average_cost + (1. - self.beta) * tf.reduce_mean(cost)
|
137 |
+
|
138 |
+
return self.running_average_cost
|
139 |
+
|
140 |
+
def eval(self, batch, cost):
|
141 |
+
"""Evaluates current baseline model on single training batch
|
142 |
+
"""
|
143 |
+
|
144 |
+
if self.alpha == 0:
|
145 |
+
return self.ema_eval(cost)
|
146 |
+
|
147 |
+
if self.alpha < 1:
|
148 |
+
v_ema = self.ema_eval(cost)
|
149 |
+
else:
|
150 |
+
v_ema = 0.0
|
151 |
+
|
152 |
+
v_b, _ = self.model(batch)
|
153 |
+
|
154 |
+
v_b = tf.stop_gradient(v_b)
|
155 |
+
v_ema = tf.stop_gradient(v_ema)
|
156 |
+
|
157 |
+
# Combination of baseline cost and exp. moving average cost
|
158 |
+
return self.alpha * v_b + (1 - self.alpha) * v_ema
|
159 |
+
|
160 |
+
def eval_all(self, dataset):
|
161 |
+
"""Evaluates current baseline model on the whole dataset only for non warm-up epochs
|
162 |
+
"""
|
163 |
+
|
164 |
+
if self.alpha < 1:
|
165 |
+
return None
|
166 |
+
|
167 |
+
val_costs = rollout(self.model, dataset, batch_size=2048)
|
168 |
+
|
169 |
+
return val_costs
|
170 |
+
|
171 |
+
def epoch_callback(self, model, epoch):
|
172 |
+
"""Compares current baseline model with the training model and updates baseline if it is improved
|
173 |
+
"""
|
174 |
+
|
175 |
+
self.cur_epoch = epoch
|
176 |
+
|
177 |
+
print(f"Evaluating candidate model on baseline dataset (callback epoch = {self.cur_epoch})")
|
178 |
+
candidate_vals = rollout(model, self.dataset) # costs for training model on baseline dataset
|
179 |
+
candidate_mean = tf.reduce_mean(candidate_vals)
|
180 |
+
|
181 |
+
diff = candidate_mean - self.mean
|
182 |
+
|
183 |
+
print(f"Epoch {self.cur_epoch} candidate mean {candidate_mean}, baseline epoch {self.cur_epoch} mean {self.mean}, difference {diff}")
|
184 |
+
|
185 |
+
if diff < 0:
|
186 |
+
# statistic + p-value
|
187 |
+
t, p = ttest_rel(candidate_vals, self.bl_vals)
|
188 |
+
|
189 |
+
p_val = p / 2
|
190 |
+
print(f"p-value: {p_val}")
|
191 |
+
|
192 |
+
if p_val < 0.05:
|
193 |
+
print('Update baseline')
|
194 |
+
self._update_baseline(model, self.cur_epoch)
|
195 |
+
|
196 |
+
# alpha controls the amount of warmup
|
197 |
+
if self.alpha < 1.0:
|
198 |
+
self.alpha = (self.cur_epoch + 1) / float(self.wp_n_epochs)
|
199 |
+
print(f"alpha was updated to {self.alpha}")
|
200 |
+
|
201 |
+
|
202 |
+
def load_tf_model(path, embedding_dim=128, graph_size=20, n_encode_layers=2):
|
203 |
+
"""Load model weights from hd5 file
|
204 |
+
"""
|
205 |
+
# https://stackoverflow.com/questions/51806852/cant-save-custom-subclassed-model
|
206 |
+
CAPACITIES = {10: 20.,
|
207 |
+
20: 30.,
|
208 |
+
50: 40.,
|
209 |
+
100: 50.
|
210 |
+
}
|
211 |
+
|
212 |
+
data_random = [tf.random.uniform((2, 2,), minval=0, maxval=1, dtype=tf.dtypes.float32),
|
213 |
+
tf.random.uniform((2, graph_size, 2), minval=0, maxval=1, dtype=tf.dtypes.float32),
|
214 |
+
tf.cast(tf.random.uniform(minval=1, maxval=10, shape=(2, graph_size),
|
215 |
+
dtype=tf.int32), tf.float32) / tf.cast(CAPACITIES[graph_size], tf.float32)]
|
216 |
+
|
217 |
+
model_loaded = AttentionDynamicModel(embedding_dim,n_encode_layers=n_encode_layers)
|
218 |
+
set_decode_type(model_loaded, "greedy")
|
219 |
+
_, _ = model_loaded(data_random)
|
220 |
+
|
221 |
+
model_loaded.load_weights(path)
|
222 |
+
|
223 |
+
return model_loaded
|
utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import tensorflow as tf
|
3 |
+
import pandas as pd
|
4 |
+
import seaborn as sns
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import numpy as np
|
8 |
+
from datetime import datetime
|
9 |
+
import time
|
10 |
+
|
11 |
+
|
12 |
+
def create_data_on_disk(graph_size, num_samples, is_save=True, filename=None, is_return=False, seed=1234):
|
13 |
+
"""Generate validation dataset (with SEED) and save
|
14 |
+
"""
|
15 |
+
|
16 |
+
CAPACITIES = {
|
17 |
+
10: 20.,
|
18 |
+
20: 30.,
|
19 |
+
50: 40.,
|
20 |
+
100: 50.
|
21 |
+
}
|
22 |
+
depo, graphs, demand = (tf.random.uniform(minval=0, maxval=1, shape=(num_samples, 2), seed=seed),
|
23 |
+
tf.random.uniform(minval=0, maxval=1, shape=(num_samples, graph_size, 2), seed=seed),
|
24 |
+
tf.cast(tf.random.uniform(minval=1, maxval=10, shape=(num_samples, graph_size),
|
25 |
+
dtype=tf.int32, seed=seed), tf.float32) / tf.cast(CAPACITIES[graph_size], tf.float32)
|
26 |
+
)
|
27 |
+
if is_save:
|
28 |
+
save_to_pickle('Validation_dataset_{}.pkl'.format(filename), (depo, graphs, demand))
|
29 |
+
|
30 |
+
if is_return:
|
31 |
+
return tf.data.Dataset.from_tensor_slices((list(depo), list(graphs), list(demand)))
|
32 |
+
|
33 |
+
|
34 |
+
def save_to_pickle(filename, item):
|
35 |
+
"""Save to pickle
|
36 |
+
"""
|
37 |
+
with open(filename, 'wb') as handle:
|
38 |
+
pickle.dump(item, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
39 |
+
|
40 |
+
|
41 |
+
def read_from_pickle(path, return_tf_data_set=True, num_samples=None):
|
42 |
+
"""Read dataset from file (pickle)
|
43 |
+
"""
|
44 |
+
|
45 |
+
objects = []
|
46 |
+
with (open(path, "rb")) as openfile:
|
47 |
+
while True:
|
48 |
+
try:
|
49 |
+
objects.append(pickle.load(openfile))
|
50 |
+
except EOFError:
|
51 |
+
break
|
52 |
+
objects = objects[0]
|
53 |
+
if return_tf_data_set:
|
54 |
+
depo, graphs, demand = objects
|
55 |
+
if num_samples is not None:
|
56 |
+
return tf.data.Dataset.from_tensor_slices((list(depo), list(graphs), list(demand))).take(num_samples)
|
57 |
+
else:
|
58 |
+
return tf.data.Dataset.from_tensor_slices((list(depo), list(graphs), list(demand)))
|
59 |
+
else:
|
60 |
+
return objects
|
61 |
+
|
62 |
+
|
63 |
+
def generate_data_onfly(num_samples=10000, graph_size=20):
|
64 |
+
"""Generate temp dataset in memory
|
65 |
+
"""
|
66 |
+
|
67 |
+
CAPACITIES = {
|
68 |
+
10: 20.,
|
69 |
+
20: 30.,
|
70 |
+
50: 40.,
|
71 |
+
100: 50.
|
72 |
+
}
|
73 |
+
depo, graphs, demand = (tf.random.uniform(minval=0, maxval=1, shape=(num_samples, 2)),
|
74 |
+
tf.random.uniform(minval=0, maxval=1, shape=(num_samples, graph_size, 2)),
|
75 |
+
tf.cast(tf.random.uniform(minval=1, maxval=10, shape=(num_samples, graph_size),
|
76 |
+
dtype=tf.int32), tf.float32)/tf.cast(CAPACITIES[graph_size], tf.float32)
|
77 |
+
)
|
78 |
+
|
79 |
+
return tf.data.Dataset.from_tensor_slices((list(depo), list(graphs), list(demand)))
|
80 |
+
|
81 |
+
|
82 |
+
def get_results(train_loss_results, train_cost_results, val_cost, save_results=True, filename=None, plots=True):
|
83 |
+
|
84 |
+
epochs_num = len(train_loss_results)
|
85 |
+
|
86 |
+
df_train = pd.DataFrame(data={'epochs': list(range(epochs_num)),
|
87 |
+
'loss': train_loss_results,
|
88 |
+
'cost': train_cost_results,
|
89 |
+
})
|
90 |
+
df_test = pd.DataFrame(data={'epochs': list(range(epochs_num)),
|
91 |
+
'val_сost': val_cost})
|
92 |
+
if save_results:
|
93 |
+
df_train.to_excel('train_results_{}.xlsx'.format(filename), index=False)
|
94 |
+
df_test.to_excel('test_results_{}.xlsx'.format(filename), index=False)
|
95 |
+
|
96 |
+
if plots:
|
97 |
+
plt.figure(figsize=(15, 9))
|
98 |
+
ax = sns.lineplot(x='epochs', y='loss', data=df_train, color='salmon', label='train loss')
|
99 |
+
ax2 = ax.twinx()
|
100 |
+
sns.lineplot(x='epochs', y='cost', data=df_train, color='cornflowerblue', label='train cost', ax=ax2)
|
101 |
+
sns.lineplot(x='epochs', y='val_сost', data=df_test, palette='darkblue', label='val cost').set(ylabel='cost')
|
102 |
+
ax.legend(loc=(0.75, 0.90), ncol=1)
|
103 |
+
ax2.legend(loc=(0.75, 0.95), ncol=2)
|
104 |
+
ax.grid(axis='x')
|
105 |
+
ax2.grid(True)
|
106 |
+
plt.savefig('learning_curve_plot_{}.jpg'.format(filename))
|
107 |
+
plt.show()
|
108 |
+
|
109 |
+
|
110 |
+
def get_journey(batch, pi, title, ind_in_batch=0):
|
111 |
+
"""Plots journey of agent
|
112 |
+
|
113 |
+
Args:
|
114 |
+
batch: dataset of graphs
|
115 |
+
pi: paths of agent obtained from model
|
116 |
+
ind_in_batch: index of graph in batch to be plotted
|
117 |
+
"""
|
118 |
+
|
119 |
+
# Remove extra zeros
|
120 |
+
pi_ = get_clean_path(pi[ind_in_batch].numpy())
|
121 |
+
|
122 |
+
# Unpack variables
|
123 |
+
depo_coord = batch[0][ind_in_batch].numpy()
|
124 |
+
points_coords = batch[1][ind_in_batch].numpy()
|
125 |
+
demands = batch[2][ind_in_batch].numpy()
|
126 |
+
node_labels = ['(' + str(x[0]) + ', ' + x[1] + ')' for x in enumerate(demands.round(2).astype(str))]
|
127 |
+
|
128 |
+
# Concatenate depot and points
|
129 |
+
full_coords = np.concatenate((depo_coord.reshape(1, 2), points_coords))
|
130 |
+
|
131 |
+
# Get list with agent loops in path
|
132 |
+
list_of_paths = []
|
133 |
+
cur_path = []
|
134 |
+
for idx, node in enumerate(pi_):
|
135 |
+
|
136 |
+
cur_path.append(node)
|
137 |
+
|
138 |
+
if idx != 0 and node == 0:
|
139 |
+
if cur_path[0] != 0:
|
140 |
+
cur_path.insert(0, 0)
|
141 |
+
list_of_paths.append(cur_path)
|
142 |
+
cur_path = []
|
143 |
+
|
144 |
+
list_of_path_traces = []
|
145 |
+
for path_counter, path in enumerate(list_of_paths):
|
146 |
+
coords = full_coords[[int(x) for x in path]]
|
147 |
+
|
148 |
+
# Calculate length of each agent loop
|
149 |
+
lengths = np.sqrt(np.sum(np.diff(coords, axis=0) ** 2, axis=1))
|
150 |
+
total_length = np.sum(lengths)
|
151 |
+
|
152 |
+
list_of_path_traces.append(go.Scatter(x=coords[:, 0],
|
153 |
+
y=coords[:, 1],
|
154 |
+
mode="markers+lines",
|
155 |
+
name=f"path_{path_counter}, length={total_length:.2f}",
|
156 |
+
opacity=1.0))
|
157 |
+
|
158 |
+
trace_points = go.Scatter(x=points_coords[:, 0],
|
159 |
+
y=points_coords[:, 1],
|
160 |
+
mode='markers+text',
|
161 |
+
name='destinations',
|
162 |
+
text=node_labels,
|
163 |
+
textposition='top center',
|
164 |
+
marker=dict(size=7),
|
165 |
+
opacity=1.0
|
166 |
+
)
|
167 |
+
|
168 |
+
trace_depo = go.Scatter(x=[depo_coord[0]],
|
169 |
+
y=[depo_coord[1]],
|
170 |
+
text=['1.0'], textposition='bottom center',
|
171 |
+
mode='markers+text',
|
172 |
+
marker=dict(size=15),
|
173 |
+
name='depot'
|
174 |
+
)
|
175 |
+
|
176 |
+
layout = go.Layout(title='<b>Example: {}</b>'.format(title),
|
177 |
+
xaxis=dict(title='X coordinate'),
|
178 |
+
yaxis=dict(title='Y coordinate'),
|
179 |
+
showlegend=True,
|
180 |
+
width=1000,
|
181 |
+
height=1000,
|
182 |
+
template="plotly_white"
|
183 |
+
)
|
184 |
+
|
185 |
+
data = [trace_points, trace_depo] + list_of_path_traces
|
186 |
+
print('Current path: ', pi_)
|
187 |
+
fig = go.Figure(data=data, layout=layout)
|
188 |
+
fig.show()
|
189 |
+
|
190 |
+
def get_cur_time():
|
191 |
+
"""Returns local time as string
|
192 |
+
"""
|
193 |
+
ts = time.time()
|
194 |
+
return datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
|
195 |
+
|
196 |
+
|
197 |
+
def get_clean_path(arr):
|
198 |
+
"""Returns extra zeros from path.
|
199 |
+
Dynamical model generates duplicated zeros for several graphs when obtaining partial solutions.
|
200 |
+
"""
|
201 |
+
|
202 |
+
p1, p2 = 0, 1
|
203 |
+
output = []
|
204 |
+
|
205 |
+
while p2 < len(arr):
|
206 |
+
|
207 |
+
if arr[p1] != arr[p2]:
|
208 |
+
output.append(arr[p1])
|
209 |
+
if p2 == len(arr) - 1:
|
210 |
+
output.append(arr[p2])
|
211 |
+
|
212 |
+
p1 += 1
|
213 |
+
p2 += 1
|
214 |
+
|
215 |
+
if output[0] != 0:
|
216 |
+
output.insert(0, 0.0)
|
217 |
+
if output[-1] != 0:
|
218 |
+
output.append(0.0)
|
219 |
+
|
220 |
+
return output
|
221 |
+
|
utils_demo.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
from plotly.subplots import make_subplots
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def f_get_results_plot_seaborn(data, title, graph_size=20):
|
9 |
+
fig = plt.figure(figsize=(15, 9))
|
10 |
+
ax = fig.add_subplot()
|
11 |
+
ax.plot(data['epochs'], data['train_loss'], color='salmon', label='train loss')
|
12 |
+
ax2 = ax.twinx()
|
13 |
+
ax2.plot(data['epochs'], data['train_cost'], color='cornflowerblue', label='train cost')
|
14 |
+
ax2.plot(data['epochs'], data['val_cost'], color='darkblue', label='val cost')
|
15 |
+
|
16 |
+
if graph_size == 20:
|
17 |
+
am_val = 6.4
|
18 |
+
else:
|
19 |
+
am_val = 10.98
|
20 |
+
|
21 |
+
plt.axhline(y=am_val, color='black', linestyle='--', linewidth=1.5, label='AM article best score')
|
22 |
+
|
23 |
+
fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax.transAxes)
|
24 |
+
|
25 |
+
ax.set_ylabel('Loss')
|
26 |
+
ax2.set_ylabel('Cost')
|
27 |
+
ax.set_xlabel('Epochs')
|
28 |
+
ax.grid(False)
|
29 |
+
ax2.grid(False)
|
30 |
+
ax2.set_yticks(np.arange(min(data['val_cost'].min(), data['train_cost'].min())-0.2,
|
31 |
+
max(data['val_cost'].max(), data['train_cost'].max())+0.1,
|
32 |
+
0.1).round(2))
|
33 |
+
plt.title('Learning Curve: ' + title)
|
34 |
+
plt.show()
|
35 |
+
|
36 |
+
|
37 |
+
def f_get_results_plot_plotly(data, title, graph_size=20):
|
38 |
+
# Create figure with secondary y-axis
|
39 |
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
40 |
+
|
41 |
+
# Add traces
|
42 |
+
fig.add_trace(
|
43 |
+
go.Scatter(x=data['epochs'], y=data['train_loss'], name="train loss", marker_color='salmon'),
|
44 |
+
secondary_y=False,
|
45 |
+
)
|
46 |
+
|
47 |
+
fig.add_trace(
|
48 |
+
go.Scatter(x=data['epochs'], y=data['train_cost'], name="train cost", marker_color='cornflowerblue'),
|
49 |
+
secondary_y=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
fig.add_trace(
|
53 |
+
go.Scatter(x=data['epochs'], y=data['val_cost'], name="val cost", marker_color='darkblue'),
|
54 |
+
secondary_y=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
# Add figure title
|
58 |
+
fig.update_layout(
|
59 |
+
title_text="Learning Curve: " + title,
|
60 |
+
width=950,
|
61 |
+
height=650,
|
62 |
+
# plot_bgcolor='rgba(0,0,0,0)'
|
63 |
+
template="plotly_white"
|
64 |
+
)
|
65 |
+
|
66 |
+
# Set x-axis title
|
67 |
+
fig.update_xaxes(title_text="Number of epoch")
|
68 |
+
|
69 |
+
# Set y-axes titles
|
70 |
+
fig.update_yaxes(title_text="<b>Loss", secondary_y=False, showgrid=False, zeroline=False)
|
71 |
+
fig.update_yaxes(title_text="<b>Cost", secondary_y=True, dtick=0.1)#, nticks=20)
|
72 |
+
|
73 |
+
fig.show()
|