File size: 819 Bytes
b3d4005 4aee592 b3d4005 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret
# logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = floret.load_model(self.config.filename)
def predict(self, text, k=1):
predictions = self.model.predict(text, k)
return predictions |