Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from functools import partial | |
from typing import Callable, List, Optional, Tuple, Union | |
from einops import repeat | |
from mmcv import Config | |
import pytorch_lightning as pl | |
import torch | |
from risk_biased.models.cvae_params import CVAEParams | |
from risk_biased.models.biased_cvae_model import ( | |
cvae_factory, | |
) | |
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams | |
from risk_biased.utils.risk import get_risk_estimator | |
from risk_biased.utils.risk import get_risk_level_sampler | |
class LitTrajectoryPredictorParams: | |
""" | |
cvae_params: CVAEParams class defining the necessary parameters for the CVAE model | |
risk distribution: dict of string and values defining the risk distribution to use | |
risk_estimator: dict of string and values defining the risk estimator to use | |
kl_weight: float defining the weight of the KL term in the loss function | |
kl_threshold: float defining the threshold to apply when computing kl divergence (avoid posterior collapse) | |
risk_weight: float defining the weight of the risk term in the loss function | |
n_mc_samples_risk: int defining the number of Monte Carlo samples to use when estimating the risk | |
n_mc_samples_biased: int defining the number of Monte Carlo samples to use when estimating the expected biased cost | |
dt: float defining the duration between two consecutive time steps | |
learning_rate: float defining the learning rate for the optimizer | |
use_risk_constraint: bool defining whether to use the risk constrained optimization procedure | |
risk_constraint_update_every_n_epoch: int defining the number of epochs between two risk weight updates | |
risk_constraint_weight_update_factor: float defining the factor by which the risk weight is multiplied at each update | |
risk_constraint_weight_maximum: float defining the maximum value of the risk weight | |
num_samples_min_fde: int defining the number of samples to use when estimating the minimum FDE | |
condition_on_ego_future: bool defining whether to condition the biasing on the ego future trajectory (else on the ego past) | |
""" | |
cvae_params: CVAEParams | |
risk_distribution: dict | |
risk_estimator: dict | |
kl_weight: float | |
kl_threshold: float | |
risk_weight: float | |
n_mc_samples_risk: int | |
n_mc_samples_biased: int | |
dt: float | |
learning_rate: float | |
use_risk_constraint: bool | |
risk_constraint_update_every_n_epoch: int | |
risk_constraint_weight_update_factor: float | |
risk_constraint_weight_maximum: float | |
num_samples_min_fde: int | |
condition_on_ego_future: bool | |
def from_config(cfg: Config): | |
cvae_params = CVAEParams.from_config(cfg) | |
return LitTrajectoryPredictorParams( | |
risk_distribution=cfg.risk_distribution, | |
risk_estimator=cfg.risk_estimator, | |
kl_weight=cfg.kl_weight, | |
kl_threshold=cfg.kl_threshold, | |
risk_weight=cfg.risk_weight, | |
n_mc_samples_risk=cfg.n_mc_samples_risk, | |
n_mc_samples_biased=cfg.n_mc_samples_biased, | |
dt=cfg.dt, | |
learning_rate=cfg.learning_rate, | |
cvae_params=cvae_params, | |
use_risk_constraint=cfg.use_risk_constraint, | |
risk_constraint_update_every_n_epoch=cfg.risk_constraint_update_every_n_epoch, | |
risk_constraint_weight_update_factor=cfg.risk_constraint_weight_update_factor, | |
risk_constraint_weight_maximum=cfg.risk_constraint_weight_maximum, | |
num_samples_min_fde=cfg.num_samples_min_fde, | |
condition_on_ego_future=cfg.condition_on_ego_future, | |
) | |
class LitTrajectoryPredictor(pl.LightningModule): | |
"""Pytorch Lightning Module for Trajectory Prediction with the biased cvae model | |
Args: | |
params : dataclass object containing the necessary parameters | |
cost_params: dataclass object defining the TTC cost function | |
unnormalizer: function that takes in a trajectory and an offset and that outputs the | |
unnormalized trajectory | |
""" | |
def __init__( | |
self, | |
params: LitTrajectoryPredictorParams, | |
cost_params: TTCCostParams, | |
unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], | |
) -> None: | |
super().__init__() | |
model = cvae_factory( | |
params.cvae_params, | |
cost_function=TTCCostTorch(cost_params), | |
risk_estimator=get_risk_estimator(params.risk_estimator), | |
training_mode="cvae", | |
) | |
self.model = model | |
self.params = params | |
self._unnormalize_trajectory = unnormalizer | |
self.set_training_mode("cvae") | |
self.learning_rate = params.learning_rate | |
self.num_samples_min_fde = params.num_samples_min_fde | |
self.dynamic_state_dim = params.cvae_params.dynamic_state_dim | |
self.dt = params.cvae_params.dt | |
self.use_risk_constraint = params.use_risk_constraint | |
self.risk_weight = params.risk_weight | |
self.risk_weight_ratio = params.risk_weight / params.kl_weight | |
self.kl_weight = params.kl_weight | |
if self.use_risk_constraint: | |
self.risk_constraint_update_every_n_epoch = ( | |
params.risk_constraint_update_every_n_epoch | |
) | |
self.risk_constraint_weight_update_factor = ( | |
params.risk_constraint_weight_update_factor | |
) | |
self.risk_constraint_weight_maximum = params.risk_constraint_weight_maximum | |
self._risk_sampler = get_risk_level_sampler(params.risk_distribution) | |
def set_training_mode(self, training_mode: str): | |
self.model.set_training_mode(training_mode) | |
self.partial_get_loss = partial( | |
self.model.get_loss, | |
kl_threshold=self.params.kl_threshold, | |
n_samples_risk=self.params.n_mc_samples_risk, | |
n_samples_biased=self.params.n_mc_samples_biased, | |
dt=self.params.dt, | |
unnormalizer=self._unnormalize_trajectory, | |
) | |
def _get_loss( | |
self, | |
x: torch.Tensor, | |
mask_x: torch.Tensor, | |
map: torch.Tensor, | |
mask_map: torch.Tensor, | |
y: torch.Tensor, | |
mask_y: torch.Tensor, | |
mask_loss: torch.Tensor, | |
x_ego: torch.Tensor, | |
y_ego: torch.Tensor, | |
offset: Optional[torch.Tensor] = None, | |
risk_level: Optional[torch.Tensor] = None, | |
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]], dict]: | |
"""Compute loss based on trajectory history x and future y | |
Args: | |
x: (batch_size, num_agents, num_steps, state_dim) tensor of history | |
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask | |
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects | |
mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding | |
y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. | |
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask. | |
mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss | |
should be computed and to False where it shouldn't | |
offset : (batch_size, num_agents, state_dim) offset position from ego | |
risk_level : (batch_size, num_agents) tensor of risk levels desired for future trajectories | |
Returns: | |
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: (1,) loss tensor or tuple of | |
loss tensors | |
dict: dict that contains values to be logged | |
""" | |
return self.partial_get_loss( | |
x=x, | |
mask_x=mask_x, | |
map=map, | |
mask_map=mask_map, | |
y=y, | |
mask_y=mask_y, | |
mask_loss=mask_loss, | |
offset=offset, | |
risk_level=risk_level, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
risk_weight=self.risk_weight, | |
kl_weight=self.kl_weight, | |
) | |
def log_with_prefix( | |
self, | |
log_dict: dict, | |
prefix: Optional[str] = None, | |
on_step: Optional[bool] = None, | |
on_epoch: Optional[bool] = None, | |
) -> None: | |
"""log entries in log_dict while optinally adding "<prefix>/" to its keys | |
Args: | |
log_dict: dict that contains values to be logged | |
prefix: prefix to be added to keys | |
on_step: if True logs at this step. None auto-logs at the training_step but not | |
validation/test_step | |
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test | |
step but not training_step | |
""" | |
if prefix is None: | |
prefix = "" | |
else: | |
prefix += "/" | |
for (metric, value) in log_dict.items(): | |
metric = prefix + metric | |
self.log(metric, value, on_step=on_step, on_epoch=on_epoch) | |
def configure_optimizers( | |
self, | |
) -> Union[torch.optim.Optimizer, List[torch.optim.Optimizer]]: | |
"""Configure optimizer for PyTorch-Lightning | |
Returns: | |
torch.optim.Optimizer: optimizer to be used for training | |
""" | |
if isinstance(self.model.get_parameters(), list): | |
self._optimizers = [ | |
torch.optim.Adam(params, lr=self.learning_rate) | |
for params in self.model.get_parameters() | |
] | |
else: | |
self._optimizers = [ | |
torch.optim.Adam(self.model.get_parameters(), lr=self.learning_rate) | |
] | |
return self._optimizers | |
def training_step( | |
self, | |
batch: Tuple[ | |
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | |
], | |
batch_idx: int, | |
) -> dict: | |
"""Training step definition for PyTorch-Lightning | |
Args: | |
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene | |
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data | |
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory | |
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data | |
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted | |
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene | |
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data | |
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time | |
(batch_size, 1, num_steps, state_dim), # ego past trajectory | |
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory | |
batch_idx : batch_idx to be used by PyTorch-Lightning | |
Returns: | |
dict: dict of outputs containing loss | |
""" | |
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch | |
risk_level = repeat( | |
self._risk_sampler.sample(x.shape[0], x.device), | |
"b -> b num_agents", | |
num_agents=x.shape[1], | |
) | |
loss, log_dict = self._get_loss( | |
x=x, | |
mask_x=mask_x, | |
map=map, | |
mask_map=mask_map, | |
y=y, | |
mask_y=mask_y, | |
mask_loss=mask_loss, | |
offset=offset, | |
risk_level=risk_level, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
) | |
if isinstance(loss, tuple): | |
loss = sum(loss) | |
self.log_with_prefix(log_dict, prefix="train", on_step=True, on_epoch=False) | |
return {"loss": loss} | |
def training_epoch_end(self, outputs: List[dict]) -> None: | |
"""Called at the end of the training epoch with the outputs of all training steps | |
Args: | |
outputs: list of outputs of all training steps in the current epoch | |
""" | |
if self.use_risk_constraint: | |
if ( | |
self.model.training_mode == "bias" | |
and (self.trainer.current_epoch + 1) | |
% self.risk_constraint_update_every_n_epoch | |
== 0 | |
): | |
self.risk_weight_ratio *= self.risk_constraint_weight_update_factor | |
if self.risk_weight_ratio < self.risk_constraint_weight_maximum: | |
sum_weight = self.risk_weight + self.kl_weight | |
self.risk_weight = ( | |
sum_weight | |
* self.risk_weight_ratio | |
/ (1 + self.risk_weight_ratio) | |
) | |
self.kl_weight = sum_weight / (1 + self.risk_weight_ratio) | |
# self.risk_weight *= self.risk_constraint_weight_update_factor | |
# if self.risk_weight > self.risk_constraint_weight_maximum: | |
# self.risk_weight = self.risk_constraint_weight_maximum | |
def _get_risk_tensor( | |
self, | |
batch_size: int, | |
num_agents: int, | |
device: torch.device, | |
risk_level: Optional[torch.Tensor] = None, | |
): | |
"""This function is used to reformat different possible formattings of risk_level input arguments into a tensor of shape (batch_size). | |
If given a tensor the same tensor is returned. | |
If given a float value, a tensor of this value is returned. | |
If given None, a tensor filled with random samples is returned. | |
Args: | |
batch_size : desired batch size | |
device : device on which we want to store risk | |
risk_level : The risk level as a tensor, a float value or None | |
Returns: | |
_type_: _description_ | |
""" | |
if risk_level is not None: | |
if isinstance(risk_level, float): | |
risk_level = ( | |
torch.ones(batch_size, num_agents, device=device) * risk_level | |
) | |
else: | |
risk_level = risk_level.to(device) | |
else: | |
risk_level = None | |
return risk_level | |
def validation_step( | |
self, | |
batch: Tuple[ | |
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | |
], | |
batch_idx: int, | |
risk_level: float = 1.0, | |
) -> dict: | |
"""Validation step definition for PyTorch-Lightning | |
Args: | |
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene | |
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data | |
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory | |
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data | |
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted | |
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene | |
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data | |
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time | |
(batch_size, 1, num_steps, state_dim), # ego past trajectory | |
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory | |
batch_idx : batch_idx to be used by PyTorch-Lightning | |
risk_level : optional desired risk level | |
Returns: | |
dict: dict of outputs containing loss | |
""" | |
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch | |
risk_level = self._get_risk_tensor( | |
x.shape[0], x.shape[1], x.device, risk_level=risk_level | |
) | |
self.model.eval() | |
log_dict_accuracy = self.model.get_prediction_accuracy( | |
x=x, | |
mask_x=mask_x, | |
map=map, | |
mask_map=mask_map, | |
y=y, | |
mask_loss=mask_loss, | |
offset=offset, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
unnormalizer=self._unnormalize_trajectory, | |
risk_level=risk_level, | |
num_samples_min_fde=self.num_samples_min_fde, | |
) | |
loss, log_dict_loss = self._get_loss( | |
x=x, | |
mask_x=mask_x, | |
map=map, | |
mask_map=mask_map, | |
y=y, | |
mask_y=mask_y, | |
mask_loss=mask_loss, | |
offset=offset, | |
risk_level=risk_level, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
) | |
if isinstance(loss, tuple): | |
loss = sum(loss) | |
self.log_with_prefix( | |
dict(log_dict_accuracy, **log_dict_loss), | |
prefix="val", | |
on_step=False, | |
on_epoch=True, | |
) | |
self.model.train() | |
return {"loss": loss} | |
def test_step( | |
self, | |
batch: Tuple[ | |
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | |
], | |
batch_idx: int, | |
risk_level: Optional[torch.Tensor] = None, | |
) -> dict: | |
"""Test step definition for PyTorch-Lightning | |
Args: | |
batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene | |
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data | |
(batch_size, num_agents, num_steps_future, state_dim), # future trajectory | |
(batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data | |
(batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted | |
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene | |
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data | |
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time | |
(batch_size, 1, num_steps, state_dim), # ego past trajectory | |
(batch_size, 1, num_steps_future, state_dim)] # ego future trajectory | |
batch_idx : batch_idx to be used by PyTorch-Lightning | |
risk_level : optional desired risk level | |
Returns: | |
dict: dict of outputs containing loss | |
""" | |
x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch | |
risk_level = self._get_risk_tensor( | |
x.shape[0], x.shape[1], x.device, risk_level=risk_level | |
) | |
loss, log_dict = self._get_loss( | |
x=x, | |
mask_x=mask_x, | |
map=map, | |
mask_map=mask_map, | |
y=y, | |
mask_y=mask_y, | |
mask_loss=mask_loss, | |
offset=offset, | |
risk_level=risk_level, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
) | |
if isinstance(loss, tuple): | |
loss = sum(loss) | |
self.log_with_prefix(log_dict, prefix="test", on_step=False, on_epoch=True) | |
return {"loss": loss} | |
def predict_step( | |
self, | |
batch: Tuple[torch.Tensor, torch.Tensor], | |
batch_idx: int = 0, | |
risk_level: Optional[torch.Tensor] = None, | |
n_samples: int = 0, | |
return_weights: bool = False, | |
) -> torch.Tensor: | |
"""Predict step definition for PyTorch-Lightning | |
Args: | |
batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene | |
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data | |
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene | |
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data | |
(batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time | |
(batch_size, 1, num_steps, state_dim), # past trajectory of the ego agent in the scene | |
(batch_size, 1, num_steps_future, state_dim),] # future trajectory of the ego agent in the scene | |
batch_idx : batch_idx to be used by PyTorch-Lightning (unused here) | |
risk_level : optional desired risk level | |
n_samples: Number of samples to predict per agent | |
With value of 0 does not include the `n_samples` dim in the output. | |
return_weights: If True, also returns the sample weights | |
Returns: | |
(batch_size, (n_samples), num_steps_future, state_dim) tensor | |
""" | |
x, mask_x, map, mask_map, offset, x_ego, y_ego = batch | |
risk_level = self._get_risk_tensor( | |
batch_size=x.shape[0], | |
num_agents=x.shape[1], | |
device=x.device, | |
risk_level=risk_level, | |
) | |
y_sampled, weights, _ = self.model( | |
x, | |
mask_x, | |
map, | |
mask_map, | |
offset=offset, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
risk_level=risk_level, | |
n_samples=n_samples, | |
) | |
predict_sampled = self._unnormalize_trajectory(y_sampled, offset) | |
if return_weights: | |
return predict_sampled, weights | |
else: | |
return predict_sampled | |
def predict_loop_once( | |
self, | |
batch: Tuple[torch.Tensor, torch.Tensor], | |
batch_idx: int = 0, | |
risk_level: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
"""Predict with refinment: | |
A first prediction is done as in predict_step, however instead of unnormalize and return it, | |
it is fed to the encoder that wast trained to encode past and ground truth future. | |
Then the decoder is used again but its latent input sample is biased by the encoder | |
instead of being a sample of the prior distribution. | |
Then as in predict_step the result is unnormalized and returned. | |
Args: | |
batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene | |
(batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data | |
(batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene | |
(batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data | |
(batch_size, num_agents, state_dim),] # position offset of all agents relative to ego at present time | |
batch_idx : batch_idx to be used by PyTorch-Lightning (Unused here). Defaults to 0. | |
risk_level : optional desired risk level | |
Returns: | |
torch.Tensor: (batch_size, num_steps_future, state_dim) tensor | |
""" | |
x, mask_x, map, mask_map, offset = batch | |
risk_level = self._get_risk_tensor( | |
x.shape[0], x.shape[1], x.device, risk_level=risk_level | |
) | |
y_sampled, _ = self.model( | |
x, | |
mask_x, | |
map, | |
mask_map, | |
offset=offset, | |
risk_level=risk_level, | |
) | |
mask_y = repeat(mask_x.any(-1), "b a -> b a f", f=y_sampled.shape[-2]) | |
y_sampled, _ = self.model( | |
x, | |
mask_x, | |
map, | |
mask_map, | |
y_sampled, | |
mask_y, | |
offset=offset, | |
risk_level=risk_level, | |
) | |
predict_sampled = self._unnormalize_trajectory(y_sampled, offset=offset) | |
return predict_sampled | |