farzadab commited on
Commit
f807815
·
verified ·
1 Parent(s): 909cf24

Update ultravox_pipeline.py

Browse files
Files changed (1) hide show
  1. ultravox_pipeline.py +9 -4
ultravox_pipeline.py CHANGED
@@ -19,7 +19,7 @@ class UltravoxPipeline(transformers.Pipeline):
19
  ):
20
  if tokenizer is None:
21
  tokenizer = transformers.AutoTokenizer.from_pretrained(
22
- model.config._name_or_path
23
  )
24
 
25
  if audio_processor is None:
@@ -49,15 +49,20 @@ class UltravoxPipeline(transformers.Pipeline):
49
  if "turns" in inputs:
50
  turns = inputs["turns"]
51
  else:
 
 
 
52
  prompt = inputs.get("prompt", "<|audio|>")
53
  if "<|audio|>" not in prompt:
54
  logging.warning(
55
  "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
56
  )
57
  prompt += " <|audio|>"
58
- turns = [{"role": "user", "content": prompt}]
59
 
60
- text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False)
 
 
61
 
62
  # TODO: allow text-only mode?
63
  assert "audio" in inputs, "Audio input is required"
@@ -113,4 +118,4 @@ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
113
  pipeline_class=UltravoxPipeline,
114
  pt_model=transformers.AutoModel,
115
  type="multimodal",
116
- )
 
19
  ):
20
  if tokenizer is None:
21
  tokenizer = transformers.AutoTokenizer.from_pretrained(
22
+ model.config.text_config._name_or_path
23
  )
24
 
25
  if audio_processor is None:
 
49
  if "turns" in inputs:
50
  turns = inputs["turns"]
51
  else:
52
+ turns = []
53
+
54
+ if not turns or turns[-1]["role"] != "user":
55
  prompt = inputs.get("prompt", "<|audio|>")
56
  if "<|audio|>" not in prompt:
57
  logging.warning(
58
  "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
59
  )
60
  prompt += " <|audio|>"
61
+ turns.append({"role": "user", "content": prompt})
62
 
63
+ text = self.processor.tokenizer.apply_chat_template(
64
+ turns, add_generation_prompt=True, tokenize=False
65
+ )
66
 
67
  # TODO: allow text-only mode?
68
  assert "audio" in inputs, "Audio input is required"
 
118
  pipeline_class=UltravoxPipeline,
119
  pt_model=transformers.AutoModel,
120
  type="multimodal",
121
+ )