boazchung commited on
Commit
f4807cc
1 Parent(s): b1e927a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def get_pipe():
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ model_name = "heegyu/koalpaca-355m"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ tokenizer.truncation_side = "right"
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ return model, tokenizer
10
+
11
+ def get_response(tokenizer, model, context):
12
+ context = f"<usr>{context}\n<sys>"
13
+ inputs = tokenizer(
14
+ context,
15
+ truncation=True,
16
+ max_length=512,
17
+ return_tensors="pt")
18
+
19
+ generation_args = dict(
20
+ max_length=256,
21
+ min_length=64,
22
+ eos_token_id=2,
23
+ do_sample=True,
24
+ top_p=1.0,
25
+ early_stopping=True
26
+ )
27
+
28
+ outputs = model.generate(**inputs, **generation_args)
29
+ response = tokenizer.decode(outputs[0])
30
+ print(context)
31
+ print(response)
32
+ response = response[len(context):].replace("</s>", "")
33
+
34
+ return response
35
+
36
+ model, tokenizer = get_pipe()
37
+
38
+ def ask_question(input_):
39
+ response = get_response(tokenizer, model, input_)
40
+ return response
41
+
42
+ gr.Interface(fn=ask_question, inputs="text", outputs="text", title="KoAlpaca-355M", description="한국어로 질문하세요.").launch()