File size: 1,309 Bytes
2f6628d
 
fd48f4d
2f6628d
fd48f4d
2f6628d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
###########################################################################
# NLP demo software by HyperbeeAI.                                        #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]"
print("imported dataloader.py")
print(license_statement)
print("")

from torchtext.legacy.datasets import TranslationDataset
from torchtext.legacy.data import Field, BucketIterator
import os

class NewsDataset(TranslationDataset):

    name = 'news-comm-v15'

    @staticmethod
    def sort_key(ex):
        return len(ex.src)

    @classmethod
    def splits(cls, exts, fields, root='./',
               train='news-comm-v15-all', validation='news-comm-v15-all-valid', test='news-comm-v15-all-test', **kwargs):

        if 'path' not in kwargs:
            expected_folder = os.path.join(root, cls.name)
            path = expected_folder if os.path.exists(expected_folder) else None
        else:
            path = kwargs['path']
            del kwargs['path']

        return super(NewsDataset, cls).splits(exts, fields, path, root, train, validation, test, **kwargs)