Xu Xuenan
Initial commit
a121edc
from typing import Callable
import os
from dashscope import Generation
class QwenAgent(object):
def __init__(self,
system_prompt: str = None,
track_history: bool = True):
self.system_prompt = system_prompt
if system_prompt is None:
self.history = []
else:
self.history = [
{"role": "system", "content": system_prompt}
]
self.track_history = track_history
def basic_success_check(self, response):
if not response or not response.output or not response.output.text:
print(response)
return False
else:
return True
def run(self,
prompt: str,
top_p: float = 0.95,
temperature: float = 1.0,
seed: int = 1,
max_length: int = 1024,
max_try: int = 5,
success_check_fn: Callable = None
):
self.history.append({
"role": "user",
"content": prompt
})
success = False
try_times = 0
while try_times < max_try:
response = Generation.call(
model="qwen2-72b-instruct",
messages=self.history,
top_p=top_p,
temperature=temperature,
api_key=os.environ.get('DASHSCOPE_API_KEY'),
seed=seed,
max_length=max_length
)
if success_check_fn is None:
success_check_fn = lambda x: True
if self.basic_success_check(response) and success_check_fn(response.output.text):
response = response.output.text
self.history.append({
"role": "assistant",
"content": response
})
success = True
break
else:
try_times += 1
if not self.track_history:
if self.system_prompt is not None:
self.history = self.history[:1]
else:
self.history = []
return response, success