""" |
Author: Minh Pham-Dinh |
Created: Jan 27th, 2024 |
Last Modified: Feb 10th, 2024 |
Email: [email protected] |
Description: |
main Dreamer file. |
The implementation is based on: |
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. |
[Online]. Available: https://arxiv.org/abs/1912.01603 |
""" |
import os |
import numpy as np |
import yaml |
from tqdm import tqdm |
import wandb |
import torch |
import torch.nn as nn |
import torch.optim as optim |
import utils.models as models |
from utils.buffer import ReplayBuffer |
from utils.utils import td_lambda, log_metrics |
class Dreamer: |
def __init__(self, config, logpath, env, writer = None, wandb_writer=None): |
self.config = config |
self.device = torch.device(self.config.device) |
self.env = env |
self.obs_size = env.observation_space.shape |
self.action_size = env.action_space.n if self.config.env.discrete else env.action_space.shape[0] |
self.epsilon = self.config.main.epsilon_start |
self.env_step = 0 |
self.logpath = logpath |
np.random.seed(self.config.seed) |
torch.manual_seed(self.config.seed) |
self.rssm = models.RSSM(self.config.main.stochastic_size, |
self.config.main.embedded_obs_size, |
self.config.main.deterministic_size, |
self.config.main.hidden_units, |
self.action_size).to(self.device) |
self.reward = models.RewardNet(self.config.main.stochastic_size + self.config.main.deterministic_size, |
self.config.main.hidden_units).to(self.device) |
if self.config.main.continue_loss: |
self.cont_net = models.ContinuoNet(self.config.main.stochastic_size + self.config.main.deterministic_size, |
self.config.main.hidden_units).to(self.device) |
self.encoder = models.ConvEncoder(input_shape=self.obs_size).to(self.device) |
self.decoder = models.ConvDecoder(self.config.main.stochastic_size, |
self.config.main.deterministic_size, |
out_shape=self.obs_size).to(self.device) |
self.dyna_parameters = ( |
list(self.rssm.parameters()) |
+ list(self.reward.parameters()) |
+ list(self.encoder.parameters()) |
+ list(self.decoder.parameters()) |
) |
if self.config.main.continue_loss: |
self.dyna_parameters += list(self.cont_net.parameters()) |
self.actor = models.Actor(self.config.main.stochastic_size + self.config.main.deterministic_size, |
self.config.main.hidden_units, |
self.action_size, |
self.config.env.discrete).to(self.device) |
self.critic = models.Critic(self.config.main.stochastic_size + self.config.main.deterministic_size, |
self.config.main.hidden_units).to(self.device) |
self.dyna_optimizer = optim.Adam(self.dyna_parameters, lr=self.config.main.dyna_model_lr) |
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.config.main.actor_lr) |
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.config.main.critic_lr) |
self.gradient_step = 0 |
self.buffer = ReplayBuffer(self.config.main.buffer_capacity, self.obs_size, (self.action_size, )) |
self.wandb_writer = wandb_writer |
self.writer = writer |
def update_epsilon(self): |
"""In use for decaying epsilon in discrete env |
Returns: |
_type_: _description_ |
""" |
eps_start = self.config.main.epsilon_start |
eps_end = self.config.main.epsilon_end |
decay_steps = self.config.main.eps_decay_steps |
decay_rate = (eps_start - eps_end) / (decay_steps) |
self.epsilon = max(eps_end, eps_start - decay_rate*self.gradient_step) |
def train(self): |
"""main training loop, implementation follow closely with the loop from the official paper |
Returns: |
_type_: _description_ |
""" |
ep = 0 |
obs, _ = self.env.reset() |
while ep < self.config.main.data_init_ep: |
action = self.env.action_space.sample() |
if self.config.env.discrete: |
actions = np.zeros((self.action_size, )) |
actions[action] = 1.0 |
else: |
actions = action |
next_obs, reward, termination, truncation, info = self.env.step(action) |
self.buffer.add(obs, actions, reward, termination or truncation) |
obs = next_obs |
if "episode" in info: |
obs, _ = self.env.reset() |
ep += 1 |
print(ep) |
if 'video_path' in info and self.wandb_writer: |
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')}) |
for _ in tqdm(range(self.config.main.total_iter)): |
if _ % self.config.main.save_freq == 0: |
directory = self.logpath + 'models/' |
os.makedirs(directory, exist_ok=True) |
torch.save(self.rssm, self.logpath + 'models/rssm_model') |
torch.save(self.encoder, self.logpath + 'models/encoder') |
torch.save(self.decoder, self.logpath + 'models/decoder') |
torch.save(self.actor, self.logpath + 'models/actor') |
torch.save(self.critic, self.logpath + 'models/critic') |
if _ % self.config.main.eval_freq == 0: |
eval_score = self.data_collection(self.config.main.eval_eps, eval=True) |
metrics = {'performance/evaluation score': eval_score} |
log_metrics(metrics, self.env_step, self.writer, self.wandb_writer) |
for c in tqdm(range(self.config.main.collect_iter)): |
batch = self.buffer.sample(self.config.main.batch_size, self.config.main.seq_len, self.device) |
post, deter = self.dynamic_learning(batch) |
self.behavioral_learning(post, deter) |
self.gradient_step += 1 |
self.update_epsilon() |
self.data_collection(self.config.main.data_interact_ep) |
def dynamic_learning(self, batch): |
"""Learning the dynamic model. In this method, we sequentially pass data in the RSSM to |
learn the model |
Args: |
batch (addict.Dict): batches of data |
""" |
''' |
We unpack the batch. A batch contains: |
- b_obs (batch_size, seq_len, *obs.shape): batches of observation |
- b_a (batch_size, seq_len, 1): batches of action |
- b_r (batch_size, seq_len, 1): batches of rewards |
- b_d (batch_size, seq_len, 1): batches of termination signal |
''' |
b_obs = batch.obs |
b_a = batch.actions |
b_r = batch.rewards |
b_d = batch.dones |
batch_size, seq_len, _ = b_r.shape |
eb_obs = self.encoder(b_obs) |
posterior = torch.zeros((batch_size, self.config.main.stochastic_size)).to(self.device) |
deterministic = torch.zeros((batch_size, self.config.main.deterministic_size)).to(self.device) |
posteriors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device) |
priors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device) |
deterministics = torch.zeros((batch_size, seq_len-1, self.config.main.deterministic_size)).to(self.device) |
posterior_means = torch.zeros_like(posteriors).to(self.device) |
posterior_stds = torch.zeros_like(posteriors).to(self.device) |
prior_means = torch.zeros_like(priors).to(self.device) |
prior_stds = torch.zeros_like(priors).to(self.device) |
for t in (range(1, seq_len)): |
deterministic = self.rssm.recurrent(posterior, b_a[:, t-1, :], deterministic) |
prior_dist, prior = self.rssm.transition(deterministic) |
posterior_dist, posterior = self.rssm.representation(eb_obs[:, t, :], deterministic) |
''' |
store recurrent data |
data are shifted 1 timestep ahead. Start from the second timestep or t=1 |
''' |
posteriors[:, t-1, :] = posterior |
posterior_means[:, t-1, :] = posterior_dist.mean |
posterior_stds[:, t-1, :] = posterior_dist.scale |
priors[:, t-1, :] = prior |
prior_means[:, t-1, :] = prior_dist.mean |
prior_stds[:, t-1, :] = prior_dist.scale |
deterministics[:, t-1, :] = deterministic |
''' |
Reconstruction loss. This loss helps the model learn to encode pixels observation. |
''' |
mps_flatten = False |
if self.device == torch.device("mps"): |
mps_flatten = True |
reconstruct_dist = self.decoder(posteriors, deterministics, mps_flatten) |
target = b_obs[:, 1:] |
if mps_flatten: |
target = target.reshape(-1, *self.obs_size) |
reconstruct_loss = reconstruct_dist.log_prob(target).mean() |
rewards = self.reward(posteriors, deterministics) |
rewards_dist = torch.distributions.Normal(rewards, 1) |
rewards_dist = torch.distributions.Independent(rewards_dist, 1) |
rewards_loss = rewards_dist.log_prob(b_r[:, 1:]).mean() |
''' |
Continuity loss. This loss term helps predict the probability of an episode terminate at a particular state |
''' |
if self.config.main.continue_loss: |
cont_logits, _ = self.cont_net(posteriors, deterministics) |
cont_target = (1 - b_d[:, 1:]) * self.config.main.discount |
continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(cont_logits, cont_target) |
else: |
continue_loss = torch.zeros((1)).to(self.device) |
''' |
KL loss. Matching the distribution of transition and representation model. This is to ensure we have the accurate transition model for use in imagination process |
''' |
priors_dist = torch.distributions.Independent( |
torch.distributions.Normal(prior_means, prior_stds), 1 |
) |
posteriors_dist = torch.distributions.Independent( |
torch.distributions.Normal(posterior_means, posterior_stds), 1 |
) |
kl_loss = torch.max( |
torch.mean(torch.distributions.kl.kl_divergence(posteriors_dist, priors_dist)), |
torch.tensor(self.config.main.free_nats).to(self.device) |
) |
total_loss = self.config.main.kl_divergence_scale * kl_loss - reconstruct_loss - rewards_loss + continue_loss |
self.dyna_optimizer.zero_grad() |
total_loss.backward() |
nn.utils.clip_grad_norm_( |
self.dyna_parameters, |
self.config.main.clip_grad, |
norm_type=self.config.main.grad_norm_type, |
) |
self.dyna_optimizer.step() |
metrics = { |
'Dynamic_model/KL': kl_loss.item(), |
'Dynamic_model/Reconstruction': reconstruct_loss.item(), |
'Dynamic_model/Reward': rewards_loss.item(), |
'Dynamic_model/Continue': continue_loss.item(), |
'Dynamic_model/Total': total_loss.item() |
} |
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer) |
return posteriors.detach(), deterministics.detach() |
def behavioral_learning(self, state, deterministics): |
"""Learning behavioral through latent imagination |
Args: |
self (_type_): _description_ |
state (batch_size, seq_len-1, stoch_state_size): starting point state |
deterministics (batch_size, seq_len-1, stoch_state_size) |
""" |
state = state.reshape(-1, self.config.main.stochastic_size) |
deterministics = deterministics.reshape(-1, self.config.main.deterministic_size) |
batch_size, stochastic_size = state.shape |
_, deterministics_size = deterministics.shape |
state_trajectories = torch.zeros((batch_size, self.config.main.horizon, stochastic_size)).to(self.device) |
deterministics_trajectories = torch.zeros((batch_size, self.config.main.horizon, deterministics_size)).to(self.device) |
for t in range(self.config.main.horizon): |
action = self.actor(state, deterministics) |
deterministics = self.rssm.recurrent(state, action, deterministics) |
_, state = self.rssm.transition(deterministics) |
state_trajectories[:, t, :] = state |
deterministics_trajectories[:, t, :] = deterministics |
''' |
After imagining, we have both the state trajectories and deterministic trajectories, which can be used to create latent states. |
- state_trajectories (N, HORIZON_LEN) |
- deteerministic_trajectories (N, HORIZON_LEN) |
''' |
rewards = self.reward(state_trajectories, deterministics_trajectories) |
rewards_dist = torch.distributions.Normal(rewards, 1) |
rewards_dist = torch.distributions.Independent(rewards_dist, 1) |
rewards = rewards_dist.mode |
if self.config.main.continue_loss: |
_, conts_dist = self.cont_net(state_trajectories, deterministics_trajectories) |
continues = conts_dist.mean |
else: |
continues = self.config.main.discount * torch.ones_like(rewards) |
values = self.critic(state_trajectories, deterministics_trajectories).mode |
returns = td_lambda( |
rewards, |
continues, |
values, |
self.config.main.lambda_, |
self.device |
) |
discount = torch.cumprod(torch.cat(( |
torch.ones_like(continues[:, :1]).to(self.device), |
continues[:, :-2] |
), 1), 1).detach() |
actor_loss = -(discount * returns).mean() |
self.actor_optimizer.zero_grad() |
actor_loss.backward() |
nn.utils.clip_grad_norm_( |
self.actor.parameters(), |
self.config.main.clip_grad, |
norm_type=self.config.main.grad_norm_type, |
) |
self.actor_optimizer.step() |
values_dist = self.critic(state_trajectories[:, :-1].detach(), deterministics_trajectories[:, :-1].detach()) |
critic_loss = -(discount.squeeze() * values_dist.log_prob(returns.detach())).mean() |
self.critic_optimizer.zero_grad() |
critic_loss.backward() |
nn.utils.clip_grad_norm_( |
self.critic.parameters(), |
self.config.main.clip_grad, |
norm_type=self.config.main.grad_norm_type, |
) |
self.critic_optimizer.step() |
metrics = { |
'Behavorial_model/Actor': actor_loss.item(), |
'Behavorial_model/Critic': critic_loss.item() |
} |
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer) |
@torch.no_grad() |
def data_collection(self, num_episodes, eval=False): |
"""data collection method. Roll out agent a number of episodes and collect data |
If eval=True. The agent is set for evaluation mode with no exploration noise and data collection |
Args: |
num_episodes (int): number of episodes |
eval (bool): Evaluation mode. Defaults to False. |
random (bool): Random mode. Defaults to False. |
Returns: |
average_score: average score over number of rollout episodes |
""" |
score = 0 |
ep = 0 |
obs, _ = self.env.reset() |
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device) |
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device) |
action = torch.zeros((1, self.action_size)).to(self.device) |
while ep < num_episodes: |
embed_obs = self.encoder(torch.from_numpy(obs).to(self.device, dtype=torch.float)) |
deterministic = self.rssm.recurrent(posterior, action, deterministic) |
_, posterior = self.rssm.representation(embed_obs, deterministic) |
actor_out = self.actor(posterior, deterministic) |
if not eval: |
actions = actor_out.cpu().numpy() |
if self.config.env.discrete: |
if np.random.rand() < self.epsilon: |
action = self.env.action_space.sample() |
else: |
action = np.argmax(actions) |
else: |
mean_noise = self.config.main.mean_noise |
std_noise = self.config.main.std_noise |
normal_dist = torch.distributions.Normal(actor_out + mean_noise, std_noise) |
sampled_action = normal_dist.sample().cpu().numpy() |
actions = np.clip(sampled_action, a_min=-1, a_max=1) |
action = actions[0] |
else: |
actions = actor_out.cpu().numpy() |
if self.config.env.discrete: |
action = np.argmax(actions) |
else: |
actions = np.clip(actions, a_min=-1, a_max=1) |
action = actions[0] |
next_obs, reward, termination, truncation, info = self.env.step(action) |
if not eval: |
self.buffer.add(obs, actions, reward, termination | truncation) |
self.env_step += self.config.env.action_repeat |
obs = next_obs |
action = actor_out |
if "episode" in info: |
cur_score = info["episode"]["r"][0] |
score += cur_score |
obs, _ = self.env.reset() |
ep += 1 |
if 'video_path' in info and self.wandb_writer: |
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')}) |
log_metrics({'performance/training score': cur_score}, self.env_step, self.writer, self.wandb_writer) |
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device) |
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device) |
action = torch.zeros((1, self.action_size)).to(self.device) |
return score/num_episodes |