File size: 5,602 Bytes
d3fdae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# llm_utils.py
import os
import time
import asyncio
import json
import logging
import streamlit as st

from google import genai
from google.genai import types

logger = logging.getLogger(__name__)

# Initialize the Gemini client using the new SDK
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))

def get_generation_model(model_type: str):
    if model_type == "flash":
        model_name = "gemini-2.0-flash"
    else:
        model_name = "gemini-2.0-flash-thinking-exp-01-21"
    generation_config = types.GenerateContentConfig(
        temperature=0.7,
        top_p=0.95,
        top_k=64,
        max_output_tokens=65536,
        response_mime_type="text/plain",
    )
    return model_name, generation_config

async def async_generate_text(prompt, pdf_file=None, model_name=None, generation_config=None):
    contents = [pdf_file, prompt] if pdf_file else prompt
    while True:
        try:
            st.toast("Sending prompt to the model...")
            response = await client.aio.models.generate_content(
                model=model_name,
                contents=contents,
                config=generation_config,
            )
            st.toast("Received response from the model.")
            logger.debug("Generated text for prompt. Length: %d", len(response.text))
            return response.text
        except Exception as e:
            logger.exception("Error during asynchronous LLM API call:")
            st.toast("Error during asynchronous LLM API call: " + str(e))
            await asyncio.sleep(30)

def clean_json_response(response_text: str) -> str:
    stripped = response_text.strip()
    if stripped.startswith("```"):
        lines = stripped.splitlines()
        if lines[0].strip().startswith("```"):
            lines = lines[1:]
        if lines and lines[-1].strip() == "```":
            lines = lines[:-1]
        return "\n".join(lines).strip()
    return response_text

class TitleReference:
    def __init__(self, title=None, apa_reference=None, classification=None, bullet_list=None, error=None):
        self.title = title
        self.apa_reference = apa_reference
        self.classification = classification
        self.bullet_list = bullet_list or []
        self.error = error

async def generate_title_reference_and_classification(pdf_file, title_model_name, title_generation_config):
    title_prompt = (
        "Analyze the uploaded document and determine if it is a valid academic article. "
        "If it is a valid academic article, generate a concise and engaging title, an APA formatted reference, and classify the paper as 'Good academic paper'. "
        "Also, generate a bullet list for the following items: context, method, theory, main findings. "
        "If it is not a valid academic article (for example, if it is too short or just a title page), "
        "classify it as 'Not a valid academic paper' and return an 'error' key with an appropriate message. "
        "Output the result strictly in JSON format with keys 'title', 'apa_reference', 'classification', and 'bullet_list'. "
        "The 'bullet_list' value should be an array of strings. Do not include any extra commentary."
    )
    response_text = await async_generate_text(
        title_prompt,
        pdf_file,
        model_name=title_model_name,
        generation_config=title_generation_config
    )
    logger.debug("Title/Reference generation response: %s", response_text)
    cleaned_response = clean_json_response(response_text)
    logger.debug("Cleaned Title/Reference JSON: %s", cleaned_response)
    try:
        data = json.loads(cleaned_response)
    except Exception as e:
        logger.exception("Invalid JSON returned: %s", e)
        raise Exception("Invalid JSON returned: " + str(e))
    
    if "error" in data:
        return TitleReference(error=data["error"])
    else:
        required_keys = ["title", "apa_reference", "classification", "bullet_list"]
        if any(key not in data for key in required_keys):
            raise Exception("Expected keys 'title', 'apa_reference', 'classification', and 'bullet_list' not found in response.")
        return TitleReference(
            title=data["title"],
            apa_reference=data["apa_reference"],
            classification=data["classification"],
            bullet_list=data["bullet_list"]
        )
        
        
# Add these functions so they can be imported elsewhere
def upload_to_gemini(file_path, mime_type=None):
    client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
    file = client.files.upload(file=file_path)
    st.toast(f"Uploaded file '{file.display_name}' as: {file.uri}")
    logger.debug("Uploaded file: %s with URI: %s", file.display_name, file.uri)
    return file

def wait_for_files_active(files):
    st.toast("Waiting for file processing...")
    client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
    for file in files:
        current_file = client.files.get(name=file.name)
        logger.debug("Initial state for file %s: %s", file.name, current_file.state.name)
        while current_file.state.name == "PROCESSING":
            time.sleep(10)
            current_file = client.files.get(name=file.name)
            logger.debug("Polling file %s, state: %s", file.name, current_file.state.name)
        if current_file.state.name != "ACTIVE":
            error_msg = f"File {current_file.name} failed to process, state: {current_file.state.name}"
            logger.error(error_msg)
            raise Exception(error_msg)
    st.toast("All files processed and ready.")
    logger.debug("All files are active.")