qgallouedec HF staff commited on
Commit
c6481b5
1 Parent(s): a788d9c

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoints/jat_small_v100/checkpoint-250000",
3
+ "action_loss_coef": 0.995,
4
+ "activation_function": "gelu_new",
5
+ "architectures": [
6
+ "JatModel"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "attention_layers": [
10
+ "global",
11
+ "local",
12
+ "global",
13
+ "local",
14
+ "global",
15
+ "local",
16
+ "global",
17
+ "local",
18
+ "global",
19
+ "local",
20
+ "global",
21
+ "local"
22
+ ],
23
+ "attention_types": [
24
+ [
25
+ [
26
+ "global",
27
+ "local"
28
+ ],
29
+ 6
30
+ ]
31
+ ],
32
+ "auto_map": {
33
+ "AutoConfig": "configuration_jat.JatConfig",
34
+ "AutoModelForCausalLM": "modeling_jat.JatModel"
35
+ },
36
+ "bos_token_id": 50256,
37
+ "classifier_dropout": 0.1,
38
+ "embed_dropout": 0.0,
39
+ "eos_token_id": 50256,
40
+ "hidden_size": 768,
41
+ "image_size": 224,
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": null,
44
+ "layer_norm_epsilon": 1e-05,
45
+ "max_continuous_size": 377,
46
+ "max_discrete_value": 212,
47
+ "max_position_embeddings": 512,
48
+ "model_type": "jat",
49
+ "num_channels": 3,
50
+ "num_heads": 12,
51
+ "num_layers": 12,
52
+ "observation_loss_coef": 0.005,
53
+ "patch_size": 16,
54
+ "resid_dropout": 0.0,
55
+ "tokenizer_class": "GPT2TokenizerFast",
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.36.1",
58
+ "use_cache": true,
59
+ "vocab_size": 50257,
60
+ "window_size": 256
61
+ }
configuration_jat.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPTNeoConfig
2
+
3
+
4
+ class JatConfig(GPTNeoConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`JatModel`]. It is used to instantiate a Jat
7
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with
8
+ the defaults will yield a similar configuration to that of the ... (TODO)
9
+
10
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
11
+ documentation from [`PretrainedConfig`] for more information.
12
+
13
+
14
+ Args:
15
+ vocab_size (`int`, *optional*, defaults to 50257):
16
+ Vocabulary size of the GPT Neo model. Defines the number of different tokens that can be represented by the
17
+ `inputs_ids` passed when calling [`GPTNeoModel`]. Vocabulary size of the model. Defines the different
18
+ tokens that can be represented by the *inputs_ids* passed to the forward method of [`GPTNeoModel`].
19
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
20
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
21
+ just in case (e.g., 512 or 1024 or 2048).
22
+ hidden_size (`int`, *optional*, defaults to 2048):
23
+ Dimensionality of the encoder layers and the pooler layer.
24
+ num_layers (`int`, *optional*, defaults to 24):
25
+ Number of hidden layers in the Transformer encoder.
26
+ attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
27
+ The type of attention for each layer in a `List` of the following format `[[["attention_type"],
28
+ num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
29
+ value of `attention_type` from `["global", "local"]`
30
+ num_heads (`int`, *optional*, defaults to 16):
31
+ Number of attention heads for each attention layer in the Transformer encoder.
32
+ intermediate_size (`int`, *optional*, defaults to 8192):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ window_size (`int`, *optional*, defaults to 256):
35
+ The size of the sliding window for local attention.
36
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu_new"`):
37
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
38
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
39
+ resid_dropout (`float`, *optional*, defaults to 0.0):
40
+ Residual dropout used in the attention pattern.
41
+ embed_dropout (`float`, *optional*, defaults to 0.0):
42
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
43
+ attention_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for the attention probabilities.
45
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
46
+ Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The
47
+ dropout ratio for the hidden layer.
48
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
49
+ The epsilon used by the layer normalization layers.
50
+ initializer_range (`float`, *optional*, defaults to 0.02):
51
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ bos_token_id (`int`, *optional*, defaults to 50256):
56
+ The id of the beginning of sentence token in the vocabulary.
57
+ eos_token_id (`int`, *optional*, defaults to 50256):
58
+ The id of the end of sentence token in the vocabulary.
59
+ max_continuous_size (`int`, *optional*, default to 376):
60
+ The maximum size of the continuous values.
61
+ max_discrete_value (`int`, *optional*, default to 18):
62
+ The maximum value of the discrete values.
63
+ image_size (`int`, *optional*, defaults to 224):
64
+ The size (resolution) of each image.
65
+ patch_size (`int`, *optional*, defaults to 16):
66
+ The size (resolution) of each patch.
67
+ observation_loss_coef (`float`, *optional*, defaults to 0.005):
68
+ The coefficient for the observation loss. When set to 0.0, the observation is not even predicted.
69
+ action_loss_coef (`float`, *optional*, defaults to 0.995):
70
+ The coefficient for the action loss.
71
+ """
72
+
73
+ model_type = "jat"
74
+
75
+ def __init__(
76
+ self,
77
+ vocab_size=50257,
78
+ max_position_embeddings=2048,
79
+ hidden_size=2048,
80
+ num_layers=24,
81
+ attention_types=[[["global", "local"], 12]],
82
+ num_heads=16,
83
+ intermediate_size=None,
84
+ window_size=256,
85
+ activation_function="gelu_new",
86
+ resid_dropout=0.0,
87
+ embed_dropout=0.0,
88
+ attention_dropout=0.0,
89
+ classifier_dropout=0.1,
90
+ layer_norm_epsilon=1e-5,
91
+ initializer_range=0.02,
92
+ use_cache=True,
93
+ bos_token_id=50256,
94
+ eos_token_id=50256,
95
+ max_continuous_size=377,
96
+ max_discrete_value=18,
97
+ image_size=224,
98
+ num_channels=3,
99
+ patch_size=16,
100
+ observation_loss_coef=0.005,
101
+ action_loss_coef=0.995,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(
105
+ vocab_size,
106
+ max_position_embeddings,
107
+ hidden_size,
108
+ num_layers,
109
+ attention_types,
110
+ num_heads,
111
+ intermediate_size,
112
+ window_size,
113
+ activation_function,
114
+ resid_dropout,
115
+ embed_dropout,
116
+ attention_dropout,
117
+ classifier_dropout,
118
+ layer_norm_epsilon,
119
+ initializer_range,
120
+ use_cache,
121
+ bos_token_id,
122
+ eos_token_id,
123
+ **kwargs,
124
+ )
125
+ self.max_continuous_size = max_continuous_size
126
+ self.max_discrete_value = max_discrete_value
127
+ self.image_size = image_size
128
+ self.num_channels = num_channels
129
+ self.patch_size = patch_size
130
+ self.observation_loss_coef = observation_loss_coef
131
+ self.action_loss_coef = action_loss_coef
132
+
133
+
134
+ JatConfig.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.36.1"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b66502d0c5687998593d89582dc18697d3d144871b0905c345126e632c22508
3
+ size 770828444
modeling_jat.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from gymnasium import spaces
9
+ from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
10
+ from transformers import GPTNeoModel, GPTNeoPreTrainedModel
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
13
+
14
+ from .configuration_jat import JatConfig
15
+ from .processing_jat import JatProcessor
16
+
17
+
18
+ def compute_mse_loss(
19
+ predicted: FloatTensor, true: FloatTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
20
+ ) -> FloatTensor:
21
+ """
22
+ Compute the Mean Squared Error (MSE) loss between predicted and true observations, considering valid timesteps.
23
+
24
+ Args:
25
+ predicted (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
26
+ Predicted observations at the output of the model.
27
+ true (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
28
+ Ground truth observations.
29
+ mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
30
+ Boolean mask indicating valid timesteps.
31
+ weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
32
+ Weights to be applied to the loss.
33
+
34
+ Returns:
35
+ loss (`FloatTensor` of shape `(,)`):
36
+ MSE loss between predicted and true observations.
37
+ """
38
+ # Compute element-wise MSE loss
39
+ loss = F.mse_loss(predicted, true, reduction="none")
40
+
41
+ # Average the loss over all dimensions after the second one
42
+ for dim in reversed(range(2, loss.dim())):
43
+ loss = loss.mean(dim=dim)
44
+
45
+ # Use the mask to zero out invalid entries
46
+ if mask is not None:
47
+ loss = loss * mask
48
+
49
+ # Apply weights if provided
50
+ if weights is not None:
51
+ loss = loss * weights
52
+
53
+ # Sum the loss and normalize by the number of valid elements
54
+ loss = loss.sum() / mask.sum() if mask is not None else loss.mean()
55
+
56
+ return loss
57
+
58
+
59
+ def compute_ce_loss(
60
+ logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
61
+ ) -> FloatTensor:
62
+ """
63
+ Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps.
64
+
65
+ Args:
66
+ logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`):
67
+ Predicted logits at the output of the model.
68
+ labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`):
69
+ Ground truth class labels.
70
+ mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
71
+ Boolean mask indicating valid timesteps.
72
+ weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
73
+ Weights to be applied to the loss.
74
+
75
+ Returns:
76
+ loss (`FloatTensor` of shape `(,)`):
77
+ CE loss between predicted logits and true class labels.
78
+ """
79
+ if mask is not None:
80
+ logits = logits[mask.bool()] # (Y, X, C)
81
+ labels = labels[mask.bool()] # (Y, X)
82
+ if weights is not None:
83
+ weights = weights[mask.bool()] # (Y,)
84
+ else:
85
+ logits = logits.flatten(end_dim=2) # (B, L, X, C) -> (B*L, X, C)
86
+ labels = labels.flatten(end_dim=1) # (B, L, X) -> (B*L, X)
87
+ if weights is not None:
88
+ weights = weights.flatten(end_dim=1) # (B, L) -> (B*L,)
89
+
90
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none") # (Y*X,)
91
+ loss = loss.view(labels.size()) # (Y, X)
92
+ loss = loss.mean(-1) # (Y,)
93
+
94
+ # Multiply the loss by the weights
95
+ if weights is not None:
96
+ loss = loss * weights # (Y,)
97
+
98
+ # Average the loss
99
+ loss = loss.mean()
100
+
101
+ return loss
102
+
103
+
104
+ def cyclic_expand_dim(tensor: Tensor, expanded_dim_size: int) -> Tensor:
105
+ """
106
+ Expands the last dimension of a tensor cyclically to a specified size.
107
+
108
+ Args:
109
+ tensor (`torch.Tensor` of shape `(batch_size, seq_len, ...)`):
110
+ Input tensor whose last dimension is to be expanded cyclically.
111
+ expanded_dim_size (`int`):
112
+ The desired size of the last dimension after expansion.
113
+
114
+ Returns:
115
+ `torch.Tensor` of shape `(batch_size, seq_len, expanded_dim_size)`:
116
+ A tensor with its last dimension expanded cyclically to the specified size.
117
+
118
+ Examples:
119
+ >>> tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
120
+ >>> cyclic_expand_dim(tensor, 5)
121
+ tensor([[[1, 2, 1, 2, 1], [3, 4, 3, 4, 3]], [[5, 6, 5, 6, 5], [7, 8, 7, 8, 7]]])
122
+ """
123
+ B, L, X = tensor.shape
124
+ if expanded_dim_size < X:
125
+ raise ValueError(
126
+ f"Expanded dimension size ({expanded_dim_size}) must be greater than the original dimension size ({X})."
127
+ )
128
+ indices = torch.arange(expanded_dim_size) % X
129
+ return tensor[..., indices]
130
+
131
+
132
+ class ResidualBlock(nn.Module):
133
+ """
134
+ A residual block module that consists of two convolutional layers with a residual connection.
135
+
136
+ Args:
137
+ in_shape (`Tuple[int, int, int]`):
138
+ Shape of the input tensor.
139
+ out_channels (`int`):
140
+ Number of output channels.
141
+
142
+ Returns:
143
+ `torch.Tensor` of shape `(batch_size, out_channels, in_shape[1], in_shape[2])`:
144
+ Output tensor.
145
+ """
146
+
147
+ def __init__(self, in_shape: Tuple[int, int, int], out_channels: int) -> None:
148
+ super().__init__()
149
+ out_shape = (out_channels, in_shape[1], in_shape[2])
150
+
151
+ self.conv1 = nn.Conv2d(in_shape[0], out_channels, kernel_size=3, stride=1, padding=1)
152
+ self.norm1 = nn.LayerNorm(out_shape)
153
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
154
+ self.norm2 = nn.LayerNorm(out_shape)
155
+
156
+ # Handling the change in dimensions with a 1x1 convolution
157
+ self.shortcut = nn.Sequential(
158
+ nn.Conv2d(in_shape[0], out_channels, kernel_size=1, stride=1), nn.LayerNorm(out_shape)
159
+ )
160
+
161
+ def forward(self, x: FloatTensor) -> FloatTensor:
162
+ out = F.leaky_relu(self.norm1(self.conv1(x)))
163
+ out = self.norm2(self.conv2(out))
164
+ out += self.shortcut(x)
165
+ return F.leaky_relu(out, inplace=True)
166
+
167
+
168
+ class AttentionLayer(nn.Module):
169
+ """
170
+ Attention layer that applies an attention mechanism to the input tensor.
171
+
172
+ Args:
173
+ num_channels (`int`):
174
+ Number of channels.
175
+
176
+ Returns:
177
+ `torch.Tensor`:
178
+ Output tensor of the same shape as the input tensor.
179
+ """
180
+
181
+ def __init__(self, num_channels: int) -> None:
182
+ super().__init__()
183
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
184
+ self.fc = nn.Sequential(
185
+ nn.Linear(num_channels, num_channels // 8, bias=False),
186
+ nn.ReLU(inplace=True),
187
+ nn.Linear(num_channels // 8, num_channels, bias=False),
188
+ nn.Sigmoid(),
189
+ )
190
+
191
+ def forward(self, x: FloatTensor) -> FloatTensor:
192
+ b, c, _, _ = x.size()
193
+ y = self.avg_pool(x).view(b, c)
194
+ y = self.fc(y).view(b, c, 1, 1)
195
+ return x * y.expand_as(x)
196
+
197
+
198
+ class ImageEncoder(nn.Module):
199
+ """
200
+ Image encoder that encodes a batch of images.
201
+
202
+ Args:
203
+ hidden_size (`int`):
204
+ Size of the output hidden state.
205
+
206
+ Returns:
207
+ `torch.Tensor` of shape `(batch_size, hidden_size)`:
208
+ Output tensor.
209
+ """
210
+
211
+ def __init__(self, hidden_size: int) -> None:
212
+ super().__init__()
213
+ self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1) # 42x42
214
+ self.norm1 = nn.InstanceNorm2d(32)
215
+ self.att1 = AttentionLayer(32)
216
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # 21x21
217
+ self.norm2 = nn.InstanceNorm2d(64)
218
+ self.att2 = AttentionLayer(64)
219
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # 11x11
220
+ self.norm3 = nn.InstanceNorm2d(128)
221
+ self.att3 = AttentionLayer(128)
222
+ self.fc = nn.Linear(128 * 11 * 11, hidden_size) # Adjusted to the new spatial dimension
223
+
224
+ def forward(self, x: FloatTensor) -> FloatTensor:
225
+ x = F.leaky_relu(self.norm1(self.conv1(x)), inplace=True)
226
+ x = self.att1(x)
227
+ x = F.leaky_relu(self.norm2(self.conv2(x)), inplace=True)
228
+ x = self.att2(x)
229
+ x = F.leaky_relu(self.norm3(self.conv3(x)), inplace=True)
230
+ x = self.att3(x)
231
+ x = x.view(x.size(0), -1) # Flatten the tensor
232
+ x = self.fc(x)
233
+ return x
234
+
235
+
236
+ class ImageDecoder(nn.Module):
237
+ """
238
+ Image decoder that decodes a batch of encoded representations.
239
+
240
+ Args:
241
+ hidden_size (`int`):
242
+ Size of the input hidden state.
243
+
244
+ Returns:
245
+ `torch.Tensor` of shape `(batch_size, 4, 84, 84)`:
246
+ Output tensor representing the reconstructed images.
247
+ """
248
+
249
+ def __init__(self, hidden_size: int) -> None:
250
+ super().__init__()
251
+ self.fc = nn.Linear(hidden_size, 128 * 11 * 11)
252
+ self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # 21x21
253
+ self.norm1 = nn.InstanceNorm2d(64)
254
+ self.att1 = AttentionLayer(64)
255
+ self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # 42x42
256
+ self.norm2 = nn.InstanceNorm2d(32)
257
+ self.att2 = AttentionLayer(32)
258
+ self.deconv3 = nn.ConvTranspose2d(32, 4, kernel_size=3, stride=2, padding=1, output_padding=1) # 84x84
259
+
260
+ def forward(self, x: FloatTensor) -> FloatTensor:
261
+ x = self.fc(x)
262
+ x = x.view(x.size(0), 128, 11, 11) # Reshape to the spatial dimension of encoder's last conv layer
263
+ x = F.leaky_relu(self.norm1(self.deconv1(x)), inplace=True) # 22x22
264
+ x = F.interpolate(x, size=(21, 21)) # 21x21
265
+ x = self.att1(x)
266
+ x = F.leaky_relu(self.norm2(self.deconv2(x)), inplace=True)
267
+ x = self.att2(x)
268
+ x = F.tanh(self.deconv3(x))
269
+ return x
270
+
271
+
272
+ class DualBatchReshapeWrapper(nn.Module):
273
+ """
274
+ Wrapper to make a module designed for a single batch work with a dual batch.
275
+
276
+ Args:
277
+ module (`nn.Module`):
278
+ Module to be wrapped.
279
+ """
280
+
281
+ def __init__(self, module: nn.Module) -> None:
282
+ super().__init__()
283
+ self.module = module
284
+
285
+ def forward(self, x: FloatTensor) -> FloatTensor:
286
+ n1, n2 = x.shape[:2]
287
+ x = x.view(n1 * n2, *x.shape[2:])
288
+ x = self.module(x)
289
+ x = x.view(n1, n2, *x.shape[1:])
290
+ return x
291
+
292
+
293
+ @dataclass
294
+ class JatOutput(ModelOutput):
295
+ """
296
+ Output of the Jat model.
297
+
298
+ The model can be used for both RL and NLP tasks. For RL tasks, the model takes in observations and actions
299
+ (`continuous_observations`, `discrete_actions`, etc.). For textual tasks, the model takes in a sequence of tokens
300
+ and/or images (`input_ids`, `image`). The output depends on the type of input.
301
+
302
+ Args:
303
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
304
+ For RL input, the loss is the sum of the observation loss and the action loss.
305
+ For textual input, the causal language modeling loss.
306
+ observation_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
307
+ Only returned when RL input is provided. The MSE loss between predicted and true observations for
308
+ continuous observations and the cross-entropy loss for discrete observations.
309
+ action_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
310
+ Only returned when RL input is provided. The MSE loss between predicted and true actions for
311
+ continuous actions and the cross-entropy loss for discrete actions.
312
+ pred_observations (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
313
+ Only returned when RL input is provided. Predicted observations from t=1 to t=max_seq_len+1.
314
+ pred_actions (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`):
315
+ Only returned when RL input is provided. Predicted actions from t=0 to t=max_seq_len. When input actions
316
+ are discrete, the predicted actions are logits.
317
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
318
+ Sequence of hidden-states at the output of the last layer of the model.
319
+
320
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
321
+ hidden_size)` is output.
322
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
323
+ when `config.use_cache=True`):
324
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
325
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
326
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
327
+ encoder_sequence_length, embed_size_per_head)`.
328
+
329
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
330
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
331
+ input) to speed up sequential decoding.
332
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
333
+ when `config.output_hidden_states=True`):
334
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
335
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
336
+
337
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
338
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
339
+ `config.output_attentions=True`):
340
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
341
+ sequence_length)`.
342
+
343
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
344
+ heads.
345
+ """
346
+
347
+ loss: Optional[FloatTensor] = None
348
+ observation_loss: Optional[FloatTensor] = None
349
+ action_loss: Optional[FloatTensor] = None
350
+ pred_observations: Optional[FloatTensor] = None
351
+ pred_actions: Optional[FloatTensor] = None
352
+ logits: Optional[FloatTensor] = None
353
+ past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None
354
+ hidden_states: Optional[Tuple[FloatTensor]] = None
355
+ attentions: Optional[Tuple[FloatTensor]] = None
356
+
357
+
358
+ class JatModel(GPTNeoPreTrainedModel):
359
+ """
360
+ Jat model.
361
+ """
362
+
363
+ config_class = JatConfig
364
+
365
+ def __init__(self, config: JatConfig) -> None:
366
+ super().__init__(config)
367
+
368
+ vocab_size = config.vocab_size
369
+ hidden_size = config.hidden_size
370
+ max_discrete_value = config.max_discrete_value
371
+ max_continuous_size = config.max_continuous_size
372
+ self.observation_loss_coef = config.observation_loss_coef
373
+ self.action_loss_coef = config.action_loss_coef
374
+
375
+ # Transformer
376
+ self.transformer = GPTNeoModel(config)
377
+
378
+ # Encoders
379
+ self.vit_encoder = ViTPatchEmbeddings(config)
380
+ self.single_discrete_encoder = self.transformer.wte
381
+ self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size)
382
+ self.multi_discrete_encoder = nn.Sequential(
383
+ self.single_discrete_encoder, # (B, L, X, H)
384
+ nn.Linear(hidden_size, hidden_size // 50), # (B, L, X, H // 50)
385
+ nn.ReLU(),
386
+ nn.Flatten(start_dim=2), # (B, L, X * (H // 50))
387
+ nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1), # (B, L, H)
388
+ ) # -1 to account for the reward
389
+ self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size))
390
+
391
+ # Decoders
392
+ self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False)
393
+ self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size)
394
+ self.multi_discrete_decoder = nn.Sequential(
395
+ nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)), # (B, L, X * (H // 50))
396
+ nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)), # (B, L, X, H // 50)
397
+ nn.ReLU(),
398
+ nn.Linear(hidden_size // 50, hidden_size), # (B, L, X, H)
399
+ nn.ReLU(),
400
+ nn.Linear(hidden_size, 8, bias=False), # (B, L, X, 8) - the max possible value in the dataset is 8
401
+ )
402
+ self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size))
403
+
404
+ # Initialize weights and apply final processing
405
+ self.post_init()
406
+
407
+ def embed_textual(
408
+ self,
409
+ input_ids: Optional[LongTensor],
410
+ pixel_values: Optional[FloatTensor] = None,
411
+ attention_mask: Optional[BoolTensor] = None,
412
+ ) -> Tensor:
413
+ text_inputs_embeds = self.single_discrete_encoder(input_ids) if input_ids is not None else None
414
+ image_inputs_embeds = self.vit_encoder(pixel_values) if pixel_values is not None else None
415
+ # Concatenate text and image inputs
416
+ if image_inputs_embeds is not None and text_inputs_embeds is not None:
417
+ inputs_embeds = torch.cat((image_inputs_embeds, text_inputs_embeds), dim=1)
418
+ # Add attention mask for image inputs
419
+ image_mask = torch.ones(image_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device)
420
+ if attention_mask is None:
421
+ attention_mask = torch.ones(text_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device)
422
+ attention_mask = torch.cat((image_mask, attention_mask), dim=1)
423
+ elif image_inputs_embeds is not None:
424
+ inputs_embeds = image_inputs_embeds
425
+ elif text_inputs_embeds is not None:
426
+ inputs_embeds = text_inputs_embeds
427
+ attention_mask = attention_mask
428
+ else:
429
+ raise ValueError("At least one of `input_ids` or `pixel_values` must be provided.")
430
+ return inputs_embeds, attention_mask
431
+
432
+ def embed_rl(
433
+ self,
434
+ continuous_observations: Optional[FloatTensor] = None,
435
+ discrete_observations: Optional[LongTensor] = None,
436
+ image_observations: Optional[FloatTensor] = None,
437
+ continuous_actions: Optional[FloatTensor] = None,
438
+ discrete_actions: Optional[LongTensor] = None,
439
+ rewards: Optional[FloatTensor] = None,
440
+ attention_mask: Optional[BoolTensor] = None,
441
+ ):
442
+ # Prepare RL inputs (pad and cat rewards to observations)
443
+ assert rewards is not None
444
+ if continuous_observations is not None:
445
+ continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1)
446
+ continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size)
447
+ if continuous_actions is not None:
448
+ continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
449
+
450
+ # Encode
451
+ if continuous_observations is not None:
452
+ batch_size, seq_len = continuous_observations.shape[:2]
453
+ inputs_embeds_observations = self.continuous_encoder(continuous_observations)
454
+ elif discrete_observations is not None:
455
+ batch_size, seq_len = discrete_observations.shape[:2]
456
+ inputs_embeds_observations = self.multi_discrete_encoder(discrete_observations)
457
+ inputs_embeds_observations = torch.cat((inputs_embeds_observations, rewards.unsqueeze(-1)), dim=-1)
458
+ elif image_observations is not None:
459
+ batch_size, seq_len = image_observations.shape[:2]
460
+ inputs_embeds_observations = self.image_encoder(image_observations)
461
+ else:
462
+ raise ValueError("Missing observations.")
463
+ if continuous_actions is not None:
464
+ inputs_embeds_actions = self.continuous_encoder(continuous_actions)
465
+ elif discrete_actions is not None:
466
+ inputs_embeds_actions = self.single_discrete_encoder(discrete_actions)
467
+ else:
468
+ raise ValueError("Missing actions.")
469
+
470
+ # Concatenate observations and actions
471
+ inputs_embeds = torch.cat((inputs_embeds_observations, inputs_embeds_actions), dim=2)
472
+ inputs_embeds = inputs_embeds.view(batch_size, 2 * seq_len, self.config.hidden_size)
473
+ if attention_mask is not None:
474
+ attention_mask = torch.repeat_interleave(attention_mask, repeats=2, dim=1)
475
+ return inputs_embeds, attention_mask
476
+
477
+ def output_textual(
478
+ self,
479
+ transformer_outputs,
480
+ input_ids: Optional[LongTensor] = None,
481
+ attention_mask: Optional[BoolTensor] = None,
482
+ return_loss: bool = True,
483
+ return_dict: Optional[bool] = None,
484
+ ):
485
+ hidden_states = transformer_outputs[0]
486
+ loss = None
487
+ # Get only textual hidden states
488
+ lm_logits = self.single_discrete_decoder(hidden_states)
489
+ if return_loss:
490
+ if input_ids is None:
491
+ raise ValueError("Input IDs must be provided when `return_loss=True`.")
492
+
493
+ # Shift so that tokens < n predict n
494
+ num_text_tokens = input_ids.shape[1]
495
+ shift_logits = lm_logits[:, -num_text_tokens:-1, :].contiguous()
496
+ shift_labels = input_ids[:, 1:].contiguous()
497
+ if attention_mask is not None:
498
+ shift_attention_mask = attention_mask[:, -num_text_tokens:]
499
+ shift_attention_mask = shift_attention_mask[:, 1:]
500
+ else:
501
+ shift_attention_mask = torch.ones(shift_labels.shape, dtype=bool, device=self.device)
502
+ shift_logits = shift_logits[shift_attention_mask.bool()]
503
+ shift_labels = shift_labels[shift_attention_mask.bool()]
504
+ loss_fct = nn.CrossEntropyLoss()
505
+ loss = loss_fct(shift_logits, shift_labels)
506
+
507
+ if not return_dict:
508
+ output = (lm_logits,) + transformer_outputs[1:]
509
+ return ((loss,) + output) if loss is not None else output
510
+
511
+ return JatOutput(
512
+ loss=loss,
513
+ logits=lm_logits,
514
+ past_key_values=transformer_outputs.past_key_values,
515
+ hidden_states=transformer_outputs.hidden_states,
516
+ attentions=transformer_outputs.attentions,
517
+ )
518
+
519
+ def output_rl(
520
+ self,
521
+ transformer_outputs,
522
+ continuous_observations: Optional[FloatTensor] = None,
523
+ discrete_observations: Optional[LongTensor] = None,
524
+ image_observations: Optional[FloatTensor] = None,
525
+ continuous_actions: Optional[FloatTensor] = None,
526
+ discrete_actions: Optional[LongTensor] = None,
527
+ rewards: Optional[FloatTensor] = None,
528
+ attention_mask: Optional[BoolTensor] = None,
529
+ return_loss: bool = True,
530
+ return_dict: Optional[bool] = None,
531
+ loss_weight: Optional[FloatTensor] = None,
532
+ ):
533
+ hidden_states = transformer_outputs.last_hidden_state
534
+ loss, observation_loss, action_loss = None, None, None
535
+ # Observations
536
+ assert rewards is not None
537
+ observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None
538
+ if continuous_observations is not None:
539
+ if self.observation_loss_coef == 0.0:
540
+ warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
541
+ pred_observations = None
542
+ observation_loss = 0.0
543
+ else:
544
+ obs_size = continuous_observations.shape[-1]
545
+ continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1)
546
+ continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size)
547
+ pred_observations = self.continuous_decoder(hidden_states[:, 1::2])
548
+ if return_loss:
549
+ observation_loss = compute_mse_loss(
550
+ pred_observations[:, :-1],
551
+ continuous_observations[:, 1:],
552
+ observations_mask[:, 1:] if observations_mask is not None else None,
553
+ weights=loss_weight[:, 1:] if loss_weight is not None else None,
554
+ )
555
+ pred_observations = pred_observations[..., :obs_size]
556
+ elif discrete_observations is not None: # Note: reward is not predicted
557
+ if self.observation_loss_coef == 0.0:
558
+ warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
559
+ pred_observations = None
560
+ observation_loss = 0.0
561
+ else:
562
+ warnings.warn("Discrete observations prediction are not supported yet.") # way too expensive
563
+ pred_observations = None
564
+ observation_loss = 0.0
565
+ # pred_observations = self.multi_discrete_decoder(hidden_states[:, 1::2])
566
+ # if return_loss:
567
+ # observation_loss = compute_ce_loss(
568
+ # pred_observations[:, :-1],
569
+ # discrete_observations[:, 1:],
570
+ # observations_mask[:, 1:] if observations_mask is not None else None,
571
+ # weights=loss_weight[:, 1:] if loss_weight is not None else None,
572
+ # )
573
+ elif image_observations is not None:
574
+ if self.observation_loss_coef == 0.0:
575
+ warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
576
+ pred_observations = None
577
+ observation_loss = 0.0
578
+ else:
579
+ pred_observations = self.image_decoder(hidden_states[:, 1::2])
580
+ if return_loss:
581
+ observation_loss = compute_mse_loss(
582
+ pred_observations[:, :-1],
583
+ image_observations[:, 1:],
584
+ observations_mask[:, 1:] if observations_mask is not None else None,
585
+ weights=loss_weight[:, 1:] if loss_weight is not None else None,
586
+ )
587
+
588
+ # Actions
589
+ actions_mask = attention_mask[:, ::2] if attention_mask is not None else None
590
+ if continuous_actions is not None:
591
+ act_size = continuous_actions.shape[-1]
592
+ continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
593
+ pred_actions = self.continuous_decoder(hidden_states[:, ::2])
594
+ if return_loss:
595
+ action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight)
596
+ pred_actions = pred_actions[..., :act_size]
597
+ elif discrete_actions is not None:
598
+ pred_actions = self.single_discrete_decoder(hidden_states[:, ::2])
599
+ if return_loss:
600
+ action_loss = compute_ce_loss(pred_actions, discrete_actions, actions_mask, weights=loss_weight)
601
+
602
+ # Return output
603
+ if return_loss:
604
+ loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss
605
+
606
+ if not return_dict:
607
+ output = (pred_observations, pred_actions) + transformer_outputs[1:]
608
+ return ((loss, observation_loss, action_loss) + output) if loss is not None else output
609
+
610
+ return JatOutput(
611
+ loss=loss,
612
+ observation_loss=observation_loss,
613
+ action_loss=action_loss,
614
+ pred_observations=pred_observations,
615
+ pred_actions=pred_actions,
616
+ past_key_values=transformer_outputs.past_key_values,
617
+ hidden_states=transformer_outputs.hidden_states,
618
+ attentions=transformer_outputs.attentions,
619
+ )
620
+
621
+ def forward(
622
+ self,
623
+ input_ids: Optional[LongTensor] = None,
624
+ pixel_values: Optional[FloatTensor] = None,
625
+ continuous_observations: Optional[FloatTensor] = None,
626
+ discrete_observations: Optional[LongTensor] = None,
627
+ image_observations: Optional[FloatTensor] = None,
628
+ continuous_actions: Optional[FloatTensor] = None,
629
+ discrete_actions: Optional[LongTensor] = None,
630
+ rewards: Optional[FloatTensor] = None,
631
+ past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
632
+ attention_mask: Optional[BoolTensor] = None,
633
+ token_type_ids: Optional[LongTensor] = None,
634
+ position_ids: Optional[LongTensor] = None,
635
+ return_loss: bool = True,
636
+ use_cache: Optional[bool] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None,
640
+ loss_weight: Optional[FloatTensor] = None,
641
+ ) -> JatOutput:
642
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
643
+
644
+ # Textual tasks
645
+ if input_ids is not None or pixel_values is not None:
646
+ inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask)
647
+ # RL tasks
648
+ elif (
649
+ continuous_observations is not None or discrete_observations is not None or image_observations is not None
650
+ ):
651
+ inputs_embeds, attention_mask = self.embed_rl(
652
+ continuous_observations,
653
+ discrete_observations,
654
+ image_observations,
655
+ continuous_actions,
656
+ discrete_actions,
657
+ rewards,
658
+ attention_mask,
659
+ )
660
+ else:
661
+ raise ValueError("Input not provided.")
662
+
663
+ # Pass through transformer
664
+ transformer_outputs = self.transformer(
665
+ past_key_values=past_key_values,
666
+ attention_mask=attention_mask,
667
+ token_type_ids=token_type_ids,
668
+ position_ids=position_ids,
669
+ inputs_embeds=inputs_embeds,
670
+ use_cache=use_cache,
671
+ output_attentions=output_attentions,
672
+ output_hidden_states=output_hidden_states,
673
+ return_dict=return_dict,
674
+ )
675
+
676
+ if input_ids is not None or pixel_values is not None:
677
+ return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict)
678
+ else:
679
+ return self.output_rl(
680
+ transformer_outputs,
681
+ continuous_observations,
682
+ discrete_observations,
683
+ image_observations,
684
+ continuous_actions,
685
+ discrete_actions,
686
+ rewards,
687
+ attention_mask,
688
+ return_loss,
689
+ return_dict,
690
+ loss_weight,
691
+ )
692
+
693
+ def reset_rl(self):
694
+ self._last_key_values = None
695
+ self.last_discrete_observation = None
696
+ self.last_continuous_observation = None
697
+ self.last_text_observation = None
698
+ self.last_image_observation = None
699
+ self.last_discrete_action = None
700
+ self.last_continuous_action = None
701
+ self.last_reward = None
702
+
703
+ @torch.no_grad()
704
+ def get_next_action(
705
+ self,
706
+ processor: JatProcessor,
707
+ continuous_observation: Optional[List[float]] = None,
708
+ discrete_observation: Optional[List[int]] = None,
709
+ text_observation: Optional[str] = None,
710
+ image_observation: Optional[np.ndarray] = None,
711
+ action_space: Union[spaces.Box, spaces.Discrete] = None,
712
+ reward: Optional[float] = None,
713
+ deterministic: bool = False,
714
+ ):
715
+ # Get the maximum sequence length
716
+ max_length = self.config.max_position_embeddings // 2
717
+
718
+ # Convert everything to lists
719
+ def to_list(x):
720
+ return x.tolist() if isinstance(x, np.ndarray) else x
721
+
722
+ continuous_observation = to_list(continuous_observation)
723
+ discrete_observation = to_list(discrete_observation)
724
+
725
+ # Add a fake action to the end of the sequence
726
+ if isinstance(action_space, spaces.Box):
727
+ fake_continuous_action = [0.0 for _ in range(action_space.shape[0])]
728
+ fake_discrete_action = None
729
+ elif isinstance(action_space, spaces.Discrete):
730
+ fake_continuous_action = None
731
+ fake_discrete_action = 0
732
+
733
+ continuous_observations = [continuous_observation] if continuous_observation is not None else None
734
+ discrete_observations = [discrete_observation] if discrete_observation is not None else None
735
+ text_observations = [text_observation] if text_observation is not None else None
736
+ image_observations = [image_observation] if image_observation is not None else None
737
+ continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None
738
+ discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None
739
+ rewards = [reward] if reward is not None else [0.0]
740
+
741
+ if self._last_key_values is not None:
742
+ # We concatenate the last observation with the current one
743
+ continuous_observations = (
744
+ [self.last_continuous_observation] + continuous_observations
745
+ if continuous_observations is not None
746
+ else None
747
+ )
748
+ discrete_observations = (
749
+ [self.last_discrete_observation] + discrete_observations if discrete_observations is not None else None
750
+ )
751
+ text_observations = (
752
+ [self.last_text_observation] + text_observations if text_observations is not None else None
753
+ )
754
+ image_observations = (
755
+ [self.last_image_observation] + image_observations if image_observations is not None else None
756
+ )
757
+ continuous_actions = (
758
+ [self.last_continuous_action] + continuous_actions if continuous_actions is not None else None
759
+ )
760
+ discrete_actions = [self.last_discrete_action] + discrete_actions if discrete_actions is not None else None
761
+ rewards = [self.last_reward] + rewards
762
+
763
+ # Store the last observation
764
+ self.last_continuous_observation = continuous_observations[-1] if continuous_observations is not None else None
765
+ self.last_discrete_observation = discrete_observations[-1] if discrete_observations is not None else None
766
+ self.last_text_observation = text_observations[-1] if text_observations is not None else None
767
+ self.last_image_observation = image_observations[-1] if image_observations is not None else None
768
+ self.last_reward = rewards[-1]
769
+
770
+ # Add the batch dimension
771
+ continuous_observations = [continuous_observations] if continuous_observations is not None else None
772
+ discrete_observations = [discrete_observations] if discrete_observations is not None else None
773
+ text_observations = [text_observations] if text_observations is not None else None
774
+ image_observations = [image_observations] if image_observations is not None else None
775
+ continuous_actions = [continuous_actions] if continuous_actions is not None else None
776
+ discrete_actions = [discrete_actions] if discrete_actions is not None else None
777
+ rewards = [rewards]
778
+
779
+ # Process the inputs
780
+ processed = processor(
781
+ continuous_observations=continuous_observations,
782
+ discrete_observations=discrete_observations,
783
+ text_observations=text_observations,
784
+ image_observations=image_observations,
785
+ continuous_actions=continuous_actions,
786
+ discrete_actions=discrete_actions,
787
+ rewards=rewards,
788
+ truncation=True,
789
+ truncation_side="left",
790
+ max_length=max_length,
791
+ return_tensors="pt",
792
+ )
793
+ processed.to(self.device)
794
+
795
+ # Forward pass
796
+ outputs = self(**processed, past_key_values=self._last_key_values, return_loss=False)
797
+
798
+ # Truncate the past key-values
799
+ self._last_key_values = tuple(
800
+ tuple(pkv[:, :, -self.config.max_position_embeddings + 2 :] for pkv in pkvs)
801
+ for pkvs in outputs.past_key_values
802
+ )
803
+ # Store the last key values
804
+ # We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
805
+ self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
806
+
807
+ # Return the predicted action
808
+ if continuous_actions is not None:
809
+ self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
810
+ return self.last_continuous_action
811
+ elif discrete_actions is not None:
812
+ logits = outputs.pred_actions[0, -1, : action_space.n]
813
+ if deterministic:
814
+ self.last_discrete_action = logits.argmax().cpu().item()
815
+ else: # sample
816
+ self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[0].item()
817
+ return self.last_discrete_action
818
+
819
+ # Allows to use .generate()
820
+ def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs):
821
+ # only last token for inputs_ids if past is defined in kwargs
822
+ if past_key_values is not None:
823
+ pixel_values = None
824
+ input_ids = input_ids[:, -1].unsqueeze(-1)
825
+
826
+ model_inputs = {
827
+ "input_ids": input_ids,
828
+ "pixel_values": pixel_values,
829
+ "past_key_values": past_key_values,
830
+ "use_cache": kwargs.get("use_cache"),
831
+ }
832
+
833
+ return model_inputs
834
+
835
+
836
+ JatModel.register_for_auto_class("AutoModelForCausalLM")
processing_jat.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from transformers import BatchEncoding
8
+ from transformers.processing_utils import ProcessorMixin
9
+
10
+
11
+ def to_tensor(x):
12
+ """
13
+ Convert a nested structure of numpy arrays or tensors (including lists and tuples of them)
14
+ into a tensor. Assumes that all nested structures can be converted into a tensor directly.
15
+
16
+ :param x: Nested structure containing numpy arrays, tensors, lists, or tuples
17
+ :return: torch.Tensor
18
+ """
19
+ with warnings.catch_warnings():
20
+ # Convert specific warning to an error
21
+ warnings.filterwarnings(
22
+ "error",
23
+ category=UserWarning,
24
+ message=".*Creating a tensor from a list of numpy.ndarrays is extremely slow.*",
25
+ )
26
+ try:
27
+ return torch.Tensor(x)
28
+ except Exception:
29
+ if isinstance(x, list):
30
+ return torch.stack([to_tensor(item) for item in x])
31
+ else:
32
+ raise TypeError("Unsupported type for conversion to tensor")
33
+
34
+
35
+ def truncate(
36
+ encoding: Dict[str, List[List[Any]]], max_length: int, truncation_side: str = "right", preserve: bool = False
37
+ ) -> Dict[str, List[List[Any]]]:
38
+ """
39
+ Truncate the sequences in the encoding to the specified maximum length.
40
+
41
+ This function is designed to process batch of sequences represented in the encoding dictionary.
42
+ Depending on the chosen strategy, sequences are either truncated with loss of residual data or with preservation
43
+ and incorporation of residual data into the batch.
44
+
45
+ Args:
46
+ encoding (`Mapping`):
47
+ A dictionary where each key-value pair consists of a feature name and its corresponding batch of sequences.
48
+ The sequences are expected to be lists.
49
+ max_length (`int`):
50
+ The maximum allowable length for the sequences.
51
+ truncation_side (`str`, **optional**):
52
+ The strategy to use for truncation. Can be `"left"` or `"right"`. Defaults to `"right"`.
53
+ preserve (`bool`, **optional**):
54
+ Whether to preserve the residual data by adding them as new sequences in the batch. Defaults to `False`.
55
+
56
+ Returns:
57
+ `Dict[str, List[List[Any]]]`:
58
+ A dictionary with the same keys as the input `encoding`, containing the truncated batch of sequences.
59
+ If `preserve` is set to `True`, the batch size may increase due to the addition of new sequences formed
60
+ from the residual data.
61
+
62
+ Example:
63
+
64
+ >>> encoding = {'feature1': [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}
65
+ >>> truncate(encoding, 3, preserve=False)
66
+ {'feature1': [[1, 2, 3], [6, 7, 8]]}
67
+
68
+ >>> truncate(encoding, 3, preserve=True)
69
+ {'feature1': [[1, 2, 3], [4, 5], [6, 7, 8], [9, 10]]}
70
+ """
71
+ truncated_encoding = {}
72
+
73
+ for key, sequences in encoding.items():
74
+ if not all(isinstance(seq, list) for seq in sequences):
75
+ raise TypeError(f"All sequences under key {key} should be of type list.")
76
+
77
+ truncated_sequences = []
78
+
79
+ for seq in sequences:
80
+ if len(seq) <= max_length:
81
+ truncated_sequences.append(seq)
82
+ continue
83
+
84
+ if preserve: # truncate and append the residual as new sequences
85
+ if truncation_side == "right":
86
+ truncated_sequences.extend([seq[i : i + max_length] for i in range(0, len(seq), max_length)])
87
+ elif truncation_side == "left":
88
+ n = len(seq) // max_length + int(len(seq) % max_length > 0)
89
+ low, high = len(seq) - n * max_length, len(seq)
90
+ truncated_sequences.extend(
91
+ [seq[max(0, i - max_length) : i] for i in range(high, low, -max_length)]
92
+ )
93
+ else:
94
+ raise ValueError(f"Invalid truncation_side: {truncation_side}")
95
+ else: # simply truncate the sequence
96
+ if truncation_side == "right":
97
+ truncated_sequences.append(seq[:max_length])
98
+ elif truncation_side == "left":
99
+ truncated_sequences.append(seq[-max_length:])
100
+
101
+ truncated_encoding[key] = truncated_sequences
102
+
103
+ return truncated_encoding
104
+
105
+
106
+ def pad(encoding: Dict[str, List[List[Any]]], target_length: int) -> Dict[str, List[List[Any]]]:
107
+ """
108
+ Pad the sequences in the encoding to the specified maximum length.
109
+
110
+ This function is designed to process batch of sequences represented in the encoding dictionary.
111
+ The padding value is set to be the first element in the sequence.
112
+
113
+ Args:
114
+ encoding (`Mapping`):
115
+ A dictionary where each key-value pair consists of a feature name and its corresponding batch of sequences.
116
+ The sequences are expected to be lists.
117
+ target_length (`int`):
118
+ The desired length for the sequences.
119
+
120
+ Returns:
121
+ `Dict[str, List[List[Any]]]`:
122
+ A dictionary with the same keys as the input `encoding`, containing the padded batch of sequences.
123
+ An additional key `attention_mask` is added to the dictionary to indicate the positions of the non-padding
124
+ elements with 1s and the padding elements with 0s. If the input `encoding` already contains an
125
+ `attention_mask` key, the corresponding mask will be updated such that the original masking is preserved,
126
+ and the newly added padding elements will be masked with 0s. In other words, the resulting
127
+ `attention_mask` is a logical "AND" between the provided mask and the mask created due to padding, ensuring
128
+ that any element masked originally remains masked.
129
+
130
+ Example:
131
+
132
+ >>> encoding = {'feature1': [[1, 2], [3, 4, 5]]}
133
+ >>> pad(encoding, 4)
134
+ {'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 1, 0, 0], [1, 1, 1, 0]]}
135
+
136
+ >>> encoding = {'feature1': [[1, 2], [3, 4, 5]], "attention_mask": [[1, 0], [0, 1, 1]]}
137
+ >>> pad(encoding, 4)
138
+ {'feature1': [[1, 2, 1, 1], [3, 4, 5, 3]], 'attention_mask': [[1, 0, 0, 0], [0, 1, 1, 0]]}
139
+ """
140
+ padded_encoding = {}
141
+
142
+ for key, sequences in encoding.items():
143
+ if not all(isinstance(seq, (list, torch.Tensor)) for seq in sequences):
144
+ raise TypeError(f"All sequences under key {key} should be of type list or tensor.")
145
+ if key == "attention_mask": # attention_mask is handled separately
146
+ continue
147
+
148
+ padded_sequences = []
149
+ pad_mask = []
150
+
151
+ for seq in sequences:
152
+ pad_len = target_length - len(seq)
153
+ padded_seq = list(seq) + [seq[0]] * max(0, pad_len)
154
+ mask = [1] * len(seq) + [0] * max(0, pad_len)
155
+
156
+ padded_sequences.append(padded_seq)
157
+ pad_mask.append(mask)
158
+
159
+ padded_encoding[key] = padded_sequences
160
+
161
+ if "attention_mask" in encoding:
162
+ padded_encoding["attention_mask"] = [
163
+ [a * (b[i] if i < len(b) else 0) for i, a in enumerate(row)]
164
+ for row, b in zip(pad_mask, encoding["attention_mask"])
165
+ ]
166
+ else:
167
+ padded_encoding["attention_mask"] = pad_mask
168
+
169
+ return padded_encoding
170
+
171
+
172
+ class JatProcessor(ProcessorMixin):
173
+ r"""
174
+ JAT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
175
+
176
+ [`JatProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BertTokenizerFast`]. See the
177
+ [`~JatProcessor.__call__`] and [`~JatProcessor.decode`] for more information.
178
+
179
+ Args:
180
+ image_processor ([`AutoImageProcessor`]):
181
+ The image processor is a required input.
182
+ tokenizer ([`AutoTokenizer`]):
183
+ The tokenizer is a required input.
184
+ """
185
+ attributes = ["image_processor", "tokenizer"]
186
+ image_processor_class = "AutoImageProcessor"
187
+ tokenizer_class = "AutoTokenizer"
188
+
189
+ DONT_TRUNCATE_OR_PAD = {"pixel_values"} # Or, a better name for this would be
190
+
191
+ def __init__(self, image_processor, tokenizer):
192
+ super().__init__(image_processor, tokenizer)
193
+ self.current_processor = self.image_processor
194
+
195
+ def _truncate_and_pad(
196
+ self,
197
+ encoding: dict,
198
+ padding: Union[bool, str],
199
+ truncation: Union[bool, str],
200
+ truncation_side: str = "right",
201
+ max_length: Optional[int] = None,
202
+ ) -> dict:
203
+ # If max_length is not provided, use the maximum length accepted by the model.
204
+ if max_length is None:
205
+ max_length = self.tokenizer.model_max_length
206
+
207
+ # Exclude keys that we don't want to truncate or pad.
208
+ excluded = {key: value for key, value in encoding.items() if key in self.DONT_TRUNCATE_OR_PAD}
209
+ encoding = {key: value for key, value in encoding.items() if key not in self.DONT_TRUNCATE_OR_PAD}
210
+
211
+ # Apply Truncation
212
+ if truncation in [True, "lossy"]:
213
+ encoding = truncate(encoding, max_length, truncation_side, preserve=False)
214
+ elif truncation == "preserve":
215
+ encoding = truncate(encoding, max_length, truncation_side, preserve=True)
216
+ elif truncation in [False, "do_not_truncate"]:
217
+ pass
218
+ else:
219
+ raise ValueError("Invalid truncation strategy:" + str(truncation))
220
+
221
+ # Apply Padding
222
+ if padding in [True, "longest"]:
223
+ target_length = max(len(seq) for sequences in encoding.values() for seq in sequences)
224
+ encoding = pad(encoding, target_length)
225
+ elif padding == "max_length":
226
+ encoding = pad(encoding, max_length)
227
+ elif padding in [False, "do_not_pad"]:
228
+ pass
229
+ else:
230
+ raise ValueError("Invalid padding strategy:" + str(padding))
231
+
232
+ # Add back the excluded keys.
233
+ encoding.update(excluded)
234
+
235
+ # Particular case, we handle the conversion to tensor of image_observations, as the format used
236
+ # (list of tensors) is not properly handled by the BatchEncoding class:
237
+ if "image_observations" in encoding:
238
+ encoding["image_observations"] = to_tensor(encoding["image_observations"])
239
+
240
+ return encoding
241
+
242
+ def __call__(
243
+ self,
244
+ text=None,
245
+ images=None,
246
+ continuous_observations=None,
247
+ discrete_observations=None,
248
+ text_observations=None,
249
+ image_observations=None,
250
+ continuous_actions=None,
251
+ discrete_actions=None,
252
+ rewards=None,
253
+ return_tensors=None,
254
+ **kwargs,
255
+ ):
256
+ """
257
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
258
+ and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
259
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
260
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
261
+ of the above two methods for more information.
262
+
263
+ Args:
264
+ text (`str`, `List[str]`, `List[List[str]]`):
265
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
266
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
267
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
268
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`,
269
+ `List[np.ndarray]`, `List[torch.Tensor]`):
270
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
271
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
272
+ number of channels, H and W are image height and width.
273
+ continuous_observations (`List[List[List[float]]]`):
274
+ The continuous observations or batch of continuous observations to be encoded.
275
+ discrete_observations (`List[List[List[int]]]`):
276
+ The discrete observations or batch of discrete observations to be encoded.
277
+ text_observations (`List[List[str]]`):
278
+ The text observations or batch of text observations to be encoded.
279
+ image_observations (`List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]`):
280
+ The image observations or batch of image observations to be encoded.
281
+ continuous_actions (`List[List[List[float]]]`):
282
+ The continuous actions or batch of continuous actions to be encoded.
283
+ discrete_actions (``List[List[int]]`):
284
+ The discrete actions or batch of discrete actions to be encoded.
285
+ rewards (``List[List[float]]`):
286
+ The rewards or batch of rewards to be encoded.
287
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
288
+ If set, will return tensors of a particular framework. Acceptable values are:
289
+
290
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
291
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
292
+ - `'np'`: Return NumPy `np.ndarray` objects.
293
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
294
+
295
+ Returns:
296
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
297
+
298
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
299
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
300
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
301
+ `None`).
302
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
303
+ """
304
+ # we truncate and pad ourselves so we need to pass padding=False and truncation=False to the tokenizer
305
+ padding = kwargs.pop("padding", False)
306
+ truncation = kwargs.pop("truncation", False)
307
+ truncation_side = kwargs.pop("truncation_side", "right")
308
+ max_length = kwargs.pop("max_length", None)
309
+
310
+ # Ensure that the input is batched
311
+ if text is not None and not isinstance(text, list):
312
+ text = [text]
313
+
314
+ encoding = {}
315
+ if text is not None:
316
+ encoding["input_ids"] = self.tokenizer(text, **kwargs)["input_ids"]
317
+ if images is not None:
318
+ encoding["pixel_values"] = self.image_processor(images, **kwargs).pixel_values
319
+ if continuous_observations is not None:
320
+ encoding["continuous_observations"] = copy.deepcopy(continuous_observations)
321
+ if discrete_observations is not None:
322
+ encoding["discrete_observations"] = copy.deepcopy(discrete_observations)
323
+ if text_observations is not None:
324
+ if "discrete_observations" not in encoding:
325
+ raise ValueError("discrete_observations must be provided if text_observations is provided")
326
+ for batch_idx, sequence in enumerate(text_observations):
327
+ encoded_text = self.tokenizer(sequence, max_length=64, padding="max_length")["input_ids"]
328
+ for timestep, text_tokens in enumerate(encoded_text):
329
+ encoding["discrete_observations"][batch_idx][timestep].extend(text_tokens)
330
+ if image_observations is not None:
331
+ image_observations = [[(F.to_tensor(im) - 0.5) / 0.5 for im in ep] for ep in image_observations]
332
+ encoding["image_observations"] = image_observations
333
+ if continuous_actions is not None:
334
+ encoding["continuous_actions"] = copy.deepcopy(continuous_actions)
335
+ if discrete_actions is not None:
336
+ encoding["discrete_actions"] = copy.deepcopy(discrete_actions)
337
+
338
+ if rewards is not None:
339
+ encoding["rewards"] = [[float(r) for r in ep] for ep in rewards]
340
+
341
+ # Handle image+text case, need to reduce the max_len as the image and text will be concatenated
342
+ if text is not None and images is not None:
343
+ if max_length is None:
344
+ max_length = self.tokenizer.model_max_length
345
+ max_length -= (224 // 16) ** 2 # substract the number of image tokens
346
+ elif (
347
+ continuous_observations is not None
348
+ or discrete_observations is not None
349
+ or text_observations is not None
350
+ or image_observations is not None
351
+ ):
352
+ if max_length is None:
353
+ max_length = self.tokenizer.model_max_length
354
+ max_length //= 2 # observations and actions are interleaved
355
+
356
+ encoding = self._truncate_and_pad(encoding, padding, truncation, truncation_side, max_length)
357
+
358
+ return BatchEncoding(encoding, tensor_type=return_tensors)
359
+
360
+ def batch_decode(self, *args, **kwargs):
361
+ """
362
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
363
+ refer to the docstring of this method for more information.
364
+ """
365
+ return self.tokenizer.batch_decode(*args, **kwargs)
366
+
367
+ def decode(self, *args, **kwargs):
368
+ """
369
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
370
+ the docstring of this method for more information.
371
+ """
372
+ return self.tokenizer.decode(*args, **kwargs)
373
+
374
+ def pad(self, *args, **kwargs):
375
+ inputs = args[0]
376
+ keys = [key for key in inputs[0].keys() if inputs[0][key] is not None]
377
+ inputs = {key: [arg[key] for arg in inputs] for key in keys}
378
+ elmt = next(iter(inputs.values()))
379
+ if isinstance(elmt[0], torch.Tensor) and not isinstance(elmt, torch.Tensor):
380
+ encoding = {key: torch.stack(inputs[key]) for key in inputs.keys()}
381
+ else:
382
+ encoding = self._truncate_and_pad(
383
+ inputs, padding=kwargs.get("padding", False), truncation=False, max_length=kwargs.get("max_length")
384
+ )
385
+
386
+ return BatchEncoding(encoding, tensor_type=kwargs.get("return_tensors"))
387
+
388
+ @property
389
+ def model_input_names(self):
390
+ return [
391
+ "input_ids",
392
+ "attention_mask",
393
+ "pixel_values",
394
+ "continuous_observations",
395
+ "discrete_observations",
396
+ "image_observations",
397
+ "continuous_actions",
398
+ "discrete_actions",
399
+ "rewards",
400
+ ]
401
+
402
+
403
+ JatProcessor.register_for_auto_class("AutoProcessor")