import gradio as gr
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import anthropic
import json
import re
import os
from pathlib import Path
# Initialize the SentenceTransformer model
model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
def load_json_prompts(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def get_embeddings(chunk):
"""Get embedding for chunk using SentenceTransformer"""
embedding = model.encode(chunk)
return embedding
def search_similar(query_embedding, chunk_df, k=6, max_per_citation=2):
"""Search chunks using cosine similarity with citation limit"""
similarities = []
for idx, row in chunk_df.iterrows():
sim = np.dot(query_embedding, row['embeddings']) / (
np.linalg.norm(query_embedding) * np.linalg.norm(row['embeddings'])
)
similarities.append({
'citation_id': f"[{idx + 1}]",
'citation': row['citation'],
'text': row['text_chunk'],
'chunk_label': row['chunk_label'],
'similarity': sim
})
similarities = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
citations_count = {}
filtered_results = []
for result in similarities:
citation = result['citation']
citations_count[citation] = citations_count.get(citation, 0)
if citations_count[citation] < max_per_citation:
citations_count[citation] += 1
filtered_results.append({
'citation': result['citation'],
'text': result['text'],
'chunk_label': result['chunk_label']
})
if len(filtered_results) == k:
break
return filtered_results
def naive_search(thesis, context, client):
message = client.messages.create(
model="claude-3-sonnet-20240229",
max_tokens=3000,
temperature=0,
system=naive_system_prompt,
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": f"Here are the references you should use:\n{context}\n\n\nHere is the topic you need to write about:\n\n{thesis}\n\nReturn only the text without preface."
}]
}]
)
return message.content[0].text
def section_drafter(thesis, context, style_notes, client):
if len(style_notes) > 0:
style_notes = f'''Here's some **important** style guidelines::\n\n{style_notes}'''
message = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=5000,
temperature=0,
system=section_draft_prompt,
messages=[{
"role": "user",
"content": [{
"type": "text",
"text": f"Write a part of this literature review. {style_notes} Here are the references you should use:\n{context}\n\n\nWe are currently working on this section:\n\n{thesis}\n\nReturn only the prefaced only with an approriate markdown subheader for the specific section. The text should be comprehensive and detailed, being sure to cite existing work and the work it enganges with."
}]
}]
)
return message.content[0].text
def format_asa_citation(bibtex_string):
"""Converts a BibTeX string to an ASA style citation."""
author_match = re.search(r"author\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
title_match = re.search(r"title\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
year_match = re.search(r"year\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
journal_match = re.search(r"journal\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
volume_match = re.search(r"volume\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
number_match = re.search(r"number\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
pages_match = re.search(r"pages\s*=\s*{(.*?)}", bibtex_string, re.DOTALL)
author = author_match.group(1).strip() if author_match else ""
title = title_match.group(1).strip() if title_match else ""
year = year_match.group(1).strip() if year_match else ""
journal = journal_match.group(1).strip() if journal_match else ""
volume = volume_match.group(1).strip() if volume_match else ""
number = number_match.group(1).strip() if number_match else ""
pages = pages_match.group(1).strip() if pages_match else ""
citation = f"{author}. {year}. {title}. {journal} {volume}({number}): {pages}."
return citation
def extract_cites(context):
cites = [format_asa_citation(item['citation']) for item in context]
cites = list(set(cites))
return '\n'.join([f'* {cite}' for cite in cites])
def generate_literature_review(thesis, style_notes, api_key, progress=gr.Progress()):
yield gr.update(value="")
output = []
# Check if using the special password "quote" to load from environment
if api_key == "quote":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Environment variable ANTHROPIC_API_KEY not found")
# Initialize Anthropic client
client = anthropic.Anthropic(api_key=api_key)
# Load data
progress(0.1, desc="Loading document chunks...")
# First attempt
progress(0.2, desc="Finding initial references...")
context = search_similar(get_embeddings(thesis), chunk_df, k=8)
first_cites = extract_cites(context)
progress(0.4, desc="Generating first draft...")
naive_results = naive_search(thesis, context, client)
# Second attempt
progress(0.6, desc="Finding additional references...")
context2 = search_similar(get_embeddings(naive_results), chunk_df, k=16)
text_chunks = [c['text'] for c in context]
combo_context = context + [c for c in context2 if c['text'] not in text_chunks]
final_cites = extract_cites(combo_context)
ref_count = len(final_cites.split('\n'))
progress(0.8, desc=f"Generating final draft from {ref_count} sources...")
draft = section_drafter(thesis, combo_context, style_notes, client)
output.append(draft)
output.append("\n## Sources\n" + final_cites)
progress(1.0, desc="Complete!")
yield "\n\n".join(output)
def create_interface():
theme = gr.themes.Soft(
primary_hue="slate",
)
with gr.Blocks(theme=theme) as app:
gr.Markdown("# CiteCraft")
intro_text = f"Using {chunk_count} pages from {source_count} sources."
gr.Markdown(intro_text)
with gr.Row():
with gr.Column(scale=1):
thesis_input = gr.Textbox(
label="Section Topic",
placeholder="Enter your research question or section theme here",
lines=4
)
style_notes_input = gr.Textbox(
label="Style notes (Optional)",
placeholder="Enter any writing style modifications (optional)",
lines=4
)
api_key = gr.Textbox(
label="Anthropic API Key",
placeholder="Enter your Anthropic API key",
type="password"
)
generate_button = gr.Button("Generate Review", variant="primary")
with gr.Column(scale=2):
output_text = gr.Markdown(label="Generated Review")
generate_button.click(
generate_literature_review,
inputs=[thesis_input, style_notes_input, api_key],
outputs=output_text
)
return app
if __name__ == "__main__":
# Load necessary prompts and configurations
try:
json_prompts = load_json_prompts('prompts.json')
print("Successfully loaded prompts.json")
# Get system prompts and protest file name
section_draft_prompt = json_prompts['section_draft_prompt']
naive_system_prompt = json_prompts['naive_system_prompt']
protest_file_name = json_prompts['protest_file_name']
print(f"Will load document chunks from: {protest_file_name}")
chunk_df = pd.read_json(protest_file_name)
chunk_count = len(chunk_df)
source_count = len(chunk_df['citation'].value_counts())
except Exception as e:
print(f"Error loading prompts.json: {e}")
raise
# Launch the app
app = create_interface()
app.launch(share=True) # share=True creates a public URL