from transformers import Pipeline | |
from transformers.pipelines import PIPELINE_REGISTRY | |
import floret | |
from huggingface_hub import hf_hub_download | |
class Pipeline_One(Pipeline): | |
# def __init__(self, model_path: str): | |
# """ | |
# Initialize the Floret language detection pipeline | |
# Args: | |
# model_path (str): Path to the .bin model file | |
# """ | |
# super().__init__() | |
# self.model = floret.FastText.load_model(model_path) | |
# def __init__(self, model_name="floret_model.bin", repo_id="Maslionok/pipeline1", revision="main", **kwargs): | |
# """ | |
# Initialize the Floret language detection pipeline. | |
# Args: | |
# model_name (str): The name of the Floret model file. | |
# repo_id (str): The Hugging Face repository ID. | |
# revision (str): The branch/revision to download from. | |
# """ | |
# super().__init__(**kwargs) | |
# model_path = hf_hub_download(repo_id=repo_id, filename=model_name, revision=revision) | |
# self.model = floret.load_model(model_path) | |
# def _sanitize_parameters(self, **kwargs): | |
# # Add any additional parameter handling if necessary | |
# return {}, {}, {} | |
def _sanitize_parameters(self, **kwargs): | |
print("000000000") | |
preprocess_kwargs = {} | |
if "text" in kwargs: | |
preprocess_kwargs["text"] = kwargs["text"] | |
return preprocess_kwargs, {}, {} | |
def preprocess(self, text, **kwargs): | |
print("this is preprocessing:") | |
print(text) | |
return text | |
def _forward(self, inputs): | |
model_output = self.model.predict(**inputs, k=1) | |
return model_output | |
def postprocess(self, outputs, **kwargs): | |
return outputs | |
# PIPELINE_REGISTRY.register_pipeline( | |
# task="language-detection", | |
# pipeline_class=Pipeline_One, | |
# default={"model": None}, | |
# ) |