Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import time | |
import torch | |
from distributed_training.data.dataset import DataLoader | |
from huggingface_hub import create_tag, list_repo_refs, scan_cache_dir | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "cuda" | |
test_indices_length = 1000 | |
AUTOMATE = True | |
models = [ | |
"distributed/optimized-gpt2-2b", | |
"distributed/optimized-gpt2-1b", | |
"distributed/optimized-gpt2-500m", | |
"distributed/optimized-gpt2-250m", | |
"distributed/optimized-gpt2-250m-v0.1.3", | |
"distributed/optimized-gpt2-250m-v0.1.1", | |
"distributed/gpt2-94m", | |
] | |
if os.path.exists("results.json"): | |
with open("results.json", "r") as file: | |
results = json.load(file) | |
else: | |
results = {} | |
while True: | |
for model_name in [models[0]]: | |
if model_name not in results.keys(): | |
results[model_name] = {} | |
tokenizer = AutoTokenizer.from_pretrained( | |
"distributed/optimized-gpt2-250m", trust_remote_code=True | |
) | |
refs = list_repo_refs(model_name, repo_type="model") | |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None | |
if global_epoch in results[model_name]["main-net"].keys(): | |
print(f"Results for epoch {global_epoch} already calcualted") | |
time.sleep(30 * 60) | |
for epoch in range(0, global_epoch, 1): | |
if str(epoch) in results[model_name]["main-net"].keys(): | |
continue | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, revision=str(epoch), trust_remote_code=True | |
) | |
model = model.to(device) | |
search_start = random.choice( | |
range(DataLoader.max_pages - test_indices_length + 1) | |
) | |
group = [i for i in range(search_start, search_start + test_indices_length)] | |
dataloader = DataLoader( | |
batch_size=1, | |
sequence_length=1024, | |
rows=group, | |
) | |
total_loss = 0 | |
index = 0 | |
# Train data for one epoch | |
for index, batch in enumerate(dataloader): | |
inputs = batch[0].to(device) | |
labels = batch[1].to(device) | |
if len(inputs[0]) != len(labels[0]): | |
breakpoint() | |
if "optimized" in model_name: | |
outputs = model(input_ids=inputs, labels=labels) | |
loss = outputs[1] | |
else: | |
outputs = model(input_ids=inputs, labels=inputs) | |
loss = outputs.loss | |
# Accumulate Total Loss | |
total_loss += loss.detach().item() | |
# Backward Pass | |
model.zero_grad() | |
average_loss = total_loss / (index + 1) | |
results[model_name]["main-net"][str(epoch)] = [average_loss] | |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}") | |
with open("results.json", "w") as outfile: | |
json.dump(results, outfile, indent=4) | |
current_revision = model.config._commit_hash | |
keep_recent = 1 | |
try: | |
cache_info = scan_cache_dir() | |
for repo in cache_info.repos: | |
if repo.repo_id == model_name: | |
revisions = sorted( | |
repo.revisions, key=lambda r: r.last_modified, reverse=True | |
) | |
current_index = next( | |
( | |
i | |
for i, r in enumerate(revisions) | |
if r.commit_hash == current_revision | |
), | |
None, | |
) | |
if current_index is not None: | |
for revision in revisions[ | |
max(current_index + 1, keep_recent) : | |
]: | |
cache_info.delete_revisions( | |
revision.commit_hash | |
).execute() | |
break | |
except: | |
print( | |
"Failed to delete previous model version from cache. This might lead to 100% disk space utlisation in the future." | |
) | |