File size: 3,337 Bytes
6bc49a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e06d26
6bc49a9
4e06d26
6bc49a9
 
 
 
 
 
 
 
 
e008050
6bc49a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

from gpt2_knn_attention import GPT2KNNAttention
from knn_memory import KNNLayer, ClearMemoryLayer


def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8):
    layer = model.transformer.h[layer_ind].attn
    state = layer.state_dict()
    knn_layer = GPT2KNNAttention(
        config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx)
    knn_state = knn_layer.state_dict()

    for k, v in state.items():
        knn_state[k] = v

    knn_layer.load_state_dict(knn_state)

    model.transformer.h[8].attn = knn_layer
    model.transformer = ClearMemoryLayer(
        knn_memory, bos_token_id, eos_token_id, model.transformer)
    model.eval()


model_name = "gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
config = model.config
model.eval()

knn_memory = KNNLayer(config, share_memory=False, batch_size=1)
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inject_knn_in_gpt2(
    model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
model.load_state_dict(torch.load('gpt2_knn_attention.pt'))


def generate(text, temperature, max_new_tokens, top_p):
    encoded_input = tokenizer(text, return_tensors='pt')
    output = model.generate(**encoded_input, do_sample=True,
                            max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p)
    return tokenizer.decode(output[0])


desc = "Попытка повторить статью от Google [Memorizing Transformers](https://arxiv.org/abs/2203.08913). "\
       "В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\
       "Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\
       "Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)"


demo = gr.Interface(
    fn=generate,
    inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),
            gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'),
            gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'),
            gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')],
    outputs=gr.outputs.Textbox(label="Generated Text"),
    description=desc,
    title="Memorizing Transformers",
    examples=[
        ["My name is Lewis and I like to", 0.8, 32, 0.92]
    ]
)

demo.launch()