pos-french / train.py
qanastek's picture
Update README.md
854b8b0
raw
history blame
1.57 kB
import os
import argparse
from datetime import datetime
from flair.data import Corpus
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.datasets import UniversalDependenciesCorpus
from flair.embeddings import WordEmbeddings, StackedEmbeddings
parser = argparse.ArgumentParser(description='Flair Training Part-of-speech tagging')
parser.add_argument('-output', type=str, default="models/", help='The output directory')
parser.add_argument('-epochs', type=int, default=1, help='Number of Epochs')
args = parser.parse_args()
output = os.path.join(args.output, "UPOS_UD_FRENCH_PLUS_" + str(args.epochs) + "_" + datetime.today().strftime('%Y-%m-%d-%H:%M:%S'))
print(output)
# corpus: Corpus = UD_FRENCH()
corpus: Corpus = UniversalDependenciesCorpus(
data_folder='UD_FRENCH_PLUS',
train_file="fr_gsd-ud-train.conllu",
test_file="fr_gsd-ud-test.conllu",
dev_file="fr_gsd-ud-dev.conllu",
)
# print(corpus)
tag_type = 'upos'
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
# print(tag_dictionary)
embedding_types = [
WordEmbeddings('fr'),
]
embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=True
)
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.train(
output,
learning_rate=0.1,
mini_batch_size=128,
max_epochs=args.epochs
)