hanAlex commited on
Commit
2251701
Β·
verified Β·
1 Parent(s): 6768b55

Upload model_server.py

Browse files
Files changed (1) hide show
  1. model_server.py +144 -0
model_server.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker with transformers libs executes the model.
3
+
4
+ Run BF16 inference with:
5
+
6
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
7
+
8
+ Run Int4 inference with:
9
+
10
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
11
+
12
+ """
13
+ import argparse
14
+ import json
15
+
16
+ from fastapi import FastAPI, Request
17
+ from fastapi.responses import StreamingResponse
18
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
19
+ from transformers.generation.streamers import BaseStreamer
20
+ import torch
21
+ import uvicorn
22
+
23
+ from threading import Thread
24
+ from queue import Queue
25
+
26
+
27
+ class TokenStreamer(BaseStreamer):
28
+ def __init__(self, skip_prompt: bool = False, timeout=None):
29
+ self.skip_prompt = skip_prompt
30
+
31
+ # variables used in the streaming process
32
+ self.token_queue = Queue()
33
+ self.stop_signal = None
34
+ self.next_tokens_are_prompt = True
35
+ self.timeout = timeout
36
+
37
+ def put(self, value):
38
+ if len(value.shape) > 1 and value.shape[0] > 1:
39
+ raise ValueError("TextStreamer only supports batch size 1")
40
+ elif len(value.shape) > 1:
41
+ value = value[0]
42
+
43
+ if self.skip_prompt and self.next_tokens_are_prompt:
44
+ self.next_tokens_are_prompt = False
45
+ return
46
+
47
+ for token in value.tolist():
48
+ self.token_queue.put(token)
49
+
50
+ def end(self):
51
+ self.token_queue.put(self.stop_signal)
52
+
53
+ def __iter__(self):
54
+ return self
55
+
56
+ def __next__(self):
57
+ value = self.token_queue.get(timeout=self.timeout)
58
+ if value == self.stop_signal:
59
+ raise StopIteration()
60
+ else:
61
+ return value
62
+
63
+
64
+ class ModelWorker:
65
+ def __init__(self, model_path, dtype="bfloat16", device='cuda'):
66
+ self.device = device
67
+ self.bnb_config = BitsAndBytesConfig(
68
+ load_in_4bit=True,
69
+ bnb_4bit_use_double_quant=True,
70
+ bnb_4bit_quant_type="nf4",
71
+ bnb_4bit_compute_dtype=torch.bfloat16
72
+ ) if dtype == "int4" else None
73
+
74
+ self.glm_model = AutoModel.from_pretrained(
75
+ model_path,
76
+ trust_remote_code=True,
77
+ quantization_config=self.bnb_config if self.bnb_config else None,
78
+ device_map={"": 0}
79
+ ).eval()
80
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
81
+
82
+ @torch.inference_mode()
83
+ def generate_stream(self, params):
84
+ tokenizer, model = self.glm_tokenizer, self.glm_model
85
+
86
+ prompt = params["prompt"]
87
+
88
+ temperature = float(params.get("temperature", 1.0))
89
+ top_p = float(params.get("top_p", 1.0))
90
+ max_new_tokens = int(params.get("max_new_tokens", 256))
91
+
92
+ inputs = tokenizer([prompt], return_tensors="pt")
93
+ inputs = inputs.to(self.device)
94
+ streamer = TokenStreamer(skip_prompt=True)
95
+ thread = Thread(
96
+ target=model.generate,
97
+ kwargs=dict(
98
+ **inputs,
99
+ max_new_tokens=int(max_new_tokens),
100
+ temperature=float(temperature),
101
+ top_p=float(top_p),
102
+ streamer=streamer
103
+ )
104
+ )
105
+ thread.start()
106
+ for token_id in streamer:
107
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
108
+
109
+ def generate_stream_gate(self, params):
110
+ try:
111
+ for x in self.generate_stream(params):
112
+ yield x
113
+ except Exception as e:
114
+ print("Caught Unknown Error", e)
115
+ ret = {
116
+ "text": "Server Error",
117
+ "error_code": 1,
118
+ }
119
+ yield (json.dumps(ret) + "\n").encode()
120
+
121
+
122
+ app = FastAPI()
123
+
124
+
125
+ @app.post("/generate_stream")
126
+ async def generate_stream(request: Request):
127
+ params = await request.json()
128
+
129
+ generator = worker.generate_stream_gate(params)
130
+ return StreamingResponse(generator)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ parser = argparse.ArgumentParser()
135
+
136
+ parser.add_argument("--host", type=str, default="localhost")
137
+ parser.add_argument("--dtype", type=str, default="bfloat16")
138
+ parser.add_argument("--device", type=str, default="cuda:0")
139
+ parser.add_argument("--port", type=int, default=10000)
140
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
141
+ args = parser.parse_args()
142
+
143
+ worker = ModelWorker(args.model_path, args.dtype, args.device)
144
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")