|
|
|
|
|
|
|
|
|
import json, math, random, os, sys |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from pytorch_lightning.utilities import rank_zero_info |
|
from .binidx import MMapIndexedDataset |
|
from .utils import MaybeIsPrime |
|
|
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, args): |
|
self.args = args |
|
|
|
if args.data_type == "binidx": |
|
self.vocab_size = args.vocab_size |
|
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") |
|
|
|
if args.my_pile_version == 1: |
|
self.data = MMapIndexedDataset(args.data_file) |
|
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size |
|
rank_zero_info(f"Data has {self.data_size} tokens.") |
|
else: |
|
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n') |
|
data_list = [i.strip().split(' ') for i in data_list] |
|
self.data = [] |
|
self.data_size = int(data_list[-1][-1]) |
|
rank_zero_info(f"Data has {self.data_size} chunks.") |
|
for d in data_list: |
|
data = MMapIndexedDataset(d[0]) |
|
data_size = len(data._bin_buffer) // data._index._dtype_size |
|
assert (data_size - args.ctx_len) == int(d[1]) |
|
self.data += [[int(d[-1]), int(d[1]), data]] |
|
|
|
|
|
if args.my_qa_mask > 0: |
|
|
|
self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document') |
|
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size |
|
else: |
|
self.data_pile = None |
|
self.data_pile_size = 0 |
|
|
|
if args.my_pile_stage > 0: |
|
|
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz |
|
assert self.samples_per_epoch == 40320 |
|
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") |
|
dataset_slot = self.data_size // args.ctx_len |
|
if args.my_pile_stage != 4: |
|
assert MaybeIsPrime(args.magic_prime) |
|
assert args.magic_prime % 3 == 2 |
|
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 |
|
elif args.data_type == "numpy": |
|
self.data = np.load(args.data_file).astype("int") |
|
self.vocab_size = args.vocab_size |
|
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") |
|
self.data_size = len(self.data) |
|
rank_zero_info(f"Data has {self.data_size} tokens.") |
|
elif args.data_type == "uint16": |
|
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) |
|
self.vocab_size = args.vocab_size |
|
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") |
|
self.data_size = self.data.shape[0] |
|
rank_zero_info(f"Data has {self.data_size} samples.") |
|
elif args.data_type == "wds_img": |
|
self.vocab_size = -1 |
|
self.data_size = -1 |
|
self.data = None |
|
self.error_count = 0 |
|
else: |
|
if args.data_type == "dummy": |
|
rank_zero_info("Building dummy data...") |
|
self.data = "" |
|
for i in range(100000): |
|
aa = (i) % 10000 |
|
bb = (i * i) % 10000 |
|
cc = aa + bb |
|
self.data += f".{aa}+{bb}={cc}." |
|
else: |
|
self.data = open(args.data_file, "r", encoding=args.data_type).read() |
|
rank_zero_info("Building token list...") |
|
unique = sorted(list(set(self.data))) |
|
self.vocab_size = len(unique) |
|
|
|
|
|
|
|
|
|
xx = 0 |
|
xxObj = {} |
|
for u in unique: |
|
xxObj[xx] = u |
|
xx += 1 |
|
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: |
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) |
|
self.data_size = len(self.data) |
|
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") |
|
self.stoi = {ch: i for i, ch in enumerate(unique)} |
|
self.itos = {i: ch for i, ch in enumerate(unique)} |
|
|
|
def __len__(self): |
|
return self.args.epoch_steps * self.args.micro_bsz |
|
|
|
def __getitem__(self, idx): |
|
args = self.args |
|
rank = self.global_rank |
|
epoch = self.real_epoch |
|
world_size = self.world_size |
|
|
|
|
|
if args.data_type == "wds_img": |
|
def init_wds(self, bias=0): |
|
def identity(x): |
|
return x |
|
import webdataset as wds |
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
img_transform = transforms.Compose([ |
|
transforms.CenterCrop(512), |
|
transforms.Resize((args.my_img_size)) |
|
]) |
|
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) |
|
for pp in self.data_raw.pipeline: |
|
if 'Resampled' in str(pp): |
|
pp.deterministic = True |
|
def worker_seed(): |
|
return rank*100000+epoch+bias*1e9 |
|
pp.worker_seed = worker_seed |
|
self.data = iter(self.data_raw) |
|
|
|
if self.data == None: |
|
init_wds(self) |
|
trial = 0 |
|
while trial < 10: |
|
try: |
|
dd = next(self.data) |
|
break |
|
except: |
|
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') |
|
self.error_count += 1 |
|
init_wds(self, self.error_count) |
|
trial += 1 |
|
pass |
|
|
|
|
|
|
|
return dd[0], dd[2] |
|
else: |
|
if args.data_type == "uint16": |
|
i = np.random.randint(0, self.data_size-1) |
|
dix = self.data[i] |
|
x = torch.tensor(dix[:-1], dtype=torch.long) |
|
y = torch.tensor(dix[1:], dtype=torch.long) |
|
else: |
|
ctx_len = args.ctx_len |
|
req_len = ctx_len + 1 |
|
magic_prime = args.magic_prime |
|
data = self.data |
|
|
|
if args.my_pile_stage > 0: |
|
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank |
|
|
|
if args.my_qa_mask > 0: |
|
ii_orig = ii |
|
if ii % 2 == 0: |
|
ii = -1 |
|
data = self.data_pile |
|
else: |
|
ii = ii // 2 |
|
if data == self.data_pile: |
|
i = np.random.randint(0, self.data_pile_size - req_len) |
|
else: |
|
if args.my_pile_stage == 4 or ii < args.my_random_steps: |
|
|
|
if args.my_pile_version == 1: |
|
i = np.random.randint(0, self.data_size - req_len) |
|
else: |
|
i = np.random.randint(0, self.data_size) |
|
else: |
|
ii = ii - args.my_random_steps |
|
factor = (math.sqrt(5) - 1) / 2 |
|
factor = int(magic_prime * factor) |
|
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len |
|
i = i + args.my_pile_shift |
|
|
|
else: |
|
|
|
i = np.random.randint(0, self.data_size - req_len) |
|
|
|
if args.data_type == "binidx": |
|
if args.my_pile_version == 1: |
|
dix = data.get(idx=0, offset=i, length=req_len).astype(int) |
|
else: |
|
|
|
for j in range(len(data)): |
|
if i < data[j][0]: |
|
ii = i |
|
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1] |
|
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int) |
|
|
|
break |
|
elif args.data_type == "numpy": |
|
dix = data[i : i + req_len] |
|
else: |
|
dix = [self.stoi[s] for s in data[i : i + req_len]] |
|
|
|
if args.my_qa_mask == 1: |
|
if data == self.data_pile: |
|
z = [1] * ctx_len |
|
else: |
|
z = [0] * ctx_len |
|
z_sum = 0 |
|
isGood = False |
|
for i in range(3, ctx_len): |
|
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: |
|
isGood = True |
|
if dix[i] == 0: |
|
isGood = False |
|
if isGood: |
|
z[i] = 1 |
|
z_sum += 1 |
|
if z_sum == 0: |
|
z = [1] * ctx_len |
|
i = np.random.randint(0, self.data_pile_size - req_len) |
|
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) |
|
z = torch.tensor(z, dtype=torch.bfloat16) |
|
|
|
x = torch.tensor(dix[:-1], dtype=torch.long) |
|
y = torch.tensor(dix[1:], dtype=torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.my_qa_mask == 1: |
|
return x, y, z |
|
|
|
return x, y |
|
|