|
import gradio as gr |
|
import PyPDF2 |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from deep_translator import GoogleTranslator |
|
import logging |
|
from typing import Optional, Dict |
|
import time |
|
from pathlib import Path |
|
import os |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
LANGUAGE_MAPPING = { |
|
"hi": { |
|
"name": "Hindi - हिन्दी", |
|
"description": "Official language of India, written in Devanagari script", |
|
"deep_translator_code": "hi" |
|
}, |
|
"ta": { |
|
"name": "Tamil - தமிழ்", |
|
"description": "Classical language of Tamil Nadu, written in Tamil script", |
|
"deep_translator_code": "ta" |
|
}, |
|
"te": { |
|
"name": "Telugu - తెలుగు", |
|
"description": "Official language of Andhra Pradesh and Telangana", |
|
"deep_translator_code": "te" |
|
}, |
|
"bn": { |
|
"name": "Bengali - বাংলা", |
|
"description": "Official language of West Bengal and Bangladesh", |
|
"deep_translator_code": "bn" |
|
}, |
|
"mr": { |
|
"name": "Marathi - मराठी", |
|
"description": "Official language of Maharashtra", |
|
"deep_translator_code": "mr" |
|
} |
|
} |
|
|
|
class PDFQueryTranslator: |
|
def __init__(self, max_retries=3, retry_delay=1): |
|
self.max_retries = max_retries |
|
self.retry_delay = retry_delay |
|
self.setup_device() |
|
self.setup_model() |
|
logger.info(f"Initialization complete. Using device: {self.device}") |
|
|
|
def setup_device(self): |
|
"""Setup CUDA device with error handling""" |
|
try: |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if self.device.type == "cuda": |
|
|
|
torch.cuda.empty_cache() |
|
logger.info(f"Available CUDA memory: {torch.cuda.get_device_properties(0).total_memory}") |
|
except Exception as e: |
|
logger.warning(f"Error setting up CUDA device: {e}. Falling back to CPU.") |
|
self.device = torch.device("cpu") |
|
|
|
def setup_model(self): |
|
"""Initialize the model with retry mechanism""" |
|
for attempt in range(self.max_retries): |
|
try: |
|
model_name = "facebook/opt-125m" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32 |
|
) |
|
|
|
if self.device.type == "cuda": |
|
self.model = self.model.to(self.device) |
|
torch.cuda.empty_cache() |
|
else: |
|
self.model = self.model.to(self.device) |
|
|
|
logger.info(f"Model loaded successfully on {self.device}") |
|
break |
|
except Exception as e: |
|
logger.error(f"Attempt {attempt + 1} failed: {str(e)}") |
|
if attempt < self.max_retries - 1: |
|
time.sleep(self.retry_delay) |
|
else: |
|
raise Exception("Failed to load model after maximum retries") |
|
|
|
def extract_text_from_pdf(self, pdf_file: str) -> str: |
|
"""Extract text from PDF with robust error handling""" |
|
try: |
|
if not os.path.exists(pdf_file): |
|
raise FileNotFoundError(f"PDF file not found: {pdf_file}") |
|
|
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
text = [] |
|
|
|
for page_num in range(len(pdf_reader.pages)): |
|
try: |
|
page = pdf_reader.pages[page_num] |
|
text.append(page.extract_text()) |
|
except Exception as e: |
|
logger.error(f"Error extracting text from page {page_num}: {e}") |
|
text.append(f"[Error extracting page {page_num}]") |
|
|
|
return "\n".join(text) |
|
except Exception as e: |
|
logger.error(f"Error processing PDF: {str(e)}") |
|
return f"Error processing PDF: {str(e)}" |
|
|
|
def translate_text(self, text: str, target_lang: str) -> str: |
|
"""Translate text using deep_translator with retry mechanism""" |
|
for attempt in range(self.max_retries): |
|
try: |
|
translator = GoogleTranslator(source='auto', target=target_lang) |
|
|
|
|
|
max_chunk_size = 4500 |
|
chunks = [text[i:i + max_chunk_size] for i in range(0, len(text), max_chunk_size)] |
|
|
|
translated_chunks = [] |
|
for chunk in chunks: |
|
translated_chunk = translator.translate(chunk) |
|
translated_chunks.append(translated_chunk) |
|
time.sleep(0.5) |
|
|
|
return ' '.join(translated_chunks) |
|
except Exception as e: |
|
logger.error(f"Translation attempt {attempt + 1} failed: {str(e)}") |
|
if attempt < self.max_retries - 1: |
|
time.sleep(self.retry_delay) |
|
else: |
|
return f"Translation error: {str(e)}" |
|
|
|
def process_query(self, pdf_file: str, query: str, language: str) -> str: |
|
"""Process query with comprehensive error handling""" |
|
try: |
|
|
|
if not pdf_file or not os.path.exists(pdf_file): |
|
return "Please provide a valid PDF file." |
|
if not query.strip(): |
|
return "Please provide a valid query." |
|
if language not in LANGUAGE_MAPPING: |
|
return "Please select a valid language." |
|
|
|
|
|
pdf_text = self.extract_text_from_pdf(pdf_file) |
|
if pdf_text.startswith("Error"): |
|
return pdf_text |
|
|
|
|
|
prompt = f"Query: {query}\n\nContent: {pdf_text[:1000]}\n\nAnswer:" |
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
input_ids, |
|
max_length=200, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
target_lang = LANGUAGE_MAPPING[language]["deep_translator_code"] |
|
translated_response = self.translate_text(response, target_lang) |
|
|
|
return translated_response |
|
|
|
except Exception as e: |
|
logger.error(f"Error in process_query: {str(e)}") |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
def create_interface(): |
|
pdf_processor = PDFQueryTranslator() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("### PDF Query and Translation System") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
pdf_input = gr.File( |
|
label="Upload PDF Document", |
|
type="filepath" |
|
) |
|
query_input = gr.Textbox( |
|
label="Enter your question about the PDF", |
|
placeholder="What would you like to know about the document?" |
|
) |
|
language_input = gr.Dropdown( |
|
label="Select Output Language", |
|
choices=[f"{code} - {info['name']}" for code, info in LANGUAGE_MAPPING.items()], |
|
value="hi - Hindi - हिन्दी" |
|
) |
|
language_description = gr.Textbox( |
|
label="Language Information", |
|
value=LANGUAGE_MAPPING['hi']['description'], |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
output_text = gr.Textbox( |
|
label="Translated Answer", |
|
placeholder="Translation will appear here...", |
|
lines=5 |
|
) |
|
|
|
def update_description(selected): |
|
code = selected.split(" - ")[0] |
|
return LANGUAGE_MAPPING[code]['description'] |
|
|
|
def process_and_translate(pdf_file, query, language): |
|
try: |
|
lang_code = language.split(" - ")[0] |
|
return pdf_processor.process_query(pdf_file, query, lang_code) |
|
except Exception as e: |
|
return f"Error processing request: {str(e)}" |
|
|
|
|
|
language_input.change( |
|
fn=update_description, |
|
inputs=[language_input], |
|
outputs=[language_description] |
|
) |
|
|
|
submit_button = gr.Button("Get Answer") |
|
submit_button.click( |
|
fn=process_and_translate, |
|
inputs=[pdf_input, query_input, language_input], |
|
outputs=output_text |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.queue() |
|
demo.launch(share=True) |
|
|