Gleb Vinarskis commited on
Commit
90328f3
·
1 Parent(s): fe1a4ae
Files changed (2) hide show
  1. config.json +1 -1
  2. impresso_langident_wrapper.py +20 -33
config.json CHANGED
@@ -18,7 +18,7 @@
18
  "custom_pipelines": {
19
  "language-detection": {
20
  "impl": "impresso_langident_wrapper.Pipeline_One",
21
- "pt": ["AutoModel"],
22
  "tf": []
23
  }
24
  }
 
18
  "custom_pipelines": {
19
  "language-detection": {
20
  "impl": "impresso_langident_wrapper.Pipeline_One",
21
+ "pt": ["DummyModel"],
22
  "tf": []
23
  }
24
  }
impresso_langident_wrapper.py CHANGED
@@ -5,48 +5,35 @@ from huggingface_hub import hf_hub_download
5
 
6
 
7
 
8
- class Pipeline_One(Pipeline):
9
- # def __init__(self, model_path: str):
10
- # """
11
- # Initialize the Floret language detection pipeline
12
-
13
- # Args:
14
- # model_path (str): Path to the .bin model file
15
- # """
16
- # super().__init__()
17
- # self.model = floret.FastText.load_model(model_path)
18
-
19
- def __init__(self, model_name="floret_model.bin", repo_id="Maslionok/pipeline1", revision="main", **kwargs):
20
  """
21
- Initialize the Floret language detection pipeline.
22
- Args:
23
- model_name (str): The name of the Floret model file.
24
- repo_id (str): The Hugging Face repository ID.
25
- revision (str): The branch/revision to download from.
26
  """
27
- super().__init__(**kwargs) # ✅ Call parent constructor
28
 
29
- # Manually download the Floret model
30
- model_path = hf_hub_download(repo_id=repo_id, filename=model_name, revision=revision)
31
-
32
- # ✅ Load the Floret model
33
- self.model = floret.load_model(model_path)
34
-
35
- def _sanitize_parameters(self, **kwargs):
36
- # Add any additional parameter handling if necessary
37
- return {}, {}, {}
38
 
39
  def preprocess(self, text, **kwargs):
40
- return text
41
 
42
  def _forward(self, inputs):
43
- model_output = self.model.predict(**inputs, k=1)
44
-
45
- return model_output
46
 
47
  def postprocess(self, outputs, **kwargs):
48
- return outputs
49
-
50
 
51
 
52
  PIPELINE_REGISTRY.register_pipeline(
 
5
 
6
 
7
 
8
+ from transformers import Pipeline
9
+ import floret
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ class Pipeline_One(Pipeline):
13
+ def __init__(self, model=None, model_name="floret_model.bin", repo_id="Maslionok/pipeline1", revision="main", **kwargs):
 
 
 
 
 
 
14
  """
15
+ Custom pipeline to manually load a Floret model.
 
 
 
 
16
  """
17
+ super().__init__(**kwargs) # ✅ Ensures Hugging Face registers this correctly
18
 
19
+ if model is None:
20
+ # Manually download the Floret model
21
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_name, revision=revision)
22
+ self.model = floret.load_model(model_path) # ✅ Load Floret manually
23
+ else:
24
+ self.model = model # If manually passed, use existing model
 
 
 
25
 
26
  def preprocess(self, text, **kwargs):
27
+ return {"text": text} # ✅ Prepare text for Floret
28
 
29
  def _forward(self, inputs):
30
+ text = inputs["text"]
31
+ predictions = self.model.predict(text, k=1) # ✅ Get prediction from Floret
32
+ return predictions
33
 
34
  def postprocess(self, outputs, **kwargs):
35
+ return outputs # ✅ Return Floret’s output as-is
36
+
37
 
38
 
39
  PIPELINE_REGISTRY.register_pipeline(