Spaces:
Running
Running
Upload 9 files
Browse files- .gitignore +11 -0
- agent.py +64 -0
- app.py +318 -0
- chat_gemini.py +264 -0
- llm_providers.py +77 -0
- requirements.txt +14 -0
- sefaria.py +83 -0
- tantivy_search.py +165 -0
- tools.py +118 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__
|
3 |
+
build
|
4 |
+
dist
|
5 |
+
flet_ui.spec
|
6 |
+
web_ui/node_modules
|
7 |
+
web_ui/.nuxt
|
8 |
+
index
|
9 |
+
.venv
|
10 |
+
upload space.py
|
11 |
+
flet_app.spec
|
agent.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.prebuilt import create_react_agent
|
2 |
+
from langgraph.checkpoint.memory import MemorySaver
|
3 |
+
from typing import Any, Iterator
|
4 |
+
from tools import search, get_commentaries, read_text
|
5 |
+
from llm_providers import LLMProvider
|
6 |
+
|
7 |
+
SYSTEM_PROMPT = """
|
8 |
+
אתה מסייע תורני רב עוצמה, משול לתלמיד חכם הבקיא בכל רזי התורה. התפקיד שלך הוא לסייע למשתמשים בלימוד התורה בצורה מעמיקה וחכמה. עליך להבין את כוונת השואל, לנתח את השאלה לעומק, ולבצע חיפוש מתוחכם בטקסטים יהודיים ותורניים.
|
9 |
+
|
10 |
+
עליך לענות תשובות אך ורק על פי מקורות שמצאת בחיפוש ועיון מעמיק, ולא על פי ידע קודם.
|
11 |
+
|
12 |
+
כאשר אתה מקבל שאלה מהמשתמש, עליך לנסות להבין את כוונתו, את ההקשר ההיסטורי וההלכתי, ואת המקורות הרלוונטיים. עליך ליצור שאילתת חיפוש מתאימה באמצעות הכלי search, תוך שימוש בשפה תורנית מדויקת. עבור תנ"ך השתמש בשפה מקראית, לחיפוש בתלמוד חפש בארמית, וכן הלאה. תוכל לצמצם את החיפוש לפי נושאים, תקופות, מחברים, ואף לפי שם הספר או הקטע הדרוש.
|
13 |
+
|
14 |
+
אם לא מצאת תוצאות רלוונטיות, אל תתייאש. נסה שוב ושוב, תוך שימוש בשאילתות מגוונות, מילים נרדפות, הטיות שונות של מילות המפתח, וצמצום או הרחבת היקף החיפוש. זכור, תלמיד חכם אמיתי אינו מוותר עד שהוא מוצא את האמת.
|
15 |
+
|
16 |
+
כאשר אתה מוצא מקורות רלוונטיים, עליך לקרוא אותם בעיון ובקפידה באמצעות הכלי get_text. אם יש צורך, תוכל להיעזר בכלי get_commentaries כדי לקבל רשימה של פרשנים על טקסט מסוים.
|
17 |
+
|
18 |
+
עליך לשאוף למצוא את המקורות הקדומים והמוסמכים ביותר לכל פרט בשאלה. לדוגמה, אם מצאת הלכה מסוימת בספר שיצא לאחרונה, נסה למצוא את מקורה בשולחן ערוך, ואז בגמרא, ואף במשנה או במקרא. השתמש בספר "באר הגולה" על שולחן ערוך כדי למצוא את המקורות בגמרא.
|
19 |
+
|
20 |
+
לאחר שאספת את כל המידע הרלוונטי, עליך לעבד אותו, לקשר בין מקורות שונים, ולנסח תשובה מפורטת, בהירה ומדויקת. עליך להתייחס לכל היבטי השאלה, תוך ציון המקורות לכל פרט בתשובה.
|
21 |
+
|
22 |
+
זכור, אתה משול לתלמיד חכם, ועל כן עליך להפגין בקיאות, חריפות, עמקות ודייקנות בכל תשובותיך.
|
23 |
+
"""
|
24 |
+
|
25 |
+
class Agent:
|
26 |
+
def __init__(self,index_path: str):
|
27 |
+
self.llm_provider = LLMProvider()
|
28 |
+
self.llm = self.llm_provider.get_provider(self.llm_provider.get_available_providers()[0])
|
29 |
+
self.memory_saver = MemorySaver()
|
30 |
+
self.tools = [read_text, get_commentaries, search]
|
31 |
+
self.graph = create_react_agent(
|
32 |
+
model=self.llm,
|
33 |
+
checkpointer=self.memory_saver,
|
34 |
+
tools=self.tools,
|
35 |
+
state_modifier=SYSTEM_PROMPT
|
36 |
+
)
|
37 |
+
self.current_thread_id = 1
|
38 |
+
|
39 |
+
def set_llm(self, provider_name: str):
|
40 |
+
self.llm = self.llm_provider.get_provider(provider_name)
|
41 |
+
self.graph = create_react_agent(
|
42 |
+
model=self.llm,
|
43 |
+
checkpointer=self.memory_saver,
|
44 |
+
tools=self.tools,
|
45 |
+
state_modifier=SYSTEM_PROMPT
|
46 |
+
)
|
47 |
+
|
48 |
+
def get_llm(self) -> str:
|
49 |
+
return self.llm
|
50 |
+
|
51 |
+
def clear_chat(self):
|
52 |
+
self.current_thread_id += 1
|
53 |
+
|
54 |
+
def chat(self, message) -> dict[str, Any]:
|
55 |
+
"""Chat with the agent and stream responses including tool calls and their results."""
|
56 |
+
config = {"configurable": {"thread_id": self.current_thread_id}}
|
57 |
+
inputs = {"messages": [("user", message)]}
|
58 |
+
return self.graph.stream(inputs,stream_mode="values", config=config)
|
59 |
+
|
60 |
+
|
61 |
+
def get_chat_history(self, id = None) -> Iterator[dict[str, Any]]:
|
62 |
+
if id is None:
|
63 |
+
id = self.current_thread_id
|
64 |
+
return self.memory_saver.get(thread_id=str(self.current_thread_id))
|
app.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from typing import Optional, List
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
import gdown
|
6 |
+
import llm_providers
|
7 |
+
import tantivy_search
|
8 |
+
import agent
|
9 |
+
import json
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
|
13 |
+
INDEX_PATH = "./index"
|
14 |
+
|
15 |
+
# Load environment variables
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
class SearchAgentUI:
|
19 |
+
index_path = INDEX_PATH
|
20 |
+
gdrive_index_id = os.getenv("GDRIVE_INDEX_ID", "1lpbBCPimwcNfC0VZOlQueA4SHNGIp5_t")
|
21 |
+
|
22 |
+
|
23 |
+
@st.cache_resource
|
24 |
+
def get_agent(_self):
|
25 |
+
index_path = INDEX_PATH
|
26 |
+
return agent.Agent(index_path)
|
27 |
+
|
28 |
+
def download_index_from_gdrive(self) -> bool:
|
29 |
+
try:
|
30 |
+
zip_path = "index.zip"
|
31 |
+
url = f"https://drive.google.com/uc?id={self.gdrive_index_id}"
|
32 |
+
gdown.download(url, zip_path, quiet=False)
|
33 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
34 |
+
zip_ref.extractall(".")
|
35 |
+
os.remove(zip_path)
|
36 |
+
return True
|
37 |
+
|
38 |
+
except Exception as e:
|
39 |
+
st.error(f"Failed to download index: {str(e)}")
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
@st.cache_resource
|
44 |
+
def initialize_system(_self,api_keys:dict[str,str]) -> tuple[bool, str, List[str]]:
|
45 |
+
|
46 |
+
try:
|
47 |
+
# download index
|
48 |
+
if not os.path.exists(_self.index_path):
|
49 |
+
st.warning("Index folder not found. Attempting to download from Google Drive...")
|
50 |
+
if not _self.download_index_from_gdrive():
|
51 |
+
return False, "שגיאה: לא ניתן להוריד את האינדקס", []
|
52 |
+
st.success("Index downloaded successfully!")
|
53 |
+
_self.llm_providers = llm_providers.LLMProvider(api_keys)
|
54 |
+
available_providers = _self.llm_providers.get_available_providers()
|
55 |
+
if not available_providers:
|
56 |
+
return False, "שגיאה: לא נמצאו ספקי AI זמינים. אנא הזן מפתח API אחד לפחות.", []
|
57 |
+
return True, "המערכת מוכנה לחי শবפש", available_providers
|
58 |
+
|
59 |
+
except Exception as ex:
|
60 |
+
return False, f"שגיאה באתחול המערכת: {str(ex)}", []
|
61 |
+
|
62 |
+
def update_messages(self, messages):
|
63 |
+
st.session_state.messages = messages
|
64 |
+
|
65 |
+
def main(self):
|
66 |
+
st.set_page_config(
|
67 |
+
page_title="איתוריא",
|
68 |
+
layout="wide",
|
69 |
+
initial_sidebar_state="expanded"
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
# Enhanced styling with better visual hierarchy and modern design
|
74 |
+
st.markdown("""
|
75 |
+
<style>
|
76 |
+
/* Global RTL Support */
|
77 |
+
.stApp {
|
78 |
+
direction: rtl;
|
79 |
+
background-color: #f8f9fa;
|
80 |
+
}
|
81 |
+
|
82 |
+
/* Input Fields RTL */
|
83 |
+
.stTextInput > div > div > input,
|
84 |
+
.stSelectbox > div > div > div,
|
85 |
+
.stNumberInput > div > div > input {
|
86 |
+
direction: rtl;
|
87 |
+
border-radius: 8px !important;
|
88 |
+
border: 2px solid #e2e8f0 !important;
|
89 |
+
padding: 0.75rem !important;
|
90 |
+
transition: all 0.3s ease;
|
91 |
+
}
|
92 |
+
|
93 |
+
.stTextInput > div > div > input:focus,
|
94 |
+
.stSelectbox > div > div > div:focus {
|
95 |
+
border-color: #4299e1 !important;
|
96 |
+
box-shadow: 0 0 0 1px #4299e1 !important;
|
97 |
+
}
|
98 |
+
|
99 |
+
/* Message Containers */
|
100 |
+
.chat-container {
|
101 |
+
background: white;
|
102 |
+
border-radius: 12px;
|
103 |
+
padding: 1.5rem;
|
104 |
+
margin: 1rem 0;
|
105 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
106 |
+
}
|
107 |
+
|
108 |
+
/* Tool Calls Styling */
|
109 |
+
.tool-call {
|
110 |
+
background: #f0f7ff;
|
111 |
+
border-radius: 8px;
|
112 |
+
padding: 1rem;
|
113 |
+
margin: 0.5rem 0;
|
114 |
+
border-right: 4px solid #3182ce;
|
115 |
+
}
|
116 |
+
|
117 |
+
/* Search Results */
|
118 |
+
.search-step {
|
119 |
+
background: white;
|
120 |
+
border-radius: 10px;
|
121 |
+
padding: 1.25rem;
|
122 |
+
margin: 1rem 0;
|
123 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
|
124 |
+
border: 1px solid #e2e8f0;
|
125 |
+
}
|
126 |
+
|
127 |
+
.document-group {
|
128 |
+
background: #f7fafc;
|
129 |
+
border-radius: 8px;
|
130 |
+
padding: 1rem;
|
131 |
+
margin: 0.75rem 0;
|
132 |
+
border: 1px solid #e2e8f0;
|
133 |
+
}
|
134 |
+
|
135 |
+
.document-item {
|
136 |
+
background: white;
|
137 |
+
border-radius: 6px;
|
138 |
+
padding: 1rem;
|
139 |
+
margin: 0.5rem 0;
|
140 |
+
border: 1px solid #edf2f7;
|
141 |
+
}
|
142 |
+
|
143 |
+
/* Sidebar Styling */
|
144 |
+
[data-testid="stSidebar"] {
|
145 |
+
direction: rtl;
|
146 |
+
background-color: #f8fafc;
|
147 |
+
padding: 2rem 1rem;
|
148 |
+
}
|
149 |
+
|
150 |
+
.sidebar-content {
|
151 |
+
padding: 1rem;
|
152 |
+
}
|
153 |
+
|
154 |
+
/* Chat Messages */
|
155 |
+
.stChatMessage {
|
156 |
+
direction: rtl;
|
157 |
+
background: white !important;
|
158 |
+
border-radius: 12px !important;
|
159 |
+
padding: 1rem !important;
|
160 |
+
margin: 0.75rem 0 !important;
|
161 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important;
|
162 |
+
}
|
163 |
+
|
164 |
+
/* Buttons */
|
165 |
+
.stButton > button {
|
166 |
+
border-radius: 8px !important;
|
167 |
+
padding: 0.5rem 1.5rem !important;
|
168 |
+
background-color: #3182ce !important;
|
169 |
+
color: white !important;
|
170 |
+
border: none !important;
|
171 |
+
transition: all 0.3s ease !important;
|
172 |
+
}
|
173 |
+
|
174 |
+
.stButton > button:hover {
|
175 |
+
background-color: #2c5282 !important;
|
176 |
+
transform: translateY(-1px);
|
177 |
+
}
|
178 |
+
|
179 |
+
/* Code Blocks */
|
180 |
+
.stCodeBlock {
|
181 |
+
direction: ltr;
|
182 |
+
text-align: left;
|
183 |
+
border-radius: 8px !important;
|
184 |
+
background: #2d3748 !important;
|
185 |
+
}
|
186 |
+
|
187 |
+
/* Links */
|
188 |
+
a {
|
189 |
+
color: #3182ce;
|
190 |
+
text-decoration: none;
|
191 |
+
transition: color 0.2s ease;
|
192 |
+
}
|
193 |
+
|
194 |
+
a:hover {
|
195 |
+
color: #2c5282;
|
196 |
+
text-decoration: underline;
|
197 |
+
}
|
198 |
+
|
199 |
+
/* Error Messages */
|
200 |
+
.stAlert {
|
201 |
+
border-radius: 8px !important;
|
202 |
+
border: none !important;
|
203 |
+
}
|
204 |
+
</style>
|
205 |
+
""", unsafe_allow_html=True)
|
206 |
+
|
207 |
+
# Initialize session state for message deduplication
|
208 |
+
if "messages" not in st.session_state:
|
209 |
+
st.session_state.messages = []
|
210 |
+
|
211 |
+
st.session_state.api_keys = {
|
212 |
+
'google': "",
|
213 |
+
'openai': "",
|
214 |
+
'anthropic': ""
|
215 |
+
}
|
216 |
+
|
217 |
+
# Sidebar settings
|
218 |
+
with st.sidebar:
|
219 |
+
st.title("הגדרות")
|
220 |
+
|
221 |
+
st.subheader("הגדרת מפתחות API")
|
222 |
+
|
223 |
+
# API Key inputs with improved styling
|
224 |
+
for provider, label in [
|
225 |
+
('google', 'Google API Key'),
|
226 |
+
('openai', 'OpenAI API Key'),
|
227 |
+
('anthropic', 'Anthropic API Key')
|
228 |
+
]:
|
229 |
+
key = st.text_input(
|
230 |
+
label,
|
231 |
+
value=st.session_state.api_keys[provider],
|
232 |
+
type="password",
|
233 |
+
key=f"{provider}_key",
|
234 |
+
help=f"הזן את מפתח ה-API של {label}"
|
235 |
+
)
|
236 |
+
st.session_state.api_keys[provider] = key
|
237 |
+
|
238 |
+
# Provider-specific links
|
239 |
+
links = {
|
240 |
+
'google': 'https://aistudio.google.com/app/apikey',
|
241 |
+
'openai': 'https://platform.openai.com/account/api-keys',
|
242 |
+
'anthropic': 'https://console.anthropic.com/'
|
243 |
+
}
|
244 |
+
st.html(f'<small> ניתן להשיג מפתח <a href="{links[provider]}">כאן</a> </small>')
|
245 |
+
|
246 |
+
st.markdown("---")
|
247 |
+
|
248 |
+
# Initialize system
|
249 |
+
success, status_msg, available_providers = self.initialize_system(st.session_state.api_keys)
|
250 |
+
|
251 |
+
if not success:
|
252 |
+
st.error(status_msg)
|
253 |
+
return
|
254 |
+
|
255 |
+
agent = self.get_agent()
|
256 |
+
|
257 |
+
# Provider selection in sidebar
|
258 |
+
with st.sidebar:
|
259 |
+
|
260 |
+
if 'provider' not in st.session_state or st.session_state.provider not in available_providers:
|
261 |
+
st.session_state.provider = available_providers[0]
|
262 |
+
|
263 |
+
provider = st.selectbox(
|
264 |
+
"ספק בינה מלאכותית",
|
265 |
+
options=available_providers,
|
266 |
+
key='provider',
|
267 |
+
help="בחר את מודל הAI לשימוש (רק מודלים עם מפתח API זמין יוצגו)"
|
268 |
+
)
|
269 |
+
if agent:
|
270 |
+
agent.set_llm(provider)
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
# Main chat interface
|
275 |
+
|
276 |
+
query = st.chat_input("הזן שאלה", key="chat_input")
|
277 |
+
if query:
|
278 |
+
stream = agent.chat(query)
|
279 |
+
for chunk in stream:
|
280 |
+
st.session_state.messages = chunk["messages"]
|
281 |
+
if st.button("צ'אט חדש"):
|
282 |
+
st.session_state.messages = []
|
283 |
+
agent.clear_chat()
|
284 |
+
|
285 |
+
for message in st.session_state.messages:
|
286 |
+
if message.type == "tool":
|
287 |
+
if message.name == "search":
|
288 |
+
results =json.loads(message.content) if message.content else []
|
289 |
+
with st.expander(f"🔍 תוצאות חיפוש: {len(results)}"):
|
290 |
+
for result in results:
|
291 |
+
st.write(result['reference'])
|
292 |
+
st.info(result['text'])
|
293 |
+
elif message.name == "get_text":
|
294 |
+
st.expander(f"📝 טקסט: {message.content}")
|
295 |
+
|
296 |
+
elif message.type == "ai" :
|
297 |
+
if message.content != "":
|
298 |
+
|
299 |
+
with st.chat_message(message.type):
|
300 |
+
if isinstance(message.content, list):
|
301 |
+
for item in message.content:
|
302 |
+
if ('text' in item):
|
303 |
+
st.write(item['text'])
|
304 |
+
|
305 |
+
else:
|
306 |
+
st.write(message.content)
|
307 |
+
|
308 |
+
for tool_call in message.tool_calls:
|
309 |
+
with st.expander(f"🛠️ שימוש בכלי: {tool_call["name"]}"):
|
310 |
+
st.json(tool_call["args"])
|
311 |
+
else:
|
312 |
+
with st.chat_message(message.type):
|
313 |
+
st.write(message.content)
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
app = SearchAgentUI()
|
318 |
+
app.main()
|
chat_gemini.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from random import choices
|
3 |
+
import string
|
4 |
+
from langchain.tools import BaseTool
|
5 |
+
import requests
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
9 |
+
from typing import (
|
10 |
+
Any,
|
11 |
+
Callable,
|
12 |
+
Dict,
|
13 |
+
List,
|
14 |
+
Literal,
|
15 |
+
Mapping,
|
16 |
+
Optional,
|
17 |
+
Sequence,
|
18 |
+
Type,
|
19 |
+
Union,
|
20 |
+
cast,
|
21 |
+
)
|
22 |
+
from langchain_core.callbacks import (
|
23 |
+
CallbackManagerForLLMRun,
|
24 |
+
)
|
25 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
26 |
+
from langchain_core.exceptions import OutputParserException
|
27 |
+
from langchain_core.language_models import LanguageModelInput
|
28 |
+
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
29 |
+
from langchain_core.messages import (
|
30 |
+
AIMessage,
|
31 |
+
BaseMessage,
|
32 |
+
HumanMessage,
|
33 |
+
ToolMessage,
|
34 |
+
SystemMessage,
|
35 |
+
)
|
36 |
+
from langchain_core.outputs import ChatGeneration, ChatResult
|
37 |
+
from langchain_core.runnables import Runnable
|
38 |
+
from langchain_core.tools import BaseTool
|
39 |
+
|
40 |
+
|
41 |
+
class ChatGemini(BaseChatModel):
|
42 |
+
|
43 |
+
@property
|
44 |
+
def _llm_type(self) -> str:
|
45 |
+
"""Get the type of language model used by this chat model."""
|
46 |
+
return "gemini"
|
47 |
+
|
48 |
+
api_key :str
|
49 |
+
base_url:str = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
|
50 |
+
model_kwargs: Any = {}
|
51 |
+
|
52 |
+
def _generate(
|
53 |
+
self,
|
54 |
+
messages: list[BaseMessage],
|
55 |
+
stop: Optional[list[str]] = None,
|
56 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
57 |
+
**kwargs: Any,
|
58 |
+
) -> ChatResult:
|
59 |
+
"""Generate a chat response using the Gemini API.
|
60 |
+
|
61 |
+
This method handles both regular text responses and function calls.
|
62 |
+
For function calls, it returns a ToolMessage with structured function call data
|
63 |
+
that can be processed by Langchain's agent executor.
|
64 |
+
|
65 |
+
Function calls are returned with:
|
66 |
+
- tool_name: The name of the function to call
|
67 |
+
- tool_call_id: A unique identifier for the function call (name is used as Gemini doesn't provide one)
|
68 |
+
- content: The function arguments as a JSON string
|
69 |
+
- additional_kwargs: Contains the full function call details
|
70 |
+
|
71 |
+
Args:
|
72 |
+
messages: List of input messages
|
73 |
+
stop: Optional list of stop sequences
|
74 |
+
run_manager: Optional callback manager
|
75 |
+
**kwargs: Additional arguments passed to the Gemini API
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
ChatResult containing either an AIMessage for text responses
|
79 |
+
or a ToolMessage for function calls
|
80 |
+
"""
|
81 |
+
# Convert messages to Gemini format
|
82 |
+
gemini_messages = []
|
83 |
+
system_message = None
|
84 |
+
for msg in messages:
|
85 |
+
# Handle both dict and LangChain message objects
|
86 |
+
if isinstance(msg, BaseMessage):
|
87 |
+
if isinstance(msg, SystemMessage):
|
88 |
+
system_message = msg.content
|
89 |
+
kwargs["system_instruction"]= {"parts":[{"text": system_message}]}
|
90 |
+
continue
|
91 |
+
if isinstance(msg, HumanMessage):
|
92 |
+
role = "user"
|
93 |
+
content = msg.content
|
94 |
+
elif isinstance(msg, AIMessage):
|
95 |
+
role = "model"
|
96 |
+
content = msg.content
|
97 |
+
elif isinstance(msg, ToolMessage):
|
98 |
+
# Handle tool messages by adding them as function outputs
|
99 |
+
gemini_messages.append(
|
100 |
+
{
|
101 |
+
"role": "model",
|
102 |
+
"parts": [{
|
103 |
+
"functionResponse": {
|
104 |
+
"name": msg.name,
|
105 |
+
"response": {"name": msg.name, "content": msg.content},
|
106 |
+
}}]}
|
107 |
+
)
|
108 |
+
continue
|
109 |
+
else:
|
110 |
+
role = "user" if msg["role"] == "human" else "model"
|
111 |
+
content = msg["content"]
|
112 |
+
|
113 |
+
message_part = {
|
114 |
+
"role": role,
|
115 |
+
"parts":[{"functionCall": { "name": msg.tool_calls[0]["name"], "args": msg.tool_calls[0]["args"]}}] if isinstance(msg, AIMessage) and msg.tool_calls else [{"text": content}]
|
116 |
+
}
|
117 |
+
gemini_messages.append(message_part)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# Prepare the request
|
122 |
+
headers = {
|
123 |
+
"Content-Type": "application/json"
|
124 |
+
}
|
125 |
+
|
126 |
+
params = {
|
127 |
+
"key": self.api_key
|
128 |
+
}
|
129 |
+
|
130 |
+
data = {
|
131 |
+
"contents": gemini_messages,
|
132 |
+
"generationConfig": {
|
133 |
+
"temperature": 0.7,
|
134 |
+
"topP": 0.8,
|
135 |
+
"topK": 40,
|
136 |
+
"maxOutputTokens": 2048,
|
137 |
+
},
|
138 |
+
**kwargs
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
try:
|
143 |
+
response = requests.post(
|
144 |
+
self.base_url,
|
145 |
+
headers=headers,
|
146 |
+
params=params,
|
147 |
+
json=data,
|
148 |
+
verify='C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
|
149 |
+
)
|
150 |
+
response.raise_for_status()
|
151 |
+
|
152 |
+
result = response.json()
|
153 |
+
if "candidates" in result and len(result["candidates"]) > 0 and "parts" in result["candidates"][0]["content"]:
|
154 |
+
parts = result["candidates"][0]["content"]["parts"]
|
155 |
+
tool_calls = []
|
156 |
+
content = ""
|
157 |
+
for part in parts:
|
158 |
+
if "text" in part:
|
159 |
+
content += part["text"]
|
160 |
+
if "functionCall" in part:
|
161 |
+
function_call = part["functionCall"]
|
162 |
+
tool_calls.append( {
|
163 |
+
"name": function_call["name"],
|
164 |
+
"id": function_call["name"]+random_string(5), # Gemini doesn't provide a unique id,}
|
165 |
+
"args": function_call["args"],
|
166 |
+
"type": "tool_call",})
|
167 |
+
# Create a proper ToolMessage with structured function call data
|
168 |
+
return ChatResult(generations=[
|
169 |
+
ChatGeneration(
|
170 |
+
message=AIMessage(
|
171 |
+
content=content,
|
172 |
+
tool_calls=tool_calls,
|
173 |
+
) if len(tool_calls) > 0 else AIMessage(content=content)
|
174 |
+
)
|
175 |
+
])
|
176 |
+
|
177 |
+
|
178 |
+
else:
|
179 |
+
raise Exception("No response generated")
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
raise Exception(f"Error calling Gemini API: {str(e)}")
|
183 |
+
|
184 |
+
|
185 |
+
def bind_tools(
|
186 |
+
self,
|
187 |
+
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
188 |
+
*,
|
189 |
+
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
|
190 |
+
**kwargs: Any,
|
191 |
+
) -> Runnable[LanguageModelInput, BaseMessage]:
|
192 |
+
"""Bind tool-like objects to this chat model.
|
193 |
+
|
194 |
+
|
195 |
+
Args:
|
196 |
+
tools: A list of tool definitions to bind to this chat model.
|
197 |
+
Supports any tool definition handled by
|
198 |
+
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
|
199 |
+
tool_choice: If provided, which tool for model to call. **This parameter
|
200 |
+
is currently ignored as it is not supported by Ollama.**
|
201 |
+
kwargs: Any additional parameters are passed directly to
|
202 |
+
``self.bind(**kwargs)``.
|
203 |
+
"""
|
204 |
+
|
205 |
+
formatted_tools = {"function_declarations": [convert_to_gemini_tool(tool) for tool in tools]}
|
206 |
+
return super().bind(tools=formatted_tools, **kwargs)
|
207 |
+
|
208 |
+
def convert_to_gemini_tool(
|
209 |
+
tool: Union[BaseTool],
|
210 |
+
*,
|
211 |
+
strict: Optional[bool] = None,
|
212 |
+
) -> dict[str, Any]:
|
213 |
+
"""Convert a tool-like object to an Gemini tool schema.
|
214 |
+
|
215 |
+
Gemini tool schema reference:
|
216 |
+
https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode
|
217 |
+
|
218 |
+
Args:
|
219 |
+
tool:
|
220 |
+
BaseTool.
|
221 |
+
strict:
|
222 |
+
If True, model output is guaranteed to exactly match the JSON Schema
|
223 |
+
provided in the function definition. If None, ``strict`` argument will not
|
224 |
+
be included in tool definition.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
A dict version of the passed in tool which is compatible with the
|
228 |
+
Gemini tool-calling API.
|
229 |
+
"""
|
230 |
+
if isinstance(tool, BaseTool):
|
231 |
+
# Extract the tool's schema
|
232 |
+
schema = tool.args_schema.schema() if tool.args_schema else {"type": "object", "properties": {}}
|
233 |
+
|
234 |
+
#convert to gemini schema
|
235 |
+
raw_properties = schema.get("properties", {})
|
236 |
+
properties = {}
|
237 |
+
for key, value in raw_properties.items():
|
238 |
+
properties[key] = {
|
239 |
+
"type": value.get("type", "string"),
|
240 |
+
"description": value.get("title", ""),
|
241 |
+
}
|
242 |
+
|
243 |
+
|
244 |
+
# Build the function definition
|
245 |
+
function_def = {
|
246 |
+
"name": tool.name,
|
247 |
+
"description": tool.description,
|
248 |
+
"parameters": {
|
249 |
+
"type": "object",
|
250 |
+
"properties": properties,
|
251 |
+
"required": schema.get("required", [])
|
252 |
+
}
|
253 |
+
}
|
254 |
+
|
255 |
+
if strict is not None:
|
256 |
+
function_def["strict"] = strict
|
257 |
+
|
258 |
+
return function_def
|
259 |
+
else:
|
260 |
+
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
261 |
+
|
262 |
+
def random_string(length: int) -> str:
|
263 |
+
return ''.join(choices(string.ascii_letters + string.digits, k=length))
|
264 |
+
|
llm_providers.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_anthropic import ChatAnthropic
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain_ollama import ChatOllama
|
4 |
+
from langchain_core.language_models.base import BaseLanguageModel
|
5 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
6 |
+
from typing import Optional, Dict, List, Any, Sequence
|
7 |
+
from langchain.tools import BaseTool
|
8 |
+
import os
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
from dataclasses import dataclass
|
13 |
+
import ollama
|
14 |
+
import copy
|
15 |
+
from chat_gemini import ChatGemini
|
16 |
+
|
17 |
+
|
18 |
+
load_dotenv()
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class LLMProvider:
|
23 |
+
|
24 |
+
def __init__(self, api_keys: Optional[Dict[str, str]] = None):
|
25 |
+
if api_keys:
|
26 |
+
self.api_keys = api_keys
|
27 |
+
self.providers: Dict[str, Any] = {}
|
28 |
+
self._setup_providers()
|
29 |
+
|
30 |
+
def _get_ollama_models(self) -> List[str]:
|
31 |
+
"""Get list of available Ollama models using the ollama package"""
|
32 |
+
try:
|
33 |
+
models = ollama.list()
|
34 |
+
return [model.model for model in models['models']]
|
35 |
+
except Exception:
|
36 |
+
return []
|
37 |
+
|
38 |
+
def _setup_providers(self):
|
39 |
+
os.environ['REQUESTS_CA_BUNDLE'] = 'C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
|
40 |
+
|
41 |
+
# Google Gemini
|
42 |
+
if google_key := os.getenv('GOOGLE_API_KEY') or self.api_keys.get('google'):
|
43 |
+
self.providers['Gemini'] = ChatGemini(api_key=google_key)
|
44 |
+
|
45 |
+
# Anthropicsel
|
46 |
+
if anthropic_key := os.getenv('ANTHROPIC_API_KEY') or self.api_keys.get('anthropic'):
|
47 |
+
self.providers['Claude'] = ChatAnthropic(
|
48 |
+
api_key=anthropic_key,
|
49 |
+
model_name="claude-3-5-sonnet-20241022",
|
50 |
+
)
|
51 |
+
|
52 |
+
# OpenAI
|
53 |
+
if openai_key := os.getenv('OPENAI_API_KEY') or self.api_keys.get('openai'):
|
54 |
+
self.providers['ChatGPT'] = ChatOpenAI(
|
55 |
+
api_key=openai_key,
|
56 |
+
model_name="gpt-4o-2024-11-20",
|
57 |
+
max_completion_tokens=4096,
|
58 |
+
|
59 |
+
)
|
60 |
+
|
61 |
+
# Ollama (local)
|
62 |
+
try:
|
63 |
+
# Get available Ollama models using the ollama package
|
64 |
+
ollama_models = self._get_ollama_models()
|
65 |
+
for model in ollama_models:
|
66 |
+
self.providers[f'Ollama-{model}'] = ChatOllama(model=model)
|
67 |
+
except Exception:
|
68 |
+
pass # Ollama not available
|
69 |
+
|
70 |
+
def get_available_providers(self) -> list[str]:
|
71 |
+
"""Return list of available provider names"""
|
72 |
+
|
73 |
+
return list(self.providers.keys())
|
74 |
+
|
75 |
+
def get_provider(self, name: str) -> Optional[Any]:
|
76 |
+
"""Get LLM provider by name"""
|
77 |
+
return self.providers.get(name)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langgraph
|
3 |
+
python-dotenv
|
4 |
+
flet
|
5 |
+
langchain-community
|
6 |
+
langchain-core
|
7 |
+
langchain-openai
|
8 |
+
langchain-anthropic
|
9 |
+
langchain-ollama
|
10 |
+
ollama
|
11 |
+
requests
|
12 |
+
tantivy
|
13 |
+
gdown
|
14 |
+
pydantic
|
sefaria.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests, json
|
2 |
+
|
3 |
+
SEFARIA_API_BASE_URL = "http://localhost:8000"
|
4 |
+
|
5 |
+
def _get_request_json_data(endpoint, ref=None, param=None):
|
6 |
+
"""
|
7 |
+
Helper function to make GET requests to the Sefaria API and parse the JSON response.
|
8 |
+
"""
|
9 |
+
url = f"{SEFARIA_API_BASE_URL}/{endpoint}"
|
10 |
+
|
11 |
+
if ref:
|
12 |
+
url += f"{ref}"
|
13 |
+
|
14 |
+
if param:
|
15 |
+
url += f"?{param}"
|
16 |
+
|
17 |
+
try:
|
18 |
+
response = requests.get(url)
|
19 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
20 |
+
data = response.json()
|
21 |
+
return data
|
22 |
+
except requests.exceptions.RequestException as e:
|
23 |
+
print(f"Error during API request: {e}")
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
def get_text(reference: str) -> str:
|
28 |
+
"""
|
29 |
+
Retrieves the text for a given reference.
|
30 |
+
"""
|
31 |
+
return str(_get_hebrew_text(reference))
|
32 |
+
|
33 |
+
|
34 |
+
def get_weekly_parasha():
|
35 |
+
"""
|
36 |
+
Retrieves the weekly Parasha data using the Calendars API.
|
37 |
+
"""
|
38 |
+
data = _get_request_json_data("api/calendars")
|
39 |
+
|
40 |
+
if data:
|
41 |
+
calendar_items = data.get('calendar_items', [])
|
42 |
+
for item in calendar_items:
|
43 |
+
if item.get('title', {}).get('en') == 'Parashat Hashavua':
|
44 |
+
parasha_ref = item.get('ref')
|
45 |
+
parasha_description = item.get('description', {}).get('he')
|
46 |
+
parasha_name_he = item.get('displayValue', {}).get('he')
|
47 |
+
return {
|
48 |
+
"ref": parasha_ref,
|
49 |
+
"description": parasha_description,
|
50 |
+
"name_he": parasha_name_he
|
51 |
+
}
|
52 |
+
|
53 |
+
print("Could not retrieve Parasha data.")
|
54 |
+
return None
|
55 |
+
|
56 |
+
|
57 |
+
def _get_hebrew_text(parasha_ref):
|
58 |
+
"""
|
59 |
+
Retrieves the Hebrew text and version title for the given verse.
|
60 |
+
"""
|
61 |
+
data = _get_request_json_data("api/v3/texts/", parasha_ref)
|
62 |
+
|
63 |
+
if data and "versions" in data and len(data['versions']) > 0:
|
64 |
+
he_pasuk = data['versions'][0]['text']
|
65 |
+
return he_pasuk
|
66 |
+
else:
|
67 |
+
print(f"Could not retrieve Hebrew text for {parasha_ref}")
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
+
def get_commentaries(parasha_ref)-> list[str]:
|
72 |
+
"""
|
73 |
+
Retrieves and filters commentaries on the given verse.
|
74 |
+
"""
|
75 |
+
data = _get_request_json_data("api/related/", parasha_ref)
|
76 |
+
|
77 |
+
commentaries = []
|
78 |
+
if data and "links" in data:
|
79 |
+
for linked_text in data["links"]:
|
80 |
+
if linked_text.get('type') == 'commentary':
|
81 |
+
commentaries.append(linked_text.get('sourceHeRef'))
|
82 |
+
|
83 |
+
return commentaries
|
tantivy_search.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional
|
2 |
+
from tantivy import Index
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
class TantivySearch:
|
9 |
+
def __init__(self, index_path: str):
|
10 |
+
"""Initialize the Tantivy search agent with the index path"""
|
11 |
+
self.index_path = index_path
|
12 |
+
self.logger = logging.getLogger(__name__)
|
13 |
+
try:
|
14 |
+
self.index = Index.open(index_path)
|
15 |
+
self.logger.info(f"Successfully opened Tantivy index at {index_path}")
|
16 |
+
except Exception as e:
|
17 |
+
self.logger.error(f"Failed to open Tantivy index: {e}")
|
18 |
+
raise
|
19 |
+
|
20 |
+
def get_query_instructions(self) -> str:
|
21 |
+
"""Return instructions for the LLM on how to parse and construct Tantivy queries"""
|
22 |
+
return """
|
23 |
+
Instructions for generating a query:
|
24 |
+
|
25 |
+
1. Boolean Operators:
|
26 |
+
|
27 |
+
- AND: term1 AND term2 (both required)
|
28 |
+
- OR: term1 OR term2 (either term)
|
29 |
+
- Multiple words default to OR operation (cloud network = cloud OR network)
|
30 |
+
- AND takes precedence over OR
|
31 |
+
- Example: Shabath AND (walk OR go)
|
32 |
+
|
33 |
+
2. Field-specific Terms:
|
34 |
+
- Field-specific terms: field:term
|
35 |
+
- Example: text:אדם AND reference:בראשית
|
36 |
+
- available fields: text, reference, topics
|
37 |
+
- text contains the text of the document
|
38 |
+
- reference contains the citation of the document, e.g. בראשית, פרק א
|
39 |
+
- topics contains the topics of the document. available topics includes: תנך, הלכה, מדרש, etc.
|
40 |
+
|
41 |
+
3. Required/Excluded Terms:
|
42 |
+
- Required (+): +term (must contain)
|
43 |
+
- Excluded (-): -term (must not contain)
|
44 |
+
- Example: +security cloud -deprecated
|
45 |
+
- Equivalent to: security AND cloud AND NOT deprecated
|
46 |
+
|
47 |
+
4. Phrase Search:
|
48 |
+
- Use quotes: "exact phrase"
|
49 |
+
- Both single/double quotes work
|
50 |
+
- Escape quotes with \\"
|
51 |
+
- Slop operator: "term1 term2"~N
|
52 |
+
- Example: "cloud security"~2
|
53 |
+
- the above will find "cloud framework and security "
|
54 |
+
- Prefix matching: "start of phrase"*
|
55 |
+
|
56 |
+
5. Wildcards:
|
57 |
+
- ? for single character
|
58 |
+
- * for any number of characters
|
59 |
+
- Example: sec?rity cloud*
|
60 |
+
|
61 |
+
6. Special Features:
|
62 |
+
- All docs: *
|
63 |
+
- Boost terms: term^2.0 (positive numbers only)
|
64 |
+
- Example: security^2.0 cloud
|
65 |
+
- the above will boost security by 2.0
|
66 |
+
|
67 |
+
Query Examples:
|
68 |
+
1. Basic: +שבת +חולה +אסור
|
69 |
+
2. Field-specific: text:סיני AND topics:תנך
|
70 |
+
3. Phrase with slop: "security framework"~2
|
71 |
+
4. Complex: +reference:בראשית +text:"הבל"^2.0 +(דמי OR דמים) -הבלים
|
72 |
+
6. Mixed: (text:"רבנו משה"^2.0 OR reference:"משנה תורה") AND topics:הלכה) AND text:"תורה המלך"~3 AND NOT topics:מדרש
|
73 |
+
|
74 |
+
Tips:
|
75 |
+
- Group complex expressions with parentheses
|
76 |
+
- Use quotes for exact phrases
|
77 |
+
- Add + for required terms, - for excluded terms
|
78 |
+
- Boost important terms with ^N
|
79 |
+
- use field-specific terms for better results.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def search(self, query: str, num_results: int = 10) -> List[Dict[str, Any]]:
|
83 |
+
"""Search the Tantivy index with the given query using Tantivy's query syntax"""
|
84 |
+
try:
|
85 |
+
# Create a searcher
|
86 |
+
searcher = self.index.searcher()
|
87 |
+
|
88 |
+
# Parse and execute the query
|
89 |
+
try:
|
90 |
+
# First try with lenient parsing
|
91 |
+
query_parser = self.index.parse_query_lenient(query)
|
92 |
+
search_results = searcher.search(query_parser[0], num_results).hits
|
93 |
+
|
94 |
+
except Exception as query_error:
|
95 |
+
self.logger.error(f"Lenient query parsing failed: {query_error}")
|
96 |
+
|
97 |
+
# Process results
|
98 |
+
results = []
|
99 |
+
for score, doc_address in search_results:
|
100 |
+
doc = searcher.doc(doc_address)
|
101 |
+
text = doc.get_first("text")
|
102 |
+
|
103 |
+
# Extract highlighted snippets based on query terms
|
104 |
+
# Remove special syntax for highlighting while preserving Hebrew
|
105 |
+
highlight_terms = re.sub(
|
106 |
+
r'[:"()[\]{}^~*\\]|\b(AND|OR|NOT|TO|IN)\b|[-+]',
|
107 |
+
' ',
|
108 |
+
query
|
109 |
+
).strip()
|
110 |
+
highlight_terms = [term for term in highlight_terms.split() if len(term) > 1]
|
111 |
+
|
112 |
+
# Create regex pattern for highlighting
|
113 |
+
if highlight_terms:
|
114 |
+
# Escape regex special chars but preserve Hebrew
|
115 |
+
patterns = [re.escape(term) for term in highlight_terms]
|
116 |
+
pattern = '|'.join(patterns)
|
117 |
+
# Get surrounding context for matches
|
118 |
+
matches = list(re.finditer(pattern, text, re.IGNORECASE))
|
119 |
+
if matches:
|
120 |
+
highlights = []
|
121 |
+
for match in matches:
|
122 |
+
start = max(0, match.start() - 50)
|
123 |
+
end = min(len(text), match.end() + 50)
|
124 |
+
highlight = text[start:end]
|
125 |
+
if start > 0:
|
126 |
+
highlight = f"...{highlight}"
|
127 |
+
if end < len(text):
|
128 |
+
highlight = f"{highlight}..."
|
129 |
+
highlights.append(highlight)
|
130 |
+
else:
|
131 |
+
highlights = [text[:100] + "..." if len(text) > 100 else text]
|
132 |
+
else:
|
133 |
+
highlights = [text[:100] + "..." if len(text) > 100 else text]
|
134 |
+
|
135 |
+
result = {
|
136 |
+
"score": float(score),
|
137 |
+
"title": doc.get_first("title") or os.path.basename(doc.get_first("filePath") or ""),
|
138 |
+
"reference": doc.get_first("reference"),
|
139 |
+
"topics": doc.get_first("topics"),
|
140 |
+
"file_path": doc.get_first("filePath"),
|
141 |
+
"line_number": doc.get_first("segment"),
|
142 |
+
"is_pdf": doc.get_first("isPdf"),
|
143 |
+
"text": text,
|
144 |
+
"highlights": highlights
|
145 |
+
}
|
146 |
+
results.append(result)
|
147 |
+
|
148 |
+
self.logger.info(f"Found {len(results)} results for query: {query}")
|
149 |
+
return results
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
self.logger.error(f"Error during search: {str(e)}")
|
153 |
+
return []
|
154 |
+
|
155 |
+
def validate_index(self) -> bool:
|
156 |
+
"""Validate that the index exists and is accessible"""
|
157 |
+
try:
|
158 |
+
# Try to create a searcher and perform a simple search
|
159 |
+
searcher = self.index.searcher()
|
160 |
+
query_parser = self.index.parse_query("*")
|
161 |
+
searcher.search(query_parser, 1)
|
162 |
+
return True
|
163 |
+
except Exception as e:
|
164 |
+
self.logger.error(f"Index validation failed: {e}")
|
165 |
+
return False
|
tools.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
from sefaria import get_text as sefaria_get_text, get_commentaries as sefaria_get_commentaries
|
3 |
+
from tantivy_search import TantivySearch
|
4 |
+
from typing import Optional
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
|
7 |
+
from app import INDEX_PATH
|
8 |
+
|
9 |
+
class ReadTextArgs(BaseModel):
|
10 |
+
reference: str = Field(description="The reference to retrieve the text for. examples: בראשית א פרק א, Genesis 1:1")
|
11 |
+
|
12 |
+
class SearchArgs(BaseModel):
|
13 |
+
query: str = Field(description="""the query for the search.
|
14 |
+
Instructions for generating a query:
|
15 |
+
|
16 |
+
1. Boolean Operators:
|
17 |
+
|
18 |
+
- AND: term1 AND term2 (both required)
|
19 |
+
- OR: term1 OR term2 (either term)
|
20 |
+
- Multiple words default to OR operation (cloud network = cloud OR network)
|
21 |
+
- AND takes precedence over OR
|
22 |
+
- Example: Shabath AND (walk OR go)
|
23 |
+
|
24 |
+
2. Field-specific Terms:
|
25 |
+
- Field-specific terms: field:term
|
26 |
+
- Example: text:אדם AND reference:בראשית
|
27 |
+
- available fields: text, reference, topics
|
28 |
+
- text contains the text of the document
|
29 |
+
- reference contains the citation of the document, e.g. בראשית, פרק א
|
30 |
+
- topics contains the topics of the document. available topics includes: תנך, הלכה, מדרש, etc.
|
31 |
+
|
32 |
+
3. Required/Excluded Terms:
|
33 |
+
- Required (+): +term (must contain)
|
34 |
+
- Excluded (-): -term (must not contain)
|
35 |
+
- Example: +security cloud -deprecated
|
36 |
+
- Equivalent to: security AND cloud AND NOT deprecated
|
37 |
+
|
38 |
+
4. Phrase Search:
|
39 |
+
- Use quotes: "exact phrase"
|
40 |
+
- Both single/double quotes work
|
41 |
+
- Escape quotes with \\"
|
42 |
+
- Slop operator: "term1 term2"~N
|
43 |
+
- Example: "cloud security"~2
|
44 |
+
- the above will find "cloud framework and security "
|
45 |
+
- Prefix matching: "start of phrase"*
|
46 |
+
|
47 |
+
5. Wildcards:
|
48 |
+
- ? for single character
|
49 |
+
- * for any number of characters
|
50 |
+
- Example: sec?rity cloud*
|
51 |
+
|
52 |
+
6. Special Features:
|
53 |
+
- All docs: *
|
54 |
+
- Boost terms: term^2.0 (positive numbers only)
|
55 |
+
- Example: security^2.0 cloud
|
56 |
+
- the above will boost security by 2.0
|
57 |
+
|
58 |
+
Query Examples:
|
59 |
+
1. Basic: +שבת +חולה +אסור
|
60 |
+
2. Field-specific: text:סיני AND topics:תנך
|
61 |
+
3. Phrase with slop: "security framework"~2
|
62 |
+
4. Complex: +reference:בראשית +text:"הבל"^2.0 +(דמי OR דמים) -הבלים
|
63 |
+
6. Mixed: (text:"רבנו משה"^2.0 OR reference:"משנה תורה") AND topics:הלכה) AND text:"תורה המלך"~3 AND NOT topics:מדרש
|
64 |
+
|
65 |
+
Tips:
|
66 |
+
- Group complex expressions with parentheses
|
67 |
+
- Use quotes for exact phrases
|
68 |
+
- Add + for required terms, - for excluded terms
|
69 |
+
- Boost important terms with ^N
|
70 |
+
- use field-specific terms for better results.
|
71 |
+
- the corpus to search in is an ancient Hebrew corpus: Tora and Talmud. so Try to use ancient Hebrew terms and or Talmudic expressions and prevent modern words that are not common in talmudic texts
|
72 |
+
""")
|
73 |
+
num_results: int = Field(description="the maximum number of results to return. Default: 10", default=10)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
index_path = INDEX_PATH
|
78 |
+
try:
|
79 |
+
tantivy = TantivySearch(index_path)
|
80 |
+
tantivy.validate_index()
|
81 |
+
except Exception as e:
|
82 |
+
raise Exception(f"failed to create index: {e}")
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
@tool(args_schema=SearchArgs)
|
87 |
+
def search( query: str, num_results: int = 10):
|
88 |
+
"""Searches the index for the given query."""
|
89 |
+
results = tantivy.search(query, num_results)
|
90 |
+
formatted_results = []
|
91 |
+
for result in results:
|
92 |
+
formatted_results.append({
|
93 |
+
'text': result.get('text', 'N/A'),
|
94 |
+
'reference': result.get('reference', 'N/A')
|
95 |
+
})
|
96 |
+
|
97 |
+
return formatted_results
|
98 |
+
|
99 |
+
|
100 |
+
@tool(args_schema=ReadTextArgs)
|
101 |
+
def read_text(reference: str )->str:
|
102 |
+
"""Retrieves the text for a given reference.
|
103 |
+
"""
|
104 |
+
text = sefaria_get_text(reference)
|
105 |
+
return {
|
106 |
+
'text': str(text),
|
107 |
+
'reference': reference
|
108 |
+
}
|
109 |
+
|
110 |
+
@tool
|
111 |
+
def get_commentaries(reference: str, num_results: int = 10)->str:
|
112 |
+
"""Retrieves references to all available commentaries on the given verse."""
|
113 |
+
commentaries = sefaria_get_commentaries(reference)
|
114 |
+
return {
|
115 |
+
'text': '\n'.join(commentaries) if isinstance(commentaries, list) else str(commentaries),
|
116 |
+
'reference': f"Commentaries on {reference}"
|
117 |
+
}
|
118 |
+
|