Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -973,6 +973,53 @@ gr.ChatInterface._setup_stop_events = _setup_stop_events
|
|
973 |
gr.ChatInterface._setup_events = _setup_events
|
974 |
|
975 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
976 |
def vllm_abort(self: Any):
|
977 |
sh = self.llm_engine.scheduler
|
978 |
for g in (sh.waiting + sh.running + sh.swapped):
|
@@ -1251,7 +1298,7 @@ def format_conversation(history):
|
|
1251 |
|
1252 |
def maybe_upload_to_dataset():
|
1253 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1254 |
-
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH
|
1255 |
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1256 |
convos = {}
|
1257 |
for l in f:
|
@@ -1350,6 +1397,7 @@ def maybe_delete_folder():
|
|
1350 |
except Exception as e:
|
1351 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1352 |
|
|
|
1353 |
AGREE_POP_SCRIPTS = """
|
1354 |
async () => {
|
1355 |
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
@@ -1366,6 +1414,7 @@ def debug_file_function(
|
|
1366 |
stop_strings: str = "[STOP],<s>,</s>",
|
1367 |
current_time: Optional[float] = None,
|
1368 |
):
|
|
|
1369 |
files = files if isinstance(files, list) else [files]
|
1370 |
print(files)
|
1371 |
filenames = [f.name for f in files]
|
@@ -1391,7 +1440,9 @@ def debug_file_function(
|
|
1391 |
|
1392 |
|
1393 |
def validate_file_item(filename, index, item: Dict[str, str]):
|
1394 |
-
|
|
|
|
|
1395 |
message = item['prompt'].strip()
|
1396 |
|
1397 |
if len(message) == 0:
|
@@ -1399,7 +1450,7 @@ def validate_file_item(filename, index, item: Dict[str, str]):
|
|
1399 |
|
1400 |
message_safety = safety_check(message, history=None)
|
1401 |
if message_safety is not None:
|
1402 |
-
raise gr.Error(f'Prompt {index}
|
1403 |
|
1404 |
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1405 |
if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
@@ -1423,25 +1474,33 @@ def read_validate_json_files(files: Union[str, List[str]]):
|
|
1423 |
validate_file_item(fname, i, x)
|
1424 |
|
1425 |
all_items.extend(items)
|
|
|
1426 |
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1427 |
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1428 |
|
1429 |
-
return all_items
|
1430 |
|
1431 |
|
1432 |
-
def remove_gradio_cache():
|
|
|
1433 |
import shutil
|
1434 |
for root, dirs, files in os.walk('/tmp/gradio/'):
|
1435 |
for f in files:
|
1436 |
-
|
1437 |
-
|
1438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1439 |
|
1440 |
|
1441 |
def maybe_upload_batch_set(pred_json_path):
|
1442 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1443 |
|
1444 |
-
if SAVE_LOGS and DATA_SET_REPO_PATH
|
1445 |
try:
|
1446 |
from huggingface_hub import upload_file
|
1447 |
path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
|
@@ -1470,7 +1529,7 @@ def batch_inference(
|
|
1470 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1471 |
):
|
1472 |
"""
|
1473 |
-
|
1474 |
|
1475 |
"""
|
1476 |
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
@@ -1493,11 +1552,10 @@ def batch_inference(
|
|
1493 |
frequency_penalty = float(frequency_penalty)
|
1494 |
max_tokens = int(max_tokens)
|
1495 |
|
1496 |
-
all_items = read_validate_json_files(files)
|
1497 |
|
1498 |
# remove all items in /tmp/gradio/
|
1499 |
-
remove_gradio_cache()
|
1500 |
-
|
1501 |
|
1502 |
if prompt_mode == 'chat':
|
1503 |
prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
|
@@ -1552,60 +1610,14 @@ def batch_inference(
|
|
1552 |
|
1553 |
|
1554 |
# BATCH_INFER_MAX_ITEMS
|
1555 |
-
|
1556 |
-
|
1557 |
```
|
1558 |
-
[ {
|
1559 |
```
|
1560 |
"""
|
1561 |
|
1562 |
|
1563 |
-
# https://huggingface.co/spaces/yuntian-deng/ChatGPT4Turbo/blob/main/app.py
|
1564 |
-
@document()
|
1565 |
-
class CusTabbedInterface(gr.Blocks):
|
1566 |
-
def __init__(
|
1567 |
-
self,
|
1568 |
-
interface_list: list[gr.Interface],
|
1569 |
-
tab_names: Optional[list[str]] = None,
|
1570 |
-
title: Optional[str] = None,
|
1571 |
-
description: Optional[str] = None,
|
1572 |
-
theme: Optional[gr.Theme] = None,
|
1573 |
-
analytics_enabled: Optional[bool] = None,
|
1574 |
-
css: Optional[str] = None,
|
1575 |
-
):
|
1576 |
-
"""
|
1577 |
-
Parameters:
|
1578 |
-
interface_list: a list of interfaces to be rendered in tabs.
|
1579 |
-
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
1580 |
-
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
1581 |
-
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
1582 |
-
css: custom css or path to custom css file to apply to entire Blocks
|
1583 |
-
Returns:
|
1584 |
-
a Gradio Tabbed Interface for the given interfaces
|
1585 |
-
"""
|
1586 |
-
super().__init__(
|
1587 |
-
title=title or "Gradio",
|
1588 |
-
theme=theme,
|
1589 |
-
analytics_enabled=analytics_enabled,
|
1590 |
-
mode="tabbed_interface",
|
1591 |
-
css=css,
|
1592 |
-
)
|
1593 |
-
self.description = description
|
1594 |
-
if tab_names is None:
|
1595 |
-
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
1596 |
-
with self:
|
1597 |
-
if title:
|
1598 |
-
gr.Markdown(
|
1599 |
-
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
1600 |
-
)
|
1601 |
-
if description:
|
1602 |
-
gr.Markdown(description)
|
1603 |
-
with gr.Tabs():
|
1604 |
-
for interface, tab_name in zip(interface_list, tab_names):
|
1605 |
-
with gr.Tab(label=tab_name):
|
1606 |
-
interface.render()
|
1607 |
-
|
1608 |
-
|
1609 |
def launch():
|
1610 |
global demo, llm, DEBUG, LOG_FILE
|
1611 |
model_desc = MODEL_DESC
|
@@ -1703,33 +1715,33 @@ def launch():
|
|
1703 |
|
1704 |
if ENABLE_BATCH_INFER:
|
1705 |
|
1706 |
-
|
1707 |
batch_inference,
|
1708 |
inputs=[
|
1709 |
gr.File(file_count='single', file_types=['json']),
|
1710 |
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1711 |
-
gr.Number(value=temperature, label='Temperature
|
1712 |
-
gr.Number(value=max_tokens, label='Max
|
1713 |
-
gr.Number(value=frequence_penalty, label='Frequency penalty
|
1714 |
-
gr.Number(value=presence_penalty, label='Presence penalty
|
1715 |
-
gr.Textbox(value="[STOP],[END],<s>,</s>", label='Comma-separated
|
1716 |
gr.Number(value=0, label='current_time', visible=False),
|
1717 |
],
|
1718 |
outputs=[
|
1719 |
# "file",
|
1720 |
gr.File(label="Generated file"),
|
1721 |
-
# gr.Textbox(),
|
1722 |
# "json"
|
1723 |
-
gr.JSON(label='Example outputs (
|
|
|
|
|
|
|
|
|
|
|
|
|
1724 |
],
|
1725 |
-
#
|
1726 |
-
# os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
|
1727 |
-
# os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]],
|
1728 |
-
# cache_examples=True
|
1729 |
-
description=FILE_UPLOAD_DESCRIPTION
|
1730 |
)
|
1731 |
|
1732 |
-
|
1733 |
demo_chat = gr.ChatInterface(
|
1734 |
response_fn,
|
1735 |
chatbot=ChatBot(
|
@@ -1757,8 +1769,8 @@ def launch():
|
|
1757 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
1758 |
],
|
1759 |
)
|
1760 |
-
demo =
|
1761 |
-
interface_list=[demo_chat,
|
1762 |
tab_names=["Chat Interface", "Batch Inference"],
|
1763 |
title=f"{model_title}",
|
1764 |
description=f"{model_desc}",
|
|
|
973 |
gr.ChatInterface._setup_events = _setup_events
|
974 |
|
975 |
|
976 |
+
|
977 |
+
@document()
|
978 |
+
class CustomTabbedInterface(gr.Blocks):
|
979 |
+
def __init__(
|
980 |
+
self,
|
981 |
+
interface_list: list[gr.Interface],
|
982 |
+
tab_names: Optional[list[str]] = None,
|
983 |
+
title: Optional[str] = None,
|
984 |
+
description: Optional[str] = None,
|
985 |
+
theme: Optional[gr.Theme] = None,
|
986 |
+
analytics_enabled: Optional[bool] = None,
|
987 |
+
css: Optional[str] = None,
|
988 |
+
):
|
989 |
+
"""
|
990 |
+
Parameters:
|
991 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
992 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
993 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
994 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
995 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
996 |
+
Returns:
|
997 |
+
a Gradio Tabbed Interface for the given interfaces
|
998 |
+
"""
|
999 |
+
super().__init__(
|
1000 |
+
title=title or "Gradio",
|
1001 |
+
theme=theme,
|
1002 |
+
analytics_enabled=analytics_enabled,
|
1003 |
+
mode="tabbed_interface",
|
1004 |
+
css=css,
|
1005 |
+
)
|
1006 |
+
self.description = description
|
1007 |
+
if tab_names is None:
|
1008 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
1009 |
+
with self:
|
1010 |
+
if title:
|
1011 |
+
gr.Markdown(
|
1012 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
1013 |
+
)
|
1014 |
+
if description:
|
1015 |
+
gr.Markdown(description)
|
1016 |
+
with gr.Tabs():
|
1017 |
+
for interface, tab_name in zip(interface_list, tab_names):
|
1018 |
+
with gr.Tab(label=tab_name):
|
1019 |
+
interface.render()
|
1020 |
+
|
1021 |
+
|
1022 |
+
|
1023 |
def vllm_abort(self: Any):
|
1024 |
sh = self.llm_engine.scheduler
|
1025 |
for g in (sh.waiting + sh.running + sh.swapped):
|
|
|
1298 |
|
1299 |
def maybe_upload_to_dataset():
|
1300 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1301 |
+
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
|
1302 |
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1303 |
convos = {}
|
1304 |
for l in f:
|
|
|
1397 |
except Exception as e:
|
1398 |
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1399 |
|
1400 |
+
|
1401 |
AGREE_POP_SCRIPTS = """
|
1402 |
async () => {
|
1403 |
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
|
|
1414 |
stop_strings: str = "[STOP],<s>,</s>",
|
1415 |
current_time: Optional[float] = None,
|
1416 |
):
|
1417 |
+
"""This is only for debug purpose"""
|
1418 |
files = files if isinstance(files, list) else [files]
|
1419 |
print(files)
|
1420 |
filenames = [f.name for f in files]
|
|
|
1440 |
|
1441 |
|
1442 |
def validate_file_item(filename, index, item: Dict[str, str]):
|
1443 |
+
"""
|
1444 |
+
check safety for items in files
|
1445 |
+
"""
|
1446 |
message = item['prompt'].strip()
|
1447 |
|
1448 |
if len(message) == 0:
|
|
|
1450 |
|
1451 |
message_safety = safety_check(message, history=None)
|
1452 |
if message_safety is not None:
|
1453 |
+
raise gr.Error(f'Prompt {index} invalid: {message_safety}')
|
1454 |
|
1455 |
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1456 |
if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
|
|
1474 |
validate_file_item(fname, i, x)
|
1475 |
|
1476 |
all_items.extend(items)
|
1477 |
+
|
1478 |
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1479 |
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1480 |
|
1481 |
+
return all_items, filenames
|
1482 |
|
1483 |
|
1484 |
+
def remove_gradio_cache(exclude_names=None):
|
1485 |
+
"""remove gradio cache to avoid flooding"""
|
1486 |
import shutil
|
1487 |
for root, dirs, files in os.walk('/tmp/gradio/'):
|
1488 |
for f in files:
|
1489 |
+
# if not any(f in ef for ef in except_files):
|
1490 |
+
if exclude_names is None or not any(ef in f for ef in exclude_names):
|
1491 |
+
print(f'Remove: {f}')
|
1492 |
+
os.unlink(os.path.join(root, f))
|
1493 |
+
# for d in dirs:
|
1494 |
+
# # if not any(d in ef for ef in except_files):
|
1495 |
+
# if exclude_names is None or not any(ef in d for ef in exclude_names):
|
1496 |
+
# print(f'Remove d: {d}')
|
1497 |
+
# shutil.rmtree(os.path.join(root, d))
|
1498 |
|
1499 |
|
1500 |
def maybe_upload_batch_set(pred_json_path):
|
1501 |
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1502 |
|
1503 |
+
if SAVE_LOGS and DATA_SET_REPO_PATH != "":
|
1504 |
try:
|
1505 |
from huggingface_hub import upload_file
|
1506 |
path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
|
|
|
1529 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1530 |
):
|
1531 |
"""
|
1532 |
+
Handle file upload batch inference
|
1533 |
|
1534 |
"""
|
1535 |
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
|
|
1552 |
frequency_penalty = float(frequency_penalty)
|
1553 |
max_tokens = int(max_tokens)
|
1554 |
|
1555 |
+
all_items, filenames = read_validate_json_files(files)
|
1556 |
|
1557 |
# remove all items in /tmp/gradio/
|
1558 |
+
remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
|
|
|
1559 |
|
1560 |
if prompt_mode == 'chat':
|
1561 |
prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
|
|
|
1610 |
|
1611 |
|
1612 |
# BATCH_INFER_MAX_ITEMS
|
1613 |
+
FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
|
1614 |
+
each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
|
1615 |
```
|
1616 |
+
[ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
|
1617 |
```
|
1618 |
"""
|
1619 |
|
1620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1621 |
def launch():
|
1622 |
global demo, llm, DEBUG, LOG_FILE
|
1623 |
model_desc = MODEL_DESC
|
|
|
1715 |
|
1716 |
if ENABLE_BATCH_INFER:
|
1717 |
|
1718 |
+
demo_file_upload = gr.Interface(
|
1719 |
batch_inference,
|
1720 |
inputs=[
|
1721 |
gr.File(file_count='single', file_types=['json']),
|
1722 |
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1723 |
+
gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
|
1724 |
+
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
1725 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
1726 |
+
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
1727 |
+
gr.Textbox(value="[STOP],[END],<s>,</s>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
1728 |
gr.Number(value=0, label='current_time', visible=False),
|
1729 |
],
|
1730 |
outputs=[
|
1731 |
# "file",
|
1732 |
gr.File(label="Generated file"),
|
|
|
1733 |
# "json"
|
1734 |
+
gr.JSON(label='Example outputs (display 2 samples)')
|
1735 |
+
],
|
1736 |
+
description=FILE_UPLOAD_DESCRIPTION,
|
1737 |
+
allow_flagging=False,
|
1738 |
+
examples=[
|
1739 |
+
["examples/upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "[STOP],[END],<s>,</s>"],
|
1740 |
+
["examples/upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "[STOP],[END],<s>,</s>,\\n"]
|
1741 |
],
|
1742 |
+
# cache_examples=True,
|
|
|
|
|
|
|
|
|
1743 |
)
|
1744 |
|
|
|
1745 |
demo_chat = gr.ChatInterface(
|
1746 |
response_fn,
|
1747 |
chatbot=ChatBot(
|
|
|
1769 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
1770 |
],
|
1771 |
)
|
1772 |
+
demo = CustomTabbedInterface(
|
1773 |
+
interface_list=[demo_chat, demo_file_upload],
|
1774 |
tab_names=["Chat Interface", "Batch Inference"],
|
1775 |
title=f"{model_title}",
|
1776 |
description=f"{model_desc}",
|