Spaces:
Running
Running
import streamlit as st | |
from tantivy_search_agent import TantivySearchAgent | |
from agent_workflow import SearchAgent | |
import os | |
from typing import Optional, List | |
from dotenv import load_dotenv | |
import gdown | |
# Load environment variables | |
load_dotenv() | |
class SearchAgentUI: | |
def __init__(self): | |
self.tantivy_agent: Optional[TantivySearchAgent] = None | |
self.agent: Optional[SearchAgent] = None | |
self.index_path ="./index" # os.getenv("INDEX_PATH", "./index") | |
# Google Drive folder ID for the index | |
self.gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t") | |
def download_index_from_gdrive(self) -> bool: | |
"""Download index folder from Google Drive""" | |
try: | |
# Create a temporary zip file path | |
zip_path = "index.zip" | |
# Download the folder as a zip file | |
url = f"https://drive.google.com/uc?id={self.gdrive_index_id}" | |
# Create a progress bar and status text | |
progress_text = st.empty() | |
progress_bar = st.progress(0) | |
def progress_callback(progress): | |
progress_bar.progress(progress) | |
progress_text.text(f"מוריד... {progress:.1f}%") | |
# Download with progress callback | |
gdown.download(url, zip_path, quiet=False, callback=progress_callback) | |
# Update status for extraction | |
progress_text.text("מחלץ קבצים...") | |
progress_bar.progress(100) | |
# Extract the zip file | |
import zipfile | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(".") | |
# Remove the zip file | |
os.remove(zip_path) | |
# Clear the progress indicators | |
progress_text.empty() | |
progress_bar.empty() | |
return True | |
except Exception as e: | |
st.error(f"Failed to download index: {str(e)}") | |
return False | |
def get_available_providers(self) -> List[str]: | |
"""Get available providers without creating a SearchAgent instance""" | |
temp_tantivy = TantivySearchAgent(self.index_path) | |
temp_agent = SearchAgent(temp_tantivy) | |
return temp_agent.get_available_providers() | |
def initialize_system(self): | |
try: | |
# Check if index folder exists | |
if not os.path.exists(self.index_path): | |
st.warning("Index folder not found. Attempting to download from Google Drive...") | |
if not self.download_index_from_gdrive(): | |
return False, "שגיאה: לא ניתן להוריד את האינדקס", [] | |
st.success("Index downloaded successfully!") | |
self.tantivy_agent = TantivySearchAgent(self.index_path) | |
if self.tantivy_agent.validate_index(): | |
available_providers = self.get_available_providers() | |
self.agent = SearchAgent( | |
self.tantivy_agent, | |
provider_name=st.session_state.get('provider', available_providers[0]) | |
) | |
return True, "המערכת מוכנה לחיפוש", available_providers | |
else: | |
return False, "שגיאה: אינדקס לא תקין", [] | |
except Exception as ex: | |
return False, f"שגיאה באתחול המערכת: {str(ex)}", [] | |
def main(self): | |
st.set_page_config( | |
page_title="איתוריא", | |
layout="wide", | |
initial_sidebar_state="collapsed" | |
) | |
# Enhanced RTL support and styling | |
st.markdown(""" | |
<style> | |
.stApp { | |
direction: rtl; | |
} | |
.stTextInput > div > div > input { | |
direction: rtl; | |
} | |
.stSelectbox > div > div > div { | |
direction: rtl; | |
} | |
.stNumberInput > div > div > input { | |
direction: rtl; | |
} | |
.search-step { | |
border: 1px solid #e0e0e0; | |
border-radius: 5px; | |
padding: 10px; | |
margin: 5px 0; | |
background-color: #f8f9fa; | |
} | |
.document-group { | |
border: 1px solid #e3f2fd; | |
border-radius: 5px; | |
padding: 10px; | |
margin: 5px 0; | |
background-color: #f5f9ff; | |
} | |
.document-item { | |
border: 1px solid #e0e0e0; | |
border-radius: 5px; | |
padding: 10px; | |
margin: 5px 0; | |
background-color: white; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize system | |
success, status_msg, available_providers = self.initialize_system() | |
# Header layout | |
col1, col2, col3 = st.columns([2,1,1]) | |
with col1: | |
if success: | |
st.success(status_msg) | |
else: | |
st.error(status_msg) | |
with col2: | |
if 'provider' not in st.session_state: | |
st.session_state.provider = available_providers[0] if available_providers else None | |
if available_providers: | |
provider = st.selectbox( | |
"ספק בינה מלאכותית", | |
options=available_providers, | |
key='provider' | |
) | |
if self.agent: | |
self.agent.set_provider(provider) | |
with col3: | |
col3_1, col3_2 = st.columns(2) | |
with col3_1: | |
max_iterations = st.number_input( | |
"מספר נסיונות מקסימלי", | |
min_value=1, | |
value=3, | |
key='max_iterations' | |
) | |
with col3_2: | |
results_per_search = st.number_input( | |
"תוצאות לכל חיפוש", | |
min_value=1, | |
value=5, | |
key='results_per_search' | |
) | |
# Search input | |
query = st.text_input( | |
"הכנס שאילתת חיפוש", | |
disabled=not success, | |
placeholder="הקלד את שאילתת החיפוש שלך כאן...", | |
key='search_query' | |
) | |
# Search button | |
if (st.button('חפש', disabled=not success) or query) and query!="" and self.agent: | |
try: | |
if 'steps' not in st.session_state: | |
st.session_state.steps = [] | |
steps_container = st.container() | |
answer_container = st.container() | |
sources_container = st.container() | |
with steps_container: | |
st.subheader("צעדי תהליך החיפוש") | |
def handle_step_update(step): | |
if 'final_result' in step: | |
final_result = step['final_result'] | |
with answer_container: | |
st.subheader("תשובה סופית") | |
st.info(final_result['answer']) | |
if final_result['sources']: | |
with sources_container: | |
st.subheader("מסמכי מקור") | |
st.markdown(f"נמצאו {len(final_result['sources'])} תוצאות") | |
for i, source in enumerate(final_result['sources']): | |
with st.expander(f"תוצאה {i+1}: {source['reference']} (ציון: {source['score']:.2f})"): | |
st.write(source['text']) | |
else: | |
with steps_container: | |
step_number = len(st.session_state.steps) + 1 | |
st.markdown(f""" | |
<div class='search-step'> | |
<strong>צעד {step_number}. {step['action']}</strong> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown(f"**{step['description']}**") | |
if 'results' in step: | |
documents = [] | |
for r in step['results']: | |
if r['type'] == 'query': | |
st.markdown("**שאילתת חיפוש:**") | |
st.code(r['content']) | |
elif r['type'] == 'document': | |
documents.append(r['content']) | |
elif r['type'] == 'evaluation': | |
content = r['content'] | |
status = "✓" if content['status'] == 'accepted' else "↻" | |
confidence = f"ביטחון: {content['confidence']}" | |
if content['status'] == 'accepted': | |
st.success(f"{status} {confidence}") | |
else: | |
st.warning(f"{status} {confidence}") | |
if content['explanation']: | |
st.info(content['explanation']) | |
elif r['type'] == 'new_query': | |
st.markdown("**ניסיון הבא:**") | |
st.code(r['content']) | |
# Display documents if any were found | |
if documents: | |
for i, doc in enumerate(documents): | |
with st.expander(f"{doc['reference']} (ציון: {doc['score']:.2f})"): | |
st.write(doc['highlights'][0]) | |
st.markdown("---") | |
st.session_state.steps.append(step) | |
# Clear previous steps before starting new search | |
st.session_state.steps = [] | |
# Start the search process | |
self.agent.search_and_answer( | |
query=query, | |
num_results=results_per_search, | |
max_iterations=max_iterations, | |
on_step=handle_step_update | |
) | |
except Exception as ex: | |
st.error(f"שגיאת חיפוש: {str(ex)}") | |
if __name__ == "__main__": | |
app = SearchAgentUI() | |
app.main() | |