koukyo1994 commited on
Commit
5e41d6c
·
verified ·
1 Parent(s): 21e5dd0

Delete files world_model/modeling_llama_action.py world_model/configuration_llama_action.py with huggingface_hub

Browse files
world_model/configuration_llama_action.py DELETED
@@ -1,13 +0,0 @@
1
- from transformers import LlamaConfig
2
-
3
-
4
- class LlamaActionConfig(LlamaConfig):
5
- model_type = "llama_action"
6
-
7
- def __init__(self, **kwargs):
8
- super().__init__(**kwargs)
9
- self.num_spatio_embeddings = kwargs.get("num_spatio_embeddings", 582)
10
- self.num_temporal_embeddings = kwargs.get("num_temporal_embeddings", 25)
11
- self.num_action_embeddings = kwargs.get("num_action_tokens", 5)
12
- self.num_image_patches = kwargs.get("num_image_patches", 576)
13
- self.action_dim = kwargs.get("action_dim", 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
world_model/modeling_llama_action.py DELETED
@@ -1,237 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import LlamaForCausalLM
6
- from transformers.modeling_outputs import CausalLMOutputWithPast
7
-
8
- from .configuration_llama_action import LlamaActionConfig
9
-
10
-
11
- class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
12
- def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
13
- super().__init__()
14
- self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
15
- self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
16
- self.num_spatio_embeddings = num_spatio_embeddings
17
- self.num_temporal_embeddings = num_temporal_embeddings
18
-
19
- def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int):
20
- seq_length = attention_mask.size(1)
21
- batch_size = attention_mask.size(0)
22
-
23
- if past_key_values_length == 0:
24
- # create a tensor of the form [0, 1, 2, ..., num_spatio_embeddings-1]
25
- spatio_indices = torch.arange(
26
- self.num_spatio_embeddings,
27
- device=attention_mask.device
28
- ).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))
29
-
30
- # create a tensor of the form [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]
31
- temporal_indices = torch.arange(
32
- self.num_temporal_embeddings,
33
- device=attention_mask.device
34
- ).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))
35
-
36
- spatio_indices = spatio_indices[:, :seq_length]
37
- temporal_indices = temporal_indices[:, :seq_length]
38
-
39
- else:
40
- temporal_index = past_key_values_length // self.num_spatio_embeddings
41
- spatio_index = past_key_values_length % self.num_spatio_embeddings
42
- spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
43
- temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))
44
-
45
- return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)
46
-
47
-
48
- class LlamaActionForCausalLM(LlamaForCausalLM):
49
- config_class = LlamaActionConfig
50
-
51
- def __init__(self, config: LlamaActionConfig):
52
- super().__init__(config)
53
-
54
- self.num_spatio_embeddings = config.num_spatio_embeddings
55
- self.num_temporal_embeddings = config.num_temporal_embeddings
56
- self.num_image_patches = config.num_image_patches
57
- self.num_action_embeddings = config.num_action_embeddings
58
-
59
- self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
60
- config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
61
- )
62
-
63
- self.action_projection = nn.Linear(config.action_dim, config.hidden_size)
64
-
65
- self.post_init()
66
-
67
- def forward(
68
- self,
69
- input_ids: Optional[torch.Tensor] = None,
70
- actions: Optional[torch.Tensor] = None,
71
- attention_mask: Optional[torch.Tensor] = None,
72
- position_ids: Optional[torch.Tensor] = None,
73
- inputs_embeds: Optional[torch.Tensor] = None,
74
- labels: Optional[torch.Tensor] = None,
75
- past_key_values: Optional[List[torch.Tensor]] = None,
76
- use_cache: Optional[bool] = None,
77
- output_attentions: Optional[bool] = None,
78
- output_hidden_states: Optional[bool] = None,
79
- return_dict: Optional[bool] = None,
80
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
81
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
- if labels is not None:
83
- use_cache = False
84
-
85
- if input_ids is not None and inputs_embeds is not None:
86
- raise ValueError(
87
- "You cannot specify both input_ids and inputs_embeds at the same time"
88
- )
89
- elif input_ids is not None:
90
- pass
91
- elif inputs_embeds is not None:
92
- pass
93
- else:
94
- raise ValueError("You have to specify either input_ids or inputs_embeds")
95
-
96
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
- if past_key_values is None or len(past_key_values) == 0:
98
- inputs_embeds_list = torch.split(
99
- inputs_embeds,
100
- split_size_or_sections=self.num_image_patches,
101
- dim=1
102
- )
103
- actions_list = torch.split(
104
- actions,
105
- split_size_or_sections=self.num_action_embeddings,
106
- dim=1
107
- )
108
-
109
- embeddings = []
110
- if len(inputs_embeds_list) == len(actions_list):
111
- # mostly used in training phase
112
- for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
113
- action_features = self.action_projection(action_embeds)
114
- embeddings.append(inputs_embeds)
115
- embeddings.append(action_features)
116
- elif len(inputs_embeds_list) < len(actions_list):
117
- # used in inference phase (mostly)
118
- for i, inputs_embeds in enumerate(inputs_embeds_list):
119
- embeddings.append(inputs_embeds)
120
- if i < len(inputs_embeds_list) - 1:
121
- # the last frame might be generating image tokens, so we don't add action embedding
122
- action_embeds = self.action_projection(actions_list[i])
123
- embeddings.append(action_embeds)
124
- if inputs_embeds_list[-1].size(1) == self.num_image_patches:
125
- # if the last frame has generated all image tokens, we add action embedding
126
- action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
127
- embeddings.append(action_embeds)
128
- else:
129
- if isinstance(past_key_values, tuple):
130
- past_key_values_length = past_key_values[0][0].size(2)
131
- else:
132
- past_key_values_length = past_key_values.get_seq_length()
133
- embeddings = []
134
- # create an interleaved sequence of image and action embeddings like image, image, ..., image, action, action, ..., action
135
- # we only generate image tokens, so we add action tokens after generating one frame
136
- if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
137
- seq_index = past_key_values_length // self.num_spatio_embeddings + 1
138
- actions_list = torch.split(
139
- actions,
140
- split_size_or_sections=self.num_action_embeddings,
141
- dim=1
142
- )
143
- action_features = self.action_projection(actions_list[seq_index - 1])
144
- embeddings.append(action_features)
145
- embeddings.append(inputs_embeds)
146
- else:
147
- pass
148
-
149
- if len(embeddings) > 0:
150
- inputs_embeds = torch.cat(embeddings, dim=1)
151
-
152
- # insert spatio-temporal positional embedding
153
- if past_key_values is not None:
154
- if isinstance(past_key_values, tuple):
155
- past_key_values_length = past_key_values[0][0].size(2)
156
- else:
157
- past_key_values_length = past_key_values.get_seq_length()
158
- else:
159
- past_key_values_length = 0
160
- inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)
161
-
162
- outputs = self.model(
163
- input_ids=None,
164
- attention_mask=attention_mask,
165
- position_ids=position_ids,
166
- past_key_values=past_key_values,
167
- inputs_embeds=inputs_embeds,
168
- use_cache=use_cache,
169
- output_attentions=output_attentions,
170
- output_hidden_states=output_hidden_states,
171
- return_dict=return_dict,
172
- )
173
-
174
- sequence_output = outputs[0]
175
- logits = self.lm_head(sequence_output).contiguous()
176
-
177
- loss = None
178
- if labels is not None:
179
- shift_logits = logits[..., :-1, :].contiguous()
180
- shift_labels = labels[..., 1:].contiguous()
181
- loss_fct = nn.CrossEntropyLoss()
182
- loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
183
-
184
- if not return_dict:
185
- output = (logits,) + outputs[1:]
186
- return ((loss,) + output) if loss is not None else output
187
-
188
- return CausalLMOutputWithPast(
189
- loss=loss,
190
- logits=logits,
191
- past_key_values=outputs.past_key_values,
192
- hidden_states=outputs.hidden_states,
193
- attentions=outputs.attentions,
194
- )
195
-
196
- def prepare_inputs_for_generation(
197
- self,
198
- input_ids,
199
- past_key_values=None,
200
- attention_mask=None,
201
- use_cache=None,
202
- **kwargs):
203
- batch_size = input_ids.size(0)
204
- seq_length = input_ids.size(1)
205
- n_frames = seq_length // self.num_image_patches
206
- attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
207
- if seq_length % self.num_image_patches != 0:
208
- n_last_frame_tokens = seq_length % self.num_image_patches
209
- attention_mask_length += n_last_frame_tokens
210
- else:
211
- print(f"attempting to generate new frame - frame no: {n_frames + 1}")
212
- attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
213
- # cut decoder_input_ids if past_key_values is used
214
- if past_key_values is not None and len(past_key_values) > 0:
215
- if isinstance(past_key_values, tuple):
216
- past_length = past_key_values[0][0].size(2)
217
- else:
218
- past_length = past_key_values.get_seq_length()
219
- if input_ids.size(1) > past_length:
220
- remove_prefix_length = past_length
221
- else:
222
- remove_prefix_length = input_ids.size(1) - 1
223
- input_ids = input_ids[:, remove_prefix_length:]
224
- seq_length = input_ids.size(1)
225
- past_key_values_length = past_length
226
- mask_seq_length = seq_length + past_key_values_length
227
- if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
228
- mask_seq_length += self.num_action_embeddings
229
- attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)
230
-
231
- return {
232
- "input_ids": input_ids,
233
- "attention_mask": attention_mask,
234
- "actions": kwargs.get("actions"),
235
- "past_key_values": past_key_values,
236
- "use_cache": use_cache,
237
- }