Rcarvalo commited on
Commit
b1183e7
·
verified ·
1 Parent(s): 9685f7b

modernBert

Browse files
Files changed (1) hide show
  1. tasks/text.py +23 -1
tasks/text.py CHANGED
@@ -6,10 +6,11 @@ import random
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -63,7 +64,28 @@ async def evaluate_text(request: TextEvaluationRequest):
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
65
  #--------------------------------------------------------------------------------------------
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Stop tracking emissions
69
  emissions_data = tracker.stop_task()
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
+ from transformers import AutoModelForSequenceClassification,AutoTokenizer
10
 
11
  router = APIRouter()
12
 
13
+ DESCRIPTION = "ModernBert Baseline"
14
  ROUTE = "/text"
15
 
16
  @router.post(ROUTE, tags=["Text Task"],
 
64
  #--------------------------------------------------------------------------------------------
65
  # YOUR MODEL INFERENCE STOPS HERE
66
  #--------------------------------------------------------------------------------------------
67
+ ## Model loading
68
+ model = AutoModelForSequenceClassification.from_pretrained("Rcarvalo/test_modernbert_finetuned_v2")
69
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
70
 
71
+ ## Data prep
72
+ def preprocess_function(df):
73
+ return tokenizer(df["quote"], truncation=True)
74
+ tokenized_test = test_dataset.map(preprocess_function, batched=True)
75
+
76
+ ## Modify inference model
77
+ training_args = torch.load("./tasks/utils/training_args.bin")
78
+ training_args.eval_strategy='no'
79
+
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=training_args,
83
+ tokenizer=tokenizer
84
+ )
85
+
86
+ ## prediction
87
+ preds = trainer.predict(tokenized_test)
88
+ predictions = np.array([np.argmax(x) for x in preds[0]])
89
 
90
  # Stop tracking emissions
91
  emissions_data = tracker.stop_task()