Gleb Vinarskis commited on
Commit
c746e15
·
1 Parent(s): ff3ca74

loading model

Browse files
Files changed (1) hide show
  1. impresso_langident_wrapper.py +9 -1
impresso_langident_wrapper.py CHANGED
@@ -1,9 +1,17 @@
1
  from transformers import Pipeline
2
  from transformers.pipelines import PIPELINE_REGISTRY
 
3
 
4
 
5
 
6
  class Pipeline_One(Pipeline):
 
 
 
 
 
 
 
7
  def _sanitize_parameters(self, **kwargs):
8
  # Add any additional parameter handling if necessary
9
  return {}, {}, {}
@@ -12,7 +20,7 @@ class Pipeline_One(Pipeline):
12
  return text
13
 
14
  def _forward(self, inputs):
15
- model_output = self.model(**inputs, k=1)
16
 
17
  return model_output
18
 
 
1
  from transformers import Pipeline
2
  from transformers.pipelines import PIPELINE_REGISTRY
3
+ import floret
4
 
5
 
6
 
7
  class Pipeline_One(Pipeline):
8
+ def __init__(self, model_path, **kwargs):
9
+ super().__init__(**kwargs) # Call the base class constructor
10
+
11
+ # Load the Floret model
12
+ self.model = floret.load_model(model_path)
13
+
14
+
15
  def _sanitize_parameters(self, **kwargs):
16
  # Add any additional parameter handling if necessary
17
  return {}, {}, {}
 
20
  return text
21
 
22
  def _forward(self, inputs):
23
+ model_output = self.model.predict(**inputs, k=1)
24
 
25
  return model_output
26