dtcda / app.py
zmbfeng's picture
move single sentence query to bottomm and optional
32d373f
import streamlit as st
import os
import fitz
import re
from transformers import AutoModelForSequenceClassification, BertTokenizer, BertModel, \
AutoTokenizer
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
def is_new_txt_file_upload(uploaded_txt_file):
if 'last_uploaded_txt_file' in st.session_state:
# Check if the newly uploaded file is different from the last one
if (uploaded_txt_file.name != st.session_state.last_uploaded_txt_file['name'] or
uploaded_txt_file.size != st.session_state.last_uploaded_txt_file['size']):
st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size}
# st.write("A new src image file has been uploaded.")
return True
else:
# st.write("The same src image file has been re-uploaded.")
return False
else:
# st.write("This is the first file upload detected.")
st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size}
return True
def is_new_file_upload(uploaded_file):
if 'last_uploaded_file' in st.session_state:
# Check if the newly uploaded file is different from the last one
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or
uploaded_file.size != st.session_state.last_uploaded_file['size']):
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
# st.write("A new src image file has been uploaded.")
return True
else:
# st.write("The same src image file has been re-uploaded.")
return False
else:
# st.write("This is the first file upload detected.")
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
return True
def add_commonality_to_similarity_score(similarity, sentence_to_find_similarity_score, query_to_find_similiarty_score):
# Tokenize both the sentence and the query
# sentence_words = set(sentence.split())
# query_words = set(query.split())
sentence_words = set(word for word in sentence_to_find_similarity_score.split() if word.lower() not in st.session_state.stop_words)
query_words = set(word for word in query_to_find_similiarty_score.split() if word.lower() not in st.session_state.stop_words)
# Calculate the number of common words
common_words = len(sentence_words.intersection(query_words))
# Adjust the similarity score with the common words count
combined_score = similarity + (common_words / max(len(query_words),
1)) # Normalize by the length of the query to keep the score between -1 and 1
return combined_score, similarity, (common_words / max(len(query_words), 1))
def contradiction_detection(premise, hypothesis):
inputs = st.session_state.roberta_tokenizer.encode_plus(premise, hypothesis, return_tensors="pt", truncation=True)
# Get model predictions
outputs = st.session_state.roberta_model(**inputs)
# Get the logits (raw predictions before softmax)
logits = outputs.logits
# Apply softmax to get probabilities for each class
probabilities = torch.softmax(logits, dim=1)
# Class labels: 0 = entailment, 1 = neutral, 2 = contradiction
predicted_class = torch.argmax(probabilities, dim=1).item()
# Class labels
labels = ["Contradiction", "Neutral", "Entailment"]
# Output the result
print(f"Prediction: {labels[predicted_class]}")
return {labels[predicted_class]}
if 'is_initialized' not in st.session_state:
st.session_state['is_initialized'] = True
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')
# print("stop words start")
# print(stopwords.words('english'))
# print("stop words end")
stop_words_list = stopwords.words('english')
st.session_state.stop_words = set(stop_words_list)
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
def encode_sentence(sentence_to_be_encoded):
if len(sentence_to_be_encoded.strip()) < 4:
return None
sentence_tokens = st.session_state.bert_tokenizer(sentence_to_be_encoded, return_tensors="pt", padding=True, truncation=True).to(
'cuda')
with torch.no_grad():
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
return sentence_encoding
def encode_paragraph(paragraph_to_be_encoded):
sentence_encodings_for_encoding_paragraph = []
paragraph_without_newline = paragraph_to_be_encoded.replace("\n", "")
sentences_for_encoding_paragraph = sent_tokenize(paragraph_without_newline)
for sentence_for_encoding_paragraph in sentences_for_encoding_paragraph:
# if sentence.strip().endswith('?'):
# sentence_encodings.append(None)
# continue
sentence_encoding = encode_sentence(sentence_for_encoding_paragraph)
sentence_encodings_for_encoding_paragraph.append([sentence_for_encoding_paragraph, sentence_encoding])
return sentence_encodings_for_encoding_paragraph
if 'list_count' in st.session_state:
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}')
if 'paragraph_sentence_encodings' not in st.session_state:
print("start embedding paragarphs")
read_progress_bar = st.progress(0)
st.session_state.paragraph_sentence_encodings = []
for index, paragraph in enumerate(st.session_state.restored_paragraphs):
# print(paragraph)
progress_percentage = index / (st.session_state.list_count - 1)
# print(progress_percentage)
read_progress_bar.progress(progress_percentage)
# sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
sentence_encodings = encode_paragraph(paragraph['paragraph'])
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
st.rerun()
big_text = """
<div style='text-align: center;'>
<h1 style='font-size: 30x;'>Contradiction Dectection</h1>
</div>
"""
# Display the styled text
st.markdown(big_text, unsafe_allow_html=True)
def convert_pdf_to_paragraph_list(pdf_doc_to_paragraph_list):
paragraphs = []
start_page = 1
for page_num in range(start_page - 1, len(pdf_doc_to_paragraph_list)): # start_page - 1 to adjust for 0-based index
page = pdf_doc_to_paragraph_list.load_page(page_num)
blocks = page.get_text("blocks")
for block in blocks:
x0, y0, x1, y1, text, block_type, flags = block
if text.strip() != "":
text = text.strip()
text = re.sub(r'\n\s+\n', '\n\n', text)
list_pattern = re.compile(r'^\s*((?:\d+\.|[a-zA-Z]\.|[*-])\s+.+)', re.MULTILINE)
match = list_pattern.search(text)
containsList = False
if match:
containsList = True
# print ("list detected")
if bool(re.search(r'\n{2,}', text)):
substrings = re.split(r'\n{2,}', text)
for substring in substrings:
if substring.strip() != "":
paragraph_for_converting_pdf = substring
paragraphs.append(
{"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num,
"text": text})
# print(f"<substring> {substring} </substring>")
else:
paragraph_for_converting_pdf = text
paragraphs.append(
{"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num, "text": None})
return paragraphs
uploaded_pdf_file = st.file_uploader("Upload a PDF file",
type=['pdf'])
st.markdown(
f'<a href="https://ikmtechnology.github.io/ikmtechnology/Sample_Master_Sample_Life_Insurance_Policy.pdf" target="_blank">Sample Master PDF download and then upload to above</a>',
unsafe_allow_html=True)
if uploaded_pdf_file is not None:
if is_new_file_upload(uploaded_pdf_file):
print("is new file uploaded")
if 'prev_query' in st.session_state:
del st.session_state['prev_query']
if 'paragraph_sentence_encodings' in st.session_state:
del st.session_state['paragraph_sentence_encodings']
save_path = './uploaded_files'
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, uploaded_pdf_file.name), "wb") as f:
f.write(uploaded_pdf_file.getbuffer()) # Write the file to the specified location
st.success(f'Saved file temp_{uploaded_pdf_file.name} in {save_path}')
st.session_state.uploaded_path = os.path.join(save_path, uploaded_pdf_file.name)
# st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path)
# print("page_count=",st.session_state.page_count)
doc = fitz.open(st.session_state.uploaded_path)
st.session_state.restored_paragraphs = convert_pdf_to_paragraph_list(doc)
if isinstance(st.session_state.restored_paragraphs, list):
# Count the restored_paragraphs of top-level elements
st.session_state.list_count = len(st.session_state.restored_paragraphs)
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}')
st.rerun()
def contradiction_detection_for_sentence(cd_query, cd_query_line_number):
query_encoding = encode_sentence(cd_query)
total_count = len(st.session_state.paragraph_sentence_encodings)
processing_progress_bar = st.progress(0)
sentence_scores, paragraph_scores = find_sentences_scores(
st.session_state.paragraph_sentence_encodings, query_encoding, cd_query, processing_progress_bar, total_count)
sorted_paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
# st.write("Top scored paragraphs and their scores:")
for i, (similarity_score, commonality_score, paragraph_from_sorted_paragraph_scores) in enumerate(
sorted_paragraph_scores[:3]): # number of paragraphs to consider
# st.write("top_three_sentences: ", paragraph['top_three_sentences'])
# st.write("paragarph number ***", i)
prev_contradiction_detected = True
for top_sentence in paragraph_from_sorted_paragraph_scores['top_three_sentences']:
if prev_contradiction_detected:
contradiction_detection_result = contradiction_detection(cd_query, top_sentence[1])
if contradiction_detection_result == {"Contradiction"}:
if top_sentence[2] >= 0.25:
st.write("master document page number ",
paragraph_from_sorted_paragraph_scores['original_text']['page_num'])
st.write("master document sentence: ", top_sentence[1])
st.write("secondary document line number", cd_query_line_number)
st.write("secondary document sentence: ", cd_query)
# st.write("commonality score",top_sentence[2])
st.write(contradiction_detection_result)
# st.write(contradiction_detection(st.session_state.premise, top_sentence[1]))
else:
prev_contradiction_detected = False
else:
break
def find_sentences_scores(paragraph_sentence_encodings, query_encoding, query_plain, processing_progress_bar, total_count):
paragraph_scores = []
sentence_scores = []
for paragraph_sentence_encoding_index, paragraph_sentence_encoding in enumerate(paragraph_sentence_encodings):
find_sentences_scores_progress_percentage = paragraph_sentence_encoding_index / (total_count - 1)
processing_progress_bar.progress(find_sentences_scores_progress_percentage)
sentence_similarities = []
for sentence_encoding in paragraph_sentence_encoding[1]:
if sentence_encoding:
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
combined_score, similarity_score, commonality_score = add_commonality_to_similarity_score(similarity,
sentence_encoding[
0],
query_plain)
# print(f"{sentence_encoding[0]} {combined_score} {similarity_score} {commonality_score}")
sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
sentence_scores.append((combined_score, sentence_encoding[0]))
sentence_similarities.sort(reverse=True, key=lambda x: x[0])
# print(sentence_similarities)
if len(sentence_similarities) >= 3:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
top_three_sentences = sentence_similarities[:3]
elif sentence_similarities:
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
top_three_sentences = sentence_similarities
else:
top_three_avg_similarity = 0
top_three_avg_commonality = 0
top_three_sentences = []
# print(f"top_three_sentences={top_three_sentences}")
# top_three_texts = [s[1] for s in top_three_sentences]
# remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
# reordered_paragraph = top_three_texts + remaining_texts
#
# original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
# modified_paragraph = ' '.join(reordered_paragraph)
paragraph_scores.append(
(top_three_avg_similarity, top_three_avg_commonality,
{'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]})
)
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
return sentence_scores, paragraph_scores
if 'paragraph_sentence_encodings' in st.session_state:
uploaded_text_file = st.file_uploader("Choose a .txt file", type="txt")
st.markdown(
f'<a href="https://ikmtechnology.github.io/ikmtechnology/Sample_Secondary.txt" target="_blank">Sample Secondary txt download and then upload to above</a>',
unsafe_allow_html=True)
if uploaded_text_file is not None:
if is_new_txt_file_upload(uploaded_text_file):
#if True:
lines = uploaded_text_file.readlines()
# Initialize an empty list to store line number and text
line_list = []
# Iterate through each line and add to the list
for line_number, line in enumerate(lines, start=1):
# Decode the line (since it will be in bytes)
decoded_line = line.decode("utf-8").strip()
line_list.append((line_number, decoded_line))
# Display the list of tuples
# st.write("Line Number and Line Content:")
for item in line_list:
# st.write(f"Line {item[0]}: {item[1]}")
sentences = sent_tokenize(item[1])
for sentence in sentences:
# st.write(f"sententce {sentence}")
contradiction_detection_for_sentence(sentence, item[0])
# print(top_sentence[1])
# st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
# st.write("top_three_sentences: ", paragraph['top_three_sentences'])
# st.write("Original Paragraph: ", paragraph['original_text'])
# A Member will be considered Actively at Work if he or she is able and available for active performance of all of his or her regular duties
# A Member will be considered as inactive at Work if he or she is able and available for active performance of all of his or her regular duties
# A Member shall be deemed inactive at Work if he or she is capable and available to perform all of his or her regular responsibilities.
# st.write("Modified Paragraph: ", paragraph['modified_text'])
toggle_single_sentence_input = st.checkbox("optional, if you want to enter a single sentence instead of an entire text file", value=False)
if toggle_single_sentence_input:
st.markdown(
"sample queries to invoke contradiction: <br/> A Member shall be deemed inactive at Work if he or she is capable and available to perform all of his or her regular responsibilities.",
unsafe_allow_html=True)
query = st.text_input("Enter your query")
if query:
if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
# if True:
st.session_state.prev_query = query
st.session_state.premise = query
contradiction_detection_for_sentence(query, 1)