import os from gradio.themes import ThemeClass as Theme import numpy as np import argparse import gradio as gr from typing import Any, Iterator from typing import Iterator, List, Optional, Tuple import filelock import glob import json import time from gradio.routes import Request from gradio.utils import SyncToAsyncIterator, async_iteration from gradio.helpers import special_args import anyio from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator from gradio_client.documentation import document, set_documentation_group from gradio.components import Button, Component from gradio.events import Dependency, EventListenerMethod from typing import List, Optional, Union, Dict, Tuple from tqdm.auto import tqdm from huggingface_hub import snapshot_download from gradio.themes import ThemeClass as Theme from .base_demo import register_demo, get_demo_class, BaseDemo import inspect from typing import AsyncGenerator, Callable, Literal, Union, cast import anyio from gradio_client import utils as client_utils from gradio_client.documentation import document from gradio.blocks import Blocks from gradio.components import ( Button, Chatbot, Component, Markdown, State, Textbox, get_component_instance, ) from gradio.events import Dependency, on from gradio.helpers import create_examples as Examples # noqa: N812 from gradio.helpers import special_args from gradio.layouts import Accordion, Group, Row from gradio.routes import Request from gradio.themes import ThemeClass as Theme from gradio.utils import SyncToAsyncIterator, async_iteration from ..globals import MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, load_embeddings, get_rag_embeddings from .chat_interface import ( SYSTEM_PROMPT, MODEL_NAME, MAX_TOKENS, TEMPERATURE, CHAT_EXAMPLES, gradio_history_to_openai_conversations, gradio_history_to_conversation_prompt, DATETIME_FORMAT, get_datetime_string, format_conversation, chat_response_stream_multiturn_engine, ChatInterfaceDemo, CustomizedChatInterface, ) from ..configs import ( CHUNK_SIZE, CHUNK_OVERLAP, RAG_EMBED_MODEL_NAME, ) RAG_CURRENT_VECTORSTORE = None def load_document_split_vectorstore(file_path): global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings from langchain_community.vectorstores import Chroma, FAISS from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) if file_path.endswith('.pdf'): loader = PyPDFLoader(file_path) elif file_path.endswith('.docx'): loader = Docx2txtLoader(file_path) elif file_path.endswith('.txt'): loader = TextLoader(file_path) splits = loader.load_and_split(splitter) RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings()) return RAG_CURRENT_VECTORSTORE def docs_to_context_content(docs: List[Any]): content = "\n".join([d.page_content for d in docs]) return content DOC_TEMPLATE = """### {content} ### """ DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \ If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query! """ def docs_to_rag_context(docs: List[Any], doc_instruction=None): doc_instruction = doc_instruction or DOC_INSTRUCTION content = docs_to_context_content(docs) context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=content) return context def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3): doc_context = None if file_input is not None: if file_input == RAG_CURRENT_FILE: # reuse vectorstore = RAG_CURRENT_VECTORSTORE print(f'Reuse vectorstore: {file_input}') else: vectorstore = load_document_split_vectorstore(file_input) print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}') RAG_CURRENT_FILE = file_input docs = vectorstore.similarity_search(message, k=rag_num_docs) doc_context = docs_to_rag_context(docs) return doc_context def chat_response_stream_multiturn_doc_engine( message: str, history: List[Tuple[str, str]], file_input: Optional[str] = None, temperature: float = 0.7, max_tokens: int = 1024, system_prompt: Optional[str] = SYSTEM_PROMPT, rag_num_docs: Optional[int] = 3, doc_instruction: Optional[str] = DOC_INSTRUCTION, # profile: Optional[gr.OAuthProfile] = None, ): global MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE if len(message) == 0: raise gr.Error("The message cannot be empty!") rag_num_docs = int(rag_num_docs) doc_instruction = doc_instruction or DOC_INSTRUCTION doc_context = None if file_input is not None: if file_input == RAG_CURRENT_FILE: # reuse vectorstore = RAG_CURRENT_VECTORSTORE print(f'Reuse vectorstore: {file_input}') else: vectorstore = load_document_split_vectorstore(file_input) print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}') RAG_CURRENT_FILE = file_input docs = vectorstore.similarity_search(message, k=rag_num_docs) # doc_context = docs_to_rag_context(docs) rag_content = docs_to_context_content(docs) doc_context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=rag_content) if doc_context is not None: message = f"{doc_context}\n\n{message}" for response, num_tokens in chat_response_stream_multiturn_engine( message, history, temperature, max_tokens, system_prompt ): # ! yield another content which is doc_context yield response, num_tokens, doc_context class RagChatInterface(CustomizedChatInterface): def __init__( self, fn: Callable[..., Any], *, chatbot: gr.Chatbot | None = None, textbox: gr.Textbox | None = None, additional_inputs: str | Component | list[str | Component] | None = None, additional_inputs_accordion_name: str | None = None, additional_inputs_accordion: str | gr.Accordion | None = None, render_additional_inputs_fn: Callable | None = None, examples: list[str] | None = None, cache_examples: bool | None = None, title: str | None = None, description: str | None = None, theme: Theme | str | None = None, css: str | None = None, js: str | None = None, head: str | None = None, analytics_enabled: bool | None = None, submit_btn: str | Button | None = "Submit", stop_btn: str | Button | None = "Stop", retry_btn: str | Button | None = "đ Retry", undo_btn: str | Button | None = "âŠī¸ Undo", clear_btn: str | Button | None = "đī¸ Clear", autofocus: bool = True, concurrency_limit: int | Literal['default'] | None = "default", fill_height: bool = True ): try: super(gr.ChatInterface, self).__init__( analytics_enabled=analytics_enabled, mode="chat_interface", css=css, title=title or "Gradio", theme=theme, js=js, head=head, fill_height=fill_height, ) except Exception as e: # Handling some old gradio version with out fill_height super(gr.ChatInterface, self).__init__( analytics_enabled=analytics_enabled, mode="chat_interface", css=css, title=title or "Gradio", theme=theme, js=js, head=head, # fill_height=fill_height, ) self.concurrency_limit = concurrency_limit self.fn = fn self.render_additional_inputs_fn = render_additional_inputs_fn self.is_async = inspect.iscoroutinefunction( self.fn ) or inspect.isasyncgenfunction(self.fn) self.is_generator = inspect.isgeneratorfunction( self.fn ) or inspect.isasyncgenfunction(self.fn) self.examples = examples if self.space_id and cache_examples is None: self.cache_examples = True else: self.cache_examples = cache_examples or False self.buttons: list[Button | None] = [] if additional_inputs: if not isinstance(additional_inputs, list): additional_inputs = [additional_inputs] self.additional_inputs = [ get_component_instance(i) for i in additional_inputs # type: ignore ] else: self.additional_inputs = [] if additional_inputs_accordion_name is not None: print( "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead." ) self.additional_inputs_accordion_params = { "label": additional_inputs_accordion_name } if additional_inputs_accordion is None: self.additional_inputs_accordion_params = { "label": "Additional Inputs", "open": False, } elif isinstance(additional_inputs_accordion, str): self.additional_inputs_accordion_params = { "label": additional_inputs_accordion } elif isinstance(additional_inputs_accordion, Accordion): self.additional_inputs_accordion_params = ( additional_inputs_accordion.recover_kwargs( additional_inputs_accordion.get_config() ) ) else: raise ValueError( f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}" ) with self: if title: Markdown( f"