Spaces:
Sleeping
Sleeping
from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session | |
from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user, current_user | |
from flask_wtf.csrf import CSRFProtect | |
from flask_wtf import FlaskForm | |
from wtforms import StringField, PasswordField, SubmitField | |
from wtforms.validators import DataRequired | |
from werkzeug.security import generate_password_hash, check_password_hash | |
import arxiv | |
import requests | |
import PyPDF2 | |
from io import BytesIO | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_groq import ChatGroq | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
from functools import lru_cache | |
import time | |
import os | |
from dotenv import load_dotenv | |
import json | |
from datetime import datetime | |
from flask_sqlalchemy import SQLAlchemy | |
from config import Config | |
# Load environment variables | |
load_dotenv() | |
# Initialize Flask extensions | |
db = SQLAlchemy() | |
login_manager = LoginManager() | |
def create_app(): | |
app = Flask(__name__) | |
app.config.from_object(Config) | |
# Initialize extensions | |
db.init_app(app) | |
login_manager.init_app(app) | |
login_manager.login_view = 'login' | |
with app.app_context(): | |
# Import routes after db initialization | |
from routes import init_routes | |
init_routes(app) | |
# Create database tables | |
db.create_all() | |
# Test database connection | |
try: | |
version = db.session.execute('SELECT VERSION()').scalar() | |
print(f"Connected to PostgreSQL: {version}") | |
except Exception as e: | |
print(f"Database connection error: {str(e)}") | |
raise e | |
return app | |
# Initialize CSRF protection | |
csrf = CSRFProtect() | |
csrf.init_app(app) | |
# Initialize Groq | |
groq_api_key = os.getenv('GROQ_API_KEY') | |
llm = ChatGroq( | |
temperature=0.1, | |
groq_api_key=groq_api_key, | |
model_name="mixtral-8x7b-32768" | |
) | |
# Initialize embeddings | |
embeddings_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
# Constants | |
MAX_CHUNKS = 50 | |
MAX_RESPONSE_LENGTH = 4000 | |
CACHE_DURATION = 3600 # 1 hour in seconds | |
# Form Classes | |
class LoginForm(FlaskForm): | |
username = StringField('Username', validators=[DataRequired()]) | |
password = PasswordField('Password', validators=[DataRequired()]) | |
submit = SubmitField('Login') | |
class RegisterForm(FlaskForm): | |
username = StringField('Username', validators=[DataRequired()]) | |
password = PasswordField('Password', validators=[DataRequired()]) | |
submit = SubmitField('Register') | |
# User class | |
class User(UserMixin): | |
def __init__(self, user_id, username): | |
self.id = user_id | |
self.username = username | |
@staticmethod | |
def get(user_id): | |
users = load_users() | |
user_data = users.get(str(user_id)) | |
if user_data: | |
return User(user_id=user_data['id'], username=user_data['username']) | |
return None | |
# User management functions | |
def load_users(): | |
try: | |
with open('users.json', 'r') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
return {} | |
def save_users(users): | |
with open('users.json', 'w') as f: | |
json.dump(users, f) | |
@login_manager.user_loader | |
def load_user(user_id): | |
return User.get(user_id) | |
# PDF Processing and Analysis | |
def process_pdf(pdf_url): | |
try: | |
print(f"Starting PDF processing for: {pdf_url}") | |
response = requests.get(pdf_url, timeout=30) | |
response.raise_for_status() | |
pdf_file = BytesIO(response.content) | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
# Clean and normalize the text | |
text = " ".join( | |
page.extract_text().encode('ascii', 'ignore').decode('ascii') | |
for page in pdf_reader.pages | |
) | |
if not text.strip(): | |
return {'error': 'No text could be extracted from the PDF'} | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len, | |
separators=["\n\n", "\n", " ", ""] | |
) | |
chunks = text_splitter.split_text(text)[:MAX_CHUNKS] | |
analysis = generate_analysis(chunks) | |
return { | |
'success': True, | |
'analysis': analysis | |
} | |
except Exception as e: | |
return {'error': f"PDF processing failed: {str(e)}"} | |
def generate_analysis(chunks): | |
analysis_prompts = { | |
'executive_summary': "Provide a concise executive summary of this research paper.", | |
'problem_analysis': "What is the main research problem and objectives?", | |
'methodology': "Describe the key methodology and approach.", | |
'findings': "What are the main findings and conclusions?", | |
'contributions': "What are the key contributions of this work?" | |
} | |
analysis_results = {} | |
for aspect, prompt in analysis_prompts.items(): | |
try: | |
# Clean and join the chunks | |
context = "\n\n".join( | |
chunk.encode('ascii', 'ignore').decode('ascii') | |
for chunk in chunks[:3] | |
) | |
response = llm.invoke( | |
f"""Based on the following context from a research paper, {prompt} | |
Context: | |
{context} | |
Please provide a clear and specific response.""" | |
) | |
analysis_results[aspect] = response.content[:MAX_RESPONSE_LENGTH] | |
except Exception as e: | |
analysis_results[aspect] = f"Analysis failed: {str(e)}" | |
return analysis_results | |
# Routes | |
@app.route('/') | |
@login_required | |
def index(): | |
return render_template('index.html') | |
@app.route('/login', methods=['GET', 'POST']) | |
def login(): | |
if current_user.is_authenticated: | |
return redirect(url_for('index')) | |
form = LoginForm() | |
if form.validate_on_submit(): | |
username = form.username.data | |
password = form.password.data | |
users = load_users() | |
user_found = None | |
for user_id, user_data in users.items(): | |
if user_data['username'] == username: | |
user_found = user_data | |
break | |
if user_found and check_password_hash(user_found['password_hash'], password): | |
user = User(user_id=user_found['id'], username=username) | |
login_user(user, remember=True) | |
return redirect(url_for('index')) | |
flash('Invalid username or password') | |
return render_template('login.html', form=form) | |
@app.route('/register', methods=['GET', 'POST']) | |
def register(): | |
if current_user.is_authenticated: | |
return redirect(url_for('index')) | |
form = RegisterForm() | |
if form.validate_on_submit(): | |
username = form.username.data | |
password = form.password.data | |
users = load_users() | |
if any(user['username'] == username for user in users.values()): | |
flash('Username already exists') | |
return render_template('register.html', form=form) | |
user_id = str(len(users) + 1) | |
users[user_id] = { | |
'id': user_id, | |
'username': username, | |
'password_hash': generate_password_hash(password) | |
} | |
save_users(users) | |
user = User(user_id=user_id, username=username) | |
login_user(user) | |
return redirect(url_for('index')) | |
return render_template('register.html', form=form) | |
@app.route('/logout') | |
@login_required | |
def logout(): | |
logout_user() | |
return redirect(url_for('login')) | |
@app.route('/search', methods=['POST']) | |
@login_required | |
def search(): | |
try: | |
data = request.get_json() | |
paper_name = data.get('paper_name') | |
sort_by = data.get('sort_by', 'relevance') | |
max_results = data.get('max_results', 10) | |
if not paper_name: | |
return jsonify({'error': 'No search query provided'}), 400 | |
# Map sort_by to arxiv.SortCriterion | |
sort_mapping = { | |
'relevance': arxiv.SortCriterion.Relevance, | |
'lastUpdated': arxiv.SortCriterion.LastUpdatedDate, | |
'submitted': arxiv.SortCriterion.SubmittedDate | |
} | |
sort_criterion = sort_mapping.get(sort_by, arxiv.SortCriterion.Relevance) | |
# Perform the search | |
search = arxiv.Search( | |
query=paper_name, | |
max_results=max_results, | |
sort_by=sort_criterion | |
) | |
results = [] | |
for paper in search.results(): | |
results.append({ | |
'title': paper.title, | |
'authors': ', '.join(author.name for author in paper.authors), | |
'abstract': paper.summary, | |
'pdf_link': paper.pdf_url, | |
'arxiv_link': paper.entry_id, | |
'published': paper.published.strftime('%Y-%m-%d'), | |
'category': paper.primary_category, | |
'comment': paper.comment if hasattr(paper, 'comment') else None, | |
'doi': paper.doi if hasattr(paper, 'doi') else None | |
}) | |
return jsonify(results) | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return jsonify({'error': f'Failed to search papers: {str(e)}'}), 500 | |
@app.route('/perform-rag', methods=['POST']) | |
@login_required | |
def perform_rag(): | |
try: | |
pdf_url = request.json.get('pdf_url') | |
if not pdf_url: | |
return jsonify({'error': 'PDF URL is required'}), 400 | |
result = process_pdf(pdf_url) | |
if 'error' in result: | |
return jsonify({'error': result['error']}), 500 | |
return jsonify(result) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
@app.route('/chat-with-paper', methods=['POST']) | |
@login_required | |
def chat_with_paper(): | |
try: | |
pdf_url = request.json.get('pdf_url') | |
question = request.json.get('question') | |
if not pdf_url or not question: | |
return jsonify({'error': 'PDF URL and question are required'}), 400 | |
# Get PDF text and create chunks | |
response = requests.get(pdf_url, timeout=30) | |
response.raise_for_status() | |
pdf_file = BytesIO(response.content) | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
text = " ".join(page.extract_text() for page in pdf_reader.pages) | |
if not text.strip(): | |
return jsonify({'error': 'No text could be extracted from the PDF'}) | |
# Create text chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len | |
) | |
chunks = text_splitter.split_text(text)[:MAX_CHUNKS] | |
# Generate embeddings for chunks | |
chunk_embeddings = embeddings_model.embed_documents(chunks) | |
# Generate embedding for the question | |
question_embedding = embeddings_model.embed_query(question) | |
# Find most relevant chunks using cosine similarity | |
similarities = [] | |
for chunk_embedding in chunk_embeddings: | |
similarity = np.dot(question_embedding, chunk_embedding) / ( | |
np.linalg.norm(question_embedding) * np.linalg.norm(chunk_embedding) | |
) | |
similarities.append(similarity) | |
# Get top 3 most relevant chunks | |
top_chunk_indices = np.argsort(similarities)[-3:][::-1] | |
relevant_chunks = [chunks[i] for i in top_chunk_indices] | |
# Construct prompt with relevant context | |
context = "\n\n".join(relevant_chunks) | |
prompt = f"""Based on the following relevant excerpts from the research paper, please answer this question: {question} | |
Context from paper: | |
{context} | |
Please provide a clear, specific, and accurate response based solely on the information provided in these excerpts. If the answer cannot be fully determined from the given context, please indicate this in your response.""" | |
# Generate response using Groq | |
response = llm.invoke(prompt) | |
# Format and return response | |
formatted_response = response.content.strip() | |
# Add source citations | |
source_info = "\n\nThis response is based on specific sections from the paper." | |
return jsonify({ | |
'response': formatted_response + source_info, | |
'relevance_scores': [float(similarities[i]) for i in top_chunk_indices] | |
}) | |
except Exception as e: | |
print(f"Chat error: {str(e)}") | |
return jsonify({'error': f'Failed to process request: {str(e)}'}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True) | |