|
from fastapi import FastAPI, HTTPException, Query |
|
from pydantic import BaseModel |
|
import os |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.document_loaders import CSVLoader |
|
from langchain_openai import ChatOpenAI |
|
from langchain_groq import ChatGroq |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from dotenv import load_dotenv |
|
from fastapi.responses import PlainTextResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import asyncio |
|
import json |
|
import re |
|
|
|
load_dotenv() |
|
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY") |
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") |
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
|
key = os.getenv("GOOGLE_API_KEY") |
|
|
|
DB_FAISS_PATH = "bgi/db_faiss" |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
embeddings = None |
|
db = None |
|
|
|
|
|
@app.on_event("startup") |
|
def load_vector_store(): |
|
global embeddings, db |
|
if os.path.exists(DB_FAISS_PATH): |
|
print("Loading existing FAISS vector store.") |
|
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-small-en', model_kwargs={'device': 'cpu'}) |
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) |
|
print("Vector store loaded.") |
|
else: |
|
print("Creating new FAISS vector store.") |
|
loader = CSVLoader(file_path="Final_Research_Dataset_2.csv", encoding="utf-8", csv_args={'delimiter': ','}) |
|
data = loader.load() |
|
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-small-en', model_kwargs={'device': 'cpu'}) |
|
db = FAISS.from_documents(data, embeddings) |
|
db.save_local(DB_FAISS_PATH) |
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
class FilterCriteria(BaseModel): |
|
impactFactor: float |
|
firstDecisionTime: int |
|
publisher: Optional[str] |
|
llmModel: str |
|
|
|
class QueryRequest(BaseModel): |
|
abstract: str |
|
criteria: FilterCriteria |
|
|
|
class Journal(BaseModel): |
|
id: int |
|
Name: str |
|
JIF: float |
|
Category: str |
|
Keywords: str |
|
Publisher: str |
|
Decision_Time: int |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
result: List[Journal] |
|
|
|
|
|
@app.get("/", response_class=PlainTextResponse) |
|
def read_root(): |
|
return "Welcome to the Journal Recommender API!" |
|
|
|
@app.get("/models") |
|
def get_models(): |
|
return {"available_models": ["openai", "groq","mixtral","gemini-pro","faiss"]} |
|
|
|
def fix_incomplete_json(raw_response): |
|
""" |
|
Fixes incomplete JSON by adding missing braces or brackets. |
|
Returns a valid JSON string or None if not fixable. |
|
""" |
|
|
|
if raw_response.endswith("},"): |
|
raw_response = raw_response[:-1] |
|
if raw_response.count("{") > raw_response.count("}"): |
|
raw_response += "}" |
|
if raw_response.count("[") > raw_response.count("]"): |
|
raw_response += "]" |
|
|
|
|
|
try: |
|
json_response = json.loads(raw_response) |
|
return json_response |
|
except json.JSONDecodeError as e: |
|
print(f"Error fixing JSON: {e}") |
|
return None |
|
|
|
|
|
|
|
@app.post("/query", response_model=QueryResponse) |
|
async def query(request: QueryRequest): |
|
global db |
|
if not db: |
|
raise HTTPException(status_code=500, detail="Vector store not loaded.") |
|
|
|
query_text = request.abstract |
|
model_choice = request.criteria.llmModel |
|
impact_factor = request.criteria.impactFactor |
|
preferred_publisher = request.criteria.publisher |
|
|
|
docs = db.similarity_search(query_text, k=5) |
|
context = "\n".join([doc.page_content for doc in docs]) |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"Give a strict comma-separated list of exactly 15 keywords from the following text. " |
|
"Give a strict comma-separated list of exactly 15 keywords from the following text. " |
|
"Do not include any bullet points, introductory text, or ending text. " |
|
"No introductory or ending text strictly" |
|
"Do not say anything like 'Here are the keywords.' " |
|
"Only return the keywords, strictly comma-separated, without any additional words." |
|
), |
|
}, |
|
{"role": "user", "content": query_text}, |
|
] |
|
llm = ChatGroq(model="llama3-8b-8192", temperature=0) |
|
ai_msg = llm.invoke(messages) |
|
keywords = ai_msg.content.split("keywords extracted from the text:\n")[-1].strip() |
|
print("Keywords:", keywords) |
|
if model_choice == "openai": |
|
retriever = db.as_retriever() |
|
|
|
|
|
system_prompt = ( |
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result." |
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords." |
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database " |
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. " |
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. " |
|
f"Ensure no introductory or ending texts are included. Give max 30 results" |
|
"Context: {context}" |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[("system", system_prompt), ("user", "{input}")] |
|
) |
|
|
|
|
|
async def create_chain(): |
|
client = ChatOpenAI(model="gpt-4o") |
|
return create_stuff_documents_chain(client, prompt) |
|
|
|
|
|
question_answer_chain = await create_chain() |
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke( |
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"} |
|
) |
|
|
|
|
|
result = [] |
|
raw_response = answer['answer'] |
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip() |
|
|
|
|
|
try: |
|
json_response = json.loads(cleaned_response) |
|
|
|
|
|
result = [] |
|
|
|
|
|
for i, journal in enumerate(json_response): |
|
try: |
|
journal_name = journal.get('Journal Name') |
|
publisher = journal.get('Publisher') |
|
jif = float(journal.get('JIF', 0)) |
|
decision_time = journal.get('Decsion Time', 0) |
|
|
|
|
|
if jif > impact_factor: |
|
result.append( |
|
Journal( |
|
id=i + 1, |
|
Name=journal_name, |
|
Publisher=publisher, |
|
JIF=jif, |
|
Category="", |
|
Keywords=keywords, |
|
Decision_Time=decision_time, |
|
) |
|
) |
|
except Exception as e: |
|
print(f"Error processing journal data: {e}") |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing JSON response: {e}") |
|
result = [] |
|
|
|
|
|
return QueryResponse(result=result) |
|
elif model_choice == "groq": |
|
retriever = db.as_retriever() |
|
|
|
|
|
system_prompt = ( |
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result." |
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords." |
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database " |
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. " |
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. " |
|
f"Ensure no introductory or ending texts are included. Dont give more than 10 results" |
|
"Context: {context}" |
|
) |
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[("system", system_prompt), ("user", "{input}")] |
|
) |
|
|
|
|
|
async def create_chain(): |
|
client = ChatGroq(model="llama-3.2-3b-preview", temperature=0) |
|
return create_stuff_documents_chain(client, prompt) |
|
|
|
|
|
question_answer_chain = await create_chain() |
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke( |
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"} |
|
) |
|
|
|
|
|
result = [] |
|
raw_response = answer['answer'] |
|
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip() |
|
|
|
|
|
try: |
|
|
|
print("Cleaned Response:", cleaned_response) |
|
json_response = json.loads(cleaned_response) |
|
|
|
|
|
result = [] |
|
|
|
|
|
for i, journal in enumerate(json_response["journals"]): |
|
print("Journal entry:", journal) |
|
|
|
try: |
|
if isinstance(journal, dict): |
|
journal_name = journal.get('Journal Name') |
|
publisher = journal.get('Publisher') |
|
jif = float(journal.get('JIF', 0)) |
|
decision_time = journal.get('Decision Time', 0) |
|
|
|
|
|
if jif > impact_factor: |
|
result.append( |
|
Journal( |
|
id=i + 1, |
|
Name=journal_name, |
|
Publisher=publisher, |
|
JIF=jif, |
|
Category="", |
|
Keywords=keywords, |
|
Decision_Time=decision_time, |
|
) |
|
) |
|
else: |
|
print(f"Skipping invalid journal entry: {journal}") |
|
except Exception as e: |
|
print(f"Error processing journal data: {e}") |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing JSON response: {e}") |
|
result = [] |
|
|
|
|
|
return QueryResponse(result=result) |
|
|
|
|
|
elif model_choice == "mixtral": |
|
retriever = db.as_retriever() |
|
|
|
|
|
system_prompt = ( |
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result." |
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords." |
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database " |
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. " |
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. " |
|
f"Ensure no introductory or ending texts are included. Dont give more than 10 results" |
|
"Context: {context}" |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[("system", system_prompt), ("user", "{input}")] |
|
) |
|
|
|
|
|
|
|
|
|
async def create_chain(): |
|
client = ChatGroq(model="mixtral-8x7b-32768",temperature=0) |
|
return create_stuff_documents_chain(client, prompt) |
|
|
|
|
|
question_answer_chain = await create_chain() |
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke( |
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"} |
|
) |
|
|
|
|
|
result = [] |
|
raw_response = answer['answer'] |
|
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip() |
|
|
|
|
|
try: |
|
|
|
print("Cleaned Response:", cleaned_response) |
|
json_response = json.loads(cleaned_response) |
|
|
|
|
|
result = [] |
|
|
|
|
|
for i, journal in enumerate(json_response): |
|
print("Journal entry:", journal) |
|
|
|
try: |
|
if isinstance(journal, dict): |
|
journal_name = journal.get('Journal Name') |
|
publisher = journal.get('Publisher') |
|
jif = float(journal.get('JIF', 0)) |
|
decision_time = journal.get('Decsion Time', 0) |
|
|
|
|
|
if jif > impact_factor: |
|
result.append( |
|
Journal( |
|
id=i + 1, |
|
Name=journal_name, |
|
Publisher=publisher, |
|
JIF=jif, |
|
Category="", |
|
Keywords=keywords, |
|
Decision_Time=decision_time, |
|
) |
|
) |
|
else: |
|
print(f"Skipping invalid journal entry: {journal}") |
|
except Exception as e: |
|
print(f"Error processing journal data: {e}") |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing JSON response: {e}") |
|
result = [] |
|
|
|
|
|
return QueryResponse(result=result) |
|
|
|
elif model_choice == "gemini-pro": |
|
print("Using Gemini-Pro model") |
|
retriever = db.as_retriever() |
|
|
|
|
|
system_prompt = ( |
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result." |
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords." |
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database " |
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. " |
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. " |
|
f"Ensure no introductory or ending texts are included." |
|
"Context: {context}" |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[("system", system_prompt), ("user", "{input}")] |
|
) |
|
|
|
async def create_chain(): |
|
client = ChatGoogleGenerativeAI( |
|
model="gemini-pro", |
|
google_api_key=key, |
|
convert_system_message_to_human=True, |
|
) |
|
return create_stuff_documents_chain(client, prompt) |
|
|
|
|
|
question_answer_chain = await create_chain() |
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke( |
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"} |
|
) |
|
|
|
|
|
result = [] |
|
raw_response = answer['answer'] |
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip() |
|
|
|
|
|
try: |
|
json_response = json.loads(cleaned_response) |
|
|
|
|
|
result = [] |
|
|
|
|
|
for i, journal in enumerate(json_response): |
|
try: |
|
journal_name = journal.get('Journal Name') |
|
publisher = journal.get('Publisher') |
|
jif = float(journal.get('JIF', 0)) |
|
decision_time = journal.get('Decsion Time', 0) |
|
|
|
|
|
if jif > impact_factor: |
|
result.append( |
|
Journal( |
|
id=i + 1, |
|
Name=journal_name, |
|
Publisher=publisher, |
|
JIF=jif, |
|
Category="", |
|
Keywords=keywords, |
|
Decision_Time=decision_time, |
|
) |
|
) |
|
except Exception as e: |
|
print(f"Error processing journal data: {e}") |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"Error parsing JSON response: {e}") |
|
result = [] |
|
|
|
|
|
return QueryResponse(result=result) |
|
elif model_choice == "faiss": |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="BAAI/bge-small-en", model_kwargs={"device": "cpu"} |
|
) |
|
jif = impact_factor |
|
publisher = preferred_publisher.split() if preferred_publisher else [] |
|
|
|
|
|
db1 = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) |
|
|
|
|
|
query_embedding = embeddings.embed_query(keywords) |
|
|
|
|
|
results = db1.similarity_search_by_vector(query_embedding, k=20) |
|
|
|
|
|
context = "\n\n".join(doc.page_content for doc in results) |
|
|
|
|
|
min_jif = jif |
|
valid_publishers = publisher if publisher != ["no preference"] else None |
|
|
|
|
|
entries = re.split(r"\n(?=Name:)", context.strip()) |
|
|
|
|
|
journal_list = [] |
|
|
|
|
|
for entry in entries: |
|
|
|
name = re.search(r"Name: (.+)", entry) |
|
jif_match = re.search(r"JIF: (.+)", entry) |
|
category = re.search(r"Category: (.+)", entry) |
|
keywords_match = re.search(r"Keywords: (.+)", entry) |
|
publisher_match = re.search(r"Publisher: (.+)", entry) |
|
first_decision_match = re.search(r"Decsion Time: (.+)", entry) |
|
|
|
if jif_match : |
|
|
|
name_value = name.group(1).strip() |
|
jif_value = float(jif_match.group(1).strip()) |
|
category_value = category.group(1).strip() |
|
keywords_value = keywords_match.group(1).strip() |
|
publisher_value = publisher_match.group(1).strip() |
|
decision_time = first_decision_match.group(1).strip() |
|
|
|
if jif_value >= min_jif : |
|
|
|
if publisher and (publisher_value in publisher): |
|
print("inside pubisher match") |
|
journal = Journal( |
|
id=len(journal_list) + 1, |
|
Name=name_value, |
|
JIF=jif_value, |
|
Category=category_value, |
|
Keywords=keywords_value, |
|
Publisher=publisher_value, |
|
Decision_Time=decision_time |
|
) |
|
journal_list.append(journal) |
|
elif not publisher: |
|
journal = Journal( |
|
id=len(journal_list) + 1, |
|
Name=name_value, |
|
JIF=jif_value, |
|
Category=category_value, |
|
Keywords=keywords_value, |
|
Publisher=publisher_value, |
|
Decision_Time=decision_time |
|
) |
|
journal_list.append(journal) |
|
|
|
|
|
|
|
|
|
return {"result": [journal.dict() for journal in journal_list]} |
|
else: |
|
raise HTTPException(status_code=400, detail="Invalid model choice.") |
|
|
|
|
|
response = llm.predict(f"Context: {context}\n\nQuestion: {query_text}") |
|
return QueryResponse(result=response) |
|
|
|
|
|
|
|
|