|
from fastapi import FastAPI, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
import os
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain_community.document_loaders import CSVLoader
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_groq import ChatGroq
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain.chains import create_retrieval_chain
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from dotenv import load_dotenv
|
|
from fastapi.responses import PlainTextResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import asyncio
|
|
import json
|
|
import re
|
|
|
|
load_dotenv()
|
|
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
|
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
|
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
|
key = os.getenv("GOOGLE_API_KEY")
|
|
|
|
DB_FAISS_PATH = "bgi/db_faiss"
|
|
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
embeddings = None
|
|
db = None
|
|
|
|
|
|
@app.on_event("startup")
|
|
def load_vector_store():
|
|
global embeddings, db
|
|
if os.path.exists(DB_FAISS_PATH):
|
|
print("Loading existing FAISS vector store.")
|
|
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-small-en', model_kwargs={'device': 'cpu'})
|
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
|
|
print("Vector store loaded.")
|
|
else:
|
|
print("Creating new FAISS vector store.")
|
|
loader = CSVLoader(file_path="Final_Research_Dataset_2.csv", encoding="utf-8", csv_args={'delimiter': ','})
|
|
data = loader.load()
|
|
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-small-en', model_kwargs={'device': 'cpu'})
|
|
db = FAISS.from_documents(data, embeddings)
|
|
db.save_local(DB_FAISS_PATH)
|
|
|
|
|
|
|
|
from typing import List, Optional
|
|
|
|
class FilterCriteria(BaseModel):
|
|
impactFactor: float
|
|
firstDecisionTime: int
|
|
publisher: Optional[str]
|
|
llmModel: str
|
|
|
|
class QueryRequest(BaseModel):
|
|
abstract: str
|
|
criteria: FilterCriteria
|
|
|
|
class Journal(BaseModel):
|
|
id: int
|
|
Name: str
|
|
JIF: float
|
|
Category: str
|
|
Keywords: str
|
|
Publisher: str
|
|
Decision_Time: int
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
result: List[Journal]
|
|
|
|
|
|
@app.get("/", response_class=PlainTextResponse)
|
|
def read_root():
|
|
return "Welcome to the Journal Recommender API!"
|
|
|
|
@app.get("/models")
|
|
def get_models():
|
|
return {"available_models": ["openai", "groq","mixtral","gemini-pro","faiss"]}
|
|
|
|
def fix_incomplete_json(raw_response):
|
|
"""
|
|
Fixes incomplete JSON by adding missing braces or brackets.
|
|
Returns a valid JSON string or None if not fixable.
|
|
"""
|
|
|
|
if raw_response.endswith("},"):
|
|
raw_response = raw_response[:-1]
|
|
if raw_response.count("{") > raw_response.count("}"):
|
|
raw_response += "}"
|
|
if raw_response.count("[") > raw_response.count("]"):
|
|
raw_response += "]"
|
|
|
|
|
|
try:
|
|
json_response = json.loads(raw_response)
|
|
return json_response
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error fixing JSON: {e}")
|
|
return None
|
|
|
|
|
|
|
|
@app.post("/query", response_model=QueryResponse)
|
|
async def query(request: QueryRequest):
|
|
global db
|
|
if not db:
|
|
raise HTTPException(status_code=500, detail="Vector store not loaded.")
|
|
|
|
query_text = request.abstract
|
|
model_choice = request.criteria.llmModel
|
|
impact_factor = request.criteria.impactFactor
|
|
preferred_publisher = request.criteria.publisher
|
|
|
|
docs = db.similarity_search(query_text, k=5)
|
|
context = "\n".join([doc.page_content for doc in docs])
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Give a strict comma-separated list of exactly 15 keywords from the following text. "
|
|
"Give a strict comma-separated list of exactly 15 keywords from the following text. "
|
|
"Do not include any bullet points, introductory text, or ending text. "
|
|
"No introductory or ending text strictly"
|
|
"Do not say anything like 'Here are the keywords.' "
|
|
"Only return the keywords, strictly comma-separated, without any additional words."
|
|
),
|
|
},
|
|
{"role": "user", "content": query_text},
|
|
]
|
|
llm = ChatGroq(model="llama3-8b-8192", temperature=0)
|
|
ai_msg = llm.invoke(messages)
|
|
keywords = ai_msg.content.split("keywords extracted from the text:\n")[-1].strip()
|
|
print("Keywords:", keywords)
|
|
if model_choice == "openai":
|
|
retriever = db.as_retriever()
|
|
|
|
|
|
system_prompt = (
|
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result."
|
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords."
|
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database "
|
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. "
|
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. "
|
|
f"Ensure no introductory or ending texts are included. Give max 30 results"
|
|
"Context: {context}"
|
|
)
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", system_prompt), ("user", "{input}")]
|
|
)
|
|
|
|
|
|
async def create_chain():
|
|
client = ChatOpenAI(model="gpt-4o")
|
|
return create_stuff_documents_chain(client, prompt)
|
|
|
|
|
|
question_answer_chain = await create_chain()
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke(
|
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"}
|
|
)
|
|
|
|
|
|
result = []
|
|
raw_response = answer['answer']
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip()
|
|
|
|
|
|
try:
|
|
json_response = json.loads(cleaned_response)
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for i, journal in enumerate(json_response):
|
|
try:
|
|
journal_name = journal.get('Journal Name')
|
|
publisher = journal.get('Publisher')
|
|
jif = float(journal.get('JIF', 0))
|
|
decision_time = journal.get('Decsion Time', 0)
|
|
|
|
|
|
if jif > impact_factor:
|
|
result.append(
|
|
Journal(
|
|
id=i + 1,
|
|
Name=journal_name,
|
|
Publisher=publisher,
|
|
JIF=jif,
|
|
Category="",
|
|
Keywords=keywords,
|
|
Decision_Time=decision_time,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
print(f"Error processing journal data: {e}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error parsing JSON response: {e}")
|
|
result = []
|
|
|
|
|
|
return QueryResponse(result=result)
|
|
elif model_choice == "groq":
|
|
retriever = db.as_retriever()
|
|
|
|
|
|
system_prompt = (
|
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result."
|
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords."
|
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database "
|
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. "
|
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. "
|
|
f"Ensure no introductory or ending texts are included. Dont give more than 10 results"
|
|
"Context: {context}"
|
|
)
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", system_prompt), ("user", "{input}")]
|
|
)
|
|
|
|
|
|
async def create_chain():
|
|
client = ChatGroq(model="llama-3.2-3b-preview", temperature=0)
|
|
return create_stuff_documents_chain(client, prompt)
|
|
|
|
|
|
question_answer_chain = await create_chain()
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke(
|
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"}
|
|
)
|
|
|
|
|
|
result = []
|
|
raw_response = answer['answer']
|
|
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip()
|
|
|
|
|
|
try:
|
|
|
|
print("Cleaned Response:", cleaned_response)
|
|
json_response = json.loads(cleaned_response)
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for i, journal in enumerate(json_response["journals"]):
|
|
print("Journal entry:", journal)
|
|
|
|
try:
|
|
if isinstance(journal, dict):
|
|
journal_name = journal.get('Journal Name')
|
|
publisher = journal.get('Publisher')
|
|
jif = float(journal.get('JIF', 0))
|
|
decision_time = journal.get('Decision Time', 0)
|
|
|
|
|
|
if jif > impact_factor:
|
|
result.append(
|
|
Journal(
|
|
id=i + 1,
|
|
Name=journal_name,
|
|
Publisher=publisher,
|
|
JIF=jif,
|
|
Category="",
|
|
Keywords=keywords,
|
|
Decision_Time=decision_time,
|
|
)
|
|
)
|
|
else:
|
|
print(f"Skipping invalid journal entry: {journal}")
|
|
except Exception as e:
|
|
print(f"Error processing journal data: {e}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error parsing JSON response: {e}")
|
|
result = []
|
|
|
|
|
|
return QueryResponse(result=result)
|
|
|
|
|
|
elif model_choice == "mixtral":
|
|
retriever = db.as_retriever()
|
|
|
|
|
|
system_prompt = (
|
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result."
|
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords."
|
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database "
|
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. "
|
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. "
|
|
f"Ensure no introductory or ending texts are included. Dont give more than 10 results"
|
|
"Context: {context}"
|
|
)
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", system_prompt), ("user", "{input}")]
|
|
)
|
|
|
|
|
|
|
|
|
|
async def create_chain():
|
|
client = ChatGroq(model="mixtral-8x7b-32768",temperature=0)
|
|
return create_stuff_documents_chain(client, prompt)
|
|
|
|
|
|
question_answer_chain = await create_chain()
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke(
|
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"}
|
|
)
|
|
|
|
|
|
result = []
|
|
raw_response = answer['answer']
|
|
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip()
|
|
|
|
|
|
try:
|
|
|
|
print("Cleaned Response:", cleaned_response)
|
|
json_response = json.loads(cleaned_response)
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for i, journal in enumerate(json_response):
|
|
print("Journal entry:", journal)
|
|
|
|
try:
|
|
if isinstance(journal, dict):
|
|
journal_name = journal.get('Journal Name')
|
|
publisher = journal.get('Publisher')
|
|
jif = float(journal.get('JIF', 0))
|
|
decision_time = journal.get('Decsion Time', 0)
|
|
|
|
|
|
if jif > impact_factor:
|
|
result.append(
|
|
Journal(
|
|
id=i + 1,
|
|
Name=journal_name,
|
|
Publisher=publisher,
|
|
JIF=jif,
|
|
Category="",
|
|
Keywords=keywords,
|
|
Decision_Time=decision_time,
|
|
)
|
|
)
|
|
else:
|
|
print(f"Skipping invalid journal entry: {journal}")
|
|
except Exception as e:
|
|
print(f"Error processing journal data: {e}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error parsing JSON response: {e}")
|
|
result = []
|
|
|
|
|
|
return QueryResponse(result=result)
|
|
|
|
elif model_choice == "gemini-pro":
|
|
print("Using Gemini-Pro model")
|
|
retriever = db.as_retriever()
|
|
|
|
|
|
system_prompt = (
|
|
f"You are a specialized Journal recommender that compares all journals in database to given research paper keywords and based on JIF and publisher gives result."
|
|
f"From the provided context, recommend all journals that are suitable for research paper with {keywords} keywords."
|
|
f"Ensure that you include **every** journal with a Journal Impact Factor (JIF) strictly greater than {impact_factor}, and the Journal must be only from any Publishers in list: {preferred_publisher}. And Pls show that jif as in Context database "
|
|
f"Make sure to include both exact matches and related journals, and prioritize including **all relevant high-JIF journals without repetition**. "
|
|
f"Present the results in a json format with the following information: Journal Name, Publisher, JIF, Decsion Time. "
|
|
f"Ensure no introductory or ending texts are included."
|
|
"Context: {context}"
|
|
)
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", system_prompt), ("user", "{input}")]
|
|
)
|
|
|
|
async def create_chain():
|
|
client = ChatGoogleGenerativeAI(
|
|
model="gemini-pro",
|
|
google_api_key=key,
|
|
convert_system_message_to_human=True,
|
|
)
|
|
return create_stuff_documents_chain(client, prompt)
|
|
|
|
|
|
question_answer_chain = await create_chain()
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
|
|
|
|
|
|
|
|
|
|
|
answer = rag_chain.invoke(
|
|
{"input": f"Keywords: {keywords}, Minimum JIF: {impact_factor},Publisher list: {preferred_publisher}"}
|
|
)
|
|
|
|
|
|
result = []
|
|
raw_response = answer['answer']
|
|
cleaned_response = raw_response.strip('```json\n').strip('```').strip()
|
|
|
|
|
|
try:
|
|
json_response = json.loads(cleaned_response)
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for i, journal in enumerate(json_response):
|
|
try:
|
|
journal_name = journal.get('Journal Name')
|
|
publisher = journal.get('Publisher')
|
|
jif = float(journal.get('JIF', 0))
|
|
decision_time = journal.get('Decsion Time', 0)
|
|
|
|
|
|
if jif > impact_factor:
|
|
result.append(
|
|
Journal(
|
|
id=i + 1,
|
|
Name=journal_name,
|
|
Publisher=publisher,
|
|
JIF=jif,
|
|
Category="",
|
|
Keywords=keywords,
|
|
Decision_Time=decision_time,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
print(f"Error processing journal data: {e}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error parsing JSON response: {e}")
|
|
result = []
|
|
|
|
|
|
return QueryResponse(result=result)
|
|
elif model_choice == "faiss":
|
|
embeddings = HuggingFaceEmbeddings(
|
|
model_name="BAAI/bge-small-en", model_kwargs={"device": "cpu"}
|
|
)
|
|
jif = impact_factor
|
|
publisher = preferred_publisher
|
|
|
|
|
|
db1 = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
|
|
|
|
|
|
query_embedding = embeddings.embed_query(keywords)
|
|
|
|
|
|
results = db1.similarity_search_by_vector(query_embedding, k=20)
|
|
|
|
|
|
context = "\n\n".join(doc.page_content for doc in results)
|
|
|
|
|
|
min_jif = jif
|
|
valid_publishers = publisher if publisher != ["no preference"] else None
|
|
|
|
|
|
entries = re.split(r"\n(?=Name:)", context.strip())
|
|
|
|
|
|
journal_list = []
|
|
|
|
|
|
for entry in entries:
|
|
|
|
name = re.search(r"Name: (.+)", entry)
|
|
jif_match = re.search(r"JIF: (.+)", entry)
|
|
category = re.search(r"Category: (.+)", entry)
|
|
keywords_match = re.search(r"Keywords: (.+)", entry)
|
|
publisher_match = re.search(r"Publisher: (.+)", entry)
|
|
first_decision_match = re.search(r"Decsion Time: (.+)", entry)
|
|
|
|
if jif_match :
|
|
|
|
name_value = name.group(1).strip()
|
|
jif_value = float(jif_match.group(1).strip())
|
|
category_value = category.group(1).strip()
|
|
keywords_value = keywords_match.group(1).strip()
|
|
publisher_value = publisher_match.group(1).strip()
|
|
decision_time = first_decision_match.group(1).strip()
|
|
|
|
if jif_value >= min_jif :
|
|
|
|
journal = Journal(
|
|
id=len(journal_list) + 1,
|
|
Name=name_value,
|
|
JIF=jif_value,
|
|
Category=category_value,
|
|
Keywords=keywords_value,
|
|
Publisher=publisher_value,
|
|
Decision_Time = decision_time
|
|
)
|
|
|
|
|
|
journal_list.append(journal)
|
|
|
|
|
|
return {"result": [journal.dict() for journal in journal_list]}
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Invalid model choice.")
|
|
|
|
|
|
response = llm.predict(f"Context: {context}\n\nQuestion: {query_text}")
|
|
return QueryResponse(result=response)
|
|
|
|
|
|
|
|
|