zRzRzRzRzRzRzR sixsixcoder commited on
Commit
2c22262
·
verified ·
1 Parent(s): 9ab2ccb

Update modeling_cogvlm.py (#6)

Browse files

- Update modeling_cogvlm.py (a65b36bac15240fbb712070071770a678cd0c082)


Co-authored-by: sixgod <[email protected]>

Files changed (1) hide show
  1. modeling_cogvlm.py +15 -4
modeling_cogvlm.py CHANGED
@@ -1,9 +1,11 @@
1
  """largely copy from llama and adapt for cogvlm"""
2
  import warnings
 
3
  from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
 
5
  import math
6
  import torch
 
7
  from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
@@ -26,7 +28,12 @@ logger = get_logger(__name__)
26
 
27
  LANGUAGE_TOKEN_TYPE = 0
28
  VISION_TOKEN_TYPE = 1
29
-
 
 
 
 
 
30
 
31
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
  def _make_causal_mask(
@@ -736,9 +743,13 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
736
  standardize_cache_format: bool = False,
737
  ) -> Dict[str, Any]:
738
  # update past_key_values
739
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
740
- outputs, standardize_cache_format=standardize_cache_format
741
- )
 
 
 
 
742
  if getattr(outputs, "state", None) is not None:
743
  model_kwargs["state"] = outputs.state
744
 
 
1
  """largely copy from llama and adapt for cogvlm"""
2
  import warnings
3
+ import packaging.version
4
  from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
5
 
6
  import math
7
  import torch
8
+ import transformers
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from torchvision import transforms
 
28
 
29
  LANGUAGE_TOKEN_TYPE = 0
30
  VISION_TOKEN_TYPE = 1
31
+ TRANSFORMERS_ABOVE_441 = (
32
+ True
33
+ if packaging.version.parse(transformers.__version__)
34
+ >= packaging.version.parse("4.42.0")
35
+ else False
36
+ )
37
 
38
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
39
  def _make_causal_mask(
 
743
  standardize_cache_format: bool = False,
744
  ) -> Dict[str, Any]:
745
  # update past_key_values
746
+ if TRANSFORMERS_ABOVE_441:
747
+ cache_name, cache = self._extract_past_from_model_output(outputs)
748
+ model_kwargs[cache_name] = cache
749
+ else:
750
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
751
+ outputs, standardize_cache_format=standardize_cache_format
752
+ )
753
  if getattr(outputs, "state", None) is not None:
754
  model_kwargs["state"] = outputs.state
755