Gleb Vinarskis commited on
Commit
26d45f0
·
1 Parent(s): d58562c

changed pipeline

Browse files
Files changed (1) hide show
  1. impresso_langident_wrapper.py +10 -32
impresso_langident_wrapper.py CHANGED
@@ -1,37 +1,15 @@
1
- import floret # Assuming Floret is already installed
2
-
3
-
4
- class FloretLangIdentifier:
5
- def __init__(self, model_path):
6
- self.model = floret.load_model(model_path)
7
-
8
- def predict(self, text):
9
- predictions = self.model.predict(text)
10
- return predictions
11
-
12
-
13
-
14
-
15
-
16
-
17
  from transformers import Pipeline
18
-
19
-
20
- class MyPipeline(Pipeline):
21
  def _sanitize_parameters(self, **kwargs):
22
- preprocess_kwargs = {}
23
- if "maybe_arg" in kwargs:
24
- preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
25
- return preprocess_kwargs, {}, {}
26
 
27
- def preprocess(self, inputs, maybe_arg=2):
28
- return inputs
29
 
30
- def _forward(self, model_inputs):
31
- # model_inputs == {"model_input": model_input}
32
- outputs = self.model.predict_language(**model_inputs)
33
- # Maybe {"logits": Tensor(...)}
34
- return outputs
35
 
36
- def postprocess(self, model_outputs):
37
- return model_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import Pipeline
2
+ class Pipeline_One(Pipeline):
 
 
3
  def _sanitize_parameters(self, **kwargs):
4
+ # Add any additional parameter handling if necessary
5
+ return kwargs, {}, {}
 
 
6
 
7
+ def preprocess(self, text, **kwargs):
8
+ return text
9
 
10
+ def _forward(self, inputs):
11
+ model_output = self.model.predict(inputs, k=1)
12
+ return model_output
 
 
13
 
14
+ def postprocess(self, outputs, **kwargs):
15
+ return outputs