Spaces:
Running
Running
import os | |
from contextlib import contextmanager, redirect_stderr, redirect_stdout | |
from typing import List | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
from ..message import SYSTEM_NAME as SYSTEM | |
from ..message import Message | |
from .base import IntelligenceBackend | |
def suppress_stdout_stderr(): | |
"""A context manager that redirects stdout and stderr to devnull.""" | |
with open(os.devnull, "w") as fnull: | |
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: | |
yield (err, out) | |
with suppress_stdout_stderr(): | |
# Try to import the transformers package | |
try: | |
import transformers | |
from transformers import pipeline | |
from transformers.pipelines.conversational import ( | |
Conversation, | |
ConversationalPipeline, | |
) | |
except ImportError: | |
is_transformers_available = False | |
else: | |
is_transformers_available = True | |
class TransformersConversational(IntelligenceBackend): | |
"""Interface to the Transformers ConversationalPipeline.""" | |
stateful = False | |
type_name = "transformers:conversational" | |
def __init__(self, model: str, device: int = -1, **kwargs): | |
super().__init__(model=model, device=device, **kwargs) | |
self.model = model | |
self.device = device | |
assert is_transformers_available, "Transformers package is not installed" | |
self.chatbot = pipeline( | |
task="conversational", model=self.model, device=self.device | |
) | |
def _get_response(self, conversation): | |
conversation = self.chatbot(conversation) | |
response = conversation.generated_responses[-1] | |
return response | |
def _msg_template(agent_name, content): | |
return f"[{agent_name}]: {content}" | |
def query( | |
self, | |
agent_name: str, | |
role_desc: str, | |
history_messages: List[Message], | |
global_prompt: str = None, | |
request_msg: Message = None, | |
*args, | |
**kwargs, | |
) -> str: | |
user_inputs, generated_responses = [], [] | |
all_messages = ( | |
[(SYSTEM, global_prompt), (SYSTEM, role_desc)] | |
if global_prompt | |
else [(SYSTEM, role_desc)] | |
) | |
for msg in history_messages: | |
all_messages.append((msg.agent_name, msg.content)) | |
if request_msg: | |
all_messages.append((SYSTEM, request_msg.content)) | |
prev_is_user = False # Whether the previous message is from the user | |
for i, message in enumerate(all_messages): | |
if i == 0: | |
assert ( | |
message[0] == SYSTEM | |
) # The first message should be from the system | |
if message[0] != agent_name: | |
if not prev_is_user: | |
user_inputs.append(self._msg_template(message[0], message[1])) | |
else: | |
user_inputs[-1] += "\n" + self._msg_template(message[0], message[1]) | |
prev_is_user = True | |
else: | |
if prev_is_user: | |
generated_responses.append(message[1]) | |
else: | |
generated_responses[-1] += "\n" + message[1] | |
prev_is_user = False | |
assert len(user_inputs) == len(generated_responses) + 1 | |
past_user_inputs = user_inputs[:-1] | |
new_user_input = user_inputs[-1] | |
# Recreate a conversation object from the history messages | |
conversation = Conversation( | |
text=new_user_input, | |
past_user_inputs=past_user_inputs, | |
generated_responses=generated_responses, | |
) | |
# Get the response | |
response = self._get_response(conversation) | |
return response | |
# conversation = Conversation("Going to the movies tonight - any suggestions?") | |
# | |
# # Steps usually performed by the model when generating a response: | |
# # 1. Mark the user input as processed (moved to the history) | |
# conversation.mark_processed() | |
# # 2. Append a mode response | |
# conversation.append_response("The Big lebowski.") | |
# | |
# conversation.add_user_input("Is it good?") | |