File size: 883 Bytes
e7e3b60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from typing import Dict, List, Any

from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX


class PreTrainedPipeline():
    def __init__(self, path: str):
        cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
            directory=path,
            spe_filename="spe_32k_lc_en.model",
            model_filename="punct_cap_seg_en.onnx",
            config_filename="config.yaml",
        )
        self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)

    def __call__(self, data: str) -> List[Dict]:
        # Use list to generate a batch of size 1
        pred_texts: List[List[str]] = self._punctuator.infer([data])
        # Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
        outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
        return outputs