koukyo1994 commited on
Commit
21e5dd0
·
verified ·
1 Parent(s): e999941

Upload folder using huggingface_hub

Browse files
configuration_llama_action.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig
2
+
3
+
4
+ class LlamaActionConfig(LlamaConfig):
5
+ model_type = "llama_action"
6
+
7
+ def __init__(self, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.num_spatio_embeddings = kwargs.get("num_spatio_embeddings", 582)
10
+ self.num_temporal_embeddings = kwargs.get("num_temporal_embeddings", 25)
11
+ self.num_action_embeddings = kwargs.get("num_action_tokens", 5)
12
+ self.num_image_patches = kwargs.get("num_image_patches", 576)
13
+ self.action_dim = kwargs.get("action_dim", 3)
modeling_llama_action.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import LlamaForCausalLM
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+
8
+ from .configuration_llama_action import LlamaActionConfig
9
+
10
+
11
+ class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
12
+ def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
13
+ super().__init__()
14
+ self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
15
+ self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
16
+ self.num_spatio_embeddings = num_spatio_embeddings
17
+ self.num_temporal_embeddings = num_temporal_embeddings
18
+
19
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int):
20
+ seq_length = attention_mask.size(1)
21
+ batch_size = attention_mask.size(0)
22
+
23
+ if past_key_values_length == 0:
24
+ # create a tensor of the form [0, 1, 2, ..., num_spatio_embeddings-1]
25
+ spatio_indices = torch.arange(
26
+ self.num_spatio_embeddings,
27
+ device=attention_mask.device
28
+ ).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))
29
+
30
+ # create a tensor of the form [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]
31
+ temporal_indices = torch.arange(
32
+ self.num_temporal_embeddings,
33
+ device=attention_mask.device
34
+ ).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))
35
+
36
+ spatio_indices = spatio_indices[:, :seq_length]
37
+ temporal_indices = temporal_indices[:, :seq_length]
38
+
39
+ else:
40
+ temporal_index = past_key_values_length // self.num_spatio_embeddings
41
+ spatio_index = past_key_values_length % self.num_spatio_embeddings
42
+ spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
43
+ temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))
44
+
45
+ return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)
46
+
47
+
48
+ class LlamaActionForCausalLM(LlamaForCausalLM):
49
+ config_class = LlamaActionConfig
50
+
51
+ def __init__(self, config: LlamaActionConfig):
52
+ super().__init__(config)
53
+
54
+ self.num_spatio_embeddings = config.num_spatio_embeddings
55
+ self.num_temporal_embeddings = config.num_temporal_embeddings
56
+ self.num_image_patches = config.num_image_patches
57
+ self.num_action_embeddings = config.num_action_embeddings
58
+
59
+ self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
60
+ config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
61
+ )
62
+
63
+ self.action_projection = nn.Linear(config.action_dim, config.hidden_size)
64
+
65
+ self.post_init()
66
+
67
+ def forward(
68
+ self,
69
+ input_ids: Optional[torch.Tensor] = None,
70
+ actions: Optional[torch.Tensor] = None,
71
+ attention_mask: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ labels: Optional[torch.Tensor] = None,
75
+ past_key_values: Optional[List[torch.Tensor]] = None,
76
+ use_cache: Optional[bool] = None,
77
+ output_attentions: Optional[bool] = None,
78
+ output_hidden_states: Optional[bool] = None,
79
+ return_dict: Optional[bool] = None,
80
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
81
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
+ if labels is not None:
83
+ use_cache = False
84
+
85
+ if input_ids is not None and inputs_embeds is not None:
86
+ raise ValueError(
87
+ "You cannot specify both input_ids and inputs_embeds at the same time"
88
+ )
89
+ elif input_ids is not None:
90
+ pass
91
+ elif inputs_embeds is not None:
92
+ pass
93
+ else:
94
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
95
+
96
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
+ if past_key_values is None or len(past_key_values) == 0:
98
+ inputs_embeds_list = torch.split(
99
+ inputs_embeds,
100
+ split_size_or_sections=self.num_image_patches,
101
+ dim=1
102
+ )
103
+ actions_list = torch.split(
104
+ actions,
105
+ split_size_or_sections=self.num_action_embeddings,
106
+ dim=1
107
+ )
108
+
109
+ embeddings = []
110
+ if len(inputs_embeds_list) == len(actions_list):
111
+ # mostly used in training phase
112
+ for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
113
+ action_features = self.action_projection(action_embeds)
114
+ embeddings.append(inputs_embeds)
115
+ embeddings.append(action_features)
116
+ elif len(inputs_embeds_list) < len(actions_list):
117
+ # used in inference phase (mostly)
118
+ for i, inputs_embeds in enumerate(inputs_embeds_list):
119
+ embeddings.append(inputs_embeds)
120
+ if i < len(inputs_embeds_list) - 1:
121
+ # the last frame might be generating image tokens, so we don't add action embedding
122
+ action_embeds = self.action_projection(actions_list[i])
123
+ embeddings.append(action_embeds)
124
+ if inputs_embeds_list[-1].size(1) == self.num_image_patches:
125
+ # if the last frame has generated all image tokens, we add action embedding
126
+ action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
127
+ embeddings.append(action_embeds)
128
+ else:
129
+ if isinstance(past_key_values, tuple):
130
+ past_key_values_length = past_key_values[0][0].size(2)
131
+ else:
132
+ past_key_values_length = past_key_values.get_seq_length()
133
+ embeddings = []
134
+ # create an interleaved sequence of image and action embeddings like image, image, ..., image, action, action, ..., action
135
+ # we only generate image tokens, so we add action tokens after generating one frame
136
+ if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
137
+ seq_index = past_key_values_length // self.num_spatio_embeddings + 1
138
+ actions_list = torch.split(
139
+ actions,
140
+ split_size_or_sections=self.num_action_embeddings,
141
+ dim=1
142
+ )
143
+ action_features = self.action_projection(actions_list[seq_index - 1])
144
+ embeddings.append(action_features)
145
+ embeddings.append(inputs_embeds)
146
+ else:
147
+ pass
148
+
149
+ if len(embeddings) > 0:
150
+ inputs_embeds = torch.cat(embeddings, dim=1)
151
+
152
+ # insert spatio-temporal positional embedding
153
+ if past_key_values is not None:
154
+ if isinstance(past_key_values, tuple):
155
+ past_key_values_length = past_key_values[0][0].size(2)
156
+ else:
157
+ past_key_values_length = past_key_values.get_seq_length()
158
+ else:
159
+ past_key_values_length = 0
160
+ inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)
161
+
162
+ outputs = self.model(
163
+ input_ids=None,
164
+ attention_mask=attention_mask,
165
+ position_ids=position_ids,
166
+ past_key_values=past_key_values,
167
+ inputs_embeds=inputs_embeds,
168
+ use_cache=use_cache,
169
+ output_attentions=output_attentions,
170
+ output_hidden_states=output_hidden_states,
171
+ return_dict=return_dict,
172
+ )
173
+
174
+ sequence_output = outputs[0]
175
+ logits = self.lm_head(sequence_output).contiguous()
176
+
177
+ loss = None
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :].contiguous()
180
+ shift_labels = labels[..., 1:].contiguous()
181
+ loss_fct = nn.CrossEntropyLoss()
182
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
183
+
184
+ if not return_dict:
185
+ output = (logits,) + outputs[1:]
186
+ return ((loss,) + output) if loss is not None else output
187
+
188
+ return CausalLMOutputWithPast(
189
+ loss=loss,
190
+ logits=logits,
191
+ past_key_values=outputs.past_key_values,
192
+ hidden_states=outputs.hidden_states,
193
+ attentions=outputs.attentions,
194
+ )
195
+
196
+ def prepare_inputs_for_generation(
197
+ self,
198
+ input_ids,
199
+ past_key_values=None,
200
+ attention_mask=None,
201
+ use_cache=None,
202
+ **kwargs):
203
+ batch_size = input_ids.size(0)
204
+ seq_length = input_ids.size(1)
205
+ n_frames = seq_length // self.num_image_patches
206
+ attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
207
+ if seq_length % self.num_image_patches != 0:
208
+ n_last_frame_tokens = seq_length % self.num_image_patches
209
+ attention_mask_length += n_last_frame_tokens
210
+ else:
211
+ print(f"attempting to generate new frame - frame no: {n_frames + 1}")
212
+ attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
213
+ # cut decoder_input_ids if past_key_values is used
214
+ if past_key_values is not None and len(past_key_values) > 0:
215
+ if isinstance(past_key_values, tuple):
216
+ past_length = past_key_values[0][0].size(2)
217
+ else:
218
+ past_length = past_key_values.get_seq_length()
219
+ if input_ids.size(1) > past_length:
220
+ remove_prefix_length = past_length
221
+ else:
222
+ remove_prefix_length = input_ids.size(1) - 1
223
+ input_ids = input_ids[:, remove_prefix_length:]
224
+ seq_length = input_ids.size(1)
225
+ past_key_values_length = past_length
226
+ mask_seq_length = seq_length + past_key_values_length
227
+ if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
228
+ mask_seq_length += self.num_action_embeddings
229
+ attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)
230
+
231
+ return {
232
+ "input_ids": input_ids,
233
+ "attention_mask": attention_mask,
234
+ "actions": kwargs.get("actions"),
235
+ "past_key_values": past_key_values,
236
+ "use_cache": use_cache,
237
+ }