|
import tensorflow as tf |
|
import numpy as np |
|
|
|
from attention_graph_encoder import GraphAttentionEncoder |
|
from enviroment import AgentVRP |
|
|
|
|
|
def set_decode_type(model, decode_type): |
|
model.set_decode_type(decode_type) |
|
|
|
class AttentionDynamicModel(tf.keras.Model): |
|
|
|
def __init__(self, |
|
embedding_dim, |
|
n_encode_layers=2, |
|
n_heads=8, |
|
tanh_clipping=10. |
|
): |
|
|
|
super().__init__() |
|
|
|
|
|
self.embedding_dim = embedding_dim |
|
self.n_encode_layers = n_encode_layers |
|
self.decode_type = None |
|
|
|
|
|
self.problem = AgentVRP |
|
self.n_heads = n_heads |
|
|
|
|
|
self.embedder = GraphAttentionEncoder(input_dim=self.embedding_dim, |
|
num_heads=self.n_heads, |
|
num_layers=self.n_encode_layers |
|
) |
|
|
|
|
|
|
|
self.output_dim = self.embedding_dim |
|
self.num_heads = n_heads |
|
|
|
self.head_depth = self.output_dim // self.num_heads |
|
self.dk_mha_decoder = tf.cast(self.head_depth, tf.float32) |
|
self.dk_get_loc_p = tf.cast(self.output_dim, tf.float32) |
|
|
|
if self.output_dim % self.num_heads != 0: |
|
raise ValueError("number of heads must divide d_model=output_dim") |
|
|
|
self.tanh_clipping = tanh_clipping |
|
|
|
|
|
self.wq_context = tf.keras.layers.Dense(self.output_dim, use_bias=False, |
|
name='wq_context') |
|
self.wq_step_context = tf.keras.layers.Dense(self.output_dim, use_bias=False, |
|
name='wq_step_context') |
|
|
|
|
|
self.wk = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk') |
|
self.wk_tanh = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk_tanh') |
|
|
|
|
|
self.wv = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wv') |
|
|
|
|
|
self.w_out = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='w_out') |
|
|
|
def set_decode_type(self, decode_type): |
|
self.decode_type = decode_type |
|
|
|
def split_heads(self, tensor, batch_size): |
|
"""Function for computing attention on several heads simultaneously |
|
Splits last dimension of a tensor into (num_heads, head_depth). |
|
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast |
|
""" |
|
tensor = tf.reshape(tensor, (batch_size, -1, self.num_heads, self.head_depth)) |
|
return tf.transpose(tensor, perm=[0, 2, 1, 3]) |
|
|
|
def _select_node(self, logits): |
|
"""Select next node based on decoding type. |
|
""" |
|
|
|
|
|
|
|
if self.decode_type == "greedy": |
|
selected = tf.math.argmax(logits, axis=-1) |
|
|
|
elif self.decode_type == "sampling": |
|
|
|
|
|
selected = tf.random.categorical(logits[:, 0, :], 1) |
|
else: |
|
assert False, "Unknown decode type" |
|
|
|
return tf.squeeze(selected, axis=-1) |
|
|
|
def get_step_context(self, state, embeddings): |
|
"""Takes a state and graph embeddings, |
|
Returns a part [h_N, D] of context vector [h_c, h_N, D], |
|
that is related to RL Agent last step. |
|
""" |
|
|
|
prev_node = state.prev_a |
|
|
|
|
|
cur_embedded_node = tf.gather(embeddings, tf.cast(prev_node, tf.int32), batch_dims=1) |
|
|
|
|
|
step_context = tf.concat([cur_embedded_node, self.problem.VEHICLE_CAPACITY - state.used_capacity[:, :, None]], axis=-1) |
|
|
|
return step_context |
|
|
|
def decoder_mha(self, Q, K, V, mask=None): |
|
""" Computes Multi-Head Attention part of decoder |
|
Basically, its a part of MHA sublayer, but we cant construct a layer since Q changes in a decoding loop. |
|
|
|
Args: |
|
mask: a mask for visited nodes, |
|
has shape (batch_size, seq_len_q, seq_len_k), seq_len_q = 1 for context vector attention in decoder |
|
Q: query (context vector for decoder) |
|
has shape (..., seq_len_q, head_depth) with seq_len_q = 1 for context_vector attention in decoder |
|
K, V: key, value (projections of nodes embeddings) |
|
have shape (..., seq_len_k, head_depth), (..., seq_len_v, head_depth), |
|
with seq_len_k = seq_len_v = n_nodes for decoder |
|
""" |
|
|
|
compatibility = tf.matmul(Q, K, transpose_b=True)/tf.math.sqrt(self.dk_mha_decoder) |
|
|
|
if mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
mask = mask[:, tf.newaxis, :, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compatibility = tf.where(mask, |
|
tf.ones_like(compatibility) * (-np.inf), |
|
compatibility |
|
) |
|
|
|
|
|
compatibility = tf.nn.softmax(compatibility, axis=-1) |
|
attention = tf.matmul(compatibility, V) |
|
|
|
|
|
attention = tf.transpose(attention, perm=[0, 2, 1, 3]) |
|
|
|
|
|
attention = tf.reshape(attention, (self.batch_size, -1, self.output_dim)) |
|
|
|
output = self.w_out(attention) |
|
|
|
return output |
|
|
|
def get_log_p(self, Q, K, mask=None): |
|
"""Single-Head attention sublayer in decoder, |
|
computes log-probabilities for node selection. |
|
|
|
Args: |
|
mask: mask for nodes |
|
Q: query (output of mha layer) |
|
has shape (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context attention in decoder |
|
K: key (projection of node embeddings) |
|
has shape (batch_size, seq_len_k, output_dim), seq_len_k = n_nodes for decoder |
|
""" |
|
|
|
compatibility = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(self.dk_get_loc_p) |
|
compatibility = tf.math.tanh(compatibility) * self.tanh_clipping |
|
|
|
if mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compatibility = tf.where(mask, |
|
tf.ones_like(compatibility) * (-np.inf), |
|
compatibility |
|
) |
|
|
|
log_p = tf.nn.log_softmax(compatibility, axis=-1) |
|
|
|
return log_p |
|
|
|
def get_log_likelihood(self, _log_p, a): |
|
|
|
|
|
log_p = tf.gather_nd(_log_p, tf.cast(tf.expand_dims(a, axis=-1), tf.int32), batch_dims=2) |
|
|
|
|
|
return tf.reduce_sum(log_p,1) |
|
|
|
def get_projections(self, embeddings, context_vectors): |
|
|
|
|
|
K = self.wk(embeddings) |
|
K_tanh = self.wk_tanh(embeddings) |
|
V = self.wv(embeddings) |
|
Q_context = self.wq_context(context_vectors[:, tf.newaxis, :]) |
|
|
|
|
|
K = self.split_heads(K, self.batch_size) |
|
V = self.split_heads(V, self.batch_size) |
|
|
|
return K_tanh, Q_context, K, V |
|
|
|
def call(self, inputs, return_pi=False): |
|
|
|
embeddings, mean_graph_emb = self.embedder(inputs) |
|
|
|
self.batch_size = tf.shape(embeddings)[0] |
|
|
|
outputs = [] |
|
sequences = [] |
|
|
|
state = self.problem(inputs) |
|
|
|
K_tanh, Q_context, K, V = self.get_projections(embeddings, mean_graph_emb) |
|
|
|
|
|
i = 0 |
|
inner_i = 0 |
|
|
|
while not state.all_finished(): |
|
|
|
if i > 0: |
|
state.i = tf.zeros(1, dtype=tf.int64) |
|
att_mask, cur_num_nodes = state.get_att_mask() |
|
embeddings, context_vectors = self.embedder(inputs, att_mask, cur_num_nodes) |
|
K_tanh, Q_context, K, V = self.get_projections(embeddings, context_vectors) |
|
|
|
inner_i = 0 |
|
while not state.partial_finished(): |
|
|
|
step_context = self.get_step_context(state, embeddings) |
|
Q_step_context = self.wq_step_context(step_context) |
|
Q = Q_context + Q_step_context |
|
|
|
|
|
Q = self.split_heads(Q, self.batch_size) |
|
|
|
|
|
mask = state.get_mask() |
|
|
|
|
|
mha = self.decoder_mha(Q, K, V, mask) |
|
|
|
|
|
log_p = self.get_log_p(mha, K_tanh, mask) |
|
|
|
|
|
selected = self._select_node(log_p) |
|
|
|
state.step(selected) |
|
|
|
outputs.append(log_p[:, 0, :]) |
|
sequences.append(selected) |
|
|
|
inner_i += 1 |
|
|
|
i += 1 |
|
|
|
_log_p, pi = tf.stack(outputs, 1), tf.cast(tf.stack(sequences, 1), tf.float32) |
|
|
|
cost = self.problem.get_costs(inputs, pi) |
|
|
|
ll = self.get_log_likelihood(_log_p, pi) |
|
|
|
if return_pi: |
|
return cost, ll, pi |
|
|
|
return cost, ll |
|
|