|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
from transformers import ( |
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, |
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING, |
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
BlenderbotSmallForConditionalGeneration, |
|
BlenderbotSmallTokenizer, |
|
Conversation, |
|
ConversationalPipeline, |
|
TFAutoModelForCausalLM, |
|
pipeline, |
|
) |
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device |
|
|
|
from .test_pipelines_common import ANY |
|
|
|
|
|
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 |
|
|
|
|
|
@is_pipeline_test |
|
class ConversationalPipelineTests(unittest.TestCase): |
|
model_mapping = dict( |
|
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()) |
|
if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
else [] + list(MODEL_FOR_CAUSAL_LM_MAPPING.items()) |
|
if MODEL_FOR_CAUSAL_LM_MAPPING |
|
else [] |
|
) |
|
tf_model_mapping = dict( |
|
list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()) |
|
if TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
else [] + list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.items()) |
|
if TF_MODEL_FOR_CAUSAL_LM_MAPPING |
|
else [] |
|
) |
|
|
|
def get_test_pipeline(self, model, tokenizer, processor): |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
return conversation_agent, [Conversation("Hi there!")] |
|
|
|
def run_pipeline_test(self, conversation_agent, _): |
|
|
|
outputs = conversation_agent(Conversation("Hi there!")) |
|
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) |
|
|
|
|
|
outputs = conversation_agent([Conversation("Hi there!")]) |
|
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) |
|
|
|
|
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") |
|
conversation_2 = Conversation("What's the last book you have read?") |
|
self.assertEqual(len(conversation_1.past_user_inputs), 0) |
|
self.assertEqual(len(conversation_2.past_user_inputs), 0) |
|
|
|
outputs = conversation_agent([conversation_1, conversation_2]) |
|
self.assertEqual(outputs, [conversation_1, conversation_2]) |
|
self.assertEqual( |
|
outputs, |
|
[ |
|
Conversation( |
|
past_user_inputs=["Going to the movies tonight - any suggestions?"], |
|
generated_responses=[ANY(str)], |
|
), |
|
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]), |
|
], |
|
) |
|
|
|
|
|
conversation_2.add_user_input("Why do you recommend it?") |
|
outputs = conversation_agent(conversation_2) |
|
self.assertEqual(outputs, conversation_2) |
|
self.assertEqual( |
|
outputs, |
|
Conversation( |
|
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"], |
|
generated_responses=[ANY(str), ANY(str)], |
|
), |
|
) |
|
with self.assertRaises(ValueError): |
|
conversation_agent("Hi there!") |
|
with self.assertRaises(ValueError): |
|
conversation_agent(Conversation()) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
conversation_agent(conversation_2) |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation(self): |
|
|
|
conversation_agent = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM) |
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") |
|
conversation_2 = Conversation("What's the last book you have read?") |
|
|
|
self.assertEqual(len(conversation_1.past_user_inputs), 0) |
|
self.assertEqual(len(conversation_2.past_user_inputs), 0) |
|
|
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) |
|
|
|
self.assertEqual(result, [conversation_1, conversation_2]) |
|
self.assertEqual(len(result[0].past_user_inputs), 1) |
|
self.assertEqual(len(result[1].past_user_inputs), 1) |
|
self.assertEqual(len(result[0].generated_responses), 1) |
|
self.assertEqual(len(result[1].generated_responses), 1) |
|
self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?") |
|
self.assertEqual(result[0].generated_responses[0], "The Big Lebowski") |
|
self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?") |
|
self.assertEqual(result[1].generated_responses[0], "The Last Question") |
|
|
|
conversation_2.add_user_input("Why do you recommend it?") |
|
result = conversation_agent(conversation_2, do_sample=False, max_length=1000) |
|
|
|
self.assertEqual(result, conversation_2) |
|
self.assertEqual(len(result.past_user_inputs), 2) |
|
self.assertEqual(len(result.generated_responses), 2) |
|
self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?") |
|
self.assertEqual(result.generated_responses[1], "It's a good book.") |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation_truncated_history(self): |
|
|
|
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM) |
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") |
|
|
|
self.assertEqual(len(conversation_1.past_user_inputs), 0) |
|
|
|
result = conversation_agent(conversation_1, do_sample=False, max_length=36) |
|
|
|
self.assertEqual(result, conversation_1) |
|
self.assertEqual(len(result.past_user_inputs), 1) |
|
self.assertEqual(len(result.generated_responses), 1) |
|
self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?") |
|
self.assertEqual(result.generated_responses[0], "The Big Lebowski") |
|
|
|
conversation_1.add_user_input("Is it an action movie?") |
|
result = conversation_agent(conversation_1, do_sample=False, max_length=36) |
|
|
|
self.assertEqual(result, conversation_1) |
|
self.assertEqual(len(result.past_user_inputs), 2) |
|
self.assertEqual(len(result.generated_responses), 2) |
|
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") |
|
self.assertEqual(result.generated_responses[1], "It's a comedy.") |
|
|
|
@require_torch |
|
def test_small_model_pt(self): |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
conversation = Conversation("hello") |
|
output = conversation_agent(conversation) |
|
self.assertEqual(output, Conversation(past_user_inputs=["hello"], generated_responses=["Hi"])) |
|
|
|
@require_tf |
|
def test_small_model_tf(self): |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") |
|
model = TFAutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
conversation = Conversation("hello") |
|
output = conversation_agent(conversation) |
|
self.assertEqual(output, Conversation(past_user_inputs=["hello"], generated_responses=["Hi"])) |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation_dialogpt_input_ids(self): |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
|
|
conversation_1 = Conversation("hello") |
|
inputs = conversation_agent.preprocess(conversation_1) |
|
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]]) |
|
|
|
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"]) |
|
inputs = conversation_agent.preprocess(conversation_2) |
|
self.assertEqual( |
|
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]] |
|
) |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation_blenderbot_400M_input_ids(self): |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
|
|
|
|
conversation_1 = Conversation("hello") |
|
inputs = conversation_agent.preprocess(conversation_1) |
|
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]]) |
|
|
|
|
|
conversation_1 = Conversation( |
|
"I like lasagne.", |
|
past_user_inputs=["hello"], |
|
generated_responses=[ |
|
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie." |
|
], |
|
) |
|
inputs = conversation_agent.preprocess(conversation_1) |
|
self.assertEqual( |
|
inputs["input_ids"].tolist(), |
|
[ |
|
|
|
[ |
|
1710, |
|
86, |
|
228, |
|
228, |
|
946, |
|
304, |
|
398, |
|
6881, |
|
558, |
|
964, |
|
38, |
|
452, |
|
315, |
|
265, |
|
6252, |
|
452, |
|
322, |
|
968, |
|
6884, |
|
3146, |
|
278, |
|
306, |
|
265, |
|
617, |
|
87, |
|
388, |
|
75, |
|
341, |
|
286, |
|
521, |
|
21, |
|
228, |
|
228, |
|
281, |
|
398, |
|
6881, |
|
558, |
|
964, |
|
21, |
|
2, |
|
], |
|
], |
|
) |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation_blenderbot_400M(self): |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) |
|
|
|
conversation_1 = Conversation("hello") |
|
result = conversation_agent( |
|
conversation_1, |
|
) |
|
self.assertEqual( |
|
result.generated_responses[0], |
|
|
|
|
|
|
|
" Hello! How are you doing today? I just got back from a walk with my dog.", |
|
) |
|
|
|
conversation_1 = Conversation("Lasagne hello") |
|
result = conversation_agent(conversation_1, encoder_no_repeat_ngram_size=3) |
|
self.assertEqual( |
|
result.generated_responses[0], |
|
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.", |
|
) |
|
|
|
conversation_1 = Conversation( |
|
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne." |
|
) |
|
result = conversation_agent( |
|
conversation_1, |
|
encoder_no_repeat_ngram_size=3, |
|
) |
|
self.assertEqual( |
|
result.generated_responses[0], |
|
" Me too. I like how it can be topped with vegetables, meats, and condiments.", |
|
) |
|
|
|
@require_torch |
|
@slow |
|
def test_integration_torch_conversation_encoder_decoder(self): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M") |
|
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM) |
|
|
|
conversation_1 = Conversation("My name is Sarah and I live in London") |
|
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ") |
|
|
|
self.assertEqual(len(conversation_1.past_user_inputs), 0) |
|
self.assertEqual(len(conversation_2.past_user_inputs), 0) |
|
|
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) |
|
|
|
self.assertEqual(result, [conversation_1, conversation_2]) |
|
self.assertEqual(len(result[0].past_user_inputs), 1) |
|
self.assertEqual(len(result[1].past_user_inputs), 1) |
|
self.assertEqual(len(result[0].generated_responses), 1) |
|
self.assertEqual(len(result[1].generated_responses), 1) |
|
self.assertEqual(result[0].past_user_inputs[0], "My name is Sarah and I live in London") |
|
self.assertEqual( |
|
result[0].generated_responses[0], |
|
"hi sarah, i live in london as well. do you have any plans for the weekend?", |
|
) |
|
self.assertEqual( |
|
result[1].past_user_inputs[0], "Going to the movies tonight, What movie would you recommend? " |
|
) |
|
self.assertEqual( |
|
result[1].generated_responses[0], "i don't know... i'm not really sure. what movie are you going to see?" |
|
) |
|
|
|
conversation_1.add_user_input("Not yet, what about you?") |
|
conversation_2.add_user_input("What's your name?") |
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) |
|
|
|
self.assertEqual(result, [conversation_1, conversation_2]) |
|
self.assertEqual(len(result[0].past_user_inputs), 2) |
|
self.assertEqual(len(result[1].past_user_inputs), 2) |
|
self.assertEqual(len(result[0].generated_responses), 2) |
|
self.assertEqual(len(result[1].generated_responses), 2) |
|
self.assertEqual(result[0].past_user_inputs[1], "Not yet, what about you?") |
|
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.") |
|
self.assertEqual(result[1].past_user_inputs[1], "What's your name?") |
|
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.") |
|
|
|
@require_torch |
|
@slow |
|
def test_from_pipeline_conversation(self): |
|
model_id = "facebook/blenderbot_small-90M" |
|
|
|
|
|
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id) |
|
|
|
|
|
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id) |
|
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id) |
|
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer) |
|
|
|
conversation = Conversation("My name is Sarah and I live in London") |
|
conversation_copy = Conversation("My name is Sarah and I live in London") |
|
|
|
result_model_id = conversation_agent_from_model_id([conversation]) |
|
result_model = conversation_agent_from_model([conversation_copy]) |
|
|
|
|
|
self.assertEqual( |
|
result_model_id.generated_responses[0], |
|
"hi sarah, i live in london as well. do you have any plans for the weekend?", |
|
) |
|
self.assertEqual( |
|
result_model_id.generated_responses[0], |
|
result_model.generated_responses[0], |
|
) |
|
|