thak123's picture
Create model.py
9251dbc
raw
history blame
967 Bytes
import transformers
from mtm import MultitaskModel
from config import model_name, drop_out
multitask_model = MultitaskModel.create(
model_name=model_name,
model_type_dict={
"document": transformers.AutoModelForSequenceClassification,
"paragraph": transformers.AutoModelForSequenceClassification,
"sentence": transformers.AutoModelForSequenceClassification,
},
model_config_dict={
"document": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
"paragraph": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
"sentence": transformers.AutoConfig.from_pretrained(model_name, num_labels=3, hidden_dropout_prob=drop_out, attention_probs_dropout_prob=drop_out),
},
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)