IliaLarchenko commited on
Commit
6f108e3
·
verified ·
1 Parent(s): f596fc3

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +38 -0
  2. config.json +48 -0
  3. config.yaml +225 -0
  4. model.safetensors +3 -0
  5. replay.mp4 +0 -0
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: lerobot
3
+ tags:
4
+ - model_hub_mixin
5
+ - pytorch_model_hub_mixin
6
+ - robotics
7
+ - dot
8
+ license: apache-2.0
9
+ datasets:
10
+ - lerobot/pusht_keypoints
11
+ pipeline_tag: robotics
12
+ ---
13
+
14
+ # Model Card for "Decoder Only Transformer (DOT) Policy" for PushT keypoints dataset
15
+
16
+ Read more about the model and implementation details in the [DOT Policy repository](https://github.com/IliaLarchenko/dot_policy).
17
+
18
+ This model is trained using the [LeRobot library](https://huggingface.co/lerobot) and achieves state-of-the-art results on behavior cloning on the PushT keypoints dataset. It achieves 84.5% success rate (and 0.964 average max reward) vs. ~78% for the previous state-of-the-art model or 69% that I managed to reproduce using VQ-BET implementation in LeRobot.
19
+
20
+ This result is achieved without the checkpoint selection. If you are interested in an even better model with a success rate of ~94% (but harder to reproduce as it requires some parameters tuning and checkpoint selection), please refer to [this model](https://huggingface.co/IliaLarchenko/dot_pusht_keypoints_best)
21
+
22
+ You can use this model by installing LeRobot from [this branch](https://github.com/IliaLarchenko/lerobot/tree/dot)
23
+
24
+ To train the model:
25
+
26
+ ```bash
27
+ python lerobot/scripts/train.py policy=dot_pusht_keypoints env=pusht env.obs_type=environment_state_agent_pos
28
+ ```
29
+
30
+ To evaluate the model:
31
+
32
+ ```bash
33
+ python lerobot/scripts/eval.py -p IliaLarchenko/dot_pusht_keypoints eval.n_episodes=1000 eval.batch_size=100 seed=1000000
34
+ ```
35
+
36
+ Model size:
37
+ - Total parameters: 2.1m
38
+ - Trainable parameters: 2.1m
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.75,
3
+ "crop_scale": 1.0,
4
+ "dim_feedforward": 512,
5
+ "dim_model": 128,
6
+ "dropout": 0.1,
7
+ "inference_horizon": 20,
8
+ "input_normalization_modes": {
9
+ "observation.environment_state": "min_max",
10
+ "observation.state": "min_max"
11
+ },
12
+ "input_shapes": {
13
+ "observation.environment_state": [
14
+ 16
15
+ ],
16
+ "observation.state": [
17
+ 2
18
+ ]
19
+ },
20
+ "lookback_aug": 5,
21
+ "lookback_obs_steps": 10,
22
+ "lora_rank": 20,
23
+ "merge_lora": true,
24
+ "n_decoder_layers": 8,
25
+ "n_heads": 8,
26
+ "n_obs_steps": 3,
27
+ "noise_decay": 0.999995,
28
+ "output_normalization_modes": {
29
+ "action": "min_max"
30
+ },
31
+ "output_shapes": {
32
+ "action": [
33
+ 2
34
+ ]
35
+ },
36
+ "pre_norm": true,
37
+ "predict_every_n": 1,
38
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
39
+ "rescale_shape": [
40
+ 96,
41
+ 96
42
+ ],
43
+ "return_every_n": 2,
44
+ "state_noise": 0.01,
45
+ "train_alpha": 0.9,
46
+ "train_horizon": 20,
47
+ "vision_backbone": "resnet18"
48
+ }
config.yaml ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resume: false
2
+ device: cuda
3
+ use_amp: true
4
+ seed: 100000
5
+ dataset_repo_id: lerobot/pusht_keypoints
6
+ video_backend: pyav
7
+ training:
8
+ offline_steps: 1000000
9
+ num_workers: 24
10
+ batch_size: 24
11
+ eval_freq: 10000
12
+ log_freq: 1000
13
+ save_checkpoint: true
14
+ save_freq: 50000
15
+ online_steps: 0
16
+ online_rollout_n_episodes: 1
17
+ online_rollout_batch_size: 1
18
+ online_steps_between_rollouts: 1
19
+ online_sampling_ratio: 0.5
20
+ online_env_seed: null
21
+ online_buffer_capacity: null
22
+ online_buffer_seed_size: 0
23
+ do_online_rollout_async: false
24
+ image_transforms:
25
+ enable: false
26
+ max_num_transforms: 3
27
+ random_order: false
28
+ brightness:
29
+ weight: 1
30
+ min_max:
31
+ - 0.8
32
+ - 1.2
33
+ contrast:
34
+ weight: 1
35
+ min_max:
36
+ - 0.8
37
+ - 1.2
38
+ saturation:
39
+ weight: 1
40
+ min_max:
41
+ - 0.5
42
+ - 1.5
43
+ hue:
44
+ weight: 1
45
+ min_max:
46
+ - -0.05
47
+ - 0.05
48
+ sharpness:
49
+ weight: 1
50
+ min_max:
51
+ - 0.8
52
+ - 1.2
53
+ save_model: true
54
+ grad_clip_norm: 50
55
+ lr: 0.0001
56
+ min_lr: 0.0001
57
+ lr_cycle_steps: 300000
58
+ weight_decay: 1.0e-05
59
+ delta_timestamps:
60
+ observation.environment_state:
61
+ - -1.5
62
+ - -1.4
63
+ - -1.3
64
+ - -1.2
65
+ - -1.1
66
+ - -1.0
67
+ - -0.9
68
+ - -0.8
69
+ - -0.7
70
+ - -0.6
71
+ - -0.5
72
+ - -0.1
73
+ - 0.0
74
+ observation.state:
75
+ - -1.5
76
+ - -1.4
77
+ - -1.3
78
+ - -1.2
79
+ - -1.1
80
+ - -1.0
81
+ - -0.9
82
+ - -0.8
83
+ - -0.7
84
+ - -0.6
85
+ - -0.5
86
+ - -0.1
87
+ - 0.0
88
+ action:
89
+ - -1.5
90
+ - -1.4
91
+ - -1.3
92
+ - -1.2
93
+ - -1.1
94
+ - -1.0
95
+ - -0.9
96
+ - -0.8
97
+ - -0.7
98
+ - -0.6
99
+ - -0.5
100
+ - -0.1
101
+ - 0.0
102
+ - 0.1
103
+ - 0.2
104
+ - 0.3
105
+ - 0.4
106
+ - 0.5
107
+ - 0.6
108
+ - 0.7
109
+ - 0.8
110
+ - 0.9
111
+ - 1.0
112
+ - 1.1
113
+ - 1.2
114
+ - 1.3
115
+ - 1.4
116
+ - 1.5
117
+ - 1.6
118
+ - 1.7
119
+ - 1.8
120
+ - 1.9
121
+ eval:
122
+ n_episodes: 100
123
+ batch_size: 100
124
+ use_async_envs: false
125
+ wandb:
126
+ enable: true
127
+ disable_artifact: false
128
+ project: pusht
129
+ notes: ''
130
+ fps: 10
131
+ env:
132
+ name: pusht
133
+ task: PushT-v0
134
+ image_size: 96
135
+ state_dim: 2
136
+ action_dim: 2
137
+ fps: ${fps}
138
+ episode_length: 300
139
+ gym:
140
+ obs_type: environment_state_agent_pos
141
+ render_mode: rgb_array
142
+ visualization_width: 384
143
+ visualization_height: 384
144
+ override_dataset_stats:
145
+ observation.environment_state:
146
+ min:
147
+ - 0.0
148
+ - 0.0
149
+ - 0.0
150
+ - 0.0
151
+ - 0.0
152
+ - 0.0
153
+ - 0.0
154
+ - 0.0
155
+ - 0.0
156
+ - 0.0
157
+ - 0.0
158
+ - 0.0
159
+ - 0.0
160
+ - 0.0
161
+ - 0.0
162
+ - 0.0
163
+ max:
164
+ - 512.0
165
+ - 512.0
166
+ - 512.0
167
+ - 512.0
168
+ - 512.0
169
+ - 512.0
170
+ - 512.0
171
+ - 512.0
172
+ - 512.0
173
+ - 512.0
174
+ - 512.0
175
+ - 512.0
176
+ - 512.0
177
+ - 512.0
178
+ - 512.0
179
+ - 512.0
180
+ observation.state:
181
+ min:
182
+ - 0.0
183
+ - 0.0
184
+ max:
185
+ - 512.0
186
+ - 512.0
187
+ action:
188
+ min:
189
+ - 0.0
190
+ - 0.0
191
+ max:
192
+ - 512.0
193
+ - 512.0
194
+ policy:
195
+ name: dot
196
+ n_obs_steps: 3
197
+ train_horizon: 20
198
+ inference_horizon: 20
199
+ lookback_obs_steps: 10
200
+ lookback_aug: 5
201
+ input_shapes:
202
+ observation.environment_state:
203
+ - 16
204
+ observation.state:
205
+ - ${env.state_dim}
206
+ output_shapes:
207
+ action:
208
+ - ${env.action_dim}
209
+ input_normalization_modes:
210
+ observation.environment_state: min_max
211
+ observation.state: min_max
212
+ output_normalization_modes:
213
+ action: min_max
214
+ state_noise: 0.01
215
+ noise_decay: 0.999995
216
+ pre_norm: true
217
+ dim_model: 128
218
+ n_heads: 8
219
+ dim_feedforward: 512
220
+ n_decoder_layers: 8
221
+ dropout: 0.1
222
+ alpha: 0.75
223
+ train_alpha: 0.9
224
+ predict_every_n: 1
225
+ return_every_n: 2
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b45aaac6d363fb26f405462dd901b45d1b436b686c163c9b4d2b71085bdc1aa5
3
+ size 8523444
replay.mp4 ADDED
Binary file (58.5 kB). View file