KoichiYasuoka
commited on
Commit
•
a36c945
1
Parent(s):
9a8f505
algorithm improved
Browse files
ud.py
CHANGED
@@ -7,6 +7,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
7 |
with torch.no_grad():
|
8 |
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
|
9 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
|
|
|
|
10 |
def postprocess(self,model_outputs,**kwargs):
|
11 |
import numpy
|
12 |
if "logits" not in model_outputs:
|
@@ -29,17 +31,34 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
29 |
h=self.chu_liu_edmonds(m)
|
30 |
v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
|
31 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
32 |
-
|
33 |
-
if g:
|
34 |
for i,j in reversed(list(enumerate(q[1:],1))):
|
35 |
if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
|
36 |
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
37 |
v[i-1]=(v[i-1][0],v.pop(i)[1])
|
38 |
q.pop(i)
|
|
|
|
|
|
|
|
|
39 |
t=model_outputs["sentence"].replace("\n"," ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
u="# text = "+t+"\n"
|
41 |
for i,(s,e) in enumerate(v):
|
42 |
-
u+="\t".join([str(i+1),t[s:e],
|
43 |
return u+"\n"
|
44 |
def chu_liu_edmonds(self,matrix):
|
45 |
import numpy
|
|
|
7 |
with torch.no_grad():
|
8 |
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
|
9 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
10 |
+
def check_model_type(self,supported_models):
|
11 |
+
pass
|
12 |
def postprocess(self,model_outputs,**kwargs):
|
13 |
import numpy
|
14 |
if "logits" not in model_outputs:
|
|
|
31 |
h=self.chu_liu_edmonds(m)
|
32 |
v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
|
33 |
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
|
34 |
+
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
|
|
35 |
for i,j in reversed(list(enumerate(q[1:],1))):
|
36 |
if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
|
37 |
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
38 |
v[i-1]=(v[i-1][0],v.pop(i)[1])
|
39 |
q.pop(i)
|
40 |
+
elif v[i-1][1]>v[i][0]:
|
41 |
+
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
42 |
+
v[i-1]=(v[i-1][0],v.pop(i)[1])
|
43 |
+
q.pop(i)
|
44 |
t=model_outputs["sentence"].replace("\n"," ")
|
45 |
+
for i,(s,e) in reversed(list(enumerate(v))):
|
46 |
+
w=t[s:e]
|
47 |
+
if w.startswith(" "):
|
48 |
+
j=len(w)-len(w.lstrip())
|
49 |
+
w=w.lstrip()
|
50 |
+
v[i]=(v[i][0]+j,v[i][1])
|
51 |
+
if w.endswith(" "):
|
52 |
+
j=len(w)-len(w.rstrip())
|
53 |
+
w=w.rstrip()
|
54 |
+
v[i]=(v[i][0],v[i][1]-j)
|
55 |
+
if w.strip()=="":
|
56 |
+
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
|
57 |
+
v.pop(i)
|
58 |
+
q.pop(i)
|
59 |
u="# text = "+t+"\n"
|
60 |
for i,(s,e) in enumerate(v):
|
61 |
+
u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
|
62 |
return u+"\n"
|
63 |
def chu_liu_edmonds(self,matrix):
|
64 |
import numpy
|