SengTak's picture
added necesary files-1
55890ea verified
import random
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from .model_overrides import get_forward
# A custom encode function to override the forward of the model
def encode_custom(forward, encoder, sentence_feature):
embed_mask = None
if "embed_mask" in sentence_feature:
embed_mask = sentence_feature.pop("embed_mask")
out, reps = forward(encoder.model, **sentence_feature)
sentence_feature["embed_mask"] = embed_mask
return [encoder.get_pooling(sentence_feature, emb) for emb in reps]
def l3prune(encoder, dataset, loss_fn, batch_size=64, num_batches=100):
dataset = [t for t in dataset]
subset = random.sample(dataset, batch_size*num_batches)
subset = [[encoder.prepare_for_tokenization(t) for t in s.texts] for s in subset]
subset = [subset[i:i + batch_size] for i in range(0, len(subset), batch_size)]
num_layers = encoder.model.config.num_hidden_layers
loss = {i: [] for i in range(1, num_layers+1)}
forward = get_forward(encoder.model)
with torch.no_grad():
# Override the forward of the model to get the intermediate representations in only one pass
if forward:
encode = partial(encode_custom, forward)
for batch in tqdm(subset):
features = []
for j in range(3):
embs = [t[j] for t in batch]
embs = encoder.tokenize(embs).to(encoder.model.device)
embs = encode(encoder, embs)
features += [embs]
q, d, d_neg = features
for i in range(num_layers):
loss[i+1] += [loss_fn(q[i], d[i], d_neg[i])]
else:
# Without the override, we have to rerun the forward pass with each layer pruned
for l in range(num_layers, 0, -1):
encoder.prune(layer_prune=l)
for batch in tqdm(subset):
features = []
for j in range(3):
embs = [t[j] for t in batch]
embs = encoder.tokenize(embs).to(encoder.model.device)
embs = encoder.forward(embs)
features += [embs]
q, d, d_neg = features
loss[l] += [loss_fn(q, d, d_neg)]
loss = [torch.tensor(loss[i]).mean().float().detach() for i in range(1, num_layers+1)]
# minima before and after midpoint
midpoint = num_layers // 2
small_p = np.argmin(loss[:midpoint]) + 1
large_p = np.argmin(loss[midpoint:]) + midpoint + 1
return small_p, large_p