File size: 5,239 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import os
from collections import defaultdict
from typing import *

import numpy as np
from allennlp.data import Vocabulary
from tqdm import tqdm

from sftp import SpanPredictor, Span
from sftp.utils import VIRTUAL_ROOT


def read_framenet(path: str):
    ret = list()
    for line in map(json.loads, open(path).readlines()):
        ret.append((line['tokens'], Span.from_json(line['annotations'])))
    return ret


def co_occur(
        predictor: SpanPredictor,
        sentences: List[Tuple[List[str], Span]],
        event_list: List[str],
        arg_list: List[str],
):
    idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label')
    event_count = np.zeros([len(event_list), len(idx2label)], np.float64)
    arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64)
    for sent, vr in tqdm(sentences):
        # For events
        _, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr])
        for event, dist in zip(vr, event_dist):
            event_count[event_list.index(event.label)] += dist
        # For args
        for event, one_event_dist in zip(vr, event_dist):
            parent_label = idx2label[int(one_event_dist.argmax())]
            arg_spans = [child.boundary for child in event]
            _, _, arg_dist = predictor.force_decode(
                sent, event.boundary, parent_label, arg_spans
            )
            for arg, dist in zip(event, arg_dist):
                arg_count[arg_list.index(arg.label)] += dist
    return event_count, arg_count


def create_vocab(events, args):
    vocab = Vocabulary()
    vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label')
    for event in events:
        vocab.add_token_to_namespace(event, 'span_label')
    for arg in args:
        vocab.add_token_to_namespace(arg, 'span_label')
    return vocab


def count_data(annotations: Iterable[Span]):
    event_cnt, arg_cnt = defaultdict(int), defaultdict(int)
    for sent in annotations:
        for event in sent:
            event_cnt[event.label] += 1
            for arg in event:
                arg_cnt[arg.label] += 1
    return dict(event_cnt), dict(arg_cnt)


def gen_mapping(
        src_label: List[str], src_count: Dict[str, int],
        tgt_onto: List[str], tgt_label: List[str],
        cooccur_count: np.ndarray
):
    """
    :param src_label: Src label list, including events and args.
    :param src_count: Src label count, event or arg.
    :param tgt_onto: Target label list, only event or arg.
    :param tgt_label: Target label count, event or arg.
    :param cooccur_count: Co-occurrence counting table.
    :return: Mapping dict.
    """
    onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float)
    for onto_idx, onto_tag in enumerate(tgt_onto):
        onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0
    ret = dict()
    for src_tag, src_freq in src_count.items():
        if src_tag in src_label:
            src_idx = src_label.index(src_tag)
            ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label)
    return ret


def ontology_map(
        model_path,
        src_data: List[Tuple[List[str], Span]],
        tgt_data: List[Tuple[List[str], Span]],
        device: int,
        dst_path: str,
        meta: Optional[dict] = None,
) -> None:
    ret = {'meta': meta or {}}
    data = {'src': {}, 'tgt': {}}
    for name, datasets in [['src', src_data], ['tgt', tgt_data]]:
        d = data[name]
        d['sentences'], d['annotations'] = zip(*datasets)
        d['event_cnt'], d['arg_cnt'] = count_data(d['annotations'])
        d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt'])

    predictor = SpanPredictor.from_path(model_path, cuda_device=device)
    tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg'])
    for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]:
        data[name]['label'] = [
            vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label'))
        ]

    data['event'], data['arg'] = co_occur(
        predictor, tgt_data, data['tgt']['event'], data['tgt']['arg']
    )
    mapping = {}
    for layer in ['event', 'arg']:
        mapping[layer] = gen_mapping(
            data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer]
        )

    for key, name in [['source', 'src'], ['target', 'tgt']]:
        ret[key] = {
            'label': data[name]['label'],
            'event': data[name]['event'],
            'argument': data[name]['arg']
        }
    ret['mapping'] = {
        'event': mapping['event'],
        'argument': mapping['arg']
    }

    os.makedirs(dst_path, exist_ok=True)
    with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp:
        json.dump(ret, fp)
    with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp:
        to_dump = list()
        to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event']))
        for event in ret['target']['event']:
            to_dump.append('\t'.join([event] + ret['target']['argument']))
        fp.write('\n'.join(to_dump))
    tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary'))