curry tang
commited on
Commit
•
b6ec8b9
1
Parent(s):
99a9a6e
update
Browse files
app.py
CHANGED
@@ -2,7 +2,10 @@ import gradio as gr
|
|
2 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
3 |
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
|
4 |
from config import settings
|
5 |
-
from prompts import
|
|
|
|
|
|
|
6 |
from langchain_core.prompts import PromptTemplate
|
7 |
from log import logging
|
8 |
from utils import convert_image_to_base64
|
@@ -21,6 +24,12 @@ provider_model_map = dict(
|
|
21 |
Tongyi=tongyi_llm,
|
22 |
)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
support_vision_models = [
|
25 |
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
|
26 |
'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
|
@@ -33,29 +42,39 @@ def get_default_chat():
|
|
33 |
return _llm.get_chat_engine()
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def predict(message, history, _chat, _current_assistant: str):
|
37 |
logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
|
38 |
files_len = len(message.files)
|
39 |
-
|
40 |
-
_chat = get_default_chat()
|
41 |
if files_len > 0:
|
42 |
if _chat.model_name not in support_vision_models:
|
43 |
raise gr.Error("当前模型不支持图片,请更换模型。")
|
44 |
|
45 |
_lc_history = []
|
46 |
-
|
47 |
-
|
48 |
-
assistant_prompt = backend_developer_prompt
|
49 |
-
if _current_assistant == '数据分析师':
|
50 |
-
assistant_prompt = analyst_prompt
|
51 |
-
_lc_history.append(SystemMessage(content=assistant_prompt))
|
52 |
-
|
53 |
-
for his_msg in history:
|
54 |
-
if his_msg['role'] == 'user':
|
55 |
-
if not hasattr(his_msg['content'], 'file'):
|
56 |
-
_lc_history.append(HumanMessage(content=his_msg['content']))
|
57 |
-
if his_msg['role'] == 'assistant':
|
58 |
-
_lc_history.append(AIMessage(content=his_msg['content']))
|
59 |
|
60 |
if files_len == 0:
|
61 |
_lc_history.append(HumanMessage(content=message.text))
|
@@ -81,8 +100,7 @@ def update_chat(_provider: str, _model: str, _temperature: float, _max_tokens: i
|
|
81 |
|
82 |
|
83 |
def explain_code(_code_type: str, _code: str, _chat):
|
84 |
-
|
85 |
-
_chat = get_default_chat()
|
86 |
chat_messages = [
|
87 |
SystemMessage(content=explain_code_template),
|
88 |
HumanMessage(content=_code),
|
@@ -94,8 +112,7 @@ def explain_code(_code_type: str, _code: str, _chat):
|
|
94 |
|
95 |
|
96 |
def optimize_code(_code_type: str, _code: str, _chat):
|
97 |
-
|
98 |
-
_chat = get_default_chat()
|
99 |
prompt = PromptTemplate.from_template(optimize_code_template)
|
100 |
prompt = prompt.format(code_type=_code_type)
|
101 |
chat_messages = [
|
@@ -109,8 +126,7 @@ def optimize_code(_code_type: str, _code: str, _chat):
|
|
109 |
|
110 |
|
111 |
def debug_code(_code_type: str, _code: str, _chat):
|
112 |
-
|
113 |
-
_chat = get_default_chat()
|
114 |
prompt = PromptTemplate.from_template(debug_code_template)
|
115 |
prompt = prompt.format(code_type=_code_type)
|
116 |
chat_messages = [
|
@@ -124,8 +140,7 @@ def debug_code(_code_type: str, _code: str, _chat):
|
|
124 |
|
125 |
|
126 |
def function_gen(_code_type: str, _code: str, _chat):
|
127 |
-
|
128 |
-
_chat = get_default_chat()
|
129 |
prompt = PromptTemplate.from_template(function_gen_template)
|
130 |
prompt = prompt.format(code_type=_code_type)
|
131 |
chat_messages = [
|
@@ -139,8 +154,7 @@ def function_gen(_code_type: str, _code: str, _chat):
|
|
139 |
|
140 |
|
141 |
def translate_doc(_language_input, _language_output, _doc, _chat):
|
142 |
-
|
143 |
-
_chat = get_default_chat()
|
144 |
prompt = PromptTemplate.from_template(translate_doc_template)
|
145 |
prompt = prompt.format(language_input=_language_input, language_output=_language_output)
|
146 |
chat_messages = [
|
|
|
2 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
3 |
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
|
4 |
from config import settings
|
5 |
+
from prompts import (
|
6 |
+
web_prompt, explain_code_template, optimize_code_template, debug_code_template,
|
7 |
+
function_gen_template, translate_doc_template, backend_developer_prompt, analyst_prompt
|
8 |
+
)
|
9 |
from langchain_core.prompts import PromptTemplate
|
10 |
from log import logging
|
11 |
from utils import convert_image_to_base64
|
|
|
24 |
Tongyi=tongyi_llm,
|
25 |
)
|
26 |
|
27 |
+
system_prompt_map = {
|
28 |
+
"前端开发助手": web_prompt,
|
29 |
+
"后端开发助手": backend_developer_prompt,
|
30 |
+
"数据分析师": analyst_prompt,
|
31 |
+
}
|
32 |
+
|
33 |
support_vision_models = [
|
34 |
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
|
35 |
'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
|
|
|
42 |
return _llm.get_chat_engine()
|
43 |
|
44 |
|
45 |
+
def get_chat_or_default(chat):
|
46 |
+
if chat is None:
|
47 |
+
chat = get_default_chat()
|
48 |
+
return chat
|
49 |
+
|
50 |
+
|
51 |
+
def convert_history_to_langchain_history(history, lc_history):
|
52 |
+
for his_msg in history:
|
53 |
+
if his_msg['role'] == 'user':
|
54 |
+
if not hasattr(his_msg['content'], 'file'):
|
55 |
+
lc_history.append(HumanMessage(content=his_msg['content']))
|
56 |
+
if his_msg['role'] == 'assistant':
|
57 |
+
lc_history.append(AIMessage(content=his_msg['content']))
|
58 |
+
return lc_history
|
59 |
+
|
60 |
+
|
61 |
+
def append_system_prompt(key: str, lc_history):
|
62 |
+
prompt = system_prompt_map[key]
|
63 |
+
lc_history.append(SystemMessage(content=prompt))
|
64 |
+
return lc_history
|
65 |
+
|
66 |
+
|
67 |
def predict(message, history, _chat, _current_assistant: str):
|
68 |
logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
|
69 |
files_len = len(message.files)
|
70 |
+
_chat = get_chat_or_default(_chat)
|
|
|
71 |
if files_len > 0:
|
72 |
if _chat.model_name not in support_vision_models:
|
73 |
raise gr.Error("当前模型不支持图片,请更换模型。")
|
74 |
|
75 |
_lc_history = []
|
76 |
+
_lc_history = append_system_prompt(_current_assistant, _lc_history)
|
77 |
+
_lc_history = convert_history_to_langchain_history(history, _lc_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
if files_len == 0:
|
80 |
_lc_history.append(HumanMessage(content=message.text))
|
|
|
100 |
|
101 |
|
102 |
def explain_code(_code_type: str, _code: str, _chat):
|
103 |
+
_chat = get_chat_or_default(_chat)
|
|
|
104 |
chat_messages = [
|
105 |
SystemMessage(content=explain_code_template),
|
106 |
HumanMessage(content=_code),
|
|
|
112 |
|
113 |
|
114 |
def optimize_code(_code_type: str, _code: str, _chat):
|
115 |
+
_chat = get_chat_or_default(_chat)
|
|
|
116 |
prompt = PromptTemplate.from_template(optimize_code_template)
|
117 |
prompt = prompt.format(code_type=_code_type)
|
118 |
chat_messages = [
|
|
|
126 |
|
127 |
|
128 |
def debug_code(_code_type: str, _code: str, _chat):
|
129 |
+
_chat = get_chat_or_default(_chat)
|
|
|
130 |
prompt = PromptTemplate.from_template(debug_code_template)
|
131 |
prompt = prompt.format(code_type=_code_type)
|
132 |
chat_messages = [
|
|
|
140 |
|
141 |
|
142 |
def function_gen(_code_type: str, _code: str, _chat):
|
143 |
+
_chat = get_chat_or_default(_chat)
|
|
|
144 |
prompt = PromptTemplate.from_template(function_gen_template)
|
145 |
prompt = prompt.format(code_type=_code_type)
|
146 |
chat_messages = [
|
|
|
154 |
|
155 |
|
156 |
def translate_doc(_language_input, _language_output, _doc, _chat):
|
157 |
+
_chat = get_chat_or_default(_chat)
|
|
|
158 |
prompt = PromptTemplate.from_template(translate_doc_template)
|
159 |
prompt = prompt.format(language_input=_language_input, language_output=_language_output)
|
160 |
chat_messages = [
|
llm.py
CHANGED
@@ -60,8 +60,8 @@ class DeepSeekLLM(BaseLLM):
|
|
60 |
|
61 |
class OpenRouterLLM(BaseLLM):
|
62 |
_support_models = [
|
63 |
-
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', '
|
64 |
-
'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct',
|
65 |
'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
|
66 |
'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
|
67 |
'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
|
|
|
60 |
|
61 |
class OpenRouterLLM(BaseLLM):
|
62 |
_support_models = [
|
63 |
+
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o-2024-08-06',
|
64 |
+
'google/gemini-pro-1.5-exp', 'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct',
|
65 |
'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
|
66 |
'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
|
67 |
'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
|