nanotranslator-hf / dataloader.py
sonebu
update email
fd48f4d
###########################################################################
# NLP demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]"
print("imported dataloader.py")
print(license_statement)
print("")
from torchtext.legacy.datasets import TranslationDataset
from torchtext.legacy.data import Field, BucketIterator
import os
class NewsDataset(TranslationDataset):
name = 'news-comm-v15'
@staticmethod
def sort_key(ex):
return len(ex.src)
@classmethod
def splits(cls, exts, fields, root='./',
train='news-comm-v15-all', validation='news-comm-v15-all-valid', test='news-comm-v15-all-test', **kwargs):
if 'path' not in kwargs:
expected_folder = os.path.join(root, cls.name)
path = expected_folder if os.path.exists(expected_folder) else None
else:
path = kwargs['path']
del kwargs['path']
return super(NewsDataset, cls).splits(exts, fields, path, root, train, validation, test, **kwargs)