File size: 819 Bytes
b3d4005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aee592
b3d4005
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret


# logger = logging.getLogger(__name__)


def get_info(label_map):
    num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
    return num_token_labels_dict


class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.model = floret.load_model(self.config.filename)

    def predict(self, text, k=1):
        predictions = self.model.predict(text, k)
        return predictions