File size: 4,609 Bytes
c5cd944 6b22e1f c5cd944 06cfca8 c5cd944 d5b56e4 c5cd944 a84cd54 c5cd944 d5b56e4 c5cd944 6b22e1f c5cd944 6b22e1f c5cd944 6b22e1f c5cd944 f807815 c5cd944 f807815 c5cd944 6b22e1f c5cd944 2eebed8 c5cd944 6b22e1f c5cd944 2eebed8 c5cd944 6b22e1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import transformers
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor
class UltravoxPipeline(transformers.Pipeline):
def __init__(
self,
model: UltravoxModel,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
audio_processor: Optional[transformers.ProcessorMixin] = None,
**kwargs
):
if tokenizer is None:
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config._name_or_path
)
except:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config.text_model_id or model.config.text_config._name_or_path
)
if audio_processor is None:
audio_processor = transformers.AutoProcessor.from_pretrained(
model.config.audio_model_id or model.config.audio_config._name_or_path
)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.processor = UltravoxProcessor(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
)
def _sanitize_parameters(self, **kwargs):
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
return {}, generation_kwargs, {}
def preprocess(self, inputs: Dict[str, Any]):
turns: list = inputs.get("turns", [])
audio = inputs.get("audio", None)
# Convert to float32 if needed.
if isinstance(audio, np.ndarray):
if audio.dtype == np.float64:
audio = audio.astype(np.float32)
elif audio.dtype == np.int16:
audio = audio.astype(np.float32) / np.float32(32768.0)
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / np.float32(2147483648.0)
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
prompt = inputs.get("prompt", "<|audio|>")
if "<|audio|>" not in prompt:
logging.warning(
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
)
prompt += " <|audio|>"
turns.append({"role": "user", "content": prompt})
text = self.processor.tokenizer.apply_chat_template(
turns, add_generation_prompt=True, tokenize=False
)
if "sampling_rate" not in inputs and audio is not None:
logging.warning(
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
)
output = self.processor(
text=text,
audio=audio,
sampling_rate=inputs.get("sampling_rate", 16000),
)
if "audio_values" in output:
output["audio_values"] = output["audio_values"].to(self.model.dtype)
return output
def _forward(
self,
model_inputs: Dict[str, Any],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: float = 1.1,
) -> List[int]:
temperature = temperature or None
do_sample = temperature is not None
terminators = [self.tokenizer.eos_token_id]
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
input_len = model_inputs["input_ids"].shape[1]
outputs = self.model.generate(
**model_inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
eos_token_id=terminators
)
return outputs[0][input_len:]
def postprocess(self, model_outputs) -> str:
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
return output_text
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
"ultravox-pipeline",
pipeline_class=UltravoxPipeline,
pt_model=transformers.AutoModel,
type="multimodal",
)
|