PyDreamerV1 / algos /dreamer.py
minhphd's picture
Upload 30 files
ce3feed verified
"""
Author: Minh Pham-Dinh
Created: Jan 27th, 2024
Last Modified: Feb 10th, 2024
Email: [email protected]
Description:
main Dreamer file.
The implementation is based on:
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
[Online]. Available: https://arxiv.org/abs/1912.01603
"""
# Standard Library Imports
import os
import numpy as np
import yaml
from tqdm import tqdm
import wandb
# Machine Learning and Data Processing Imports
import torch
import torch.nn as nn
import torch.optim as optim
# Custom Utility Imports
import utils.models as models
from utils.buffer import ReplayBuffer
from utils.utils import td_lambda, log_metrics
class Dreamer:
def __init__(self, config, logpath, env, writer = None, wandb_writer=None):
self.config = config
self.device = torch.device(self.config.device)
self.env = env
self.obs_size = env.observation_space.shape
self.action_size = env.action_space.n if self.config.env.discrete else env.action_space.shape[0]
self.epsilon = self.config.main.epsilon_start
self.env_step = 0
self.logpath = logpath
# Set random seed for reproducibility
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
#dynamic networks initialized
self.rssm = models.RSSM(self.config.main.stochastic_size,
self.config.main.embedded_obs_size,
self.config.main.deterministic_size,
self.config.main.hidden_units,
self.action_size).to(self.device)
self.reward = models.RewardNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
self.config.main.hidden_units).to(self.device)
if self.config.main.continue_loss:
self.cont_net = models.ContinuoNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
self.config.main.hidden_units).to(self.device)
self.encoder = models.ConvEncoder(input_shape=self.obs_size).to(self.device)
self.decoder = models.ConvDecoder(self.config.main.stochastic_size,
self.config.main.deterministic_size,
out_shape=self.obs_size).to(self.device)
self.dyna_parameters = (
list(self.rssm.parameters())
+ list(self.reward.parameters())
+ list(self.encoder.parameters())
+ list(self.decoder.parameters())
)
if self.config.main.continue_loss:
self.dyna_parameters += list(self.cont_net.parameters())
#behavior networks initialized
self.actor = models.Actor(self.config.main.stochastic_size + self.config.main.deterministic_size,
self.config.main.hidden_units,
self.action_size,
self.config.env.discrete).to(self.device)
self.critic = models.Critic(self.config.main.stochastic_size + self.config.main.deterministic_size,
self.config.main.hidden_units).to(self.device)
#optimizers
self.dyna_optimizer = optim.Adam(self.dyna_parameters, lr=self.config.main.dyna_model_lr)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.config.main.actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.config.main.critic_lr)
self.gradient_step = 0
#buffer
self.buffer = ReplayBuffer(self.config.main.buffer_capacity, self.obs_size, (self.action_size, ))
#tracking stuff
self.wandb_writer = wandb_writer
self.writer = writer
def update_epsilon(self):
"""In use for decaying epsilon in discrete env
Returns:
_type_: _description_
"""
eps_start = self.config.main.epsilon_start
eps_end = self.config.main.epsilon_end
decay_steps = self.config.main.eps_decay_steps
decay_rate = (eps_start - eps_end) / (decay_steps)
self.epsilon = max(eps_end, eps_start - decay_rate*self.gradient_step)
def train(self):
"""main training loop, implementation follow closely with the loop from the official paper
Returns:
_type_: _description_
"""
#prefill dataset
ep = 0
obs, _ = self.env.reset()
while ep < self.config.main.data_init_ep:
action = self.env.action_space.sample()
if self.config.env.discrete:
actions = np.zeros((self.action_size, ))
actions[action] = 1.0
else:
actions = action
next_obs, reward, termination, truncation, info = self.env.step(action)
self.buffer.add(obs, actions, reward, termination or truncation)
obs = next_obs
if "episode" in info:
obs, _ = self.env.reset()
ep += 1
print(ep)
if 'video_path' in info and self.wandb_writer:
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
#main train loop
for _ in tqdm(range(self.config.main.total_iter)):
#save model if reached checkpoint
if _ % self.config.main.save_freq == 0:
#check if models folder exist
directory = self.logpath + 'models/'
os.makedirs(directory, exist_ok=True)
#save models
torch.save(self.rssm, self.logpath + 'models/rssm_model')
torch.save(self.encoder, self.logpath + 'models/encoder')
torch.save(self.decoder, self.logpath + 'models/decoder')
torch.save(self.actor, self.logpath + 'models/actor')
torch.save(self.critic, self.logpath + 'models/critic')
#run eval if reach eval checkpoint
if _ % self.config.main.eval_freq == 0:
eval_score = self.data_collection(self.config.main.eval_eps, eval=True)
metrics = {'performance/evaluation score': eval_score}
log_metrics(metrics, self.env_step, self.writer, self.wandb_writer)
#training step
for c in tqdm(range(self.config.main.collect_iter)):
#draw data
batch = self.buffer.sample(self.config.main.batch_size, self.config.main.seq_len, self.device)
#dynamic learning
post, deter = self.dynamic_learning(batch)
#behavioral learning
self.behavioral_learning(post, deter)
#update step
self.gradient_step += 1
self.update_epsilon()
# collect more data with exploration noise
self.data_collection(self.config.main.data_interact_ep)
def dynamic_learning(self, batch):
"""Learning the dynamic model. In this method, we sequentially pass data in the RSSM to
learn the model
Args:
batch (addict.Dict): batches of data
"""
'''
We unpack the batch. A batch contains:
- b_obs (batch_size, seq_len, *obs.shape): batches of observation
- b_a (batch_size, seq_len, 1): batches of action
- b_r (batch_size, seq_len, 1): batches of rewards
- b_d (batch_size, seq_len, 1): batches of termination signal
'''
b_obs = batch.obs
b_a = batch.actions
b_r = batch.rewards
b_d = batch.dones
batch_size, seq_len, _ = b_r.shape
eb_obs = self.encoder(b_obs)
#initialized stochastic states (posterior) and deterministic states to first pass into the recurrent model
posterior = torch.zeros((batch_size, self.config.main.stochastic_size)).to(self.device)
deterministic = torch.zeros((batch_size, self.config.main.deterministic_size)).to(self.device)
#initialized memory storing of sequential gradients data
posteriors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
priors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
deterministics = torch.zeros((batch_size, seq_len-1, self.config.main.deterministic_size)).to(self.device)
posterior_means = torch.zeros_like(posteriors).to(self.device)
posterior_stds = torch.zeros_like(posteriors).to(self.device)
prior_means = torch.zeros_like(priors).to(self.device)
prior_stds = torch.zeros_like(priors).to(self.device)
#start passing data through the dynamic model
for t in (range(1, seq_len)):
deterministic = self.rssm.recurrent(posterior, b_a[:, t-1, :], deterministic)
prior_dist, prior = self.rssm.transition(deterministic)
#detail observation is shifted 1 timestep ahead(action is associated with the next state)
posterior_dist, posterior = self.rssm.representation(eb_obs[:, t, :], deterministic)
'''
store recurrent data
data are shifted 1 timestep ahead. Start from the second timestep or t=1
'''
posteriors[:, t-1, :] = posterior
posterior_means[:, t-1, :] = posterior_dist.mean
posterior_stds[:, t-1, :] = posterior_dist.scale
priors[:, t-1, :] = prior
prior_means[:, t-1, :] = prior_dist.mean
prior_stds[:, t-1, :] = prior_dist.scale
deterministics[:, t-1, :] = deterministic
#we start optimizing model with the provided data
'''
Reconstruction loss. This loss helps the model learn to encode pixels observation.
'''
mps_flatten = False
if self.device == torch.device("mps"):
mps_flatten = True
reconstruct_dist = self.decoder(posteriors, deterministics, mps_flatten)
target = b_obs[:, 1:]
if mps_flatten:
target = target.reshape(-1, *self.obs_size)
reconstruct_loss = reconstruct_dist.log_prob(target).mean()
#reward loss
rewards = self.reward(posteriors, deterministics)
rewards_dist = torch.distributions.Normal(rewards, 1)
rewards_dist = torch.distributions.Independent(rewards_dist, 1)
rewards_loss = rewards_dist.log_prob(b_r[:, 1:]).mean()
'''
Continuity loss. This loss term helps predict the probability of an episode terminate at a particular state
'''
if self.config.main.continue_loss:
# calculate log prob manually as tensorflow doesn't support float value in logprob of Bernoulli
# follow closely to Hafner's official code for Dreamer
cont_logits, _ = self.cont_net(posteriors, deterministics)
cont_target = (1 - b_d[:, 1:]) * self.config.main.discount
continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(cont_logits, cont_target)
else:
continue_loss = torch.zeros((1)).to(self.device)
'''
KL loss. Matching the distribution of transition and representation model. This is to ensure we have the accurate transition model for use in imagination process
'''
priors_dist = torch.distributions.Independent(
torch.distributions.Normal(prior_means, prior_stds), 1
)
posteriors_dist = torch.distributions.Independent(
torch.distributions.Normal(posterior_means, posterior_stds), 1
)
kl_loss = torch.max(
torch.mean(torch.distributions.kl.kl_divergence(posteriors_dist, priors_dist)),
torch.tensor(self.config.main.free_nats).to(self.device)
)
total_loss = self.config.main.kl_divergence_scale * kl_loss - reconstruct_loss - rewards_loss + continue_loss
self.dyna_optimizer.zero_grad()
total_loss.backward()
nn.utils.clip_grad_norm_(
self.dyna_parameters,
self.config.main.clip_grad,
norm_type=self.config.main.grad_norm_type,
)
self.dyna_optimizer.step()
#tensorboard logging
metrics = {
'Dynamic_model/KL': kl_loss.item(),
'Dynamic_model/Reconstruction': reconstruct_loss.item(),
'Dynamic_model/Reward': rewards_loss.item(),
'Dynamic_model/Continue': continue_loss.item(),
'Dynamic_model/Total': total_loss.item()
}
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
return posteriors.detach(), deterministics.detach()
def behavioral_learning(self, state, deterministics):
"""Learning behavioral through latent imagination
Args:
self (_type_): _description_
state (batch_size, seq_len-1, stoch_state_size): starting point state
deterministics (batch_size, seq_len-1, stoch_state_size)
"""
#flatten the batches --> new size (batch_size * (seq_len-1), *)
state = state.reshape(-1, self.config.main.stochastic_size)
deterministics = deterministics.reshape(-1, self.config.main.deterministic_size)
batch_size, stochastic_size = state.shape
_, deterministics_size = deterministics.shape
#initialized trajectories
state_trajectories = torch.zeros((batch_size, self.config.main.horizon, stochastic_size)).to(self.device)
deterministics_trajectories = torch.zeros((batch_size, self.config.main.horizon, deterministics_size)).to(self.device)
#imagine trajectories
for t in range(self.config.main.horizon):
# do not include the starting state
action = self.actor(state, deterministics)
deterministics = self.rssm.recurrent(state, action, deterministics)
_, state = self.rssm.transition(deterministics)
state_trajectories[:, t, :] = state
deterministics_trajectories[:, t, :] = deterministics
'''
After imagining, we have both the state trajectories and deterministic trajectories, which can be used to create latent states.
- state_trajectories (N, HORIZON_LEN)
- deteerministic_trajectories (N, HORIZON_LEN)
'''
#actor update
#compute rewards for each trajectories
rewards = self.reward(state_trajectories, deterministics_trajectories)
rewards_dist = torch.distributions.Normal(rewards, 1)
rewards_dist = torch.distributions.Independent(rewards_dist, 1)
rewards = rewards_dist.mode
if self.config.main.continue_loss:
_, conts_dist = self.cont_net(state_trajectories, deterministics_trajectories)
continues = conts_dist.mean
else:
continues = self.config.main.discount * torch.ones_like(rewards)
values = self.critic(state_trajectories, deterministics_trajectories).mode
#calculate trajectories returns
#returns should have shape (N, HORIZON_LEN - 1, 1) (last values are ignored due to nature of bootstrapping)
returns = td_lambda(
rewards,
continues,
values,
self.config.main.lambda_,
self.device
)
#culm product for discount
discount = torch.cumprod(torch.cat((
torch.ones_like(continues[:, :1]).to(self.device),
continues[:, :-2]
), 1), 1).detach()
# actor optimizing
actor_loss = -(discount * returns).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(
self.actor.parameters(),
self.config.main.clip_grad,
norm_type=self.config.main.grad_norm_type,
)
self.actor_optimizer.step()
# critic optimizing
values_dist = self.critic(state_trajectories[:, :-1].detach(), deterministics_trajectories[:, :-1].detach())
critic_loss = -(discount.squeeze() * values_dist.log_prob(returns.detach())).mean()
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(
self.critic.parameters(),
self.config.main.clip_grad,
norm_type=self.config.main.grad_norm_type,
)
self.critic_optimizer.step()
metrics = {
'Behavorial_model/Actor': actor_loss.item(),
'Behavorial_model/Critic': critic_loss.item()
}
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
@torch.no_grad()
def data_collection(self, num_episodes, eval=False):
"""data collection method. Roll out agent a number of episodes and collect data
If eval=True. The agent is set for evaluation mode with no exploration noise and data collection
Args:
num_episodes (int): number of episodes
eval (bool): Evaluation mode. Defaults to False.
random (bool): Random mode. Defaults to False.
Returns:
average_score: average score over number of rollout episodes
"""
score = 0
ep = 0
obs, _ = self.env.reset()
#initialized all zeros
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
action = torch.zeros((1, self.action_size)).to(self.device)
while ep < num_episodes:
embed_obs = self.encoder(torch.from_numpy(obs).to(self.device, dtype=torch.float)) #(1, embed_obs_sz)
deterministic = self.rssm.recurrent(posterior, action, deterministic)
_, posterior = self.rssm.representation(embed_obs, deterministic)
actor_out = self.actor(posterior, deterministic)
#detail: add exploration noise if not in evaluation mode
if not eval:
actions = actor_out.cpu().numpy()
if self.config.env.discrete:
if np.random.rand() < self.epsilon:
action = self.env.action_space.sample()
else:
action = np.argmax(actions)
else:
mean_noise = self.config.main.mean_noise
std_noise = self.config.main.std_noise
normal_dist = torch.distributions.Normal(actor_out + mean_noise, std_noise)
sampled_action = normal_dist.sample().cpu().numpy()
actions = np.clip(sampled_action, a_min=-1, a_max=1)
action = actions[0]
else:
actions = actor_out.cpu().numpy()
if self.config.env.discrete:
action = np.argmax(actions)
else:
actions = np.clip(actions, a_min=-1, a_max=1)
action = actions[0]
next_obs, reward, termination, truncation, info = self.env.step(action)
if not eval:
self.buffer.add(obs, actions, reward, termination | truncation)
self.env_step += self.config.env.action_repeat
obs = next_obs
action = actor_out
if "episode" in info:
cur_score = info["episode"]["r"][0]
score += cur_score
obs, _ = self.env.reset()
ep += 1
if 'video_path' in info and self.wandb_writer:
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
log_metrics({'performance/training score': cur_score}, self.env_step, self.writer, self.wandb_writer)
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
action = torch.zeros((1, self.action_size)).to(self.device)
return score/num_episodes