Spaces:
Running
Running
ugmSorcero
commited on
Commit
Β·
39503cb
1
Parent(s):
8d3aacc
Adds linter and fixes linting
Browse files- app.py +1 -3
- core/pipelines.py +14 -4
- core/search_index.py +9 -5
- interface/components.py +23 -13
- interface/pages.py +25 -19
- linter.sh +1 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -5,9 +5,7 @@ st.set_page_config(
|
|
5 |
page_icon="π",
|
6 |
layout="wide",
|
7 |
initial_sidebar_state="expanded",
|
8 |
-
menu_items={
|
9 |
-
'About': "https://github.com/ugm2/neural-search-demo"
|
10 |
-
}
|
11 |
)
|
12 |
|
13 |
from streamlit_option_menu import option_menu
|
|
|
5 |
page_icon="π",
|
6 |
layout="wide",
|
7 |
initial_sidebar_state="expanded",
|
8 |
+
menu_items={"About": "https://github.com/ugm2/neural-search-demo"},
|
|
|
|
|
9 |
)
|
10 |
|
11 |
from streamlit_option_menu import option_menu
|
core/pipelines.py
CHANGED
@@ -9,9 +9,10 @@ from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
|
9 |
from haystack.nodes.preprocessor import PreProcessor
|
10 |
import streamlit as st
|
11 |
|
|
|
12 |
@st.cache(allow_output_mutation=True)
|
13 |
def keyword_search(
|
14 |
-
index=
|
15 |
):
|
16 |
document_store = InMemoryDocumentStore(index=index)
|
17 |
keyword_retriever = TfidfRetriever(document_store=(document_store))
|
@@ -31,16 +32,25 @@ def keyword_search(
|
|
31 |
# INDEXING PIPELINE
|
32 |
index_pipeline = Pipeline()
|
33 |
index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
|
34 |
-
index_pipeline.add_node(
|
|
|
|
|
35 |
index_pipeline.add_node(
|
36 |
document_store, name="DocumentStore", inputs=["TfidfRetriever"]
|
37 |
)
|
38 |
|
39 |
return search_pipeline, index_pipeline
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def dense_passage_retrieval(
|
43 |
-
index=
|
44 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
45 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
46 |
):
|
|
|
9 |
from haystack.nodes.preprocessor import PreProcessor
|
10 |
import streamlit as st
|
11 |
|
12 |
+
|
13 |
@st.cache(allow_output_mutation=True)
|
14 |
def keyword_search(
|
15 |
+
index="documents",
|
16 |
):
|
17 |
document_store = InMemoryDocumentStore(index=index)
|
18 |
keyword_retriever = TfidfRetriever(document_store=(document_store))
|
|
|
32 |
# INDEXING PIPELINE
|
33 |
index_pipeline = Pipeline()
|
34 |
index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
|
35 |
+
index_pipeline.add_node(
|
36 |
+
keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"]
|
37 |
+
)
|
38 |
index_pipeline.add_node(
|
39 |
document_store, name="DocumentStore", inputs=["TfidfRetriever"]
|
40 |
)
|
41 |
|
42 |
return search_pipeline, index_pipeline
|
43 |
|
44 |
+
|
45 |
+
@st.cache(
|
46 |
+
hash_funcs={
|
47 |
+
tokenizers.Tokenizer: lambda _: None,
|
48 |
+
tokenizers.AddedToken: lambda _: None,
|
49 |
+
},
|
50 |
+
allow_output_mutation=True,
|
51 |
+
)
|
52 |
def dense_passage_retrieval(
|
53 |
+
index="documents",
|
54 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
55 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
56 |
):
|
core/search_index.py
CHANGED
@@ -6,9 +6,9 @@ def format_docs(documents):
|
|
6 |
"""Given a list of documents, format the documents and return the documents and doc ids."""
|
7 |
db_docs: list = []
|
8 |
for doc in documents:
|
9 |
-
doc_id = doc[
|
10 |
db_doc = {
|
11 |
-
"content": doc[
|
12 |
"content_type": "text",
|
13 |
"id": str(uuid.uuid4()),
|
14 |
"meta": {"id": doc_id},
|
@@ -16,11 +16,13 @@ def format_docs(documents):
|
|
16 |
db_docs.append(Document(**db_doc))
|
17 |
return db_docs, [doc.meta["id"] for doc in db_docs]
|
18 |
|
|
|
19 |
def index(documents, pipeline):
|
20 |
documents, doc_ids = format_docs(documents)
|
21 |
pipeline.run(documents=documents)
|
22 |
return doc_ids
|
23 |
|
|
|
24 |
def search(queries, pipeline):
|
25 |
results = []
|
26 |
matches_queries = pipeline.run_batch(queries=queries)
|
@@ -35,10 +37,12 @@ def search(queries, pipeline):
|
|
35 |
"text": res.content,
|
36 |
"score": res.score,
|
37 |
"id": res.meta["id"],
|
38 |
-
"fragment_id": res.id
|
39 |
}
|
40 |
)
|
41 |
if not score_is_empty:
|
42 |
-
query_results = sorted(
|
|
|
|
|
43 |
results.append(query_results)
|
44 |
-
return results
|
|
|
6 |
"""Given a list of documents, format the documents and return the documents and doc ids."""
|
7 |
db_docs: list = []
|
8 |
for doc in documents:
|
9 |
+
doc_id = doc["id"] if doc["id"] is not None else str(uuid.uuid4())
|
10 |
db_doc = {
|
11 |
+
"content": doc["text"],
|
12 |
"content_type": "text",
|
13 |
"id": str(uuid.uuid4()),
|
14 |
"meta": {"id": doc_id},
|
|
|
16 |
db_docs.append(Document(**db_doc))
|
17 |
return db_docs, [doc.meta["id"] for doc in db_docs]
|
18 |
|
19 |
+
|
20 |
def index(documents, pipeline):
|
21 |
documents, doc_ids = format_docs(documents)
|
22 |
pipeline.run(documents=documents)
|
23 |
return doc_ids
|
24 |
|
25 |
+
|
26 |
def search(queries, pipeline):
|
27 |
results = []
|
28 |
matches_queries = pipeline.run_batch(queries=queries)
|
|
|
37 |
"text": res.content,
|
38 |
"score": res.score,
|
39 |
"id": res.meta["id"],
|
40 |
+
"fragment_id": res.id,
|
41 |
}
|
42 |
)
|
43 |
if not score_is_empty:
|
44 |
+
query_results = sorted(
|
45 |
+
query_results, key=lambda x: x["score"], reverse=True
|
46 |
+
)
|
47 |
results.append(query_results)
|
48 |
+
return results
|
interface/components.py
CHANGED
@@ -3,36 +3,47 @@ import core.pipelines as pipelines_functions
|
|
3 |
from inspect import getmembers, isfunction
|
4 |
from networkx.drawing.nx_agraph import to_agraph
|
5 |
|
|
|
6 |
def component_select_pipeline(container):
|
7 |
-
pipeline_names, pipeline_funcs = list(
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
with container:
|
10 |
selected_pipeline = st.selectbox(
|
11 |
-
|
12 |
pipeline_names,
|
13 |
-
index=pipeline_names.index(
|
|
|
|
|
14 |
)
|
15 |
-
|
16 |
-
st.session_state[
|
17 |
-
|
|
|
|
|
18 |
|
19 |
def component_show_pipeline(container, pipeline):
|
20 |
"""Draw the pipeline"""
|
21 |
-
with st.expander(
|
22 |
graphviz = to_agraph(pipeline.graph)
|
23 |
graphviz.layout("dot")
|
24 |
st.graphviz_chart(graphviz.string())
|
25 |
-
|
|
|
26 |
def component_show_search_result(container, results):
|
27 |
with container:
|
28 |
for idx, document in enumerate(results):
|
29 |
st.markdown(f"### Match {idx+1}")
|
30 |
st.markdown(f"**Text**: {document['text']}")
|
31 |
st.markdown(f"**Document**: {document['id']}")
|
32 |
-
if document[
|
33 |
st.markdown(f"**Score**: {document['score']:.3f}")
|
34 |
st.markdown("---")
|
35 |
|
|
|
36 |
def component_text_input(container):
|
37 |
"""Draw the Text Input widget"""
|
38 |
with container:
|
@@ -48,7 +59,6 @@ def component_text_input(container):
|
|
48 |
else:
|
49 |
break
|
50 |
corpus = [
|
51 |
-
{"text": doc["text"], "id": doc_id}
|
52 |
-
for doc_id, doc in enumerate(texts)
|
53 |
]
|
54 |
-
return corpus
|
|
|
3 |
from inspect import getmembers, isfunction
|
4 |
from networkx.drawing.nx_agraph import to_agraph
|
5 |
|
6 |
+
|
7 |
def component_select_pipeline(container):
|
8 |
+
pipeline_names, pipeline_funcs = list(
|
9 |
+
zip(*getmembers(pipelines_functions, isfunction))
|
10 |
+
)
|
11 |
+
pipeline_names = [
|
12 |
+
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
13 |
+
]
|
14 |
with container:
|
15 |
selected_pipeline = st.selectbox(
|
16 |
+
"Select pipeline",
|
17 |
pipeline_names,
|
18 |
+
index=pipeline_names.index("Keyword Search")
|
19 |
+
if "Keyword Search" in pipeline_names
|
20 |
+
else 0,
|
21 |
)
|
22 |
+
(
|
23 |
+
st.session_state["search_pipeline"],
|
24 |
+
st.session_state["index_pipeline"],
|
25 |
+
) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
|
26 |
+
|
27 |
|
28 |
def component_show_pipeline(container, pipeline):
|
29 |
"""Draw the pipeline"""
|
30 |
+
with st.expander("Show pipeline"):
|
31 |
graphviz = to_agraph(pipeline.graph)
|
32 |
graphviz.layout("dot")
|
33 |
st.graphviz_chart(graphviz.string())
|
34 |
+
|
35 |
+
|
36 |
def component_show_search_result(container, results):
|
37 |
with container:
|
38 |
for idx, document in enumerate(results):
|
39 |
st.markdown(f"### Match {idx+1}")
|
40 |
st.markdown(f"**Text**: {document['text']}")
|
41 |
st.markdown(f"**Document**: {document['id']}")
|
42 |
+
if document["score"] is not None:
|
43 |
st.markdown(f"**Score**: {document['score']:.3f}")
|
44 |
st.markdown("---")
|
45 |
|
46 |
+
|
47 |
def component_text_input(container):
|
48 |
"""Draw the Text Input widget"""
|
49 |
with container:
|
|
|
59 |
else:
|
60 |
break
|
61 |
corpus = [
|
62 |
+
{"text": doc["text"], "id": doc_id} for doc_id, doc in enumerate(texts)
|
|
|
63 |
]
|
64 |
+
return corpus
|
interface/pages.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit_option_menu import option_menu
|
3 |
from core.search_index import index, search
|
4 |
-
from interface.components import
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def page_landing_page(container):
|
7 |
with container:
|
@@ -22,33 +27,34 @@ def page_landing_page(container):
|
|
22 |
"\n - Include file/url indexing"
|
23 |
"\n - [Optional] Include text to audio to read responses"
|
24 |
)
|
25 |
-
|
|
|
26 |
def page_search(container):
|
27 |
with container:
|
28 |
st.title("Query me!")
|
29 |
-
|
30 |
## SEARCH ##
|
31 |
query = st.text_input("Query")
|
32 |
-
|
33 |
-
component_show_pipeline(container, st.session_state[
|
34 |
-
|
35 |
if st.button("Search"):
|
36 |
-
st.session_state[
|
37 |
queries=[query],
|
38 |
-
pipeline=st.session_state[
|
39 |
)
|
40 |
-
if
|
41 |
component_show_search_result(
|
42 |
-
container=container,
|
43 |
-
results=st.session_state['search_results'][0]
|
44 |
)
|
45 |
-
|
|
|
46 |
def page_index(container):
|
47 |
with container:
|
48 |
st.title("Index time!")
|
49 |
-
|
50 |
-
component_show_pipeline(container, st.session_state[
|
51 |
-
|
52 |
input_funcs = {
|
53 |
"Raw Text": (component_text_input, "card-text"),
|
54 |
}
|
@@ -60,15 +66,15 @@ def page_index(container):
|
|
60 |
default_index=0,
|
61 |
orientation="horizontal",
|
62 |
)
|
63 |
-
|
64 |
corpus = input_funcs[selected_input][0](container)
|
65 |
-
|
66 |
if len(corpus) > 0:
|
67 |
index_results = None
|
68 |
if st.button("Index"):
|
69 |
index_results = index(
|
70 |
corpus,
|
71 |
-
st.session_state[
|
72 |
)
|
73 |
if index_results:
|
74 |
-
st.write(index_results)
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit_option_menu import option_menu
|
3 |
from core.search_index import index, search
|
4 |
+
from interface.components import (
|
5 |
+
component_show_pipeline,
|
6 |
+
component_show_search_result,
|
7 |
+
component_text_input,
|
8 |
+
)
|
9 |
+
|
10 |
|
11 |
def page_landing_page(container):
|
12 |
with container:
|
|
|
27 |
"\n - Include file/url indexing"
|
28 |
"\n - [Optional] Include text to audio to read responses"
|
29 |
)
|
30 |
+
|
31 |
+
|
32 |
def page_search(container):
|
33 |
with container:
|
34 |
st.title("Query me!")
|
35 |
+
|
36 |
## SEARCH ##
|
37 |
query = st.text_input("Query")
|
38 |
+
|
39 |
+
component_show_pipeline(container, st.session_state["search_pipeline"])
|
40 |
+
|
41 |
if st.button("Search"):
|
42 |
+
st.session_state["search_results"] = search(
|
43 |
queries=[query],
|
44 |
+
pipeline=st.session_state["search_pipeline"],
|
45 |
)
|
46 |
+
if "search_results" in st.session_state:
|
47 |
component_show_search_result(
|
48 |
+
container=container, results=st.session_state["search_results"][0]
|
|
|
49 |
)
|
50 |
+
|
51 |
+
|
52 |
def page_index(container):
|
53 |
with container:
|
54 |
st.title("Index time!")
|
55 |
+
|
56 |
+
component_show_pipeline(container, st.session_state["index_pipeline"])
|
57 |
+
|
58 |
input_funcs = {
|
59 |
"Raw Text": (component_text_input, "card-text"),
|
60 |
}
|
|
|
66 |
default_index=0,
|
67 |
orientation="horizontal",
|
68 |
)
|
69 |
+
|
70 |
corpus = input_funcs[selected_input][0](container)
|
71 |
+
|
72 |
if len(corpus) > 0:
|
73 |
index_results = None
|
74 |
if st.button("Index"):
|
75 |
index_results = index(
|
76 |
corpus,
|
77 |
+
st.session_state["index_pipeline"],
|
78 |
)
|
79 |
if index_results:
|
80 |
+
st.write(index_results)
|
linter.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python -m black app.py interface core
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
streamlit
|
2 |
streamlit_option_menu
|
3 |
farm-haystack
|
4 |
-
pygraphviz
|
|
|
|
1 |
streamlit
|
2 |
streamlit_option_menu
|
3 |
farm-haystack
|
4 |
+
pygraphviz
|
5 |
+
black
|