cvrp-model / attention_dynamic_model.py
peterkros's picture
Upload 7 files
d6c2e08
raw
history blame
12.2 kB
import tensorflow as tf
import numpy as np
from attention_graph_encoder import GraphAttentionEncoder
from enviroment import AgentVRP
def set_decode_type(model, decode_type):
model.set_decode_type(decode_type)
class AttentionDynamicModel(tf.keras.Model):
def __init__(self,
embedding_dim,
n_encode_layers=2,
n_heads=8,
tanh_clipping=10.
):
super().__init__()
# attributes for MHA
self.embedding_dim = embedding_dim
self.n_encode_layers = n_encode_layers
self.decode_type = None
# attributes for VRP problem
self.problem = AgentVRP
self.n_heads = n_heads
# Encoder part
self.embedder = GraphAttentionEncoder(input_dim=self.embedding_dim,
num_heads=self.n_heads,
num_layers=self.n_encode_layers
)
# Decoder part
self.output_dim = self.embedding_dim
self.num_heads = n_heads
self.head_depth = self.output_dim // self.num_heads
self.dk_mha_decoder = tf.cast(self.head_depth, tf.float32) # for decoding in mha_decoder
self.dk_get_loc_p = tf.cast(self.output_dim, tf.float32) # for decoding in mha_decoder
if self.output_dim % self.num_heads != 0:
raise ValueError("number of heads must divide d_model=output_dim")
self.tanh_clipping = tanh_clipping
# we split projection matrix Wq into 2 matrices: Wq*[h_c, h_N, D] = Wq_context*h_c + Wq_step_context[h_N, D]
self.wq_context = tf.keras.layers.Dense(self.output_dim, use_bias=False,
name='wq_context') # (d_q_context, output_dim)
self.wq_step_context = tf.keras.layers.Dense(self.output_dim, use_bias=False,
name='wq_step_context') # (d_q_step_context, output_dim)
# we need two Wk projections since there is MHA followed by 1-head attention - they have different keys K
self.wk = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk') # (d_k, output_dim)
self.wk_tanh = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wk_tanh') # (d_k_tanh, output_dim)
# we dont need Wv projection for 1-head attention: only need attention weights as outputs
self.wv = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='wv') # (d_v, output_dim)
# we dont need wq for 1-head tanh attention, since we can absorb it into w_out
self.w_out = tf.keras.layers.Dense(self.output_dim, use_bias=False, name='w_out') # (d_model, d_model)
def set_decode_type(self, decode_type):
self.decode_type = decode_type
def split_heads(self, tensor, batch_size):
"""Function for computing attention on several heads simultaneously
Splits last dimension of a tensor into (num_heads, head_depth).
Then we transpose it as (batch_size, num_heads, ..., head_depth) so that we can use broadcast
"""
tensor = tf.reshape(tensor, (batch_size, -1, self.num_heads, self.head_depth))
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def _select_node(self, logits):
"""Select next node based on decoding type.
"""
# assert tf.reduce_all(logits == logits), "Probs should not contain any nans"
if self.decode_type == "greedy":
selected = tf.math.argmax(logits, axis=-1) # (batch_size, 1)
elif self.decode_type == "sampling":
# logits has a shape of (batch_size, 1, n_nodes), we have to squeeze it
# to (batch_size, n_nodes) since tf.random.categorical requires matrix
selected = tf.random.categorical(logits[:, 0, :], 1) # (bach_size,1)
else:
assert False, "Unknown decode type"
return tf.squeeze(selected, axis=-1) # (bach_size,)
def get_step_context(self, state, embeddings):
"""Takes a state and graph embeddings,
Returns a part [h_N, D] of context vector [h_c, h_N, D],
that is related to RL Agent last step.
"""
# index of previous node
prev_node = state.prev_a # (batch_size, 1)
# from embeddings=(batch_size, n_nodes, input_dim) select embeddings of previous nodes
cur_embedded_node = tf.gather(embeddings, tf.cast(prev_node, tf.int32), batch_dims=1) # (batch_size, 1, input_dim)
# add remaining capacity
step_context = tf.concat([cur_embedded_node, self.problem.VEHICLE_CAPACITY - state.used_capacity[:, :, None]], axis=-1)
return step_context # (batch_size, 1, input_dim + 1)
def decoder_mha(self, Q, K, V, mask=None):
""" Computes Multi-Head Attention part of decoder
Basically, its a part of MHA sublayer, but we cant construct a layer since Q changes in a decoding loop.
Args:
mask: a mask for visited nodes,
has shape (batch_size, seq_len_q, seq_len_k), seq_len_q = 1 for context vector attention in decoder
Q: query (context vector for decoder)
has shape (..., seq_len_q, head_depth) with seq_len_q = 1 for context_vector attention in decoder
K, V: key, value (projections of nodes embeddings)
have shape (..., seq_len_k, head_depth), (..., seq_len_v, head_depth),
with seq_len_k = seq_len_v = n_nodes for decoder
"""
compatibility = tf.matmul(Q, K, transpose_b=True)/tf.math.sqrt(self.dk_mha_decoder) # (batch_size, num_heads, seq_len_q, seq_len_k)
if mask is not None:
# we need to reshape mask:
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, 1, seq_len_q, seq_len_k)
# so that we will be able to do a broadcast:
# (batch_size, num_heads, seq_len_q, seq_len_k) + (batch_size, 1, seq_len_q, seq_len_k)
mask = mask[:, tf.newaxis, :, :]
# we use tf.where since 0*-np.inf returns nan, but not -np.inf
# compatibility = tf.where(
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
# compatibility
# )
compatibility = tf.where(mask,
tf.ones_like(compatibility) * (-np.inf),
compatibility
)
compatibility = tf.nn.softmax(compatibility, axis=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
attention = tf.matmul(compatibility, V) # (batch_size, num_heads, seq_len_q, head_depth)
# transpose back to (batch_size, seq_len_q, num_heads, depth)
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
# concatenate heads (last 2 dimensions)
attention = tf.reshape(attention, (self.batch_size, -1, self.output_dim)) # (batch_size, seq_len_q, output_dim)
output = self.w_out(attention) # (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context att in decoder
return output
def get_log_p(self, Q, K, mask=None):
"""Single-Head attention sublayer in decoder,
computes log-probabilities for node selection.
Args:
mask: mask for nodes
Q: query (output of mha layer)
has shape (batch_size, seq_len_q, output_dim), seq_len_q = 1 for context attention in decoder
K: key (projection of node embeddings)
has shape (batch_size, seq_len_k, output_dim), seq_len_k = n_nodes for decoder
"""
compatibility = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(self.dk_get_loc_p)
compatibility = tf.math.tanh(compatibility) * self.tanh_clipping
if mask is not None:
# we dont need to reshape mask like we did in multi-head version:
# (batch_size, seq_len_q, seq_len_k) --> (batch_size, num_heads, seq_len_q, seq_len_k)
# since we dont have multiple heads
# compatibility = tf.where(
# tf.broadcast_to(mask, compatibility.shape), tf.ones_like(compatibility) * (-np.inf),
# compatibility
# )
compatibility = tf.where(mask,
tf.ones_like(compatibility) * (-np.inf),
compatibility
)
log_p = tf.nn.log_softmax(compatibility, axis=-1) # (batch_size, seq_len_q, seq_len_k)
return log_p
def get_log_likelihood(self, _log_p, a):
# Get log_p corresponding to selected actions
log_p = tf.gather_nd(_log_p, tf.cast(tf.expand_dims(a, axis=-1), tf.int32), batch_dims=2)
# Calculate log_likelihood
return tf.reduce_sum(log_p,1)
def get_projections(self, embeddings, context_vectors):
# we compute some projections (common for each policy step) before decoding loop for efficiency
K = self.wk(embeddings) # (batch_size, n_nodes, output_dim)
K_tanh = self.wk_tanh(embeddings) # (batch_size, n_nodes, output_dim)
V = self.wv(embeddings) # (batch_size, n_nodes, output_dim)
Q_context = self.wq_context(context_vectors[:, tf.newaxis, :]) # (batch_size, 1, output_dim)
# we dont need to split K_tanh since there is only 1 head; Q will be split in decoding loop
K = self.split_heads(K, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth)
V = self.split_heads(V, self.batch_size) # (batch_size, num_heads, n_nodes, head_depth)
return K_tanh, Q_context, K, V
def call(self, inputs, return_pi=False):
embeddings, mean_graph_emb = self.embedder(inputs)
self.batch_size = tf.shape(embeddings)[0]
outputs = []
sequences = []
state = self.problem(inputs)
K_tanh, Q_context, K, V = self.get_projections(embeddings, mean_graph_emb)
# Perform decoding steps
i = 0
inner_i = 0
while not state.all_finished():
if i > 0:
state.i = tf.zeros(1, dtype=tf.int64)
att_mask, cur_num_nodes = state.get_att_mask()
embeddings, context_vectors = self.embedder(inputs, att_mask, cur_num_nodes)
K_tanh, Q_context, K, V = self.get_projections(embeddings, context_vectors)
inner_i = 0
while not state.partial_finished():
step_context = self.get_step_context(state, embeddings) # (batch_size, 1), (batch_size, 1, input_dim + 1)
Q_step_context = self.wq_step_context(step_context) # (batch_size, 1, output_dim)
Q = Q_context + Q_step_context
# split heads for Q
Q = self.split_heads(Q, self.batch_size) # (batch_size, num_heads, 1, head_depth)
# get current mask
mask = state.get_mask() # (batch_size, 1, n_nodes) with True/False indicating where agent can go
# compute MHA decoder vectors for current mask
mha = self.decoder_mha(Q, K, V, mask) # (batch_size, 1, output_dim)
# compute probabilities
log_p = self.get_log_p(mha, K_tanh, mask) # (batch_size, 1, n_nodes)
# next step is to select node
selected = self._select_node(log_p)
state.step(selected)
outputs.append(log_p[:, 0, :])
sequences.append(selected)
inner_i += 1
i += 1
_log_p, pi = tf.stack(outputs, 1), tf.cast(tf.stack(sequences, 1), tf.float32)
cost = self.problem.get_costs(inputs, pi)
ll = self.get_log_likelihood(_log_p, pi)
if return_pi:
return cost, ll, pi
return cost, ll