File size: 4,609 Bytes
c5cd944
 
 
6b22e1f
c5cd944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06cfca8
 
 
 
 
 
 
 
c5cd944
 
d5b56e4
 
c5cd944
 
a84cd54
 
c5cd944
d5b56e4
 
 
c5cd944
 
 
6b22e1f
 
c5cd944
 
 
6b22e1f
 
 
 
 
 
 
 
 
 
 
 
 
c5cd944
 
 
 
 
6b22e1f
c5cd944
f807815
c5cd944
f807815
 
 
c5cd944
6b22e1f
c5cd944
 
 
 
2eebed8
c5cd944
6b22e1f
c5cd944
 
2eebed8
 
 
 
c5cd944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b22e1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
from typing import Any, Dict, List, Optional

import numpy as np
import transformers

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor


class UltravoxPipeline(transformers.Pipeline):
    def __init__(
        self,
        model: UltravoxModel,
        tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
        audio_processor: Optional[transformers.ProcessorMixin] = None,
        **kwargs
    ):
        if tokenizer is None:
            try:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config._name_or_path
                )
            except:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config.text_model_id or model.config.text_config._name_or_path
                )

        if audio_processor is None:
            audio_processor = transformers.AutoProcessor.from_pretrained(
                model.config.audio_model_id or model.config.audio_config._name_or_path
            )

        super().__init__(model=model, tokenizer=tokenizer, **kwargs)

        self.processor = UltravoxProcessor(
            audio_processor=audio_processor,
            tokenizer=tokenizer,
            stack_factor=model.config.stack_factor,
        )

    def _sanitize_parameters(self, **kwargs):
        generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
        generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
        return {}, generation_kwargs, {}

    def preprocess(self, inputs: Dict[str, Any]):
        turns: list = inputs.get("turns", [])

        audio = inputs.get("audio", None)
        # Convert to float32 if needed.
        if isinstance(audio, np.ndarray):
            if audio.dtype == np.float64:
                audio = audio.astype(np.float32)
            elif audio.dtype == np.int16:
                audio = audio.astype(np.float32) / np.float32(32768.0)
            elif audio.dtype == np.int32:
                audio = audio.astype(np.float32) / np.float32(2147483648.0)

        if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
            prompt = inputs.get("prompt", "<|audio|>")
            if "<|audio|>" not in prompt:
                logging.warning(
                    "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
                )

                prompt += " <|audio|>"
            turns.append({"role": "user", "content": prompt})

        text = self.processor.tokenizer.apply_chat_template(
            turns, add_generation_prompt=True, tokenize=False
        )

        if "sampling_rate" not in inputs and audio is not None:
            logging.warning(
                "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
            )

        output = self.processor(
            text=text,
            audio=audio,
            sampling_rate=inputs.get("sampling_rate", 16000),
        )
        if "audio_values" in output:
            output["audio_values"] = output["audio_values"].to(self.model.dtype)

        return output

    def _forward(
        self,
        model_inputs: Dict[str, Any],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        repetition_penalty: float = 1.1,
    ) -> List[int]:
        temperature = temperature or None
        do_sample = temperature is not None

        terminators = [self.tokenizer.eos_token_id]
        if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
            terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))

        input_len = model_inputs["input_ids"].shape[1]

        outputs = self.model.generate(
            **model_inputs,
            do_sample=do_sample,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=terminators
        )
        return outputs[0][input_len:]

    def postprocess(self, model_outputs) -> str:
        output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
        return output_text


transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
    "ultravox-pipeline",
    pipeline_class=UltravoxPipeline,
    pt_model=transformers.AutoModel,
    type="multimodal",
)