|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import re |
|
import requests |
|
from typing import List, Dict, Tuple |
|
from toolbox import get_conf, encode_image, get_pictures_list, to_markdown_tabs |
|
|
|
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS") |
|
|
|
""" |
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- |
|
第五部分 一些文件处理方法 |
|
files_filter_handler 根据type过滤文件 |
|
input_encode_handler 提取input中的文件,并解析 |
|
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本 |
|
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件 |
|
html_view_blank 超链接 |
|
html_local_file 本地文件取相对路径 |
|
to_markdown_tabs 文件list 转换为 md tab |
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- |
|
""" |
|
|
|
|
|
def files_filter_handler(file_list): |
|
new_list = [] |
|
filter_ = [ |
|
"png", |
|
"jpg", |
|
"jpeg", |
|
"bmp", |
|
"svg", |
|
"webp", |
|
"ico", |
|
"tif", |
|
"tiff", |
|
"raw", |
|
"eps", |
|
] |
|
for file in file_list: |
|
file = str(file).replace("file=", "") |
|
if os.path.exists(file): |
|
if str(os.path.basename(file)).split(".")[-1] in filter_: |
|
new_list.append(file) |
|
return new_list |
|
|
|
|
|
def input_encode_handler(inputs, llm_kwargs): |
|
if llm_kwargs["most_recent_uploaded"].get("path"): |
|
image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"]) |
|
md_encode = [] |
|
for md_path in image_paths: |
|
type_ = os.path.splitext(md_path)[1].replace(".", "") |
|
type_ = "jpeg" if type_ == "jpg" else type_ |
|
md_encode.append({"data": encode_image(md_path), "type": type_}) |
|
return inputs, md_encode |
|
|
|
|
|
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False): |
|
new_list = [] |
|
if not filter_: |
|
filter_ = [ |
|
"png", |
|
"jpg", |
|
"jpeg", |
|
"bmp", |
|
"svg", |
|
"webp", |
|
"ico", |
|
"tif", |
|
"tiff", |
|
"raw", |
|
"eps", |
|
] |
|
for file in file_list: |
|
if str(os.path.basename(file)).split(".")[-1] in filter_: |
|
new_list.append(html_local_img(file, md=md_type)) |
|
elif os.path.exists(file): |
|
new_list.append(link_mtime_to_md(file)) |
|
else: |
|
new_list.append(file) |
|
return new_list |
|
|
|
|
|
def link_mtime_to_md(file): |
|
link_local = html_local_file(file) |
|
link_name = os.path.basename(file) |
|
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})" |
|
return a |
|
|
|
|
|
def html_local_file(file): |
|
base_path = os.path.dirname(__file__) |
|
if os.path.exists(str(file)): |
|
file = f'file={file.replace(base_path, ".")}' |
|
return file |
|
|
|
|
|
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True): |
|
style = "" |
|
if max_width is not None: |
|
style += f"max-width: {max_width};" |
|
if max_height is not None: |
|
style += f"max-height: {max_height};" |
|
__file = html_local_file(__file) |
|
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>' |
|
if md: |
|
a = f"![{__file}]({__file})" |
|
return a |
|
|
|
|
|
|
|
class GoogleChatInit: |
|
def __init__(self): |
|
self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k" |
|
|
|
def generate_chat(self, inputs, llm_kwargs, history, system_prompt): |
|
headers, payload = self.generate_message_payload( |
|
inputs, llm_kwargs, history, system_prompt |
|
) |
|
response = requests.post( |
|
url=self.url_gemini, |
|
headers=headers, |
|
data=json.dumps(payload), |
|
stream=True, |
|
proxies=proxies, |
|
timeout=TIMEOUT_SECONDS, |
|
) |
|
return response.iter_lines() |
|
|
|
def __conversation_user(self, user_input, llm_kwargs): |
|
what_i_have_asked = {"role": "user", "parts": []} |
|
if "vision" not in self.url_gemini: |
|
input_ = user_input |
|
encode_img = [] |
|
else: |
|
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs) |
|
what_i_have_asked["parts"].append({"text": input_}) |
|
if encode_img: |
|
for data in encode_img: |
|
what_i_have_asked["parts"].append( |
|
{ |
|
"inline_data": { |
|
"mime_type": f"image/{data['type']}", |
|
"data": data["data"], |
|
} |
|
} |
|
) |
|
return what_i_have_asked |
|
|
|
def __conversation_history(self, history, llm_kwargs): |
|
messages = [] |
|
conversation_cnt = len(history) // 2 |
|
if conversation_cnt: |
|
for index in range(0, 2 * conversation_cnt, 2): |
|
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs) |
|
what_gpt_answer = { |
|
"role": "model", |
|
"parts": [{"text": history[index + 1]}], |
|
} |
|
messages.append(what_i_have_asked) |
|
messages.append(what_gpt_answer) |
|
return messages |
|
|
|
def generate_message_payload( |
|
self, inputs, llm_kwargs, history, system_prompt |
|
) -> Tuple[Dict, Dict]: |
|
messages = [ |
|
|
|
|
|
|
|
] |
|
self.url_gemini = self.url_gemini.replace( |
|
"%m", llm_kwargs["llm_model"] |
|
).replace("%k", get_conf("GEMINI_API_KEY")) |
|
header = {"Content-Type": "application/json"} |
|
if "vision" not in self.url_gemini: |
|
messages.extend( |
|
self.__conversation_history(history, llm_kwargs) |
|
) |
|
messages.append(self.__conversation_user(inputs, llm_kwargs)) |
|
payload = { |
|
"contents": messages, |
|
"generationConfig": { |
|
|
|
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "), |
|
"temperature": llm_kwargs.get("temperature", 1), |
|
"topP": llm_kwargs.get("top_p", 0.8), |
|
"topK": 10, |
|
}, |
|
} |
|
return header, payload |
|
|
|
|
|
if __name__ == "__main__": |
|
google = GoogleChatInit() |
|
|
|
|
|
|