File size: 3,074 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torchtext.data.metrics import bleu_score

def bleu(valid_src_data, valid_trg_data, model, device, k, max_strlen):
    pred_sents = []
    for sentence in valid_src_data:
        pred_trg = model.translate_sentence(sentence, device, k, max_strlen)
        pred_sents.append(pred_trg)
    
    pred_sents = [self.TRG.preprocess(sent) for sent in pred_sents]
    trg_sents = [[sent.split()] for sent in valid_trg_data]
    
    return bleu_score(pred_sents, trg_sents)

def bleu_single(model, valid_dataset, debug=False):
    """Perform single sentence translation, then calculate bleu score. Update when batch beam search is online"""
    # need to join the sentence back per sample (the iterator is the one that had already been split to tokens)
    # THIS METRIC USE 2D vs 3D! AAAAAAHHHHHHH!!!!
    translate_pair = ( ([pair.trg], model.translate_sentence(pair.src, debug=debug)) for pair in valid_dataset)
#    raise Exception(next(translate_pair))
    labels, predictions = [list(l) for l in zip(*translate_pair)] # zip( *((l, p.split()) for l, p in translate_pair) )
    return bleu_score(predictions, labels)

def bleu_batch(model, valid_dataset, batch_size, debug=False):
    """Perform batch sentence translation in the same vein."""
    predictions = model.translate_batch_sentence([s.src for s in valid_dataset], output_tokens=True, batch_size=batch_size)
    labels = [[s.trg] for s in valid_dataset]
    return bleu_score(predictions, labels)


def _revert_trg(sent, eos): # revert batching process on sentence
    try:
        endloc = sent.index(eos)
        return sent[1:endloc]
    except ValueError:
        return sent[1:]

def bleu_batch_iter(model, valid_iter, src_lang=None, trg_lang=None, eos_token="<eos>", debug=False):
    """Perform batched translations; other metrics are the same. Note that the inputs/outputs had been preprocessed, but have [length, batch_size] shape as per BucketIterator"""
#    raise NotImplementedError("Error during calculation, currently unusable.")
 #   raise Exception([[model.SRC.vocab.itos[t] for t in batch] for batch in next(iter(valid_iter)).src.transpose(0, 1)])
    
    translated_batched_pair = (
        (
            pair.trg.transpose(0, 1), # transpose due to timestep-first batches
            model.decode_strategy.translate_batch_sentence(
                pair.src.transpose(0, 1),
                src_lang=src_lang,
                trg_lang=trg_lang,
                output_tokens=True, 
                field_processed=True, 
                replace_unk=False, # do not replace in this version
                debug=debug
            )
        ) 
        for pair in valid_iter 
    ) 
    flattened_pair = ( ([model.TRG.vocab.itos[i] for i in trg], pred) for batch_trg, batch_pred in translated_batched_pair for trg, pred in zip(batch_trg, batch_pred) )
    flat_labels, predictions = [list(l) for l in zip(*flattened_pair)]
    labels = [[_revert_trg(l, eos_token)] for l in flat_labels] # remove <sos> and <eos> also updim the trg for 3D requirements.
    return bleu_score(predictions, labels)