koukyo1994
commited on
update llama_action model
Browse files- modeling_llama_action.py +14 -1
modeling_llama_action.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from transformers import LlamaForCausalLM
|
6 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
7 |
|
8 |
from .configuration_llama_action import LlamaActionConfig
|
9 |
|
@@ -204,11 +205,23 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
|
|
204 |
seq_length = input_ids.size(1)
|
205 |
n_frames = seq_length // self.num_image_patches
|
206 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
if seq_length % self.num_image_patches != 0:
|
208 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
209 |
attention_mask_length += n_last_frame_tokens
|
210 |
else:
|
211 |
-
|
|
|
|
|
212 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
213 |
# cut decoder_input_ids if past_key_values is used
|
214 |
if past_key_values is not None and len(past_key_values) > 0:
|
|
|
4 |
import torch.nn as nn
|
5 |
from transformers import LlamaForCausalLM
|
6 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
+
from tqdm import tqdm
|
8 |
|
9 |
from .configuration_llama_action import LlamaActionConfig
|
10 |
|
|
|
205 |
seq_length = input_ids.size(1)
|
206 |
n_frames = seq_length // self.num_image_patches
|
207 |
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
|
208 |
+
if kwargs.pop("show_progress", False):
|
209 |
+
prefix = kwargs.pop("prefix", "")
|
210 |
+
max_length = kwargs.pop("max_length")
|
211 |
+
if past_key_values is None or len(past_key_values) == 0:
|
212 |
+
pbar = tqdm(total=max_length - len(input_ids[0]), desc=prefix, leave=False)
|
213 |
+
postfix = f"Frame [{n_frames + 1}/{max_length // self.num_image_patches}]"
|
214 |
+
pbar.set_postfix_str(postfix)
|
215 |
+
else:
|
216 |
+
pbar.update()
|
217 |
+
|
218 |
if seq_length % self.num_image_patches != 0:
|
219 |
n_last_frame_tokens = seq_length % self.num_image_patches
|
220 |
attention_mask_length += n_last_frame_tokens
|
221 |
else:
|
222 |
+
if kwargs.pop("show_progress", False):
|
223 |
+
postfix = f"Frame [{n_frames + 1}/{max_length // self.num_image_patches}]"
|
224 |
+
pbar.set_postfix(postfix)
|
225 |
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
|
226 |
# cut decoder_input_ids if past_key_values is used
|
227 |
if past_key_values is not None and len(past_key_values) > 0:
|