import gradio as gr import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer import anthropic import json import re import os from pathlib import Path # Initialize the SentenceTransformer model model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) def load_json_prompts(file_path): with open(file_path, 'r', encoding='utf-8') as file: return json.load(file) def get_embeddings(chunk): """Get embedding for chunk using SentenceTransformer""" embedding = model.encode(chunk) return embedding def search_similar(query_embedding, chunk_df, k=6, max_per_citation=2): """Search chunks using cosine similarity with citation limit""" similarities = [] for idx, row in chunk_df.iterrows(): sim = np.dot(query_embedding, row['embeddings']) / ( np.linalg.norm(query_embedding) * np.linalg.norm(row['embeddings']) ) similarities.append({ 'citation_id': f"[{idx + 1}]", 'citation': row['citation'], 'text': row['text_chunk'], 'chunk_label': row['chunk_label'], 'similarity': sim }) similarities = sorted(similarities, key=lambda x: x['similarity'], reverse=True) citations_count = {} filtered_results = [] for result in similarities: citation = result['citation'] citations_count[citation] = citations_count.get(citation, 0) if citations_count[citation] < max_per_citation: citations_count[citation] += 1 filtered_results.append({ 'citation': result['citation'], 'text': result['text'], 'chunk_label': result['chunk_label'] }) if len(filtered_results) == k: break return filtered_results def naive_search(thesis, context, client): message = client.messages.create( model="claude-3-sonnet-20240229", max_tokens=3000, temperature=0, system=naive_system_prompt, messages=[{ "role": "user", "content": [{ "type": "text", "text": f"Here are the references you should use:\n{context}\n\n\nHere is the topic you need to write about:\n\n{thesis}\n\nReturn only the text without preface." }] }] ) return message.content[0].text def section_drafter(thesis, context, style_notes, client): if len(style_notes) > 0: style_notes = f'''Here's some **important** style guidelines::\n\n{style_notes}''' message = client.messages.create( model="claude-3-5-sonnet-20241022", max_tokens=5000, temperature=0, system=section_draft_prompt, messages=[{ "role": "user", "content": [{ "type": "text", "text": f"Write a part of this literature review. {style_notes} Here are the references you should use:\n{context}\n\n\nWe are currently working on this section:\n\n{thesis}\n\nReturn only the prefaced only with an approriate markdown subheader for the specific section. The text should be comprehensive and detailed, being sure to cite existing work and the work it enganges with." }] }] ) return message.content[0].text def format_asa_citation(bibtex_string): """Converts a BibTeX string to an ASA style citation.""" author_match = re.search(r"author\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) title_match = re.search(r"title\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) year_match = re.search(r"year\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) journal_match = re.search(r"journal\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) volume_match = re.search(r"volume\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) number_match = re.search(r"number\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) pages_match = re.search(r"pages\s*=\s*{(.*?)}", bibtex_string, re.DOTALL) author = author_match.group(1).strip() if author_match else "" title = title_match.group(1).strip() if title_match else "" year = year_match.group(1).strip() if year_match else "" journal = journal_match.group(1).strip() if journal_match else "" volume = volume_match.group(1).strip() if volume_match else "" number = number_match.group(1).strip() if number_match else "" pages = pages_match.group(1).strip() if pages_match else "" citation = f"{author}. {year}. {title}. {journal} {volume}({number}): {pages}." return citation def extract_cites(context): cites = [format_asa_citation(item['citation']) for item in context] cites = list(set(cites)) return '\n'.join([f'* {cite}' for cite in cites]) def generate_literature_review(thesis, style_notes, api_key, progress=gr.Progress()): yield gr.update(value="") output = [] # Check if using the special password "quote" to load from environment if api_key == "quote": api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise ValueError("Environment variable ANTHROPIC_API_KEY not found") # Initialize Anthropic client client = anthropic.Anthropic(api_key=api_key) # Load data progress(0.1, desc="Loading document chunks...") # First attempt progress(0.2, desc="Finding initial references...") context = search_similar(get_embeddings(thesis), chunk_df, k=8) first_cites = extract_cites(context) progress(0.4, desc="Generating first draft...") naive_results = naive_search(thesis, context, client) # Second attempt progress(0.6, desc="Finding additional references...") context2 = search_similar(get_embeddings(naive_results), chunk_df, k=16) text_chunks = [c['text'] for c in context] combo_context = context + [c for c in context2 if c['text'] not in text_chunks] final_cites = extract_cites(combo_context) ref_count = len(final_cites.split('\n')) progress(0.8, desc=f"Generating final draft from {ref_count} sources...") draft = section_drafter(thesis, combo_context, style_notes, client) output.append(draft) output.append("\n## Sources\n" + final_cites) progress(1.0, desc="Complete!") yield "\n\n".join(output) def create_interface(): theme = gr.themes.Soft( primary_hue="slate", ) with gr.Blocks(theme=theme) as app: gr.Markdown("# CiteCraft") intro_text = f"Using {chunk_count} pages from {source_count} sources." gr.Markdown(intro_text) with gr.Row(): with gr.Column(scale=1): thesis_input = gr.Textbox( label="Section Topic", placeholder="Enter your research question or section theme here", lines=4 ) style_notes_input = gr.Textbox( label="Style notes (Optional)", placeholder="Enter any writing style modifications (optional)", lines=4 ) api_key = gr.Textbox( label="Anthropic API Key", placeholder="Enter your Anthropic API key", type="password" ) generate_button = gr.Button("Generate Review", variant="primary") with gr.Column(scale=2): output_text = gr.Markdown(label="Generated Review") generate_button.click( generate_literature_review, inputs=[thesis_input, style_notes_input, api_key], outputs=output_text ) return app if __name__ == "__main__": # Load necessary prompts and configurations try: json_prompts = load_json_prompts('prompts.json') print("Successfully loaded prompts.json") # Get system prompts and protest file name section_draft_prompt = json_prompts['section_draft_prompt'] naive_system_prompt = json_prompts['naive_system_prompt'] protest_file_name = json_prompts['protest_file_name'] print(f"Will load document chunks from: {protest_file_name}") chunk_df = pd.read_json(protest_file_name) chunk_count = len(chunk_df) source_count = len(chunk_df['citation'].value_counts()) except Exception as e: print(f"Error loading prompts.json: {e}") raise # Launch the app app = create_interface() app.launch(share=True) # share=True creates a public URL