pipeline1 / impresso_langident_wrapper.py
Gleb Vinarskis
debug
65de753
from transformers import Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
import floret
from huggingface_hub import hf_hub_download
class Pipeline_One(Pipeline):
# def __init__(self, model_path: str):
# """
# Initialize the Floret language detection pipeline
# Args:
# model_path (str): Path to the .bin model file
# """
# super().__init__()
# self.model = floret.FastText.load_model(model_path)
# def __init__(self, model_name="floret_model.bin", repo_id="Maslionok/pipeline1", revision="main", **kwargs):
# """
# Initialize the Floret language detection pipeline.
# Args:
# model_name (str): The name of the Floret model file.
# repo_id (str): The Hugging Face repository ID.
# revision (str): The branch/revision to download from.
# """
# super().__init__(**kwargs)
# model_path = hf_hub_download(repo_id=repo_id, filename=model_name, revision=revision)
# self.model = floret.load_model(model_path)
# def _sanitize_parameters(self, **kwargs):
# # Add any additional parameter handling if necessary
# return {}, {}, {}
def _sanitize_parameters(self, **kwargs):
print("000000000")
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
print("this is preprocessing:")
print(text)
return text
def _forward(self, inputs):
model_output = self.model.predict(**inputs, k=1)
return model_output
def postprocess(self, outputs, **kwargs):
return outputs
# PIPELINE_REGISTRY.register_pipeline(
# task="language-detection",
# pipeline_class=Pipeline_One,
# default={"model": None},
# )