minhphd commited on
Commit
ce3feed
·
verified ·
1 Parent(s): 46c27fd

Upload 30 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gif/pacman_imagine.gif filter=lfs diff=lfs merge=lfs -text
37
+ gif/pacman.gif filter=lfs diff=lfs merge=lfs -text
38
+ gif/quadruped.gif filter=lfs diff=lfs merge=lfs -text
39
+ gif/walker_imagine.gif filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyDreamerV1: Clean pytorch implementation of Hafner et al Dreamer
2
+ <div align="center">
3
+ <img src="./gif/boxing.gif" alt="Actual run in " width="200px" height="200px"/>
4
+ <img src="./gif/quadruped.gif" alt="Actual run in " width="200px" height="200px"/>
5
+ <img src="./gif/walker.gif" alt="Actual run in " width="200px" height="200px"/>
6
+ </div>
7
+ <div align="center">
8
+ <img src="./gif/boxing_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
9
+ <img src="./gif/quadruped_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
10
+ <img src="./gif/walker_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
11
+ </div>
12
+
13
+
14
+
15
+ This repository offers a comprehensive implementation of the Dreamer algorithm, as presented in the groundbreaking work by Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination." Our implementation is dedicated to faithfully reproducing the innovative approach of learning and planning within a learned latent space, enabling agents to efficiently master complex behaviors through imagination alone.
16
+
17
+ ## Why Dreamer?
18
+
19
+ Dreamer stands at the forefront of model-based reinforcement learning by introducing an efficient method for learning behaviors directly from high-dimensional sensory inputs. It leverages a latent dynamics model to 'imagine' future states and rewards, enabling it to plan and execute actions that maximize long-term rewards purely from simulated experiences. This approach significantly improves sample efficiency over traditional model-free methods and opens new avenues for learning complex and nuanced behaviors in simulated environments. However, the official code was unfortunately regarded as complex and difficult to understand, and there are only a handful of Dreamer reimplementation that was able to reproduce the results.
20
+
21
+ ## Implementation Highlights
22
+
23
+ - **Modular Design**: My implementation of the Recurrent State Space Model (RSSM) is broken down into cleanly separated modules for the transition, representation, and recurrent models. This not only facilitates a deeper understanding of the underlying mechanics but also allows for easy customization and extension.
24
+
25
+ - **True to the Source**: By closely adhering to the methodologies detailed in the original DreamerV1 paper, the code captures the essence of latent space learning and imagination-driven planning. From the incorporation of exploration noise to the td lambda calculation, every element is designed to replicate the paper's results as closely as possible. The sets of hyperparamenters are excactly indentical to the sets mentioned in the paper
26
+
27
+ - **Detailed Training Insights**: The training loop is separated and mirroring the paper's outline. Comprehensive comments of hidden implementation details thorough documentation accompany the code, serving as a valuable resource for both learning and further research.
28
+
29
+ ## Getting Started
30
+
31
+ 1. **Clone the Repository**: Get the code by cloning this repository to your local machine.
32
+ ```
33
+ git clone https://github.com/minhphd/PyDreamerV1
34
+ ```
35
+
36
+ 2. **Install Dependencies**: Ensure you have all necessary dependencies by running:
37
+ ```
38
+ pip3 install -r requirements.txt
39
+ ```
40
+
41
+ 3. **Run the Training**: Kickstart the training process with a simple command:
42
+ ```
43
+ python main.py --config <Path to config file>
44
+ ```
45
+
46
+ 4. **Visualize Results**: Utilize TensorBoard to observe training progress and visualize the agent's performance in real-time. Wandb is also supported, simply set enable to True and replace with your account information in config files.
47
+ ```
48
+ tensorboard --logdir=runs
49
+ ```
50
+ **Optional: Visualize imagine sequences**: Using saved models to visualize agent's prediction of environment dynamic. You would need to create a config folder in the run logging directory and drag the training config file in
51
+ ```
52
+ python imagine.py --runpath <Path to run file>
53
+ ```
54
+
55
+ ## Citation
56
+ This implementation was made possible thanks to these papers.
57
+ ```bibtex
58
+ @article{hafner2019dream,
59
+ title={Dream to Control: Learning Behaviors by Latent Imagination},
60
+ author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
61
+ journal={arXiv preprint arXiv:1912.01603},
62
+ year={2019}
63
+ }
64
+ @misc{1801.00690,
65
+ title = {DeepMind Control Suite},
66
+ author = {Yuval Tassa and Yotam Doron and Alistair Muldal and Tom Erez and Yazhe Li and Diego de Las Casas and David Budden and Abbas Abdolmaleki and Josh Merel and Andrew Lefrancq and Timothy Lillicrap and Martin Riedmiller},
67
+ journal = {arXiv preprint arXiv:1801.00690},
68
+ year = {2018},
69
+ }
70
+
71
+ ```
72
+
73
+ ## Contributions
74
+
75
+ Contributions are welcome! Whether it's extending functionality, improving efficiency, or correcting bugs, your input helps make this project better for everyone.
algos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
algos/dreamer.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Minh Pham-Dinh
3
+ Created: Jan 27th, 2024
4
+ Last Modified: Feb 10th, 2024
5
6
+
7
+ Description:
8
+ main Dreamer file.
9
+
10
+ The implementation is based on:
11
+ Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
12
+ [Online]. Available: https://arxiv.org/abs/1912.01603
13
+ """
14
+
15
+ # Standard Library Imports
16
+ import os
17
+ import numpy as np
18
+ import yaml
19
+ from tqdm import tqdm
20
+ import wandb
21
+
22
+ # Machine Learning and Data Processing Imports
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.optim as optim
26
+
27
+ # Custom Utility Imports
28
+ import utils.models as models
29
+ from utils.buffer import ReplayBuffer
30
+ from utils.utils import td_lambda, log_metrics
31
+
32
+
33
+ class Dreamer:
34
+ def __init__(self, config, logpath, env, writer = None, wandb_writer=None):
35
+ self.config = config
36
+ self.device = torch.device(self.config.device)
37
+ self.env = env
38
+ self.obs_size = env.observation_space.shape
39
+ self.action_size = env.action_space.n if self.config.env.discrete else env.action_space.shape[0]
40
+ self.epsilon = self.config.main.epsilon_start
41
+ self.env_step = 0
42
+ self.logpath = logpath
43
+
44
+ # Set random seed for reproducibility
45
+ np.random.seed(self.config.seed)
46
+ torch.manual_seed(self.config.seed)
47
+
48
+ #dynamic networks initialized
49
+ self.rssm = models.RSSM(self.config.main.stochastic_size,
50
+ self.config.main.embedded_obs_size,
51
+ self.config.main.deterministic_size,
52
+ self.config.main.hidden_units,
53
+ self.action_size).to(self.device)
54
+
55
+ self.reward = models.RewardNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
56
+ self.config.main.hidden_units).to(self.device)
57
+
58
+ if self.config.main.continue_loss:
59
+ self.cont_net = models.ContinuoNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
60
+ self.config.main.hidden_units).to(self.device)
61
+
62
+ self.encoder = models.ConvEncoder(input_shape=self.obs_size).to(self.device)
63
+ self.decoder = models.ConvDecoder(self.config.main.stochastic_size,
64
+ self.config.main.deterministic_size,
65
+ out_shape=self.obs_size).to(self.device)
66
+ self.dyna_parameters = (
67
+ list(self.rssm.parameters())
68
+ + list(self.reward.parameters())
69
+ + list(self.encoder.parameters())
70
+ + list(self.decoder.parameters())
71
+ )
72
+
73
+ if self.config.main.continue_loss:
74
+ self.dyna_parameters += list(self.cont_net.parameters())
75
+
76
+ #behavior networks initialized
77
+ self.actor = models.Actor(self.config.main.stochastic_size + self.config.main.deterministic_size,
78
+ self.config.main.hidden_units,
79
+ self.action_size,
80
+ self.config.env.discrete).to(self.device)
81
+ self.critic = models.Critic(self.config.main.stochastic_size + self.config.main.deterministic_size,
82
+ self.config.main.hidden_units).to(self.device)
83
+
84
+ #optimizers
85
+ self.dyna_optimizer = optim.Adam(self.dyna_parameters, lr=self.config.main.dyna_model_lr)
86
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.config.main.actor_lr)
87
+ self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.config.main.critic_lr)
88
+ self.gradient_step = 0
89
+
90
+ #buffer
91
+ self.buffer = ReplayBuffer(self.config.main.buffer_capacity, self.obs_size, (self.action_size, ))
92
+
93
+ #tracking stuff
94
+ self.wandb_writer = wandb_writer
95
+ self.writer = writer
96
+
97
+
98
+ def update_epsilon(self):
99
+ """In use for decaying epsilon in discrete env
100
+
101
+ Returns:
102
+ _type_: _description_
103
+ """
104
+ eps_start = self.config.main.epsilon_start
105
+ eps_end = self.config.main.epsilon_end
106
+ decay_steps = self.config.main.eps_decay_steps
107
+ decay_rate = (eps_start - eps_end) / (decay_steps)
108
+ self.epsilon = max(eps_end, eps_start - decay_rate*self.gradient_step)
109
+
110
+
111
+ def train(self):
112
+ """main training loop, implementation follow closely with the loop from the official paper
113
+
114
+ Returns:
115
+ _type_: _description_
116
+ """
117
+
118
+ #prefill dataset
119
+ ep = 0
120
+ obs, _ = self.env.reset()
121
+ while ep < self.config.main.data_init_ep:
122
+ action = self.env.action_space.sample()
123
+ if self.config.env.discrete:
124
+ actions = np.zeros((self.action_size, ))
125
+ actions[action] = 1.0
126
+ else:
127
+ actions = action
128
+
129
+ next_obs, reward, termination, truncation, info = self.env.step(action)
130
+
131
+ self.buffer.add(obs, actions, reward, termination or truncation)
132
+ obs = next_obs
133
+
134
+ if "episode" in info:
135
+ obs, _ = self.env.reset()
136
+ ep += 1
137
+ print(ep)
138
+ if 'video_path' in info and self.wandb_writer:
139
+ self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
140
+
141
+ #main train loop
142
+ for _ in tqdm(range(self.config.main.total_iter)):
143
+ #save model if reached checkpoint
144
+ if _ % self.config.main.save_freq == 0:
145
+
146
+ #check if models folder exist
147
+ directory = self.logpath + 'models/'
148
+ os.makedirs(directory, exist_ok=True)
149
+
150
+ #save models
151
+ torch.save(self.rssm, self.logpath + 'models/rssm_model')
152
+ torch.save(self.encoder, self.logpath + 'models/encoder')
153
+ torch.save(self.decoder, self.logpath + 'models/decoder')
154
+ torch.save(self.actor, self.logpath + 'models/actor')
155
+ torch.save(self.critic, self.logpath + 'models/critic')
156
+
157
+ #run eval if reach eval checkpoint
158
+ if _ % self.config.main.eval_freq == 0:
159
+ eval_score = self.data_collection(self.config.main.eval_eps, eval=True)
160
+ metrics = {'performance/evaluation score': eval_score}
161
+ log_metrics(metrics, self.env_step, self.writer, self.wandb_writer)
162
+
163
+ #training step
164
+ for c in tqdm(range(self.config.main.collect_iter)):
165
+ #draw data
166
+ batch = self.buffer.sample(self.config.main.batch_size, self.config.main.seq_len, self.device)
167
+
168
+ #dynamic learning
169
+ post, deter = self.dynamic_learning(batch)
170
+
171
+ #behavioral learning
172
+ self.behavioral_learning(post, deter)
173
+
174
+ #update step
175
+ self.gradient_step += 1
176
+ self.update_epsilon()
177
+
178
+ # collect more data with exploration noise
179
+ self.data_collection(self.config.main.data_interact_ep)
180
+
181
+
182
+
183
+ def dynamic_learning(self, batch):
184
+ """Learning the dynamic model. In this method, we sequentially pass data in the RSSM to
185
+ learn the model
186
+
187
+ Args:
188
+ batch (addict.Dict): batches of data
189
+ """
190
+
191
+ '''
192
+ We unpack the batch. A batch contains:
193
+ - b_obs (batch_size, seq_len, *obs.shape): batches of observation
194
+ - b_a (batch_size, seq_len, 1): batches of action
195
+ - b_r (batch_size, seq_len, 1): batches of rewards
196
+ - b_d (batch_size, seq_len, 1): batches of termination signal
197
+ '''
198
+ b_obs = batch.obs
199
+ b_a = batch.actions
200
+ b_r = batch.rewards
201
+ b_d = batch.dones
202
+
203
+ batch_size, seq_len, _ = b_r.shape
204
+ eb_obs = self.encoder(b_obs)
205
+
206
+ #initialized stochastic states (posterior) and deterministic states to first pass into the recurrent model
207
+ posterior = torch.zeros((batch_size, self.config.main.stochastic_size)).to(self.device)
208
+ deterministic = torch.zeros((batch_size, self.config.main.deterministic_size)).to(self.device)
209
+
210
+ #initialized memory storing of sequential gradients data
211
+ posteriors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
212
+ priors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
213
+ deterministics = torch.zeros((batch_size, seq_len-1, self.config.main.deterministic_size)).to(self.device)
214
+
215
+ posterior_means = torch.zeros_like(posteriors).to(self.device)
216
+ posterior_stds = torch.zeros_like(posteriors).to(self.device)
217
+ prior_means = torch.zeros_like(priors).to(self.device)
218
+ prior_stds = torch.zeros_like(priors).to(self.device)
219
+
220
+ #start passing data through the dynamic model
221
+ for t in (range(1, seq_len)):
222
+ deterministic = self.rssm.recurrent(posterior, b_a[:, t-1, :], deterministic)
223
+ prior_dist, prior = self.rssm.transition(deterministic)
224
+
225
+ #detail observation is shifted 1 timestep ahead(action is associated with the next state)
226
+ posterior_dist, posterior = self.rssm.representation(eb_obs[:, t, :], deterministic)
227
+
228
+ '''
229
+ store recurrent data
230
+ data are shifted 1 timestep ahead. Start from the second timestep or t=1
231
+ '''
232
+ posteriors[:, t-1, :] = posterior
233
+ posterior_means[:, t-1, :] = posterior_dist.mean
234
+ posterior_stds[:, t-1, :] = posterior_dist.scale
235
+
236
+ priors[:, t-1, :] = prior
237
+ prior_means[:, t-1, :] = prior_dist.mean
238
+ prior_stds[:, t-1, :] = prior_dist.scale
239
+
240
+ deterministics[:, t-1, :] = deterministic
241
+
242
+ #we start optimizing model with the provided data
243
+
244
+ '''
245
+ Reconstruction loss. This loss helps the model learn to encode pixels observation.
246
+ '''
247
+ mps_flatten = False
248
+ if self.device == torch.device("mps"):
249
+ mps_flatten = True
250
+
251
+ reconstruct_dist = self.decoder(posteriors, deterministics, mps_flatten)
252
+ target = b_obs[:, 1:]
253
+ if mps_flatten:
254
+ target = target.reshape(-1, *self.obs_size)
255
+ reconstruct_loss = reconstruct_dist.log_prob(target).mean()
256
+
257
+ #reward loss
258
+ rewards = self.reward(posteriors, deterministics)
259
+ rewards_dist = torch.distributions.Normal(rewards, 1)
260
+ rewards_dist = torch.distributions.Independent(rewards_dist, 1)
261
+ rewards_loss = rewards_dist.log_prob(b_r[:, 1:]).mean()
262
+
263
+ '''
264
+ Continuity loss. This loss term helps predict the probability of an episode terminate at a particular state
265
+ '''
266
+ if self.config.main.continue_loss:
267
+ # calculate log prob manually as tensorflow doesn't support float value in logprob of Bernoulli
268
+ # follow closely to Hafner's official code for Dreamer
269
+ cont_logits, _ = self.cont_net(posteriors, deterministics)
270
+ cont_target = (1 - b_d[:, 1:]) * self.config.main.discount
271
+ continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(cont_logits, cont_target)
272
+ else:
273
+ continue_loss = torch.zeros((1)).to(self.device)
274
+
275
+ '''
276
+ KL loss. Matching the distribution of transition and representation model. This is to ensure we have the accurate transition model for use in imagination process
277
+ '''
278
+ priors_dist = torch.distributions.Independent(
279
+ torch.distributions.Normal(prior_means, prior_stds), 1
280
+ )
281
+ posteriors_dist = torch.distributions.Independent(
282
+ torch.distributions.Normal(posterior_means, posterior_stds), 1
283
+ )
284
+ kl_loss = torch.max(
285
+ torch.mean(torch.distributions.kl.kl_divergence(posteriors_dist, priors_dist)),
286
+ torch.tensor(self.config.main.free_nats).to(self.device)
287
+ )
288
+
289
+ total_loss = self.config.main.kl_divergence_scale * kl_loss - reconstruct_loss - rewards_loss + continue_loss
290
+
291
+ self.dyna_optimizer.zero_grad()
292
+ total_loss.backward()
293
+ nn.utils.clip_grad_norm_(
294
+ self.dyna_parameters,
295
+ self.config.main.clip_grad,
296
+ norm_type=self.config.main.grad_norm_type,
297
+ )
298
+ self.dyna_optimizer.step()
299
+
300
+ #tensorboard logging
301
+ metrics = {
302
+ 'Dynamic_model/KL': kl_loss.item(),
303
+ 'Dynamic_model/Reconstruction': reconstruct_loss.item(),
304
+ 'Dynamic_model/Reward': rewards_loss.item(),
305
+ 'Dynamic_model/Continue': continue_loss.item(),
306
+ 'Dynamic_model/Total': total_loss.item()
307
+ }
308
+
309
+ log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
310
+
311
+ return posteriors.detach(), deterministics.detach()
312
+
313
+
314
+ def behavioral_learning(self, state, deterministics):
315
+ """Learning behavioral through latent imagination
316
+
317
+ Args:
318
+ self (_type_): _description_
319
+ state (batch_size, seq_len-1, stoch_state_size): starting point state
320
+ deterministics (batch_size, seq_len-1, stoch_state_size)
321
+ """
322
+
323
+ #flatten the batches --> new size (batch_size * (seq_len-1), *)
324
+ state = state.reshape(-1, self.config.main.stochastic_size)
325
+ deterministics = deterministics.reshape(-1, self.config.main.deterministic_size)
326
+
327
+ batch_size, stochastic_size = state.shape
328
+ _, deterministics_size = deterministics.shape
329
+
330
+ #initialized trajectories
331
+ state_trajectories = torch.zeros((batch_size, self.config.main.horizon, stochastic_size)).to(self.device)
332
+ deterministics_trajectories = torch.zeros((batch_size, self.config.main.horizon, deterministics_size)).to(self.device)
333
+
334
+ #imagine trajectories
335
+ for t in range(self.config.main.horizon):
336
+ # do not include the starting state
337
+ action = self.actor(state, deterministics)
338
+ deterministics = self.rssm.recurrent(state, action, deterministics)
339
+ _, state = self.rssm.transition(deterministics)
340
+ state_trajectories[:, t, :] = state
341
+ deterministics_trajectories[:, t, :] = deterministics
342
+
343
+ '''
344
+ After imagining, we have both the state trajectories and deterministic trajectories, which can be used to create latent states.
345
+ - state_trajectories (N, HORIZON_LEN)
346
+ - deteerministic_trajectories (N, HORIZON_LEN)
347
+ '''
348
+
349
+ #actor update
350
+
351
+ #compute rewards for each trajectories
352
+ rewards = self.reward(state_trajectories, deterministics_trajectories)
353
+ rewards_dist = torch.distributions.Normal(rewards, 1)
354
+ rewards_dist = torch.distributions.Independent(rewards_dist, 1)
355
+ rewards = rewards_dist.mode
356
+
357
+ if self.config.main.continue_loss:
358
+ _, conts_dist = self.cont_net(state_trajectories, deterministics_trajectories)
359
+ continues = conts_dist.mean
360
+ else:
361
+ continues = self.config.main.discount * torch.ones_like(rewards)
362
+
363
+ values = self.critic(state_trajectories, deterministics_trajectories).mode
364
+
365
+ #calculate trajectories returns
366
+ #returns should have shape (N, HORIZON_LEN - 1, 1) (last values are ignored due to nature of bootstrapping)
367
+ returns = td_lambda(
368
+ rewards,
369
+ continues,
370
+ values,
371
+ self.config.main.lambda_,
372
+ self.device
373
+ )
374
+
375
+ #culm product for discount
376
+ discount = torch.cumprod(torch.cat((
377
+ torch.ones_like(continues[:, :1]).to(self.device),
378
+ continues[:, :-2]
379
+ ), 1), 1).detach()
380
+
381
+ # actor optimizing
382
+ actor_loss = -(discount * returns).mean()
383
+
384
+ self.actor_optimizer.zero_grad()
385
+ actor_loss.backward()
386
+ nn.utils.clip_grad_norm_(
387
+ self.actor.parameters(),
388
+ self.config.main.clip_grad,
389
+ norm_type=self.config.main.grad_norm_type,
390
+ )
391
+ self.actor_optimizer.step()
392
+
393
+
394
+ # critic optimizing
395
+ values_dist = self.critic(state_trajectories[:, :-1].detach(), deterministics_trajectories[:, :-1].detach())
396
+
397
+ critic_loss = -(discount.squeeze() * values_dist.log_prob(returns.detach())).mean()
398
+
399
+ self.critic_optimizer.zero_grad()
400
+ critic_loss.backward()
401
+ nn.utils.clip_grad_norm_(
402
+ self.critic.parameters(),
403
+ self.config.main.clip_grad,
404
+ norm_type=self.config.main.grad_norm_type,
405
+ )
406
+ self.critic_optimizer.step()
407
+
408
+ metrics = {
409
+ 'Behavorial_model/Actor': actor_loss.item(),
410
+ 'Behavorial_model/Critic': critic_loss.item()
411
+ }
412
+
413
+ log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
414
+
415
+
416
+ @torch.no_grad()
417
+ def data_collection(self, num_episodes, eval=False):
418
+ """data collection method. Roll out agent a number of episodes and collect data
419
+ If eval=True. The agent is set for evaluation mode with no exploration noise and data collection
420
+
421
+ Args:
422
+ num_episodes (int): number of episodes
423
+ eval (bool): Evaluation mode. Defaults to False.
424
+ random (bool): Random mode. Defaults to False.
425
+
426
+ Returns:
427
+ average_score: average score over number of rollout episodes
428
+ """
429
+ score = 0
430
+ ep = 0
431
+ obs, _ = self.env.reset()
432
+ #initialized all zeros
433
+ posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
434
+ deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
435
+ action = torch.zeros((1, self.action_size)).to(self.device)
436
+
437
+ while ep < num_episodes:
438
+ embed_obs = self.encoder(torch.from_numpy(obs).to(self.device, dtype=torch.float)) #(1, embed_obs_sz)
439
+ deterministic = self.rssm.recurrent(posterior, action, deterministic)
440
+ _, posterior = self.rssm.representation(embed_obs, deterministic)
441
+ actor_out = self.actor(posterior, deterministic)
442
+
443
+ #detail: add exploration noise if not in evaluation mode
444
+ if not eval:
445
+ actions = actor_out.cpu().numpy()
446
+ if self.config.env.discrete:
447
+ if np.random.rand() < self.epsilon:
448
+ action = self.env.action_space.sample()
449
+ else:
450
+ action = np.argmax(actions)
451
+ else:
452
+ mean_noise = self.config.main.mean_noise
453
+ std_noise = self.config.main.std_noise
454
+
455
+ normal_dist = torch.distributions.Normal(actor_out + mean_noise, std_noise)
456
+ sampled_action = normal_dist.sample().cpu().numpy()
457
+ actions = np.clip(sampled_action, a_min=-1, a_max=1)
458
+ action = actions[0]
459
+ else:
460
+ actions = actor_out.cpu().numpy()
461
+ if self.config.env.discrete:
462
+ action = np.argmax(actions)
463
+ else:
464
+ actions = np.clip(actions, a_min=-1, a_max=1)
465
+ action = actions[0]
466
+
467
+ next_obs, reward, termination, truncation, info = self.env.step(action)
468
+
469
+ if not eval:
470
+ self.buffer.add(obs, actions, reward, termination | truncation)
471
+ self.env_step += self.config.env.action_repeat
472
+ obs = next_obs
473
+
474
+ action = actor_out
475
+ if "episode" in info:
476
+ cur_score = info["episode"]["r"][0]
477
+ score += cur_score
478
+ obs, _ = self.env.reset()
479
+ ep += 1
480
+
481
+ if 'video_path' in info and self.wandb_writer:
482
+ self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
483
+ log_metrics({'performance/training score': cur_score}, self.env_step, self.writer, self.wandb_writer)
484
+
485
+ posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
486
+ deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
487
+ action = torch.zeros((1, self.action_size)).to(self.device)
488
+
489
+ return score/num_episodes
bash/setup.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Update and Upgrade the System
4
+ echo "Updating and upgrading the system..."
5
+ sudo apt-get update -y
6
+ sudo apt-get upgrade -y
7
+
8
+ # Install dependencies for Gymnasium
9
+ echo "Installing dependencies for Gymnasium..."
10
+
11
+ # Development tools
12
+ sudo apt-get install -y build-essential
13
+
14
+ # Python 3 and pip
15
+ sudo apt-get install -y python3 python3-pip
16
+ sudo apt-get install python3-opencv
17
+
18
+ # System libraries
19
+ sudo apt-get install -y libglew-dev libjpeg-dev libboost-all-dev libglu1-mesa-dev freeglut3-dev mesa-common-dev
20
+
21
+ # SWIG for interface generation
22
+ sudo apt-get install -y swig
23
+
24
+ # Gymnasium and additional dependencies via pip
25
+ echo "Installing requirements.txt"
26
+ pip3 install -r requirements.txt
27
+ sudo apt-get install xvfb
28
+ Xvfb :99 -screen 0 1024x768x24 &
29
+
30
+ export DISPLAY=:99
31
+
32
+ echo "Setup complete!"
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/dm_control/Cart-pole.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "cuda"
2
+ experiment_name: Cart-pole
3
+ seed: 0
4
+
5
+ env:
6
+ env_id: cartpole
7
+ task: balance
8
+ discrete: False
9
+ new_obs_size: [64, 64]
10
+ norm_obs: True
11
+
12
+ tensorboard:
13
+ enable: False
14
+ log_dir: "./runs/"
15
+ log_frequency: 1 # Log every 1000 steps
16
+
17
+ wandb:
18
+ enable: True
19
+ project: "dreamer"
20
+ entity: "phdminh01"
21
+ log_frequency: 1
22
+
23
+ video_recording:
24
+ enable: True
25
+ record_frequency: 50 #episodes
26
+
27
+ main:
28
+ continue_loss: False
29
+ continue_scale_factor: 10
30
+ total_iter: 2000
31
+ save_freq: 20
32
+ collect_iter: 100
33
+ data_interact_ep: 1
34
+ # data_init_ep: 1
35
+ data_init_ep: 5
36
+ horizon: 15
37
+ batch_size: 50
38
+ seq_len: 50
39
+ eval_eps: 3
40
+ eval_freq: 5
41
+
42
+ kl_divergence_scale : 1
43
+ free_nats : 3
44
+ discount : 0.99
45
+ lambda_ : 0.95
46
+
47
+ actor_lr : 8.0e-5
48
+ critic_lr : 8.0e-5
49
+ dyna_model_lr : 6.0e-4
50
+ grad_norm_type : 2
51
+ clip_grad : 100
52
+
53
+ hidden_units: 300
54
+ deterministic_size : 200
55
+ stochastic_size : 30
56
+ embedded_obs_size : 1024
57
+ buffer_capacity : 500000
58
+
59
+ epsilon_start: 0.4
60
+ epsilon_end: 0.1
61
+ eps_decay_steps: 200000
62
+
63
+ mean_noise: 0
64
+ std_noise: 0.3
configs/dm_control/Quadruped.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "cuda"
2
+ experiment_name: Quadruped
3
+ seed: 0
4
+
5
+ env:
6
+ env_id: quadruped
7
+ task: walk
8
+ discrete: False
9
+ new_obs_size: [64, 64]
10
+ norm_obs: True
11
+
12
+ tensorboard:
13
+ enable: False
14
+ log_dir: "./runs/"
15
+ log_frequency: 1 # Log every 1000 steps
16
+
17
+ wandb:
18
+ enable: True
19
+ project: "dreamer"
20
+ entity: "phdminh01"
21
+ log_frequency: 1
22
+
23
+ video_recording:
24
+ enable: True
25
+ record_frequency: 50 #episodes
26
+
27
+ main:
28
+ continue_loss: False
29
+ continue_scale_factor: 10
30
+ total_iter: 2000
31
+ save_freq: 20
32
+ collect_iter: 100
33
+ data_interact_ep: 1
34
+ # data_init_ep: 1
35
+ data_init_ep: 5
36
+ horizon: 15
37
+ batch_size: 50
38
+ seq_len: 50
39
+ eval_eps: 3
40
+ eval_freq: 5
41
+
42
+ kl_divergence_scale : 1
43
+ free_nats : 3
44
+ discount : 0.99
45
+ lambda_ : 0.95
46
+
47
+ actor_lr : 8.0e-5
48
+ critic_lr : 8.0e-5
49
+ dyna_model_lr : 6.0e-4
50
+ grad_norm_type : 2
51
+ clip_grad : 100
52
+
53
+ hidden_units: 300
54
+ deterministic_size : 200
55
+ stochastic_size : 30
56
+ embedded_obs_size : 1024
57
+ buffer_capacity : 500000
58
+
59
+ epsilon_start: 0.4
60
+ epsilon_end: 0.1
61
+ eps_decay_steps: 200000
62
+
63
+ mean_noise: 0
64
+ std_noise: 0.3
configs/dm_control/Walker.yml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "mps"
2
+ experiment_name: Walker
3
+ seed: 0
4
+
5
+ env:
6
+ env_id: walker
7
+ task: walk
8
+ new_obs_size: [64, 64]
9
+ action_repeat: 2
10
+ time_limit: 1000
11
+
12
+ tensorboard:
13
+ enable: False
14
+ log_dir: "./runs/"
15
+ log_frequency: 1 # Log every 1000 steps
16
+
17
+ wandb:
18
+ enable: False
19
+ project: "dreamer"
20
+ entity: "phdminh01"
21
+ log_frequency: 1
22
+
23
+ video_recording:
24
+ enable: False
25
+ record_frequency: 100 #episodes
26
+
27
+ main:
28
+ continue_loss: False
29
+ continue_scale_factor: 10
30
+ total_iter: 2000
31
+ save_freq: 20
32
+ collect_iter: 100
33
+ data_interact_ep: 1
34
+ # data_init_ep: 1
35
+ data_init_ep: 5
36
+ horizon: 15
37
+ batch_size: 50
38
+ seq_len: 50
39
+ eval_eps: 3
40
+ eval_freq: 5
41
+
42
+ kl_divergence_scale : 1
43
+ free_nats : 3
44
+ discount : 0.99
45
+ lambda_ : 0.95
46
+
47
+ use_continue_flag : True
48
+ actor_lr : 8.0e-5
49
+ critic_lr : 8.0e-5
50
+ dyna_model_lr : 6.0e-4
51
+ grad_norm_type : 2
52
+ clip_grad : 100
53
+
54
+ hidden_units: 300
55
+ deterministic_size : 200
56
+ stochastic_size : 30
57
+ embedded_obs_size : 1024
58
+ buffer_capacity : 500000
59
+
60
+ epsilon_start: 0.4
61
+ epsilon_end: 0.1
62
+ eps_decay_steps: 200000
63
+
64
+ mean_noise: 0
65
+ std_noise: 0.3
configs/gymnasium/Boxing-v5.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "mps"
2
+ experiment_name: Boxing-v5-new
3
+ seed: 0
4
+
5
+ env:
6
+ env_id: ALE/Boxing-v5
7
+ channel_first: True
8
+ discrete: True
9
+ resize_obs: True
10
+ new_obs_size: [64, 64]
11
+ norm_obs: True
12
+
13
+ tensorboard:
14
+ enable: True
15
+ log_dir: "./runs/"
16
+ log_frequency: 1 # Log every 1000 steps
17
+
18
+ wandb:
19
+ enable: False
20
+ project: "dreamer"
21
+ entity: "phdminh01"
22
+ log_frequency: 1
23
+
24
+ video_recording:
25
+ enable: True
26
+ record_frequency: 100 #episodes
27
+ save_path: "./runs/"
28
+
29
+ main:
30
+ continue_loss: False
31
+ continue_scale_factor: 10
32
+ total_iter: 2000
33
+ save_freq: 20
34
+ collect_iter: 100
35
+ data_interact_ep: 1
36
+ data_init_ep: 1
37
+ # data_init_ep: 5
38
+ horizon: 10
39
+ batch_size: 50
40
+ seq_len: 50
41
+ eval_eps: 3
42
+ eval_freq: 5
43
+
44
+ kl_divergence_scale : 1
45
+ free_nats : 3
46
+ discount : 0.99
47
+ lambda_ : 0.95
48
+
49
+ use_continue_flag : True
50
+ actor_lr : 8.0e-5
51
+ critic_lr : 8.0e-5
52
+ dyna_model_lr : 6.0e-4
53
+ grad_norm_type : 2
54
+ clip_grad : 100
55
+
56
+ hidden_units: 400
57
+ deterministic_size : 600
58
+ stochastic_size : 600
59
+ embedded_obs_size : 1024
60
+ buffer_capacity : 500000
61
+
62
+ epsilon_start: 0.4
63
+ epsilon_end: 0.1
64
+ eps_decay_steps: 200000
65
+
66
+ mean_noise: 0
67
+ std_noise: 0.3
configs/gymnasium/Pacman-v5.yml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "cuda"
2
+ experiment_name: Pacman-v5
3
+ seed: 0
4
+
5
+ env:
6
+ env_id: ALE/MsPacman-v5
7
+ channel_first: True
8
+ discrete: True
9
+ resize_obs: True
10
+ new_obs_size: [64, 64]
11
+ norm_obs: True
12
+
13
+ tensorboard:
14
+ enable: False
15
+ log_dir: "./runs/"
16
+ log_frequency: 1 # Log every 1000 steps
17
+
18
+ wandb:
19
+ enable: True
20
+ project: "dreamer"
21
+ entity: "phdminh01"
22
+ log_frequency: 1
23
+
24
+ video_recording:
25
+ enable: True
26
+ record_frequency: 100 #episodes
27
+ save_path: "./runs/"
28
+
29
+ main:
30
+ continue_loss: True
31
+ continue_scale_factor: 10
32
+ total_iter: 2000
33
+ save_freq: 20
34
+ collect_iter: 100
35
+ data_interact_ep: 1
36
+ # data_init_ep: 1
37
+ data_init_ep: 5
38
+ horizon: 15
39
+ batch_size: 50
40
+ seq_len: 50
41
+ eval_eps: 3
42
+ eval_freq: 5
43
+
44
+ kl_divergence_scale : 1
45
+ free_nats : 3
46
+ discount : 0.99
47
+ lambda_ : 0.95
48
+
49
+ actor_lr : 8.0e-5
50
+ critic_lr : 8.0e-5
51
+ dyna_model_lr : 6.0e-4
52
+ grad_norm_type : 2
53
+ clip_grad : 100
54
+
55
+ hidden_units: 400
56
+ deterministic_size : 600
57
+ stochastic_size : 600
58
+ embedded_obs_size : 1024
59
+ buffer_capacity : 500000
60
+
61
+ epsilon_start: 0.4
62
+ epsilon_end: 0.1
63
+ eps_decay_steps: 200000
64
+
65
+ mean_noise: 0
66
+ std_noise: 0.3
configs/gymnasium/ant_v4.yml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "mps"
2
+ experiment_name: Ant-v4
3
+ seed: 0
4
+
5
+ gymnasium:
6
+ env_id: Ant-v4
7
+ channel_first: True
8
+ pixels: True
9
+ discrete: False
10
+ resize_obs: True
11
+ new_obs_size: [64, 64]
12
+ norm_obs: True
13
+
14
+ tensorboard:
15
+ enable: False
16
+ log_dir: "./runs/"
17
+ log_frequency: 1 # Log every 1000 steps
18
+
19
+ wandb:
20
+ enable: False
21
+ project: "dreamer"
22
+ entity: "phdminh01"
23
+ log_frequency: 1
24
+
25
+ video_recording:
26
+ enable: False
27
+ record_frequency: 100 #episodes
28
+
29
+ main:
30
+ total_iter: 2000
31
+ save_freq: 100
32
+ collect_iter: 100
33
+ data_interact_ep: 1
34
+ # data_init_ep: 1
35
+ data_init_ep: 5
36
+ horizon: 15
37
+ batch_size: 50
38
+ seq_len: 50
39
+ eval_eps: 3
40
+ eval_freq: 5
41
+
42
+ kl_divergence_scale : 1
43
+ free_nats : 3
44
+ discount : 0.99
45
+ lambda_ : 0.95
46
+
47
+ use_continue_flag : True
48
+ actor_lr : 3.0e-4
49
+ critic_lr : 3.0e-4
50
+ dyna_model_lr : 6.0e-4
51
+ grad_norm_type : 2
52
+ clip_grad : 100
53
+
54
+ hidden_units: 400
55
+ deterministic_size : 600
56
+ stochastic_size : 600
57
+ embedded_obs_size : 1024
58
+ buffer_capacity : 500000
59
+
60
+ epsilon_start: 0.4
61
+ epsilon_end: 0.1
62
+ eps_decay_steps: 200000
63
+
64
+ mean_noise: 0
65
+ std_noise: 0.3
configs/gymnasium/car_racing_config.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: "cuda"
2
+ experiment_name: CarRacing-v2
3
+ seed: 0
4
+
5
+ gymnasium:
6
+ env_id: CarRacing-v2
7
+ channel_first: True
8
+ discrete: False
9
+ resize_obs: True
10
+ new_obs_size: [64, 64]
11
+ norm_obs: True
12
+
13
+ tensorboard:
14
+ enable: True
15
+ log_dir: "./runs/"
16
+ log_frequency: 1 # Log every 1000 steps
17
+
18
+ wandb:
19
+ enable: True
20
+ project: "dreamer"
21
+ entity: "phdminh01"
22
+ log_frequency: 1
23
+
24
+ video_recording:
25
+ enable: True
26
+ record_frequency: 100 #episodes
27
+
28
+ main:
29
+ total_iter: 2000
30
+ save_freq: 100
31
+ collect_iter: 100
32
+ data_interact_ep: 1
33
+ # data_init_ep: 1
34
+ data_init_ep: 5
35
+ horizon: 15
36
+ batch_size: 50
37
+ seq_len: 50
38
+ eval_eps: 3
39
+ eval_freq: 5
40
+
41
+ kl_divergence_scale : 1
42
+ free_nats : 3
43
+ discount : 0.99
44
+ lambda_ : 0.95
45
+
46
+ use_continue_flag : True
47
+ actor_lr : 3.0e-4
48
+ critic_lr : 3.0e-4
49
+ dyna_model_lr : 6.0e-4
50
+ grad_norm_type : 2
51
+ clip_grad : 100
52
+
53
+ hidden_units: 400
54
+ deterministic_size : 600
55
+ stochastic_size : 600
56
+ embedded_obs_size : 1024
57
+ buffer_capacity : 500000
58
+
59
+ epsilon_start: 0.4
60
+ epsilon_end: 0.1
61
+ eps_decay_steps: 200000
62
+
63
+ mean_noise: 0
64
+ std_noise: 0.3
gif/boxing.gif ADDED
gif/boxing_imagine.gif ADDED
gif/pacman.gif ADDED

Git LFS Details

  • SHA256: 0d138c415095804b0e565c041226613c3d6c9a82c06d00a7a5349ca5e94d557a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
gif/pacman_imagine.gif ADDED

Git LFS Details

  • SHA256: 98462a09ec13b588685cde9f0bd69bfac4ab6d295e82d24c55e013a9e760f7a4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
gif/quadruped.gif ADDED

Git LFS Details

  • SHA256: 9d9fed8476008b30fc517c8fbbf43c795a719be6b5df439bc62f7146efb44bb9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
gif/quadruped_imagine.gif ADDED
gif/walker_imagine.gif ADDED

Git LFS Details

  • SHA256: 91c7d7e0faad5d120829e4fced2baa65d7d37367e5b05b24c77caf9d1db48c9a
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
imagine.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Minh Pham-Dinh
3
+ Created: Feb 4th, 2024
4
+ Last Modified: Feb 6th, 2024
5
6
+
7
+ Description:
8
+ Imagination file. Run this file to generate dream sequences
9
+ """
10
+
11
+ import sys
12
+ import argparse
13
+ from utils.wrappers import DMCtoGymWrapper, AtariPreprocess
14
+ from addict import Dict
15
+ import yaml
16
+ import gymnasium as gym
17
+ import torch
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ import glob
21
+
22
+ parser = argparse.ArgumentParser(description='Process configuration file path.')
23
+ parser.add_argument('--runpath', type=str, help='Path to the run file.', required=True)
24
+ parser.add_argument('--horizon', type=int, help='number of imagination steps.', default=15)
25
+
26
+ # Parse the arguments
27
+ args = parser.parse_args()
28
+
29
+ # Load the configuration file specified by the command line argument
30
+ run_path = args.runpath
31
+ HORIZON = args.horizon
32
+
33
+ config_files = glob.glob(run_path + '/config/*.yml')
34
+
35
+ if len(config_files) != 1:
36
+ print('there should only be 1 config file in config directory')
37
+
38
+ with open(config_files[0], 'r') as file:
39
+ config = Dict(yaml.load(file, Loader=yaml.FullLoader))
40
+
41
+ env_id = config.env.env_id
42
+
43
+ if 'ALE' in config.env.env_id:
44
+ env = gym.make(env_id, render_mode='rgb_array')
45
+ env = AtariPreprocess(env, config.env.new_obs_size,
46
+ False)
47
+ else:
48
+ task = config.env.task
49
+ env = DMCtoGymWrapper(env_id, task,
50
+ resize=config.env.new_obs_size,
51
+ record=False)
52
+
53
+ print("start imagining")
54
+
55
+ encode = torch.load(run_path + '/models/encoder', map_location=torch.device('cpu') )
56
+ decoder = torch.load(run_path + '/models/decoder', map_location=torch.device('cpu') )
57
+ rssm = torch.load(run_path + '/models/rssm_model', map_location=torch.device('cpu') )
58
+ actor = torch.load(run_path + '/models/actor', map_location=torch.device('cpu'))
59
+
60
+ obs, _ = env.reset()
61
+
62
+ for i in range(100):
63
+ obs, _, _, _, _ = env.step(env.action_space.sample())
64
+
65
+ posterior = torch.zeros((1, config.main.stochastic_size))
66
+ deterministic = torch.zeros((1, config.main.deterministic_size))
67
+ e_obs = encode(torch.from_numpy(obs).to(dtype=torch.float))
68
+
69
+ _, posterior = rssm.representation(e_obs, deterministic)
70
+
71
+ from PIL import Image
72
+
73
+ frames = []
74
+
75
+ for i in tqdm(range(200)):
76
+ actions = actor(posterior, deterministic)
77
+ deterministic = rssm.recurrent(posterior, actions, deterministic)
78
+ dist, posterior = rssm.transition(deterministic)
79
+ d_obs = decoder(posterior, deterministic)
80
+ d_obs = d_obs.mean.squeeze().detach().numpy()
81
+ obs = ((d_obs.transpose([1,2,0]) + 0.5) * 255).clip(0, 255).astype(np.uint8)
82
+ img = Image.fromarray(obs, "RGB")
83
+ frames.append(img)
84
+
85
+ print("saving gif")
86
+ frame_one = frames[0]
87
+ frame_one.save(run_path + "/imagine.gif", format="GIF", append_images=frames, save_all=True, duration=30, loop=0)
88
+ print("finished")
requirements.txt ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ addict==2.4.0
3
+ ale-py==0.8.1
4
+ anyio==4.2.0
5
+ appnope==0.1.3
6
+ argon2-cffi==23.1.0
7
+ argon2-cffi-bindings==21.2.0
8
+ arrow==1.3.0
9
+ async-lru==2.0.4
10
+ attrdict==2.0.1
11
+ attrs==23.2.0
12
+ AutoROM==0.4.2
13
+ AutoROM.accept-rom-license==0.6.1
14
+ Babel==2.14.0
15
+ beautifulsoup4==4.12.3
16
+ bleach==6.1.0
17
+ box2d-py==2.3.5
18
+ cachetools==5.3.2
19
+ certifi==2023.11.17
20
+ cffi==1.16.0
21
+ charset-normalizer==3.3.2
22
+ click==8.1.7
23
+ cloudpickle==3.0.0
24
+ comm==0.2.1
25
+ contourpy==1.2.0
26
+ cycler==0.12.1
27
+ debugpy==1.8.0
28
+ decorator==4.4.2
29
+ defusedxml==0.7.1
30
+ dm-tree==0.1.8
31
+ etils==1.6.0
32
+ Farama-Notifications==0.0.4
33
+ fastjsonschema==2.19.1
34
+ filelock==3.13.1
35
+ fonttools==4.47.2
36
+ fqdn==1.5.1
37
+ fsspec==2023.12.2
38
+ gast==0.5.4
39
+ glfw==2.6.5
40
+ google-auth==2.26.2
41
+ google-auth-oauthlib==1.2.0
42
+ grpcio==1.60.0
43
+ gymnasium==0.29.1
44
+ idna==3.6
45
+ imageio==2.33.1
46
+ imageio-ffmpeg==0.4.9
47
+ importlib-resources==6.1.1
48
+ ipykernel==6.29.0
49
+ ipywidgets==8.1.1
50
+ isoduration==20.11.0
51
+ Jinja2==3.1.3
52
+ json5==0.9.14
53
+ jsonpointer==2.4
54
+ jsonschema==4.21.0
55
+ jsonschema-specifications==2023.12.1
56
+ jupyter==1.0.0
57
+ jupyter-console==6.6.3
58
+ jupyter-events==0.9.0
59
+ jupyter-lsp==2.2.2
60
+ jupyter_client==8.6.0
61
+ jupyter_core==5.7.1
62
+ jupyter_server==2.12.5
63
+ jupyter_server_terminals==0.5.1
64
+ jupyterlab==4.0.10
65
+ jupyterlab-widgets==3.0.9
66
+ jupyterlab_pygments==0.3.0
67
+ jupyterlab_server==2.25.2
68
+ kiwisolver==1.4.5
69
+ lz4==4.3.3
70
+ Markdown==3.5.2
71
+ MarkupSafe==2.1.3
72
+ matplotlib==3.8.2
73
+ mistune==3.0.2
74
+ moviepy==1.0.3
75
+ mpmath==1.3.0
76
+ mujoco==3.1.1
77
+ nbclient==0.9.0
78
+ nbconvert==7.14.2
79
+ nbformat==5.9.2
80
+ nest-asyncio==1.5.9
81
+ networkx==3.2.1
82
+ notebook==7.0.6
83
+ notebook_shim==0.2.3
84
+ numpy==1.26.3
85
+ oauthlib==3.2.2
86
+ opencv-python==4.9.0.80
87
+ overrides==7.4.0
88
+ packaging==23.2
89
+ pandocfilters==1.5.1
90
+ pillow==10.2.0
91
+ platformdirs==4.1.0
92
+ proglog==0.1.10
93
+ prometheus-client==0.19.0
94
+ protobuf==4.23.4
95
+ psutil==5.9.7
96
+ pyasn1==0.5.1
97
+ pyasn1-modules==0.3.0
98
+ pycparser==2.21
99
+ pygame==2.5.2
100
+ PyOpenGL==3.1.7
101
+ pyparsing==3.1.1
102
+ python-dateutil==2.8.2
103
+ python-json-logger==2.0.7
104
+ PyYAML==6.0.1
105
+ pyzmq==25.1.2
106
+ qtconsole==5.5.1
107
+ QtPy==2.4.1
108
+ referencing==0.32.1
109
+ requests==2.31.0
110
+ requests-oauthlib==1.3.1
111
+ rfc3339-validator==0.1.4
112
+ rfc3986-validator==0.1.1
113
+ rpds-py==0.17.1
114
+ rsa==4.9
115
+ Send2Trash==1.8.2
116
+ Shimmy==0.2.1
117
+ sniffio==1.3.0
118
+ soupsieve==2.5
119
+ swig==4.1.1.post1
120
+ sympy==1.12
121
+ tensorboard==2.15.1
122
+ tensorboard-data-server==0.7.2
123
+ tensorflow-probability==0.23.0
124
+ terminado==0.18.0
125
+ tinycss2==1.2.1
126
+ tomli==2.0.1
127
+ torch==2.1.2
128
+ torchaudio==2.1.2
129
+ torchvision==0.16.2
130
+ wandb
131
+ tornado==6.4
132
+ tqdm==4.66.1
133
+ types-python-dateutil==2.8.19.20240106
134
+ typing_extensions==4.9.0
135
+ uri-template==1.3.0
136
+ urllib3==2.1.0
137
+ webcolors==1.13
138
+ webencodings==0.5.1
139
+ websocket-client==1.7.0
140
+ Werkzeug==3.0.1
141
+ widgetsnbextension==4.0.9
142
+ zipp==3.17.0
143
+ dm_control
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/__pycache__/buffer.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
utils/__pycache__/models.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
utils/__pycache__/wrappers.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
utils/buffer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Minh Pham-Dinh
3
+ Created: Jan 26th, 2024
4
+ Last Modified: Feb 5th, 2024
5
6
+
7
+ Description:
8
+ File containing the ReplayBuffer that will be used in Dreamer.
9
+
10
+ The implementation is based on:
11
+ Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
12
+ [Online]. Available: https://arxiv.org/abs/1912.01603
13
+ """
14
+
15
+ import numpy as np
16
+ from gymnasium import Env
17
+ import torch
18
+ from addict import Dict
19
+
20
+ class ReplayBuffer:
21
+ def __init__(self, capacity, obs_size, action_size):
22
+
23
+ # check if the env is gymnasium or dmc
24
+ self.obs_size = obs_size
25
+ self.action_size = action_size
26
+
27
+ # from SimpleDreamer implementation, saving memory
28
+ state_type = np.uint8 if len(self.obs_size) < 3 else np.float32
29
+
30
+ self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type)
31
+
32
+ self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32)
33
+ self.rewards = np.zeros((capacity, 1), dtype=np.float32)
34
+ self.dones = np.zeros((capacity, 1), dtype=np.float32)
35
+
36
+ self.pointer = 0
37
+ self.full = False
38
+
39
+ print(f'''
40
+ -----------initialized memory----------
41
+
42
+ obs_buffer_shape: {self.observation.shape}
43
+ actions_buffer_shape: {self.actions.shape}
44
+ rewards_buffer_shape: {self.rewards.shape}
45
+ dones_buffer_shape: {self.dones.shape}
46
+
47
+ ----------------------------------------
48
+ ''')
49
+
50
+ def add(self, obs, action, reward, done):
51
+ """Add method for buffer
52
+
53
+ Args:
54
+ obs (np.array): current observation
55
+ action (np.array): action taken
56
+ reward (float): reward received after action
57
+ next_obs (np.array): next observation
58
+ done (bool): boolean value of termination or truncation
59
+ """
60
+ self.observation[self.pointer] = obs
61
+ self.actions[self.pointer] = action
62
+ self.rewards[self.pointer] = reward
63
+ self.dones[self.pointer] = done
64
+ self.pointer = (self.pointer + 1) % self.observation.shape[0]
65
+ if self.pointer == 0:
66
+ self.full = True
67
+
68
+ def sample(self, batch_size, seq_len, device):
69
+ """
70
+ Samples batches of experiences of fixed sequence length from the replay buffer,
71
+ taking into account the circular nature of the buffer to avoid crossing the
72
+ "end" of the buffer when it is full.
73
+
74
+ This method ensures that sampled sequences are continuous and do not wrap around
75
+ the end of the buffer, maintaining the temporal integrity of experiences. This is
76
+ particularly important when the buffer is full, and the pointer marks the boundary
77
+ between the newest and oldest data in the buffer.
78
+
79
+ Args:
80
+ batch_size (int): The number of sequences to sample.
81
+ seq_len (int): The length of each sequence to sample.
82
+ device (torch.device): The device on which the sampled data will be loaded.
83
+
84
+ Raises:
85
+ Exception: If there is not enough data in the buffer to sample a full sequence.
86
+
87
+ Returns:
88
+ Dict: A dictionary containing the sampled sequences of observations, actions,
89
+ rewards, and dones. Each item in the dictionary is a tensor of shape
90
+ (batch_size, seq_len, feature_dimension), except for 'dones' which is of shape
91
+ (batch_size, seq_len, 1).
92
+
93
+ Notes:
94
+ - The method handles different scenarios based on the buffer's state (full or not)
95
+ and the pointer's position to ensure valid sequence sampling without wrapping.
96
+ - When the buffer is not full, sequences can start from index 0 up to the
97
+ index where `seq_len` sequences can fit without surpassing the current pointer.
98
+ - When the buffer is full, the method ensures sequences do not start in a way
99
+ that would cause them to wrap around past the pointer, effectively crossing
100
+ the boundary between the newest and oldest data.
101
+ - This approach guarantees the sampled sequences respect the temporal order
102
+ and continuity necessary for algorithms that rely on sequences of experiences.
103
+ """
104
+
105
+ # Ensure there's enough data to sample
106
+ if self.pointer < seq_len and not self.full:
107
+ raise Exception('not enough data to sample')
108
+
109
+ # detail: handling different cases for circular sampling
110
+ if self.full:
111
+ if self.pointer - seq_len < 0:
112
+ valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1)
113
+ else:
114
+ range_1 = np.arange(0, self.pointer - seq_len + 1)
115
+ range_2 = np.arange(self.pointer, self.observation.shape[0])
116
+ valid_range = np.concatenate((range_1, range_2), -1)
117
+ else:
118
+ valid_range = np.arange(0, self.pointer-seq_len+1)
119
+
120
+ start_index = np.random.choice(valid_range, (batch_size, 1))
121
+
122
+ seq_len = np.arange(seq_len)
123
+ sample_idcs = (start_index + seq_len) % self.observation.shape[0]
124
+
125
+ batch = Dict()
126
+
127
+ batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device)
128
+ batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device)
129
+ batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device)
130
+ batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device)
131
+
132
+ return batch
133
+
134
+ def clear(self, ):
135
+ self.pointer = 0
136
+ self.full = False
137
+
138
+ def __len__(self, ):
139
+ return self.pointer
utils/models.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Minh Pham-Dinh
3
+ Created: Jan 26th, 2024
4
+ Last Modified: Feb 10th, 2024
5
6
+
7
+ Description:
8
+ File containing all models that will be used in Dreamer.
9
+
10
+ The implementation is based on:
11
+ Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
12
+ [Online]. Available: https://arxiv.org/abs/1912.01603
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+ def initialize_weights(m):
21
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
22
+ nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
23
+ nn.init.constant_(m.bias.data, 0)
24
+ elif isinstance(m, nn.Linear):
25
+ nn.init.kaiming_uniform_(m.weight.data)
26
+ nn.init.constant_(m.bias.data, 0)
27
+
28
+
29
+ class RSSM(nn.Module):
30
+ """Reccurent State Space Model (RSSM)
31
+ The main model that we will use to learn the latent dynamic of the environment
32
+ """
33
+ def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU):
34
+ super().__init__()
35
+ self.stochastic_size = stochastic_size
36
+ self.action_size = action_size
37
+ self.deterministic_size = deterministic_size
38
+ self.obs_embed_size = obs_embed_size
39
+ self.action_size = action_size
40
+
41
+ # recurrent
42
+ self.recurrent_linear = nn.Sequential(
43
+ nn.Linear(stochastic_size + action_size, hidden_size),
44
+ activation(),
45
+ )
46
+ self.gru_cell = nn.GRUCell(hidden_size, deterministic_size)
47
+
48
+ # representation model, for calculating posterior
49
+ self.representatio_model = nn.Sequential(
50
+ nn.Linear(deterministic_size + obs_embed_size, hidden_size),
51
+ activation(),
52
+ nn.Linear(hidden_size, stochastic_size*2)
53
+ )
54
+
55
+ # transition model, for calculating prior, use for imagining trajectories
56
+ self.transition_model = nn.Sequential(
57
+ nn.Linear(deterministic_size, hidden_size),
58
+ activation(),
59
+ nn.Linear(hidden_size, stochastic_size*2)
60
+ )
61
+
62
+
63
+
64
+ def recurrent(self, stoch_state, action, deterministic):
65
+ """The recurrent model, calculate the deterministic state given the stochastic state
66
+ the action, and the prior deterministic
67
+
68
+ Args:
69
+ a_t-1 (batch_size, action_size): action at time step, cannot be None.
70
+ s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None.
71
+ h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None.
72
+
73
+ Returns:
74
+ h_t: deterministic at next time step
75
+ """
76
+
77
+ # initialize some sizes
78
+ x = torch.cat((action, stoch_state), -1)
79
+ out = self.recurrent_linear(x)
80
+ out = self.gru_cell(out, deterministic)
81
+ return out
82
+
83
+
84
+ def representation(self, embed_obs, deterministic):
85
+ """Calculate the distribution p of the stochastic state.
86
+
87
+ Args:
88
+ o_t (batch_size, embeded_obs_size): embedded observation (encoded)
89
+ h_t (batch_size, deterministic_size): determinstic size
90
+
91
+ Returns:
92
+ s_t posterior_distribution: distribution of stochastic states
93
+ s_t posterior: sampled stochastic states
94
+ """
95
+ x = torch.cat((embed_obs, deterministic), -1)
96
+ out = self.representatio_model(x)
97
+ mean, std = torch.chunk(out, 2, -1)
98
+ std = F.softplus(std) + 0.1
99
+
100
+ post_dist = torch.distributions.Normal(mean, std)
101
+ post = post_dist.rsample()
102
+
103
+ return post_dist, post
104
+
105
+
106
+ def transition(self, deterministic):
107
+ """Calculate the distribution q of the stochastic state.
108
+
109
+ Args:
110
+ h_t (batch_size, deterministic_size): determinstic size
111
+
112
+ Returns:
113
+ s_t prior_distribution: distribution of stochastic states
114
+ s_t prior: sampled stochastic states
115
+ """
116
+ out = self.transition_model(deterministic)
117
+ mean, std = torch.chunk(out, 2, -1)
118
+ std = F.softplus(std) + 0.1
119
+
120
+ prior_dist = torch.distributions.Normal(mean, std)
121
+ prior = prior_dist.rsample()
122
+ return prior_dist, prior
123
+
124
+
125
+ class ConvEncoder(nn.Module):
126
+ def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU):
127
+ super().__init__()
128
+ self.depth = depth
129
+ self.input_shape = input_shape
130
+ self.conv_layer = nn.Sequential(
131
+ nn.Conv2d(
132
+ in_channels=input_shape[0],
133
+ out_channels=depth * 1,
134
+ kernel_size=4,
135
+ stride=2,
136
+ padding="valid"
137
+ ),
138
+ activation(),
139
+ nn.Conv2d(
140
+ in_channels=depth * 1,
141
+ out_channels=depth * 2,
142
+ kernel_size=4,
143
+ stride=2,
144
+ padding="valid"
145
+ ),
146
+ activation(),
147
+ nn.Conv2d(
148
+ in_channels=depth * 2,
149
+ out_channels=depth * 4,
150
+ kernel_size=4,
151
+ stride=2,
152
+ padding="valid"
153
+ ),
154
+ activation(),
155
+ nn.Conv2d(
156
+ in_channels=depth * 4,
157
+ out_channels=depth * 8,
158
+ kernel_size=4,
159
+ stride=2,
160
+ padding="valid"
161
+ ),
162
+ activation()
163
+ )
164
+ self.conv_layer.apply(initialize_weights)
165
+
166
+
167
+ def forward(self, x):
168
+ batch_shape = x.shape[:-len(self.input_shape)]
169
+ if not batch_shape:
170
+ batch_shape = (1, )
171
+
172
+ x = x.reshape(-1, *self.input_shape)
173
+
174
+ out = self.conv_layer(x)
175
+
176
+ #flatten output
177
+ return out.reshape(*batch_shape, -1)
178
+
179
+
180
+ class ConvDecoder(nn.Module):
181
+ """Decode latent dynamic
182
+ Also referred to as observation model by the official Dreamer paper
183
+
184
+ """
185
+ def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU):
186
+ super().__init__()
187
+ self.out_shape = out_shape
188
+ self.net = nn.Sequential(
189
+ nn.Linear(deterministic_size + stochastic_size, depth*32),
190
+ nn.Unflatten(1, (depth * 32, 1)),
191
+ nn.Unflatten(2, (1, 1)),
192
+ nn.ConvTranspose2d(
193
+ depth * 32,
194
+ depth * 4,
195
+ kernel_size=5,
196
+ stride=2,
197
+ ),
198
+ activation(),
199
+ nn.ConvTranspose2d(
200
+ depth * 4,
201
+ depth * 2,
202
+ kernel_size=5,
203
+ stride=2,
204
+ ),
205
+ activation(),
206
+ nn.ConvTranspose2d(
207
+ depth * 2,
208
+ depth * 1,
209
+ kernel_size=5 + 1,
210
+ stride=2,
211
+ ),
212
+ activation(),
213
+ nn.ConvTranspose2d(
214
+ depth * 1,
215
+ out_shape[0],
216
+ kernel_size=5+1,
217
+ stride=2,
218
+ ),
219
+ )
220
+ self.net.apply(initialize_weights)
221
+
222
+
223
+
224
+ def forward(self, posterior, deterministic, mps_flatten=False):
225
+ """take in the stochastic state (posterior) and deterministic to construct the latent state then
226
+ output reconstructed pixel observation
227
+
228
+ Args:
229
+ s_t (batch_sz, stoch_size): stochastic state (or posterior)
230
+ h_t (batch_sz, deterministic_size): deterministic state
231
+ mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can
232
+ only support max 4 dimension (stupid af)
233
+ Returns:
234
+ o'_t: reconstructed_obs
235
+ """
236
+ x = torch.cat((posterior, deterministic), -1)
237
+ batch_shape = x.shape[:-1]
238
+ if not batch_shape:
239
+ batch_shape = (1, )
240
+
241
+ x = x.reshape(-1, x.shape[-1])
242
+
243
+ if mps_flatten:
244
+ batch_shape = (-1, )
245
+
246
+ mean = self.net(x).reshape(*batch_shape, *self.out_shape)
247
+
248
+ dist = torch.distributions.Normal(mean, 1)
249
+
250
+ # #flatten output
251
+ return torch.distributions.Independent(dist, len(self.out_shape))
252
+
253
+
254
+ class RewardNet(nn.Module):
255
+ """reward prediction model. It take in the stochastic state and the deterministic to construct
256
+ latent state. It then output the reward prediciton
257
+
258
+ Args:
259
+ nn (_type_): _description_
260
+ """
261
+ def __init__(self, input_size, hidden_size, activation=nn.ELU):
262
+ super().__init__()
263
+
264
+ self.net = nn.Sequential(
265
+ nn.Linear(input_size, hidden_size),
266
+ activation(),
267
+ nn.Linear(hidden_size, 1)
268
+ )
269
+
270
+
271
+ def forward(self, stoch_state, deterministic):
272
+ """take in the stochastic state and deterministic to construct the latent state then
273
+ output reard prediction
274
+
275
+ Args:
276
+ s_t (batch_sz, stoch_size): stochastic state (or posterior)
277
+ h_t (batch_sz, deterministic_size): deterministic state
278
+
279
+ Returns:
280
+ r_t: rewards
281
+ """
282
+ x = torch.cat((stoch_state, deterministic), -1)
283
+ batch_shape = x.shape[:-1]
284
+ if not batch_shape:
285
+ batch_shape = (1, )
286
+
287
+ x = x.reshape(-1, x.shape[-1])
288
+
289
+ return self.net(x).reshape(*batch_shape, 1)
290
+
291
+
292
+ class ContinuoNet(nn.Module):
293
+ """continuity prediction model. It take in the stochastic state and the deterministic to construct
294
+ latent state. It then output the prediction of whether the termination state has been reached
295
+
296
+ Args:
297
+ nn (_type_): _description_
298
+ """
299
+ def __init__(self, input_size, hidden_size, activation=nn.ELU):
300
+ super().__init__()
301
+
302
+ self.net = nn.Sequential(
303
+ nn.Linear(input_size, hidden_size),
304
+ activation(),
305
+ nn.Linear(hidden_size, hidden_size),
306
+ activation(),
307
+ nn.Linear(hidden_size, 1)
308
+ )
309
+
310
+
311
+ def forward(self, stoch_state, deterministic):
312
+ """take in the stochastic state and deterministic to construct the latent state then
313
+ output reard prediction
314
+
315
+ Args:
316
+ s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
317
+ h_t deterministic (batch_sz, deterministic_size): deterministic state
318
+
319
+ Returns:
320
+ dist: Beurnoulli distribution of done
321
+ """
322
+ x = torch.cat((stoch_state, deterministic), -1)
323
+ batch_shape = x.shape[:-1]
324
+ if not batch_shape:
325
+ batch_shape = (1, )
326
+
327
+ x = x.reshape(-1, x.shape[-1])
328
+
329
+ x = self.net(x).reshape(*batch_shape, 1)
330
+ return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1)
331
+
332
+
333
+ class Actor(nn.Module):
334
+ """actor network
335
+ """
336
+ def __init__(self,
337
+ latent_size,
338
+ hidden_size,
339
+ action_size,
340
+ discrete=True,
341
+ activation=nn.ELU,
342
+ min_std=1e-4,
343
+ init_std=5,
344
+ mean_scale=5):
345
+
346
+ super().__init__()
347
+ self.latent_size = latent_size
348
+ self.hidden_size = hidden_size
349
+ self.action_size = (action_size if discrete else action_size*2)
350
+ self.discrete = discrete
351
+ self.min_std=min_std
352
+ self.init_std = init_std
353
+ self.mean_scale = mean_scale
354
+
355
+ self.net = nn.Sequential(
356
+ nn.Linear(latent_size, hidden_size),
357
+ activation(),
358
+ nn.Linear(hidden_size, self.action_size)
359
+ )
360
+
361
+
362
+ def forward(self, stoch_state, deterministic):
363
+ """actor network. get in stochastic state and deterministic state to construct latent state
364
+ and then use latent state to predict appropriate action
365
+
366
+ Args:
367
+ s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
368
+ h_t deterministic (batch_sz, deterministic_size): deterministic state
369
+
370
+ Returns:
371
+ action distribution. OneHot if discrete, else is tanhNormal
372
+ """
373
+ latent_state = torch.cat((stoch_state, deterministic), -1)
374
+ x = self.net(latent_state)
375
+
376
+ if self.discrete:
377
+ # straight through gradient (mentioned in DreamerV2)
378
+ dist = torch.distributions.OneHotCategorical(logits=x)
379
+ action = dist.sample() + dist.probs - dist.probs.detach()
380
+ else:
381
+ #ensure that the softplut output proper init_std
382
+ raw_init_std = np.log(np.exp(self.init_std) - 1)
383
+
384
+ mean, std = torch.chunk(x, 2, -1)
385
+ mean = self.mean_scale * F.tanh(mean / self.mean_scale)
386
+ std = F.softplus(std + raw_init_std) + self.min_std
387
+
388
+ dist = torch.distributions.Normal(mean, std)
389
+ dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform())
390
+ action = torch.distributions.Independent(dist, 1).rsample()
391
+
392
+ return action
393
+
394
+
395
+ class Critic(nn.Module):
396
+ """
397
+ critic network
398
+ """
399
+ def __init__(self, latent_size, hidden_size, activation=nn.ELU):
400
+ super().__init__()
401
+ self.latent_size = latent_size
402
+
403
+ self.net = nn.Sequential(
404
+ nn.Linear(latent_size, hidden_size),
405
+ activation(),
406
+ nn.Linear(hidden_size, hidden_size),
407
+ activation(),
408
+ nn.Linear(hidden_size, 1)
409
+ )
410
+
411
+
412
+
413
+ def forward(self, stoch_state, deterministic):
414
+ """critic network. get in stochastic state and deterministic state to construct latent state
415
+ and then use latent state to predict state value
416
+
417
+ Args:
418
+ s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior)
419
+ h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state
420
+
421
+ Returns:
422
+ state value distribution.
423
+ """
424
+ latent_state = torch.cat((stoch_state, deterministic), -1)
425
+
426
+ batch_shape = latent_state.shape[:-1]
427
+ if not batch_shape:
428
+ batch_shape = (1, )
429
+
430
+ latent_state = latent_state.reshape(-1, self.latent_size)
431
+
432
+ x = self.net(latent_state)
433
+
434
+ x = x.reshape(*batch_shape, 1)
435
+
436
+ dist = torch.distributions.Normal(x, 1)
437
+ dist = torch.distributions.Independent(dist, 1)
438
+
439
+ return dist
440
+
utils/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def log_metrics(metrics, step, tb_writer, wandb_writer):
4
+ # Log metrics to TensorBoard
5
+ if tb_writer:
6
+ for key, value in metrics.items():
7
+ tb_writer.add_scalar(key, value, step)
8
+
9
+ # Log metrics to wandb
10
+ # if wandb_writer:
11
+ # wandb_writer.log(metrics, step=step)
12
+
13
+
14
+ def td_lambda(rewards, predicted_discount, values, lambda_, device):
15
+ """
16
+ Compute the TD(λ) returns for value estimation.
17
+
18
+ Args:
19
+ - rewards (Tensor): Tensor of rewards with shape [batch_size, horizon_len, 1].
20
+ - predicted_discount (Tensor): Tensor indicating probability of episode termination with shape [batch_size, horizon_len, 1].
21
+ - values (Tensor): Tensor of value estimates with shape [batch_size, horizon_len, 1].
22
+ - lambda_ (float): The λ parameter in TD(λ) controlling bias-variance tradeoff.
23
+
24
+ Returns:
25
+ - td_lambda (Tensor): The computed lambda returns with shape [batch_size, time_steps - 1].
26
+ """
27
+ batch_size, _, _ = rewards.shape
28
+ last_lambda = torch.zeros((batch_size, 1)).to(device)
29
+ cur_rewards = rewards[:, :-1]
30
+ next_values = values[:, 1:]
31
+ predicted_discount = predicted_discount[:, :-1]
32
+
33
+ td_1 = cur_rewards + predicted_discount * next_values * (1 - lambda_)
34
+ returns = torch.zeros_like(cur_rewards).to(device)
35
+
36
+
37
+ for i in reversed(range(td_1.size(1))):
38
+ last_lambda = td_1[:, i] + predicted_discount[:, i] * lambda_ * last_lambda
39
+ returns[:, i] = last_lambda
40
+
41
+ return returns
utils/wrappers.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Minh Pham-Dinh
3
+ Created: Feb 4th, 2024
4
+ Last Modified: Feb 7th, 2024
5
6
+
7
+ Description:
8
+ File containing wrappers for different environment types.
9
+ """
10
+
11
+ import gymnasium as gym
12
+ from dm_control import suite
13
+ from dm_control.suite.wrappers import pixels
14
+ import numpy as np
15
+ import cv2
16
+ import os
17
+ from dm_control import suite
18
+ from dm_control.rl.control import Environment
19
+
20
+ #wrapper by Hafner et al
21
+ class ActionRepeat:
22
+ def __init__(self, env, repeats):
23
+ self.env = env
24
+ self.repeats = repeats
25
+
26
+ def __getattr__(self, name):
27
+ return getattr(self.env, name)
28
+
29
+ def step(self, action):
30
+ done = False
31
+ total_reward = 0
32
+ current_step = 0
33
+ while current_step < self.repeats and not done:
34
+ obs, reward, termination, truncation, info = self.env.step(action)
35
+ total_reward += reward
36
+ current_step += 1
37
+ done = termination or truncation
38
+ return obs, total_reward, termination, truncation, info
39
+
40
+
41
+ #wrapper by Hafner et al
42
+ class NormalizeActions:
43
+ """
44
+ A wrapper class that normalizes the action space of an environment.
45
+
46
+ Args:
47
+ env (gym.Env): The environment to be wrapped.
48
+
49
+ Attributes:
50
+ _env (gym.Env): The original environment.
51
+ _mask (numpy.ndarray): A boolean mask indicating which action dimensions are finite.
52
+ _low (numpy.ndarray): The lower bounds of the action space.
53
+ _high (numpy.ndarray): The upper bounds of the action space.
54
+ """
55
+
56
+ def __init__(self, env):
57
+ self._env = env
58
+ self._mask = np.logical_and(
59
+ np.isfinite(env.action_space.low),
60
+ np.isfinite(env.action_space.high))
61
+ self._low = np.where(self._mask, env.action_space.low, -1)
62
+ self._high = np.where(self._mask, env.action_space.high, 1)
63
+
64
+ def __getattr__(self, name):
65
+ """
66
+ Delegate attribute access to the original environment.
67
+
68
+ Args:
69
+ name (str): The name of the attribute.
70
+
71
+ Returns:
72
+ Any: The value of the attribute in the original environment.
73
+ """
74
+ return getattr(self._env, name)
75
+
76
+ @property
77
+ def action_space(self):
78
+ """
79
+ Get the normalized action space.
80
+
81
+ Returns:
82
+ gym.spaces.Box: The normalized action space.
83
+ """
84
+ low = np.where(self._mask, -np.ones_like(self._low), self._low)
85
+ high = np.where(self._mask, np.ones_like(self._low), self._high)
86
+ return gym.spaces.Box(low, high, dtype=np.float32)
87
+
88
+ def step(self, action):
89
+ """
90
+ Take a step in the environment with a normalized action.
91
+
92
+ Args:
93
+ action (numpy.ndarray): The normalized action.
94
+
95
+ Returns:
96
+ Tuple: A tuple containing the next state, reward, done flag, and additional information.
97
+ """
98
+ original = (action + 1) / 2 * (self._high - self._low) + self._low
99
+ original = np.where(self._mask, original, action)
100
+ return self._env.step(original)
101
+
102
+
103
+ class DMCtoGymWrapper(gym.Env):
104
+ """
105
+ Wrapper to convert a DeepMind Control Suite environment to a Gymnasium environment with additional features like recording and episode truncation.
106
+
107
+ Args:
108
+ domain_name (str): The name of the domain.
109
+ task_name (str): The name of the task.
110
+ task_kwargs (dict, optional): Additional kwargs for the task.
111
+ visualize_reward (bool, optional): Whether to visualize the reward. Defaults to False.
112
+ resize (list, optional): New size to resize observations. Defaults to [64, 64].
113
+ record (bool, optional): Whether to record episodes. Defaults to False.
114
+ record_freq (int, optional): Frequency (in episodes) to record. Defaults to 100.
115
+ record_path (str, optional): Path to save recorded videos. Defaults to '../'.
116
+ max_episode_steps (int, optional): Maximum steps per episode for truncation. Defaults to 1000.
117
+ """
118
+ def __init__(self, domain_name, task_name, task_kwargs=None, visualize_reward=False, resize=[64,64], record=False, record_freq=100, record_path='../', max_episode_steps=1000, camera=None):
119
+ super().__init__()
120
+ self.env = suite.load(domain_name, task_name, task_kwargs=task_kwargs, visualize_reward=visualize_reward)
121
+ self.episode_count = -1
122
+ self.record = record
123
+ self.record_freq = record_freq
124
+ self.record_path = record_path
125
+ self.max_episode_steps = max_episode_steps
126
+ self.current_step = 0
127
+ self.total_reward = 0
128
+ self.recorder = None
129
+
130
+ # Define action and observation space based on the DMC environment
131
+ action_spec = self.env.action_spec()
132
+ self.action_space = gym.spaces.Box(low=action_spec.minimum, high=action_spec.maximum, dtype=np.float32)
133
+
134
+ # Initialize the pixels wrapper for observation space
135
+ self.env = pixels.Wrapper(self.env, pixels_only=True)
136
+ self.resize = resize # Assuming RGB images
137
+ self.observation_space = gym.spaces.Box(low=-0.5, high=+0.5, shape=(3, *resize), dtype=np.float32)
138
+
139
+ if camera is None:
140
+ camera = dict(quadruped=2).get(domain_name, 0)
141
+ self._camera = camera
142
+
143
+ def step(self, action):
144
+ time_step = self.env.step(action)
145
+ obs = self._get_obs(self.env)
146
+
147
+ reward = time_step.reward if time_step.reward is not None else 0
148
+ self.total_reward += (reward or 0)
149
+ self.current_step += 1
150
+
151
+ termination = time_step.last()
152
+ truncation = (self.current_step == self.max_episode_steps)
153
+ info = {}
154
+ if termination or truncation:
155
+ info = {
156
+ 'episode': {
157
+ 'r': [self.total_reward],
158
+ 'l': self.current_step
159
+ }
160
+ }
161
+
162
+ if self.recorder:
163
+ frame = cv2.cvtColor(self.env.physics.render(camera_id=self._camera), cv2.COLOR_RGB2BGR)
164
+ self.recorder.write(frame)
165
+ video_file = os.path.join(self.record_path, f"episode_{self.episode_count}.webm")
166
+ if termination or truncation:
167
+ self._reset_recorder()
168
+ info['video_path'] = video_file
169
+
170
+ return obs, reward, termination, truncation, info
171
+
172
+ def reset(self):
173
+ self.current_step = 0
174
+ self.total_reward = 0
175
+ self.episode_count += 1
176
+
177
+ time_step = self.env.reset()
178
+ obs = self._get_obs(self.env)
179
+
180
+ if self.record and self.episode_count % self.record_freq == 0:
181
+ self._start_recording(self.env.physics.render(camera_id=self._camera))
182
+
183
+ return obs, {}
184
+
185
+ def _start_recording(self, frame):
186
+ if not os.path.exists(self.record_path):
187
+ os.makedirs(self.record_path)
188
+ video_file = os.path.join(self.record_path, f"episode_{self.episode_count}.webm")
189
+ height, width, _ = frame.shape
190
+ self.recorder = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'vp80'), 30, (width, height))
191
+ self.recorder.write(frame)
192
+
193
+ def _reset_recorder(self):
194
+ if self.recorder:
195
+ self.recorder.release()
196
+ self.recorder = None
197
+
198
+ def _get_obs(self, env):
199
+ obs = self.render()
200
+ obs = obs/255 - 0.5
201
+ rearranged_obs = obs.transpose([2,0,1])
202
+ return rearranged_obs
203
+
204
+ def render(self, mode='rgb_array'):
205
+ return self.env.physics.render(*self.resize, camera_id=self._camera) # Adjust camera_id based on the environment
206
+
207
+
208
+ class AtariPreprocess(gym.Wrapper):
209
+ """
210
+ A custom Gym wrapper that integrates multiple environment processing steps:
211
+ - Records episode statistics and videos.
212
+ - Resizes observations to a specified shape.
213
+ - Scales and reorders observation channels.
214
+ - Scales rewards using the tanh function.
215
+
216
+ Parameters:
217
+ - env (gym.Env): The original environment to wrap.
218
+ - new_obs_size (tuple): The target size for observation resizing (height, width).
219
+ - record (bool): If True, enable video recording.
220
+ - record_path (str): The directory path where videos will be saved.
221
+ - record_freq (int): Frequency (in episodes) at which to record videos.
222
+ """
223
+ def __init__(self, env, new_obs_size, record=False, record_path='../videos/', record_freq=100):
224
+ super().__init__(env)
225
+ self.env = gym.wrappers.RecordEpisodeStatistics(env)
226
+
227
+ if record:
228
+ self.env = gym.wrappers.RecordVideo(self.env, record_path, episode_trigger=lambda episode_id: episode_id % record_freq == 0)
229
+ self.env = gym.wrappers.ResizeObservation(self.env, shape=new_obs_size)
230
+
231
+ self.new_obs_size = new_obs_size
232
+ self.observation_space = gym.spaces.Box(
233
+ low=-0.5, high=0.5,
234
+ shape=(3, new_obs_size[0], new_obs_size[1]),
235
+ dtype=np.float32
236
+ )
237
+
238
+ def step(self, action):
239
+ obs, reward, termination, truncation, info = super().step(action)
240
+ obs = self.process_observation(obs)
241
+ reward = np.tanh(reward) # Scale reward
242
+ return obs, reward, termination, truncation, info
243
+
244
+ def reset(self, **kwargs):
245
+ obs, info = super().reset(**kwargs)
246
+ obs = self.process_observation(obs)
247
+ return obs, info
248
+
249
+ def process_observation(self, observation):
250
+ """
251
+ Process and return the observation from the environment.
252
+ - Scales pixel values to the range [-0.5, 0.5].
253
+ - Reorders channels to CHW format (channels, height, width).
254
+
255
+ Parameters:
256
+ - observation (np.ndarray): The original observation from the environment.
257
+
258
+ Returns:
259
+ - np.ndarray: The processed observation.
260
+ """
261
+ if 'pixels' in observation:
262
+ observation = observation['pixels']
263
+ observation = observation / 255.0 - 0.5
264
+ observation = np.transpose(observation, (2, 0, 1))
265
+ return observation