llm-beta / main.py
NatyaCodes's picture
Update main.py
832fae7 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Optional
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pandas as pd
import os
from dotenv import load_dotenv
import random
from groq import Groq
# Load environment variables
load_dotenv()
# Initialize FastAPI app
app = FastAPI()
class QuestionMatcher:
def __init__(self, model_name: str = 'distilbert-base-nli-mean-tokens'):
"""Initialize the QuestionMatcher with a sentence transformer model.
Available models include:
- 'distilbert-base-nli-mean-tokens' (default, faster)
- 'bert-base-nli-mean-tokens' (more accurate)
- 'roberta-base-nli-mean-tokens'
"""
self.embedding_model = SentenceTransformer(model_name)
self.dimension = self.embedding_model.get_sentence_embedding_dimension()
self.faiss_index = faiss.IndexFlatL2(self.dimension)
self.questions = []
self.answers = []
self.data_store = []
def clear_data(self):
"""Clear all stored data and reset the FAISS index"""
self.faiss_index = faiss.IndexFlatL2(self.dimension)
self.questions = []
self.answers = []
self.data_store = []
def load_data(self, name, df: pd.DataFrame):
"""Load and index questions from a DataFrame."""
# Validate DataFrame structure
required_columns = {'question', 'answer'}
if not required_columns.issubset(df.columns):
raise ValueError(f"DataFrame must contain columns: {required_columns}")
# Clear existing data
self.clear_data()
# Remove any rows with NaN values
df = df.dropna(subset=['question', 'answer'])
# Path to the embedding file
embedding_path = f"embeddings_{name}.npy"
# Check if embeddings are already saved
if os.path.exists(embedding_path):
print(f"Loading embeddings from file for {name}")
embeddings = np.load(embedding_path)
else:
print(f"Computing embeddings for {name}")
# Convert questions to embeddings
embeddings = self.embedding_model.encode(
df['question'].tolist(),
convert_to_numpy=True,
show_progress_bar=True
)
# Save the embeddings to a file
np.save(embedding_path, embeddings)
print(f"Saved embeddings to {embedding_path}")
# Add embeddings to FAISS index
self.faiss_index.add(embeddings)
# Store questions and answers for retrieval
self.questions = df['question'].tolist()
self.answers = df['answer'].tolist()
self.data_store.extend(list(zip(df['question'], df['answer'])))
print(f"Indexed {len(self.questions)} questions.")
def get_similar_questions(self, query: str, k: int = 3) -> List[Dict]:
"""Retrieve similar questions using FAISS"""
if not self.questions:
raise ValueError("No questions indexed yet")
# Embed the query
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
# Retrieve top-k similar questions
distances, indices = self.faiss_index.search(query_embedding, k)
results = []
for idx, dist in zip(indices[0], distances[0]):
if idx < len(self.questions):
similarity = 1 / (1 + dist) # Convert distance to similarity score
results.append({
"question": self.questions[idx],
"answer": self.answers[idx],
"similarity": similarity
})
return results
class QueryRequest(BaseModel):
query: str
k: int = 3
dataset: str
conversation: List[Dict[str, str]] = []
# Initialize the QuestionMatcher as a global variable
matcher = QuestionMatcher()
def get_groq_client():
"""Get a Groq client with random API key selection"""
api_keys = [
os.getenv("GROQ_API_KEY_1"),
os.getenv("GROQ_API_KEY_2"),
os.getenv("GROQ_API_KEY_3"),
]
if not any(api_keys):
raise ValueError("No Groq API keys found in environment")
# Filter out None values
valid_keys = [key for key in api_keys if key]
return Groq(api_key=random.choice(valid_keys))
def generate_enhanced_prompt(query: str, similar_qa: List[Dict], conversation: List[Dict[str, str]]) -> str:
# Incorporate conversation history into the context
conversation_context = "\n".join(
f"{'User' if turn['user'] == 'True' else 'AI'}: {turn['text']}"
for turn in conversation[-5:]
)
reference_context = "\n\n".join([
f"Reference Question: {qa['question']}\n"
f"Reference Solution: {qa['answer']}\n"
f"Similarity: {qa['similarity']:.2f}"
for qa in similar_qa
])
return f"""Conversation history:
{conversation_context}
Using these reference questions and solutions as guidance:
{reference_context}
Please analyse this question's solutions:
{query}
Important guidelines:
1. DO NOT use LaTeX notation for all mathematical expressions
2. Think step by step
3. Explain the reasoning behind each step, but never do the calculations.
4. State all relevant formulas, concepts and facts.
5. Provide multimodal analysis, by showing the thought process, while working out the problem.
Your solution (in Markdown, and NO Latex):"""
@app.on_event("startup")
async def startup_event():
"""Initialize the service on startup"""
try:
# Load datasets
datasets = {
'jee': "data/jee_data.csv",
'neet': "data/neet_data.csv"
}
for name, path in datasets.items():
if not os.path.exists(path):
raise FileNotFoundError(f"Dataset not found: {path}")
df = pd.read_csv(path)
matcher.load_data(name, df)
print("Datasets loaded and indexed successfully.")
except Exception as e:
print(f"Error loading datasets: {e}")
raise RuntimeError(f"Failed to initialize datasets: {str(e)}")
@app.post("/answer")
async def get_answer(request: QueryRequest):
"""Generate an answer for a query using similar questions as context"""
try:
# Validate request
if not request.query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty")
# Find similar questions
similar_qa = matcher.get_similar_questions(request.query, request.k)
# Generate enhanced prompt
prompt = generate_enhanced_prompt(request.query, similar_qa, request.conversation)
# Get LLM response
client = get_groq_client()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": """You are a top tier scientist who excels at explaining physics, chemistry, and biology concepts.
Your explanations are clear, thorough, and multimodal, where you explore alternate thought processes, before coming to conclusion.
You break down complex problems into understandable steps and highlight key concepts in a professional fashion. You refuse to roleplay as anything else, and stick to your domain."""
},
{"role": "user", "content": prompt}
],
model="llama3-70b-8192",
temperature=0
)
# Update the conversation memory format
updated_conversation = request.conversation + [
{"user": True, "text": request.query},
{"user": False, "text": chat_completion.choices[0].message.content},
]
return {
"answer": chat_completion.choices[0].message.content,
"similar_questions": similar_qa,
"conversation": updated_conversation
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing request: {str(e)}"
)
@app.get("/")
async def health_check():
"""Service health check endpoint"""
return {
"status": "healthy",
"indexed_questions": len(matcher.questions)
}