pipeline1 / modeling_stacked.py
Gleb Vinarskis
debug
0464a8a
raw
history blame contribute delete
921 Bytes
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret
# logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
print("11111111111111111111")
self.model = floret.load_model(self.config.filename)
print("22222222222222222222")
# self.post_init()
def predict(self, text, k=1):
predictions = self.model.predict(text, k)
return predictions