from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union import logging, json, os import floret # logger = logging.getLogger(__name__) def get_info(label_map): num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} return num_token_labels_dict class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.model = floret.load_model(self.config.filename) print("We loaded the model") def predict(self, text, k=1): predictions = self.model.predict(text, k) return predictions