Upload 30 files
Browse files- .gitattributes +4 -0
- README.md +75 -0
- algos/.DS_Store +0 -0
- algos/dreamer.py +489 -0
- bash/setup.sh +32 -0
- configs/.DS_Store +0 -0
- configs/dm_control/Cart-pole.yml +64 -0
- configs/dm_control/Quadruped.yml +64 -0
- configs/dm_control/Walker.yml +65 -0
- configs/gymnasium/Boxing-v5.yml +67 -0
- configs/gymnasium/Pacman-v5.yml +66 -0
- configs/gymnasium/ant_v4.yml +65 -0
- configs/gymnasium/car_racing_config.yml +64 -0
- gif/boxing.gif +0 -0
- gif/boxing_imagine.gif +0 -0
- gif/pacman.gif +3 -0
- gif/pacman_imagine.gif +3 -0
- gif/quadruped.gif +3 -0
- gif/quadruped_imagine.gif +0 -0
- gif/walker_imagine.gif +3 -0
- imagine.py +88 -0
- requirements.txt +143 -0
- utils/.DS_Store +0 -0
- utils/__pycache__/buffer.cpython-310.pyc +0 -0
- utils/__pycache__/models.cpython-310.pyc +0 -0
- utils/__pycache__/utils.cpython-310.pyc +0 -0
- utils/__pycache__/wrappers.cpython-310.pyc +0 -0
- utils/buffer.py +139 -0
- utils/models.py +440 -0
- utils/utils.py +41 -0
- utils/wrappers.py +265 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
gif/pacman_imagine.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
gif/pacman.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
gif/quadruped.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
gif/walker_imagine.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyDreamerV1: Clean pytorch implementation of Hafner et al Dreamer
|
2 |
+
<div align="center">
|
3 |
+
<img src="./gif/boxing.gif" alt="Actual run in " width="200px" height="200px"/>
|
4 |
+
<img src="./gif/quadruped.gif" alt="Actual run in " width="200px" height="200px"/>
|
5 |
+
<img src="./gif/walker.gif" alt="Actual run in " width="200px" height="200px"/>
|
6 |
+
</div>
|
7 |
+
<div align="center">
|
8 |
+
<img src="./gif/boxing_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
|
9 |
+
<img src="./gif/quadruped_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
|
10 |
+
<img src="./gif/walker_imagine.gif" alt="Imagination in " width="200px" height="200px"/>
|
11 |
+
</div>
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
This repository offers a comprehensive implementation of the Dreamer algorithm, as presented in the groundbreaking work by Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination." Our implementation is dedicated to faithfully reproducing the innovative approach of learning and planning within a learned latent space, enabling agents to efficiently master complex behaviors through imagination alone.
|
16 |
+
|
17 |
+
## Why Dreamer?
|
18 |
+
|
19 |
+
Dreamer stands at the forefront of model-based reinforcement learning by introducing an efficient method for learning behaviors directly from high-dimensional sensory inputs. It leverages a latent dynamics model to 'imagine' future states and rewards, enabling it to plan and execute actions that maximize long-term rewards purely from simulated experiences. This approach significantly improves sample efficiency over traditional model-free methods and opens new avenues for learning complex and nuanced behaviors in simulated environments. However, the official code was unfortunately regarded as complex and difficult to understand, and there are only a handful of Dreamer reimplementation that was able to reproduce the results.
|
20 |
+
|
21 |
+
## Implementation Highlights
|
22 |
+
|
23 |
+
- **Modular Design**: My implementation of the Recurrent State Space Model (RSSM) is broken down into cleanly separated modules for the transition, representation, and recurrent models. This not only facilitates a deeper understanding of the underlying mechanics but also allows for easy customization and extension.
|
24 |
+
|
25 |
+
- **True to the Source**: By closely adhering to the methodologies detailed in the original DreamerV1 paper, the code captures the essence of latent space learning and imagination-driven planning. From the incorporation of exploration noise to the td lambda calculation, every element is designed to replicate the paper's results as closely as possible. The sets of hyperparamenters are excactly indentical to the sets mentioned in the paper
|
26 |
+
|
27 |
+
- **Detailed Training Insights**: The training loop is separated and mirroring the paper's outline. Comprehensive comments of hidden implementation details thorough documentation accompany the code, serving as a valuable resource for both learning and further research.
|
28 |
+
|
29 |
+
## Getting Started
|
30 |
+
|
31 |
+
1. **Clone the Repository**: Get the code by cloning this repository to your local machine.
|
32 |
+
```
|
33 |
+
git clone https://github.com/minhphd/PyDreamerV1
|
34 |
+
```
|
35 |
+
|
36 |
+
2. **Install Dependencies**: Ensure you have all necessary dependencies by running:
|
37 |
+
```
|
38 |
+
pip3 install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
3. **Run the Training**: Kickstart the training process with a simple command:
|
42 |
+
```
|
43 |
+
python main.py --config <Path to config file>
|
44 |
+
```
|
45 |
+
|
46 |
+
4. **Visualize Results**: Utilize TensorBoard to observe training progress and visualize the agent's performance in real-time. Wandb is also supported, simply set enable to True and replace with your account information in config files.
|
47 |
+
```
|
48 |
+
tensorboard --logdir=runs
|
49 |
+
```
|
50 |
+
**Optional: Visualize imagine sequences**: Using saved models to visualize agent's prediction of environment dynamic. You would need to create a config folder in the run logging directory and drag the training config file in
|
51 |
+
```
|
52 |
+
python imagine.py --runpath <Path to run file>
|
53 |
+
```
|
54 |
+
|
55 |
+
## Citation
|
56 |
+
This implementation was made possible thanks to these papers.
|
57 |
+
```bibtex
|
58 |
+
@article{hafner2019dream,
|
59 |
+
title={Dream to Control: Learning Behaviors by Latent Imagination},
|
60 |
+
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
|
61 |
+
journal={arXiv preprint arXiv:1912.01603},
|
62 |
+
year={2019}
|
63 |
+
}
|
64 |
+
@misc{1801.00690,
|
65 |
+
title = {DeepMind Control Suite},
|
66 |
+
author = {Yuval Tassa and Yotam Doron and Alistair Muldal and Tom Erez and Yazhe Li and Diego de Las Casas and David Budden and Abbas Abdolmaleki and Josh Merel and Andrew Lefrancq and Timothy Lillicrap and Martin Riedmiller},
|
67 |
+
journal = {arXiv preprint arXiv:1801.00690},
|
68 |
+
year = {2018},
|
69 |
+
}
|
70 |
+
|
71 |
+
```
|
72 |
+
|
73 |
+
## Contributions
|
74 |
+
|
75 |
+
Contributions are welcome! Whether it's extending functionality, improving efficiency, or correcting bugs, your input helps make this project better for everyone.
|
algos/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
algos/dreamer.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Minh Pham-Dinh
|
3 |
+
Created: Jan 27th, 2024
|
4 |
+
Last Modified: Feb 10th, 2024
|
5 |
+
Email: [email protected]
|
6 |
+
|
7 |
+
Description:
|
8 |
+
main Dreamer file.
|
9 |
+
|
10 |
+
The implementation is based on:
|
11 |
+
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
|
12 |
+
[Online]. Available: https://arxiv.org/abs/1912.01603
|
13 |
+
"""
|
14 |
+
|
15 |
+
# Standard Library Imports
|
16 |
+
import os
|
17 |
+
import numpy as np
|
18 |
+
import yaml
|
19 |
+
from tqdm import tqdm
|
20 |
+
import wandb
|
21 |
+
|
22 |
+
# Machine Learning and Data Processing Imports
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
import torch.optim as optim
|
26 |
+
|
27 |
+
# Custom Utility Imports
|
28 |
+
import utils.models as models
|
29 |
+
from utils.buffer import ReplayBuffer
|
30 |
+
from utils.utils import td_lambda, log_metrics
|
31 |
+
|
32 |
+
|
33 |
+
class Dreamer:
|
34 |
+
def __init__(self, config, logpath, env, writer = None, wandb_writer=None):
|
35 |
+
self.config = config
|
36 |
+
self.device = torch.device(self.config.device)
|
37 |
+
self.env = env
|
38 |
+
self.obs_size = env.observation_space.shape
|
39 |
+
self.action_size = env.action_space.n if self.config.env.discrete else env.action_space.shape[0]
|
40 |
+
self.epsilon = self.config.main.epsilon_start
|
41 |
+
self.env_step = 0
|
42 |
+
self.logpath = logpath
|
43 |
+
|
44 |
+
# Set random seed for reproducibility
|
45 |
+
np.random.seed(self.config.seed)
|
46 |
+
torch.manual_seed(self.config.seed)
|
47 |
+
|
48 |
+
#dynamic networks initialized
|
49 |
+
self.rssm = models.RSSM(self.config.main.stochastic_size,
|
50 |
+
self.config.main.embedded_obs_size,
|
51 |
+
self.config.main.deterministic_size,
|
52 |
+
self.config.main.hidden_units,
|
53 |
+
self.action_size).to(self.device)
|
54 |
+
|
55 |
+
self.reward = models.RewardNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
|
56 |
+
self.config.main.hidden_units).to(self.device)
|
57 |
+
|
58 |
+
if self.config.main.continue_loss:
|
59 |
+
self.cont_net = models.ContinuoNet(self.config.main.stochastic_size + self.config.main.deterministic_size,
|
60 |
+
self.config.main.hidden_units).to(self.device)
|
61 |
+
|
62 |
+
self.encoder = models.ConvEncoder(input_shape=self.obs_size).to(self.device)
|
63 |
+
self.decoder = models.ConvDecoder(self.config.main.stochastic_size,
|
64 |
+
self.config.main.deterministic_size,
|
65 |
+
out_shape=self.obs_size).to(self.device)
|
66 |
+
self.dyna_parameters = (
|
67 |
+
list(self.rssm.parameters())
|
68 |
+
+ list(self.reward.parameters())
|
69 |
+
+ list(self.encoder.parameters())
|
70 |
+
+ list(self.decoder.parameters())
|
71 |
+
)
|
72 |
+
|
73 |
+
if self.config.main.continue_loss:
|
74 |
+
self.dyna_parameters += list(self.cont_net.parameters())
|
75 |
+
|
76 |
+
#behavior networks initialized
|
77 |
+
self.actor = models.Actor(self.config.main.stochastic_size + self.config.main.deterministic_size,
|
78 |
+
self.config.main.hidden_units,
|
79 |
+
self.action_size,
|
80 |
+
self.config.env.discrete).to(self.device)
|
81 |
+
self.critic = models.Critic(self.config.main.stochastic_size + self.config.main.deterministic_size,
|
82 |
+
self.config.main.hidden_units).to(self.device)
|
83 |
+
|
84 |
+
#optimizers
|
85 |
+
self.dyna_optimizer = optim.Adam(self.dyna_parameters, lr=self.config.main.dyna_model_lr)
|
86 |
+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.config.main.actor_lr)
|
87 |
+
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.config.main.critic_lr)
|
88 |
+
self.gradient_step = 0
|
89 |
+
|
90 |
+
#buffer
|
91 |
+
self.buffer = ReplayBuffer(self.config.main.buffer_capacity, self.obs_size, (self.action_size, ))
|
92 |
+
|
93 |
+
#tracking stuff
|
94 |
+
self.wandb_writer = wandb_writer
|
95 |
+
self.writer = writer
|
96 |
+
|
97 |
+
|
98 |
+
def update_epsilon(self):
|
99 |
+
"""In use for decaying epsilon in discrete env
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
_type_: _description_
|
103 |
+
"""
|
104 |
+
eps_start = self.config.main.epsilon_start
|
105 |
+
eps_end = self.config.main.epsilon_end
|
106 |
+
decay_steps = self.config.main.eps_decay_steps
|
107 |
+
decay_rate = (eps_start - eps_end) / (decay_steps)
|
108 |
+
self.epsilon = max(eps_end, eps_start - decay_rate*self.gradient_step)
|
109 |
+
|
110 |
+
|
111 |
+
def train(self):
|
112 |
+
"""main training loop, implementation follow closely with the loop from the official paper
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
_type_: _description_
|
116 |
+
"""
|
117 |
+
|
118 |
+
#prefill dataset
|
119 |
+
ep = 0
|
120 |
+
obs, _ = self.env.reset()
|
121 |
+
while ep < self.config.main.data_init_ep:
|
122 |
+
action = self.env.action_space.sample()
|
123 |
+
if self.config.env.discrete:
|
124 |
+
actions = np.zeros((self.action_size, ))
|
125 |
+
actions[action] = 1.0
|
126 |
+
else:
|
127 |
+
actions = action
|
128 |
+
|
129 |
+
next_obs, reward, termination, truncation, info = self.env.step(action)
|
130 |
+
|
131 |
+
self.buffer.add(obs, actions, reward, termination or truncation)
|
132 |
+
obs = next_obs
|
133 |
+
|
134 |
+
if "episode" in info:
|
135 |
+
obs, _ = self.env.reset()
|
136 |
+
ep += 1
|
137 |
+
print(ep)
|
138 |
+
if 'video_path' in info and self.wandb_writer:
|
139 |
+
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
|
140 |
+
|
141 |
+
#main train loop
|
142 |
+
for _ in tqdm(range(self.config.main.total_iter)):
|
143 |
+
#save model if reached checkpoint
|
144 |
+
if _ % self.config.main.save_freq == 0:
|
145 |
+
|
146 |
+
#check if models folder exist
|
147 |
+
directory = self.logpath + 'models/'
|
148 |
+
os.makedirs(directory, exist_ok=True)
|
149 |
+
|
150 |
+
#save models
|
151 |
+
torch.save(self.rssm, self.logpath + 'models/rssm_model')
|
152 |
+
torch.save(self.encoder, self.logpath + 'models/encoder')
|
153 |
+
torch.save(self.decoder, self.logpath + 'models/decoder')
|
154 |
+
torch.save(self.actor, self.logpath + 'models/actor')
|
155 |
+
torch.save(self.critic, self.logpath + 'models/critic')
|
156 |
+
|
157 |
+
#run eval if reach eval checkpoint
|
158 |
+
if _ % self.config.main.eval_freq == 0:
|
159 |
+
eval_score = self.data_collection(self.config.main.eval_eps, eval=True)
|
160 |
+
metrics = {'performance/evaluation score': eval_score}
|
161 |
+
log_metrics(metrics, self.env_step, self.writer, self.wandb_writer)
|
162 |
+
|
163 |
+
#training step
|
164 |
+
for c in tqdm(range(self.config.main.collect_iter)):
|
165 |
+
#draw data
|
166 |
+
batch = self.buffer.sample(self.config.main.batch_size, self.config.main.seq_len, self.device)
|
167 |
+
|
168 |
+
#dynamic learning
|
169 |
+
post, deter = self.dynamic_learning(batch)
|
170 |
+
|
171 |
+
#behavioral learning
|
172 |
+
self.behavioral_learning(post, deter)
|
173 |
+
|
174 |
+
#update step
|
175 |
+
self.gradient_step += 1
|
176 |
+
self.update_epsilon()
|
177 |
+
|
178 |
+
# collect more data with exploration noise
|
179 |
+
self.data_collection(self.config.main.data_interact_ep)
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
def dynamic_learning(self, batch):
|
184 |
+
"""Learning the dynamic model. In this method, we sequentially pass data in the RSSM to
|
185 |
+
learn the model
|
186 |
+
|
187 |
+
Args:
|
188 |
+
batch (addict.Dict): batches of data
|
189 |
+
"""
|
190 |
+
|
191 |
+
'''
|
192 |
+
We unpack the batch. A batch contains:
|
193 |
+
- b_obs (batch_size, seq_len, *obs.shape): batches of observation
|
194 |
+
- b_a (batch_size, seq_len, 1): batches of action
|
195 |
+
- b_r (batch_size, seq_len, 1): batches of rewards
|
196 |
+
- b_d (batch_size, seq_len, 1): batches of termination signal
|
197 |
+
'''
|
198 |
+
b_obs = batch.obs
|
199 |
+
b_a = batch.actions
|
200 |
+
b_r = batch.rewards
|
201 |
+
b_d = batch.dones
|
202 |
+
|
203 |
+
batch_size, seq_len, _ = b_r.shape
|
204 |
+
eb_obs = self.encoder(b_obs)
|
205 |
+
|
206 |
+
#initialized stochastic states (posterior) and deterministic states to first pass into the recurrent model
|
207 |
+
posterior = torch.zeros((batch_size, self.config.main.stochastic_size)).to(self.device)
|
208 |
+
deterministic = torch.zeros((batch_size, self.config.main.deterministic_size)).to(self.device)
|
209 |
+
|
210 |
+
#initialized memory storing of sequential gradients data
|
211 |
+
posteriors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
|
212 |
+
priors = torch.zeros((batch_size, seq_len-1, self.config.main.stochastic_size)).to(self.device)
|
213 |
+
deterministics = torch.zeros((batch_size, seq_len-1, self.config.main.deterministic_size)).to(self.device)
|
214 |
+
|
215 |
+
posterior_means = torch.zeros_like(posteriors).to(self.device)
|
216 |
+
posterior_stds = torch.zeros_like(posteriors).to(self.device)
|
217 |
+
prior_means = torch.zeros_like(priors).to(self.device)
|
218 |
+
prior_stds = torch.zeros_like(priors).to(self.device)
|
219 |
+
|
220 |
+
#start passing data through the dynamic model
|
221 |
+
for t in (range(1, seq_len)):
|
222 |
+
deterministic = self.rssm.recurrent(posterior, b_a[:, t-1, :], deterministic)
|
223 |
+
prior_dist, prior = self.rssm.transition(deterministic)
|
224 |
+
|
225 |
+
#detail observation is shifted 1 timestep ahead(action is associated with the next state)
|
226 |
+
posterior_dist, posterior = self.rssm.representation(eb_obs[:, t, :], deterministic)
|
227 |
+
|
228 |
+
'''
|
229 |
+
store recurrent data
|
230 |
+
data are shifted 1 timestep ahead. Start from the second timestep or t=1
|
231 |
+
'''
|
232 |
+
posteriors[:, t-1, :] = posterior
|
233 |
+
posterior_means[:, t-1, :] = posterior_dist.mean
|
234 |
+
posterior_stds[:, t-1, :] = posterior_dist.scale
|
235 |
+
|
236 |
+
priors[:, t-1, :] = prior
|
237 |
+
prior_means[:, t-1, :] = prior_dist.mean
|
238 |
+
prior_stds[:, t-1, :] = prior_dist.scale
|
239 |
+
|
240 |
+
deterministics[:, t-1, :] = deterministic
|
241 |
+
|
242 |
+
#we start optimizing model with the provided data
|
243 |
+
|
244 |
+
'''
|
245 |
+
Reconstruction loss. This loss helps the model learn to encode pixels observation.
|
246 |
+
'''
|
247 |
+
mps_flatten = False
|
248 |
+
if self.device == torch.device("mps"):
|
249 |
+
mps_flatten = True
|
250 |
+
|
251 |
+
reconstruct_dist = self.decoder(posteriors, deterministics, mps_flatten)
|
252 |
+
target = b_obs[:, 1:]
|
253 |
+
if mps_flatten:
|
254 |
+
target = target.reshape(-1, *self.obs_size)
|
255 |
+
reconstruct_loss = reconstruct_dist.log_prob(target).mean()
|
256 |
+
|
257 |
+
#reward loss
|
258 |
+
rewards = self.reward(posteriors, deterministics)
|
259 |
+
rewards_dist = torch.distributions.Normal(rewards, 1)
|
260 |
+
rewards_dist = torch.distributions.Independent(rewards_dist, 1)
|
261 |
+
rewards_loss = rewards_dist.log_prob(b_r[:, 1:]).mean()
|
262 |
+
|
263 |
+
'''
|
264 |
+
Continuity loss. This loss term helps predict the probability of an episode terminate at a particular state
|
265 |
+
'''
|
266 |
+
if self.config.main.continue_loss:
|
267 |
+
# calculate log prob manually as tensorflow doesn't support float value in logprob of Bernoulli
|
268 |
+
# follow closely to Hafner's official code for Dreamer
|
269 |
+
cont_logits, _ = self.cont_net(posteriors, deterministics)
|
270 |
+
cont_target = (1 - b_d[:, 1:]) * self.config.main.discount
|
271 |
+
continue_loss = torch.nn.functional.binary_cross_entropy_with_logits(cont_logits, cont_target)
|
272 |
+
else:
|
273 |
+
continue_loss = torch.zeros((1)).to(self.device)
|
274 |
+
|
275 |
+
'''
|
276 |
+
KL loss. Matching the distribution of transition and representation model. This is to ensure we have the accurate transition model for use in imagination process
|
277 |
+
'''
|
278 |
+
priors_dist = torch.distributions.Independent(
|
279 |
+
torch.distributions.Normal(prior_means, prior_stds), 1
|
280 |
+
)
|
281 |
+
posteriors_dist = torch.distributions.Independent(
|
282 |
+
torch.distributions.Normal(posterior_means, posterior_stds), 1
|
283 |
+
)
|
284 |
+
kl_loss = torch.max(
|
285 |
+
torch.mean(torch.distributions.kl.kl_divergence(posteriors_dist, priors_dist)),
|
286 |
+
torch.tensor(self.config.main.free_nats).to(self.device)
|
287 |
+
)
|
288 |
+
|
289 |
+
total_loss = self.config.main.kl_divergence_scale * kl_loss - reconstruct_loss - rewards_loss + continue_loss
|
290 |
+
|
291 |
+
self.dyna_optimizer.zero_grad()
|
292 |
+
total_loss.backward()
|
293 |
+
nn.utils.clip_grad_norm_(
|
294 |
+
self.dyna_parameters,
|
295 |
+
self.config.main.clip_grad,
|
296 |
+
norm_type=self.config.main.grad_norm_type,
|
297 |
+
)
|
298 |
+
self.dyna_optimizer.step()
|
299 |
+
|
300 |
+
#tensorboard logging
|
301 |
+
metrics = {
|
302 |
+
'Dynamic_model/KL': kl_loss.item(),
|
303 |
+
'Dynamic_model/Reconstruction': reconstruct_loss.item(),
|
304 |
+
'Dynamic_model/Reward': rewards_loss.item(),
|
305 |
+
'Dynamic_model/Continue': continue_loss.item(),
|
306 |
+
'Dynamic_model/Total': total_loss.item()
|
307 |
+
}
|
308 |
+
|
309 |
+
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
|
310 |
+
|
311 |
+
return posteriors.detach(), deterministics.detach()
|
312 |
+
|
313 |
+
|
314 |
+
def behavioral_learning(self, state, deterministics):
|
315 |
+
"""Learning behavioral through latent imagination
|
316 |
+
|
317 |
+
Args:
|
318 |
+
self (_type_): _description_
|
319 |
+
state (batch_size, seq_len-1, stoch_state_size): starting point state
|
320 |
+
deterministics (batch_size, seq_len-1, stoch_state_size)
|
321 |
+
"""
|
322 |
+
|
323 |
+
#flatten the batches --> new size (batch_size * (seq_len-1), *)
|
324 |
+
state = state.reshape(-1, self.config.main.stochastic_size)
|
325 |
+
deterministics = deterministics.reshape(-1, self.config.main.deterministic_size)
|
326 |
+
|
327 |
+
batch_size, stochastic_size = state.shape
|
328 |
+
_, deterministics_size = deterministics.shape
|
329 |
+
|
330 |
+
#initialized trajectories
|
331 |
+
state_trajectories = torch.zeros((batch_size, self.config.main.horizon, stochastic_size)).to(self.device)
|
332 |
+
deterministics_trajectories = torch.zeros((batch_size, self.config.main.horizon, deterministics_size)).to(self.device)
|
333 |
+
|
334 |
+
#imagine trajectories
|
335 |
+
for t in range(self.config.main.horizon):
|
336 |
+
# do not include the starting state
|
337 |
+
action = self.actor(state, deterministics)
|
338 |
+
deterministics = self.rssm.recurrent(state, action, deterministics)
|
339 |
+
_, state = self.rssm.transition(deterministics)
|
340 |
+
state_trajectories[:, t, :] = state
|
341 |
+
deterministics_trajectories[:, t, :] = deterministics
|
342 |
+
|
343 |
+
'''
|
344 |
+
After imagining, we have both the state trajectories and deterministic trajectories, which can be used to create latent states.
|
345 |
+
- state_trajectories (N, HORIZON_LEN)
|
346 |
+
- deteerministic_trajectories (N, HORIZON_LEN)
|
347 |
+
'''
|
348 |
+
|
349 |
+
#actor update
|
350 |
+
|
351 |
+
#compute rewards for each trajectories
|
352 |
+
rewards = self.reward(state_trajectories, deterministics_trajectories)
|
353 |
+
rewards_dist = torch.distributions.Normal(rewards, 1)
|
354 |
+
rewards_dist = torch.distributions.Independent(rewards_dist, 1)
|
355 |
+
rewards = rewards_dist.mode
|
356 |
+
|
357 |
+
if self.config.main.continue_loss:
|
358 |
+
_, conts_dist = self.cont_net(state_trajectories, deterministics_trajectories)
|
359 |
+
continues = conts_dist.mean
|
360 |
+
else:
|
361 |
+
continues = self.config.main.discount * torch.ones_like(rewards)
|
362 |
+
|
363 |
+
values = self.critic(state_trajectories, deterministics_trajectories).mode
|
364 |
+
|
365 |
+
#calculate trajectories returns
|
366 |
+
#returns should have shape (N, HORIZON_LEN - 1, 1) (last values are ignored due to nature of bootstrapping)
|
367 |
+
returns = td_lambda(
|
368 |
+
rewards,
|
369 |
+
continues,
|
370 |
+
values,
|
371 |
+
self.config.main.lambda_,
|
372 |
+
self.device
|
373 |
+
)
|
374 |
+
|
375 |
+
#culm product for discount
|
376 |
+
discount = torch.cumprod(torch.cat((
|
377 |
+
torch.ones_like(continues[:, :1]).to(self.device),
|
378 |
+
continues[:, :-2]
|
379 |
+
), 1), 1).detach()
|
380 |
+
|
381 |
+
# actor optimizing
|
382 |
+
actor_loss = -(discount * returns).mean()
|
383 |
+
|
384 |
+
self.actor_optimizer.zero_grad()
|
385 |
+
actor_loss.backward()
|
386 |
+
nn.utils.clip_grad_norm_(
|
387 |
+
self.actor.parameters(),
|
388 |
+
self.config.main.clip_grad,
|
389 |
+
norm_type=self.config.main.grad_norm_type,
|
390 |
+
)
|
391 |
+
self.actor_optimizer.step()
|
392 |
+
|
393 |
+
|
394 |
+
# critic optimizing
|
395 |
+
values_dist = self.critic(state_trajectories[:, :-1].detach(), deterministics_trajectories[:, :-1].detach())
|
396 |
+
|
397 |
+
critic_loss = -(discount.squeeze() * values_dist.log_prob(returns.detach())).mean()
|
398 |
+
|
399 |
+
self.critic_optimizer.zero_grad()
|
400 |
+
critic_loss.backward()
|
401 |
+
nn.utils.clip_grad_norm_(
|
402 |
+
self.critic.parameters(),
|
403 |
+
self.config.main.clip_grad,
|
404 |
+
norm_type=self.config.main.grad_norm_type,
|
405 |
+
)
|
406 |
+
self.critic_optimizer.step()
|
407 |
+
|
408 |
+
metrics = {
|
409 |
+
'Behavorial_model/Actor': actor_loss.item(),
|
410 |
+
'Behavorial_model/Critic': critic_loss.item()
|
411 |
+
}
|
412 |
+
|
413 |
+
log_metrics(metrics, self.gradient_step, self.writer, self.wandb_writer)
|
414 |
+
|
415 |
+
|
416 |
+
@torch.no_grad()
|
417 |
+
def data_collection(self, num_episodes, eval=False):
|
418 |
+
"""data collection method. Roll out agent a number of episodes and collect data
|
419 |
+
If eval=True. The agent is set for evaluation mode with no exploration noise and data collection
|
420 |
+
|
421 |
+
Args:
|
422 |
+
num_episodes (int): number of episodes
|
423 |
+
eval (bool): Evaluation mode. Defaults to False.
|
424 |
+
random (bool): Random mode. Defaults to False.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
average_score: average score over number of rollout episodes
|
428 |
+
"""
|
429 |
+
score = 0
|
430 |
+
ep = 0
|
431 |
+
obs, _ = self.env.reset()
|
432 |
+
#initialized all zeros
|
433 |
+
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
|
434 |
+
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
|
435 |
+
action = torch.zeros((1, self.action_size)).to(self.device)
|
436 |
+
|
437 |
+
while ep < num_episodes:
|
438 |
+
embed_obs = self.encoder(torch.from_numpy(obs).to(self.device, dtype=torch.float)) #(1, embed_obs_sz)
|
439 |
+
deterministic = self.rssm.recurrent(posterior, action, deterministic)
|
440 |
+
_, posterior = self.rssm.representation(embed_obs, deterministic)
|
441 |
+
actor_out = self.actor(posterior, deterministic)
|
442 |
+
|
443 |
+
#detail: add exploration noise if not in evaluation mode
|
444 |
+
if not eval:
|
445 |
+
actions = actor_out.cpu().numpy()
|
446 |
+
if self.config.env.discrete:
|
447 |
+
if np.random.rand() < self.epsilon:
|
448 |
+
action = self.env.action_space.sample()
|
449 |
+
else:
|
450 |
+
action = np.argmax(actions)
|
451 |
+
else:
|
452 |
+
mean_noise = self.config.main.mean_noise
|
453 |
+
std_noise = self.config.main.std_noise
|
454 |
+
|
455 |
+
normal_dist = torch.distributions.Normal(actor_out + mean_noise, std_noise)
|
456 |
+
sampled_action = normal_dist.sample().cpu().numpy()
|
457 |
+
actions = np.clip(sampled_action, a_min=-1, a_max=1)
|
458 |
+
action = actions[0]
|
459 |
+
else:
|
460 |
+
actions = actor_out.cpu().numpy()
|
461 |
+
if self.config.env.discrete:
|
462 |
+
action = np.argmax(actions)
|
463 |
+
else:
|
464 |
+
actions = np.clip(actions, a_min=-1, a_max=1)
|
465 |
+
action = actions[0]
|
466 |
+
|
467 |
+
next_obs, reward, termination, truncation, info = self.env.step(action)
|
468 |
+
|
469 |
+
if not eval:
|
470 |
+
self.buffer.add(obs, actions, reward, termination | truncation)
|
471 |
+
self.env_step += self.config.env.action_repeat
|
472 |
+
obs = next_obs
|
473 |
+
|
474 |
+
action = actor_out
|
475 |
+
if "episode" in info:
|
476 |
+
cur_score = info["episode"]["r"][0]
|
477 |
+
score += cur_score
|
478 |
+
obs, _ = self.env.reset()
|
479 |
+
ep += 1
|
480 |
+
|
481 |
+
if 'video_path' in info and self.wandb_writer:
|
482 |
+
self.wandb_writer.log({'performance/videos': wandb.Video(info['video_path'], format='webm')})
|
483 |
+
log_metrics({'performance/training score': cur_score}, self.env_step, self.writer, self.wandb_writer)
|
484 |
+
|
485 |
+
posterior = torch.zeros((1, self.config.main.stochastic_size)).to(self.device)
|
486 |
+
deterministic = torch.zeros((1, self.config.main.deterministic_size)).to(self.device)
|
487 |
+
action = torch.zeros((1, self.action_size)).to(self.device)
|
488 |
+
|
489 |
+
return score/num_episodes
|
bash/setup.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Update and Upgrade the System
|
4 |
+
echo "Updating and upgrading the system..."
|
5 |
+
sudo apt-get update -y
|
6 |
+
sudo apt-get upgrade -y
|
7 |
+
|
8 |
+
# Install dependencies for Gymnasium
|
9 |
+
echo "Installing dependencies for Gymnasium..."
|
10 |
+
|
11 |
+
# Development tools
|
12 |
+
sudo apt-get install -y build-essential
|
13 |
+
|
14 |
+
# Python 3 and pip
|
15 |
+
sudo apt-get install -y python3 python3-pip
|
16 |
+
sudo apt-get install python3-opencv
|
17 |
+
|
18 |
+
# System libraries
|
19 |
+
sudo apt-get install -y libglew-dev libjpeg-dev libboost-all-dev libglu1-mesa-dev freeglut3-dev mesa-common-dev
|
20 |
+
|
21 |
+
# SWIG for interface generation
|
22 |
+
sudo apt-get install -y swig
|
23 |
+
|
24 |
+
# Gymnasium and additional dependencies via pip
|
25 |
+
echo "Installing requirements.txt"
|
26 |
+
pip3 install -r requirements.txt
|
27 |
+
sudo apt-get install xvfb
|
28 |
+
Xvfb :99 -screen 0 1024x768x24 &
|
29 |
+
|
30 |
+
export DISPLAY=:99
|
31 |
+
|
32 |
+
echo "Setup complete!"
|
configs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/dm_control/Cart-pole.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "cuda"
|
2 |
+
experiment_name: Cart-pole
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
env:
|
6 |
+
env_id: cartpole
|
7 |
+
task: balance
|
8 |
+
discrete: False
|
9 |
+
new_obs_size: [64, 64]
|
10 |
+
norm_obs: True
|
11 |
+
|
12 |
+
tensorboard:
|
13 |
+
enable: False
|
14 |
+
log_dir: "./runs/"
|
15 |
+
log_frequency: 1 # Log every 1000 steps
|
16 |
+
|
17 |
+
wandb:
|
18 |
+
enable: True
|
19 |
+
project: "dreamer"
|
20 |
+
entity: "phdminh01"
|
21 |
+
log_frequency: 1
|
22 |
+
|
23 |
+
video_recording:
|
24 |
+
enable: True
|
25 |
+
record_frequency: 50 #episodes
|
26 |
+
|
27 |
+
main:
|
28 |
+
continue_loss: False
|
29 |
+
continue_scale_factor: 10
|
30 |
+
total_iter: 2000
|
31 |
+
save_freq: 20
|
32 |
+
collect_iter: 100
|
33 |
+
data_interact_ep: 1
|
34 |
+
# data_init_ep: 1
|
35 |
+
data_init_ep: 5
|
36 |
+
horizon: 15
|
37 |
+
batch_size: 50
|
38 |
+
seq_len: 50
|
39 |
+
eval_eps: 3
|
40 |
+
eval_freq: 5
|
41 |
+
|
42 |
+
kl_divergence_scale : 1
|
43 |
+
free_nats : 3
|
44 |
+
discount : 0.99
|
45 |
+
lambda_ : 0.95
|
46 |
+
|
47 |
+
actor_lr : 8.0e-5
|
48 |
+
critic_lr : 8.0e-5
|
49 |
+
dyna_model_lr : 6.0e-4
|
50 |
+
grad_norm_type : 2
|
51 |
+
clip_grad : 100
|
52 |
+
|
53 |
+
hidden_units: 300
|
54 |
+
deterministic_size : 200
|
55 |
+
stochastic_size : 30
|
56 |
+
embedded_obs_size : 1024
|
57 |
+
buffer_capacity : 500000
|
58 |
+
|
59 |
+
epsilon_start: 0.4
|
60 |
+
epsilon_end: 0.1
|
61 |
+
eps_decay_steps: 200000
|
62 |
+
|
63 |
+
mean_noise: 0
|
64 |
+
std_noise: 0.3
|
configs/dm_control/Quadruped.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "cuda"
|
2 |
+
experiment_name: Quadruped
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
env:
|
6 |
+
env_id: quadruped
|
7 |
+
task: walk
|
8 |
+
discrete: False
|
9 |
+
new_obs_size: [64, 64]
|
10 |
+
norm_obs: True
|
11 |
+
|
12 |
+
tensorboard:
|
13 |
+
enable: False
|
14 |
+
log_dir: "./runs/"
|
15 |
+
log_frequency: 1 # Log every 1000 steps
|
16 |
+
|
17 |
+
wandb:
|
18 |
+
enable: True
|
19 |
+
project: "dreamer"
|
20 |
+
entity: "phdminh01"
|
21 |
+
log_frequency: 1
|
22 |
+
|
23 |
+
video_recording:
|
24 |
+
enable: True
|
25 |
+
record_frequency: 50 #episodes
|
26 |
+
|
27 |
+
main:
|
28 |
+
continue_loss: False
|
29 |
+
continue_scale_factor: 10
|
30 |
+
total_iter: 2000
|
31 |
+
save_freq: 20
|
32 |
+
collect_iter: 100
|
33 |
+
data_interact_ep: 1
|
34 |
+
# data_init_ep: 1
|
35 |
+
data_init_ep: 5
|
36 |
+
horizon: 15
|
37 |
+
batch_size: 50
|
38 |
+
seq_len: 50
|
39 |
+
eval_eps: 3
|
40 |
+
eval_freq: 5
|
41 |
+
|
42 |
+
kl_divergence_scale : 1
|
43 |
+
free_nats : 3
|
44 |
+
discount : 0.99
|
45 |
+
lambda_ : 0.95
|
46 |
+
|
47 |
+
actor_lr : 8.0e-5
|
48 |
+
critic_lr : 8.0e-5
|
49 |
+
dyna_model_lr : 6.0e-4
|
50 |
+
grad_norm_type : 2
|
51 |
+
clip_grad : 100
|
52 |
+
|
53 |
+
hidden_units: 300
|
54 |
+
deterministic_size : 200
|
55 |
+
stochastic_size : 30
|
56 |
+
embedded_obs_size : 1024
|
57 |
+
buffer_capacity : 500000
|
58 |
+
|
59 |
+
epsilon_start: 0.4
|
60 |
+
epsilon_end: 0.1
|
61 |
+
eps_decay_steps: 200000
|
62 |
+
|
63 |
+
mean_noise: 0
|
64 |
+
std_noise: 0.3
|
configs/dm_control/Walker.yml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "mps"
|
2 |
+
experiment_name: Walker
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
env:
|
6 |
+
env_id: walker
|
7 |
+
task: walk
|
8 |
+
new_obs_size: [64, 64]
|
9 |
+
action_repeat: 2
|
10 |
+
time_limit: 1000
|
11 |
+
|
12 |
+
tensorboard:
|
13 |
+
enable: False
|
14 |
+
log_dir: "./runs/"
|
15 |
+
log_frequency: 1 # Log every 1000 steps
|
16 |
+
|
17 |
+
wandb:
|
18 |
+
enable: False
|
19 |
+
project: "dreamer"
|
20 |
+
entity: "phdminh01"
|
21 |
+
log_frequency: 1
|
22 |
+
|
23 |
+
video_recording:
|
24 |
+
enable: False
|
25 |
+
record_frequency: 100 #episodes
|
26 |
+
|
27 |
+
main:
|
28 |
+
continue_loss: False
|
29 |
+
continue_scale_factor: 10
|
30 |
+
total_iter: 2000
|
31 |
+
save_freq: 20
|
32 |
+
collect_iter: 100
|
33 |
+
data_interact_ep: 1
|
34 |
+
# data_init_ep: 1
|
35 |
+
data_init_ep: 5
|
36 |
+
horizon: 15
|
37 |
+
batch_size: 50
|
38 |
+
seq_len: 50
|
39 |
+
eval_eps: 3
|
40 |
+
eval_freq: 5
|
41 |
+
|
42 |
+
kl_divergence_scale : 1
|
43 |
+
free_nats : 3
|
44 |
+
discount : 0.99
|
45 |
+
lambda_ : 0.95
|
46 |
+
|
47 |
+
use_continue_flag : True
|
48 |
+
actor_lr : 8.0e-5
|
49 |
+
critic_lr : 8.0e-5
|
50 |
+
dyna_model_lr : 6.0e-4
|
51 |
+
grad_norm_type : 2
|
52 |
+
clip_grad : 100
|
53 |
+
|
54 |
+
hidden_units: 300
|
55 |
+
deterministic_size : 200
|
56 |
+
stochastic_size : 30
|
57 |
+
embedded_obs_size : 1024
|
58 |
+
buffer_capacity : 500000
|
59 |
+
|
60 |
+
epsilon_start: 0.4
|
61 |
+
epsilon_end: 0.1
|
62 |
+
eps_decay_steps: 200000
|
63 |
+
|
64 |
+
mean_noise: 0
|
65 |
+
std_noise: 0.3
|
configs/gymnasium/Boxing-v5.yml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "mps"
|
2 |
+
experiment_name: Boxing-v5-new
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
env:
|
6 |
+
env_id: ALE/Boxing-v5
|
7 |
+
channel_first: True
|
8 |
+
discrete: True
|
9 |
+
resize_obs: True
|
10 |
+
new_obs_size: [64, 64]
|
11 |
+
norm_obs: True
|
12 |
+
|
13 |
+
tensorboard:
|
14 |
+
enable: True
|
15 |
+
log_dir: "./runs/"
|
16 |
+
log_frequency: 1 # Log every 1000 steps
|
17 |
+
|
18 |
+
wandb:
|
19 |
+
enable: False
|
20 |
+
project: "dreamer"
|
21 |
+
entity: "phdminh01"
|
22 |
+
log_frequency: 1
|
23 |
+
|
24 |
+
video_recording:
|
25 |
+
enable: True
|
26 |
+
record_frequency: 100 #episodes
|
27 |
+
save_path: "./runs/"
|
28 |
+
|
29 |
+
main:
|
30 |
+
continue_loss: False
|
31 |
+
continue_scale_factor: 10
|
32 |
+
total_iter: 2000
|
33 |
+
save_freq: 20
|
34 |
+
collect_iter: 100
|
35 |
+
data_interact_ep: 1
|
36 |
+
data_init_ep: 1
|
37 |
+
# data_init_ep: 5
|
38 |
+
horizon: 10
|
39 |
+
batch_size: 50
|
40 |
+
seq_len: 50
|
41 |
+
eval_eps: 3
|
42 |
+
eval_freq: 5
|
43 |
+
|
44 |
+
kl_divergence_scale : 1
|
45 |
+
free_nats : 3
|
46 |
+
discount : 0.99
|
47 |
+
lambda_ : 0.95
|
48 |
+
|
49 |
+
use_continue_flag : True
|
50 |
+
actor_lr : 8.0e-5
|
51 |
+
critic_lr : 8.0e-5
|
52 |
+
dyna_model_lr : 6.0e-4
|
53 |
+
grad_norm_type : 2
|
54 |
+
clip_grad : 100
|
55 |
+
|
56 |
+
hidden_units: 400
|
57 |
+
deterministic_size : 600
|
58 |
+
stochastic_size : 600
|
59 |
+
embedded_obs_size : 1024
|
60 |
+
buffer_capacity : 500000
|
61 |
+
|
62 |
+
epsilon_start: 0.4
|
63 |
+
epsilon_end: 0.1
|
64 |
+
eps_decay_steps: 200000
|
65 |
+
|
66 |
+
mean_noise: 0
|
67 |
+
std_noise: 0.3
|
configs/gymnasium/Pacman-v5.yml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "cuda"
|
2 |
+
experiment_name: Pacman-v5
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
env:
|
6 |
+
env_id: ALE/MsPacman-v5
|
7 |
+
channel_first: True
|
8 |
+
discrete: True
|
9 |
+
resize_obs: True
|
10 |
+
new_obs_size: [64, 64]
|
11 |
+
norm_obs: True
|
12 |
+
|
13 |
+
tensorboard:
|
14 |
+
enable: False
|
15 |
+
log_dir: "./runs/"
|
16 |
+
log_frequency: 1 # Log every 1000 steps
|
17 |
+
|
18 |
+
wandb:
|
19 |
+
enable: True
|
20 |
+
project: "dreamer"
|
21 |
+
entity: "phdminh01"
|
22 |
+
log_frequency: 1
|
23 |
+
|
24 |
+
video_recording:
|
25 |
+
enable: True
|
26 |
+
record_frequency: 100 #episodes
|
27 |
+
save_path: "./runs/"
|
28 |
+
|
29 |
+
main:
|
30 |
+
continue_loss: True
|
31 |
+
continue_scale_factor: 10
|
32 |
+
total_iter: 2000
|
33 |
+
save_freq: 20
|
34 |
+
collect_iter: 100
|
35 |
+
data_interact_ep: 1
|
36 |
+
# data_init_ep: 1
|
37 |
+
data_init_ep: 5
|
38 |
+
horizon: 15
|
39 |
+
batch_size: 50
|
40 |
+
seq_len: 50
|
41 |
+
eval_eps: 3
|
42 |
+
eval_freq: 5
|
43 |
+
|
44 |
+
kl_divergence_scale : 1
|
45 |
+
free_nats : 3
|
46 |
+
discount : 0.99
|
47 |
+
lambda_ : 0.95
|
48 |
+
|
49 |
+
actor_lr : 8.0e-5
|
50 |
+
critic_lr : 8.0e-5
|
51 |
+
dyna_model_lr : 6.0e-4
|
52 |
+
grad_norm_type : 2
|
53 |
+
clip_grad : 100
|
54 |
+
|
55 |
+
hidden_units: 400
|
56 |
+
deterministic_size : 600
|
57 |
+
stochastic_size : 600
|
58 |
+
embedded_obs_size : 1024
|
59 |
+
buffer_capacity : 500000
|
60 |
+
|
61 |
+
epsilon_start: 0.4
|
62 |
+
epsilon_end: 0.1
|
63 |
+
eps_decay_steps: 200000
|
64 |
+
|
65 |
+
mean_noise: 0
|
66 |
+
std_noise: 0.3
|
configs/gymnasium/ant_v4.yml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "mps"
|
2 |
+
experiment_name: Ant-v4
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
gymnasium:
|
6 |
+
env_id: Ant-v4
|
7 |
+
channel_first: True
|
8 |
+
pixels: True
|
9 |
+
discrete: False
|
10 |
+
resize_obs: True
|
11 |
+
new_obs_size: [64, 64]
|
12 |
+
norm_obs: True
|
13 |
+
|
14 |
+
tensorboard:
|
15 |
+
enable: False
|
16 |
+
log_dir: "./runs/"
|
17 |
+
log_frequency: 1 # Log every 1000 steps
|
18 |
+
|
19 |
+
wandb:
|
20 |
+
enable: False
|
21 |
+
project: "dreamer"
|
22 |
+
entity: "phdminh01"
|
23 |
+
log_frequency: 1
|
24 |
+
|
25 |
+
video_recording:
|
26 |
+
enable: False
|
27 |
+
record_frequency: 100 #episodes
|
28 |
+
|
29 |
+
main:
|
30 |
+
total_iter: 2000
|
31 |
+
save_freq: 100
|
32 |
+
collect_iter: 100
|
33 |
+
data_interact_ep: 1
|
34 |
+
# data_init_ep: 1
|
35 |
+
data_init_ep: 5
|
36 |
+
horizon: 15
|
37 |
+
batch_size: 50
|
38 |
+
seq_len: 50
|
39 |
+
eval_eps: 3
|
40 |
+
eval_freq: 5
|
41 |
+
|
42 |
+
kl_divergence_scale : 1
|
43 |
+
free_nats : 3
|
44 |
+
discount : 0.99
|
45 |
+
lambda_ : 0.95
|
46 |
+
|
47 |
+
use_continue_flag : True
|
48 |
+
actor_lr : 3.0e-4
|
49 |
+
critic_lr : 3.0e-4
|
50 |
+
dyna_model_lr : 6.0e-4
|
51 |
+
grad_norm_type : 2
|
52 |
+
clip_grad : 100
|
53 |
+
|
54 |
+
hidden_units: 400
|
55 |
+
deterministic_size : 600
|
56 |
+
stochastic_size : 600
|
57 |
+
embedded_obs_size : 1024
|
58 |
+
buffer_capacity : 500000
|
59 |
+
|
60 |
+
epsilon_start: 0.4
|
61 |
+
epsilon_end: 0.1
|
62 |
+
eps_decay_steps: 200000
|
63 |
+
|
64 |
+
mean_noise: 0
|
65 |
+
std_noise: 0.3
|
configs/gymnasium/car_racing_config.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "cuda"
|
2 |
+
experiment_name: CarRacing-v2
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
gymnasium:
|
6 |
+
env_id: CarRacing-v2
|
7 |
+
channel_first: True
|
8 |
+
discrete: False
|
9 |
+
resize_obs: True
|
10 |
+
new_obs_size: [64, 64]
|
11 |
+
norm_obs: True
|
12 |
+
|
13 |
+
tensorboard:
|
14 |
+
enable: True
|
15 |
+
log_dir: "./runs/"
|
16 |
+
log_frequency: 1 # Log every 1000 steps
|
17 |
+
|
18 |
+
wandb:
|
19 |
+
enable: True
|
20 |
+
project: "dreamer"
|
21 |
+
entity: "phdminh01"
|
22 |
+
log_frequency: 1
|
23 |
+
|
24 |
+
video_recording:
|
25 |
+
enable: True
|
26 |
+
record_frequency: 100 #episodes
|
27 |
+
|
28 |
+
main:
|
29 |
+
total_iter: 2000
|
30 |
+
save_freq: 100
|
31 |
+
collect_iter: 100
|
32 |
+
data_interact_ep: 1
|
33 |
+
# data_init_ep: 1
|
34 |
+
data_init_ep: 5
|
35 |
+
horizon: 15
|
36 |
+
batch_size: 50
|
37 |
+
seq_len: 50
|
38 |
+
eval_eps: 3
|
39 |
+
eval_freq: 5
|
40 |
+
|
41 |
+
kl_divergence_scale : 1
|
42 |
+
free_nats : 3
|
43 |
+
discount : 0.99
|
44 |
+
lambda_ : 0.95
|
45 |
+
|
46 |
+
use_continue_flag : True
|
47 |
+
actor_lr : 3.0e-4
|
48 |
+
critic_lr : 3.0e-4
|
49 |
+
dyna_model_lr : 6.0e-4
|
50 |
+
grad_norm_type : 2
|
51 |
+
clip_grad : 100
|
52 |
+
|
53 |
+
hidden_units: 400
|
54 |
+
deterministic_size : 600
|
55 |
+
stochastic_size : 600
|
56 |
+
embedded_obs_size : 1024
|
57 |
+
buffer_capacity : 500000
|
58 |
+
|
59 |
+
epsilon_start: 0.4
|
60 |
+
epsilon_end: 0.1
|
61 |
+
eps_decay_steps: 200000
|
62 |
+
|
63 |
+
mean_noise: 0
|
64 |
+
std_noise: 0.3
|
gif/boxing.gif
ADDED
![]() |
gif/boxing_imagine.gif
ADDED
![]() |
gif/pacman.gif
ADDED
![]() |
Git LFS Details
|
gif/pacman_imagine.gif
ADDED
![]() |
Git LFS Details
|
gif/quadruped.gif
ADDED
![]() |
Git LFS Details
|
gif/quadruped_imagine.gif
ADDED
![]() |
gif/walker_imagine.gif
ADDED
![]() |
Git LFS Details
|
imagine.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Minh Pham-Dinh
|
3 |
+
Created: Feb 4th, 2024
|
4 |
+
Last Modified: Feb 6th, 2024
|
5 |
+
Email: [email protected]
|
6 |
+
|
7 |
+
Description:
|
8 |
+
Imagination file. Run this file to generate dream sequences
|
9 |
+
"""
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import argparse
|
13 |
+
from utils.wrappers import DMCtoGymWrapper, AtariPreprocess
|
14 |
+
from addict import Dict
|
15 |
+
import yaml
|
16 |
+
import gymnasium as gym
|
17 |
+
import torch
|
18 |
+
from tqdm import tqdm
|
19 |
+
import numpy as np
|
20 |
+
import glob
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser(description='Process configuration file path.')
|
23 |
+
parser.add_argument('--runpath', type=str, help='Path to the run file.', required=True)
|
24 |
+
parser.add_argument('--horizon', type=int, help='number of imagination steps.', default=15)
|
25 |
+
|
26 |
+
# Parse the arguments
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
# Load the configuration file specified by the command line argument
|
30 |
+
run_path = args.runpath
|
31 |
+
HORIZON = args.horizon
|
32 |
+
|
33 |
+
config_files = glob.glob(run_path + '/config/*.yml')
|
34 |
+
|
35 |
+
if len(config_files) != 1:
|
36 |
+
print('there should only be 1 config file in config directory')
|
37 |
+
|
38 |
+
with open(config_files[0], 'r') as file:
|
39 |
+
config = Dict(yaml.load(file, Loader=yaml.FullLoader))
|
40 |
+
|
41 |
+
env_id = config.env.env_id
|
42 |
+
|
43 |
+
if 'ALE' in config.env.env_id:
|
44 |
+
env = gym.make(env_id, render_mode='rgb_array')
|
45 |
+
env = AtariPreprocess(env, config.env.new_obs_size,
|
46 |
+
False)
|
47 |
+
else:
|
48 |
+
task = config.env.task
|
49 |
+
env = DMCtoGymWrapper(env_id, task,
|
50 |
+
resize=config.env.new_obs_size,
|
51 |
+
record=False)
|
52 |
+
|
53 |
+
print("start imagining")
|
54 |
+
|
55 |
+
encode = torch.load(run_path + '/models/encoder', map_location=torch.device('cpu') )
|
56 |
+
decoder = torch.load(run_path + '/models/decoder', map_location=torch.device('cpu') )
|
57 |
+
rssm = torch.load(run_path + '/models/rssm_model', map_location=torch.device('cpu') )
|
58 |
+
actor = torch.load(run_path + '/models/actor', map_location=torch.device('cpu'))
|
59 |
+
|
60 |
+
obs, _ = env.reset()
|
61 |
+
|
62 |
+
for i in range(100):
|
63 |
+
obs, _, _, _, _ = env.step(env.action_space.sample())
|
64 |
+
|
65 |
+
posterior = torch.zeros((1, config.main.stochastic_size))
|
66 |
+
deterministic = torch.zeros((1, config.main.deterministic_size))
|
67 |
+
e_obs = encode(torch.from_numpy(obs).to(dtype=torch.float))
|
68 |
+
|
69 |
+
_, posterior = rssm.representation(e_obs, deterministic)
|
70 |
+
|
71 |
+
from PIL import Image
|
72 |
+
|
73 |
+
frames = []
|
74 |
+
|
75 |
+
for i in tqdm(range(200)):
|
76 |
+
actions = actor(posterior, deterministic)
|
77 |
+
deterministic = rssm.recurrent(posterior, actions, deterministic)
|
78 |
+
dist, posterior = rssm.transition(deterministic)
|
79 |
+
d_obs = decoder(posterior, deterministic)
|
80 |
+
d_obs = d_obs.mean.squeeze().detach().numpy()
|
81 |
+
obs = ((d_obs.transpose([1,2,0]) + 0.5) * 255).clip(0, 255).astype(np.uint8)
|
82 |
+
img = Image.fromarray(obs, "RGB")
|
83 |
+
frames.append(img)
|
84 |
+
|
85 |
+
print("saving gif")
|
86 |
+
frame_one = frames[0]
|
87 |
+
frame_one.save(run_path + "/imagine.gif", format="GIF", append_images=frames, save_all=True, duration=30, loop=0)
|
88 |
+
print("finished")
|
requirements.txt
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
addict==2.4.0
|
3 |
+
ale-py==0.8.1
|
4 |
+
anyio==4.2.0
|
5 |
+
appnope==0.1.3
|
6 |
+
argon2-cffi==23.1.0
|
7 |
+
argon2-cffi-bindings==21.2.0
|
8 |
+
arrow==1.3.0
|
9 |
+
async-lru==2.0.4
|
10 |
+
attrdict==2.0.1
|
11 |
+
attrs==23.2.0
|
12 |
+
AutoROM==0.4.2
|
13 |
+
AutoROM.accept-rom-license==0.6.1
|
14 |
+
Babel==2.14.0
|
15 |
+
beautifulsoup4==4.12.3
|
16 |
+
bleach==6.1.0
|
17 |
+
box2d-py==2.3.5
|
18 |
+
cachetools==5.3.2
|
19 |
+
certifi==2023.11.17
|
20 |
+
cffi==1.16.0
|
21 |
+
charset-normalizer==3.3.2
|
22 |
+
click==8.1.7
|
23 |
+
cloudpickle==3.0.0
|
24 |
+
comm==0.2.1
|
25 |
+
contourpy==1.2.0
|
26 |
+
cycler==0.12.1
|
27 |
+
debugpy==1.8.0
|
28 |
+
decorator==4.4.2
|
29 |
+
defusedxml==0.7.1
|
30 |
+
dm-tree==0.1.8
|
31 |
+
etils==1.6.0
|
32 |
+
Farama-Notifications==0.0.4
|
33 |
+
fastjsonschema==2.19.1
|
34 |
+
filelock==3.13.1
|
35 |
+
fonttools==4.47.2
|
36 |
+
fqdn==1.5.1
|
37 |
+
fsspec==2023.12.2
|
38 |
+
gast==0.5.4
|
39 |
+
glfw==2.6.5
|
40 |
+
google-auth==2.26.2
|
41 |
+
google-auth-oauthlib==1.2.0
|
42 |
+
grpcio==1.60.0
|
43 |
+
gymnasium==0.29.1
|
44 |
+
idna==3.6
|
45 |
+
imageio==2.33.1
|
46 |
+
imageio-ffmpeg==0.4.9
|
47 |
+
importlib-resources==6.1.1
|
48 |
+
ipykernel==6.29.0
|
49 |
+
ipywidgets==8.1.1
|
50 |
+
isoduration==20.11.0
|
51 |
+
Jinja2==3.1.3
|
52 |
+
json5==0.9.14
|
53 |
+
jsonpointer==2.4
|
54 |
+
jsonschema==4.21.0
|
55 |
+
jsonschema-specifications==2023.12.1
|
56 |
+
jupyter==1.0.0
|
57 |
+
jupyter-console==6.6.3
|
58 |
+
jupyter-events==0.9.0
|
59 |
+
jupyter-lsp==2.2.2
|
60 |
+
jupyter_client==8.6.0
|
61 |
+
jupyter_core==5.7.1
|
62 |
+
jupyter_server==2.12.5
|
63 |
+
jupyter_server_terminals==0.5.1
|
64 |
+
jupyterlab==4.0.10
|
65 |
+
jupyterlab-widgets==3.0.9
|
66 |
+
jupyterlab_pygments==0.3.0
|
67 |
+
jupyterlab_server==2.25.2
|
68 |
+
kiwisolver==1.4.5
|
69 |
+
lz4==4.3.3
|
70 |
+
Markdown==3.5.2
|
71 |
+
MarkupSafe==2.1.3
|
72 |
+
matplotlib==3.8.2
|
73 |
+
mistune==3.0.2
|
74 |
+
moviepy==1.0.3
|
75 |
+
mpmath==1.3.0
|
76 |
+
mujoco==3.1.1
|
77 |
+
nbclient==0.9.0
|
78 |
+
nbconvert==7.14.2
|
79 |
+
nbformat==5.9.2
|
80 |
+
nest-asyncio==1.5.9
|
81 |
+
networkx==3.2.1
|
82 |
+
notebook==7.0.6
|
83 |
+
notebook_shim==0.2.3
|
84 |
+
numpy==1.26.3
|
85 |
+
oauthlib==3.2.2
|
86 |
+
opencv-python==4.9.0.80
|
87 |
+
overrides==7.4.0
|
88 |
+
packaging==23.2
|
89 |
+
pandocfilters==1.5.1
|
90 |
+
pillow==10.2.0
|
91 |
+
platformdirs==4.1.0
|
92 |
+
proglog==0.1.10
|
93 |
+
prometheus-client==0.19.0
|
94 |
+
protobuf==4.23.4
|
95 |
+
psutil==5.9.7
|
96 |
+
pyasn1==0.5.1
|
97 |
+
pyasn1-modules==0.3.0
|
98 |
+
pycparser==2.21
|
99 |
+
pygame==2.5.2
|
100 |
+
PyOpenGL==3.1.7
|
101 |
+
pyparsing==3.1.1
|
102 |
+
python-dateutil==2.8.2
|
103 |
+
python-json-logger==2.0.7
|
104 |
+
PyYAML==6.0.1
|
105 |
+
pyzmq==25.1.2
|
106 |
+
qtconsole==5.5.1
|
107 |
+
QtPy==2.4.1
|
108 |
+
referencing==0.32.1
|
109 |
+
requests==2.31.0
|
110 |
+
requests-oauthlib==1.3.1
|
111 |
+
rfc3339-validator==0.1.4
|
112 |
+
rfc3986-validator==0.1.1
|
113 |
+
rpds-py==0.17.1
|
114 |
+
rsa==4.9
|
115 |
+
Send2Trash==1.8.2
|
116 |
+
Shimmy==0.2.1
|
117 |
+
sniffio==1.3.0
|
118 |
+
soupsieve==2.5
|
119 |
+
swig==4.1.1.post1
|
120 |
+
sympy==1.12
|
121 |
+
tensorboard==2.15.1
|
122 |
+
tensorboard-data-server==0.7.2
|
123 |
+
tensorflow-probability==0.23.0
|
124 |
+
terminado==0.18.0
|
125 |
+
tinycss2==1.2.1
|
126 |
+
tomli==2.0.1
|
127 |
+
torch==2.1.2
|
128 |
+
torchaudio==2.1.2
|
129 |
+
torchvision==0.16.2
|
130 |
+
wandb
|
131 |
+
tornado==6.4
|
132 |
+
tqdm==4.66.1
|
133 |
+
types-python-dateutil==2.8.19.20240106
|
134 |
+
typing_extensions==4.9.0
|
135 |
+
uri-template==1.3.0
|
136 |
+
urllib3==2.1.0
|
137 |
+
webcolors==1.13
|
138 |
+
webencodings==0.5.1
|
139 |
+
websocket-client==1.7.0
|
140 |
+
Werkzeug==3.0.1
|
141 |
+
widgetsnbextension==4.0.9
|
142 |
+
zipp==3.17.0
|
143 |
+
dm_control
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/__pycache__/buffer.cpython-310.pyc
ADDED
Binary file (5.29 kB). View file
|
|
utils/__pycache__/models.cpython-310.pyc
ADDED
Binary file (11.9 kB). View file
|
|
utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.53 kB). View file
|
|
utils/__pycache__/wrappers.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
utils/buffer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Minh Pham-Dinh
|
3 |
+
Created: Jan 26th, 2024
|
4 |
+
Last Modified: Feb 5th, 2024
|
5 |
+
Email: [email protected]
|
6 |
+
|
7 |
+
Description:
|
8 |
+
File containing the ReplayBuffer that will be used in Dreamer.
|
9 |
+
|
10 |
+
The implementation is based on:
|
11 |
+
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
|
12 |
+
[Online]. Available: https://arxiv.org/abs/1912.01603
|
13 |
+
"""
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from gymnasium import Env
|
17 |
+
import torch
|
18 |
+
from addict import Dict
|
19 |
+
|
20 |
+
class ReplayBuffer:
|
21 |
+
def __init__(self, capacity, obs_size, action_size):
|
22 |
+
|
23 |
+
# check if the env is gymnasium or dmc
|
24 |
+
self.obs_size = obs_size
|
25 |
+
self.action_size = action_size
|
26 |
+
|
27 |
+
# from SimpleDreamer implementation, saving memory
|
28 |
+
state_type = np.uint8 if len(self.obs_size) < 3 else np.float32
|
29 |
+
|
30 |
+
self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type)
|
31 |
+
|
32 |
+
self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32)
|
33 |
+
self.rewards = np.zeros((capacity, 1), dtype=np.float32)
|
34 |
+
self.dones = np.zeros((capacity, 1), dtype=np.float32)
|
35 |
+
|
36 |
+
self.pointer = 0
|
37 |
+
self.full = False
|
38 |
+
|
39 |
+
print(f'''
|
40 |
+
-----------initialized memory----------
|
41 |
+
|
42 |
+
obs_buffer_shape: {self.observation.shape}
|
43 |
+
actions_buffer_shape: {self.actions.shape}
|
44 |
+
rewards_buffer_shape: {self.rewards.shape}
|
45 |
+
dones_buffer_shape: {self.dones.shape}
|
46 |
+
|
47 |
+
----------------------------------------
|
48 |
+
''')
|
49 |
+
|
50 |
+
def add(self, obs, action, reward, done):
|
51 |
+
"""Add method for buffer
|
52 |
+
|
53 |
+
Args:
|
54 |
+
obs (np.array): current observation
|
55 |
+
action (np.array): action taken
|
56 |
+
reward (float): reward received after action
|
57 |
+
next_obs (np.array): next observation
|
58 |
+
done (bool): boolean value of termination or truncation
|
59 |
+
"""
|
60 |
+
self.observation[self.pointer] = obs
|
61 |
+
self.actions[self.pointer] = action
|
62 |
+
self.rewards[self.pointer] = reward
|
63 |
+
self.dones[self.pointer] = done
|
64 |
+
self.pointer = (self.pointer + 1) % self.observation.shape[0]
|
65 |
+
if self.pointer == 0:
|
66 |
+
self.full = True
|
67 |
+
|
68 |
+
def sample(self, batch_size, seq_len, device):
|
69 |
+
"""
|
70 |
+
Samples batches of experiences of fixed sequence length from the replay buffer,
|
71 |
+
taking into account the circular nature of the buffer to avoid crossing the
|
72 |
+
"end" of the buffer when it is full.
|
73 |
+
|
74 |
+
This method ensures that sampled sequences are continuous and do not wrap around
|
75 |
+
the end of the buffer, maintaining the temporal integrity of experiences. This is
|
76 |
+
particularly important when the buffer is full, and the pointer marks the boundary
|
77 |
+
between the newest and oldest data in the buffer.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
batch_size (int): The number of sequences to sample.
|
81 |
+
seq_len (int): The length of each sequence to sample.
|
82 |
+
device (torch.device): The device on which the sampled data will be loaded.
|
83 |
+
|
84 |
+
Raises:
|
85 |
+
Exception: If there is not enough data in the buffer to sample a full sequence.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Dict: A dictionary containing the sampled sequences of observations, actions,
|
89 |
+
rewards, and dones. Each item in the dictionary is a tensor of shape
|
90 |
+
(batch_size, seq_len, feature_dimension), except for 'dones' which is of shape
|
91 |
+
(batch_size, seq_len, 1).
|
92 |
+
|
93 |
+
Notes:
|
94 |
+
- The method handles different scenarios based on the buffer's state (full or not)
|
95 |
+
and the pointer's position to ensure valid sequence sampling without wrapping.
|
96 |
+
- When the buffer is not full, sequences can start from index 0 up to the
|
97 |
+
index where `seq_len` sequences can fit without surpassing the current pointer.
|
98 |
+
- When the buffer is full, the method ensures sequences do not start in a way
|
99 |
+
that would cause them to wrap around past the pointer, effectively crossing
|
100 |
+
the boundary between the newest and oldest data.
|
101 |
+
- This approach guarantees the sampled sequences respect the temporal order
|
102 |
+
and continuity necessary for algorithms that rely on sequences of experiences.
|
103 |
+
"""
|
104 |
+
|
105 |
+
# Ensure there's enough data to sample
|
106 |
+
if self.pointer < seq_len and not self.full:
|
107 |
+
raise Exception('not enough data to sample')
|
108 |
+
|
109 |
+
# detail: handling different cases for circular sampling
|
110 |
+
if self.full:
|
111 |
+
if self.pointer - seq_len < 0:
|
112 |
+
valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1)
|
113 |
+
else:
|
114 |
+
range_1 = np.arange(0, self.pointer - seq_len + 1)
|
115 |
+
range_2 = np.arange(self.pointer, self.observation.shape[0])
|
116 |
+
valid_range = np.concatenate((range_1, range_2), -1)
|
117 |
+
else:
|
118 |
+
valid_range = np.arange(0, self.pointer-seq_len+1)
|
119 |
+
|
120 |
+
start_index = np.random.choice(valid_range, (batch_size, 1))
|
121 |
+
|
122 |
+
seq_len = np.arange(seq_len)
|
123 |
+
sample_idcs = (start_index + seq_len) % self.observation.shape[0]
|
124 |
+
|
125 |
+
batch = Dict()
|
126 |
+
|
127 |
+
batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device)
|
128 |
+
batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device)
|
129 |
+
batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device)
|
130 |
+
batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device)
|
131 |
+
|
132 |
+
return batch
|
133 |
+
|
134 |
+
def clear(self, ):
|
135 |
+
self.pointer = 0
|
136 |
+
self.full = False
|
137 |
+
|
138 |
+
def __len__(self, ):
|
139 |
+
return self.pointer
|
utils/models.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Minh Pham-Dinh
|
3 |
+
Created: Jan 26th, 2024
|
4 |
+
Last Modified: Feb 10th, 2024
|
5 |
+
Email: [email protected]
|
6 |
+
|
7 |
+
Description:
|
8 |
+
File containing all models that will be used in Dreamer.
|
9 |
+
|
10 |
+
The implementation is based on:
|
11 |
+
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
|
12 |
+
[Online]. Available: https://arxiv.org/abs/1912.01603
|
13 |
+
"""
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
def initialize_weights(m):
|
21 |
+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
22 |
+
nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
|
23 |
+
nn.init.constant_(m.bias.data, 0)
|
24 |
+
elif isinstance(m, nn.Linear):
|
25 |
+
nn.init.kaiming_uniform_(m.weight.data)
|
26 |
+
nn.init.constant_(m.bias.data, 0)
|
27 |
+
|
28 |
+
|
29 |
+
class RSSM(nn.Module):
|
30 |
+
"""Reccurent State Space Model (RSSM)
|
31 |
+
The main model that we will use to learn the latent dynamic of the environment
|
32 |
+
"""
|
33 |
+
def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU):
|
34 |
+
super().__init__()
|
35 |
+
self.stochastic_size = stochastic_size
|
36 |
+
self.action_size = action_size
|
37 |
+
self.deterministic_size = deterministic_size
|
38 |
+
self.obs_embed_size = obs_embed_size
|
39 |
+
self.action_size = action_size
|
40 |
+
|
41 |
+
# recurrent
|
42 |
+
self.recurrent_linear = nn.Sequential(
|
43 |
+
nn.Linear(stochastic_size + action_size, hidden_size),
|
44 |
+
activation(),
|
45 |
+
)
|
46 |
+
self.gru_cell = nn.GRUCell(hidden_size, deterministic_size)
|
47 |
+
|
48 |
+
# representation model, for calculating posterior
|
49 |
+
self.representatio_model = nn.Sequential(
|
50 |
+
nn.Linear(deterministic_size + obs_embed_size, hidden_size),
|
51 |
+
activation(),
|
52 |
+
nn.Linear(hidden_size, stochastic_size*2)
|
53 |
+
)
|
54 |
+
|
55 |
+
# transition model, for calculating prior, use for imagining trajectories
|
56 |
+
self.transition_model = nn.Sequential(
|
57 |
+
nn.Linear(deterministic_size, hidden_size),
|
58 |
+
activation(),
|
59 |
+
nn.Linear(hidden_size, stochastic_size*2)
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def recurrent(self, stoch_state, action, deterministic):
|
65 |
+
"""The recurrent model, calculate the deterministic state given the stochastic state
|
66 |
+
the action, and the prior deterministic
|
67 |
+
|
68 |
+
Args:
|
69 |
+
a_t-1 (batch_size, action_size): action at time step, cannot be None.
|
70 |
+
s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None.
|
71 |
+
h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
h_t: deterministic at next time step
|
75 |
+
"""
|
76 |
+
|
77 |
+
# initialize some sizes
|
78 |
+
x = torch.cat((action, stoch_state), -1)
|
79 |
+
out = self.recurrent_linear(x)
|
80 |
+
out = self.gru_cell(out, deterministic)
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
def representation(self, embed_obs, deterministic):
|
85 |
+
"""Calculate the distribution p of the stochastic state.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
o_t (batch_size, embeded_obs_size): embedded observation (encoded)
|
89 |
+
h_t (batch_size, deterministic_size): determinstic size
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
s_t posterior_distribution: distribution of stochastic states
|
93 |
+
s_t posterior: sampled stochastic states
|
94 |
+
"""
|
95 |
+
x = torch.cat((embed_obs, deterministic), -1)
|
96 |
+
out = self.representatio_model(x)
|
97 |
+
mean, std = torch.chunk(out, 2, -1)
|
98 |
+
std = F.softplus(std) + 0.1
|
99 |
+
|
100 |
+
post_dist = torch.distributions.Normal(mean, std)
|
101 |
+
post = post_dist.rsample()
|
102 |
+
|
103 |
+
return post_dist, post
|
104 |
+
|
105 |
+
|
106 |
+
def transition(self, deterministic):
|
107 |
+
"""Calculate the distribution q of the stochastic state.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
h_t (batch_size, deterministic_size): determinstic size
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
s_t prior_distribution: distribution of stochastic states
|
114 |
+
s_t prior: sampled stochastic states
|
115 |
+
"""
|
116 |
+
out = self.transition_model(deterministic)
|
117 |
+
mean, std = torch.chunk(out, 2, -1)
|
118 |
+
std = F.softplus(std) + 0.1
|
119 |
+
|
120 |
+
prior_dist = torch.distributions.Normal(mean, std)
|
121 |
+
prior = prior_dist.rsample()
|
122 |
+
return prior_dist, prior
|
123 |
+
|
124 |
+
|
125 |
+
class ConvEncoder(nn.Module):
|
126 |
+
def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU):
|
127 |
+
super().__init__()
|
128 |
+
self.depth = depth
|
129 |
+
self.input_shape = input_shape
|
130 |
+
self.conv_layer = nn.Sequential(
|
131 |
+
nn.Conv2d(
|
132 |
+
in_channels=input_shape[0],
|
133 |
+
out_channels=depth * 1,
|
134 |
+
kernel_size=4,
|
135 |
+
stride=2,
|
136 |
+
padding="valid"
|
137 |
+
),
|
138 |
+
activation(),
|
139 |
+
nn.Conv2d(
|
140 |
+
in_channels=depth * 1,
|
141 |
+
out_channels=depth * 2,
|
142 |
+
kernel_size=4,
|
143 |
+
stride=2,
|
144 |
+
padding="valid"
|
145 |
+
),
|
146 |
+
activation(),
|
147 |
+
nn.Conv2d(
|
148 |
+
in_channels=depth * 2,
|
149 |
+
out_channels=depth * 4,
|
150 |
+
kernel_size=4,
|
151 |
+
stride=2,
|
152 |
+
padding="valid"
|
153 |
+
),
|
154 |
+
activation(),
|
155 |
+
nn.Conv2d(
|
156 |
+
in_channels=depth * 4,
|
157 |
+
out_channels=depth * 8,
|
158 |
+
kernel_size=4,
|
159 |
+
stride=2,
|
160 |
+
padding="valid"
|
161 |
+
),
|
162 |
+
activation()
|
163 |
+
)
|
164 |
+
self.conv_layer.apply(initialize_weights)
|
165 |
+
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
batch_shape = x.shape[:-len(self.input_shape)]
|
169 |
+
if not batch_shape:
|
170 |
+
batch_shape = (1, )
|
171 |
+
|
172 |
+
x = x.reshape(-1, *self.input_shape)
|
173 |
+
|
174 |
+
out = self.conv_layer(x)
|
175 |
+
|
176 |
+
#flatten output
|
177 |
+
return out.reshape(*batch_shape, -1)
|
178 |
+
|
179 |
+
|
180 |
+
class ConvDecoder(nn.Module):
|
181 |
+
"""Decode latent dynamic
|
182 |
+
Also referred to as observation model by the official Dreamer paper
|
183 |
+
|
184 |
+
"""
|
185 |
+
def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU):
|
186 |
+
super().__init__()
|
187 |
+
self.out_shape = out_shape
|
188 |
+
self.net = nn.Sequential(
|
189 |
+
nn.Linear(deterministic_size + stochastic_size, depth*32),
|
190 |
+
nn.Unflatten(1, (depth * 32, 1)),
|
191 |
+
nn.Unflatten(2, (1, 1)),
|
192 |
+
nn.ConvTranspose2d(
|
193 |
+
depth * 32,
|
194 |
+
depth * 4,
|
195 |
+
kernel_size=5,
|
196 |
+
stride=2,
|
197 |
+
),
|
198 |
+
activation(),
|
199 |
+
nn.ConvTranspose2d(
|
200 |
+
depth * 4,
|
201 |
+
depth * 2,
|
202 |
+
kernel_size=5,
|
203 |
+
stride=2,
|
204 |
+
),
|
205 |
+
activation(),
|
206 |
+
nn.ConvTranspose2d(
|
207 |
+
depth * 2,
|
208 |
+
depth * 1,
|
209 |
+
kernel_size=5 + 1,
|
210 |
+
stride=2,
|
211 |
+
),
|
212 |
+
activation(),
|
213 |
+
nn.ConvTranspose2d(
|
214 |
+
depth * 1,
|
215 |
+
out_shape[0],
|
216 |
+
kernel_size=5+1,
|
217 |
+
stride=2,
|
218 |
+
),
|
219 |
+
)
|
220 |
+
self.net.apply(initialize_weights)
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
def forward(self, posterior, deterministic, mps_flatten=False):
|
225 |
+
"""take in the stochastic state (posterior) and deterministic to construct the latent state then
|
226 |
+
output reconstructed pixel observation
|
227 |
+
|
228 |
+
Args:
|
229 |
+
s_t (batch_sz, stoch_size): stochastic state (or posterior)
|
230 |
+
h_t (batch_sz, deterministic_size): deterministic state
|
231 |
+
mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can
|
232 |
+
only support max 4 dimension (stupid af)
|
233 |
+
Returns:
|
234 |
+
o'_t: reconstructed_obs
|
235 |
+
"""
|
236 |
+
x = torch.cat((posterior, deterministic), -1)
|
237 |
+
batch_shape = x.shape[:-1]
|
238 |
+
if not batch_shape:
|
239 |
+
batch_shape = (1, )
|
240 |
+
|
241 |
+
x = x.reshape(-1, x.shape[-1])
|
242 |
+
|
243 |
+
if mps_flatten:
|
244 |
+
batch_shape = (-1, )
|
245 |
+
|
246 |
+
mean = self.net(x).reshape(*batch_shape, *self.out_shape)
|
247 |
+
|
248 |
+
dist = torch.distributions.Normal(mean, 1)
|
249 |
+
|
250 |
+
# #flatten output
|
251 |
+
return torch.distributions.Independent(dist, len(self.out_shape))
|
252 |
+
|
253 |
+
|
254 |
+
class RewardNet(nn.Module):
|
255 |
+
"""reward prediction model. It take in the stochastic state and the deterministic to construct
|
256 |
+
latent state. It then output the reward prediciton
|
257 |
+
|
258 |
+
Args:
|
259 |
+
nn (_type_): _description_
|
260 |
+
"""
|
261 |
+
def __init__(self, input_size, hidden_size, activation=nn.ELU):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
self.net = nn.Sequential(
|
265 |
+
nn.Linear(input_size, hidden_size),
|
266 |
+
activation(),
|
267 |
+
nn.Linear(hidden_size, 1)
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
def forward(self, stoch_state, deterministic):
|
272 |
+
"""take in the stochastic state and deterministic to construct the latent state then
|
273 |
+
output reard prediction
|
274 |
+
|
275 |
+
Args:
|
276 |
+
s_t (batch_sz, stoch_size): stochastic state (or posterior)
|
277 |
+
h_t (batch_sz, deterministic_size): deterministic state
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
r_t: rewards
|
281 |
+
"""
|
282 |
+
x = torch.cat((stoch_state, deterministic), -1)
|
283 |
+
batch_shape = x.shape[:-1]
|
284 |
+
if not batch_shape:
|
285 |
+
batch_shape = (1, )
|
286 |
+
|
287 |
+
x = x.reshape(-1, x.shape[-1])
|
288 |
+
|
289 |
+
return self.net(x).reshape(*batch_shape, 1)
|
290 |
+
|
291 |
+
|
292 |
+
class ContinuoNet(nn.Module):
|
293 |
+
"""continuity prediction model. It take in the stochastic state and the deterministic to construct
|
294 |
+
latent state. It then output the prediction of whether the termination state has been reached
|
295 |
+
|
296 |
+
Args:
|
297 |
+
nn (_type_): _description_
|
298 |
+
"""
|
299 |
+
def __init__(self, input_size, hidden_size, activation=nn.ELU):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.net = nn.Sequential(
|
303 |
+
nn.Linear(input_size, hidden_size),
|
304 |
+
activation(),
|
305 |
+
nn.Linear(hidden_size, hidden_size),
|
306 |
+
activation(),
|
307 |
+
nn.Linear(hidden_size, 1)
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
def forward(self, stoch_state, deterministic):
|
312 |
+
"""take in the stochastic state and deterministic to construct the latent state then
|
313 |
+
output reard prediction
|
314 |
+
|
315 |
+
Args:
|
316 |
+
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
|
317 |
+
h_t deterministic (batch_sz, deterministic_size): deterministic state
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
dist: Beurnoulli distribution of done
|
321 |
+
"""
|
322 |
+
x = torch.cat((stoch_state, deterministic), -1)
|
323 |
+
batch_shape = x.shape[:-1]
|
324 |
+
if not batch_shape:
|
325 |
+
batch_shape = (1, )
|
326 |
+
|
327 |
+
x = x.reshape(-1, x.shape[-1])
|
328 |
+
|
329 |
+
x = self.net(x).reshape(*batch_shape, 1)
|
330 |
+
return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1)
|
331 |
+
|
332 |
+
|
333 |
+
class Actor(nn.Module):
|
334 |
+
"""actor network
|
335 |
+
"""
|
336 |
+
def __init__(self,
|
337 |
+
latent_size,
|
338 |
+
hidden_size,
|
339 |
+
action_size,
|
340 |
+
discrete=True,
|
341 |
+
activation=nn.ELU,
|
342 |
+
min_std=1e-4,
|
343 |
+
init_std=5,
|
344 |
+
mean_scale=5):
|
345 |
+
|
346 |
+
super().__init__()
|
347 |
+
self.latent_size = latent_size
|
348 |
+
self.hidden_size = hidden_size
|
349 |
+
self.action_size = (action_size if discrete else action_size*2)
|
350 |
+
self.discrete = discrete
|
351 |
+
self.min_std=min_std
|
352 |
+
self.init_std = init_std
|
353 |
+
self.mean_scale = mean_scale
|
354 |
+
|
355 |
+
self.net = nn.Sequential(
|
356 |
+
nn.Linear(latent_size, hidden_size),
|
357 |
+
activation(),
|
358 |
+
nn.Linear(hidden_size, self.action_size)
|
359 |
+
)
|
360 |
+
|
361 |
+
|
362 |
+
def forward(self, stoch_state, deterministic):
|
363 |
+
"""actor network. get in stochastic state and deterministic state to construct latent state
|
364 |
+
and then use latent state to predict appropriate action
|
365 |
+
|
366 |
+
Args:
|
367 |
+
s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior)
|
368 |
+
h_t deterministic (batch_sz, deterministic_size): deterministic state
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
action distribution. OneHot if discrete, else is tanhNormal
|
372 |
+
"""
|
373 |
+
latent_state = torch.cat((stoch_state, deterministic), -1)
|
374 |
+
x = self.net(latent_state)
|
375 |
+
|
376 |
+
if self.discrete:
|
377 |
+
# straight through gradient (mentioned in DreamerV2)
|
378 |
+
dist = torch.distributions.OneHotCategorical(logits=x)
|
379 |
+
action = dist.sample() + dist.probs - dist.probs.detach()
|
380 |
+
else:
|
381 |
+
#ensure that the softplut output proper init_std
|
382 |
+
raw_init_std = np.log(np.exp(self.init_std) - 1)
|
383 |
+
|
384 |
+
mean, std = torch.chunk(x, 2, -1)
|
385 |
+
mean = self.mean_scale * F.tanh(mean / self.mean_scale)
|
386 |
+
std = F.softplus(std + raw_init_std) + self.min_std
|
387 |
+
|
388 |
+
dist = torch.distributions.Normal(mean, std)
|
389 |
+
dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform())
|
390 |
+
action = torch.distributions.Independent(dist, 1).rsample()
|
391 |
+
|
392 |
+
return action
|
393 |
+
|
394 |
+
|
395 |
+
class Critic(nn.Module):
|
396 |
+
"""
|
397 |
+
critic network
|
398 |
+
"""
|
399 |
+
def __init__(self, latent_size, hidden_size, activation=nn.ELU):
|
400 |
+
super().__init__()
|
401 |
+
self.latent_size = latent_size
|
402 |
+
|
403 |
+
self.net = nn.Sequential(
|
404 |
+
nn.Linear(latent_size, hidden_size),
|
405 |
+
activation(),
|
406 |
+
nn.Linear(hidden_size, hidden_size),
|
407 |
+
activation(),
|
408 |
+
nn.Linear(hidden_size, 1)
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
def forward(self, stoch_state, deterministic):
|
414 |
+
"""critic network. get in stochastic state and deterministic state to construct latent state
|
415 |
+
and then use latent state to predict state value
|
416 |
+
|
417 |
+
Args:
|
418 |
+
s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior)
|
419 |
+
h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
state value distribution.
|
423 |
+
"""
|
424 |
+
latent_state = torch.cat((stoch_state, deterministic), -1)
|
425 |
+
|
426 |
+
batch_shape = latent_state.shape[:-1]
|
427 |
+
if not batch_shape:
|
428 |
+
batch_shape = (1, )
|
429 |
+
|
430 |
+
latent_state = latent_state.reshape(-1, self.latent_size)
|
431 |
+
|
432 |
+
x = self.net(latent_state)
|
433 |
+
|
434 |
+
x = x.reshape(*batch_shape, 1)
|
435 |
+
|
436 |
+
dist = torch.distributions.Normal(x, 1)
|
437 |
+
dist = torch.distributions.Independent(dist, 1)
|
438 |
+
|
439 |
+
return dist
|
440 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def log_metrics(metrics, step, tb_writer, wandb_writer):
|
4 |
+
# Log metrics to TensorBoard
|
5 |
+
if tb_writer:
|
6 |
+
for key, value in metrics.items():
|
7 |
+
tb_writer.add_scalar(key, value, step)
|
8 |
+
|
9 |
+
# Log metrics to wandb
|
10 |
+
# if wandb_writer:
|
11 |
+
# wandb_writer.log(metrics, step=step)
|
12 |
+
|
13 |
+
|
14 |
+
def td_lambda(rewards, predicted_discount, values, lambda_, device):
|
15 |
+
"""
|
16 |
+
Compute the TD(λ) returns for value estimation.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
- rewards (Tensor): Tensor of rewards with shape [batch_size, horizon_len, 1].
|
20 |
+
- predicted_discount (Tensor): Tensor indicating probability of episode termination with shape [batch_size, horizon_len, 1].
|
21 |
+
- values (Tensor): Tensor of value estimates with shape [batch_size, horizon_len, 1].
|
22 |
+
- lambda_ (float): The λ parameter in TD(λ) controlling bias-variance tradeoff.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
- td_lambda (Tensor): The computed lambda returns with shape [batch_size, time_steps - 1].
|
26 |
+
"""
|
27 |
+
batch_size, _, _ = rewards.shape
|
28 |
+
last_lambda = torch.zeros((batch_size, 1)).to(device)
|
29 |
+
cur_rewards = rewards[:, :-1]
|
30 |
+
next_values = values[:, 1:]
|
31 |
+
predicted_discount = predicted_discount[:, :-1]
|
32 |
+
|
33 |
+
td_1 = cur_rewards + predicted_discount * next_values * (1 - lambda_)
|
34 |
+
returns = torch.zeros_like(cur_rewards).to(device)
|
35 |
+
|
36 |
+
|
37 |
+
for i in reversed(range(td_1.size(1))):
|
38 |
+
last_lambda = td_1[:, i] + predicted_discount[:, i] * lambda_ * last_lambda
|
39 |
+
returns[:, i] = last_lambda
|
40 |
+
|
41 |
+
return returns
|
utils/wrappers.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Minh Pham-Dinh
|
3 |
+
Created: Feb 4th, 2024
|
4 |
+
Last Modified: Feb 7th, 2024
|
5 |
+
Email: [email protected]
|
6 |
+
|
7 |
+
Description:
|
8 |
+
File containing wrappers for different environment types.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import gymnasium as gym
|
12 |
+
from dm_control import suite
|
13 |
+
from dm_control.suite.wrappers import pixels
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
import os
|
17 |
+
from dm_control import suite
|
18 |
+
from dm_control.rl.control import Environment
|
19 |
+
|
20 |
+
#wrapper by Hafner et al
|
21 |
+
class ActionRepeat:
|
22 |
+
def __init__(self, env, repeats):
|
23 |
+
self.env = env
|
24 |
+
self.repeats = repeats
|
25 |
+
|
26 |
+
def __getattr__(self, name):
|
27 |
+
return getattr(self.env, name)
|
28 |
+
|
29 |
+
def step(self, action):
|
30 |
+
done = False
|
31 |
+
total_reward = 0
|
32 |
+
current_step = 0
|
33 |
+
while current_step < self.repeats and not done:
|
34 |
+
obs, reward, termination, truncation, info = self.env.step(action)
|
35 |
+
total_reward += reward
|
36 |
+
current_step += 1
|
37 |
+
done = termination or truncation
|
38 |
+
return obs, total_reward, termination, truncation, info
|
39 |
+
|
40 |
+
|
41 |
+
#wrapper by Hafner et al
|
42 |
+
class NormalizeActions:
|
43 |
+
"""
|
44 |
+
A wrapper class that normalizes the action space of an environment.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
env (gym.Env): The environment to be wrapped.
|
48 |
+
|
49 |
+
Attributes:
|
50 |
+
_env (gym.Env): The original environment.
|
51 |
+
_mask (numpy.ndarray): A boolean mask indicating which action dimensions are finite.
|
52 |
+
_low (numpy.ndarray): The lower bounds of the action space.
|
53 |
+
_high (numpy.ndarray): The upper bounds of the action space.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, env):
|
57 |
+
self._env = env
|
58 |
+
self._mask = np.logical_and(
|
59 |
+
np.isfinite(env.action_space.low),
|
60 |
+
np.isfinite(env.action_space.high))
|
61 |
+
self._low = np.where(self._mask, env.action_space.low, -1)
|
62 |
+
self._high = np.where(self._mask, env.action_space.high, 1)
|
63 |
+
|
64 |
+
def __getattr__(self, name):
|
65 |
+
"""
|
66 |
+
Delegate attribute access to the original environment.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
name (str): The name of the attribute.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Any: The value of the attribute in the original environment.
|
73 |
+
"""
|
74 |
+
return getattr(self._env, name)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def action_space(self):
|
78 |
+
"""
|
79 |
+
Get the normalized action space.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
gym.spaces.Box: The normalized action space.
|
83 |
+
"""
|
84 |
+
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
85 |
+
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
86 |
+
return gym.spaces.Box(low, high, dtype=np.float32)
|
87 |
+
|
88 |
+
def step(self, action):
|
89 |
+
"""
|
90 |
+
Take a step in the environment with a normalized action.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
action (numpy.ndarray): The normalized action.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Tuple: A tuple containing the next state, reward, done flag, and additional information.
|
97 |
+
"""
|
98 |
+
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
99 |
+
original = np.where(self._mask, original, action)
|
100 |
+
return self._env.step(original)
|
101 |
+
|
102 |
+
|
103 |
+
class DMCtoGymWrapper(gym.Env):
|
104 |
+
"""
|
105 |
+
Wrapper to convert a DeepMind Control Suite environment to a Gymnasium environment with additional features like recording and episode truncation.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
domain_name (str): The name of the domain.
|
109 |
+
task_name (str): The name of the task.
|
110 |
+
task_kwargs (dict, optional): Additional kwargs for the task.
|
111 |
+
visualize_reward (bool, optional): Whether to visualize the reward. Defaults to False.
|
112 |
+
resize (list, optional): New size to resize observations. Defaults to [64, 64].
|
113 |
+
record (bool, optional): Whether to record episodes. Defaults to False.
|
114 |
+
record_freq (int, optional): Frequency (in episodes) to record. Defaults to 100.
|
115 |
+
record_path (str, optional): Path to save recorded videos. Defaults to '../'.
|
116 |
+
max_episode_steps (int, optional): Maximum steps per episode for truncation. Defaults to 1000.
|
117 |
+
"""
|
118 |
+
def __init__(self, domain_name, task_name, task_kwargs=None, visualize_reward=False, resize=[64,64], record=False, record_freq=100, record_path='../', max_episode_steps=1000, camera=None):
|
119 |
+
super().__init__()
|
120 |
+
self.env = suite.load(domain_name, task_name, task_kwargs=task_kwargs, visualize_reward=visualize_reward)
|
121 |
+
self.episode_count = -1
|
122 |
+
self.record = record
|
123 |
+
self.record_freq = record_freq
|
124 |
+
self.record_path = record_path
|
125 |
+
self.max_episode_steps = max_episode_steps
|
126 |
+
self.current_step = 0
|
127 |
+
self.total_reward = 0
|
128 |
+
self.recorder = None
|
129 |
+
|
130 |
+
# Define action and observation space based on the DMC environment
|
131 |
+
action_spec = self.env.action_spec()
|
132 |
+
self.action_space = gym.spaces.Box(low=action_spec.minimum, high=action_spec.maximum, dtype=np.float32)
|
133 |
+
|
134 |
+
# Initialize the pixels wrapper for observation space
|
135 |
+
self.env = pixels.Wrapper(self.env, pixels_only=True)
|
136 |
+
self.resize = resize # Assuming RGB images
|
137 |
+
self.observation_space = gym.spaces.Box(low=-0.5, high=+0.5, shape=(3, *resize), dtype=np.float32)
|
138 |
+
|
139 |
+
if camera is None:
|
140 |
+
camera = dict(quadruped=2).get(domain_name, 0)
|
141 |
+
self._camera = camera
|
142 |
+
|
143 |
+
def step(self, action):
|
144 |
+
time_step = self.env.step(action)
|
145 |
+
obs = self._get_obs(self.env)
|
146 |
+
|
147 |
+
reward = time_step.reward if time_step.reward is not None else 0
|
148 |
+
self.total_reward += (reward or 0)
|
149 |
+
self.current_step += 1
|
150 |
+
|
151 |
+
termination = time_step.last()
|
152 |
+
truncation = (self.current_step == self.max_episode_steps)
|
153 |
+
info = {}
|
154 |
+
if termination or truncation:
|
155 |
+
info = {
|
156 |
+
'episode': {
|
157 |
+
'r': [self.total_reward],
|
158 |
+
'l': self.current_step
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
if self.recorder:
|
163 |
+
frame = cv2.cvtColor(self.env.physics.render(camera_id=self._camera), cv2.COLOR_RGB2BGR)
|
164 |
+
self.recorder.write(frame)
|
165 |
+
video_file = os.path.join(self.record_path, f"episode_{self.episode_count}.webm")
|
166 |
+
if termination or truncation:
|
167 |
+
self._reset_recorder()
|
168 |
+
info['video_path'] = video_file
|
169 |
+
|
170 |
+
return obs, reward, termination, truncation, info
|
171 |
+
|
172 |
+
def reset(self):
|
173 |
+
self.current_step = 0
|
174 |
+
self.total_reward = 0
|
175 |
+
self.episode_count += 1
|
176 |
+
|
177 |
+
time_step = self.env.reset()
|
178 |
+
obs = self._get_obs(self.env)
|
179 |
+
|
180 |
+
if self.record and self.episode_count % self.record_freq == 0:
|
181 |
+
self._start_recording(self.env.physics.render(camera_id=self._camera))
|
182 |
+
|
183 |
+
return obs, {}
|
184 |
+
|
185 |
+
def _start_recording(self, frame):
|
186 |
+
if not os.path.exists(self.record_path):
|
187 |
+
os.makedirs(self.record_path)
|
188 |
+
video_file = os.path.join(self.record_path, f"episode_{self.episode_count}.webm")
|
189 |
+
height, width, _ = frame.shape
|
190 |
+
self.recorder = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'vp80'), 30, (width, height))
|
191 |
+
self.recorder.write(frame)
|
192 |
+
|
193 |
+
def _reset_recorder(self):
|
194 |
+
if self.recorder:
|
195 |
+
self.recorder.release()
|
196 |
+
self.recorder = None
|
197 |
+
|
198 |
+
def _get_obs(self, env):
|
199 |
+
obs = self.render()
|
200 |
+
obs = obs/255 - 0.5
|
201 |
+
rearranged_obs = obs.transpose([2,0,1])
|
202 |
+
return rearranged_obs
|
203 |
+
|
204 |
+
def render(self, mode='rgb_array'):
|
205 |
+
return self.env.physics.render(*self.resize, camera_id=self._camera) # Adjust camera_id based on the environment
|
206 |
+
|
207 |
+
|
208 |
+
class AtariPreprocess(gym.Wrapper):
|
209 |
+
"""
|
210 |
+
A custom Gym wrapper that integrates multiple environment processing steps:
|
211 |
+
- Records episode statistics and videos.
|
212 |
+
- Resizes observations to a specified shape.
|
213 |
+
- Scales and reorders observation channels.
|
214 |
+
- Scales rewards using the tanh function.
|
215 |
+
|
216 |
+
Parameters:
|
217 |
+
- env (gym.Env): The original environment to wrap.
|
218 |
+
- new_obs_size (tuple): The target size for observation resizing (height, width).
|
219 |
+
- record (bool): If True, enable video recording.
|
220 |
+
- record_path (str): The directory path where videos will be saved.
|
221 |
+
- record_freq (int): Frequency (in episodes) at which to record videos.
|
222 |
+
"""
|
223 |
+
def __init__(self, env, new_obs_size, record=False, record_path='../videos/', record_freq=100):
|
224 |
+
super().__init__(env)
|
225 |
+
self.env = gym.wrappers.RecordEpisodeStatistics(env)
|
226 |
+
|
227 |
+
if record:
|
228 |
+
self.env = gym.wrappers.RecordVideo(self.env, record_path, episode_trigger=lambda episode_id: episode_id % record_freq == 0)
|
229 |
+
self.env = gym.wrappers.ResizeObservation(self.env, shape=new_obs_size)
|
230 |
+
|
231 |
+
self.new_obs_size = new_obs_size
|
232 |
+
self.observation_space = gym.spaces.Box(
|
233 |
+
low=-0.5, high=0.5,
|
234 |
+
shape=(3, new_obs_size[0], new_obs_size[1]),
|
235 |
+
dtype=np.float32
|
236 |
+
)
|
237 |
+
|
238 |
+
def step(self, action):
|
239 |
+
obs, reward, termination, truncation, info = super().step(action)
|
240 |
+
obs = self.process_observation(obs)
|
241 |
+
reward = np.tanh(reward) # Scale reward
|
242 |
+
return obs, reward, termination, truncation, info
|
243 |
+
|
244 |
+
def reset(self, **kwargs):
|
245 |
+
obs, info = super().reset(**kwargs)
|
246 |
+
obs = self.process_observation(obs)
|
247 |
+
return obs, info
|
248 |
+
|
249 |
+
def process_observation(self, observation):
|
250 |
+
"""
|
251 |
+
Process and return the observation from the environment.
|
252 |
+
- Scales pixel values to the range [-0.5, 0.5].
|
253 |
+
- Reorders channels to CHW format (channels, height, width).
|
254 |
+
|
255 |
+
Parameters:
|
256 |
+
- observation (np.ndarray): The original observation from the environment.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
- np.ndarray: The processed observation.
|
260 |
+
"""
|
261 |
+
if 'pixels' in observation:
|
262 |
+
observation = observation['pixels']
|
263 |
+
observation = observation / 255.0 - 0.5
|
264 |
+
observation = np.transpose(observation, (2, 0, 1))
|
265 |
+
return observation
|