Spaces:
Sleeping
Sleeping
init!
Browse files
final.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# final for hugging face
|
2 |
+
|
3 |
+
import os
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import faiss
|
7 |
+
import numpy as np
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from groq import Groq
|
10 |
+
from mirascope.core import groq
|
11 |
+
from pydantic import BaseModel
|
12 |
+
|
13 |
+
# Set page config
|
14 |
+
st.set_page_config(page_title="Smart Course Search with Faiss", page_icon="🔍", layout="wide")
|
15 |
+
|
16 |
+
# Groq API Key
|
17 |
+
GROQ_API_KEY = "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"
|
18 |
+
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
19 |
+
|
20 |
+
# Initialize Groq client
|
21 |
+
groq_client = Groq(api_key=GROQ_API_KEY)
|
22 |
+
|
23 |
+
# Initialize Sentence Transformer model
|
24 |
+
@st.cache_resource
|
25 |
+
def init_model():
|
26 |
+
return SentenceTransformer("all-MiniLM-L6-v2")
|
27 |
+
|
28 |
+
model = init_model()
|
29 |
+
|
30 |
+
# Load the CSV file
|
31 |
+
@st.cache_data
|
32 |
+
def load_data():
|
33 |
+
# In Hugging Face Spaces
|
34 |
+
csv_path = "Analytics_vidhya_final_data.csv"
|
35 |
+
if os.path.exists(csv_path):
|
36 |
+
return pd.read_csv(csv_path)
|
37 |
+
else:
|
38 |
+
st.error(f"CSV file not found at {csv_path}. Please make sure it's uploaded to your Space.")
|
39 |
+
return pd.DataFrame()
|
40 |
+
|
41 |
+
courses_df = load_data()
|
42 |
+
|
43 |
+
|
44 |
+
# Initialize Faiss index and document store
|
45 |
+
dimension = 384 # Embedding dimension for "all-MiniLM-L6-v2"
|
46 |
+
|
47 |
+
if 'faiss_index' not in st.session_state:
|
48 |
+
st.session_state['faiss_index'] = faiss.IndexFlatL2(dimension)
|
49 |
+
st.session_state['document_store'] = []
|
50 |
+
|
51 |
+
faiss_index = st.session_state['faiss_index']
|
52 |
+
document_store = st.session_state['document_store']
|
53 |
+
|
54 |
+
# Function to prepare course data
|
55 |
+
def prepare_course_data(df):
|
56 |
+
documents = []
|
57 |
+
metadatas = []
|
58 |
+
|
59 |
+
for index, row in df.iterrows():
|
60 |
+
title = row.get('Title', '').strip()
|
61 |
+
curriculum = row.get('Course Curriculum', '').strip()
|
62 |
+
description = row.get('Course Description', '').strip()
|
63 |
+
|
64 |
+
if not title or not curriculum or not description:
|
65 |
+
continue
|
66 |
+
|
67 |
+
content = f"{title} {curriculum} {description}".strip()
|
68 |
+
documents.append(content)
|
69 |
+
|
70 |
+
metadata = {
|
71 |
+
"title": title,
|
72 |
+
"curriculum": curriculum,
|
73 |
+
"description": description,
|
74 |
+
}
|
75 |
+
metadatas.append(metadata)
|
76 |
+
|
77 |
+
return documents, metadatas
|
78 |
+
|
79 |
+
# Add courses to Faiss index
|
80 |
+
def add_courses_to_faiss(df):
|
81 |
+
documents, metadatas = prepare_course_data(df)
|
82 |
+
|
83 |
+
if not documents:
|
84 |
+
st.warning("No valid documents to add to the database")
|
85 |
+
return 0
|
86 |
+
|
87 |
+
try:
|
88 |
+
embeddings = model.encode(documents)
|
89 |
+
faiss_index.add(np.array(embeddings, dtype="float32"))
|
90 |
+
document_store.extend(metadatas)
|
91 |
+
return len(documents)
|
92 |
+
except Exception as e:
|
93 |
+
st.error(f"Error adding documents to Faiss: {str(e)}")
|
94 |
+
return 0
|
95 |
+
|
96 |
+
# Faiss search function
|
97 |
+
def faiss_search(query, k=3):
|
98 |
+
if faiss_index.ntotal == 0:
|
99 |
+
st.warning("Faiss index is empty. Cannot perform search.")
|
100 |
+
return []
|
101 |
+
|
102 |
+
query_embedding = model.encode([query])
|
103 |
+
distances, indices = faiss_index.search(np.array(query_embedding, dtype="float32"), k)
|
104 |
+
|
105 |
+
results = []
|
106 |
+
for i, idx in enumerate(indices[0]):
|
107 |
+
if idx < len(document_store):
|
108 |
+
results.append({
|
109 |
+
"content": document_store[idx],
|
110 |
+
"metadata": document_store[idx],
|
111 |
+
"score": -distances[0][i]
|
112 |
+
})
|
113 |
+
|
114 |
+
return results
|
115 |
+
|
116 |
+
# Groq search function
|
117 |
+
def groq_search(user_query):
|
118 |
+
prompt = f"""
|
119 |
+
You are an AI assistant specializing in data science, machine learning, artificial intelligence, generative AI, data engineering, and data analytics. Your task is to analyze the following user query and determine if it's related to these fields:
|
120 |
+
|
121 |
+
User Query: "{user_query}"
|
122 |
+
|
123 |
+
Please provide a detailed response that includes:
|
124 |
+
1. Whether the query is related to the mentioned fields (data science, ML, AI, GenAI, data engineering, or data analytics).
|
125 |
+
2. If related, explain how it connects to these fields and suggest potential subtopics or courses that might be relevant.
|
126 |
+
3. If not directly related, try to find any indirect connections to the mentioned fields.
|
127 |
+
|
128 |
+
Your response should be informative and help guide a course recommendation system. End your response with a clear YES if the query is related to the mentioned fields, or NO if it's completely unrelated.
|
129 |
+
"""
|
130 |
+
|
131 |
+
try:
|
132 |
+
chat_completion = groq_client.chat.completions.create(
|
133 |
+
messages=[
|
134 |
+
{
|
135 |
+
"role": "user",
|
136 |
+
"content": prompt,
|
137 |
+
}
|
138 |
+
],
|
139 |
+
model="llama3-8b-8192",
|
140 |
+
)
|
141 |
+
return chat_completion.choices[0].message.content
|
142 |
+
except Exception as e:
|
143 |
+
st.error(f"Error in Groq API call: {str(e)}")
|
144 |
+
return "ERROR: Unable to process the query"
|
145 |
+
|
146 |
+
# Mirascope analysis
|
147 |
+
class SearchResult(BaseModel):
|
148 |
+
final_output: int
|
149 |
+
|
150 |
+
@groq.call("llama-3.1-70b-versatile", response_model=SearchResult)
|
151 |
+
def extract_relevance(text: str) -> str:
|
152 |
+
return f"""Extract the integer from text whether we move forward or not it can be either 0 or 1: {text}"""
|
153 |
+
|
154 |
+
# Streamlit UI
|
155 |
+
st.title("Smart Course Search System")
|
156 |
+
|
157 |
+
# Show Faiss index count
|
158 |
+
db_count = faiss_index.ntotal
|
159 |
+
st.write(f"Current number of courses in the Faiss index: {db_count}")
|
160 |
+
|
161 |
+
# Add courses to database if not already added
|
162 |
+
if db_count == 0 and not courses_df.empty:
|
163 |
+
added_count = add_courses_to_faiss(courses_df)
|
164 |
+
st.success(f"{added_count} courses added to the Faiss index!")
|
165 |
+
db_count = faiss_index.ntotal
|
166 |
+
st.write(f"Updated number of courses in the Faiss index: {db_count}")
|
167 |
+
|
168 |
+
# Search query input
|
169 |
+
user_query = st.text_input("Enter your search query")
|
170 |
+
|
171 |
+
if user_query:
|
172 |
+
with st.spinner("Analyzing your query..."):
|
173 |
+
groq_response = groq_search(user_query)
|
174 |
+
search_result = extract_relevance(groq_response)
|
175 |
+
|
176 |
+
if search_result.final_output == 1:
|
177 |
+
st.success("Your query is relevant to our course catalog. Here are the search results:")
|
178 |
+
results = faiss_search(user_query)
|
179 |
+
|
180 |
+
if results:
|
181 |
+
st.subheader("Search Results")
|
182 |
+
for i, result in enumerate(results, 1):
|
183 |
+
with st.expander(f"Result {i}: {result['metadata']['title']}"):
|
184 |
+
st.write(f"Relevance Score: {result['score']:.2f}")
|
185 |
+
st.subheader("Course Curriculum")
|
186 |
+
st.write(result['metadata']['curriculum'])
|
187 |
+
st.subheader("Course Description")
|
188 |
+
st.write(result['metadata']['description'])
|
189 |
+
else:
|
190 |
+
st.info("No specific courses found for your query. Try a different search term.")
|
191 |
+
else:
|
192 |
+
st.warning("We're sorry, but we couldn't find any courses directly related to your search query.")
|
193 |
+
st.write("Our current catalog focuses on data science, machine learning, artificial intelligence, generative AI, data engineering, and data analytics. Please try a different search term related to these fields.")
|
194 |
+
|
195 |
+
# Debug information
|
196 |
+
if st.checkbox("Show Debug Information"):
|
197 |
+
st.subheader("Debug Information")
|
198 |
+
st.write(f"Database count: {db_count}")
|
199 |
+
if db_count > 0:
|
200 |
+
st.write("Sample document:")
|
201 |
+
if document_store:
|
202 |
+
st.json(document_store[0])
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
pass
|