File size: 883 Bytes
e7e3b60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from typing import Dict, List, Any
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
class PreTrainedPipeline():
def __init__(self, path: str):
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
directory=path,
spe_filename="spe_32k_lc_en.model",
model_filename="punct_cap_seg_en.onnx",
config_filename="config.yaml",
)
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
def __call__(self, data: str) -> List[Dict]:
# Use list to generate a batch of size 1
pred_texts: List[List[str]] = self._punctuator.infer([data])
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
return outputs
|