File size: 5,535 Bytes
4a2c956
7c790c0
 
 
61b9ff7
 
 
f19951d
0fe023b
d7e18bd
db3b05f
 
 
7c790c0
4a2c956
7c790c0
 
 
 
0fe023b
de61ce6
7c790c0
 
 
 
 
 
 
f943f56
7c790c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61b9ff7
 
 
7c790c0
 
 
 
 
 
ea2eccb
7c790c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0083156
 
7c790c0
 
 
d9ddd86
 
d9f0585
 
 
 
 
 
a8cdd01
d9f0585
 
 
 
 
 
 
 
 
a8cdd01
d9ddd86
 
 
c87c557
d9ddd86
 
 
 
7c790c0
 
 
 
 
2b73ce3
7c790c0
 
2b73ce3
c87c557
7c790c0
d9ddd86
 
 
7c790c0
d7e18bd
7c790c0
 
 
 
 
 
 
 
 
c87c557
7c790c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import os
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 2048
title = "RWKV-4-Pile-14B-20230313-ctx8192-test1050"
desc = f'''Links:
<a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a>
<a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a>
<a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a>
'''

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)

from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename=f"{title}.pth")
model = RWKV(model=model_path, strategy='cuda fp16i8 *20 -> cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "20B_tokenizer.json")

def infer(
        ctx,
        token_count=10,
        temperature=1.0,
        top_p=0.8,
        presencePenalty = 0.1,
        countPenalty = 0.1,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     token_ban = [0], # ban the generation of some tokens
                     token_stop = []) # stop generation whenever you see any token here

    ctx = ctx.strip(' ')
    if ctx.endswith('\n'):
        ctx = f'\n{ctx.strip()}\n'
    else:
        ctx = f'\n{ctx.strip()}'

    gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
    print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
    
    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    state = None
    for i in range(int(token_count)):
        out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
        for n in args.token_ban:
            out[n] = -float('inf')
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        if token in args.token_stop:
            break
        all_tokens += [token]
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1
        
        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            yield out_str.strip()
            out_last = i + 1
    gc.collect()
    torch.cuda.empty_cache()
    yield out_str.strip()

examples = [
    ["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nHow can we eliminate poverty?\n\nFull Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
    ["Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n", 150, 1.0, 0.7, 0.2, 0.2],
    ['''Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Generate a list of adjectives that describe a person as brave.

### Response:
''', 150, 1.0, 0.2, 0.5, 0.5],
    ['''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Arrange the given numbers in ascending order.

### Input:
2, 4, 0, 8, 3

### Response:
''', 150, 1.0, 0.2, 0.5, 0.5],
    ["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
    ["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 150, 1.0, 0.7, 0.2, 0.2],
    ["Dear sir,\nI would like to express my boundless apologies for the recent nuclear war.", 150, 1.0, 0.7, 0.2, 0.2],
    ["Here is a shell script to find all .hpp files in /home/workspace and delete the 3th row string of these files:", 150, 1.0, 0.7, 0.1, 0.1],
    ["Building a website can be done in 10 simple steps:\n1.", 150, 1.0, 0.7, 0.2, 0.2],
    ["A Chinese phrase is provided: 百闻不如一见。\nThe masterful Chinese translator flawlessly translates the phrase into English:", 150, 1.0, 0.5, 0.2, 0.2],
    ["I believe the meaning of life is", 150, 1.0, 0.7, 0.2, 0.2],
    ["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
]


iface = gr.Interface(
    fn=infer,
    description=f'''{desc} *** <b>Please try examples first (bottom of page)</b> *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"),  # prompt
        gr.Slider(10, 200, step=10, value=150),  # token_count
        gr.Slider(0.2, 2.0, step=0.1, value=1.0),  # temperature
        gr.Slider(0.0, 1.0, step=0.05, value=0.7),  # top_p
        gr.Slider(0.0, 1.0, step=0.1, value=0.2),  # presencePenalty
        gr.Slider(0.0, 1.0, step=0.1, value=0.2),  # countPenalty
    ],
    outputs=gr.Textbox(label="Generated Output", lines=28),
    examples=examples,
    cache_examples=False,
).queue()

demo = gr.TabbedInterface(
    [iface], ["Generative"],
    title=title,
)

demo.queue(max_size=10)
demo.launch(share=False)