Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
VLLM-based demo script to launch Language chat model for Southeast Asian Languages
|
7 |
"""
|
8 |
|
|
|
9 |
import os
|
10 |
import numpy as np
|
11 |
import argparse
|
@@ -972,53 +973,6 @@ gr.ChatInterface._setup_stop_events = _setup_stop_events
|
|
972 |
gr.ChatInterface._setup_events = _setup_events
|
973 |
|
974 |
|
975 |
-
|
976 |
-
@document()
|
977 |
-
class CustomTabbedInterface(gr.Blocks):
|
978 |
-
def __init__(
|
979 |
-
self,
|
980 |
-
interface_list: list[gr.Interface],
|
981 |
-
tab_names: Optional[list[str]] = None,
|
982 |
-
title: Optional[str] = None,
|
983 |
-
description: Optional[str] = None,
|
984 |
-
theme: Optional[gr.Theme] = None,
|
985 |
-
analytics_enabled: Optional[bool] = None,
|
986 |
-
css: Optional[str] = None,
|
987 |
-
):
|
988 |
-
"""
|
989 |
-
Parameters:
|
990 |
-
interface_list: a list of interfaces to be rendered in tabs.
|
991 |
-
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
992 |
-
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
993 |
-
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
994 |
-
css: custom css or path to custom css file to apply to entire Blocks
|
995 |
-
Returns:
|
996 |
-
a Gradio Tabbed Interface for the given interfaces
|
997 |
-
"""
|
998 |
-
super().__init__(
|
999 |
-
title=title or "Gradio",
|
1000 |
-
theme=theme,
|
1001 |
-
analytics_enabled=analytics_enabled,
|
1002 |
-
mode="tabbed_interface",
|
1003 |
-
css=css,
|
1004 |
-
)
|
1005 |
-
self.description = description
|
1006 |
-
if tab_names is None:
|
1007 |
-
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
1008 |
-
with self:
|
1009 |
-
if title:
|
1010 |
-
gr.Markdown(
|
1011 |
-
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
1012 |
-
)
|
1013 |
-
if description:
|
1014 |
-
gr.Markdown(description)
|
1015 |
-
with gr.Tabs():
|
1016 |
-
for interface, tab_name in zip(interface_list, tab_names):
|
1017 |
-
with gr.Tab(label=tab_name):
|
1018 |
-
interface.render()
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
def vllm_abort(self: Any):
|
1023 |
sh = self.llm_engine.scheduler
|
1024 |
for g in (sh.waiting + sh.running + sh.swapped):
|
@@ -1297,7 +1251,7 @@ def format_conversation(history):
|
|
1297 |
|
1298 |
def maybe_upload_to_dataset():
|
1299 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1300 |
-
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH
|
1301 |
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1302 |
convos = {}
|
1303 |
for l in f:
|
@@ -1396,7 +1350,6 @@ def maybe_delete_folder():
|
|
1396 |
except Exception as e:
|
1397 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1398 |
|
1399 |
-
|
1400 |
AGREE_POP_SCRIPTS = """
|
1401 |
async () => {
|
1402 |
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
@@ -1413,7 +1366,6 @@ def debug_file_function(
|
|
1413 |
stop_strings: str = "[STOP],<s>,</s>",
|
1414 |
current_time: Optional[float] = None,
|
1415 |
):
|
1416 |
-
"""This is only for debug purpose"""
|
1417 |
files = files if isinstance(files, list) else [files]
|
1418 |
print(files)
|
1419 |
filenames = [f.name for f in files]
|
@@ -1439,9 +1391,7 @@ def debug_file_function(
|
|
1439 |
|
1440 |
|
1441 |
def validate_file_item(filename, index, item: Dict[str, str]):
|
1442 |
-
|
1443 |
-
check safety for items in files
|
1444 |
-
"""
|
1445 |
message = item['prompt'].strip()
|
1446 |
|
1447 |
if len(message) == 0:
|
@@ -1449,7 +1399,7 @@ def validate_file_item(filename, index, item: Dict[str, str]):
|
|
1449 |
|
1450 |
message_safety = safety_check(message, history=None)
|
1451 |
if message_safety is not None:
|
1452 |
-
raise gr.Error(f'Prompt {index}
|
1453 |
|
1454 |
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1455 |
if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
@@ -1473,33 +1423,25 @@ def read_validate_json_files(files: Union[str, List[str]]):
|
|
1473 |
validate_file_item(fname, i, x)
|
1474 |
|
1475 |
all_items.extend(items)
|
1476 |
-
|
1477 |
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1478 |
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1479 |
|
1480 |
-
return all_items
|
1481 |
|
1482 |
|
1483 |
-
def remove_gradio_cache(
|
1484 |
-
"""remove gradio cache to avoid flooding"""
|
1485 |
import shutil
|
1486 |
for root, dirs, files in os.walk('/tmp/gradio/'):
|
1487 |
for f in files:
|
1488 |
-
|
1489 |
-
|
1490 |
-
|
1491 |
-
os.unlink(os.path.join(root, f))
|
1492 |
-
# for d in dirs:
|
1493 |
-
# # if not any(d in ef for ef in except_files):
|
1494 |
-
# if exclude_names is None or not any(ef in d for ef in exclude_names):
|
1495 |
-
# print(f'Remove d: {d}')
|
1496 |
-
# shutil.rmtree(os.path.join(root, d))
|
1497 |
|
1498 |
|
1499 |
def maybe_upload_batch_set(pred_json_path):
|
1500 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1501 |
|
1502 |
-
if SAVE_LOGS and DATA_SET_REPO_PATH
|
1503 |
try:
|
1504 |
from huggingface_hub import upload_file
|
1505 |
path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
|
@@ -1528,7 +1470,7 @@ def batch_inference(
|
|
1528 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1529 |
):
|
1530 |
"""
|
1531 |
-
|
1532 |
|
1533 |
"""
|
1534 |
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
@@ -1551,10 +1493,11 @@ def batch_inference(
|
|
1551 |
frequency_penalty = float(frequency_penalty)
|
1552 |
max_tokens = int(max_tokens)
|
1553 |
|
1554 |
-
all_items
|
1555 |
|
1556 |
# remove all items in /tmp/gradio/
|
1557 |
-
remove_gradio_cache(
|
|
|
1558 |
|
1559 |
if prompt_mode == 'chat':
|
1560 |
prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
|
@@ -1594,6 +1537,7 @@ def batch_inference(
|
|
1594 |
for res, item in zip(responses, all_items):
|
1595 |
item['response'] = res
|
1596 |
|
|
|
1597 |
save_path = BATCH_INFER_SAVE_TMP_FILE
|
1598 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
1599 |
with open(save_path, 'w', encoding='utf-8') as f:
|
@@ -1608,14 +1552,60 @@ def batch_inference(
|
|
1608 |
|
1609 |
|
1610 |
# BATCH_INFER_MAX_ITEMS
|
1611 |
-
|
1612 |
-
|
1613 |
```
|
1614 |
-
[ {
|
1615 |
```
|
1616 |
"""
|
1617 |
|
1618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1619 |
def launch():
|
1620 |
global demo, llm, DEBUG, LOG_FILE
|
1621 |
model_desc = MODEL_DESC
|
@@ -1713,33 +1703,33 @@ def launch():
|
|
1713 |
|
1714 |
if ENABLE_BATCH_INFER:
|
1715 |
|
1716 |
-
|
1717 |
batch_inference,
|
1718 |
inputs=[
|
1719 |
gr.File(file_count='single', file_types=['json']),
|
1720 |
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1721 |
-
gr.Number(value=temperature, label='Temperature
|
1722 |
-
gr.Number(value=max_tokens, label='Max tokens
|
1723 |
-
gr.Number(value=frequence_penalty, label='Frequency penalty
|
1724 |
-
gr.Number(value=presence_penalty, label='Presence penalty
|
1725 |
-
gr.Textbox(value="[STOP],[END],<s>,</s>", label='
|
1726 |
gr.Number(value=0, label='current_time', visible=False),
|
1727 |
],
|
1728 |
outputs=[
|
1729 |
# "file",
|
1730 |
gr.File(label="Generated file"),
|
|
|
1731 |
# "json"
|
1732 |
-
gr.JSON(label='Example outputs (
|
1733 |
],
|
1734 |
-
|
1735 |
-
|
1736 |
-
|
1737 |
-
|
1738 |
-
|
1739 |
-
],
|
1740 |
-
# cache_examples=True,
|
1741 |
)
|
1742 |
|
|
|
1743 |
demo_chat = gr.ChatInterface(
|
1744 |
response_fn,
|
1745 |
chatbot=ChatBot(
|
@@ -1767,8 +1757,8 @@ def launch():
|
|
1767 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
1768 |
],
|
1769 |
)
|
1770 |
-
demo =
|
1771 |
-
interface_list=[demo_chat,
|
1772 |
tab_names=["Chat Interface", "Batch Inference"],
|
1773 |
title=f"{model_title}",
|
1774 |
description=f"{model_desc}",
|
@@ -1834,4 +1824,3 @@ def main():
|
|
1834 |
|
1835 |
if __name__ == "__main__":
|
1836 |
main()
|
1837 |
-
|
|
|
6 |
VLLM-based demo script to launch Language chat model for Southeast Asian Languages
|
7 |
"""
|
8 |
|
9 |
+
|
10 |
import os
|
11 |
import numpy as np
|
12 |
import argparse
|
|
|
973 |
gr.ChatInterface._setup_events = _setup_events
|
974 |
|
975 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
976 |
def vllm_abort(self: Any):
|
977 |
sh = self.llm_engine.scheduler
|
978 |
for g in (sh.waiting + sh.running + sh.swapped):
|
|
|
1251 |
|
1252 |
def maybe_upload_to_dataset():
|
1253 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1254 |
+
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH is not "":
|
1255 |
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1256 |
convos = {}
|
1257 |
for l in f:
|
|
|
1350 |
except Exception as e:
|
1351 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1352 |
|
|
|
1353 |
AGREE_POP_SCRIPTS = """
|
1354 |
async () => {
|
1355 |
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
|
|
1366 |
stop_strings: str = "[STOP],<s>,</s>",
|
1367 |
current_time: Optional[float] = None,
|
1368 |
):
|
|
|
1369 |
files = files if isinstance(files, list) else [files]
|
1370 |
print(files)
|
1371 |
filenames = [f.name for f in files]
|
|
|
1391 |
|
1392 |
|
1393 |
def validate_file_item(filename, index, item: Dict[str, str]):
|
1394 |
+
# BATCH_INFER_MAX_PROMPT_TOKENS
|
|
|
|
|
1395 |
message = item['prompt'].strip()
|
1396 |
|
1397 |
if len(message) == 0:
|
|
|
1399 |
|
1400 |
message_safety = safety_check(message, history=None)
|
1401 |
if message_safety is not None:
|
1402 |
+
raise gr.Error(f'Prompt {index} unsafe or supported: {message_safety}')
|
1403 |
|
1404 |
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1405 |
if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
|
|
1423 |
validate_file_item(fname, i, x)
|
1424 |
|
1425 |
all_items.extend(items)
|
|
|
1426 |
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1427 |
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1428 |
|
1429 |
+
return all_items
|
1430 |
|
1431 |
|
1432 |
+
def remove_gradio_cache():
|
|
|
1433 |
import shutil
|
1434 |
for root, dirs, files in os.walk('/tmp/gradio/'):
|
1435 |
for f in files:
|
1436 |
+
os.unlink(os.path.join(root, f))
|
1437 |
+
for d in dirs:
|
1438 |
+
shutil.rmtree(os.path.join(root, d))
|
|
|
|
|
|
|
|
|
|
|
|
|
1439 |
|
1440 |
|
1441 |
def maybe_upload_batch_set(pred_json_path):
|
1442 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1443 |
|
1444 |
+
if SAVE_LOGS and DATA_SET_REPO_PATH is not "":
|
1445 |
try:
|
1446 |
from huggingface_hub import upload_file
|
1447 |
path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
|
|
|
1470 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1471 |
):
|
1472 |
"""
|
1473 |
+
Must handle
|
1474 |
|
1475 |
"""
|
1476 |
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
|
|
1493 |
frequency_penalty = float(frequency_penalty)
|
1494 |
max_tokens = int(max_tokens)
|
1495 |
|
1496 |
+
all_items = read_validate_json_files(files)
|
1497 |
|
1498 |
# remove all items in /tmp/gradio/
|
1499 |
+
remove_gradio_cache()
|
1500 |
+
|
1501 |
|
1502 |
if prompt_mode == 'chat':
|
1503 |
prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
|
|
|
1537 |
for res, item in zip(responses, all_items):
|
1538 |
item['response'] = res
|
1539 |
|
1540 |
+
# save_path = "/mnt/workspace/workgroup/phi/test.json"
|
1541 |
save_path = BATCH_INFER_SAVE_TMP_FILE
|
1542 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
1543 |
with open(save_path, 'w', encoding='utf-8') as f:
|
|
|
1552 |
|
1553 |
|
1554 |
# BATCH_INFER_MAX_ITEMS
|
1555 |
+
FILE_UPLOAD_DESC = f"""File upload json format, with JSON object as list of dict with < {BATCH_INFER_MAX_ITEMS} items"""
|
1556 |
+
FILE_UPLOAD_DESCRIPTION = FILE_UPLOAD_DESC + """
|
1557 |
```
|
1558 |
+
[ {\"id\": 0, \"prompt\": \"Hello world\"} , {\"id\": 1, \"prompt\": \"Hi there?\"}]
|
1559 |
```
|
1560 |
"""
|
1561 |
|
1562 |
|
1563 |
+
# https://huggingface.co/spaces/yuntian-deng/ChatGPT4Turbo/blob/main/app.py
|
1564 |
+
@document()
|
1565 |
+
class CusTabbedInterface(gr.Blocks):
|
1566 |
+
def __init__(
|
1567 |
+
self,
|
1568 |
+
interface_list: list[gr.Interface],
|
1569 |
+
tab_names: Optional[list[str]] = None,
|
1570 |
+
title: Optional[str] = None,
|
1571 |
+
description: Optional[str] = None,
|
1572 |
+
theme: Optional[gr.Theme] = None,
|
1573 |
+
analytics_enabled: Optional[bool] = None,
|
1574 |
+
css: Optional[str] = None,
|
1575 |
+
):
|
1576 |
+
"""
|
1577 |
+
Parameters:
|
1578 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
1579 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
1580 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
1581 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
1582 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
1583 |
+
Returns:
|
1584 |
+
a Gradio Tabbed Interface for the given interfaces
|
1585 |
+
"""
|
1586 |
+
super().__init__(
|
1587 |
+
title=title or "Gradio",
|
1588 |
+
theme=theme,
|
1589 |
+
analytics_enabled=analytics_enabled,
|
1590 |
+
mode="tabbed_interface",
|
1591 |
+
css=css,
|
1592 |
+
)
|
1593 |
+
self.description = description
|
1594 |
+
if tab_names is None:
|
1595 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
1596 |
+
with self:
|
1597 |
+
if title:
|
1598 |
+
gr.Markdown(
|
1599 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
1600 |
+
)
|
1601 |
+
if description:
|
1602 |
+
gr.Markdown(description)
|
1603 |
+
with gr.Tabs():
|
1604 |
+
for interface, tab_name in zip(interface_list, tab_names):
|
1605 |
+
with gr.Tab(label=tab_name):
|
1606 |
+
interface.render()
|
1607 |
+
|
1608 |
+
|
1609 |
def launch():
|
1610 |
global demo, llm, DEBUG, LOG_FILE
|
1611 |
model_desc = MODEL_DESC
|
|
|
1703 |
|
1704 |
if ENABLE_BATCH_INFER:
|
1705 |
|
1706 |
+
demo_file = gr.Interface(
|
1707 |
batch_inference,
|
1708 |
inputs=[
|
1709 |
gr.File(file_count='single', file_types=['json']),
|
1710 |
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1711 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1712 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1713 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1714 |
+
gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1715 |
+
gr.Textbox(value="[STOP],[END],<s>,</s>", label='Comma-separated STOP string to stop generation only in few-shot mode', lines=1),
|
1716 |
gr.Number(value=0, label='current_time', visible=False),
|
1717 |
],
|
1718 |
outputs=[
|
1719 |
# "file",
|
1720 |
gr.File(label="Generated file"),
|
1721 |
+
# gr.Textbox(),
|
1722 |
# "json"
|
1723 |
+
gr.JSON(label='Example outputs (max 2 samples)')
|
1724 |
],
|
1725 |
+
# examples=[[[os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
|
1726 |
+
# os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
|
1727 |
+
# os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]],
|
1728 |
+
# cache_examples=True
|
1729 |
+
description=FILE_UPLOAD_DESCRIPTION
|
|
|
|
|
1730 |
)
|
1731 |
|
1732 |
+
|
1733 |
demo_chat = gr.ChatInterface(
|
1734 |
response_fn,
|
1735 |
chatbot=ChatBot(
|
|
|
1757 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
1758 |
],
|
1759 |
)
|
1760 |
+
demo = CusTabbedInterface(
|
1761 |
+
interface_list=[demo_chat, demo_file],
|
1762 |
tab_names=["Chat Interface", "Batch Inference"],
|
1763 |
title=f"{model_title}",
|
1764 |
description=f"{model_desc}",
|
|
|
1824 |
|
1825 |
if __name__ == "__main__":
|
1826 |
main()
|
|