File size: 8,731 Bytes
7def60a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
#!/usr/bin/env python3
import asyncio
from concurrent import futures
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc
import grpc
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_tokenizer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer that implements the Backend service defined in backend.proto.
"""
def generate(self,prompt, max_new_tokens):
"""
Generates text based on the given prompt and maximum number of new tokens.
Args:
prompt (str): The prompt to generate text from.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated text.
"""
self.generator.end_beam_search()
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
decoded_text = ''
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
if token.item() == self.generator.tokenizer.eos_token_id:
break
return decoded_text
def Health(self, request, context):
"""
Returns a health check message.
Args:
request: The health check request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The health check reply.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
async def LoadModel(self, request, context):
"""
Loads a language model.
Args:
request: The load model request.
context: The gRPC context.
Returns:
backend_pb2.Result: The load model result.
"""
engine_args = AsyncEngineArgs(
model=request.Model,
)
if request.Quantization != "":
engine_args.quantization = request.Quantization
if request.GPUMemoryUtilization != 0:
engine_args.gpu_memory_utilization = request.GPUMemoryUtilization
if request.TrustRemoteCode:
engine_args.trust_remote_code = request.TrustRemoteCode
if request.EnforceEager:
engine_args.enforce_eager = request.EnforceEager
if request.TensorParallelSize:
engine_args.tensor_parallel_size = request.TensorParallelSize
if request.SwapSpace != 0:
engine_args.swap_space = request.SwapSpace
if request.MaxModelLen != 0:
engine_args.max_model_len = request.MaxModelLen
try:
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
try:
engine_model_config = await self.llm.get_model_config()
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left",
)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The predict result.
"""
gen = self._predict(request, context, streaming=False)
res = await gen.__anext__()
return res
async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
Args:
request: The predict stream request.
context: The gRPC context.
Returns:
backend_pb2.Result: The predict stream result.
"""
iterations = self._predict(request, context, streaming=True)
try:
async for iteration in iterations:
yield iteration
finally:
await iterations.aclose()
async def _predict(self, request, context, streaming=False):
# Build sampling parameters
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
if request.TopP != 0:
sampling_params.top_p = request.TopP
if request.Tokens > 0:
sampling_params.max_tokens = request.Tokens
if request.Temperature != 0:
sampling_params.temperature = request.Temperature
if request.TopK != 0:
sampling_params.top_k = request.TopK
if request.PresencePenalty != 0:
sampling_params.presence_penalty = request.PresencePenalty
if request.FrequencyPenalty != 0:
sampling_params.frequency_penalty = request.FrequencyPenalty
if request.StopPrompts:
sampling_params.stop = request.StopPrompts
if request.IgnoreEOS:
sampling_params.ignore_eos = request.IgnoreEOS
if request.Seed != 0:
sampling_params.seed = request.Seed
prompt = request.Prompt
# If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
# Generate text
request_id = random_uuid()
outputs = self.llm.generate(prompt, sampling_params, request_id)
# Stream the results
generated_text = ""
try:
async for request_output in outputs:
iteration_text = request_output.outputs[0].text
if streaming:
# Remove text already sent as vllm concatenates the text from previous yields
delta_iteration_text = iteration_text.removeprefix(generated_text)
# Send the partial result
yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8'))
# Keep track of text generated
generated_text = iteration_text
finally:
await outputs.aclose()
# If streaming, we already sent everything
if streaming:
return
# Sending the final generated text
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
async def serve(address):
# Start asyncio gRPC server
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
# Add the servicer to the server
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
# Bind the server to the address
server.add_insecure_port(address)
# Gracefully shutdown the server on SIGTERM or SIGINT
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(
sig, lambda: asyncio.ensure_future(server.stop(5))
)
# Start the server
await server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Wait for the server to be terminated
await server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
asyncio.run(serve(args.addr)) |