baohuynhbk14 commited on
Commit
b41e98c
·
1 Parent(s): 2304772

Remove deprecated API functions and update model initialization in app.py

Browse files
Files changed (5) hide show
  1. api.py +0 -33
  2. app.py +38 -96
  3. controller.py +0 -291
  4. gradio_web_server.py +0 -761
  5. model_worker.py +0 -541
api.py DELETED
@@ -1,33 +0,0 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- import base64
8
- import json
9
- from io import BytesIO
10
-
11
- import requests
12
- from PIL import Image
13
-
14
-
15
- def get_model_list(controller_url):
16
- ret = requests.post(controller_url + '/refresh_all_workers')
17
- assert ret.status_code == 200
18
- ret = requests.post(controller_url + '/list_models')
19
- models = ret.json()['models']
20
- return models
21
-
22
-
23
- def get_selected_worker_ip(controller_url, selected_model):
24
- ret = requests.post(controller_url + '/get_worker_address',
25
- json={'model': selected_model})
26
- worker_addr = ret.json()['address']
27
- return worker_addr
28
-
29
-
30
- def pil_image_to_base64(image):
31
- buffered = BytesIO()
32
- image.save(buffered, format='PNG')
33
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -24,11 +24,14 @@ from utils import (
24
  load_image_from_base64,
25
  get_log_filename,
26
  )
 
 
27
  from conversation import Conversation
 
28
 
29
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
30
 
31
- headers = {"User-Agent": "InternVL-Chat Client"}
32
 
33
  no_change_btn = gr.Button()
34
  enable_btn = gr.Button(interactive=True)
@@ -62,64 +65,6 @@ def init_state(state=None):
62
  del state
63
  return Conversation()
64
 
65
-
66
- def find_bounding_boxes(state, response):
67
- pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>")
68
- matches = pattern.findall(response)
69
- results = []
70
- for match in matches:
71
- results.append((match[0], eval(match[1])))
72
- returned_image = None
73
- latest_image = state.get_images(source=state.USER)[-1]
74
- returned_image = latest_image.copy()
75
- width, height = returned_image.size
76
- draw = ImageDraw.Draw(returned_image)
77
- for result in results:
78
- line_width = max(1, int(min(width, height) / 200))
79
- random_color = (
80
- random.randint(0, 128),
81
- random.randint(0, 128),
82
- random.randint(0, 128),
83
- )
84
- category_name, coordinates = result
85
- coordinates = [
86
- (
87
- float(x[0]) / 1000,
88
- float(x[1]) / 1000,
89
- float(x[2]) / 1000,
90
- float(x[3]) / 1000,
91
- )
92
- for x in coordinates
93
- ]
94
- coordinates = [
95
- (
96
- int(x[0] * width),
97
- int(x[1] * height),
98
- int(x[2] * width),
99
- int(x[3] * height),
100
- )
101
- for x in coordinates
102
- ]
103
- for box in coordinates:
104
- draw.rectangle(box, outline=random_color, width=line_width)
105
- font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2))
106
- text_size = font.getbbox(category_name)
107
- text_width, text_height = (
108
- text_size[2] - text_size[0],
109
- text_size[3] - text_size[1],
110
- )
111
- text_position = (box[0], max(0, box[1] - text_height))
112
- draw.rectangle(
113
- [
114
- text_position,
115
- (text_position[0] + text_width, text_position[1] + text_height),
116
- ],
117
- fill=random_color,
118
- )
119
- draw.text(text_position, category_name, fill="white", font=font)
120
- return returned_image if len(matches) > 0 else None
121
-
122
-
123
  def vote_last_response(state, liked, request: gr.Request):
124
  conv_data = {
125
  "tstamp": round(time.time(), 4),
@@ -220,6 +165,15 @@ def add_text(state, message, system_prompt, request: gr.Request):
220
  disable_btn,
221
  ) * 5
222
 
 
 
 
 
 
 
 
 
 
223
 
224
  def http_bot(
225
  state,
@@ -230,7 +184,7 @@ def http_bot(
230
  max_input_tiles,
231
  request: gr.Request,
232
  ):
233
- model_name = 'Vintern-1B-v3'
234
  logger.info(f"http_bot. ip: {request.client.host}")
235
  start_tstamp = time.time()
236
  if hasattr(state, "skip_next") and state.skip_next:
@@ -242,12 +196,8 @@ def http_bot(
242
  ) + (no_change_btn,) * 5
243
  return
244
 
245
- worker_addr = os.environ.get("WORKER_ADDR", "")
246
- api_token = os.environ.get("API_TOKEN", "")
247
- headers = {"Authorization": f"{api_token}", "Content-Type": "application/json"}
248
-
249
  # No available worker
250
- if worker_addr == "":
251
  # state.messages[-1][-1] = server_error_msg
252
  state.update_message(Conversation.ASSISTANT, server_error_msg)
253
  yield (
@@ -265,17 +215,6 @@ def http_bot(
265
  all_images = state.get_images(source=state.USER)
266
  all_image_paths = [state.save_image(image) for image in all_images]
267
 
268
- # Make requests
269
- pload = {
270
- "model": model_name,
271
- "messages": state.get_prompt_v2(inlude_image=True, max_dynamic_patch=max_input_tiles),
272
- "temperature": float(temperature),
273
- "top_p": float(top_p),
274
- "max_tokens": max_new_tokens,
275
- "repetition_penalty": repetition_penalty,
276
- "stream": True
277
- }
278
- logger.info(f"==== request ====\n{pload}")
279
  state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
280
  yield (
281
  state,
@@ -285,26 +224,29 @@ def http_bot(
285
 
286
  try:
287
  # Stream output
288
- response = requests.post(worker_addr, json=pload, headers=headers, stream=True, timeout=300)
289
- finnal_output = ''
290
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"):
291
- if chunk:
292
- chunk = chunk.decode()
293
- if chunk == 'data: [DONE]':
294
- break
295
- if chunk.startswith("data:"):
296
- chunk = chunk[5:]
297
- chunk = json.loads(chunk)
298
- output = chunk['choices'][0]['delta']['content']
299
- finnal_output += output
 
 
300
 
301
- state.update_message(Conversation.ASSISTANT, finnal_output + state.streaming_placeholder, None)
302
- yield (
303
- state,
304
- state.to_gradio_chatbot(),
305
- gr.MultimodalTextbox(interactive=False),
306
- ) + (disable_btn,) * 5
307
- except requests.exceptions.RequestException as e:
 
308
  state.update_message(Conversation.ASSISTANT, server_error_msg, None)
309
  yield (
310
  state,
@@ -332,7 +274,7 @@ def http_bot(
332
  ) + (enable_btn,) * 5
333
 
334
  finish_tstamp = time.time()
335
- logger.info(f"{finnal_output}")
336
  data = {
337
  "tstamp": round(finish_tstamp, 4),
338
  "like": None,
 
24
  load_image_from_base64,
25
  get_log_filename,
26
  )
27
+ from threading import Thread
28
+ import torch
29
  from conversation import Conversation
30
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
31
 
32
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
33
 
34
+ headers = {"User-Agent": "Vintern-Chat Client"}
35
 
36
  no_change_btn = gr.Button()
37
  enable_btn = gr.Button(interactive=True)
 
65
  del state
66
  return Conversation()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def vote_last_response(state, liked, request: gr.Request):
69
  conv_data = {
70
  "tstamp": round(time.time(), 4),
 
165
  disable_btn,
166
  ) * 5
167
 
168
+ model_name = "5CD-AI/Vintern-1B-v3_5"
169
+ model = AutoModel.from_pretrained(
170
+ model_name,
171
+ torch_dtype=torch.bfloat16,
172
+ low_cpu_mem_usage=True,
173
+ trust_remote_code=True,
174
+ ).eval().cuda()
175
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
176
+
177
 
178
  def http_bot(
179
  state,
 
184
  max_input_tiles,
185
  request: gr.Request,
186
  ):
187
+
188
  logger.info(f"http_bot. ip: {request.client.host}")
189
  start_tstamp = time.time()
190
  if hasattr(state, "skip_next") and state.skip_next:
 
196
  ) + (no_change_btn,) * 5
197
  return
198
 
 
 
 
 
199
  # No available worker
200
+ if model is None:
201
  # state.messages[-1][-1] = server_error_msg
202
  state.update_message(Conversation.ASSISTANT, server_error_msg)
203
  yield (
 
215
  all_images = state.get_images(source=state.USER)
216
  all_image_paths = [state.save_image(image) for image in all_images]
217
 
 
 
 
 
 
 
 
 
 
 
 
218
  state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
219
  yield (
220
  state,
 
224
 
225
  try:
226
  # Stream output
227
+ # response = requests.post(worker_addr, json=pload, headers=headers, stream=True, timeout=300)
228
+ streamer = TextIteratorStreamer(
229
+ tokenizer, skip_prompt=True, skip_special_tokens=True
230
+ )
231
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
232
+
233
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
234
+ thread.start()
235
+
236
+ buffer = ""
237
+ for new_text in streamer:
238
+ buffer += new_text
239
+ # Remove <|im_end|> or similar tokens from the output
240
+ buffer = buffer.replace("<|im_end|>", "")
241
 
242
+ state.update_message(Conversation.ASSISTANT, buffer + state.streaming_placeholder, None)
243
+ yield (
244
+ state,
245
+ state.to_gradio_chatbot(),
246
+ gr.MultimodalTextbox(interactive=False),
247
+ ) + (disable_btn,) * 5
248
+
249
+ except Exception as e:
250
  state.update_message(Conversation.ASSISTANT, server_error_msg, None)
251
  yield (
252
  state,
 
274
  ) + (enable_btn,) * 5
275
 
276
  finish_tstamp = time.time()
277
+ logger.info(f"{buffer}")
278
  data = {
279
  "tstamp": round(finish_tstamp, 4),
280
  "like": None,
controller.py DELETED
@@ -1,291 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import dataclasses
7
- import json
8
- import re
9
- import threading
10
- import time
11
- from enum import Enum, auto
12
- from typing import List
13
-
14
- import numpy as np
15
- import requests
16
- import uvicorn
17
- from fastapi import FastAPI, Request
18
- from starlette.responses import StreamingResponse
19
- from utils import build_logger, server_error_msg
20
-
21
- CONTROLLER_HEART_BEAT_EXPIRATION = 30
22
- logger = build_logger('controller', 'controller.log')
23
-
24
-
25
- class DispatchMethod(Enum):
26
- LOTTERY = auto()
27
- SHORTEST_QUEUE = auto()
28
-
29
- @classmethod
30
- def from_str(cls, name):
31
- if name == 'lottery':
32
- return cls.LOTTERY
33
- elif name == 'shortest_queue':
34
- return cls.SHORTEST_QUEUE
35
- else:
36
- raise ValueError(f'Invalid dispatch method')
37
-
38
-
39
- @dataclasses.dataclass
40
- class WorkerInfo:
41
- model_names: List[str]
42
- speed: int
43
- queue_length: int
44
- check_heart_beat: bool
45
- last_heart_beat: str
46
-
47
-
48
- def heart_beat_controller(controller):
49
- while True:
50
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
51
- controller.remove_stable_workers_by_expiration()
52
-
53
-
54
- class Controller:
55
- def __init__(self, dispatch_method: str):
56
- # Dict[str -> WorkerInfo]
57
- self.worker_info = {}
58
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
59
-
60
- self.heart_beat_thread = threading.Thread(
61
- target=heart_beat_controller, args=(self,))
62
- self.heart_beat_thread.start()
63
-
64
- logger.info('Init controller')
65
-
66
- def register_worker(self, worker_name: str, check_heart_beat: bool,
67
- worker_status: dict):
68
- if worker_name not in self.worker_info:
69
- logger.info(f'Register a new worker: {worker_name}')
70
- else:
71
- logger.info(f'Register an existing worker: {worker_name}')
72
-
73
- if not worker_status:
74
- worker_status = self.get_worker_status(worker_name)
75
- if not worker_status:
76
- return False
77
-
78
- self.worker_info[worker_name] = WorkerInfo(
79
- worker_status['model_names'], worker_status['speed'], worker_status['queue_length'],
80
- check_heart_beat, time.time())
81
-
82
- logger.info(f'Register done: {worker_name}, {worker_status}')
83
- return True
84
-
85
- def get_worker_status(self, worker_name: str):
86
- try:
87
- r = requests.post(worker_name + '/worker_get_status', timeout=5)
88
- except requests.exceptions.RequestException as e:
89
- logger.error(f'Get status fails: {worker_name}, {e}')
90
- return None
91
-
92
- if r.status_code != 200:
93
- logger.error(f'Get status fails: {worker_name}, {r}')
94
- return None
95
-
96
- return r.json()
97
-
98
- def remove_worker(self, worker_name: str):
99
- del self.worker_info[worker_name]
100
-
101
- def refresh_all_workers(self):
102
- old_info = dict(self.worker_info)
103
- self.worker_info = {}
104
-
105
- for w_name, w_info in old_info.items():
106
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
107
- logger.info(f'Remove stale worker: {w_name}')
108
-
109
- def list_models(self):
110
- model_names = set()
111
-
112
- for w_name, w_info in self.worker_info.items():
113
- model_names.update(w_info.model_names)
114
-
115
- def extract_key(s):
116
- if 'Pro' in s:
117
- return 999
118
- match = re.match(r'Vintern-(\d+)B', s)
119
- if match:
120
- return int(match.group(1))
121
- return -1
122
-
123
- def custom_sort_key(s):
124
- key = extract_key(s)
125
- # Return a tuple where -1 will ensure that non-matching items come last
126
- return (0 if key != -1 else 1, -key if key != -1 else s)
127
-
128
- sorted_list = sorted(list(model_names), key=custom_sort_key)
129
- return sorted_list
130
-
131
- def get_worker_address(self, model_name: str):
132
- if self.dispatch_method == DispatchMethod.LOTTERY:
133
- worker_names = []
134
- worker_speeds = []
135
- for w_name, w_info in self.worker_info.items():
136
- if model_name in w_info.model_names:
137
- worker_names.append(w_name)
138
- worker_speeds.append(w_info.speed)
139
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
140
- norm = np.sum(worker_speeds)
141
- if norm < 1e-4:
142
- return ''
143
- worker_speeds = worker_speeds / norm
144
- if True: # Directly return address
145
- pt = np.random.choice(np.arange(len(worker_names)),
146
- p=worker_speeds)
147
- worker_name = worker_names[pt]
148
- return worker_name
149
-
150
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
151
- worker_names = []
152
- worker_qlen = []
153
- for w_name, w_info in self.worker_info.items():
154
- if model_name in w_info.model_names:
155
- worker_names.append(w_name)
156
- worker_qlen.append(w_info.queue_length / w_info.speed)
157
- if len(worker_names) == 0:
158
- return ''
159
- min_index = np.argmin(worker_qlen)
160
- w_name = worker_names[min_index]
161
- self.worker_info[w_name].queue_length += 1
162
- logger.info(f'names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}')
163
- return w_name
164
- else:
165
- raise ValueError(f'Invalid dispatch method: {self.dispatch_method}')
166
-
167
- def receive_heart_beat(self, worker_name: str, queue_length: int):
168
- if worker_name not in self.worker_info:
169
- logger.info(f'Receive unknown heart beat. {worker_name}')
170
- return False
171
-
172
- self.worker_info[worker_name].queue_length = queue_length
173
- self.worker_info[worker_name].last_heart_beat = time.time()
174
- logger.info(f'Receive heart beat. {worker_name}')
175
- return True
176
-
177
- def remove_stable_workers_by_expiration(self):
178
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
179
- to_delete = []
180
- for worker_name, w_info in self.worker_info.items():
181
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
182
- to_delete.append(worker_name)
183
-
184
- for worker_name in to_delete:
185
- self.remove_worker(worker_name)
186
-
187
- def worker_api_generate_stream(self, params):
188
- worker_addr = self.get_worker_address(params['model'])
189
- if not worker_addr:
190
- logger.info(f"no worker: {params['model']}")
191
- ret = {
192
- 'text': server_error_msg,
193
- 'error_code': 2,
194
- }
195
- yield json.dumps(ret).encode() + b'\0'
196
-
197
- try:
198
- response = requests.post(worker_addr + '/worker_generate_stream',
199
- json=params, stream=True, timeout=5)
200
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\0'):
201
- if chunk:
202
- yield chunk + b'\0'
203
- except requests.exceptions.RequestException as e:
204
- logger.info(f'worker timeout: {worker_addr}')
205
- ret = {
206
- 'text': server_error_msg,
207
- 'error_code': 3,
208
- }
209
- yield json.dumps(ret).encode() + b'\0'
210
-
211
- # Let the controller act as a worker to achieve hierarchical
212
- # management. This can be used to connect isolated sub networks.
213
- def worker_api_get_status(self):
214
- model_names = set()
215
- speed = 0
216
- queue_length = 0
217
-
218
- for w_name in self.worker_info:
219
- worker_status = self.get_worker_status(w_name)
220
- if worker_status is not None:
221
- model_names.update(worker_status['model_names'])
222
- speed += worker_status['speed']
223
- queue_length += worker_status['queue_length']
224
-
225
- return {
226
- 'model_names': list(model_names),
227
- 'speed': speed,
228
- 'queue_length': queue_length,
229
- }
230
-
231
-
232
- app = FastAPI()
233
-
234
-
235
- @app.post('/register_worker')
236
- async def register_worker(request: Request):
237
- data = await request.json()
238
- controller.register_worker(
239
- data['worker_name'], data['check_heart_beat'],
240
- data.get('worker_status', None))
241
-
242
-
243
- @app.post('/refresh_all_workers')
244
- async def refresh_all_workers():
245
- models = controller.refresh_all_workers()
246
-
247
-
248
- @app.post('/list_models')
249
- async def list_models():
250
- models = controller.list_models()
251
- return {'models': models}
252
-
253
-
254
- @app.post('/get_worker_address')
255
- async def get_worker_address(request: Request):
256
- data = await request.json()
257
- addr = controller.get_worker_address(data['model'])
258
- return {'address': addr}
259
-
260
-
261
- @app.post('/receive_heart_beat')
262
- async def receive_heart_beat(request: Request):
263
- data = await request.json()
264
- exist = controller.receive_heart_beat(
265
- data['worker_name'], data['queue_length'])
266
- return {'exist': exist}
267
-
268
-
269
- @app.post('/worker_generate_stream')
270
- async def worker_api_generate_stream(request: Request):
271
- params = await request.json()
272
- generator = controller.worker_api_generate_stream(params)
273
- return StreamingResponse(generator)
274
-
275
-
276
- @app.post('/worker_get_status')
277
- async def worker_api_get_status(request: Request):
278
- return controller.worker_api_get_status()
279
-
280
-
281
- if __name__ == '__main__':
282
- parser = argparse.ArgumentParser()
283
- parser.add_argument('--host', type=str, default='0.0.0.0')
284
- parser.add_argument('--port', type=int, default=10075)
285
- parser.add_argument('--dispatch-method', type=str, choices=[
286
- 'lottery', 'shortest_queue'], default='shortest_queue')
287
- args = parser.parse_args()
288
- logger.info(f'args: {args}')
289
-
290
- controller = Controller(args.dispatch_method)
291
- uvicorn.run(app, host=args.host, port=args.port, log_level='info')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_web_server.py DELETED
@@ -1,761 +0,0 @@
1
- import argparse
2
- from ast import parse
3
- import datetime
4
- import json
5
- import os
6
- import time
7
- import hashlib
8
- import re
9
-
10
- import gradio as gr
11
- import requests
12
- import random
13
- from filelock import FileLock
14
- from io import BytesIO
15
- from PIL import Image, ImageDraw, ImageFont
16
-
17
- from constants import LOGDIR
18
- from utils import (
19
- build_logger,
20
- server_error_msg,
21
- violates_moderation,
22
- moderation_msg,
23
- load_image_from_base64,
24
- get_log_filename,
25
- )
26
- from conversation import Conversation
27
-
28
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
29
-
30
- headers = {"User-Agent": "InternVL-Chat Client"}
31
-
32
- no_change_btn = gr.Button()
33
- enable_btn = gr.Button(interactive=True)
34
- disable_btn = gr.Button(interactive=False)
35
-
36
-
37
- def write2file(path, content):
38
- lock = FileLock(f"{path}.lock")
39
- with lock:
40
- with open(path, "a") as fout:
41
- fout.write(content)
42
-
43
-
44
- def sort_models(models):
45
- def custom_sort_key(model_name):
46
- # InternVL-Chat-V1-5 should be the first item
47
- if model_name == "Vintern-1B-v3":
48
- return (1, model_name) # 1 indicates highest precedence
49
- elif model_name.startswith("Vintern-1B-v3"):
50
- return (1, model_name) # 1 indicates highest precedence
51
- else:
52
- return (0, model_name) # 0 indicates normal order
53
-
54
- models.sort(key=custom_sort_key, reverse=True)
55
- try: # We have five InternVL-Chat-V1-5 models, randomly choose one to be the first
56
- first_three = models[:4]
57
- random.shuffle(first_three)
58
- models[:4] = first_three
59
- except:
60
- pass
61
- return models
62
-
63
-
64
- def get_model_list():
65
- logger.info(f"Call `get_model_list`")
66
- ret = requests.post(args.controller_url + "/refresh_all_workers")
67
- logger.info(f"status_code from `get_model_list`: {ret.status_code}")
68
- assert ret.status_code == 200
69
- ret = requests.post(args.controller_url + "/list_models")
70
- logger.info(f"status_code from `list_models`: {ret.status_code}")
71
- models = ret.json()["models"]
72
- models = sort_models(models)
73
-
74
- logger.info(f"Models (from {args.controller_url}): {models}")
75
- return models
76
-
77
-
78
- get_window_url_params = """
79
- function() {
80
- const params = new URLSearchParams(window.location.search);
81
- url_params = Object.fromEntries(params);
82
- console.log(url_params);
83
- return url_params;
84
- }
85
- """
86
-
87
-
88
- def init_state(state=None):
89
- if state is not None:
90
- del state
91
- return Conversation()
92
-
93
- def load_demo(url_params, request: gr.Request = None):
94
- if not request:
95
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
96
-
97
- dropdown_update = gr.Dropdown(visible=True)
98
- if "model" in url_params:
99
- model = url_params["model"]
100
- if model in models:
101
- dropdown_update = gr.Dropdown(value=model, visible=True)
102
-
103
- state = init_state()
104
- return state, dropdown_update
105
-
106
-
107
- def load_demo_refresh_model_list(request: gr.Request = None):
108
- if not request:
109
- logger.info(f"load_demo. ip: {request.client.host}")
110
- models = get_model_list()
111
- state = init_state()
112
- dropdown_update = gr.Dropdown(
113
- choices=models, value=models[0] if len(models) > 0 else ""
114
- )
115
- return state, dropdown_update
116
-
117
-
118
- def vote_last_response(state, liked, model_selector, request: gr.Request):
119
- conv_data = {
120
- "tstamp": round(time.time(), 4),
121
- "like": liked,
122
- "model": model_selector,
123
- "state": state.dict(),
124
- "ip": request.client.host,
125
- }
126
- write2file(get_log_filename(), json.dumps(conv_data) + "\n")
127
-
128
-
129
- def upvote_last_response(state, model_selector, request: gr.Request):
130
- logger.info(f"upvote. ip: {request.client.host}")
131
- vote_last_response(state, True, model_selector, request)
132
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
133
- return (textbox,) + (disable_btn,) * 3
134
-
135
-
136
- def downvote_last_response(state, model_selector, request: gr.Request):
137
- logger.info(f"downvote. ip: {request.client.host}")
138
- vote_last_response(state, False, model_selector, request)
139
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
140
- return (textbox,) + (disable_btn,) * 3
141
-
142
-
143
- def vote_selected_response(
144
- state, model_selector, request: gr.Request, data: gr.LikeData
145
- ):
146
- logger.info(
147
- f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}"
148
- )
149
- conv_data = {
150
- "tstamp": round(time.time(), 4),
151
- "like": data.liked,
152
- "index": data.index,
153
- "model": model_selector,
154
- "state": state.dict(),
155
- "ip": request.client.host,
156
- }
157
- write2file(get_log_filename(), json.dumps(conv_data) + "\n")
158
- return
159
-
160
-
161
- def flag_last_response(state, model_selector, request: gr.Request):
162
- logger.info(f"flag. ip: {request.client.host}")
163
- vote_last_response(state, "flag", model_selector, request)
164
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
165
- return (textbox,) + (disable_btn,) * 3
166
-
167
-
168
- def regenerate(state, image_process_mode, request: gr.Request):
169
- logger.info(f"regenerate. ip: {request.client.host}")
170
- # state.messages[-1][-1] = None
171
- state.update_message(Conversation.ASSISTANT, None, -1)
172
- prev_human_msg = state.messages[-2]
173
- if type(prev_human_msg[1]) in (tuple, list):
174
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
175
- state.skip_next = False
176
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
177
- return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
178
-
179
-
180
- def clear_history(request: gr.Request):
181
- logger.info(f"clear_history. ip: {request.client.host}")
182
- state = init_state()
183
- textbox = gr.MultimodalTextbox(value=None, interactive=True)
184
- return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5
185
-
186
-
187
- def change_system_prompt(state, system_prompt, request: gr.Request):
188
- logger.info(f"Change system prompt. ip: {request.client.host}")
189
- state.set_system_message(system_prompt)
190
- return state
191
-
192
-
193
- def add_text(state, message, system_prompt, model_selector, request: gr.Request):
194
- print(f"state: {state}")
195
- if not state:
196
- state, model_selector = load_demo_refresh_model_list(request)
197
- images = message.get("files", [])
198
- text = message.get("text", "").strip()
199
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
200
- # import pdb; pdb.set_trace()
201
- textbox = gr.MultimodalTextbox(value=None, interactive=False)
202
- if len(text) <= 0 and len(images) == 0:
203
- state.skip_next = True
204
- return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
205
- if args.moderate:
206
- flagged = violates_moderation(text)
207
- if flagged:
208
- state.skip_next = True
209
- textbox = gr.MultimodalTextbox(
210
- value={"text": moderation_msg}, interactive=True
211
- )
212
- return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
213
- images = [Image.open(path).convert("RGB") for path in images]
214
-
215
- if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
216
- state = init_state(state)
217
- state.set_system_message(system_prompt)
218
- state.append_message(Conversation.USER, text, images)
219
- state.skip_next = False
220
- return (state, state.to_gradio_chatbot(), textbox, model_selector) + (
221
- disable_btn,
222
- ) * 5
223
-
224
-
225
- def http_bot(
226
- state,
227
- model_selector,
228
- temperature,
229
- top_p,
230
- repetition_penalty,
231
- max_new_tokens,
232
- max_input_tiles,
233
- # bbox_threshold,
234
- # mask_threshold,
235
- request: gr.Request,
236
- ):
237
- logger.info(f"http_bot. ip: {request.client.host}")
238
- start_tstamp = time.time()
239
- model_name = model_selector
240
- if hasattr(state, "skip_next") and state.skip_next:
241
- # This generate call is skipped due to invalid inputs
242
- yield (
243
- state,
244
- state.to_gradio_chatbot(),
245
- gr.MultimodalTextbox(interactive=False),
246
- ) + (no_change_btn,) * 5
247
- return
248
-
249
- # Query worker address
250
- controller_url = args.controller_url
251
- ret = requests.post(
252
- controller_url + "/get_worker_address", json={"model": model_name}
253
- )
254
- worker_addr = ret.json()["address"]
255
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
256
-
257
- # No available worker
258
- if worker_addr == "":
259
- # state.messages[-1][-1] = server_error_msg
260
- state.update_message(Conversation.ASSISTANT, server_error_msg)
261
- yield (
262
- state,
263
- state.to_gradio_chatbot(),
264
- gr.MultimodalTextbox(interactive=False),
265
- disable_btn,
266
- disable_btn,
267
- disable_btn,
268
- enable_btn,
269
- enable_btn,
270
- )
271
- return
272
-
273
- all_images = state.get_images(source=state.USER)
274
- all_image_paths = [state.save_image(image) for image in all_images]
275
-
276
- # Make requests
277
- pload = {
278
- "model": model_name,
279
- "prompt": state.get_prompt(),
280
- "temperature": float(temperature),
281
- "top_p": float(top_p),
282
- "max_new_tokens": max_new_tokens,
283
- "max_input_tiles": max_input_tiles,
284
- # "bbox_threshold": bbox_threshold,
285
- # "mask_threshold": mask_threshold,
286
- "repetition_penalty": repetition_penalty,
287
- "images": f"List of {len(all_images)} images: {all_image_paths}",
288
- }
289
- logger.info(f"==== request ====\n{pload}")
290
- pload.pop("images")
291
- pload["prompt"] = state.get_prompt(inlude_image=True)
292
- state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
293
- yield (
294
- state,
295
- state.to_gradio_chatbot(),
296
- gr.MultimodalTextbox(interactive=False),
297
- ) + (disable_btn,) * 5
298
-
299
- try:
300
- # Stream output
301
- response = requests.post(
302
- worker_addr + "/worker_generate_stream",
303
- headers=headers,
304
- json=pload,
305
- stream=True,
306
- timeout=20,
307
- )
308
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
309
- if chunk:
310
- data = json.loads(chunk.decode())
311
- if data["error_code"] == 0:
312
- if "text" in data:
313
- output = data["text"].strip()
314
- output += state.streaming_placeholder
315
-
316
- image = None
317
- if "image" in data:
318
- image = load_image_from_base64(data["image"])
319
- _ = state.save_image(image)
320
-
321
- state.update_message(Conversation.ASSISTANT, output, image)
322
- yield (
323
- state,
324
- state.to_gradio_chatbot(),
325
- gr.MultimodalTextbox(interactive=False),
326
- ) + (disable_btn,) * 5
327
- else:
328
- output = (
329
- f"**{data['text']}**" + f" (error_code: {data['error_code']})"
330
- )
331
-
332
- state.update_message(Conversation.ASSISTANT, output, None)
333
- yield (
334
- state,
335
- state.to_gradio_chatbot(),
336
- gr.MultimodalTextbox(interactive=True),
337
- ) + (
338
- disable_btn,
339
- disable_btn,
340
- disable_btn,
341
- enable_btn,
342
- enable_btn,
343
- )
344
- return
345
- except requests.exceptions.RequestException as e:
346
- state.update_message(Conversation.ASSISTANT, server_error_msg, None)
347
- yield (
348
- state,
349
- state.to_gradio_chatbot(),
350
- gr.MultimodalTextbox(interactive=True),
351
- ) + (
352
- disable_btn,
353
- disable_btn,
354
- disable_btn,
355
- enable_btn,
356
- enable_btn,
357
- )
358
- return
359
-
360
- ai_response = state.return_last_message()
361
-
362
- state.end_of_current_turn()
363
-
364
- yield (
365
- state,
366
- state.to_gradio_chatbot(),
367
- gr.MultimodalTextbox(interactive=True),
368
- ) + (enable_btn,) * 5
369
-
370
- finish_tstamp = time.time()
371
- logger.info(f"{output}")
372
- data = {
373
- "tstamp": round(finish_tstamp, 4),
374
- "like": None,
375
- "model": model_name,
376
- "start": round(start_tstamp, 4),
377
- "finish": round(start_tstamp, 4),
378
- "state": state.dict(),
379
- "images": all_image_paths,
380
- "ip": request.client.host,
381
- }
382
- write2file(get_log_filename(), json.dumps(data) + "\n")
383
-
384
-
385
- title_html = """
386
- <h2> <span class="gradient-text" id="text">InternVL2</span><span class="plain-text">: Better than the Best—Expanding Performance Boundaries of Open-Source Multimodal Models with the Progressive Scaling Strategy</span></h2>
387
- <a href="https://internvl.github.io/blog/2024-07-02-InternVL-2.0/">[📜 InternVL2 Blog]</a>
388
- <a href="https://huggingface.co/spaces/OpenGVLab/InternVL">[🤗 HF Demo]</a>
389
- <a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a>
390
- <a href="https://github.com/OpenGVLab/InternVL/blob/main/document/How_to_use_InternVL_API.md">[🌐 API]</a>
391
- """
392
-
393
- tos_markdown = """
394
- ### Terms of use
395
- By using this service, users are required to agree to the following terms:
396
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
397
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
398
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
399
- """
400
-
401
-
402
- learn_more_markdown = """
403
- ### License
404
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
405
-
406
- ### Acknowledgement
407
- This demo is modified from LLaVA's demo. Thanks for their awesome work!
408
- """
409
- # .gradio-container {margin: 5px 10px 0 10px !important};
410
- block_css = """
411
- .gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;};
412
- #buttons button {
413
- min-width: min(120px,100%);
414
- }
415
-
416
- .gradient-text {
417
- font-size: 28px;
418
- width: auto;
419
- font-weight: bold;
420
- background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet);
421
- background-clip: text;
422
- -webkit-background-clip: text;
423
- color: transparent;
424
- }
425
-
426
- .plain-text {
427
- font-size: 22px;
428
- width: auto;
429
- font-weight: bold;
430
- }
431
- """
432
-
433
- js = """
434
- function createWaveAnimation() {
435
- const text = document.getElementById('text');
436
- var i = 0;
437
- setInterval(function() {
438
- const colors = [
439
- 'red, orange, yellow, green, blue, indigo, violet, purple',
440
- 'orange, yellow, green, blue, indigo, violet, purple, red',
441
- 'yellow, green, blue, indigo, violet, purple, red, orange',
442
- 'green, blue, indigo, violet, purple, red, orange, yellow',
443
- 'blue, indigo, violet, purple, red, orange, yellow, green',
444
- 'indigo, violet, purple, red, orange, yellow, green, blue',
445
- 'violet, purple, red, orange, yellow, green, blue, indigo',
446
- 'purple, red, orange, yellow, green, blue, indigo, violet',
447
- ];
448
- const angle = 45;
449
- const colorIndex = i % colors.length;
450
- text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`;
451
- text.style.webkitBackgroundClip = 'text';
452
- text.style.backgroundClip = 'text';
453
- text.style.color = 'transparent';
454
- text.style.fontSize = '28px';
455
- text.style.width = 'auto';
456
- text.textContent = 'InternVL2';
457
- text.style.fontWeight = 'bold';
458
- i += 1;
459
- }, 200);
460
- const params = new URLSearchParams(window.location.search);
461
- url_params = Object.fromEntries(params);
462
- // console.log(url_params);
463
- // console.log('hello world...');
464
- // console.log(window.location.search);
465
- // console.log('hello world...');
466
- // alert(window.location.search)
467
- // alert(url_params);
468
- return url_params;
469
- }
470
-
471
- """
472
-
473
-
474
- def build_demo(embed_mode):
475
- textbox = gr.MultimodalTextbox(
476
- interactive=True,
477
- file_types=["image", "video"],
478
- placeholder="Enter message or upload file...",
479
- show_label=False,
480
- )
481
-
482
- with gr.Blocks(
483
- title="InternVL-Chat",
484
- theme=gr.themes.Default(),
485
- css=block_css,
486
- ) as demo:
487
- state = gr.State()
488
-
489
- if not embed_mode:
490
- # gr.Markdown(title_markdown)
491
- gr.HTML(title_html)
492
-
493
- with gr.Row():
494
- with gr.Column(scale=2):
495
-
496
- with gr.Row(elem_id="model_selector_row"):
497
- model_selector = gr.Dropdown(
498
- choices=models,
499
- value=models[0] if len(models) > 0 else "",
500
- # value="InternVL-Chat-V1-5",
501
- interactive=True,
502
- show_label=False,
503
- container=False,
504
- )
505
-
506
- with gr.Accordion("System Prompt", open=False) as system_prompt_row:
507
- system_prompt = gr.Textbox(
508
- value="请尽可能详细地回答用户的问题。",
509
- label="System Prompt",
510
- interactive=True,
511
- )
512
- with gr.Accordion("Parameters", open=False) as parameter_row:
513
- temperature = gr.Slider(
514
- minimum=0.0,
515
- maximum=1.0,
516
- value=0.2,
517
- step=0.1,
518
- interactive=True,
519
- label="Temperature",
520
- )
521
- top_p = gr.Slider(
522
- minimum=0.0,
523
- maximum=1.0,
524
- value=0.7,
525
- step=0.1,
526
- interactive=True,
527
- label="Top P",
528
- )
529
- repetition_penalty = gr.Slider(
530
- minimum=1.0,
531
- maximum=1.5,
532
- value=1.1,
533
- step=0.02,
534
- interactive=True,
535
- label="Repetition penalty",
536
- )
537
- max_output_tokens = gr.Slider(
538
- minimum=0,
539
- maximum=4096,
540
- value=1024,
541
- step=64,
542
- interactive=True,
543
- label="Max output tokens",
544
- )
545
- max_input_tiles = gr.Slider(
546
- minimum=1,
547
- maximum=32,
548
- value=12,
549
- step=1,
550
- interactive=True,
551
- label="Max input tiles (control the image size)",
552
- )
553
- examples = gr.Examples(
554
- examples=[
555
- [
556
- {
557
- "files": [
558
- "gallery/prod_9.jpg",
559
- ],
560
- "text": "What's at the far end of the image?",
561
- }
562
- ],
563
- [
564
- {
565
- "files": [
566
- "gallery/astro_on_unicorn.png",
567
- ],
568
- "text": "What does this image mean?",
569
- }
570
- ],
571
- [
572
- {
573
- "files": [
574
- "gallery/prod_12.png",
575
- ],
576
- "text": "What are the consequences of the easy decisions shown in this image?",
577
- }
578
- ],
579
- [
580
- {
581
- "files": [
582
- "gallery/child_1.jpg",
583
- "gallery/child_2.jpg",
584
- f"gallery/child_3.jpg",
585
- ],
586
- "text": "这三帧图片讲述了一件什么事情?",
587
- }
588
- ],
589
- ],
590
- inputs=[textbox],
591
- )
592
-
593
- with gr.Column(scale=8):
594
- chatbot = gr.Chatbot(
595
- elem_id="chatbot",
596
- label="InternVL2",
597
- height=580,
598
- show_copy_button=True,
599
- show_share_button=True,
600
- avatar_images=[
601
- "assets/human.png",
602
- "assets/assistant.png",
603
- ],
604
- bubble_full_width=False,
605
- )
606
- with gr.Row():
607
- with gr.Column(scale=8):
608
- textbox.render()
609
- with gr.Column(scale=1, min_width=50):
610
- submit_btn = gr.Button(value="Send", variant="primary")
611
- with gr.Row(elem_id="buttons") as button_row:
612
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
613
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
614
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
615
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
616
- regenerate_btn = gr.Button(
617
- value="🔄 Regenerate", interactive=False
618
- )
619
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
620
-
621
- if not embed_mode:
622
- gr.Markdown(tos_markdown)
623
- gr.Markdown(learn_more_markdown)
624
- url_params = gr.JSON(visible=False)
625
-
626
- # Register listeners
627
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
628
- upvote_btn.click(
629
- upvote_last_response,
630
- [state, model_selector],
631
- [textbox, upvote_btn, downvote_btn, flag_btn],
632
- )
633
- downvote_btn.click(
634
- downvote_last_response,
635
- [state, model_selector],
636
- [textbox, upvote_btn, downvote_btn, flag_btn],
637
- )
638
- chatbot.like(
639
- vote_selected_response,
640
- [state, model_selector],
641
- [],
642
- )
643
- flag_btn.click(
644
- flag_last_response,
645
- [state, model_selector],
646
- [textbox, upvote_btn, downvote_btn, flag_btn],
647
- )
648
- regenerate_btn.click(
649
- regenerate,
650
- [state, system_prompt],
651
- [state, chatbot, textbox] + btn_list,
652
- ).then(
653
- http_bot,
654
- [
655
- state,
656
- model_selector,
657
- temperature,
658
- top_p,
659
- repetition_penalty,
660
- max_output_tokens,
661
- max_input_tiles,
662
- # bbox_threshold,
663
- # mask_threshold,
664
- ],
665
- [state, chatbot, textbox] + btn_list,
666
- )
667
- clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
668
-
669
- textbox.submit(
670
- add_text,
671
- [state, textbox, system_prompt, model_selector],
672
- [state, chatbot, textbox, model_selector] + btn_list,
673
- ).then(
674
- http_bot,
675
- [
676
- state,
677
- model_selector,
678
- temperature,
679
- top_p,
680
- repetition_penalty,
681
- max_output_tokens,
682
- max_input_tiles,
683
- # bbox_threshold,
684
- # mask_threshold,
685
- ],
686
- [state, chatbot, textbox] + btn_list,
687
- )
688
- submit_btn.click(
689
- add_text,
690
- [state, textbox, system_prompt, model_selector],
691
- [state, chatbot, textbox, model_selector] + btn_list,
692
- ).then(
693
- http_bot,
694
- [
695
- state,
696
- model_selector,
697
- temperature,
698
- top_p,
699
- repetition_penalty,
700
- max_output_tokens,
701
- max_input_tiles,
702
- # bbox_threshold,
703
- # mask_threshold,
704
- ],
705
- [state, chatbot, textbox] + btn_list,
706
- )
707
-
708
- # NOTE: The following code will be not triggered when deployed on HF space.
709
- # It's very strange. I don't know why.
710
- """
711
- if args.model_list_mode == "once":
712
- demo.load(
713
- load_demo,
714
- [url_params],
715
- [state, model_selector],
716
- js=js,
717
- )
718
- elif args.model_list_mode == "reload":
719
- demo.load(
720
- load_demo_refresh_model_list,
721
- None,
722
- [state, model_selector],
723
- js=js,
724
- )
725
- else:
726
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
727
- """
728
-
729
- return demo
730
-
731
-
732
- if __name__ == "__main__":
733
- parser = argparse.ArgumentParser()
734
- parser.add_argument("--host", type=str, default="0.0.0.0")
735
- parser.add_argument("--port", type=int, default=7860)
736
- parser.add_argument("--controller-url", type=str, default=None)
737
- parser.add_argument("--concurrency-count", type=int, default=10)
738
- parser.add_argument(
739
- "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
740
- )
741
- parser.add_argument("--share", action="store_true")
742
- parser.add_argument("--moderate", action="store_true")
743
- parser.add_argument("--embed", action="store_true")
744
- args = parser.parse_args()
745
- logger.info(f"args: {args}")
746
- if not args.controller_url:
747
- args.controller_url = os.environ.get("CONTROLLER_URL", None)
748
-
749
- if not args.controller_url:
750
- raise ValueError("controller-url is required.")
751
-
752
- models = get_model_list()
753
-
754
- logger.info(args)
755
- demo = build_demo(args.embed)
756
- demo.queue(api_open=False).launch(
757
- server_name=args.host,
758
- server_port=args.port,
759
- share=args.share,
760
- max_threads=args.concurrency_count,
761
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_worker.py DELETED
@@ -1,541 +0,0 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- """
8
- A model worker executes the model.
9
- """
10
- import spaces
11
- import os
12
- import argparse
13
- import asyncio
14
-
15
- import json
16
- import math
17
- import threading
18
- import time
19
- import uuid
20
- import traceback
21
- from functools import partial
22
- from threading import Thread
23
-
24
- import requests
25
- import torch
26
- import torchvision.transforms as T
27
- import uvicorn
28
- from constants import IMAGENET_MEAN, IMAGENET_STD, WORKER_HEART_BEAT_INTERVAL
29
- from fastapi import BackgroundTasks, FastAPI, Request
30
- from fastapi.responses import StreamingResponse
31
- from PIL import Image
32
- from torchvision.transforms.functional import InterpolationMode
33
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
34
- from utils import (
35
- build_logger,
36
- pretty_print_semaphore,
37
- server_error_msg,
38
- load_image_from_base64,
39
- )
40
-
41
-
42
- worker_id = str(uuid.uuid4())[:6]
43
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
44
- global_counter = 0
45
- model_semaphore = None
46
-
47
-
48
- def build_transform(input_size):
49
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
50
- transform = T.Compose(
51
- [
52
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
53
- T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
54
- T.ToTensor(),
55
- T.Normalize(mean=MEAN, std=STD),
56
- ]
57
- )
58
- return transform
59
-
60
-
61
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
62
- best_ratio_diff = float("inf")
63
- best_ratio = (1, 1)
64
- area = width * height
65
- for ratio in target_ratios:
66
- target_aspect_ratio = ratio[0] / ratio[1]
67
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
68
- if ratio_diff < best_ratio_diff:
69
- best_ratio_diff = ratio_diff
70
- best_ratio = ratio
71
- elif ratio_diff == best_ratio_diff:
72
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
73
- best_ratio = ratio
74
- return best_ratio
75
-
76
-
77
- def dynamic_preprocess(
78
- image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
79
- ):
80
- orig_width, orig_height = image.size
81
- aspect_ratio = orig_width / orig_height
82
-
83
- # calculate the existing image aspect ratio
84
- target_ratios = set(
85
- (i, j)
86
- for n in range(min_num, max_num + 1)
87
- for i in range(1, n + 1)
88
- for j in range(1, n + 1)
89
- if i * j <= max_num and i * j >= min_num
90
- )
91
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
92
-
93
- # find the closest aspect ratio to the target
94
- target_aspect_ratio = find_closest_aspect_ratio(
95
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
96
- )
97
-
98
- # calculate the target width and height
99
- target_width = image_size * target_aspect_ratio[0]
100
- target_height = image_size * target_aspect_ratio[1]
101
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
102
-
103
- # resize the image
104
- resized_img = image.resize((target_width, target_height))
105
- processed_images = []
106
- for i in range(blocks):
107
- box = (
108
- (i % (target_width // image_size)) * image_size,
109
- (i // (target_width // image_size)) * image_size,
110
- ((i % (target_width // image_size)) + 1) * image_size,
111
- ((i // (target_width // image_size)) + 1) * image_size,
112
- )
113
- # split the image
114
- split_img = resized_img.crop(box)
115
- processed_images.append(split_img)
116
- assert len(processed_images) == blocks
117
- if use_thumbnail and len(processed_images) != 1:
118
- thumbnail_img = image.resize((image_size, image_size))
119
- processed_images.append(thumbnail_img)
120
- return processed_images
121
-
122
-
123
- def heart_beat_worker(controller):
124
- while True:
125
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
126
- controller.send_heart_beat()
127
-
128
-
129
- def split_model(model_name):
130
- device_map = {}
131
- world_size = torch.cuda.device_count()
132
- num_layers = {
133
- "Vintern-1B-v3": 24,
134
- }[model_name]
135
- # Since the first GPU will be used for ViT, treat it as half a GPU.
136
- num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
137
- num_layers_per_gpu = [num_layers_per_gpu] * world_size
138
- num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
139
- layer_cnt = 0
140
- for i, num_layer in enumerate(num_layers_per_gpu):
141
- for j in range(num_layer):
142
- device_map[f"language_model.model.layers.{layer_cnt}"] = i
143
- layer_cnt += 1
144
- device_map["vision_model"] = 0
145
- device_map["mlp1"] = 0
146
- device_map["language_model.model.tok_embeddings"] = 0
147
- device_map["language_model.model.embed_tokens"] = 0
148
- device_map["language_model.output"] = 0
149
- device_map["language_model.model.norm"] = 0
150
- device_map["language_model.lm_head"] = 0
151
- device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
152
-
153
- return device_map
154
-
155
-
156
- def multi_thread_infer(
157
- model, tokenizer, pixel_values, question, history, generation_config
158
- ):
159
- with torch.no_grad():
160
- thread = Thread(
161
- target=model.chat,
162
- kwargs=dict(
163
- tokenizer=tokenizer,
164
- pixel_values=pixel_values,
165
- question=question,
166
- history=history,
167
- return_history=False,
168
- generation_config=generation_config,
169
- ),
170
- )
171
- thread.start()
172
-
173
-
174
- class ModelWorker:
175
- def __init__(
176
- self,
177
- controller_addr,
178
- worker_addr,
179
- worker_id,
180
- model_path,
181
- model_name,
182
- load_8bit,
183
- device,
184
- context_len=8192,
185
- ):
186
- self.controller_addr = controller_addr
187
- self.worker_addr = worker_addr
188
- self.worker_id = worker_id
189
- if model_path.endswith("/"):
190
- model_path = model_path[:-1]
191
- if model_name is None:
192
- model_paths = model_path.split("/")
193
- if model_paths[-1].startswith("checkpoint-"):
194
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
195
- else:
196
- self.model_name = model_paths[-1]
197
- else:
198
- self.model_name = model_name
199
-
200
- self.import_flash_attn()
201
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
202
- tokenizer = AutoTokenizer.from_pretrained(
203
- model_path, trust_remote_code=True, use_fast=False
204
- )
205
- self.tokenizer = tokenizer
206
-
207
- if device == "auto":
208
- device_map = split_model(self.model_name)
209
- self.model = AutoModel.from_pretrained(
210
- model_path,
211
- load_in_8bit=load_8bit,
212
- torch_dtype=torch.bfloat16,
213
- device_map=device_map,
214
- trust_remote_code=True,
215
- ).eval()
216
- else:
217
- self.model = AutoModel.from_pretrained(
218
- model_path,
219
- load_in_8bit=load_8bit,
220
- torch_dtype=torch.bfloat16,
221
- trust_remote_code=True,
222
- ).eval()
223
- if not load_8bit and not device == "auto":
224
- self.model = self.model.cuda()
225
- self.load_8bit = load_8bit
226
- self.device = device
227
- self.model_path = model_path
228
- self.image_size = self.model.config.force_image_size
229
- self.context_len = context_len
230
- self.register_to_controller()
231
- self.heart_beat_thread = threading.Thread(
232
- target=heart_beat_worker, args=(self,)
233
- )
234
- self.heart_beat_thread.start()
235
-
236
- @spaces.GPU(duration=120)
237
- def import_flash_attn(self):
238
- try:
239
- import flash_attn
240
- except ImportError:
241
-
242
- def install_flash_attn():
243
- os.system(
244
- "FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn==2.5.9.post1 --no-build-isolation"
245
- )
246
-
247
- install_flash_attn()
248
- # import flash_attn
249
-
250
- def reload_model(self):
251
- del self.model
252
- torch.cuda.empty_cache()
253
- if self.device == "auto":
254
- device_map = split_model(self.model_name)
255
- self.model = AutoModel.from_pretrained(
256
- self.model_path,
257
- load_in_8bit=self.load_8bit,
258
- torch_dtype=torch.bfloat16,
259
- device_map=device_map,
260
- trust_remote_code=True,
261
- ).eval()
262
- else:
263
- self.model = AutoModel.from_pretrained(
264
- self.model_path,
265
- load_in_8bit=self.load_8bit,
266
- torch_dtype=torch.bfloat16,
267
- trust_remote_code=True,
268
- ).eval()
269
- if not self.load_8bit and not self.device == "auto":
270
- self.model = self.model.cuda()
271
-
272
- def register_to_controller(self):
273
- logger.info("Register to controller")
274
-
275
- url = self.controller_addr + "/register_worker"
276
- data = {
277
- "worker_name": self.worker_addr,
278
- "check_heart_beat": True,
279
- "worker_status": self.get_status(),
280
- }
281
- r = requests.post(url, json=data)
282
- assert r.status_code == 200
283
-
284
- def send_heart_beat(self):
285
- logger.info(
286
- f"Send heart beat. Models: {[self.model_name]}. "
287
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
288
- f"global_counter: {global_counter}"
289
- )
290
-
291
- url = self.controller_addr + "/receive_heart_beat"
292
-
293
- while True:
294
- try:
295
- ret = requests.post(
296
- url,
297
- json={
298
- "worker_name": self.worker_addr,
299
- "queue_length": self.get_queue_length(),
300
- },
301
- timeout=5,
302
- )
303
- exist = ret.json()["exist"]
304
- break
305
- except requests.exceptions.RequestException as e:
306
- logger.error(f"heart beat error: {e}")
307
- time.sleep(5)
308
-
309
- if not exist:
310
- self.register_to_controller()
311
-
312
- def get_queue_length(self):
313
- if model_semaphore is None:
314
- return 0
315
- else:
316
- return (
317
- args.limit_model_concurrency
318
- - model_semaphore._value
319
- + (
320
- len(model_semaphore._waiters)
321
- if model_semaphore._waiters is not None
322
- else 0
323
- )
324
- )
325
-
326
- def get_status(self):
327
- return {
328
- "model_names": [self.model_name],
329
- "speed": 1,
330
- "queue_length": self.get_queue_length(),
331
- }
332
-
333
- def generate_stream(self, params):
334
- system_message = params["prompt"][0]["content"]
335
- send_messages = params["prompt"][1:]
336
- max_input_tiles = params["max_input_tiles"]
337
- temperature = params["temperature"]
338
- top_p = params["top_p"]
339
- max_new_tokens = params["max_new_tokens"]
340
- repetition_penalty = params["repetition_penalty"]
341
- do_sample = True if temperature > 0.0 else False
342
-
343
- global_image_cnt = 0
344
- history, pil_images, max_input_tile_list = [], [], []
345
- for message in send_messages:
346
- if message["role"] == "user":
347
- prefix = ""
348
- if "image" in message:
349
- max_input_tile_temp = []
350
- for image_str in message["image"]:
351
- pil_images.append(load_image_from_base64(image_str))
352
- prefix += f"Image-{global_image_cnt + 1}: <image>\n\n"
353
- global_image_cnt += 1
354
- max_input_tile_temp.append(
355
- max(1, max_input_tiles // len(message["image"]))
356
- )
357
- if len(max_input_tile_temp) > 0:
358
- max_input_tile_list.append(max_input_tile_temp)
359
- content = prefix + message["content"]
360
- history.append(
361
- [
362
- content,
363
- ]
364
- )
365
- else:
366
- history[-1].append(message["content"])
367
- question, history = history[-1][0], history[:-1]
368
-
369
- if global_image_cnt == 1:
370
- question = question.replace("Image-1: <image>\n\n", "<image>\n")
371
- history = [
372
- [item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
373
- for item in history
374
- ]
375
-
376
- # Create a new list to store processed sublists
377
- flattened_list = []
378
- # Iterate through all but the last sublist in max_input_tile_list and process them
379
- for sublist in max_input_tile_list[:-1]:
380
- processed_sublist = [1] * len(
381
- sublist
382
- ) # Change each element in the sublist to 1
383
- flattened_list.extend(
384
- processed_sublist
385
- ) # Flatten the processed sublist and add to the new list
386
- # If max_input_tile_list is not empty, add the last sublist to the new list
387
- if max_input_tile_list:
388
- flattened_list.extend(max_input_tile_list[-1])
389
- max_input_tile_list = flattened_list
390
- assert len(max_input_tile_list) == len(
391
- pil_images
392
- ), "The number of max_input_tile_list and pil_images should be the same."
393
-
394
- old_system_message = self.model.system_message
395
- self.model.system_message = system_message
396
- image_tiles = []
397
- transform = build_transform(input_size=self.image_size)
398
- if len(pil_images) > 0:
399
- for current_max_input_tiles, pil_image in zip(
400
- max_input_tile_list, pil_images
401
- ):
402
- if self.model.config.dynamic_image_size:
403
- tiles = dynamic_preprocess(
404
- pil_image,
405
- image_size=self.image_size,
406
- max_num=current_max_input_tiles,
407
- use_thumbnail=self.model.config.use_thumbnail,
408
- )
409
- else:
410
- tiles = [pil_image]
411
- image_tiles += tiles
412
- pixel_values = [transform(item) for item in image_tiles]
413
- pixel_values = torch.stack(pixel_values).to(
414
- self.model.device, dtype=torch.bfloat16
415
- )
416
- logger.info(f"Split images to {pixel_values.shape}")
417
- else:
418
- pixel_values = None
419
-
420
- streamer = TextIteratorStreamer(
421
- self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
422
- )
423
- generation_config = dict(
424
- num_beams=1,
425
- max_new_tokens=max_new_tokens,
426
- do_sample=do_sample,
427
- temperature=temperature,
428
- repetition_penalty=repetition_penalty,
429
- max_length=self.context_len,
430
- top_p=top_p,
431
- streamer=streamer,
432
- )
433
- logger.info(f"Generation config: {generation_config}")
434
- multi_thread_infer(
435
- self.model,
436
- self.tokenizer,
437
- pixel_values,
438
- question,
439
- history,
440
- generation_config,
441
- )
442
-
443
- generated_text = ""
444
- for new_text in streamer:
445
- generated_text += new_text
446
- if generated_text.endswith(self.model.conv_template.sep):
447
- generated_text = generated_text[: -len(self.model.conv_template.sep)]
448
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
449
- logger.info(
450
- f"max_input_tile_list: {max_input_tile_list}, history: {history}, "
451
- f"question: {question}, answer: {generated_text}"
452
- )
453
- self.model.system_message = old_system_message
454
-
455
- def generate_stream_gate(self, params):
456
- try:
457
- for x in self.generate_stream(params):
458
- yield x
459
- except ValueError as e:
460
- print("Caught ValueError:", e)
461
- traceback.print_exc()
462
- ret = {
463
- "text": server_error_msg,
464
- "error_code": 1,
465
- }
466
- yield json.dumps(ret).encode() + b"\0"
467
- except torch.cuda.CudaError as e:
468
- traceback.print_exc()
469
- print("Caught torch.cuda.CudaError:", e)
470
- ret = {
471
- "text": server_error_msg,
472
- "error_code": 1,
473
- }
474
- yield json.dumps(ret).encode() + b"\0"
475
- except Exception as e:
476
- traceback.print_exc()
477
- print("Caught Unknown Error", e)
478
- ret = {
479
- "text": server_error_msg,
480
- "error_code": 1,
481
- }
482
- yield json.dumps(ret).encode() + b"\0"
483
-
484
-
485
- app = FastAPI()
486
-
487
-
488
- def release_model_semaphore(fn=None):
489
- model_semaphore.release()
490
- if fn is not None:
491
- fn()
492
-
493
-
494
- @app.post("/worker_generate_stream")
495
- async def generate_stream(request: Request):
496
- global model_semaphore, global_counter
497
- global_counter += 1
498
- params = await request.json()
499
-
500
- if model_semaphore is None:
501
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
502
- await model_semaphore.acquire()
503
- worker.send_heart_beat()
504
- generator = worker.generate_stream_gate(params)
505
- background_tasks = BackgroundTasks()
506
- background_tasks.add_task(
507
- partial(release_model_semaphore, fn=worker.send_heart_beat)
508
- )
509
- return StreamingResponse(generator, background=background_tasks)
510
-
511
-
512
- @app.post("/worker_get_status")
513
- async def get_status(request: Request):
514
- return worker.get_status()
515
-
516
-
517
- if __name__ == "__main__":
518
- parser = argparse.ArgumentParser()
519
- parser.add_argument("--host", type=str, default="0.0.0.0")
520
- parser.add_argument("--port", type=int, default=21002)
521
- parser.add_argument("--worker-url", type=str, default="http://localhost")
522
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
523
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
524
- parser.add_argument("--model-name", type=str)
525
- parser.add_argument("--device", type=str, default="cuda")
526
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
527
- parser.add_argument("--stream-interval", type=int, default=1)
528
- parser.add_argument("--load-8bit", action="store_true")
529
- args = parser.parse_args()
530
- logger.info(f"args: {args}")
531
-
532
- worker = ModelWorker(
533
- args.controller_url,
534
- args.worker_url + f":{args.port}",
535
- worker_id,
536
- args.model_path,
537
- args.model_name,
538
- args.load_8bit,
539
- args.device,
540
- )
541
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")