Gleb Vinarskis
commited on
Commit
·
c746e15
1
Parent(s):
ff3ca74
loading model
Browse files
impresso_langident_wrapper.py
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
from transformers import Pipeline
|
2 |
from transformers.pipelines import PIPELINE_REGISTRY
|
|
|
3 |
|
4 |
|
5 |
|
6 |
class Pipeline_One(Pipeline):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def _sanitize_parameters(self, **kwargs):
|
8 |
# Add any additional parameter handling if necessary
|
9 |
return {}, {}, {}
|
@@ -12,7 +20,7 @@ class Pipeline_One(Pipeline):
|
|
12 |
return text
|
13 |
|
14 |
def _forward(self, inputs):
|
15 |
-
model_output = self.model(**inputs, k=1)
|
16 |
|
17 |
return model_output
|
18 |
|
|
|
1 |
from transformers import Pipeline
|
2 |
from transformers.pipelines import PIPELINE_REGISTRY
|
3 |
+
import floret
|
4 |
|
5 |
|
6 |
|
7 |
class Pipeline_One(Pipeline):
|
8 |
+
def __init__(self, model_path, **kwargs):
|
9 |
+
super().__init__(**kwargs) # Call the base class constructor
|
10 |
+
|
11 |
+
# Load the Floret model
|
12 |
+
self.model = floret.load_model(model_path)
|
13 |
+
|
14 |
+
|
15 |
def _sanitize_parameters(self, **kwargs):
|
16 |
# Add any additional parameter handling if necessary
|
17 |
return {}, {}, {}
|
|
|
20 |
return text
|
21 |
|
22 |
def _forward(self, inputs):
|
23 |
+
model_output = self.model.predict(**inputs, k=1)
|
24 |
|
25 |
return model_output
|
26 |
|