Spaces:
Running
Running
File size: 5,057 Bytes
4e00df7 dfd217b 4e00df7 dfd217b 4e00df7 8a70a7b 4e00df7 033cc04 4e00df7 dfd217b 4e00df7 dfd217b 4e00df7 dfd217b 4e00df7 7b2a1d4 81971c5 3fd401e dfd217b 4e00df7 dfd217b 1e311d8 dfd217b 1f57c51 dfd217b 1ca7761 cae23e1 1ca7761 4e00df7 1ca7761 beae2a2 dfd217b 4e00df7 dfd217b e129c7d dfd217b 033cc04 dfd217b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# main.py
import os
import streamlit as st
import anthropic
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username
supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=hf_api_key,
model_name="BAAI/bge-large-en-v1.5"
)
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
model = "meta-llama/Llama-3.3-70B-Instruct"
temperature = 0.1
max_tokens = 500
stats = str(get_usage(supabase))
def response_generator(query):
qa = None
add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
logger.info('Using HF model %s', model)
# print(st.session_state['max_tokens'])
endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
model_kwargs = {"temperature" : temperature,
"max_new_tokens" : max_tokens,
# "repetition_penalty" : 1.1,
"return_full_text" : False}
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
task="text-generation",
huggingfacehub_api_token=hf_api_key,
model_kwargs=model_kwargs
)
qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
# Generate model's response
model_response = qa({"question": query})
logger.info('Result: %s', model_response["answer"])
sources = model_response["source_documents"]
logger.info('Sources: %s', model_response["source_documents"])
if len(sources) > 0:
response = model_response["answer"]
else:
response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
return response
# Set the theme
st.set_page_config(
page_title="Securade.ai - Safety Copilot",
page_icon="https://securade.ai/favicon.ico",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
"About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
"Get Help" : "https://securade.ai",
"Report a Bug": "mailto:[email protected]"
}
)
st.title("👷♂️ Safety Copilot 🦺")
st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
# st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
st.markdown("_"+ stats + " queries answered!_")
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Ask a question"):
# print(prompt)
# Add user message to chat history
st.session_state.chat_history.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner('Safety briefing in progress...'):
response = response_generator(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
# print(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
# query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
# columns = st.columns(2)
# with columns[0]:
# button = st.button("Ask")
# with columns[1]:
# clear_history = st.button("Clear History", type='secondary')
# st.markdown("---\n\n")
# if clear_history:
# # Clear memory in Langchain
# memory.clear()
# st.session_state['chat_history'] = []
# st.experimental_rerun() |