spoken-norm-taggen / infer.py
HuyenNguyen's picture
Duplicate from nguyenvulebinh/spoken-norm-taggen
394811b
#!/usr/bin/env python
# coding: utf-8
import torch
import model_handling
from data_handling import DataCollatorForNormSeq2Seq
from model_handling import EncoderDecoderSpokenNorm
import os
import random
import data_handling
from transformers.generation_logits_process import LogitsProcessorList
from transformers.generation_stopping_criteria import StoppingCriteriaList
from transformers.generation_beam_search import BeamSearchScorer
from dataclasses import dataclass
from transformers.file_utils import ModelOutput
import utils
# os.environ["CUDA_VISIBLE_DEVICES"] = "4"
use_gpu = False
if use_gpu:
if not torch.cuda.is_available():
use_gpu = False
tokenizer = model_handling.init_tokenizer()
model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm-taggen-v2').eval()
data_collator = DataCollatorForNormSeq2Seq(tokenizer)
if use_gpu:
model = model.cuda()
def make_batch_input(text_input_list):
batch_src_ids, batch_src_lengths = [], []
for text_input in text_input_list:
src_ids, src_lengths = [], []
for src in text_input.split():
src_tokenized = tokenizer(src)
ids = src_tokenized["input_ids"][1:-1]
src_ids.extend(ids)
src_lengths.append(len(ids))
src_ids = torch.tensor([0] + src_ids + [2])
src_lengths = torch.tensor([1] + src_lengths + [1]) + 1
batch_src_ids.append(src_ids)
batch_src_lengths.append(src_lengths)
assert sum(src_lengths - 1) == len(src_ids), "{} vs {}".format(sum(src_lengths), len(src_ids))
input_tokenized = tokenizer.pad({"input_ids": batch_src_ids}, padding=True)
input_word_length = tokenizer.pad({"input_ids": batch_src_lengths}, padding=True)["input_ids"] - 1
return input_tokenized['input_ids'], input_tokenized['attention_mask'], input_word_length
def make_batch_bias_list(bias_list):
if len(bias_list) > 0:
bias = data_collator.encode_list_string(bias_list)
bias_input_ids = bias['input_ids']
bias_attention_mask = bias['attention_mask']
else:
bias_input_ids = None
bias_attention_mask = None
return bias_input_ids, bias_attention_mask
def build_spoken_pronounce_mapping(bias_list):
list_pronounce = []
mapping = dict({})
for item in bias_list:
pronounces = item.split(' | ')[1:]
pronounces = [tokenizer(item)['input_ids'][1:-1] for item in pronounces]
list_pronounce.extend(pronounces)
subword_ids = list(set([item for sublist in list_pronounce for item in sublist]))
mapping = {item: [] for item in subword_ids}
for item in list_pronounce:
for wid in subword_ids:
if wid in item:
mapping[wid].append(item)
return mapping
def find_pivot(seq, subseq):
n = len(seq)
m = len(subseq)
result = []
for i in range(n - m + 1):
if seq[i] == subseq[0] and seq[i:i + m] == subseq:
result.append(i)
return result
def revise_spoken_tagging(list_tags, list_words, pronounce_mapping):
if len(pronounce_mapping) == 0:
return list_tags
result = []
for tags_tensor, sen in zip(list_tags, list_words):
tags = tags_tensor.detach().numpy().tolist()
sen = sen.detach().numpy().tolist()
candidate_pronounce = dict({})
for idx in range(len(tags)):
if tags[idx] != 0 and sen[idx] in pronounce_mapping:
for pronounce in pronounce_mapping[sen[idx]]:
pronounce_word = str(pronounce)
start_find_idx = max(0, idx - len(pronounce))
end_find_idx = idx + len(pronounce)
find_idx = find_pivot(sen[start_find_idx: end_find_idx], pronounce)
if len(find_idx) > 0:
find_idx = [item + start_find_idx for item in find_idx]
for map_idx in find_idx:
if candidate_pronounce.get(map_idx, None) is None:
candidate_pronounce[map_idx] = len(pronounce)
else:
candidate_pronounce[map_idx] = max(candidate_pronounce[map_idx], len(pronounce))
for idx, len_word in candidate_pronounce.items():
tags_tensor[idx] = 1
for i in range(1, len_word):
tags_tensor[idx + i] = 2
result.append(tags_tensor)
return result
def make_spoken_feature(input_features, text_input_list, pronounce_mapping=dict({})):
features = {
"input_ids": input_features[0],
"word_src_lengths": input_features[2],
"attention_mask": input_features[1],
# "bias_input_ids": bias_features[0],
# "bias_attention_mask": bias_features[1],
"bias_input_ids": None,
"bias_attention_mask": None,
}
if use_gpu:
for key in features.keys():
if features[key] is not None:
features[key] = features[key].cuda()
encoder_output = model.get_encoder()(**features)
spoken_tagging_output = torch.argmax(encoder_output[0].spoken_tagging_output, dim=-1)
spoken_tagging_output = revise_spoken_tagging(spoken_tagging_output, features['input_ids'], pronounce_mapping)
# print(spoken_tagging_output)
# print(features['input_ids'])
word_src_lengths = features['word_src_lengths']
encoder_features = encoder_output[0][0]
list_spoken_features = []
list_pre_norm = []
for tagging_sample, sample_word_length, text_input_features, sample_text in zip(spoken_tagging_output, word_src_lengths, encoder_features, text_input_list):
spoken_feature_idx = []
sample_words = ['<s>'] + sample_text.split() + ['</s>']
norm_words = []
spoken_phrase = []
spoken_features = []
if tagging_sample.sum() == 0:
list_pre_norm.append(sample_words)
continue
for idx, word_length in enumerate(sample_word_length):
if word_length > 0:
start = sample_word_length[:idx].sum()
end = start + word_length
if tagging_sample[start: end].sum() > 0 and sample_words[idx] not in ['<s>', '</s>']:
# Word has start tag
if (tagging_sample[start: end] == 1).sum():
if len(spoken_phrase) > 0:
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
spoken_phrase = []
list_spoken_features.append(torch.cat(spoken_features))
spoken_features = []
spoken_phrase.append(sample_words[idx])
spoken_features.append(text_input_features[start: end])
else:
if len(spoken_phrase) > 0:
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
spoken_phrase = []
list_spoken_features.append(torch.cat(spoken_features))
spoken_features = []
norm_words.append(sample_words[idx])
if len(spoken_phrase) > 0:
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
spoken_phrase = []
list_spoken_features.append(torch.cat(spoken_features))
spoken_features = []
list_pre_norm.append(norm_words)
list_features_mask = []
if len(list_spoken_features) > 0:
feature_pad = torch.zeros_like(list_spoken_features[0][:1, :])
max_length = max([len(item) for item in list_spoken_features])
for i in range(len(list_spoken_features)):
spoken_length = len(list_spoken_features[i])
remain_length = max_length - spoken_length
device = list_spoken_features[i].device
list_spoken_features[i] = torch.cat([list_spoken_features[i],
feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0)
list_features_mask.append(torch.cat([torch.ones(spoken_length, device=device, dtype=torch.int64),
torch.zeros(remain_length, device=device, dtype=torch.int64)]).unsqueeze(0))
if len(list_spoken_features) > 0:
list_spoken_features = torch.cat(list_spoken_features)
list_features_mask = torch.cat(list_features_mask)
return list_spoken_features, list_features_mask, list_pre_norm
def make_bias_feature(bias_raw_features):
features = {
"bias_input_ids": bias_raw_features[0],
"bias_attention_mask": bias_raw_features[1]
}
if use_gpu:
for key in features.keys():
if features[key] is not None:
features[key] = features[key].cuda()
return model.forward_bias(**features)
def decode_plain_output(decoder_output):
plain_output = [item.split()[1:] for item in tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=False)]
scores = torch.stack(list(decoder_output['scores'])).transpose(1, 0)
logit_output = torch.gather(scores, -1, decoder_output['sequences'][:, 1:].unsqueeze(-1)).squeeze(-1)
special_tokens = list(tokenizer.special_tokens_map.values())
generated_output = []
generated_scores = []
# filter special tokens
for out_text, out_score in zip(plain_output, logit_output):
temp_str, tmp_score = [], []
for piece, score in zip(out_text, out_score):
if piece not in special_tokens:
temp_str.append(piece)
tmp_score.append(score)
if len(temp_str) > 0:
generated_output.append(' '.join(temp_str).replace('▁', '|').replace(' ', '').replace('|', ' ').strip())
generated_scores.append((sum(tmp_score)/len(tmp_score)).cpu().detach().numpy().tolist())
else:
generated_output.append("")
generated_scores.append(0)
return generated_output, generated_scores
def generate_spoken_norm(list_spoken_features, list_features_mask, bias_features):
@dataclass
class EncoderOutputs(ModelOutput):
last_hidden_state: torch.FloatTensor = None
hidden_states: torch.FloatTensor = None
attentions: torch.FloatTensor = None
batch_size = list_spoken_features.size(0)
max_length = 50
device = list_spoken_features.device
decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
stopping_criteria=StoppingCriteriaList())
model_kwargs = {
"encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
"encoder_bias_outputs": bias_features,
"attention_mask": list_features_mask
}
decoder_output = model.greedy_search(
decoder_input_ids,
logits_processor=LogitsProcessorList(),
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True,
**model_kwargs,
)
plain_output, plain_score = decode_plain_output(decoder_output)
# plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
# # print(decoder_output)
# plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
return plain_output, plain_score
def generate_beam_spoken_norm(list_spoken_features, list_features_mask, bias_features, num_beams=3):
@dataclass
class EncoderOutputs(ModelOutput):
last_hidden_state: torch.FloatTensor = None
batch_size = list_spoken_features.size(0)
max_length = 50
num_return_sequences = 1
device = list_spoken_features.device
decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
stopping_criteria=StoppingCriteriaList())
model_kwargs = {
"encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
"encoder_bias_outputs": bias_features,
"attention_mask": list_features_mask
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
do_early_stopping=True,
num_beam_hyps_to_keep=num_return_sequences,
)
decoder_input_ids, model_kwargs = model._expand_inputs_for_generation(
decoder_input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
)
decoder_output = model.beam_search(
decoder_input_ids,
beam_scorer,
logits_processor=LogitsProcessorList(),
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
output_scores=None,
return_dict_in_generate=True,
**model_kwargs,
)
plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
return plain_output, None
def reformat_normed_term(list_pre_norm, spoken_norm_output, spoken_norm_output_score=None, threshold=None, debug=False):
output = []
for pre_norm in list_pre_norm:
normed_words = []
# words = pre_norm.split()
for w in pre_norm:
if w.startswith('<mask>'):
term = w[7:].split('](')
# print(w)
# print(term)
term_idx = int(term[0])
norm_val = spoken_norm_output[term_idx]
norm_val_score = None if (spoken_norm_output_score is None or threshold is None) else spoken_norm_output_score[term_idx]
pre_norm_val = term[1][:-1]
if debug:
if norm_val_score is not None:
normed_words.append("({})({:.2f})[{}]".format(norm_val, norm_val_score, pre_norm_val))
else:
normed_words.append("({})[{}]".format(norm_val, pre_norm_val))
else:
if threshold is not None and norm_val_score is not None:
if norm_val_score > threshold:
normed_words.append(norm_val)
else:
normed_words.append(pre_norm_val)
else:
normed_words.append(norm_val)
else:
normed_words.append(w)
output.append(" ".join(normed_words))
return output
def infer(text_input_list, bias_list):
# extract bias feature
bias_raw_features = make_batch_bias_list(bias_list)
bias_features = make_bias_feature(bias_raw_features)
pronounce_mapping = build_spoken_pronounce_mapping(bias_list)
# Chunk split input and create feature
text_input_chunk_list = [utils.split_chunk_input(item, chunk_size=60, overlap=20) for item in text_input_list]
num_chunks = [len(i) for i in text_input_chunk_list]
flatten_list = [y for x in text_input_chunk_list for y in x]
input_raw_features = make_batch_input(flatten_list)
# Extract norm term and spoken feature
list_spoken_features, list_features_mask, list_pre_norm = make_spoken_feature(input_raw_features, flatten_list, pronounce_mapping)
# Merge overlap chunks
list_pre_norm_by_input = []
for idx, input_num in enumerate(num_chunks):
start = sum(num_chunks[:idx])
end = start + num_chunks[idx]
list_pre_norm_by_input.append(list_pre_norm[start:end])
text_input_list_pre_norm = [utils.merge_chunk_pre_norm(list_chunks, overlap=20, debug=False) for list_chunks in list_pre_norm_by_input]
if len(list_spoken_features) > 0:
spoken_norm_output, spoken_norm_score = generate_spoken_norm(list_spoken_features, list_features_mask, bias_features)
else:
spoken_norm_output, spoken_norm_score = [], None
return reformat_normed_term(text_input_list_pre_norm, spoken_norm_output, spoken_norm_score, threshold=15, debug=False)