run480 commited on
Commit
415835e
1 Parent(s): 926adcf

Update app.py

Browse files

Microsoft DialoGPT model

Files changed (1) hide show
  1. app.py +73 -15
app.py CHANGED
@@ -371,23 +371,81 @@
371
  #-----------------------------------------------------------------------------------
372
  # 16. Text-to-Text Generation using the T5 model - Task 7 check whether a statement deduced from a text is correct or not.
373
 
374
- from transformers import T5ForConditionalGeneration, T5Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  import gradio as grad
376
 
377
- text2text_tkn= T5Tokenizer.from_pretrained("t5-small")
378
- mdl = T5ForConditionalGeneration.from_pretrained("t5-small")
 
 
 
 
 
379
 
380
- def text2text_deductible(sentence1,sentence2):
381
- inp1 = "rte sentence1: "+sentence1
382
- inp2 = "sentence2: "+sentence2
383
- combined_inp=inp1+" "+inp2
384
- enc = text2text_tkn(combined_inp, return_tensors="pt")
385
- tokens = mdl.generate(**enc)
386
- response=text2text_tkn.batch_decode(tokens)
387
- return response
388
 
389
- sent1=grad.Textbox(lines=1, label="Sentence1", placeholder="Text in English")
390
- sent2=grad.Textbox(lines=1, label="Sentence2", placeholder="Text in English")
391
- out=grad.Textbox(lines=1, label="Whether sentence2 is deductible from sentence1")
 
 
392
 
393
- grad.Interface(text2text_deductible, inputs=[sent1,sent2], outputs=out).launch()
 
371
  #-----------------------------------------------------------------------------------
372
  # 16. Text-to-Text Generation using the T5 model - Task 7 check whether a statement deduced from a text is correct or not.
373
 
374
+ # from transformers import T5ForConditionalGeneration, T5Tokenizer
375
+ # import gradio as grad
376
+
377
+ # text2text_tkn= T5Tokenizer.from_pretrained("t5-small")
378
+ # mdl = T5ForConditionalGeneration.from_pretrained("t5-small")
379
+
380
+ # def text2text_deductible(sentence1,sentence2):
381
+ # inp1 = "rte sentence1: "+sentence1
382
+ # inp2 = "sentence2: "+sentence2
383
+ # combined_inp=inp1+" "+inp2
384
+ # enc = text2text_tkn(combined_inp, return_tensors="pt")
385
+ # tokens = mdl.generate(**enc)
386
+ # response=text2text_tkn.batch_decode(tokens)
387
+ # return response
388
+
389
+ # sent1=grad.Textbox(lines=1, label="Sentence1", placeholder="Text in English")
390
+ # sent2=grad.Textbox(lines=1, label="Sentence2", placeholder="Text in English")
391
+ # out=grad.Textbox(lines=1, label="Whether sentence2 is deductible from sentence1")
392
+
393
+ # grad.Interface(text2text_deductible, inputs=[sent1,sent2], outputs=out).launch()
394
+
395
+ #-----------------------------------------------------------------------------------
396
+ # 17. Chatbot/Dialog Bot: a simple bot named Alicia that is based on the Microsoft DialoGPT model .
397
+
398
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
399
+ import torch
400
+
401
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
402
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
403
+
404
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
405
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
406
+
407
+ def converse(user_input, chat_history=[]):
408
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
409
+     # keep history in the tensor
410
+     bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
411
+     # get response
412
+     chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
413
+     print (chat_history)
414
+     response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
415
+     print("starting to print response")
416
+     print(response)
417
+     # html for display
418
+     html = "<div class='mybot'>"
419
+     for x, mesg in enumerate(response):
420
+         if x%2!=0 :
421
+            mesg="Alicia:"+mesg
422
+            clazz="alicia"
423
+         else :
424
+            clazz="user"
425
+         print("value of x")
426
+         print(x)
427
+         print("message")
428
+         print (mesg)
429
+         html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
430
+     html += "</div>"
431
+     print(html)
432
+     return html, chat_history
433
+
434
  import gradio as grad
435
 
436
+ css = """
437
+ .mychat {display:flex;flex-direction:column}
438
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
439
+ .mesg.user {background-color:lightblue;color:white}
440
+ .mesg.alicia {background-color:orange;color:white,align-self:self-end}
441
+ .footer {display:none !important}
442
+ """
443
 
444
+ text=grad.inputs.Textbox(placeholder="Lets chat")
 
 
 
 
 
 
 
445
 
446
+ grad.Interface(fn=converse,
447
+              theme="default",
448
+              inputs=[text, "state"],
449
+              outputs=["html", "state"],
450
+              css=css).launch()
451