AgentReview / agentreview /backends /hf_transformers.py
Yiqiao Jin
Initial Commit
bdafe83
import os
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import List
from tenacity import retry, stop_after_attempt, wait_random_exponential
from ..message import SYSTEM_NAME as SYSTEM
from ..message import Message
from .base import IntelligenceBackend
@contextmanager
def suppress_stdout_stderr():
"""A context manager that redirects stdout and stderr to devnull."""
with open(os.devnull, "w") as fnull:
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
yield (err, out)
with suppress_stdout_stderr():
# Try to import the transformers package
try:
import transformers
from transformers import pipeline
from transformers.pipelines.conversational import (
Conversation,
ConversationalPipeline,
)
except ImportError:
is_transformers_available = False
else:
is_transformers_available = True
class TransformersConversational(IntelligenceBackend):
"""Interface to the Transformers ConversationalPipeline."""
stateful = False
type_name = "transformers:conversational"
def __init__(self, model: str, device: int = -1, **kwargs):
super().__init__(model=model, device=device, **kwargs)
self.model = model
self.device = device
assert is_transformers_available, "Transformers package is not installed"
self.chatbot = pipeline(
task="conversational", model=self.model, device=self.device
)
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
def _get_response(self, conversation):
conversation = self.chatbot(conversation)
response = conversation.generated_responses[-1]
return response
@staticmethod
def _msg_template(agent_name, content):
return f"[{agent_name}]: {content}"
def query(
self,
agent_name: str,
role_desc: str,
history_messages: List[Message],
global_prompt: str = None,
request_msg: Message = None,
*args,
**kwargs,
) -> str:
user_inputs, generated_responses = [], []
all_messages = (
[(SYSTEM, global_prompt), (SYSTEM, role_desc)]
if global_prompt
else [(SYSTEM, role_desc)]
)
for msg in history_messages:
all_messages.append((msg.agent_name, msg.content))
if request_msg:
all_messages.append((SYSTEM, request_msg.content))
prev_is_user = False # Whether the previous message is from the user
for i, message in enumerate(all_messages):
if i == 0:
assert (
message[0] == SYSTEM
) # The first message should be from the system
if message[0] != agent_name:
if not prev_is_user:
user_inputs.append(self._msg_template(message[0], message[1]))
else:
user_inputs[-1] += "\n" + self._msg_template(message[0], message[1])
prev_is_user = True
else:
if prev_is_user:
generated_responses.append(message[1])
else:
generated_responses[-1] += "\n" + message[1]
prev_is_user = False
assert len(user_inputs) == len(generated_responses) + 1
past_user_inputs = user_inputs[:-1]
new_user_input = user_inputs[-1]
# Recreate a conversation object from the history messages
conversation = Conversation(
text=new_user_input,
past_user_inputs=past_user_inputs,
generated_responses=generated_responses,
)
# Get the response
response = self._get_response(conversation)
return response
# conversation = Conversation("Going to the movies tonight - any suggestions?")
#
# # Steps usually performed by the model when generating a response:
# # 1. Mark the user input as processed (moved to the history)
# conversation.mark_processed()
# # 2. Append a mode response
# conversation.append_response("The Big lebowski.")
#
# conversation.add_user_input("Is it good?")