lavawolfiee's picture
Changed md link format
e008050
import gradio as gr
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from gpt2_knn_attention import GPT2KNNAttention
from knn_memory import KNNLayer, ClearMemoryLayer
def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8):
layer = model.transformer.h[layer_ind].attn
state = layer.state_dict()
knn_layer = GPT2KNNAttention(
config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx)
knn_state = knn_layer.state_dict()
for k, v in state.items():
knn_state[k] = v
knn_layer.load_state_dict(knn_state)
model.transformer.h[8].attn = knn_layer
model.transformer = ClearMemoryLayer(
knn_memory, bos_token_id, eos_token_id, model.transformer)
model.eval()
model_name = "gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
config = model.config
model.eval()
knn_memory = KNNLayer(config, share_memory=False, batch_size=1)
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inject_knn_in_gpt2(
model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
model.load_state_dict(torch.load('gpt2_knn_attention.pt'))
def generate(text, temperature, max_new_tokens, top_p):
encoded_input = tokenizer(text, return_tensors='pt')
output = model.generate(**encoded_input, do_sample=True,
max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p)
return tokenizer.decode(output[0])
desc = "Попытка повторить статью от Google [Memorizing Transformers](https://arxiv.org/abs/2203.08913). "\
"В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\
"Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\
"Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)"
demo = gr.Interface(
fn=generate,
inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),
gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'),
gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'),
gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')],
outputs=gr.outputs.Textbox(label="Generated Text"),
description=desc,
title="Memorizing Transformers",
examples=[
["My name is Lewis and I like to", 0.8, 32, 0.92]
]
)
demo.launch()