|
from transformers.modeling_outputs import TokenClassifierOutput |
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig |
|
from torch.nn import CrossEntropyLoss |
|
from typing import Optional, Tuple, Union |
|
import logging, json, os |
|
import floret |
|
|
|
|
|
|
|
|
|
|
|
def get_info(label_map): |
|
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} |
|
return num_token_labels_dict |
|
|
|
|
|
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
print("11111111111111111111") |
|
self.model = floret.load_model(self.config.filename) |
|
print("22222222222222222222") |
|
|
|
|
|
def predict(self, text, k=1): |
|
predictions = self.model.predict(text, k) |
|
return predictions |