text-chat-ui / extensions /api /blocking_api.py
minhxthanh's picture
Upload 340 files
71e47a3
raw
history blame contribute delete
No virus
7.61 kB
import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import (
encode,
generate_reply,
stop_everything_event
)
from modules.utils import get_available_models
from modules.logging_colors import logger
def get_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
# dump
'shared.settings': shared.settings,
'shared.args': vars(shared.args),
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
response = json.dumps({
'results': [{
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = generate_params['history']
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/stop-stream':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
stop_everything_event()
response = json.dumps({
'results': 'success'
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/model':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# by default return the same as the GET interface
result = shared.model_name
# Actions: info, load, list, unload
action = body.get('action', '')
if action == 'load':
model_name = body['model_name']
args = body.get('args', {})
print('args', args)
for k in args:
setattr(shared.args, k, args[k])
shared.model_name = model_name
unload_model()
model_settings = get_model_metadata(shared.model_name)
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
try:
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora) # list
except Exception as e:
response = json.dumps({'error': {'message': repr(e)}})
self.wfile.write(response.encode('utf-8'))
raise e
shared.args.model = shared.model_name
result = get_model_info()
elif action == 'unload':
unload_model()
shared.model_name = None
shared.args.model = None
result = get_model_info()
elif action == 'list':
result = get_available_models()
elif action == 'info':
result = get_model_info()
response = json.dumps({
'result': result,
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
super().end_headers()
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str):
logger.info(f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
if ssl_verify:
logger.info(f'Starting API at https://{address}:{port}/api')
else:
logger.info(f'Starting API at http://{address}:{port}/api')
server.serve_forever()
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()