Spaces:
Running
Running
File size: 5,602 Bytes
d3fdae9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# llm_utils.py
import os
import time
import asyncio
import json
import logging
import streamlit as st
from google import genai
from google.genai import types
logger = logging.getLogger(__name__)
# Initialize the Gemini client using the new SDK
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
def get_generation_model(model_type: str):
if model_type == "flash":
model_name = "gemini-2.0-flash"
else:
model_name = "gemini-2.0-flash-thinking-exp-01-21"
generation_config = types.GenerateContentConfig(
temperature=0.7,
top_p=0.95,
top_k=64,
max_output_tokens=65536,
response_mime_type="text/plain",
)
return model_name, generation_config
async def async_generate_text(prompt, pdf_file=None, model_name=None, generation_config=None):
contents = [pdf_file, prompt] if pdf_file else prompt
while True:
try:
st.toast("Sending prompt to the model...")
response = await client.aio.models.generate_content(
model=model_name,
contents=contents,
config=generation_config,
)
st.toast("Received response from the model.")
logger.debug("Generated text for prompt. Length: %d", len(response.text))
return response.text
except Exception as e:
logger.exception("Error during asynchronous LLM API call:")
st.toast("Error during asynchronous LLM API call: " + str(e))
await asyncio.sleep(30)
def clean_json_response(response_text: str) -> str:
stripped = response_text.strip()
if stripped.startswith("```"):
lines = stripped.splitlines()
if lines[0].strip().startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines).strip()
return response_text
class TitleReference:
def __init__(self, title=None, apa_reference=None, classification=None, bullet_list=None, error=None):
self.title = title
self.apa_reference = apa_reference
self.classification = classification
self.bullet_list = bullet_list or []
self.error = error
async def generate_title_reference_and_classification(pdf_file, title_model_name, title_generation_config):
title_prompt = (
"Analyze the uploaded document and determine if it is a valid academic article. "
"If it is a valid academic article, generate a concise and engaging title, an APA formatted reference, and classify the paper as 'Good academic paper'. "
"Also, generate a bullet list for the following items: context, method, theory, main findings. "
"If it is not a valid academic article (for example, if it is too short or just a title page), "
"classify it as 'Not a valid academic paper' and return an 'error' key with an appropriate message. "
"Output the result strictly in JSON format with keys 'title', 'apa_reference', 'classification', and 'bullet_list'. "
"The 'bullet_list' value should be an array of strings. Do not include any extra commentary."
)
response_text = await async_generate_text(
title_prompt,
pdf_file,
model_name=title_model_name,
generation_config=title_generation_config
)
logger.debug("Title/Reference generation response: %s", response_text)
cleaned_response = clean_json_response(response_text)
logger.debug("Cleaned Title/Reference JSON: %s", cleaned_response)
try:
data = json.loads(cleaned_response)
except Exception as e:
logger.exception("Invalid JSON returned: %s", e)
raise Exception("Invalid JSON returned: " + str(e))
if "error" in data:
return TitleReference(error=data["error"])
else:
required_keys = ["title", "apa_reference", "classification", "bullet_list"]
if any(key not in data for key in required_keys):
raise Exception("Expected keys 'title', 'apa_reference', 'classification', and 'bullet_list' not found in response.")
return TitleReference(
title=data["title"],
apa_reference=data["apa_reference"],
classification=data["classification"],
bullet_list=data["bullet_list"]
)
# Add these functions so they can be imported elsewhere
def upload_to_gemini(file_path, mime_type=None):
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
file = client.files.upload(file=file_path)
st.toast(f"Uploaded file '{file.display_name}' as: {file.uri}")
logger.debug("Uploaded file: %s with URI: %s", file.display_name, file.uri)
return file
def wait_for_files_active(files):
st.toast("Waiting for file processing...")
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
for file in files:
current_file = client.files.get(name=file.name)
logger.debug("Initial state for file %s: %s", file.name, current_file.state.name)
while current_file.state.name == "PROCESSING":
time.sleep(10)
current_file = client.files.get(name=file.name)
logger.debug("Polling file %s, state: %s", file.name, current_file.state.name)
if current_file.state.name != "ACTIVE":
error_msg = f"File {current_file.name} failed to process, state: {current_file.state.name}"
logger.error(error_msg)
raise Exception(error_msg)
st.toast("All files processed and ready.")
logger.debug("All files are active.") |