leo-bourrel commited on
Commit
5c20978
·
1 Parent(s): e87a6a0

feat: replace postgres with sqlite

Browse files
Files changed (2) hide show
  1. app.py +10 -2
  2. custom_pgvector.py +47 -25
app.py CHANGED
@@ -2,6 +2,7 @@ import json
2
  import os
3
 
4
  import sqlalchemy
 
5
  import streamlit as st
6
  import streamlit.components.v1 as components
7
  from langchain import OpenAI
@@ -9,13 +10,14 @@ from langchain.callbacks import get_openai_callback
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.chains.conversation.memory import ConversationBufferMemory
11
  from langchain.embeddings import GPT4AllEmbeddings
 
12
 
13
  from chat_history import insert_chat_history, insert_chat_history_articles
14
  from css import load_css
15
  from custom_pgvector import CustomPGVector
16
  from message import Message
17
 
18
- CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
19
 
20
  st.set_page_config(layout="wide")
21
 
@@ -26,10 +28,16 @@ chat_column, doc_column = st.columns([2, 1])
26
 
27
  def connect() -> sqlalchemy.engine.Connection:
28
  engine = sqlalchemy.create_engine(CONNECTION_STRING)
 
 
 
 
 
 
 
29
  conn = engine.connect()
30
  return conn
31
 
32
-
33
  conn = connect()
34
 
35
 
 
2
  import os
3
 
4
  import sqlalchemy
5
+ import sqlite_vss
6
  import streamlit as st
7
  import streamlit.components.v1 as components
8
  from langchain import OpenAI
 
10
  from langchain.chains import ConversationalRetrievalChain
11
  from langchain.chains.conversation.memory import ConversationBufferMemory
12
  from langchain.embeddings import GPT4AllEmbeddings
13
+ from sqlalchemy import event
14
 
15
  from chat_history import insert_chat_history, insert_chat_history_articles
16
  from css import load_css
17
  from custom_pgvector import CustomPGVector
18
  from message import Message
19
 
20
+ CONNECTION_STRING = "sqlite:///data/sorbobot.db"
21
 
22
  st.set_page_config(layout="wide")
23
 
 
28
 
29
  def connect() -> sqlalchemy.engine.Connection:
30
  engine = sqlalchemy.create_engine(CONNECTION_STRING)
31
+
32
+ @event.listens_for(engine, "connect")
33
+ def receive_connect(connection, _):
34
+ connection.enable_load_extension(True)
35
+ sqlite_vss.load(connection)
36
+ connection.enable_load_extension(False)
37
+
38
  conn = engine.connect()
39
  return conn
40
 
 
41
  conn = connect()
42
 
43
 
custom_pgvector.py CHANGED
@@ -4,6 +4,7 @@ import contextlib
4
  import enum
5
  import json
6
  import logging
 
7
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type
8
 
9
  import pandas as pd
@@ -348,33 +349,54 @@ class CustomPGVector(VectorStore):
348
  k: int = 4,
349
  ) -> List[Any]:
350
  """Query the collection."""
351
- with Session(self._conn) as session:
352
- results = session.execute(
353
- text(
354
- f"""
 
355
  select
356
- a.id,
357
- a.title,
358
- a.doi,
359
- a.abstract,
360
- string_agg(distinct keyword."name", ',') as keywords,
361
- string_agg(distinct author."name", ',') as authors,
362
- abstract_embedding <-> '{str(embedding)}' as distance
363
- from article a
364
- left join article_keyword ON article_keyword.article_id = a.id
365
- left join keyword on article_keyword.keyword_id = keyword.id
366
- left join article_author ON article_author.article_id = a.id
367
- left join author on author.id = article_author.author_id
368
- where abstract != 'NaN'
369
- GROUP BY a.id
370
- ORDER BY distance
371
- LIMIT {k};
372
- """
373
  )
374
- )
375
- results = results.fetchall()
376
- results = pd.DataFrame(results, columns=["id", "title", "doi", "abstract", "keywords", "authors", "distance"])
377
- results = results.to_dict(orient="records")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  return results
379
 
380
  def similarity_search_by_vector(
 
4
  import enum
5
  import json
6
  import logging
7
+ import struct
8
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type
9
 
10
  import pandas as pd
 
349
  k: int = 4,
350
  ) -> List[Any]:
351
  """Query the collection."""
352
+ vector = bytearray(struct.pack("f" * len(embedding), *embedding))
353
+
354
+ cursor = self._conn.execute(
355
+ text("""
356
+ with matches as (
357
  select
358
+ rowid,
359
+ distance
360
+ from vss_article
361
+ where vss_search(
362
+ abstract_embedding,
363
+ :vector
364
+ )
365
+ limit :limit
 
 
 
 
 
 
 
 
 
366
  )
367
+ select
368
+ article.id,
369
+ article.title,
370
+ article.doi,
371
+ article.abstract,
372
+ group_concat(keyword."name", ',') as keywords,
373
+ group_concat(author."name", ',') as authors,
374
+ matches.distance
375
+ from matches
376
+ left join article on matches.rowid = article.rowid
377
+ left join article_keyword ak ON ak.article_id = article.id
378
+ left join keyword on ak.keyword_id = keyword.id
379
+ left join article_author ON article_author.article_id = article.id
380
+ left join author on author.id = article_author.author_id
381
+ group by article.id
382
+ order by distance;
383
+ """),
384
+ {"vector": vector, "limit": k}
385
+ )
386
+ results = cursor.fetchall()
387
+ results = pd.DataFrame(
388
+ results,
389
+ columns=[
390
+ "id",
391
+ "title",
392
+ "doi",
393
+ "abstract",
394
+ "keywords",
395
+ "authors",
396
+ "distance",
397
+ ],
398
+ )
399
+ results = results.to_dict(orient="records")
400
  return results
401
 
402
  def similarity_search_by_vector(