import base64 import time from pathlib import Path import pandas as pd import streamlit as st from haystack_integrations.document_stores.qdrant import QdrantDocumentStore from src.document_store.get_index import get_index from src.rag.pipeline import RAGPipeline from src.utils.data import load_css, load_json from src.utils.writer import typewriter DATA_BASE_PATH = Path(__file__).parent.parent.parent.parent / "data" # Function to load and encode the image def get_base64_image(image_path): with open(image_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode() @st.cache_data def load_css_style() -> None: load_css(Path(__file__).parent.parent.parent.parent / "style" / "style.css") @st.cache_data def load_template() -> str: path = ( Path(__file__).parent.parent.parent / "rag" / "prompt_templates" / "inc_template.txt" ) with open(path, "r") as file: template = file.read() return template @st.cache_resource def load_inc_pipeline(template: str) -> tuple[QdrantDocumentStore, RAGPipeline]: inc_index = get_index(index="inc_data_update_2") inc_rag = RAGPipeline(document_store=inc_index, top_k=7, template=template) return inc_index, inc_rag @st.cache_data def get_authors_taxonomy() -> list[str]: taxonomy = load_json(DATA_BASE_PATH / "taxonomies" / "authors_taxonomy.json") countries = [] members = taxonomy["Members"] for key, value in members.items(): if key == "Countries" or key == "International and Regional State Associations": countries.extend(value) return countries @st.cache_data def get_draft_cat_taxonomy() -> dict[str, list[str]]: taxonomy = load_json( DATA_BASE_PATH / "taxonomies" / "draftcat_taxonomy_filter.json" ) draft_labels = [] for _, subpart in taxonomy.items(): for label in subpart: draft_labels.append(label) return draft_labels @st.cache_data def get_negotiations_rounds() -> list[int]: return [1, 2, 3, 4, 5] @st.cache_data def get_example_prompts() -> list[str]: return [ example["question"] for example in load_json( DATA_BASE_PATH / "example_prompts" / "example_prompts_inc.json" ) ] @st.cache_data def set_trigger_state_values() -> tuple[bool, bool]: trigger_filter_inc = st.session_state.setdefault("trigger_inc", False) trigger_ask_inc = st.session_state.setdefault("trigger_inc", False) return trigger_filter_inc, trigger_ask_inc @st.cache_data def load_app_init() -> None: description_inc_col_1, _ = st.columns([0.66, 1]) with description_inc_col_1: with st.expander("About", icon=":material/info:"): st.markdown( """

The Interactive Treaty Assistant will support you on your research and analysis of documents submitted by INC members in the previous rounds to quickly pinpoint crucial information. Together with treaty-specific queries make use of the filters to get more precise responses. Along with the answer, the Chatbot also provides you with direct links to relevant documents enabling a deeper examination.
The tool excels at providing targeted information on countries and their positions in negotiations. Filter options by author and sections of the negotiation draft enhance accuracy, while direct links to filtered documents ensure quick access to detailed information. While the generated answers take into account up to eight documents at a time due to technical limitations, users can still access the full set of filtered documents via direct links for comprehensive exploration.

""", unsafe_allow_html=True, ) st.write("\n") st.write("\n") @st.cache_data def about_inc() -> None: st.markdown("""

Help us Improve!

""", unsafe_allow_html=True) st.markdown( """

We would appreciate your feedback and support to improve the app. You can fill out a quick feedback form (maximal 5 minutes) or use the in-depth survey (maximal 15 minutes).

""", unsafe_allow_html=True, ) review, in_depth_review, _ = st.columns(spec=[0.7, 1.0, 4], gap="large") with review: st.link_button( label="Feedback", url="https://forms.gle/PPT5g558utGDUAGh6", icon=":material/reviews:", ) with in_depth_review: st.link_button( label="Survey", url="https://docs.google.com/forms/d/1-WNS0ZdAuystajf2i6iSR5HpRfvV1LYq_TcQfaIMvkA", icon=":material/rate_review:", ) logo = get_base64_image("static/images/logo.png") st.write("\n") st.write("\n") st.write("\n") st.markdown( f""" """, unsafe_allow_html=True, ) def init_inc_page(): load_css_style() load_app_init() # Load Cache Data and Resources authors = get_authors_taxonomy() draft_labs = get_draft_cat_taxonomy() negotiation_rounds = get_negotiations_rounds() example_prompts = get_example_prompts() template = load_template() trigger_filter_inc, trigger_ask_inc = set_trigger_state_values() inc_index, inc_rag = load_inc_pipeline(template=template) # Application Column application_col_inc = st.columns(1) with application_col_inc[0]: st.markdown( """

1 Select Filters

""", unsafe_allow_html=True, ) text_design_col_1, textext_design_col_2 = st.columns([1, 1]) with text_design_col_1: st.markdown( """

Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter. We especially recommend to select countries you are interested in. """, unsafe_allow_html=True, ) st.write("\n") col_1, col_2, col_3 = st.columns([1, 1, 1]) with col_1: selected_authors_inc = st.multiselect( label="Countries or Associations", options=authors, label_visibility="visible", placeholder="Select", key="selected_authors_inc", help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc.", ) with col_2: selected_rounds_inc = st.multiselect( label="Session", options=negotiation_rounds, label_visibility="visible", placeholder="Select", key="selected_rounds_inc", help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc.

", ) with col_3: selected_draft_cats_inc = st.multiselect( label="Draft Categories", options=draft_labs, label_visibility="visible", placeholder="Select", key="selected_draft_cats", help=" Please select the parts of the negotiation draft of interest. The negotiation draft can be accessed (https://www.unep.org/inc-plastic-pollution/session-4/documents)", ) st.write("\n") st.write("\n") st.markdown( """

2 Ask a question or show documents based on selected filters

""", unsafe_allow_html=True, ) asking_inc, filtering_inc = st.tabs(["Ask a question", "Filter documents"]) with asking_inc: application_col_ask_inc, output_col_ask_inc = st.columns([1, 1.5]) with application_col_ask_inc: st.markdown( """

Ask a question, noting that the database has been restricted by filters and that your question should pertain to the selected data. \n """, unsafe_allow_html=True, ) if "prompt" not in st.session_state: prompt_inc = st.text_area("") if ( "prompt" in st.session_state and st.session_state.prompt in example_prompts # noqa: E501 ): # noqa: E501 prompt_inc = st.text_area( "Enter a question", value=st.session_state.prompt ) # noqa: E501 if ( "prompt" in st.session_state and st.session_state.prompt not in example_prompts # noqa: E501 ): # noqa: E501 del st.session_state["prompt"] prompt_inc = st.text_area("Enter a question") trigger_ask_inc = st.session_state.setdefault("trigger_inc", False) if st.button("Ask", icon=":material/send:", type="primary"): if prompt_inc == "": st.error( "Please enter a question. Reloading the app in few seconds" ) time.sleep(3) st.rerun() with st.spinner("Filtering data...") as status: if ( not selected_authors_inc and not selected_draft_cats_inc and not selected_rounds_inc ): st.error( "Selecting a filter is mandatory. We especially recommend to select countries you are interested in. Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter." ) st.stop() with st.spinner("Analyzing Filters") as status: filter_selection = { "author": selected_authors_inc, "draft_labs": selected_draft_cats_inc, "round": selected_rounds_inc, } filters = inc_rag.build_filter( filter_selections=filter_selection ) docs = inc_index.filter_documents(filters) if not docs: st.error( "The combination of filters you've chosen does not match any documents. Please try another combination of filters. If a filter combination does not return any documents, it means that there are no documents that match the selected filters and therefore no answer can be given." ) trigger_ask_inc = False st.stop() else: st.success("Filtering completed.") with st.spinner("Answering question..."): result = inc_rag.run( query=prompt_inc, filter_selections=filter_selection ) trigger_ask_inc = True st.success("Answering question completed.") st.markdown( "***≡ Examples***", help="These are example prompts that can be used to ask questions to the model. Click on a prompt to use it as a question. You can also type your own question in the text area above. In general we highly recommend to use the filter functions to narrow down the data.", ) st.caption("Double click to select the prompt") for i, prompt_inc in enumerate(example_prompts): # with col[i % 4]: if st.button(prompt_inc): if "key" not in st.session_state: st.session_state["prompt"] = prompt_inc # Define the button with filtering_inc: application_col_filter, output_col_filter = st.columns([1, 1.5]) # make the buttons text smaller with application_col_filter: st.markdown( """

This filter function allows you to see all documents that match the selected filters. The documents can be accessed via a link. \n """, unsafe_allow_html=True, ) if st.button("Filter", icon=":material/filter_alt:", type="primary"): if ( not selected_authors_inc and not selected_draft_cats_inc and not selected_rounds_inc ): st.info( "No filters selected. All documents will be shown. Longer processing time expected." ) with st.spinner("Filtering documents..."): filter = RAGPipeline.build_filter( filter_selections={ "author": selected_authors_inc, "draft_labs": selected_draft_cats_inc, "round": selected_rounds_inc, } ) result = inc_index.filter_documents(filter) retriever_ids = set() result_meta = [] for doc in result: retriever_id = doc.meta["retriever_id"] if retriever_id not in retriever_ids: result_meta.append( { "author": doc.meta["author"], "doc_type": doc.meta["doc_type"], "session": doc.meta["round"], "href": doc.meta["href"], "draft_labs": doc.meta["draft_labs"], } ) retriever_ids.add(retriever_id) else: continue result_df = pd.DataFrame(result_meta) if result_df.empty: st.info( "No documents found for the combination of filters you've chosen. All countries are represented at least once in the data. Remove the draft categories to see all documents for the countries selected or try other draft categories and/or sessions." ) trigger_filter_inc = False else: trigger_filter_inc = True if trigger_filter_inc: with output_col_filter: st.markdown("### Overview of all filtered documents") st.dataframe( result_df, hide_index=True, column_config={ "author": st.column_config.ListColumn("Authors"), "href": st.column_config.LinkColumn("Link to Document"), "draft_labs": st.column_config.ListColumn("Draft Categories"), "session": st.column_config.NumberColumn("Session"), "doc_type": st.column_config.TextColumn("Document Type"), }, ) if trigger_ask_inc: with output_col_ask_inc: if result is None: st.error( "Open AI rate limit exceeded. Please try again in a few seconds." ) st.stop() reference_data = [ (doc.meta["retriever_id"], doc.meta["href"]) for doc in result["retriever"]["documents"] ] references = ["\n"] for retriever_id, href in reference_data: references.append(f"-[{retriever_id}]: {href} \n") references = list(set(references)) st.markdown( """ Answer""", unsafe_allow_html=True, ) typewriter( text=result["llm"]["replies"][0], references=references, speed=100, ) with st.expander("Show more information to the documents"): sorted_docs = sorted( result["retriever"]["documents"], key=lambda x: x.meta["retriever_id"], ) current_doc = None markdown_text = "" for doc in sorted_docs: print(current_doc) if doc.meta["retriever_id"] != current_doc: markdown_text += f"- Document: {doc.meta['retriever_id']}\n" markdown_text += " - Text passages\n" markdown_text += f" - {doc.content}\n" else: markdown_text += f" - {doc.content}\n" current_doc = doc.meta["retriever_id"] st.write(markdown_text) trigger_ask_inc = False st.markdown( """


""", unsafe_allow_html=True, ) about_inc()