|
import os |
|
import re |
|
import pandas as pd |
|
from pypdf import PdfReader |
|
from typing import List, Dict |
|
from langchain.prompts import PromptTemplate |
|
from langchain_google_genai import GoogleGenerativeAI |
|
|
|
|
|
api_key = "AIzaSyCYGj5e2eAQbUi9HtuMaW0LDSnDuxLG54U" |
|
|
|
|
|
class InvoicePipeline: |
|
|
|
def __init__(self, paths): |
|
self._paths = paths |
|
self._llm = GoogleGenerativeAI(model="gemini-1.0-pro", google_api_key=api_key) |
|
self._prompt_template = self._get_default_prompt_template() |
|
|
|
def run(self) -> pd.DataFrame: |
|
|
|
df = pd.DataFrame({ |
|
"Invoice ID": pd.Series(dtype = "int"), |
|
"DESCRIPTION": pd.Series(dtype = "str"), |
|
"Issue Data": pd.Series(dtype = "str"), |
|
"UNIT PRICE": pd.Series(dtype = "str"), |
|
"AMOUNT": pd.Series(dtype = "int"), |
|
"Bill For": pd.Series(dtype = "str"), |
|
"From": pd.Series(dtype ="str"), |
|
"Terms": pd.Series(dtype = "str")} |
|
) |
|
|
|
for path in self._paths: |
|
raw_text = self._get_raw_text_from_pdf(path) |
|
llm_resp = self._extract_data_from_llm(raw_text) |
|
data = self._parse_response(llm_resp) |
|
df = pd.concat([df, pd.DataFrame([data])], ignore_index = True) |
|
|
|
return df |
|
|
|
|
|
def _get_default_prompt_template(self) -> PromptTemplate: |
|
template = """Extract all the following values: Invoice ID, DESCRIPTION, Issue Data,UNIT PRICE, AMOUNT, Bill for, From and Terms for: {pages} |
|
|
|
Expected Outcome: remove any dollar symbols {{"Invoice ID":"12341234", "DESCRIPTION": "UNIT PRICE", "AMOUNT": "3", "Date": "2/1/2021", "AMOUNT": "100", "Bill For": "Dev", "From": "Coca Cola", "Terms" : "Net for 30 days"}} |
|
""" |
|
|
|
prompt_template = PromptTemplate(input_variables = ["pages"], template = template) |
|
return prompt_template |
|
|
|
|
|
|
|
def _get_raw_text_from_pdf(self, path:str) -> str: |
|
text = "" |
|
pdf_reader = PdfReader(path) |
|
for page in pdf_reader.pages: |
|
text += page.extract_text() |
|
return text |
|
|
|
def _extract_data_from_llm(self, raw_data:str) -> str: |
|
resp = self._llm(self._prompt_template.format(pages = raw_data)) |
|
return resp |
|
|
|
def _parse_response(self, response: str) -> Dict[str, str]: |
|
pattern = r'{(.+)}' |
|
re_match = re.search(pattern, response, re.DOTALL) |
|
if re_match: |
|
extracted_text = re_match.group(1) |
|
data = eval('{' + extracted_text + '}') |
|
return data |
|
else: |
|
raise Exception("No match found.") |
|
|
|
|
|
|
|
|