|
|
|
import os |
|
src="KoichiYasuoka/deberta-base-japanese-aozora" |
|
tgt="KoichiYasuoka/deberta-base-japanese-aozora-ud-head" |
|
url="https://github.com/UniversalDependencies/UD_Japanese-GSDLUW" |
|
d=os.path.basename(url) |
|
os.system("test -d {} || git clone --depth=1 {}".format(d,url)) |
|
os.system("for F in train dev test ; do cp "+d+"/*-$F*.conllu $F.conllu ; done") |
|
from transformers import (AutoTokenizer,AutoModelForQuestionAnswering, |
|
AutoModelForTokenClassification,AutoConfig,DefaultDataCollator, |
|
DataCollatorForTokenClassification,TrainingArguments,Trainer) |
|
class HEADDataset(object): |
|
def __init__(self,conllu,tokenizer,augment=False,length=384): |
|
self.qa,self.pad,self.length=[],tokenizer.pad_token_id,length |
|
with open(conllu,"r",encoding="utf-8") as r: |
|
form,head=[],[] |
|
for t in r: |
|
w=t.split("\t") |
|
if len(w)==10 and w[0].isdecimal(): |
|
form.append(w[1]) |
|
head.append(len(head) if w[6]=="0" else int(w[6])-1) |
|
elif t.strip()=="" and form!=[]: |
|
v=tokenizer(form,add_special_tokens=False)["input_ids"] |
|
for i,t in enumerate(v): |
|
q=[tokenizer.cls_token_id]+t+[tokenizer.sep_token_id] |
|
c=[q]+v[0:i]+[[tokenizer.mask_token_id]]+v[i+1:]+[[q[-1]]] |
|
b=[len(sum(c[0:j+1],[])) for j in range(len(c))] |
|
if b[-1]<length: |
|
self.qa.append((sum(c,[]),head[i],b)) |
|
if augment and [1 for x in v if t==x]==[1]: |
|
c[i+1]=t |
|
b=[len(sum(c[0:j+1],[])) for j in range(len(c))] |
|
if b[-1]<length: |
|
self.qa.append((sum(c,[]),head[i],b)) |
|
form,head=[],[] |
|
__len__=lambda self:len(self.qa) |
|
def __getitem__(self,i): |
|
(v,h,b),k=self.qa[i],self.length-self.qa[i][2][-1] |
|
return {"input_ids":v+[self.pad]*k,"attention_mask":[1]*b[-1]+[0]*k, |
|
"token_type_ids":[0]*b[0]+[1]*(b[-1]-b[0])+[0]*k, |
|
"start_positions":b[h],"end_positions":b[h+1]-1} |
|
class UPOSDataset(object): |
|
def __init__(self,conllu,tokenizer,fields=[3]): |
|
self.ids,self.upos=[],[] |
|
label,cls,sep=set(),tokenizer.cls_token_id,tokenizer.sep_token_id |
|
with open(conllu,"r",encoding="utf-8") as r: |
|
form,upos=[],[] |
|
for t in r: |
|
w=t.split("\t") |
|
if len(w)==10 and w[0].isdecimal(): |
|
form.append(w[1]) |
|
upos.append("|".join(w[i] for i in fields)) |
|
elif t.strip()=="" and form!=[]: |
|
v,u=tokenizer(form,add_special_tokens=False)["input_ids"],[] |
|
for x,y in zip(v,upos): |
|
u.extend(["B-"+y]*min(len(x),1)+["I-"+y]*(len(x)-1)) |
|
if len(u)>tokenizer.model_max_length-4: |
|
self.ids.append(sum(v,[])[0:tokenizer.model_max_length-2]) |
|
self.upos.append(u[0:tokenizer.model_max_length-2]) |
|
elif len(u)>0: |
|
self.ids.append([cls]+sum(v,[])+[sep]) |
|
self.upos.append([u[0]]+u+[u[0]]) |
|
label=set(sum([self.upos[-1],list(label)],[])) |
|
form,upos=[],[] |
|
self.label2id={l:i for i,l in enumerate(sorted(label))} |
|
def __call__(*args): |
|
label=set(sum([list(t.label2id) for t in args],[])) |
|
lid={l:i for i,l in enumerate(sorted(label))} |
|
for t in args: |
|
t.label2id=lid |
|
return lid |
|
__len__=lambda self:len(self.ids) |
|
__getitem__=lambda self,i:{"input_ids":self.ids[i], |
|
"labels":[self.label2id[t] for t in self.upos[i]]} |
|
tkz=AutoTokenizer.from_pretrained(src) |
|
trainDS=HEADDataset("train.conllu",tkz,True) |
|
devDS=HEADDataset("dev.conllu",tkz) |
|
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8, |
|
output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2, |
|
evaluation_strategy="epoch",learning_rate=5e-05,warmup_ratio=0.1) |
|
trn=Trainer(args=arg,data_collator=DefaultDataCollator(), |
|
model=AutoModelForQuestionAnswering.from_pretrained(src), |
|
train_dataset=trainDS,eval_dataset=devDS) |
|
trn.train() |
|
trn.save_model(tgt) |
|
tkz.save_pretrained(tgt) |
|
trainDS=UPOSDataset("train.conllu",tkz,[7]) |
|
devDS=UPOSDataset("dev.conllu",tkz,[7]) |
|
testDS=UPOSDataset("test.conllu",tkz,[7]) |
|
lid=trainDS(devDS,testDS) |
|
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid, |
|
id2label={i:l for l,i in lid.items()}) |
|
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz), |
|
model=AutoModelForTokenClassification.from_pretrained(src,config=cfg), |
|
train_dataset=trainDS,eval_dataset=devDS) |
|
trn.train() |
|
trn.save_model(tgt+"/deprel") |
|
tkz.save_pretrained(tgt+"/deprel") |
|
trainDS=UPOSDataset("train.conllu",tkz,[3,5]) |
|
devDS=UPOSDataset("dev.conllu",tkz,[3,5]) |
|
testDS=UPOSDataset("test.conllu",tkz,[3,5]) |
|
lid=trainDS(devDS,testDS) |
|
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid, |
|
id2label={i:l for l,i in lid.items()}) |
|
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz), |
|
model=AutoModelForTokenClassification.from_pretrained(src,config=cfg), |
|
train_dataset=trainDS,eval_dataset=devDS) |
|
trn.train() |
|
trn.save_model(tgt+"/tagger") |
|
tkz.save_pretrained(tgt+"/tagger") |
|
|