Spaces:
Running
Running
# llm_utils.py | |
import os | |
import time | |
import asyncio | |
import json | |
import logging | |
import streamlit as st | |
from google import genai | |
from google.genai import types | |
logger = logging.getLogger(__name__) | |
# Initialize the Gemini client using the new SDK | |
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
def get_generation_model(model_type: str): | |
if model_type == "flash": | |
model_name = "gemini-2.0-flash" | |
else: | |
model_name = "gemini-2.0-flash-thinking-exp-01-21" | |
generation_config = types.GenerateContentConfig( | |
temperature=0.7, | |
top_p=0.95, | |
top_k=64, | |
max_output_tokens=65536, | |
response_mime_type="text/plain", | |
) | |
return model_name, generation_config | |
async def async_generate_text(prompt, pdf_file=None, model_name=None, generation_config=None): | |
contents = [pdf_file, prompt] if pdf_file else prompt | |
while True: | |
try: | |
st.toast("Sending prompt to the model...") | |
response = await client.aio.models.generate_content( | |
model=model_name, | |
contents=contents, | |
config=generation_config, | |
) | |
st.toast("Received response from the model.") | |
logger.debug("Generated text for prompt. Length: %d", len(response.text)) | |
return response.text | |
except Exception as e: | |
logger.exception("Error during asynchronous LLM API call:") | |
st.toast("Error during asynchronous LLM API call: " + str(e)) | |
await asyncio.sleep(30) | |
def clean_json_response(response_text: str) -> str: | |
stripped = response_text.strip() | |
if stripped.startswith("```"): | |
lines = stripped.splitlines() | |
if lines[0].strip().startswith("```"): | |
lines = lines[1:] | |
if lines and lines[-1].strip() == "```": | |
lines = lines[:-1] | |
return "\n".join(lines).strip() | |
return response_text | |
class TitleReference: | |
def __init__(self, title=None, apa_reference=None, classification=None, bullet_list=None, error=None): | |
self.title = title | |
self.apa_reference = apa_reference | |
self.classification = classification | |
self.bullet_list = bullet_list or [] | |
self.error = error | |
async def generate_title_reference_and_classification(pdf_file, title_model_name, title_generation_config): | |
title_prompt = ( | |
"Analyze the uploaded document and determine if it is a valid academic article. " | |
"If it is a valid academic article, generate a concise and engaging title, an APA formatted reference, and classify the paper as 'Good academic paper'. " | |
"Also, generate a bullet list for the following items: context, method, theory, main findings. " | |
"If it is not a valid academic article (for example, if it is too short or just a title page), " | |
"classify it as 'Not a valid academic paper' and return an 'error' key with an appropriate message. " | |
"Output the result strictly in JSON format with keys 'title', 'apa_reference', 'classification', and 'bullet_list'. " | |
"The 'bullet_list' value should be an array of strings. Do not include any extra commentary." | |
) | |
response_text = await async_generate_text( | |
title_prompt, | |
pdf_file, | |
model_name=title_model_name, | |
generation_config=title_generation_config | |
) | |
logger.debug("Title/Reference generation response: %s", response_text) | |
cleaned_response = clean_json_response(response_text) | |
logger.debug("Cleaned Title/Reference JSON: %s", cleaned_response) | |
try: | |
data = json.loads(cleaned_response) | |
except Exception as e: | |
logger.exception("Invalid JSON returned: %s", e) | |
raise Exception("Invalid JSON returned: " + str(e)) | |
if "error" in data: | |
return TitleReference(error=data["error"]) | |
else: | |
required_keys = ["title", "apa_reference", "classification", "bullet_list"] | |
if any(key not in data for key in required_keys): | |
raise Exception("Expected keys 'title', 'apa_reference', 'classification', and 'bullet_list' not found in response.") | |
return TitleReference( | |
title=data["title"], | |
apa_reference=data["apa_reference"], | |
classification=data["classification"], | |
bullet_list=data["bullet_list"] | |
) | |
# Add these functions so they can be imported elsewhere | |
def upload_to_gemini(file_path, mime_type=None): | |
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
file = client.files.upload(file=file_path) | |
st.toast(f"Uploaded file '{file.display_name}' as: {file.uri}") | |
logger.debug("Uploaded file: %s with URI: %s", file.display_name, file.uri) | |
return file | |
def wait_for_files_active(files): | |
st.toast("Waiting for file processing...") | |
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
for file in files: | |
current_file = client.files.get(name=file.name) | |
logger.debug("Initial state for file %s: %s", file.name, current_file.state.name) | |
while current_file.state.name == "PROCESSING": | |
time.sleep(10) | |
current_file = client.files.get(name=file.name) | |
logger.debug("Polling file %s, state: %s", file.name, current_file.state.name) | |
if current_file.state.name != "ACTIVE": | |
error_msg = f"File {current_file.name} failed to process, state: {current_file.state.name}" | |
logger.error(error_msg) | |
raise Exception(error_msg) | |
st.toast("All files processed and ready.") | |
logger.debug("All files are active.") |