|
import asyncio |
|
import json |
|
import os |
|
import time |
|
import traceback |
|
import warnings |
|
from concurrent.futures import ThreadPoolExecutor |
|
from logging import getLogger |
|
from threading import Lock |
|
from typing import AsyncGenerator, Dict, List, Optional, Union |
|
|
|
import aiohttp |
|
import requests |
|
|
|
from ..schema import ModelStatusCode |
|
from ..utils import filter_suffix |
|
from .base_api import AsyncBaseAPILLM, BaseAPILLM |
|
|
|
warnings.simplefilter('default') |
|
|
|
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' |
|
|
|
|
|
class GPTAPI(BaseAPILLM): |
|
"""Model wrapper around OpenAI's models. |
|
|
|
Args: |
|
model_type (str): The name of OpenAI's model. |
|
retry (int): Number of retires if the API call fails. Defaults to 2. |
|
key (str or List[str]): OpenAI key(s). In particular, when it |
|
is set to "ENV", the key will be fetched from the environment |
|
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a |
|
list, the keys will be used in round-robin manner. Defaults to |
|
'ENV'. |
|
org (str or List[str], optional): OpenAI organization(s). If not |
|
specified, OpenAI uses the default organization bound to each API |
|
key. If specified, the orgs will be posted with each request in |
|
round-robin manner. Defaults to None. |
|
meta_template (Dict, optional): The model's meta prompt |
|
template if needed, in case the requirement of injecting or |
|
wrapping of any meta instructions. |
|
api_base (str): The base url of OpenAI's API. Defaults to |
|
'https://api.openai.com/v1/chat/completions'. |
|
gen_params: Default generation configuration which could be overridden |
|
on the fly of generation. |
|
""" |
|
|
|
is_api: bool = True |
|
|
|
def __init__(self, |
|
model_type: str = 'gpt-3.5-turbo', |
|
retry: int = 2, |
|
json_mode: bool = False, |
|
key: Union[str, List[str]] = 'ENV', |
|
org: Optional[Union[str, List[str]]] = None, |
|
meta_template: Optional[Dict] = [ |
|
dict(role='system', api_role='system'), |
|
dict(role='user', api_role='user'), |
|
dict(role='assistant', api_role='assistant'), |
|
dict(role='environment', api_role='system') |
|
], |
|
api_base: str = OPENAI_API_BASE, |
|
proxies: Optional[Dict] = None, |
|
**gen_params): |
|
if 'top_k' in gen_params: |
|
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', |
|
DeprecationWarning) |
|
gen_params.pop('top_k') |
|
super().__init__( |
|
model_type=model_type, |
|
meta_template=meta_template, |
|
retry=retry, |
|
**gen_params) |
|
self.gen_params.pop('top_k') |
|
self.logger = getLogger(__name__) |
|
|
|
if isinstance(key, str): |
|
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] |
|
else: |
|
self.keys = key |
|
|
|
|
|
|
|
self.invalid_keys = set() |
|
|
|
self.key_ctr = 0 |
|
if isinstance(org, str): |
|
self.orgs = [org] |
|
else: |
|
self.orgs = org |
|
self.org_ctr = 0 |
|
self.url = api_base |
|
self.model_type = model_type |
|
self.proxies = proxies |
|
self.json_mode = json_mode |
|
|
|
def chat( |
|
self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
**gen_params, |
|
) -> Union[str, List[str]]: |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): a list of messages |
|
or list of lists of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
Union[str, List[str]]: generated string(s) |
|
""" |
|
assert isinstance(inputs, list) |
|
if 'max_tokens' in gen_params: |
|
raise NotImplementedError('unsupported parameter: max_tokens') |
|
gen_params = {**self.gen_params, **gen_params} |
|
with ThreadPoolExecutor(max_workers=20) as executor: |
|
tasks = [ |
|
executor.submit(self._chat, |
|
self.template_parser._prompt2api(messages), |
|
**gen_params) |
|
for messages in ( |
|
[inputs] if isinstance(inputs[0], dict) else inputs) |
|
] |
|
ret = [task.result() for task in tasks] |
|
return ret[0] if isinstance(inputs[0], dict) else ret |
|
|
|
def stream_chat( |
|
self, |
|
inputs: List[dict], |
|
**gen_params, |
|
): |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (List[dict]): a list of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: generated string |
|
""" |
|
assert isinstance(inputs, list) |
|
if 'max_tokens' in gen_params: |
|
raise NotImplementedError('unsupported parameter: max_tokens') |
|
gen_params = self.update_gen_params(**gen_params) |
|
gen_params['stream'] = True |
|
|
|
resp = '' |
|
finished = False |
|
stop_words = gen_params.get('stop_words') |
|
if stop_words is None: |
|
stop_words = [] |
|
|
|
messages = self.template_parser._prompt2api(inputs) |
|
for text in self._stream_chat(messages, **gen_params): |
|
if self.model_type.lower().startswith('qwen'): |
|
resp = text |
|
else: |
|
resp += text |
|
if not resp: |
|
continue |
|
|
|
for sw in stop_words: |
|
if sw in resp: |
|
resp = filter_suffix(resp, stop_words) |
|
finished = True |
|
break |
|
yield ModelStatusCode.STREAM_ING, resp, None |
|
if finished: |
|
break |
|
yield ModelStatusCode.END, resp, None |
|
|
|
def _chat(self, messages: List[dict], **gen_params) -> str: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
assert isinstance(messages, list) |
|
|
|
header, data = self.generate_request_data( |
|
model_type=self.model_type, |
|
messages=messages, |
|
gen_params=gen_params, |
|
json_mode=self.json_mode) |
|
|
|
max_num_retries, errmsg = 0, '' |
|
while max_num_retries < self.retry: |
|
with Lock(): |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
header['Authorization'] = f'Bearer {key}' |
|
|
|
if self.orgs: |
|
with Lock(): |
|
self.org_ctr += 1 |
|
if self.org_ctr == len(self.orgs): |
|
self.org_ctr = 0 |
|
header['OpenAI-Organization'] = self.orgs[self.org_ctr] |
|
|
|
response = dict() |
|
try: |
|
raw_response = requests.post( |
|
self.url, |
|
headers=header, |
|
data=json.dumps(data), |
|
proxies=self.proxies) |
|
response = raw_response.json() |
|
return response['choices'][0]['message']['content'].strip() |
|
except requests.ConnectionError: |
|
errmsg = 'Got connection error ' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
continue |
|
except requests.JSONDecodeError: |
|
errmsg = 'JsonDecode error, got ' + str(raw_response.content) |
|
self.logger.error(errmsg) |
|
continue |
|
except KeyError: |
|
if 'error' in response: |
|
if response['error']['code'] == 'rate_limit_exceeded': |
|
time.sleep(1) |
|
continue |
|
elif response['error']['code'] == 'insufficient_quota': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'insufficient_quota key: {key}') |
|
continue |
|
|
|
errmsg = 'Find error message in response: ' + str( |
|
response['error']) |
|
self.logger.error(errmsg) |
|
except Exception as error: |
|
errmsg = str(error) + '\n' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError('Calling OpenAI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
f'details. errmsg: {errmsg}') |
|
|
|
def _stream_chat(self, messages: List[dict], **gen_params) -> str: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
|
|
def streaming(raw_response): |
|
for chunk in raw_response.iter_lines( |
|
chunk_size=8192, decode_unicode=False, delimiter=b'\n'): |
|
if chunk: |
|
decoded = chunk.decode('utf-8') |
|
if decoded.startswith('data: [DONE]'): |
|
return |
|
if decoded[:5] == 'data:': |
|
decoded = decoded[5:] |
|
if decoded[0] == ' ': |
|
decoded = decoded[1:] |
|
else: |
|
print(decoded) |
|
continue |
|
try: |
|
response = json.loads(decoded) |
|
if 'code' in response and response['code'] == -20003: |
|
|
|
yield '' |
|
return |
|
if self.model_type.lower().startswith('qwen'): |
|
choice = response['output']['choices'][0] |
|
yield choice['message']['content'] |
|
if choice['finish_reason'] == 'stop': |
|
return |
|
else: |
|
choice = response['choices'][0] |
|
if choice['finish_reason'] == 'stop': |
|
return |
|
yield choice['delta'].get('content', '') |
|
except Exception as exc: |
|
msg = f'response {decoded} lead to exception of {str(exc)}' |
|
self.logger.error(msg) |
|
raise Exception(msg) from exc |
|
|
|
assert isinstance(messages, list) |
|
|
|
header, data = self.generate_request_data( |
|
model_type=self.model_type, |
|
messages=messages, |
|
gen_params=gen_params, |
|
json_mode=self.json_mode) |
|
|
|
max_num_retries, errmsg = 0, '' |
|
while max_num_retries < self.retry: |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
header['Authorization'] = f'Bearer {key}' |
|
|
|
if self.orgs: |
|
self.org_ctr += 1 |
|
if self.org_ctr == len(self.orgs): |
|
self.org_ctr = 0 |
|
header['OpenAI-Organization'] = self.orgs[self.org_ctr] |
|
|
|
response = dict() |
|
try: |
|
raw_response = requests.post( |
|
self.url, |
|
headers=header, |
|
data=json.dumps(data), |
|
proxies=self.proxies) |
|
return streaming(raw_response) |
|
except requests.ConnectionError: |
|
errmsg = 'Got connection error ' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
continue |
|
except requests.JSONDecodeError: |
|
errmsg = 'JsonDecode error, got ' + str(raw_response.content) |
|
self.logger.error(errmsg) |
|
continue |
|
except KeyError: |
|
if 'error' in response: |
|
if response['error']['code'] == 'rate_limit_exceeded': |
|
time.sleep(1) |
|
continue |
|
elif response['error']['code'] == 'insufficient_quota': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'insufficient_quota key: {key}') |
|
continue |
|
|
|
errmsg = 'Find error message in response: ' + str( |
|
response['error']) |
|
self.logger.error(errmsg) |
|
except Exception as error: |
|
errmsg = str(error) + '\n' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError('Calling OpenAI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
f'details. errmsg: {errmsg}') |
|
|
|
def generate_request_data(self, |
|
model_type, |
|
messages, |
|
gen_params, |
|
json_mode=False): |
|
""" |
|
Generates the request data for different model types. |
|
|
|
Args: |
|
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). |
|
messages (list): The list of messages to be sent to the model. |
|
gen_params (dict): The generation parameters. |
|
json_mode (bool): Flag to determine if the response format should be JSON. |
|
|
|
Returns: |
|
tuple: A tuple containing the header and the request data. |
|
""" |
|
|
|
gen_params = gen_params.copy() |
|
|
|
|
|
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) |
|
if max_tokens <= 0: |
|
return '', '' |
|
|
|
|
|
header = { |
|
'content-type': 'application/json', |
|
} |
|
|
|
|
|
gen_params['max_tokens'] = max_tokens |
|
if 'stop_words' in gen_params: |
|
gen_params['stop'] = gen_params.pop('stop_words') |
|
if 'repetition_penalty' in gen_params: |
|
gen_params['frequency_penalty'] = gen_params.pop( |
|
'repetition_penalty') |
|
|
|
|
|
data = {} |
|
if model_type.lower().startswith('gpt'): |
|
if 'top_k' in gen_params: |
|
warnings.warn( |
|
'`top_k` parameter is deprecated in OpenAI APIs.', |
|
DeprecationWarning) |
|
gen_params.pop('top_k') |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
data = { |
|
'model': model_type, |
|
'messages': messages, |
|
'n': 1, |
|
**gen_params |
|
} |
|
if json_mode: |
|
data['response_format'] = {'type': 'json_object'} |
|
elif model_type.lower().startswith('internlm'): |
|
data = { |
|
'model': model_type, |
|
'messages': messages, |
|
'n': 1, |
|
**gen_params |
|
} |
|
if json_mode: |
|
data['response_format'] = {'type': 'json_object'} |
|
elif model_type.lower().startswith('qwen'): |
|
header['X-DashScope-SSE'] = 'enable' |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
if 'frequency_penalty' in gen_params: |
|
gen_params['repetition_penalty'] = gen_params.pop( |
|
'frequency_penalty') |
|
gen_params['result_format'] = 'message' |
|
data = { |
|
'model': model_type, |
|
'input': { |
|
'messages': messages |
|
}, |
|
'parameters': { |
|
**gen_params |
|
} |
|
} |
|
else: |
|
raise NotImplementedError( |
|
f'Model type {model_type} is not supported') |
|
|
|
return header, data |
|
|
|
def tokenize(self, prompt: str) -> list: |
|
"""Tokenize the input prompt. |
|
|
|
Args: |
|
prompt (str): Input string. |
|
|
|
Returns: |
|
list: token ids |
|
""" |
|
import tiktoken |
|
self.tiktoken = tiktoken |
|
enc = self.tiktoken.encoding_for_model(self.model_type) |
|
return enc.encode(prompt) |
|
|
|
|
|
class AsyncGPTAPI(AsyncBaseAPILLM): |
|
"""Model wrapper around OpenAI's models. |
|
|
|
Args: |
|
model_type (str): The name of OpenAI's model. |
|
retry (int): Number of retires if the API call fails. Defaults to 2. |
|
key (str or List[str]): OpenAI key(s). In particular, when it |
|
is set to "ENV", the key will be fetched from the environment |
|
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a |
|
list, the keys will be used in round-robin manner. Defaults to |
|
'ENV'. |
|
org (str or List[str], optional): OpenAI organization(s). If not |
|
specified, OpenAI uses the default organization bound to each API |
|
key. If specified, the orgs will be posted with each request in |
|
round-robin manner. Defaults to None. |
|
meta_template (Dict, optional): The model's meta prompt |
|
template if needed, in case the requirement of injecting or |
|
wrapping of any meta instructions. |
|
api_base (str): The base url of OpenAI's API. Defaults to |
|
'https://api.openai.com/v1/chat/completions'. |
|
gen_params: Default generation configuration which could be overridden |
|
on the fly of generation. |
|
""" |
|
|
|
is_api: bool = True |
|
|
|
def __init__(self, |
|
model_type: str = 'gpt-3.5-turbo', |
|
retry: int = 2, |
|
json_mode: bool = False, |
|
key: Union[str, List[str]] = 'ENV', |
|
org: Optional[Union[str, List[str]]] = None, |
|
meta_template: Optional[Dict] = [ |
|
dict(role='system', api_role='system'), |
|
dict(role='user', api_role='user'), |
|
dict(role='assistant', api_role='assistant') |
|
], |
|
api_base: str = OPENAI_API_BASE, |
|
proxies: Optional[Dict] = None, |
|
**gen_params): |
|
if 'top_k' in gen_params: |
|
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', |
|
DeprecationWarning) |
|
gen_params.pop('top_k') |
|
super().__init__( |
|
model_type=model_type, |
|
meta_template=meta_template, |
|
retry=retry, |
|
**gen_params) |
|
self.gen_params.pop('top_k') |
|
self.logger = getLogger(__name__) |
|
|
|
if isinstance(key, str): |
|
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] |
|
else: |
|
self.keys = key |
|
|
|
|
|
|
|
self.invalid_keys = set() |
|
|
|
self.key_ctr = 0 |
|
if isinstance(org, str): |
|
self.orgs = [org] |
|
else: |
|
self.orgs = org |
|
self.org_ctr = 0 |
|
self.url = api_base |
|
self.model_type = model_type |
|
self.proxies = proxies or {} |
|
self.json_mode = json_mode |
|
|
|
async def chat( |
|
self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params, |
|
) -> Union[str, List[str]]: |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): a list of messages |
|
or list of lists of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
Union[str, List[str]]: generated string(s) |
|
""" |
|
assert isinstance(inputs, list) |
|
if 'max_tokens' in gen_params: |
|
raise NotImplementedError('unsupported parameter: max_tokens') |
|
gen_params = {**self.gen_params, **gen_params} |
|
tasks = [ |
|
self._chat(messages, **gen_params) for messages in ( |
|
[inputs] if isinstance(inputs[0], dict) else inputs) |
|
] |
|
ret = await asyncio.gather(*tasks) |
|
return ret[0] if isinstance(inputs[0], dict) else ret |
|
|
|
async def stream_chat( |
|
self, |
|
inputs: List[dict], |
|
**gen_params, |
|
): |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (List[dict]): a list of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: generated string |
|
""" |
|
assert isinstance(inputs, list) |
|
if 'max_tokens' in gen_params: |
|
raise NotImplementedError('unsupported parameter: max_tokens') |
|
gen_params = self.update_gen_params(**gen_params) |
|
gen_params['stream'] = True |
|
|
|
resp = '' |
|
finished = False |
|
stop_words = gen_params.get('stop_words') |
|
if stop_words is None: |
|
stop_words = [] |
|
|
|
messages = self.template_parser._prompt2api(inputs) |
|
async for text in self._stream_chat(messages, **gen_params): |
|
if self.model_type.lower().startswith('qwen'): |
|
resp = text |
|
else: |
|
resp += text |
|
if not resp: |
|
continue |
|
|
|
for sw in stop_words: |
|
if sw in resp: |
|
resp = filter_suffix(resp, stop_words) |
|
finished = True |
|
break |
|
yield ModelStatusCode.STREAM_ING, resp, None |
|
if finished: |
|
break |
|
yield ModelStatusCode.END, resp, None |
|
|
|
async def _chat(self, messages: List[dict], **gen_params) -> str: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
assert isinstance(messages, list) |
|
|
|
header, data = self.generate_request_data( |
|
model_type=self.model_type, |
|
messages=messages, |
|
gen_params=gen_params, |
|
json_mode=self.json_mode) |
|
|
|
max_num_retries, errmsg = 0, '' |
|
while max_num_retries < self.retry: |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
header['Authorization'] = f'Bearer {key}' |
|
|
|
if self.orgs: |
|
self.org_ctr += 1 |
|
if self.org_ctr == len(self.orgs): |
|
self.org_ctr = 0 |
|
header['OpenAI-Organization'] = self.orgs[self.org_ctr] |
|
|
|
response = dict() |
|
try: |
|
async with aiohttp.ClientSession() as session: |
|
async with session.post( |
|
self.url, |
|
headers=header, |
|
json=data, |
|
proxy=self.proxies.get( |
|
'https', self.proxies.get('http'))) as resp: |
|
response = await resp.json() |
|
return response['choices'][0]['message'][ |
|
'content'].strip() |
|
except aiohttp.ClientConnectionError: |
|
errmsg = 'Got connection error ' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
continue |
|
except aiohttp.ClientResponseError as e: |
|
errmsg = 'Response error, got ' + str(e) |
|
self.logger.error(errmsg) |
|
continue |
|
except json.JSONDecodeError: |
|
errmsg = 'JsonDecode error, got ' + (await resp.text( |
|
errors='replace')) |
|
self.logger.error(errmsg) |
|
continue |
|
except KeyError: |
|
if 'error' in response: |
|
if response['error']['code'] == 'rate_limit_exceeded': |
|
time.sleep(1) |
|
continue |
|
elif response['error']['code'] == 'insufficient_quota': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'insufficient_quota key: {key}') |
|
continue |
|
|
|
errmsg = 'Find error message in response: ' + str( |
|
response['error']) |
|
self.logger.error(errmsg) |
|
except Exception as error: |
|
errmsg = str(error) + '\n' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError('Calling OpenAI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
f'details. errmsg: {errmsg}') |
|
|
|
async def _stream_chat(self, messages: List[dict], |
|
**gen_params) -> AsyncGenerator[str, None]: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
|
|
async def streaming(raw_response): |
|
async for chunk in raw_response.content: |
|
if chunk: |
|
decoded = chunk.decode('utf-8') |
|
if decoded.startswith('data: [DONE]'): |
|
return |
|
if decoded[:5] == 'data:': |
|
decoded = decoded[5:] |
|
if decoded[0] == ' ': |
|
decoded = decoded[1:] |
|
else: |
|
print(decoded) |
|
continue |
|
try: |
|
response = json.loads(decoded) |
|
if 'code' in response and response['code'] == -20003: |
|
|
|
yield '' |
|
return |
|
if self.model_type.lower().startswith('qwen'): |
|
choice = response['output']['choices'][0] |
|
yield choice['message']['content'] |
|
if choice['finish_reason'] == 'stop': |
|
return |
|
else: |
|
choice = response['choices'][0] |
|
if choice['finish_reason'] == 'stop': |
|
return |
|
yield choice['delta'].get('content', '') |
|
except Exception as exc: |
|
msg = f'response {decoded} lead to exception of {str(exc)}' |
|
self.logger.error(msg) |
|
raise Exception(msg) from exc |
|
|
|
assert isinstance(messages, list) |
|
|
|
header, data = self.generate_request_data( |
|
model_type=self.model_type, |
|
messages=messages, |
|
gen_params=gen_params, |
|
json_mode=self.json_mode) |
|
|
|
max_num_retries, errmsg = 0, '' |
|
while max_num_retries < self.retry: |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
header['Authorization'] = f'Bearer {key}' |
|
|
|
if self.orgs: |
|
self.org_ctr += 1 |
|
if self.org_ctr == len(self.orgs): |
|
self.org_ctr = 0 |
|
header['OpenAI-Organization'] = self.orgs[self.org_ctr] |
|
|
|
response = dict() |
|
try: |
|
async with aiohttp.ClientSession() as session: |
|
async with session.post( |
|
self.url, |
|
headers=header, |
|
json=data, |
|
proxy=self.proxies.get( |
|
'https', |
|
self.proxies.get('http'))) as raw_response: |
|
async for msg in streaming(raw_response): |
|
yield msg |
|
return |
|
except aiohttp.ClientConnectionError: |
|
errmsg = 'Got connection error ' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
continue |
|
except aiohttp.ClientResponseError as e: |
|
errmsg = 'Response error, got ' + str(e) |
|
self.logger.error(errmsg) |
|
continue |
|
except KeyError: |
|
if 'error' in response: |
|
if response['error']['code'] == 'rate_limit_exceeded': |
|
time.sleep(1) |
|
continue |
|
elif response['error']['code'] == 'insufficient_quota': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'insufficient_quota key: {key}') |
|
continue |
|
|
|
errmsg = 'Find error message in response: ' + str( |
|
response['error']) |
|
self.logger.error(errmsg) |
|
except Exception as error: |
|
errmsg = str(error) + '\n' + str(traceback.format_exc()) |
|
self.logger.error(errmsg) |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError('Calling OpenAI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
f'details. errmsg: {errmsg}') |
|
|
|
def generate_request_data(self, |
|
model_type, |
|
messages, |
|
gen_params, |
|
json_mode=False): |
|
""" |
|
Generates the request data for different model types. |
|
|
|
Args: |
|
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). |
|
messages (list): The list of messages to be sent to the model. |
|
gen_params (dict): The generation parameters. |
|
json_mode (bool): Flag to determine if the response format should be JSON. |
|
|
|
Returns: |
|
tuple: A tuple containing the header and the request data. |
|
""" |
|
|
|
gen_params = gen_params.copy() |
|
|
|
|
|
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) |
|
if max_tokens <= 0: |
|
return '', '' |
|
|
|
|
|
header = { |
|
'content-type': 'application/json', |
|
} |
|
|
|
|
|
gen_params['max_tokens'] = max_tokens |
|
if 'stop_words' in gen_params: |
|
gen_params['stop'] = gen_params.pop('stop_words') |
|
if 'repetition_penalty' in gen_params: |
|
gen_params['frequency_penalty'] = gen_params.pop( |
|
'repetition_penalty') |
|
|
|
|
|
data = {} |
|
if model_type.lower().startswith('gpt'): |
|
if 'top_k' in gen_params: |
|
warnings.warn( |
|
'`top_k` parameter is deprecated in OpenAI APIs.', |
|
DeprecationWarning) |
|
gen_params.pop('top_k') |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
data = { |
|
'model': model_type, |
|
'messages': messages, |
|
'n': 1, |
|
**gen_params |
|
} |
|
if json_mode: |
|
data['response_format'] = {'type': 'json_object'} |
|
elif model_type.lower().startswith('internlm'): |
|
data = { |
|
'model': model_type, |
|
'messages': messages, |
|
'n': 1, |
|
**gen_params |
|
} |
|
if json_mode: |
|
data['response_format'] = {'type': 'json_object'} |
|
elif model_type.lower().startswith('qwen'): |
|
header['X-DashScope-SSE'] = 'enable' |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
if 'frequency_penalty' in gen_params: |
|
gen_params['repetition_penalty'] = gen_params.pop( |
|
'frequency_penalty') |
|
gen_params['result_format'] = 'message' |
|
data = { |
|
'model': model_type, |
|
'input': { |
|
'messages': messages |
|
}, |
|
'parameters': { |
|
**gen_params |
|
} |
|
} |
|
else: |
|
raise NotImplementedError( |
|
f'Model type {model_type} is not supported') |
|
|
|
return header, data |
|
|
|
def tokenize(self, prompt: str) -> list: |
|
"""Tokenize the input prompt. |
|
|
|
Args: |
|
prompt (str): Input string. |
|
|
|
Returns: |
|
list: token ids |
|
""" |
|
import tiktoken |
|
self.tiktoken = tiktoken |
|
enc = self.tiktoken.encoding_for_model(self.model_type) |
|
return enc.encode(prompt) |
|
|