wybxc commited on
Commit
3cd2d35
1 Parent(s): 1f2f4d0

feat: init

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. README.md +6 -3
  3. app.py +82 -0
  4. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .vscode
2
+ __pycache__
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
  title: NeoYiri
3
- emoji:
4
  colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.24.1
 
8
  app_file: app.py
9
- pinned: false
 
 
10
  license: apache-2.0
11
  ---
12
 
 
1
  ---
2
  title: NeoYiri
3
+ emoji: 🥳
4
  colorFrom: pink
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
+ python_version: 3.10.9
9
  app_file: app.py
10
+ models:
11
+ - wybxc/new-yiri
12
+ pinned: true
13
  license: apache-2.0
14
  ---
15
 
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import cast
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import BertTokenizerFast, ErnieForCausalLM
6
+
7
+
8
+ def load_model():
9
+ tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri")
10
+ assert isinstance(tokenizer, BertTokenizerFast)
11
+ model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri")
12
+ assert isinstance(model, ErnieForCausalLM)
13
+
14
+ return tokenizer, model
15
+
16
+
17
+ def generate(
18
+ tokenizer: BertTokenizerFast,
19
+ model: ErnieForCausalLM,
20
+ input_str: str,
21
+ alpha: float,
22
+ topk: int,
23
+ ):
24
+ input_ids = tokenizer.encode(input_str, return_tensors="pt")
25
+ input_ids = cast(torch.Tensor, input_ids)
26
+ outputs = model.generate(
27
+ input_ids,
28
+ max_new_tokens=100,
29
+ penalty_alpha=alpha,
30
+ top_k=topk,
31
+ early_stopping=True,
32
+ decoder_start_token_id=tokenizer.sep_token_id,
33
+ eos_token_id=tokenizer.sep_token_id,
34
+ )
35
+ i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id)
36
+ output = tokenizer.decode(
37
+ outputs[0, i:],
38
+ skip_special_tokens=True,
39
+ ).replace(" ", "")
40
+ return output
41
+
42
+
43
+ with gr.Blocks() as demo:
44
+ with gr.Row():
45
+ with gr.Column(scale=3):
46
+ chatbot = gr.Chatbot().style(height=500)
47
+ with gr.Row():
48
+ with gr.Column(scale=4):
49
+ msg = gr.Textbox(
50
+ show_label=False, placeholder="Enter text and press enter"
51
+ ).style(container=False)
52
+ msg = cast(gr.Textbox, msg)
53
+ with gr.Column(scale=1):
54
+ button = gr.Button("Generate")
55
+ with gr.Column(scale=1):
56
+ clear = gr.Button("Clear")
57
+ with gr.Column(scale=1):
58
+ alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha")
59
+ topk = gr.Slider(1, 50, 5, step=1, label="Top K")
60
+
61
+ tokenizer, model = load_model()
62
+
63
+ def on_message(
64
+ user_message: str, history: list[list[str]], alpha: float, topk: int
65
+ ):
66
+ bot_message = generate(
67
+ tokenizer,
68
+ model,
69
+ user_message,
70
+ alpha=alpha,
71
+ topk=topk,
72
+ )
73
+ return "", [*history, [user_message, bot_message]]
74
+
75
+ msg.submit(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
76
+ button.click(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
77
+
78
+ clear.click(lambda: None, None, chatbot)
79
+
80
+ if __name__ == "__main__":
81
+ demo.queue(concurrency_count=3)
82
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch