Spaces:
Runtime error
Runtime error
Asaad Almutareb
commited on
Commit
•
2e6490e
1
Parent(s):
5c0a79d
added sqlite schema and handling
Browse files
innovation_pathfinder_ai/database/db_handler.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlmodel import SQLModel, create_engine, Session, select
|
2 |
+
from innovation_pathfinder_ai.database.schema import Sources
|
3 |
+
from innovation_pathfinder_ai.utils.logger import get_console_logger
|
4 |
+
|
5 |
+
sqlite_file_name = "database.db"
|
6 |
+
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
7 |
+
engine = create_engine(sqlite_url, echo=False)
|
8 |
+
|
9 |
+
logger = get_console_logger("db_handler")
|
10 |
+
|
11 |
+
SQLModel.metadata.create_all(engine)
|
12 |
+
|
13 |
+
|
14 |
+
def read_one(hash_id: dict):
|
15 |
+
with Session(engine) as session:
|
16 |
+
statement = select(Sources).where(Sources.hash_id == hash_id)
|
17 |
+
sources = session.exec(statement).first()
|
18 |
+
return sources
|
19 |
+
|
20 |
+
|
21 |
+
def add_one(data: dict):
|
22 |
+
with Session(engine) as session:
|
23 |
+
if session.exec(
|
24 |
+
select(Sources).where(Sources.hash_id == data.get("hash_id"))
|
25 |
+
).first():
|
26 |
+
logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
|
27 |
+
return None # or raise an exception, or handle as needed
|
28 |
+
sources = Sources(**data)
|
29 |
+
session.add(sources)
|
30 |
+
session.commit()
|
31 |
+
session.refresh(sources)
|
32 |
+
logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
|
33 |
+
return sources
|
34 |
+
|
35 |
+
|
36 |
+
def update_one(hash_id: dict, data: dict):
|
37 |
+
with Session(engine) as session:
|
38 |
+
# Check if the item with the given hash_id exists
|
39 |
+
sources = session.exec(
|
40 |
+
select(Sources).where(Sources.hash_id == hash_id)
|
41 |
+
).first()
|
42 |
+
if not sources:
|
43 |
+
logger.warning(f"No item with hash_id {hash_id} found for update")
|
44 |
+
return None # or raise an exception, or handle as needed
|
45 |
+
for key, value in data.items():
|
46 |
+
setattr(sources, key, value)
|
47 |
+
session.commit()
|
48 |
+
logger.info(f"Item with hash_id {hash_id} updated in the database")
|
49 |
+
return sources
|
50 |
+
|
51 |
+
|
52 |
+
def delete_one(id: int):
|
53 |
+
with Session(engine) as session:
|
54 |
+
# Check if the item with the given hash_id exists
|
55 |
+
sources = session.exec(
|
56 |
+
select(Sources).where(Sources.hash_id == id)
|
57 |
+
).first()
|
58 |
+
if not sources:
|
59 |
+
logger.warning(f"No item with hash_id {id} found for deletion")
|
60 |
+
return None # or raise an exception, or handle as needed
|
61 |
+
session.delete(sources)
|
62 |
+
session.commit()
|
63 |
+
logger.info(f"Item with hash_id {id} deleted from the database")
|
64 |
+
|
65 |
+
|
66 |
+
def add_many(data: list):
|
67 |
+
with Session(engine) as session:
|
68 |
+
for info in data:
|
69 |
+
# Reuse add_one function for each item
|
70 |
+
result = add_one(info)
|
71 |
+
if result is None:
|
72 |
+
logger.warning(
|
73 |
+
f"Item with hash_id {info.get('hash_id')} could not be added"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
logger.info(
|
77 |
+
f"Item with hash_id {info.get('hash_id')} added to the database"
|
78 |
+
)
|
79 |
+
session.commit() # Commit at the end of the loop
|
80 |
+
|
81 |
+
|
82 |
+
def delete_many(ids: list):
|
83 |
+
with Session(engine) as session:
|
84 |
+
for id in ids:
|
85 |
+
# Reuse delete_one function for each item
|
86 |
+
result = delete_one(id)
|
87 |
+
if result is None:
|
88 |
+
logger.warning(f"No item with hash_id {id} found for deletion")
|
89 |
+
else:
|
90 |
+
logger.info(f"Item with hash_id {id} deleted from the database")
|
91 |
+
session.commit() # Commit at the end of the loop
|
92 |
+
|
93 |
+
|
94 |
+
def read_all(query: dict = None):
|
95 |
+
with Session(engine) as session:
|
96 |
+
statement = select(Sources)
|
97 |
+
if query:
|
98 |
+
statement = statement.where(
|
99 |
+
*[getattr(Sources, key) == value for key, value in query.items()]
|
100 |
+
)
|
101 |
+
sources = session.exec(statement).all()
|
102 |
+
return sources
|
103 |
+
|
104 |
+
|
105 |
+
def delete_all():
|
106 |
+
with Session(engine) as session:
|
107 |
+
session.exec(Sources).delete()
|
108 |
+
session.commit()
|
109 |
+
logger.info("All items deleted from the database")
|
innovation_pathfinder_ai/database/schema.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlmodel import SQLModel, Field
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
class Sources(SQLModel, table=True):
|
7 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
8 |
+
url: str = Field()
|
9 |
+
title: Optional[str] = Field(default="NA", unique=False)
|
10 |
+
hash_id: str = Field(unique=True)
|
11 |
+
created_at: float = Field(default=datetime.datetime.now().timestamp())
|
12 |
+
summary: str = Field(default="")
|
13 |
+
embedded: bool = Field(default=False)
|
14 |
+
|
15 |
+
__table_args__ = {"extend_existing": True}
|
innovation_pathfinder_ai/structured_tools/structured_tools.py
CHANGED
@@ -6,31 +6,32 @@ from langchain_community.utilities import WikipediaAPIWrapper
|
|
6 |
#from langchain.tools import Tool
|
7 |
from langchain_community.utilities import GoogleSearchAPIWrapper
|
8 |
import arxiv
|
9 |
-
|
10 |
# hacky and should be replaced with a database
|
11 |
from innovation_pathfinder_ai.source_container.container import (
|
12 |
all_sources
|
13 |
)
|
14 |
-
from innovation_pathfinder_ai.utils import
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@tool
|
17 |
def arxiv_search(query: str) -> str:
|
18 |
"""Search arxiv database for scientific research papers and studies. This is your primary information source.
|
19 |
always check it first when you search for information, before using any other tool."""
|
20 |
-
# return "LangChain"
|
21 |
global all_sources
|
22 |
-
arxiv_retriever = ArxivRetriever(load_max_docs=
|
23 |
data = arxiv_retriever.invoke(query)
|
24 |
meta_data = [i.metadata for i in data]
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# formatted_info = format_info_list(all_sources)
|
32 |
-
|
33 |
-
return meta_data.__str__()
|
34 |
|
35 |
@tool
|
36 |
def get_arxiv_paper(paper_id:str) -> None:
|
@@ -52,17 +53,13 @@ def get_arxiv_paper(paper_id:str) -> None:
|
|
52 |
@tool
|
53 |
def google_search(query: str) -> str:
|
54 |
"""Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
|
55 |
-
# return "LangChain"
|
56 |
global all_sources
|
57 |
|
58 |
websearch = GoogleSearchAPIWrapper()
|
59 |
-
search_results:dict = websearch.results(query,
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
# formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
|
64 |
-
cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in search_results]
|
65 |
-
|
66 |
all_sources += cleaner_sources
|
67 |
|
68 |
return cleaner_sources.__str__()
|
@@ -75,5 +72,9 @@ def wikipedia_search(query: str) -> str:
|
|
75 |
api_wrapper = WikipediaAPIWrapper()
|
76 |
wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
|
77 |
wikipedia_results = wikipedia_search.run(query)
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
6 |
#from langchain.tools import Tool
|
7 |
from langchain_community.utilities import GoogleSearchAPIWrapper
|
8 |
import arxiv
|
9 |
+
import ast
|
10 |
# hacky and should be replaced with a database
|
11 |
from innovation_pathfinder_ai.source_container.container import (
|
12 |
all_sources
|
13 |
)
|
14 |
+
from innovation_pathfinder_ai.utils.utils import (
|
15 |
+
parse_list_to_dicts, format_wiki_summaries, format_arxiv_documents, format_search_results
|
16 |
+
)
|
17 |
+
from innovation_pathfinder_ai.database.db_handler import (
|
18 |
+
add_many
|
19 |
+
)
|
20 |
|
21 |
@tool
|
22 |
def arxiv_search(query: str) -> str:
|
23 |
"""Search arxiv database for scientific research papers and studies. This is your primary information source.
|
24 |
always check it first when you search for information, before using any other tool."""
|
|
|
25 |
global all_sources
|
26 |
+
arxiv_retriever = ArxivRetriever(load_max_docs=3)
|
27 |
data = arxiv_retriever.invoke(query)
|
28 |
meta_data = [i.metadata for i in data]
|
29 |
+
formatted_sources = format_arxiv_documents(data)
|
30 |
+
all_sources += formatted_sources
|
31 |
+
parsed_sources = parse_list_to_dicts(formatted_sources)
|
32 |
+
add_many(parsed_sources)
|
33 |
+
|
34 |
+
return data.__str__()
|
|
|
|
|
|
|
35 |
|
36 |
@tool
|
37 |
def get_arxiv_paper(paper_id:str) -> None:
|
|
|
53 |
@tool
|
54 |
def google_search(query: str) -> str:
|
55 |
"""Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
|
|
|
56 |
global all_sources
|
57 |
|
58 |
websearch = GoogleSearchAPIWrapper()
|
59 |
+
search_results:dict = websearch.results(query, 3)
|
60 |
+
cleaner_sources =format_search_results(search_results)
|
61 |
+
parsed_csources = parse_list_to_dicts(cleaner_sources)
|
62 |
+
add_many(parsed_csources)
|
|
|
|
|
|
|
63 |
all_sources += cleaner_sources
|
64 |
|
65 |
return cleaner_sources.__str__()
|
|
|
72 |
api_wrapper = WikipediaAPIWrapper()
|
73 |
wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
|
74 |
wikipedia_results = wikipedia_search.run(query)
|
75 |
+
formatted_summaries = format_wiki_summaries(wikipedia_results)
|
76 |
+
all_sources += formatted_summaries
|
77 |
+
parsed_summaries = parse_list_to_dicts(formatted_summaries)
|
78 |
+
add_many(parsed_summaries)
|
79 |
+
|
80 |
+
return wikipedia_results.__str__()
|