Spaces:
Sleeping
Sleeping
LeTruongVu2k1
commited on
Commit
·
2720879
1
Parent(s):
ddfda7d
adding JointBERT IDSF checkpoint folder, load_model.py and utils.py from IDSF; modified app.py and requirements.txt
Browse files- .gitattributes +1 -0
- JointBERT-CRF_PhoBERTencoder/config.json +26 -0
- JointBERT-CRF_PhoBERTencoder/eval_dev_results.txt +8 -0
- JointBERT-CRF_PhoBERTencoder/eval_test_results.txt +8 -0
- JointBERT-CRF_PhoBERTencoder/events.out.tfevents.1617863943.d86fb58144ae.20305.0 +3 -0
- JointBERT-CRF_PhoBERTencoder/pytorch_model.bin +3 -0
- JointBERT-CRF_PhoBERTencoder/training_args.bin +3 -0
- app.py +69 -9
- load_model.py +250 -0
- requirements.txt +3 -1
- utils.py +115 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
jdk-8u361-linux-aarch64.rpm filter=lfs diff=lfs merge=lfs -text
|
36 |
VnCoreNLP-1.2.jar filter=lfs diff=lfs merge=lfs -text
|
37 |
models/postagger/vi-tagger filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
jdk-8u361-linux-aarch64.rpm filter=lfs diff=lfs merge=lfs -text
|
36 |
VnCoreNLP-1.2.jar filter=lfs diff=lfs merge=lfs -text
|
37 |
models/postagger/vi-tagger filter=lfs diff=lfs merge=lfs -text
|
38 |
+
JointBERT-CRF_PhoBERTencoder/ filter=lfs diff=lfs merge=lfs -text
|
JointBERT-CRF_PhoBERTencoder/config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "vinai/phobert-base",
|
3 |
+
"architectures": [
|
4 |
+
"JointPhoBERT"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"finetuning_task": "word-level",
|
10 |
+
"gradient_checkpointing": false,
|
11 |
+
"hidden_act": "gelu",
|
12 |
+
"hidden_dropout_prob": 0.1,
|
13 |
+
"hidden_size": 768,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 3072,
|
16 |
+
"layer_norm_eps": 1e-05,
|
17 |
+
"max_position_embeddings": 258,
|
18 |
+
"model_type": "roberta",
|
19 |
+
"num_attention_heads": 12,
|
20 |
+
"num_hidden_layers": 12,
|
21 |
+
"pad_token_id": 1,
|
22 |
+
"position_embedding_type": "absolute",
|
23 |
+
"tokenizer_class": "PhobertTokenizer",
|
24 |
+
"type_vocab_size": 1,
|
25 |
+
"vocab_size": 64001
|
26 |
+
}
|
JointBERT-CRF_PhoBERTencoder/eval_dev_results.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
***** Eval results *****
|
2 |
+
intent_acc = 0.984
|
3 |
+
loss = 0.7875628210604191
|
4 |
+
mean_intent_slot = 0.9723264354415622
|
5 |
+
semantic_frame_acc = 0.874
|
6 |
+
slot_f1 = 0.9606528708831245
|
7 |
+
slot_precision = 0.959254947613504
|
8 |
+
slot_recall = 0.9620548744892002
|
JointBERT-CRF_PhoBERTencoder/eval_test_results.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
***** Eval results *****
|
2 |
+
intent_acc = 0.973124300111982
|
3 |
+
loss = 0.9325816652604512
|
4 |
+
mean_intent_slot = 0.9598045425009019
|
5 |
+
semantic_frame_acc = 0.8533034714445689
|
6 |
+
slot_f1 = 0.9464847848898217
|
7 |
+
slot_precision = 0.9445026178010472
|
8 |
+
slot_recall = 0.9484752891692955
|
JointBERT-CRF_PhoBERTencoder/events.out.tfevents.1617863943.d86fb58144ae.20305.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbe6f738dce99ab0769bc25063ad5cd2017725eb5789e7dbf61081166cf81c32
|
3 |
+
size 17078
|
JointBERT-CRF_PhoBERTencoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e6947fac4a325becd62e5932ca6d2f7d15014d486ac30308e56bfd9b0e7d451
|
3 |
+
size 540968940
|
JointBERT-CRF_PhoBERTencoder/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a16e0479297d9c9fe69f4fa4041819e22092ff6145f8e97d5eb56c74180e07b
|
3 |
+
size 1583
|
app.py
CHANGED
@@ -51,6 +51,48 @@ my_classifier = pipeline(
|
|
51 |
"token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline)
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
import gradio as gr
|
55 |
|
56 |
def ner(text):
|
@@ -64,23 +106,40 @@ def ner(text):
|
|
64 |
entity['entity'] = entity.pop('entity_group')
|
65 |
|
66 |
# Remove Disfluency-entities to return a sentence with "Fluency" version
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
for entity in output[::-1]: # if we use default order of output list, we will shorten the length of the sentence, so the words later are not in the correct start and end index
|
70 |
-
start = max(0, entity['start'] - 1)
|
71 |
-
end = min(len(list_str), entity['end'] + 1)
|
72 |
|
73 |
-
list_str[start:end] = ' '
|
74 |
|
75 |
-
fluency_sentence = "".join(list_str).strip() # use strip() in case we need to remove entity at the beginning or the end of sentence
|
76 |
-
# (without strip(): "Giá vé khứ hồi à nhầm giá vé một chiều ..." -> " giá vé một chiều ...")
|
77 |
fluency_sentence = fluency_sentence[0].upper() + fluency_sentence[1:] # since capitalize() just lowercase whole sentence first then uppercase the first letter
|
78 |
|
79 |
# Replace words like "Đà_Nẵng" to "Đà Nẵng"
|
80 |
text = text.replace("_", " ")
|
81 |
fluency_sentence = fluency_sentence.replace("_", " ")
|
82 |
|
83 |
-
return {'text': text, 'entities': output}, fluency_sentence
|
|
|
|
|
|
|
84 |
|
85 |
examples = ['Tôi cần thuê à tôi muốn bay một chuyến khứ hồi từ Đà Nẵng đến Đà Lạt',
|
86 |
'Giá vé một chiều à không khứ hồi từ Đà Nẵng đến Vinh dưới 2 triệu đồng giá vé khứ hồi từ Quy Nhơn đến Vinh dưới 3 triệu đồng giá vé khứ hồi từ Buôn Ma Thuột đến Quy Nhơn à đến Vinh dưới 4 triệu rưỡi',
|
@@ -91,7 +150,8 @@ examples = ['Tôi cần thuê à tôi muốn bay một chuyến khứ hồi từ
|
|
91 |
|
92 |
demo = gr.Interface(ner,
|
93 |
gr.Textbox(label='Sentence', placeholder="Enter your sentence here..."),
|
94 |
-
outputs=[gr.HighlightedText(label='Highlighted
|
|
|
95 |
examples=examples,
|
96 |
title="Disfluency Detection",
|
97 |
description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \
|
|
|
51 |
"token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline)
|
52 |
|
53 |
|
54 |
+
#################### IDSF #######################
|
55 |
+
from utils import get_intent_labels, get_slot_labels, load_tokenizer
|
56 |
+
import argparse
|
57 |
+
import load_model as lm
|
58 |
+
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
|
61 |
+
# parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
|
62 |
+
# parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
|
63 |
+
parser.add_argument("--model_dir", default="./JointBERT-CRF_PhoBERTencoder", type=str, help="Path to save, load model")
|
64 |
+
|
65 |
+
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
|
66 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
67 |
+
|
68 |
+
pred_config = parser.parse_args()
|
69 |
+
|
70 |
+
# load model and args
|
71 |
+
args = lm.get_args(pred_config)
|
72 |
+
device = lm.get_device(pred_config)
|
73 |
+
model = lm.load_model(pred_config, args, device)
|
74 |
+
|
75 |
+
intent_label_lst = get_intent_labels(args)
|
76 |
+
slot_label_lst = get_slot_labels(args)
|
77 |
+
|
78 |
+
# Convert input file to TensorDataset
|
79 |
+
pad_token_label_id = args.ignore_index
|
80 |
+
tokenizer = load_tokenizer(args)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
#################### END IDSF #######################
|
86 |
+
|
87 |
+
def remove_disfluency(example, prediction):
|
88 |
+
characters = list(example)
|
89 |
+
|
90 |
+
for entity in reversed(prediction):
|
91 |
+
characters[entity['start']:entity['end']] = ''
|
92 |
+
|
93 |
+
return " ".join("".join(characters).split())
|
94 |
+
|
95 |
+
|
96 |
import gradio as gr
|
97 |
|
98 |
def ner(text):
|
|
|
106 |
entity['entity'] = entity.pop('entity_group')
|
107 |
|
108 |
# Remove Disfluency-entities to return a sentence with "Fluency" version
|
109 |
+
fluency_sentence = remove_disfluency(text, output)
|
110 |
+
|
111 |
+
|
112 |
+
#################### IDSF #######################
|
113 |
+
words, slot_preds, intent_pred = lm.predict(fluency_sentence)[0][0], lm.predict(fluency_sentence)[1][0], lm.predict(fluency_sentence)[2][0]
|
114 |
+
|
115 |
+
slot_tokens = []
|
116 |
+
|
117 |
+
for word, pred in zip(words, slot_preds):
|
118 |
+
if pred == 'O':
|
119 |
+
slot_tokens.extend([(word, None), (" ", None)])
|
120 |
+
elif pred[0] == 'I':
|
121 |
+
added_tokens = list(slot_tokens[-2])
|
122 |
+
added_tokens[0] += f' {word}'
|
123 |
+
slot_tokens[-2] = tuple(added_tokens)
|
124 |
+
else:
|
125 |
+
slot_tokens.extend([(word, pred[2:]), (" ", None)])
|
126 |
+
|
127 |
+
intent_label = intent_label_lst[intent_pred]
|
128 |
+
|
129 |
+
#################### END IDSF #######################
|
130 |
|
|
|
|
|
|
|
131 |
|
|
|
132 |
|
|
|
|
|
133 |
fluency_sentence = fluency_sentence[0].upper() + fluency_sentence[1:] # since capitalize() just lowercase whole sentence first then uppercase the first letter
|
134 |
|
135 |
# Replace words like "Đà_Nẵng" to "Đà Nẵng"
|
136 |
text = text.replace("_", " ")
|
137 |
fluency_sentence = fluency_sentence.replace("_", " ")
|
138 |
|
139 |
+
return {'text': text, 'entities': output}, fluency_sentence, slot_tokens, intent_label
|
140 |
+
|
141 |
+
|
142 |
+
################################### Gradio Demo ####################################
|
143 |
|
144 |
examples = ['Tôi cần thuê à tôi muốn bay một chuyến khứ hồi từ Đà Nẵng đến Đà Lạt',
|
145 |
'Giá vé một chiều à không khứ hồi từ Đà Nẵng đến Vinh dưới 2 triệu đồng giá vé khứ hồi từ Quy Nhơn đến Vinh dưới 3 triệu đồng giá vé khứ hồi từ Buôn Ma Thuột đến Quy Nhơn à đến Vinh dưới 4 triệu rưỡi',
|
|
|
150 |
|
151 |
demo = gr.Interface(ner,
|
152 |
gr.Textbox(label='Sentence', placeholder="Enter your sentence here..."),
|
153 |
+
outputs=[gr.HighlightedText(label='Disfluency Highlighted'), gr.Textbox(label='"Fluency" version'),
|
154 |
+
gr.HighlightedText(label='Slot Filling Highlighted'), gr.Textbox(label='Intent Label')],
|
155 |
examples=examples,
|
156 |
title="Disfluency Detection",
|
157 |
description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \
|
load_model.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
from utils import MODEL_CLASSES, get_intent_labels, get_slot_labels, init_logger, load_tokenizer
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def get_device(pred_config):
|
18 |
+
return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"
|
19 |
+
|
20 |
+
|
21 |
+
def get_args(pred_config):
|
22 |
+
args = torch.load(os.path.join(pred_config.model_dir, "training_args.bin"))
|
23 |
+
|
24 |
+
args.model_dir = pred_config.model_dir
|
25 |
+
args.data_dir = 'PhoATIS'
|
26 |
+
|
27 |
+
return args
|
28 |
+
|
29 |
+
|
30 |
+
def load_model(pred_config, args, device):
|
31 |
+
# Check whether model exists
|
32 |
+
if not os.path.exists(pred_config.model_dir):
|
33 |
+
raise Exception("Model doesn't exists! Train first!")
|
34 |
+
|
35 |
+
try:
|
36 |
+
model = MODEL_CLASSES[args.model_type][1].from_pretrained(
|
37 |
+
args.model_dir, args=args, intent_label_lst=get_intent_labels(args), slot_label_lst=get_slot_labels(args)
|
38 |
+
)
|
39 |
+
model.to(device)
|
40 |
+
model.eval()
|
41 |
+
logger.info("***** Model Loaded *****")
|
42 |
+
except Exception:
|
43 |
+
raise Exception("Some model files might be missing...")
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
def convert_input_file_to_tensor_dataset(
|
48 |
+
lines,
|
49 |
+
pred_config,
|
50 |
+
args,
|
51 |
+
tokenizer,
|
52 |
+
pad_token_label_id,
|
53 |
+
cls_token_segment_id=0,
|
54 |
+
pad_token_segment_id=0,
|
55 |
+
sequence_a_segment_id=0,
|
56 |
+
mask_padding_with_zero=True,
|
57 |
+
):
|
58 |
+
# Setting based on the current model type
|
59 |
+
cls_token = tokenizer.cls_token
|
60 |
+
sep_token = tokenizer.sep_token
|
61 |
+
unk_token = tokenizer.unk_token
|
62 |
+
pad_token_id = tokenizer.pad_token_id
|
63 |
+
|
64 |
+
all_input_ids = []
|
65 |
+
all_attention_mask = []
|
66 |
+
all_token_type_ids = []
|
67 |
+
all_slot_label_mask = []
|
68 |
+
|
69 |
+
for words in lines:
|
70 |
+
tokens = []
|
71 |
+
slot_label_mask = []
|
72 |
+
for word in words:
|
73 |
+
word_tokens = tokenizer.tokenize(word)
|
74 |
+
if not word_tokens:
|
75 |
+
word_tokens = [unk_token] # For handling the bad-encoded word
|
76 |
+
tokens.extend(word_tokens)
|
77 |
+
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
78 |
+
slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
|
79 |
+
|
80 |
+
# Account for [CLS] and [SEP]
|
81 |
+
special_tokens_count = 2
|
82 |
+
if len(tokens) > args.max_seq_len - special_tokens_count:
|
83 |
+
tokens = tokens[: (args.max_seq_len - special_tokens_count)]
|
84 |
+
slot_label_mask = slot_label_mask[: (args.max_seq_len - special_tokens_count)]
|
85 |
+
|
86 |
+
# Add [SEP] token
|
87 |
+
tokens += [sep_token]
|
88 |
+
token_type_ids = [sequence_a_segment_id] * len(tokens)
|
89 |
+
slot_label_mask += [pad_token_label_id]
|
90 |
+
|
91 |
+
# Add [CLS] token
|
92 |
+
tokens = [cls_token] + tokens
|
93 |
+
token_type_ids = [cls_token_segment_id] + token_type_ids
|
94 |
+
slot_label_mask = [pad_token_label_id] + slot_label_mask
|
95 |
+
|
96 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
97 |
+
|
98 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
99 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
100 |
+
|
101 |
+
# Zero-pad up to the sequence length.
|
102 |
+
padding_length = args.max_seq_len - len(input_ids)
|
103 |
+
input_ids = input_ids + ([pad_token_id] * padding_length)
|
104 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
105 |
+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
106 |
+
slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
|
107 |
+
|
108 |
+
all_input_ids.append(input_ids)
|
109 |
+
all_attention_mask.append(attention_mask)
|
110 |
+
all_token_type_ids.append(token_type_ids)
|
111 |
+
all_slot_label_mask.append(slot_label_mask)
|
112 |
+
|
113 |
+
# Change to Tensor
|
114 |
+
all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
|
115 |
+
all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
|
116 |
+
all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
|
117 |
+
all_slot_label_mask = torch.tensor(all_slot_label_mask, dtype=torch.long)
|
118 |
+
|
119 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_slot_label_mask)
|
120 |
+
|
121 |
+
return dataset
|
122 |
+
|
123 |
+
def predict(text):
|
124 |
+
|
125 |
+
lines = text
|
126 |
+
dataset = convert_input_file_to_tensor_dataset(lines, pred_config, args, tokenizer, pad_token_label_id)
|
127 |
+
|
128 |
+
# Predict
|
129 |
+
sampler = SequentialSampler(dataset)
|
130 |
+
data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)
|
131 |
+
|
132 |
+
all_slot_label_mask = None
|
133 |
+
intent_preds = None
|
134 |
+
slot_preds = None
|
135 |
+
|
136 |
+
for batch in tqdm(data_loader, desc="Predicting"):
|
137 |
+
batch = tuple(t.to(device) for t in batch)
|
138 |
+
with torch.no_grad():
|
139 |
+
inputs = {
|
140 |
+
"input_ids": batch[0],
|
141 |
+
"attention_mask": batch[1],
|
142 |
+
"intent_label_ids": None,
|
143 |
+
"slot_labels_ids": None,
|
144 |
+
}
|
145 |
+
if args.model_type != "distilbert":
|
146 |
+
inputs["token_type_ids"] = batch[2]
|
147 |
+
outputs = model(**inputs)
|
148 |
+
_, (intent_logits, slot_logits) = outputs[:2]
|
149 |
+
|
150 |
+
# Intent Prediction
|
151 |
+
if intent_preds is None:
|
152 |
+
intent_preds = intent_logits.detach().cpu().numpy()
|
153 |
+
else:
|
154 |
+
intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
|
155 |
+
|
156 |
+
# Slot prediction
|
157 |
+
if slot_preds is None:
|
158 |
+
if args.use_crf:
|
159 |
+
# decode() in `torchcrf` returns list with best index directly
|
160 |
+
slot_preds = np.array(model.crf.decode(slot_logits))
|
161 |
+
else:
|
162 |
+
slot_preds = slot_logits.detach().cpu().numpy()
|
163 |
+
all_slot_label_mask = batch[3].detach().cpu().numpy()
|
164 |
+
else:
|
165 |
+
if args.use_crf:
|
166 |
+
slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0)
|
167 |
+
else:
|
168 |
+
slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
|
169 |
+
all_slot_label_mask = np.append(all_slot_label_mask, batch[3].detach().cpu().numpy(), axis=0)
|
170 |
+
|
171 |
+
intent_preds = np.argmax(intent_preds, axis=1)
|
172 |
+
|
173 |
+
if not args.use_crf:
|
174 |
+
slot_preds = np.argmax(slot_preds, axis=2)
|
175 |
+
|
176 |
+
slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
|
177 |
+
slot_preds_list = [[] for _ in range(slot_preds.shape[0])]
|
178 |
+
|
179 |
+
for i in range(slot_preds.shape[0]):
|
180 |
+
for j in range(slot_preds.shape[1]):
|
181 |
+
if all_slot_label_mask[i, j] != pad_token_label_id:
|
182 |
+
slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
|
183 |
+
|
184 |
+
return (lines, slot_preds_list, intent_preds)
|
185 |
+
|
186 |
+
|
187 |
+
def text_analysis(text):
|
188 |
+
text = [text.strip().split()]
|
189 |
+
|
190 |
+
words, slot_preds, intent_pred = predict(text)[0][0], predict(text)[1][0], predict(text)[2][0]
|
191 |
+
|
192 |
+
slot_tokens = []
|
193 |
+
|
194 |
+
for word, pred in zip(words, slot_preds):
|
195 |
+
if pred == 'O':
|
196 |
+
slot_tokens.extend([(word, None), (" ", None)])
|
197 |
+
elif pred[0] == 'I':
|
198 |
+
added_tokens = list(slot_tokens[-2])
|
199 |
+
added_tokens[0] += f' {word}'
|
200 |
+
slot_tokens[-2] = tuple(added_tokens)
|
201 |
+
else:
|
202 |
+
slot_tokens.extend([(word, pred[2:]), (" ", None)])
|
203 |
+
|
204 |
+
intent_label = intent_label_lst[intent_pred]
|
205 |
+
|
206 |
+
return slot_tokens, intent_label
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
init_logger()
|
212 |
+
parser = argparse.ArgumentParser()
|
213 |
+
|
214 |
+
# parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
|
215 |
+
# parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
|
216 |
+
parser.add_argument("--model_dir", default="./JointBERT-CRF_PhoBERTencoder", type=str, help="Path to save, load model")
|
217 |
+
|
218 |
+
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
|
219 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
220 |
+
|
221 |
+
pred_config = parser.parse_args()
|
222 |
+
|
223 |
+
# load model and args
|
224 |
+
args = get_args(pred_config)
|
225 |
+
device = get_device(pred_config)
|
226 |
+
model = load_model(pred_config, args, device)
|
227 |
+
logger.info(args)
|
228 |
+
|
229 |
+
intent_label_lst = get_intent_labels(args)
|
230 |
+
slot_label_lst = get_slot_labels(args)
|
231 |
+
|
232 |
+
# Convert input file to TensorDataset
|
233 |
+
pad_token_label_id = args.ignore_index
|
234 |
+
tokenizer = load_tokenizer(args)
|
235 |
+
|
236 |
+
|
237 |
+
examples = ["tôi muốn bay một chuyến khứ_hồi từ đà_nẵng đến đà_lạt",
|
238 |
+
("giá vé khứ_hồi từ đà_nẵng đến vinh dưới 2 triệu đồng giá vé khứ_hồi từ quy nhơn đến vinh dưới 3 triệu đồng giá vé khứ_hồi từ"
|
239 |
+
" buôn_ma_thuột đến vinh dưới 4 triệu rưỡi"),
|
240 |
+
"cho tôi biết các chuyến bay đến đà_nẵng vào ngày 14 tháng sáu",
|
241 |
+
"những chuyến bay nào khởi_hành từ thành_phố hồ_chí_minh bay đến frankfurt mà nối chuyến ở singapore và hạ_cánh trước 9 giờ tối"]
|
242 |
+
|
243 |
+
demo = gr.Interface(
|
244 |
+
text_analysis,
|
245 |
+
gr.Textbox(placeholder="Enter sentence here...", label="Input"),
|
246 |
+
[gr.HighlightedText(label='Highlighted Output'), gr.Textbox(label='Intent Label')],
|
247 |
+
examples=examples,
|
248 |
+
)
|
249 |
+
|
250 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
transformers
|
2 |
torch
|
3 |
-
py_vncorenlp
|
|
|
|
|
|
1 |
transformers
|
2 |
torch
|
3 |
+
py_vncorenlp
|
4 |
+
numpy
|
5 |
+
tqdm
|
utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from model import JointPhoBERT, JointXLMR
|
8 |
+
from seqeval.metrics import f1_score, precision_score, recall_score
|
9 |
+
from transformers import (
|
10 |
+
AutoTokenizer,
|
11 |
+
RobertaConfig,
|
12 |
+
XLMRobertaConfig,
|
13 |
+
XLMRobertaTokenizer,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
MODEL_CLASSES = {
|
18 |
+
"xlmr": (XLMRobertaConfig, JointXLMR, XLMRobertaTokenizer),
|
19 |
+
"phobert": (RobertaConfig, JointPhoBERT, AutoTokenizer),
|
20 |
+
}
|
21 |
+
|
22 |
+
MODEL_PATH_MAP = {
|
23 |
+
"xlmr": "xlm-roberta-base",
|
24 |
+
"phobert": "vinai/phobert-base",
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def get_intent_labels(args):
|
29 |
+
return [
|
30 |
+
label.strip()
|
31 |
+
for label in open(os.path.join(args.data_dir, args.token_level, args.intent_label_file), "r", encoding="utf-8")
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def get_slot_labels(args):
|
36 |
+
return [
|
37 |
+
label.strip()
|
38 |
+
for label in open(os.path.join(args.data_dir, args.token_level, args.slot_label_file), "r", encoding="utf-8")
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
def load_tokenizer(args):
|
43 |
+
return MODEL_CLASSES[args.model_type][2].from_pretrained(args.model_name_or_path)
|
44 |
+
|
45 |
+
|
46 |
+
def init_logger():
|
47 |
+
logging.basicConfig(
|
48 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
49 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
50 |
+
level=logging.INFO,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def set_seed(args):
|
55 |
+
random.seed(args.seed)
|
56 |
+
np.random.seed(args.seed)
|
57 |
+
torch.manual_seed(args.seed)
|
58 |
+
if not args.no_cuda and torch.cuda.is_available():
|
59 |
+
torch.cuda.manual_seed_all(args.seed)
|
60 |
+
|
61 |
+
|
62 |
+
def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels):
|
63 |
+
assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels)
|
64 |
+
results = {}
|
65 |
+
intent_result = get_intent_acc(intent_preds, intent_labels)
|
66 |
+
slot_result = get_slot_metrics(slot_preds, slot_labels)
|
67 |
+
sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels)
|
68 |
+
|
69 |
+
mean_intent_slot = (intent_result["intent_acc"] + slot_result["slot_f1"]) / 2
|
70 |
+
|
71 |
+
results.update(intent_result)
|
72 |
+
results.update(slot_result)
|
73 |
+
results.update(sementic_result)
|
74 |
+
results["mean_intent_slot"] = mean_intent_slot
|
75 |
+
|
76 |
+
return results
|
77 |
+
|
78 |
+
|
79 |
+
def get_slot_metrics(preds, labels):
|
80 |
+
assert len(preds) == len(labels)
|
81 |
+
return {
|
82 |
+
"slot_precision": precision_score(labels, preds),
|
83 |
+
"slot_recall": recall_score(labels, preds),
|
84 |
+
"slot_f1": f1_score(labels, preds),
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
def get_intent_acc(preds, labels):
|
89 |
+
acc = (preds == labels).mean()
|
90 |
+
return {"intent_acc": acc}
|
91 |
+
|
92 |
+
|
93 |
+
def read_prediction_text(args):
|
94 |
+
return [text.strip() for text in open(os.path.join(args.pred_dir, args.pred_input_file), "r", encoding="utf-8")]
|
95 |
+
|
96 |
+
|
97 |
+
def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels):
|
98 |
+
"""For the cases that intent and all the slots are correct (in one sentence)"""
|
99 |
+
# Get the intent comparison result
|
100 |
+
intent_result = intent_preds == intent_labels
|
101 |
+
|
102 |
+
# Get the slot comparision result
|
103 |
+
slot_result = []
|
104 |
+
for preds, labels in zip(slot_preds, slot_labels):
|
105 |
+
assert len(preds) == len(labels)
|
106 |
+
one_sent_result = True
|
107 |
+
for p, l in zip(preds, labels):
|
108 |
+
if p != l:
|
109 |
+
one_sent_result = False
|
110 |
+
break
|
111 |
+
slot_result.append(one_sent_result)
|
112 |
+
slot_result = np.array(slot_result)
|
113 |
+
|
114 |
+
semantic_acc = np.multiply(intent_result, slot_result).mean()
|
115 |
+
return {"semantic_frame_acc": semantic_acc}
|