|
import gradio as gr |
|
import torch |
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel |
|
|
|
from gpt2_knn_attention import GPT2KNNAttention |
|
from knn_memory import KNNLayer, ClearMemoryLayer |
|
|
|
|
|
def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8): |
|
layer = model.transformer.h[layer_ind].attn |
|
state = layer.state_dict() |
|
knn_layer = GPT2KNNAttention( |
|
config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx) |
|
knn_state = knn_layer.state_dict() |
|
|
|
for k, v in state.items(): |
|
knn_state[k] = v |
|
|
|
knn_layer.load_state_dict(knn_state) |
|
|
|
model.transformer.h[8].attn = knn_layer |
|
model.transformer = ClearMemoryLayer( |
|
knn_memory, bos_token_id, eos_token_id, model.transformer) |
|
model.eval() |
|
|
|
|
|
model_name = "gpt2" |
|
tokenizer = GPT2TokenizerFast.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
config = model.config |
|
model.eval() |
|
|
|
knn_memory = KNNLayer(config, share_memory=False, batch_size=1) |
|
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id |
|
bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inject_knn_in_gpt2( |
|
model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8) |
|
model.load_state_dict(torch.load('gpt2_knn_attention.pt')) |
|
|
|
|
|
def generate(text, temperature, max_new_tokens, top_p): |
|
encoded_input = tokenizer(text, return_tensors='pt') |
|
output = model.generate(**encoded_input, do_sample=True, |
|
max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p) |
|
return tokenizer.decode(output[0]) |
|
|
|
|
|
desc = "Попытка повторить статью от Google [Memorizing Transformers](https://arxiv.org/abs/2203.08913). "\ |
|
"В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\ |
|
"Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\ |
|
"Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=[gr.inputs.Textbox(lines=5, label="Input Text"), |
|
gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'), |
|
gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'), |
|
gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')], |
|
outputs=gr.outputs.Textbox(label="Generated Text"), |
|
description=desc, |
|
title="Memorizing Transformers", |
|
examples=[ |
|
["My name is Lewis and I like to", 0.8, 32, 0.92] |
|
] |
|
) |
|
|
|
demo.launch() |
|
|