breeztest / app.py
simonzhang5429's picture
Update app.py
f669dd9 verified
import gradio as gr
import os
import shutil
from pypdf import PdfReader
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import fitz
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO,local_files_only=False,use_fast=True)
tran_hints = "请将以下的文字转为繁体:"
start_flag="<s>"
end_flag="</s>"
model = AutoModelForCausalLM.from_pretrained(
TOKENIZER_REPO,
device_map="auto",
local_files_only=False,
torch_dtype=torch.bfloat16
)
def generate(text):
chat_data = []
text = text.strip()
if text:
chat_data.append({"role": "user", "content": text})
achat=tokenizer.apply_chat_template(chat_data,return_tensors="pt")
#achat=tokenizer.encode(chat_data,return_tensors="pt",max_length=2048)
outputs = model.generate(achat,
max_new_tokens=2048,
top_p=0.01,
top_k=85,
repetition_penalty=1.1,
temperature=0)
return tokenizer.decode(outputs[0])
def tran_txt(input_txt):
data_txt=tran_hints+"\n"+input_txt.strip()
tran_result=generate(data_txt)
print("tran_result="+tran_result)
# tran_result=tran_result.strip()
# index=tran_result.find(start_flag)
# if index>=0:
# tran_result=tran_result[len(start_flag):]
# tran_result=tran_result.strip()
# c_index=tran_result.find(data_txt)
# if c_index>=0:
# tran_result=tran_result[len(data_txt):]
# e_index=tran_result.find(end_flag)
# if e_index>=0:
# tran_result=tran_result[0:e_index]
return tran_result
def exec_tran(file):
temp_file=upload_file(file)
page_texts=read_paragraphs(temp_file)
temp_result_file=file;
file_index=temp_result_file.index('.pdf')
if file_index!=-1:
temp_result_file=temp_result_file[0:file_index]
temp_result_file=temp_result_file+"_result.txt"
else :
temp_result_file=temp_result_file+"_result.txt"
tran_file_name=file.name
with open(temp_result_file,'w') as fw:
tran_result=tran_txt(tran_hints)
# print(tran_result+"\n")
for page_content in page_texts:
#lines=page_content.split('\n')
#for line_content in lines:
#print("input="+line_content)
tran_result=tran_txt(page_content)
# print("result="+tran_result)
fw.write(tran_result+"\n")
return temp_result_file
def upload_file(file):
UPLOAD_FOLDER="./data"
if not os.path.exists(UPLOAD_FOLDER):
os.mkdir(UPLOAD_FOLDER)
return shutil.copy(file,UPLOAD_FOLDER)
def read_paragraphs(pdf_path):
document = fitz.open(pdf_path)
paragraphs = []
for page in document:
text = page.get_text("paragraphs")
para_list = text.split('。')
paragraphs.extend([para for para in para_list if para.strip()])
document.close()
return paragraphs
def load_pdf_pages(filename):
page_texts=[]
reader = PdfReader(filename)
for page in reader.pages:
page_texts.append(page.extract_text())
return page_texts
def exec_translate(file):
upload_file(file)
page_texts=load_pdf_pages(file.name)
with gr.Blocks() as app:
file_output=gr.File()
upload_button=gr.UploadButton("上传pdf文件",file_types=["pdf"],file_count="single")
upload_button.upload(exec_tran,upload_button,file_output)
app.launch()