import os import gradio as gr from typing import Callable, Generator import base64 from openai import OpenAI END_POINT = os.environ.get("ENDPOINT") SECRET_KEY = os.environ.get("SECRETKEY") USERS = os.environ.get("USERS") PWD = os.environ.get("PWD") def get_fn(model_name: str, **model_kwargs) -> Callable: """Create a chat function with the specified model.""" # Instantiate an OpenAI client for a custom endpoint try: client = OpenAI( base_url=END_POINT, api_key=SECRET_KEY, ) except Exception as e: print(f"The API or base URL were not defined: {str(e)}") raise e def predict( messages: list, temperature: float, max_tokens: int, top_p: float ) -> Generator[str, None, None]: try: # Call the OpenAI API with the formatted messages response = client.chat.completions.create( model=model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, stream=True, response_format={"type": "text"}, ) response_text = "" for chunk in response: if len(chunk.choices[0].delta.content) > 0: content = chunk.choices[0].delta.content if content: response_text += content yield response_text.strip() if not response_text.strip(): yield "I apologize, but I was unable to generate a response. Please try again." except Exception as e: print(f"Error during generation: {str(e)}") yield f"An error occurred: {str(e)}" return predict def get_image_base64(url: str, ext: str) -> str: with open(url, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return f"data:image/{ext};base64,{encoded_string}" def handle_user_msg(message: str) -> str: if isinstance(message, str): return message elif isinstance(message, dict): if message.get("files"): ext = os.path.splitext(message["files"][-1])[1].strip(".").lower() if ext in ["png", "jpg", "jpeg", "gif", "pdf"]: encoded_str = get_image_base64(message["files"][-1], ext) return f"{message.get('text', '')}\n![Image]({encoded_str})" else: raise NotImplementedError(f"Unsupported file type: {ext}") else: return message.get("text", "") else: raise NotImplementedError("Unsupported message type") def get_interface_args(pipeline: str): if pipeline == "chat": inputs = None outputs = None def preprocess(message, history): messages = [] files = None for user_msg, assistant_msg in history: if assistant_msg is not None: messages.append({"role": "user", "content": handle_user_msg(user_msg)}) messages.append({"role": "assistant", "content": assistant_msg}) else: files = user_msg if isinstance(message, str) and files is not None: message = {"text": message, "files": files} elif isinstance(message, dict) and files is not None: if not message.get("files"): message["files"] = files messages.append({"role": "user", "content": handle_user_msg(message)}) return {"messages": messages} postprocess = lambda x: x # No additional postprocessing needed else: raise ValueError(f"Unsupported pipeline type: {pipeline}") return inputs, outputs, preprocess, postprocess def registry(name: str = None, **kwargs) -> gr.ChatInterface: """Create a Gradio Interface with similar styling and parameters.""" # Retrieving preprocess and postprocess functions _, _, preprocess, postprocess = get_interface_args("chat") # Getting the predict function predict_fn = get_fn(model_name=name, **kwargs) # Defining a wrapper function that integrates preprocessing and postprocessing def wrapper(message, history, system_prompt, temperature, max_tokens, top_p): # Preprocessing the inputs preprocessed = preprocess(message, history) # Extracting the preprocessed messages messages = preprocessed["messages"] # Calling the predict function and generate the response response_generator = predict_fn( messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p ) # Collecting the generated response response = "" for partial_response in response_generator: response = partial_response # Gradio will handle streaming yield response # Creating the Gradio ChatInterface with the wrapper function interface = gr.ChatInterface( fn=wrapper, additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant.", label="System prompt" ), gr.Slider(0.0, 1.0, value=0.7, label="Temperature"), gr.Slider(128, 4096, value=1024, label="Max new tokens"), gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"), ], ) return interface