yuzaa HwwwH commited on
Commit
34a6e7f
·
verified ·
1 Parent(s): 77225ff

Avoid duplicate input kwargs in `_decode` (#28)

Browse files

- Avoid duplicate input kwargs in `_decode` (18005e74b8257c981bb97dd4f350b06cd28f7aa6)
- avoid duplicate generate args (5d0120037703b4b70ec932f62ddb81e07b8b85c4)
- update modeling_minicpmo.py (cac55956a6efb7456cf5bbcad4e3e4f14d2e7ea9)


Co-authored-by: Zhihui He <[email protected]>

Files changed (1) hide show
  1. modeling_minicpmo.py +7 -1
modeling_minicpmo.py CHANGED
@@ -636,6 +636,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
636
  return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
637
 
638
  def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
 
 
639
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
640
  outputs = self.llm.generate(
641
  inputs_embeds=inputs_embeds,
@@ -777,6 +779,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
777
  tokenizer=None,
778
  vision_hidden_states=None,
779
  stream=False,
 
780
  **kwargs,
781
  ):
782
  assert input_ids is not None
@@ -814,7 +817,10 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
814
  outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
815
 
816
  result = self._decode_text(outputs.sequences, tokenizer)
817
-
 
 
 
818
  return result, outputs
819
 
820
  def chat(
 
636
  return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
637
 
638
  def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
639
+ kwargs.pop("output_hidden_states", None)
640
+ kwargs.pop("return_dict_in_generate", None)
641
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
642
  outputs = self.llm.generate(
643
  inputs_embeds=inputs_embeds,
 
779
  tokenizer=None,
780
  vision_hidden_states=None,
781
  stream=False,
782
+ decode_text=True,
783
  **kwargs,
784
  ):
785
  assert input_ids is not None
 
817
  outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
818
 
819
  result = self._decode_text(outputs.sequences, tokenizer)
820
+
821
+ if decode_text is False:
822
+ return outputs
823
+
824
  return result, outputs
825
 
826
  def chat(