|
import tensorflow as tf |
|
|
|
class AgentVRP(): |
|
|
|
VEHICLE_CAPACITY = 1.0 |
|
|
|
def __init__(self, input): |
|
|
|
depot = input[0] |
|
loc = input[1] |
|
|
|
self.batch_size, self.n_loc, _ = loc.shape |
|
|
|
|
|
self.coords = tf.concat((depot[:, None, :], loc), -2) |
|
self.demand = tf.cast(input[2], tf.float32) |
|
|
|
|
|
self.ids = tf.range(self.batch_size, dtype=tf.int64)[:, None] |
|
|
|
|
|
self.prev_a = tf.zeros((self.batch_size, 1), dtype=tf.float32) |
|
self.from_depot = self.prev_a == 0 |
|
self.used_capacity = tf.zeros((self.batch_size, 1), dtype=tf.float32) |
|
|
|
|
|
self.visited = tf.zeros((self.batch_size, 1, self.n_loc + 1), dtype=tf.uint8) |
|
|
|
|
|
self.i = tf.zeros(1, dtype=tf.int64) |
|
|
|
|
|
self.step_updates = tf.ones((self.batch_size, 1), dtype=tf.uint8) |
|
self.scatter_zeros = tf.zeros((self.batch_size, 1), dtype=tf.int64) |
|
|
|
@staticmethod |
|
def outer_pr(a, b): |
|
"""Outer product of matrices |
|
""" |
|
return tf.einsum('ki,kj->kij', a, b) |
|
|
|
def get_att_mask(self): |
|
""" Mask (batch_size, n_nodes, n_nodes) for attention encoder. |
|
We mask already visited nodes except depot |
|
""" |
|
|
|
|
|
att_mask = tf.squeeze(tf.cast(self.visited, tf.float32), axis=-2)[:, 1:] |
|
|
|
|
|
cur_num_nodes = self.n_loc + 1 - tf.reshape(tf.reduce_sum(att_mask, -1), (-1,1)) |
|
|
|
att_mask = tf.concat((tf.zeros(shape=(att_mask.shape[0],1),dtype=tf.float32),att_mask), axis=-1) |
|
|
|
ones_mask = tf.ones_like(att_mask) |
|
|
|
|
|
att_mask = AgentVRP.outer_pr(att_mask, ones_mask) \ |
|
+ AgentVRP.outer_pr(ones_mask, att_mask)\ |
|
- AgentVRP.outer_pr(att_mask, att_mask) |
|
|
|
return tf.cast(att_mask, dtype=tf.bool), cur_num_nodes |
|
|
|
def all_finished(self): |
|
"""Checks if all games are finished |
|
""" |
|
return tf.reduce_all(tf.cast(self.visited, tf.bool)) |
|
|
|
def partial_finished(self): |
|
"""Checks if partial solution for all graphs has been built, i.e. all agents came back to depot |
|
""" |
|
return tf.reduce_all(self.from_depot) and self.i != 0 |
|
|
|
def get_mask(self): |
|
""" Returns a mask (batch_size, 1, n_nodes) with available actions. |
|
Impossible nodes are masked. |
|
""" |
|
|
|
|
|
visited_loc = self.visited[:, :, 1:] |
|
|
|
|
|
exceeds_cap = self.demand + self.used_capacity > self.VEHICLE_CAPACITY |
|
|
|
|
|
|
|
mask_loc = tf.cast(visited_loc, tf.bool) | exceeds_cap[:, None, :] | ((self.i > 0) & self.from_depot[:, None, :]) |
|
|
|
|
|
mask_depot = self.from_depot & (tf.reduce_sum(tf.cast(mask_loc == False, tf.int32), axis=-1) > 0) |
|
|
|
return tf.concat([mask_depot[:, :, None], mask_loc], axis=-1) |
|
|
|
def step(self, action): |
|
|
|
|
|
selected = action[:, None] |
|
|
|
self.prev_a = selected |
|
self.from_depot = self.prev_a == 0 |
|
|
|
|
|
|
|
selected_demand = tf.gather_nd(self.demand, |
|
tf.concat([self.ids, tf.clip_by_value(self.prev_a - 1, 0, self.n_loc - 1)], axis=1) |
|
)[:, None] |
|
|
|
|
|
self.used_capacity = (self.used_capacity + selected_demand) * (1.0 - tf.cast(self.from_depot, tf.float32)) |
|
|
|
|
|
idx = tf.cast(tf.concat((self.ids, self.scatter_zeros, self.prev_a), axis=-1), tf.int32)[:, None, :] |
|
self.visited = tf.tensor_scatter_nd_update(self.visited, idx, self.step_updates) |
|
|
|
self.i = self.i + 1 |
|
|
|
@staticmethod |
|
def get_costs(dataset, pi): |
|
|
|
|
|
loc_with_depot = tf.concat([dataset[0][:, None, :], dataset[1]], axis=1) |
|
d = tf.gather(loc_with_depot, tf.cast(pi, tf.int32), batch_dims=1) |
|
|
|
|
|
|
|
return (tf.reduce_sum(tf.norm(d[:, 1:] - d[:, :-1], ord=2, axis=2), axis=1) |
|
+ tf.norm(d[:, 0] - dataset[0], ord=2, axis=1) |
|
+ tf.norm(d[:, -1] - dataset[0], ord=2, axis=1)) |
|
|