ttphong68 commited on
Commit
b662551
1 Parent(s): 56ec216

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
2
+ import torch
3
+
4
+
5
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
6
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
7
+
8
+
9
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
10
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
11
+
12
+ def converse(user_input, chat_history=[]):
13
+
14
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
15
+
16
+ # keep history in the tensor
17
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
18
+
19
+ # get response
20
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
21
+ print (chat_history)
22
+
23
+
24
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
25
+
26
+ print("starting to print response")
27
+ print(response)
28
+
29
+ # html for display
30
+ html = "<div class='mybot'>"
31
+ for x, mesg in enumerate(response):
32
+ if x%2!=0 :
33
+ mesg="Alicia:"+mesg
34
+ clazz="alicia"
35
+ else :
36
+ clazz="user"
37
+
38
+
39
+ print("value of x")
40
+ print(x)
41
+ print("message")
42
+ print (mesg)
43
+
44
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
45
+ html += "</div>"
46
+ print(html)
47
+ return html, chat_history
48
+
49
+ import gradio as grad
50
+
51
+ css = """
52
+ .mychat {display:flex;flex-direction:column}
53
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
54
+ .mesg.user {background-color:lightblue;color:white}
55
+ .mesg.alicia {background-color:orange;color:white,align-self:self-end}
56
+ .footer {display:none !important}
57
+ """
58
+ text=grad.inputs.Textbox(placeholder="Lets chat")
59
+ grad.Interface(fn=converse,
60
+ theme="default",
61
+ inputs=[text, "state"],
62
+ outputs=["html", "state"],
63
+ css=css).launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ transformers[sentencepiece]