KoichiYasuoka commited on
Commit
b344b7c
1 Parent(s): a746cb4

TransformersUD added

Browse files
Files changed (1) hide show
  1. README.md +47 -0
README.md CHANGED
@@ -38,3 +38,50 @@ start,end=torch.argmax(outputs.start_logits),torch.argmax(outputs.end_logits)
38
  print(context[offsets[start][0]:offsets[end][-1]])
39
  ```
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  print(context[offsets[start][0]:offsets[end][-1]])
39
  ```
40
 
41
+ or (with [ufal.chu-liu-edmonds](https://pypi.org/project/ufal.chu-liu-edmonds/))
42
+
43
+ ```py
44
+ class TransformersUD(object):
45
+ def __init__(self,bert):
46
+ import os
47
+ from transformers import (AutoTokenizer,AutoModelForQuestionAnswering,
48
+ AutoModelForTokenClassification,AutoConfig,TokenClassificationPipeline)
49
+ self.tokenizer=AutoTokenizer.from_pretrained(bert)
50
+ self.model=AutoModelForQuestionAnswering.from_pretrained(bert)
51
+ d=os.path.join(bert,"tagger")
52
+ if os.path.isdir(d):
53
+ m=AutoModelForTokenClassification.from_pretrained(d)
54
+ else:
55
+ from transformers.file_utils import hf_bucket_url
56
+ c=AutoConfig.from_pretrained(hf_bucket_url(bert,"tagger/config.json"))
57
+ m=AutoModelForTokenClassification.from_pretrained(
58
+ hf_bucket_url(bert,"tagger/pytorch_model.bin"),config=c)
59
+ self.tagger=TokenClassificationPipeline(model=m,tokenizer=self.tokenizer,
60
+ aggregation_strategy="simple")
61
+ def __call__(self,text):
62
+ import numpy,torch,ufal.chu_liu_edmonds
63
+ y=self.tagger(text)
64
+ w=[(t["start"],t["end"],t["entity_group"].split("|")) for t in y]
65
+ r=[text[s:e] for s,e,p in w]
66
+ v=self.tokenizer(r,add_special_tokens=False)["input_ids"]
67
+ m=numpy.full((len(v)+1,len(v)+1),numpy.nan)
68
+ for i,t in enumerate(v):
69
+ a=[[self.tokenizer.cls_token_id]+t+[self.tokenizer.sep_token_id]]
70
+ a+=v[0:i]+[[self.tokenizer.mask_token_id]]+v[i+1:]+[[a[0][-1]]]
71
+ b,c=[len(sum(a[0:j],[])) for j in range(1,len(a))],sum(a,[])
72
+ d=self.model(input_ids=torch.tensor([c]),
73
+ token_type_ids=torch.tensor([[0]*len(a[0])+[1]*(len(c)-len(a[0]))]))
74
+ s,e=d.start_logits.tolist()[0],d.end_logits.tolist()[0]
75
+ for j in range(len(b)-1):
76
+ m[i+1,0 if i==j else j+1]=s[b[j]]+e[b[j+1]-1]
77
+ h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
78
+ u="# text = "+text.replace("\n"," ")+"\n"
79
+ for i,(s,e,p) in enumerate(w,1):
80
+ u+="\t".join([str(i),r[i-1],"_",p[0],"_","|".join(p[1:-1]),str(h[i]),
81
+ p[-1],"_","_" if i<len(w) and w[i][0]<e else "SpaceAfter=No"])+"\n"
82
+ return u
83
+
84
+ nlp=TransformersUD("KoichiYasuoka/deberta-base-japanese-aozora-ud-head")
85
+ print(nlp("全学年にわたって小学校の国語の教科書に挿し絵が用いられている"))
86
+ ```
87
+