koukyo1994 commited on
Commit
d84275c
·
verified ·
1 Parent(s): 46f772f

update llama_action model

Browse files
Files changed (1) hide show
  1. modeling_llama_action.py +14 -1
modeling_llama_action.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torch.nn as nn
5
  from transformers import LlamaForCausalLM
6
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
7
 
8
  from .configuration_llama_action import LlamaActionConfig
9
 
@@ -204,11 +205,23 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
204
  seq_length = input_ids.size(1)
205
  n_frames = seq_length // self.num_image_patches
206
  attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
 
 
 
 
 
 
 
 
 
 
207
  if seq_length % self.num_image_patches != 0:
208
  n_last_frame_tokens = seq_length % self.num_image_patches
209
  attention_mask_length += n_last_frame_tokens
210
  else:
211
- print(f"attempting to generate new frame - frame no: {n_frames + 1}")
 
 
212
  attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
213
  # cut decoder_input_ids if past_key_values is used
214
  if past_key_values is not None and len(past_key_values) > 0:
 
4
  import torch.nn as nn
5
  from transformers import LlamaForCausalLM
6
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from tqdm import tqdm
8
 
9
  from .configuration_llama_action import LlamaActionConfig
10
 
 
205
  seq_length = input_ids.size(1)
206
  n_frames = seq_length // self.num_image_patches
207
  attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
208
+ if kwargs.pop("show_progress", False):
209
+ prefix = kwargs.pop("prefix", "")
210
+ max_length = kwargs.pop("max_length")
211
+ if past_key_values is None or len(past_key_values) == 0:
212
+ pbar = tqdm(total=max_length - len(input_ids[0]), desc=prefix, leave=False)
213
+ postfix = f"Frame [{n_frames + 1}/{max_length // self.num_image_patches}]"
214
+ pbar.set_postfix_str(postfix)
215
+ else:
216
+ pbar.update()
217
+
218
  if seq_length % self.num_image_patches != 0:
219
  n_last_frame_tokens = seq_length % self.num_image_patches
220
  attention_mask_length += n_last_frame_tokens
221
  else:
222
+ if kwargs.pop("show_progress", False):
223
+ postfix = f"Frame [{n_frames + 1}/{max_length // self.num_image_patches}]"
224
+ pbar.set_postfix(postfix)
225
  attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
226
  # cut decoder_input_ids if past_key_values is used
227
  if past_key_values is not None and len(past_key_values) > 0: