pdf-digest / utils /llm_utils.py
RJuro's picture
Reinitialize repository without offending large file
d3fdae9
raw
history blame
5.6 kB
# llm_utils.py
import os
import time
import asyncio
import json
import logging
import streamlit as st
from google import genai
from google.genai import types
logger = logging.getLogger(__name__)
# Initialize the Gemini client using the new SDK
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
def get_generation_model(model_type: str):
if model_type == "flash":
model_name = "gemini-2.0-flash"
else:
model_name = "gemini-2.0-flash-thinking-exp-01-21"
generation_config = types.GenerateContentConfig(
temperature=0.7,
top_p=0.95,
top_k=64,
max_output_tokens=65536,
response_mime_type="text/plain",
)
return model_name, generation_config
async def async_generate_text(prompt, pdf_file=None, model_name=None, generation_config=None):
contents = [pdf_file, prompt] if pdf_file else prompt
while True:
try:
st.toast("Sending prompt to the model...")
response = await client.aio.models.generate_content(
model=model_name,
contents=contents,
config=generation_config,
)
st.toast("Received response from the model.")
logger.debug("Generated text for prompt. Length: %d", len(response.text))
return response.text
except Exception as e:
logger.exception("Error during asynchronous LLM API call:")
st.toast("Error during asynchronous LLM API call: " + str(e))
await asyncio.sleep(30)
def clean_json_response(response_text: str) -> str:
stripped = response_text.strip()
if stripped.startswith("```"):
lines = stripped.splitlines()
if lines[0].strip().startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines).strip()
return response_text
class TitleReference:
def __init__(self, title=None, apa_reference=None, classification=None, bullet_list=None, error=None):
self.title = title
self.apa_reference = apa_reference
self.classification = classification
self.bullet_list = bullet_list or []
self.error = error
async def generate_title_reference_and_classification(pdf_file, title_model_name, title_generation_config):
title_prompt = (
"Analyze the uploaded document and determine if it is a valid academic article. "
"If it is a valid academic article, generate a concise and engaging title, an APA formatted reference, and classify the paper as 'Good academic paper'. "
"Also, generate a bullet list for the following items: context, method, theory, main findings. "
"If it is not a valid academic article (for example, if it is too short or just a title page), "
"classify it as 'Not a valid academic paper' and return an 'error' key with an appropriate message. "
"Output the result strictly in JSON format with keys 'title', 'apa_reference', 'classification', and 'bullet_list'. "
"The 'bullet_list' value should be an array of strings. Do not include any extra commentary."
)
response_text = await async_generate_text(
title_prompt,
pdf_file,
model_name=title_model_name,
generation_config=title_generation_config
)
logger.debug("Title/Reference generation response: %s", response_text)
cleaned_response = clean_json_response(response_text)
logger.debug("Cleaned Title/Reference JSON: %s", cleaned_response)
try:
data = json.loads(cleaned_response)
except Exception as e:
logger.exception("Invalid JSON returned: %s", e)
raise Exception("Invalid JSON returned: " + str(e))
if "error" in data:
return TitleReference(error=data["error"])
else:
required_keys = ["title", "apa_reference", "classification", "bullet_list"]
if any(key not in data for key in required_keys):
raise Exception("Expected keys 'title', 'apa_reference', 'classification', and 'bullet_list' not found in response.")
return TitleReference(
title=data["title"],
apa_reference=data["apa_reference"],
classification=data["classification"],
bullet_list=data["bullet_list"]
)
# Add these functions so they can be imported elsewhere
def upload_to_gemini(file_path, mime_type=None):
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
file = client.files.upload(file=file_path)
st.toast(f"Uploaded file '{file.display_name}' as: {file.uri}")
logger.debug("Uploaded file: %s with URI: %s", file.display_name, file.uri)
return file
def wait_for_files_active(files):
st.toast("Waiting for file processing...")
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
for file in files:
current_file = client.files.get(name=file.name)
logger.debug("Initial state for file %s: %s", file.name, current_file.state.name)
while current_file.state.name == "PROCESSING":
time.sleep(10)
current_file = client.files.get(name=file.name)
logger.debug("Polling file %s, state: %s", file.name, current_file.state.name)
if current_file.state.name != "ACTIVE":
error_msg = f"File {current_file.name} failed to process, state: {current_file.state.name}"
logger.error(error_msg)
raise Exception(error_msg)
st.toast("All files processed and ready.")
logger.debug("All files are active.")