Husnain commited on
Commit
bfd6e9c
·
unverified ·
1 Parent(s): 687103e

💎 [Feature] Enable gpt-3.5 in chat_api

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +31 -25
apis/chat_api.py CHANGED
@@ -12,21 +12,24 @@ from fastapi.responses import HTMLResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel, Field
14
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
 
 
 
 
15
 
16
  from messagers.message_composer import MessageComposer
17
  from mocks.stream_chat_mocker import stream_chat_mock
18
- from networks.message_streamer import MessageStreamer
19
- from utils.logger import logger
20
- from constants.models import AVAILABLE_MODELS_DICTS
21
 
22
 
23
  class ChatAPIApp:
24
  def __init__(self):
25
  self.app = FastAPI(
26
  docs_url="/",
27
- title="HuggingFace LLM API",
28
  swagger_ui_parameters={"defaultModelsExpandDepth": -1},
29
- version="1.0",
30
  )
31
  self.setup_routes()
32
 
@@ -86,19 +89,22 @@ class ChatAPIApp:
86
  def chat_completions(
87
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
88
  ):
89
- streamer = MessageStreamer(model=item.model)
90
- composer = MessageComposer(model=item.model)
91
- composer.merge(messages=item.messages)
92
- # streamer.chat = stream_chat_mock
93
-
94
- stream_response = streamer.chat_response(
95
- prompt=composer.merged_str,
96
- temperature=item.temperature,
97
- top_p=item.top_p,
98
- max_new_tokens=item.max_tokens,
99
- api_key=api_key,
100
- use_cache=item.use_cache,
101
- )
 
 
 
102
  if item.stream:
103
  event_source_response = EventSourceResponse(
104
  streamer.chat_return_generator(stream_response),
@@ -152,17 +158,17 @@ class ArgParser(argparse.ArgumentParser):
152
 
153
  self.add_argument(
154
  "-s",
155
- "--server",
156
  type=str,
157
- default="0.0.0.0",
158
- help="Server IP for HF LLM Chat API",
159
  )
160
  self.add_argument(
161
  "-p",
162
  "--port",
163
  type=int,
164
- default=23333,
165
- help="Server Port for HF LLM Chat API",
166
  )
167
 
168
  self.add_argument(
@@ -181,9 +187,9 @@ app = ChatAPIApp().app
181
  if __name__ == "__main__":
182
  args = ArgParser().args
183
  if args.dev:
184
- uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
185
  else:
186
- uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
187
 
188
  # python -m apis.chat_api # [Docker] on product mode
189
  # python -m apis.chat_api -d # [Dev] on develop mode
 
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel, Field
14
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
15
+ from tclogger import logger
16
+
17
+ from constants.models import AVAILABLE_MODELS_DICTS
18
+ from constants.envs import CONFIG
19
 
20
  from messagers.message_composer import MessageComposer
21
  from mocks.stream_chat_mocker import stream_chat_mock
22
+ from networks.huggingface_streamer import HuggingfaceStreamer
23
+ from networks.openai_streamer import OpenaiStreamer
 
24
 
25
 
26
  class ChatAPIApp:
27
  def __init__(self):
28
  self.app = FastAPI(
29
  docs_url="/",
30
+ title=CONFIG["app_name"],
31
  swagger_ui_parameters={"defaultModelsExpandDepth": -1},
32
+ version=CONFIG["version"],
33
  )
34
  self.setup_routes()
35
 
 
89
  def chat_completions(
90
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
91
  ):
92
+ if item.model == "gpt-3.5":
93
+ streamer = OpenaiStreamer()
94
+ stream_response = streamer.chat_response(messages=item.messages)
95
+ else:
96
+ streamer = HuggingfaceStreamer(model=item.model)
97
+ composer = MessageComposer(model=item.model)
98
+ composer.merge(messages=item.messages)
99
+ stream_response = streamer.chat_response(
100
+ prompt=composer.merged_str,
101
+ temperature=item.temperature,
102
+ top_p=item.top_p,
103
+ max_new_tokens=item.max_tokens,
104
+ api_key=api_key,
105
+ use_cache=item.use_cache,
106
+ )
107
+
108
  if item.stream:
109
  event_source_response = EventSourceResponse(
110
  streamer.chat_return_generator(stream_response),
 
158
 
159
  self.add_argument(
160
  "-s",
161
+ "--host",
162
  type=str,
163
+ default=CONFIG["host"],
164
+ help=f"Host for {CONFIG['app_name']}",
165
  )
166
  self.add_argument(
167
  "-p",
168
  "--port",
169
  type=int,
170
+ default=CONFIG["port"],
171
+ help=f"Port for {CONFIG['app_name']}",
172
  )
173
 
174
  self.add_argument(
 
187
  if __name__ == "__main__":
188
  args = ArgParser().args
189
  if args.dev:
190
+ uvicorn.run("__main__:app", host=args.host, port=args.port, reload=True)
191
  else:
192
+ uvicorn.run("__main__:app", host=args.host, port=args.port, reload=False)
193
 
194
  # python -m apis.chat_api # [Docker] on product mode
195
  # python -m apis.chat_api -d # [Dev] on develop mode