|
import openai |
|
import time |
|
import urllib.error |
|
import re,json,os,string,random |
|
import folder_paths |
|
import hashlib |
|
import codecs,sys |
|
import importlib.util |
|
|
|
|
|
def is_installed(package): |
|
try: |
|
spec = importlib.util.find_spec(package) |
|
except ModuleNotFoundError: |
|
return False |
|
return spec is not None |
|
|
|
|
|
def get_unique_hash(string): |
|
hash_object = hashlib.sha1(string.encode()) |
|
unique_hash = hash_object.hexdigest() |
|
return unique_hash |
|
|
|
def generate_random_string(length): |
|
letters = string.ascii_letters + string.digits |
|
return ''.join(random.choice(letters) for _ in range(length)) |
|
|
|
class AnyType(str): |
|
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss""" |
|
|
|
def __ne__(self, __value: object) -> bool: |
|
return False |
|
|
|
any_type = AnyType("*") |
|
|
|
|
|
def is_azure_url(url): |
|
pattern = r'.*\.azure\.com$' |
|
if re.match(pattern, url): |
|
return True |
|
else: |
|
return False |
|
|
|
def azure_client(key,url): |
|
client = openai.AzureOpenAI( |
|
api_key=key, |
|
|
|
api_version="2023-07-01-preview", |
|
|
|
azure_endpoint=url |
|
) |
|
return client |
|
|
|
def openai_client(key,url): |
|
client = openai.OpenAI( |
|
api_key=key, |
|
base_url=url |
|
) |
|
return client |
|
|
|
def ZhipuAI_client(key): |
|
|
|
try: |
|
if is_installed('zhipuai')==False: |
|
import subprocess |
|
|
|
|
|
print('#pip install zhipuai') |
|
|
|
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'zhipuai'], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
print("#install success") |
|
from zhipuai import ZhipuAI |
|
else: |
|
print("#install error") |
|
|
|
else: |
|
from zhipuai import ZhipuAI |
|
except: |
|
print("#install zhipuai error") |
|
|
|
client = ZhipuAI( |
|
api_key=key, |
|
) |
|
return client |
|
|
|
|
|
|
|
def phi_sort(lst): |
|
return sorted(lst, key=lambda x: x.lower().count('phi'), reverse=True) |
|
|
|
def get_llama_path(): |
|
try: |
|
return folder_paths.get_folder_paths('llamafile')[0] |
|
except: |
|
return os.path.join(folder_paths.models_dir, "llamafile") |
|
|
|
def get_llama_models(): |
|
res=[] |
|
|
|
model_path=get_llama_path() |
|
if os.path.exists(model_path): |
|
files = os.listdir(model_path) |
|
for file in files: |
|
if os.path.isfile(os.path.join(model_path, file)): |
|
res.append(file) |
|
res=phi_sort(res) |
|
return res |
|
|
|
llama_modes_list=get_llama_models() |
|
|
|
def get_llama_model_path(file_name): |
|
model_path=get_llama_path() |
|
mp=os.path.join(model_path,file_name) |
|
return mp |
|
|
|
def llama_cpp_client(file_name): |
|
try: |
|
if is_installed('llama_cpp')==False: |
|
import subprocess |
|
|
|
|
|
print('#pip install llama-cpp-python') |
|
|
|
result = subprocess.run([sys.executable, '-s', '-m', 'pip', |
|
'install', |
|
'llama-cpp-python', |
|
'--extra-index-url', |
|
'https://abetlen.github.io/llama-cpp-python/whl/cu121' |
|
], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
print("#install success") |
|
from llama_cpp import Llama |
|
|
|
subprocess.run([sys.executable, '-s', '-m', 'pip', |
|
'install', |
|
'llama-cpp-python[server]' |
|
], capture_output=True, text=True) |
|
|
|
else: |
|
print("#install error") |
|
|
|
else: |
|
from llama_cpp import Llama |
|
except: |
|
print("#install llama-cpp-python error") |
|
|
|
if file_name: |
|
mp=get_llama_model_path(file_name) |
|
|
|
|
|
|
|
|
|
llm = Llama(model_path=mp, chat_format="chatml",n_gpu_layers=-1,n_ctx=512) |
|
|
|
return llm |
|
|
|
|
|
|
|
|
|
def chat(client, model_name,messages ): |
|
|
|
try_count = 0 |
|
while True: |
|
try_count += 1 |
|
try: |
|
if hasattr(client, "chat"): |
|
response = client.chat.completions.create( |
|
model=model_name, |
|
messages=messages |
|
) |
|
else: |
|
|
|
response = client.create_chat_completion_openai_v1( |
|
messages=messages, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
break |
|
except openai.AuthenticationError as ex: |
|
raise ex |
|
except (urllib.error.HTTPError, openai.OpenAIError) as ex: |
|
if try_count >= 3: |
|
raise ex |
|
time.sleep(3) |
|
continue |
|
|
|
|
|
finish_reason = response.choices[0].finish_reason |
|
if finish_reason != "stop": |
|
raise RuntimeError("API finished with unexpected reason: " + finish_reason) |
|
|
|
content="" |
|
try: |
|
content=response.choices[0].message.content |
|
except: |
|
content=response.choices[0].delta['content'] |
|
|
|
return content |
|
|
|
|
|
class ChatGPTNode: |
|
def __init__(self): |
|
|
|
self.session_history = [] |
|
|
|
self.system_content="You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible." |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
model_list=llama_modes_list+[ |
|
"gpt-3.5-turbo", |
|
"gpt-3.5-turbo-16k", |
|
"gpt-4o", |
|
"gpt-4o-2024-05-13", |
|
"gpt-4", |
|
"gpt-4-0314", |
|
"gpt-4-0613", |
|
"gpt-3.5-turbo-0301", |
|
"gpt-3.5-turbo-0613", |
|
"gpt-3.5-turbo-16k-0613", |
|
"qwen-turbo", |
|
"qwen-plus", |
|
"qwen-long", |
|
"qwen-max", |
|
"qwen-max-longcontext", |
|
"glm-4", |
|
"glm-3-turbo", |
|
"moonshot-v1-8k", |
|
"moonshot-v1-32k", |
|
"moonshot-v1-128k", |
|
"deepseek-chat" |
|
] |
|
return { |
|
"required": { |
|
"api_key":("KEY", {"default": "", "multiline": True,"dynamicPrompts": False}), |
|
"api_url":("URL", {"default": "", "multiline": True,"dynamicPrompts": False}), |
|
"prompt": ("STRING", {"multiline": True,"dynamicPrompts": False}), |
|
"system_content": ("STRING", |
|
{ |
|
"default": "You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.", |
|
"multiline": True,"dynamicPrompts": False |
|
}), |
|
"model": ( model_list, |
|
{"default": model_list[0]}), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "step": 1}), |
|
"context_size":("INT", {"default": 1, "min": 0, "max":30, "step": 1}), |
|
}, |
|
"hidden": { |
|
"unique_id": "UNIQUE_ID", |
|
"extra_pnginfo": "EXTRA_PNGINFO", |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("STRING","STRING","STRING",) |
|
RETURN_NAMES = ("text","messages","session_history",) |
|
FUNCTION = "generate_contextual_text" |
|
CATEGORY = "♾️Mixlab/GPT" |
|
INPUT_IS_LIST = False |
|
OUTPUT_IS_LIST = (False,False,False,) |
|
|
|
|
|
def generate_contextual_text(self, |
|
api_key, |
|
api_url, |
|
prompt, |
|
system_content, |
|
model, |
|
seed,context_size,unique_id = None, extra_pnginfo=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if system_content: |
|
self.system_content=system_content |
|
|
|
|
|
|
|
|
|
if is_azure_url(api_url): |
|
client=azure_client(api_key,api_url) |
|
else: |
|
|
|
if model == "glm-4" : |
|
client = ZhipuAI_client(api_key) |
|
print('using Zhipuai interface') |
|
elif model in llama_modes_list: |
|
|
|
client=llama_cpp_client(model) |
|
else : |
|
client = openai_client(api_key,api_url) |
|
print('using ChatGPT interface') |
|
|
|
|
|
|
|
|
|
def crop_list_tail(lst, size): |
|
if size >= len(lst): |
|
return lst |
|
elif size==0: |
|
return [] |
|
else: |
|
return lst[-size:] |
|
|
|
session_history=crop_list_tail(self.session_history,context_size) |
|
|
|
messages=[{"role": "system", "content": self.system_content}]+session_history+[{"role": "user", "content": prompt}] |
|
response_content = chat(client,model,messages) |
|
|
|
self.session_history=self.session_history+[{"role": "user", "content": prompt}]+[{'role':'assistant',"content":response_content}] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (response_content,json.dumps(messages, indent=4),json.dumps(self.session_history, indent=4),) |
|
|
|
|
|
|
|
class ShowTextForGPT: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"text": ("STRING", {"forceInput": True,"dynamicPrompts": False}), |
|
}, |
|
"optional":{ |
|
"output_dir": ("STRING",{"forceInput": True,"default": "","multiline": True,"dynamicPrompts": False}), |
|
} |
|
} |
|
|
|
INPUT_IS_LIST = True |
|
RETURN_TYPES = ("STRING",) |
|
FUNCTION = "run" |
|
OUTPUT_NODE = True |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
CATEGORY = "♾️Mixlab/Text" |
|
|
|
def run(self, text,output_dir=[""]): |
|
|
|
|
|
texts=[] |
|
for t in text: |
|
if not isinstance(t, str): |
|
t = str(t) |
|
texts.append(t) |
|
|
|
text=texts |
|
|
|
if len(output_dir)==1 and (output_dir[0]=='' or os.path.dirname(output_dir[0])==''): |
|
t='\n'.join(text) |
|
output_dir=[ |
|
os.path.join(folder_paths.get_temp_directory(), |
|
get_unique_hash(t)+'.txt' |
|
) |
|
] |
|
elif len(output_dir)==1: |
|
base=os.path.basename(output_dir[0]) |
|
t='\n'.join(text) |
|
if base=='' or os.path.splitext(base)[1]=='': |
|
base=get_unique_hash(t)+'.txt' |
|
output_dir=[ |
|
os.path.join(output_dir[0], |
|
base |
|
) |
|
] |
|
|
|
|
|
|
|
|
|
if len(output_dir)==1 and len(text)>1: |
|
output_dir=[output_dir[0] for _ in range(len(text))] |
|
|
|
for i in range(len(text)): |
|
|
|
o_fp=output_dir[i] |
|
dirp=os.path.dirname(o_fp) |
|
if dirp=='': |
|
dirp=folder_paths.get_temp_directory() |
|
o_fp=os.path.join(folder_paths.get_temp_directory(),o_fp |
|
) |
|
|
|
if not os.path.exists(dirp): |
|
os.mkdir(dirp) |
|
|
|
if not os.path.splitext(o_fp)[1].lower()=='.txt': |
|
o_fp=o_fp+'.txt' |
|
|
|
t=text[i] |
|
with open(o_fp, 'w') as file: |
|
file.write(t) |
|
|
|
|
|
return {"ui": {"text": text}, "result": (text,)} |
|
|
|
|
|
|
|
class CharacterInText: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"text": ("STRING", {"multiline": True,"dynamicPrompts": False}), |
|
"character": ("STRING", {"multiline": True,"dynamicPrompts": False}), |
|
"start_index": ("INT", { |
|
"default": 1, |
|
"min": 0, |
|
"max": 1024, |
|
"step": 1, |
|
"display": "number" |
|
}), |
|
} |
|
} |
|
|
|
INPUT_IS_LIST = False |
|
RETURN_TYPES = ("INT",) |
|
FUNCTION = "run" |
|
|
|
OUTPUT_IS_LIST = (False,) |
|
|
|
CATEGORY = "♾️Mixlab/Text" |
|
|
|
def run(self, text,character,start_index): |
|
|
|
b=1 if character.lower() in text.lower() else 0 |
|
|
|
return (b+start_index,) |
|
|
|
class TextSplitByDelimiter: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"text": ("STRING", {"multiline": True,"dynamicPrompts": False}), |
|
"delimiter":("STRING", {"multiline": False,"default":",","dynamicPrompts": False}), |
|
"start_index": ("INT", { |
|
"default": 0, |
|
"min": 0, |
|
"max": 1000, |
|
"step": 1, |
|
"display": "number" |
|
}), |
|
"skip_every": ("INT", { |
|
"default": 0, |
|
"min": 0, |
|
"max": 10, |
|
"step": 1, |
|
"display": "number" |
|
}), |
|
"max_count": ("INT", { |
|
"default": 10, |
|
"min": 1, |
|
"max": 1000, |
|
"step": 1, |
|
"display": "number" |
|
}), |
|
} |
|
} |
|
|
|
INPUT_IS_LIST = False |
|
RETURN_TYPES = ("STRING",) |
|
FUNCTION = "run" |
|
|
|
OUTPUT_IS_LIST = (True,) |
|
|
|
CATEGORY = "♾️Mixlab/Text" |
|
|
|
def run(self, text,delimiter,start_index,skip_every,max_count): |
|
|
|
if delimiter=="": |
|
arr=[text.strip()] |
|
else: |
|
delimiter=codecs.decode(delimiter, 'unicode_escape') |
|
arr= [line for line in text.split(delimiter) if line.strip()] |
|
|
|
arr= arr[start_index:start_index + max_count * (skip_every+1):(skip_every+1)] |
|
|
|
return (arr,) |
|
|