tabel_ocr / predict.py
Raj-Master's picture
Create predict.py
4635598
import os
import tempfile
import random
import string
from ultralyticsplus import YOLO
import streamlit as st
import numpy as np
import pandas as pd
from process import (
filter_columns,
extract_text_of_col,
prepare_cols,
process_cols,
finalize_data,
)
from file_utils import (
get_img,
save_excel_file,
concat_csv,
convert_pdf_to_image,
filter_color,
plot,
delete_file,
)
def process_img(
img,
page_enumeration: int = 0,
filter=False,
foldername: str = "",
filename: str = "",
):
tables = PaddleOCR.table_model(img, conf=0.75)
tables = tables[0].boxes.xyxy.cpu().numpy()
results = []
for table in tables:
try:
# * crop the table as an image from the original image
sub_img = img[
int(table[1].item()): int(table[3].item()),
int(table[0].item()): int(table[2].item()),
]
columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
cols_data = columns_detect[0].boxes.data.cpu().numpy()
# * Sort columns according to the x coordinate
cols_data = np.array(
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
)
# * merge the duplicated columns
cols_data = filter_columns(cols_data)
st.image(plot(sub_img, cols_data), channels="RGB")
except:
st.warning("No Detection")
try:
columns = cols_data[:, 0:4]
sub_imgs = []
for column in columns:
# * Create list of cropped images for each column
sub_imgs.append(sub_img[:, int(column[0]): int(column[2])])
cols = []
thr = 0
for image in sub_imgs:
if filter:
# * keep only black color in the image
image = filter_color(image)
# * extract text of each column and get the length threshold
res, threshold = extract_text_of_col(image)
thr += threshold
# * arrange the rows of each column with respect to row length threshold
cols.append(prepare_cols(res, threshold * 0.6))
thr = thr / len(sub_imgs)
# * append each element in each column to its right place in the dataframe
data = process_cols(cols, thr * 0.6)
# * merge the related rows together
data: pd.DataFrame = finalize_data(data, page_enumeration)
results.append(data)
print("data : ",data)
print("results : ", results)
except:
st.warning("Text Extraction Failed")
continue
list(
map(
lambda x: save_excel_file(
*x,
foldername,
filename,
page_enumeration,
),
enumerate(results),
)
)
class PaddleOCR:
# Load Image Detection model
table_model = YOLO("table.pt")
column_model = YOLO("columns.pt")
def __call__(self, uploaded, filter=False):
foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
filename = uploaded.name.split(".")[0]
if uploaded.name.split(".")[1].lower() == "pdf":
pdf_pages = convert_pdf_to_image(uploaded.read())
for page_enumeration, page in enumerate(pdf_pages, start=1):
process_img(
np.asarray(page),
page_enumeration,
filter=filter,
foldername=foldername.name,
filename=filename,
)
else:
img = get_img(uploaded)
process_img(
img,
filter=filter,
foldername=foldername.name,
filename=filename,
)
# * concatenate all csv files if many
extra = "".join(random.choices(string.ascii_uppercase, k=5))
filename = f"{filename}_{extra}.csv"
try:
concat_csv(foldername, filename)
except:
st.warning("No results found")
foldername.cleanup()
if os.path.exists(filename):
with open(f"{filename}", "rb") as fp:
st.download_button(
label="Download CSV file",
data=fp,
file_name=filename,
mime="text/csv",
)
delete_file(filename)
else:
st.warning("No results found")