czczup commited on
Commit
3212bee
·
verified ·
1 Parent(s): c443cca

fix compatibility issue for transformers 4.46+

Browse files
configuration_intern_vit.py CHANGED
@@ -3,6 +3,7 @@
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
 
6
  import os
7
  from typing import Union
8
 
 
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
+
7
  import os
8
  from typing import Union
9
 
configuration_internvl_chat.py CHANGED
@@ -47,12 +47,12 @@ class InternVLChatConfig(PretrainedConfig):
47
  logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
48
 
49
  self.vision_config = InternVisionConfig(**vision_config)
50
- if llm_config['architectures'][0] == 'LlamaForCausalLM':
51
  self.llm_config = LlamaConfig(**llm_config)
52
- elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
53
  self.llm_config = InternLM2Config(**llm_config)
54
  else:
55
- raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
56
  self.use_backbone_lora = use_backbone_lora
57
  self.use_llm_lora = use_llm_lora
58
  self.select_layer = select_layer
 
47
  logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
48
 
49
  self.vision_config = InternVisionConfig(**vision_config)
50
+ if llm_config.get(['architectures'])[0] == 'LlamaForCausalLM':
51
  self.llm_config = LlamaConfig(**llm_config)
52
+ elif llm_config.get(['architectures'])[0] == 'InternLM2ForCausalLM':
53
  self.llm_config = InternLM2Config(**llm_config)
54
  else:
55
+ raise ValueError('Unsupported architecture: {}'.format(llm_config.get(['architectures'])[0]))
56
  self.use_backbone_lora = use_backbone_lora
57
  self.use_llm_lora = use_llm_lora
58
  self.select_layer = select_layer
modeling_internvl_chat.py CHANGED
@@ -3,6 +3,7 @@
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
 
6
  import warnings
7
  from typing import Any, List, Optional, Tuple, Union
8
 
@@ -236,7 +237,7 @@ class InternVLChatModel(PreTrainedModel):
236
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
237
  input_ids = model_inputs['input_ids'].to(self.device)
238
  attention_mask = model_inputs['attention_mask'].to(self.device)
239
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
240
  generation_config['eos_token_id'] = eos_token_id
241
  generation_output = self.generate(
242
  pixel_values=pixel_values,
@@ -245,7 +246,7 @@ class InternVLChatModel(PreTrainedModel):
245
  **generation_config
246
  )
247
  responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
248
- responses = [response.split(template.sep)[0].strip() for response in responses]
249
  return responses
250
 
251
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
@@ -264,7 +265,7 @@ class InternVLChatModel(PreTrainedModel):
264
 
265
  template = get_conv_template(self.template)
266
  template.system_message = self.system_message
267
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
268
 
269
  history = [] if history is None else history
270
  for (old_question, old_answer) in history:
@@ -293,7 +294,7 @@ class InternVLChatModel(PreTrainedModel):
293
  **generation_config
294
  )
295
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
296
- response = response.split(template.sep)[0].strip()
297
  history.append((question, response))
298
  if return_history:
299
  return response, history
@@ -313,7 +314,6 @@ class InternVLChatModel(PreTrainedModel):
313
  visual_features: Optional[torch.FloatTensor] = None,
314
  generation_config: Optional[GenerationConfig] = None,
315
  output_hidden_states: Optional[bool] = None,
316
- return_dict: Optional[bool] = None,
317
  **generate_kwargs,
318
  ) -> torch.LongTensor:
319
 
@@ -341,7 +341,6 @@ class InternVLChatModel(PreTrainedModel):
341
  attention_mask=attention_mask,
342
  generation_config=generation_config,
343
  output_hidden_states=output_hidden_states,
344
- return_dict=return_dict,
345
  use_cache=True,
346
  **generate_kwargs,
347
  )
 
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
+
7
  import warnings
8
  from typing import Any, List, Optional, Tuple, Union
9
 
 
237
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
238
  input_ids = model_inputs['input_ids'].to(self.device)
239
  attention_mask = model_inputs['attention_mask'].to(self.device)
240
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
241
  generation_config['eos_token_id'] = eos_token_id
242
  generation_output = self.generate(
243
  pixel_values=pixel_values,
 
246
  **generation_config
247
  )
248
  responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
249
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
250
  return responses
251
 
252
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
 
265
 
266
  template = get_conv_template(self.template)
267
  template.system_message = self.system_message
268
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
269
 
270
  history = [] if history is None else history
271
  for (old_question, old_answer) in history:
 
294
  **generation_config
295
  )
296
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
297
+ response = response.split(template.sep.strip())[0].strip()
298
  history.append((question, response))
299
  if return_history:
300
  return response, history
 
314
  visual_features: Optional[torch.FloatTensor] = None,
315
  generation_config: Optional[GenerationConfig] = None,
316
  output_hidden_states: Optional[bool] = None,
 
317
  **generate_kwargs,
318
  ) -> torch.LongTensor:
319
 
 
341
  attention_mask=attention_mask,
342
  generation_config=generation_config,
343
  output_hidden_states=output_hidden_states,
 
344
  use_cache=True,
345
  **generate_kwargs,
346
  )