KoichiYasuoka commited on
Commit
a36c945
1 Parent(s): 9a8f505

algorithm improved

Browse files
Files changed (1) hide show
  1. ud.py +22 -3
ud.py CHANGED
@@ -7,6 +7,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
7
  with torch.no_grad():
8
  e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
 
 
10
  def postprocess(self,model_outputs,**kwargs):
11
  import numpy
12
  if "logits" not in model_outputs:
@@ -29,17 +31,34 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
29
  h=self.chu_liu_edmonds(m)
30
  v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
31
  q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
32
- g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
33
- if g:
34
  for i,j in reversed(list(enumerate(q[1:],1))):
35
  if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
36
  h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
37
  v[i-1]=(v[i-1][0],v.pop(i)[1])
38
  q.pop(i)
 
 
 
 
39
  t=model_outputs["sentence"].replace("\n"," ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  u="# text = "+t+"\n"
41
  for i,(s,e) in enumerate(v):
42
- u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
43
  return u+"\n"
44
  def chu_liu_edmonds(self,matrix):
45
  import numpy
 
7
  with torch.no_grad():
8
  e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
10
+ def check_model_type(self,supported_models):
11
+ pass
12
  def postprocess(self,model_outputs,**kwargs):
13
  import numpy
14
  if "logits" not in model_outputs:
 
31
  h=self.chu_liu_edmonds(m)
32
  v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
33
  q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
34
+ if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
 
35
  for i,j in reversed(list(enumerate(q[1:],1))):
36
  if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
37
  h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
38
  v[i-1]=(v[i-1][0],v.pop(i)[1])
39
  q.pop(i)
40
+ elif v[i-1][1]>v[i][0]:
41
+ h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
42
+ v[i-1]=(v[i-1][0],v.pop(i)[1])
43
+ q.pop(i)
44
  t=model_outputs["sentence"].replace("\n"," ")
45
+ for i,(s,e) in reversed(list(enumerate(v))):
46
+ w=t[s:e]
47
+ if w.startswith(" "):
48
+ j=len(w)-len(w.lstrip())
49
+ w=w.lstrip()
50
+ v[i]=(v[i][0]+j,v[i][1])
51
+ if w.endswith(" "):
52
+ j=len(w)-len(w.rstrip())
53
+ w=w.rstrip()
54
+ v[i]=(v[i][0],v[i][1]-j)
55
+ if w.strip()=="":
56
+ h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
57
+ v.pop(i)
58
+ q.pop(i)
59
  u="# text = "+t+"\n"
60
  for i,(s,e) in enumerate(v):
61
+ u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
62
  return u+"\n"
63
  def chu_liu_edmonds(self,matrix):
64
  import numpy