|
import streamlit as st |
|
import os |
|
|
|
import fitz |
|
import re |
|
from transformers import AutoModelForSequenceClassification, BertTokenizer, BertModel, \ |
|
AutoTokenizer |
|
|
|
import torch |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
import nltk |
|
from nltk.tokenize import sent_tokenize |
|
from nltk.corpus import stopwords |
|
|
|
|
|
def is_new_txt_file_upload(uploaded_txt_file): |
|
if 'last_uploaded_txt_file' in st.session_state: |
|
|
|
if (uploaded_txt_file.name != st.session_state.last_uploaded_txt_file['name'] or |
|
uploaded_txt_file.size != st.session_state.last_uploaded_txt_file['size']): |
|
st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size} |
|
|
|
return True |
|
else: |
|
|
|
return False |
|
else: |
|
|
|
st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size} |
|
return True |
|
|
|
|
|
def is_new_file_upload(uploaded_file): |
|
if 'last_uploaded_file' in st.session_state: |
|
|
|
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or |
|
uploaded_file.size != st.session_state.last_uploaded_file['size']): |
|
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} |
|
|
|
return True |
|
else: |
|
|
|
return False |
|
else: |
|
|
|
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} |
|
return True |
|
|
|
|
|
def add_commonality_to_similarity_score(similarity, sentence_to_find_similarity_score, query_to_find_similiarty_score): |
|
|
|
|
|
|
|
sentence_words = set(word for word in sentence_to_find_similarity_score.split() if word.lower() not in st.session_state.stop_words) |
|
query_words = set(word for word in query_to_find_similiarty_score.split() if word.lower() not in st.session_state.stop_words) |
|
|
|
|
|
common_words = len(sentence_words.intersection(query_words)) |
|
|
|
|
|
combined_score = similarity + (common_words / max(len(query_words), |
|
1)) |
|
return combined_score, similarity, (common_words / max(len(query_words), 1)) |
|
|
|
|
|
def contradiction_detection(premise, hypothesis): |
|
inputs = st.session_state.roberta_tokenizer.encode_plus(premise, hypothesis, return_tensors="pt", truncation=True) |
|
|
|
|
|
outputs = st.session_state.roberta_model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
|
|
|
|
probabilities = torch.softmax(logits, dim=1) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
|
|
|
labels = ["Contradiction", "Neutral", "Entailment"] |
|
|
|
|
|
print(f"Prediction: {labels[predicted_class]}") |
|
return {labels[predicted_class]} |
|
|
|
|
|
if 'is_initialized' not in st.session_state: |
|
st.session_state['is_initialized'] = True |
|
|
|
nltk.download('punkt') |
|
nltk.download('punkt_tab') |
|
nltk.download('stopwords') |
|
|
|
|
|
|
|
stop_words_list = stopwords.words('english') |
|
st.session_state.stop_words = set(stop_words_list) |
|
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", ) |
|
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda') |
|
st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") |
|
st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli") |
|
|
|
|
|
def encode_sentence(sentence_to_be_encoded): |
|
if len(sentence_to_be_encoded.strip()) < 4: |
|
return None |
|
|
|
sentence_tokens = st.session_state.bert_tokenizer(sentence_to_be_encoded, return_tensors="pt", padding=True, truncation=True).to( |
|
'cuda') |
|
with torch.no_grad(): |
|
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy() |
|
return sentence_encoding |
|
|
|
|
|
def encode_paragraph(paragraph_to_be_encoded): |
|
sentence_encodings_for_encoding_paragraph = [] |
|
paragraph_without_newline = paragraph_to_be_encoded.replace("\n", "") |
|
sentences_for_encoding_paragraph = sent_tokenize(paragraph_without_newline) |
|
for sentence_for_encoding_paragraph in sentences_for_encoding_paragraph: |
|
|
|
|
|
|
|
sentence_encoding = encode_sentence(sentence_for_encoding_paragraph) |
|
sentence_encodings_for_encoding_paragraph.append([sentence_for_encoding_paragraph, sentence_encoding]) |
|
return sentence_encodings_for_encoding_paragraph |
|
|
|
|
|
if 'list_count' in st.session_state: |
|
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}') |
|
if 'paragraph_sentence_encodings' not in st.session_state: |
|
print("start embedding paragarphs") |
|
read_progress_bar = st.progress(0) |
|
st.session_state.paragraph_sentence_encodings = [] |
|
for index, paragraph in enumerate(st.session_state.restored_paragraphs): |
|
|
|
|
|
progress_percentage = index / (st.session_state.list_count - 1) |
|
|
|
read_progress_bar.progress(progress_percentage) |
|
|
|
|
|
sentence_encodings = encode_paragraph(paragraph['paragraph']) |
|
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings]) |
|
st.rerun() |
|
|
|
big_text = """ |
|
<div style='text-align: center;'> |
|
<h1 style='font-size: 30x;'>Contradiction Dectection</h1> |
|
</div> |
|
""" |
|
|
|
st.markdown(big_text, unsafe_allow_html=True) |
|
|
|
|
|
def convert_pdf_to_paragraph_list(pdf_doc_to_paragraph_list): |
|
paragraphs = [] |
|
start_page = 1 |
|
|
|
for page_num in range(start_page - 1, len(pdf_doc_to_paragraph_list)): |
|
page = pdf_doc_to_paragraph_list.load_page(page_num) |
|
blocks = page.get_text("blocks") |
|
for block in blocks: |
|
x0, y0, x1, y1, text, block_type, flags = block |
|
if text.strip() != "": |
|
text = text.strip() |
|
text = re.sub(r'\n\s+\n', '\n\n', text) |
|
list_pattern = re.compile(r'^\s*((?:\d+\.|[a-zA-Z]\.|[*-])\s+.+)', re.MULTILINE) |
|
match = list_pattern.search(text) |
|
containsList = False |
|
if match: |
|
containsList = True |
|
|
|
if bool(re.search(r'\n{2,}', text)): |
|
substrings = re.split(r'\n{2,}', text) |
|
for substring in substrings: |
|
if substring.strip() != "": |
|
paragraph_for_converting_pdf = substring |
|
paragraphs.append( |
|
{"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num, |
|
"text": text}) |
|
|
|
else: |
|
paragraph_for_converting_pdf = text |
|
paragraphs.append( |
|
{"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num, "text": None}) |
|
return paragraphs |
|
|
|
|
|
uploaded_pdf_file = st.file_uploader("Upload a PDF file", |
|
type=['pdf']) |
|
st.markdown( |
|
f'<a href="https://ikmtechnology.github.io/ikmtechnology/Sample_Master_Sample_Life_Insurance_Policy.pdf" target="_blank">Sample Master PDF download and then upload to above</a>', |
|
unsafe_allow_html=True) |
|
|
|
|
|
if uploaded_pdf_file is not None: |
|
if is_new_file_upload(uploaded_pdf_file): |
|
print("is new file uploaded") |
|
if 'prev_query' in st.session_state: |
|
del st.session_state['prev_query'] |
|
if 'paragraph_sentence_encodings' in st.session_state: |
|
del st.session_state['paragraph_sentence_encodings'] |
|
save_path = './uploaded_files' |
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
with open(os.path.join(save_path, uploaded_pdf_file.name), "wb") as f: |
|
f.write(uploaded_pdf_file.getbuffer()) |
|
st.success(f'Saved file temp_{uploaded_pdf_file.name} in {save_path}') |
|
st.session_state.uploaded_path = os.path.join(save_path, uploaded_pdf_file.name) |
|
|
|
|
|
|
|
doc = fitz.open(st.session_state.uploaded_path) |
|
|
|
st.session_state.restored_paragraphs = convert_pdf_to_paragraph_list(doc) |
|
if isinstance(st.session_state.restored_paragraphs, list): |
|
|
|
st.session_state.list_count = len(st.session_state.restored_paragraphs) |
|
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}') |
|
st.rerun() |
|
|
|
|
|
def contradiction_detection_for_sentence(cd_query, cd_query_line_number): |
|
query_encoding = encode_sentence(cd_query) |
|
|
|
total_count = len(st.session_state.paragraph_sentence_encodings) |
|
processing_progress_bar = st.progress(0) |
|
|
|
sentence_scores, paragraph_scores = find_sentences_scores( |
|
st.session_state.paragraph_sentence_encodings, query_encoding, cd_query, processing_progress_bar, total_count) |
|
|
|
sorted_paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True) |
|
|
|
|
|
for i, (similarity_score, commonality_score, paragraph_from_sorted_paragraph_scores) in enumerate( |
|
sorted_paragraph_scores[:3]): |
|
|
|
|
|
prev_contradiction_detected = True |
|
for top_sentence in paragraph_from_sorted_paragraph_scores['top_three_sentences']: |
|
|
|
if prev_contradiction_detected: |
|
contradiction_detection_result = contradiction_detection(cd_query, top_sentence[1]) |
|
|
|
if contradiction_detection_result == {"Contradiction"}: |
|
if top_sentence[2] >= 0.25: |
|
st.write("master document page number ", |
|
paragraph_from_sorted_paragraph_scores['original_text']['page_num']) |
|
st.write("master document sentence: ", top_sentence[1]) |
|
st.write("secondary document line number", cd_query_line_number) |
|
st.write("secondary document sentence: ", cd_query) |
|
|
|
st.write(contradiction_detection_result) |
|
|
|
|
|
else: |
|
prev_contradiction_detected = False |
|
else: |
|
break |
|
|
|
|
|
def find_sentences_scores(paragraph_sentence_encodings, query_encoding, query_plain, processing_progress_bar, total_count): |
|
paragraph_scores = [] |
|
sentence_scores = [] |
|
for paragraph_sentence_encoding_index, paragraph_sentence_encoding in enumerate(paragraph_sentence_encodings): |
|
find_sentences_scores_progress_percentage = paragraph_sentence_encoding_index / (total_count - 1) |
|
processing_progress_bar.progress(find_sentences_scores_progress_percentage) |
|
|
|
sentence_similarities = [] |
|
for sentence_encoding in paragraph_sentence_encoding[1]: |
|
if sentence_encoding: |
|
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0] |
|
combined_score, similarity_score, commonality_score = add_commonality_to_similarity_score(similarity, |
|
sentence_encoding[ |
|
0], |
|
query_plain) |
|
|
|
sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score)) |
|
sentence_scores.append((combined_score, sentence_encoding[0])) |
|
|
|
sentence_similarities.sort(reverse=True, key=lambda x: x[0]) |
|
|
|
if len(sentence_similarities) >= 3: |
|
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]]) |
|
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]]) |
|
top_three_sentences = sentence_similarities[:3] |
|
elif sentence_similarities: |
|
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities]) |
|
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities]) |
|
top_three_sentences = sentence_similarities |
|
else: |
|
top_three_avg_similarity = 0 |
|
top_three_avg_commonality = 0 |
|
top_three_sentences = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
paragraph_scores.append( |
|
(top_three_avg_similarity, top_three_avg_commonality, |
|
{'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]}) |
|
) |
|
|
|
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True) |
|
return sentence_scores, paragraph_scores |
|
|
|
|
|
if 'paragraph_sentence_encodings' in st.session_state: |
|
|
|
|
|
uploaded_text_file = st.file_uploader("Choose a .txt file", type="txt") |
|
st.markdown( |
|
f'<a href="https://ikmtechnology.github.io/ikmtechnology/Sample_Secondary.txt" target="_blank">Sample Secondary txt download and then upload to above</a>', |
|
unsafe_allow_html=True) |
|
if uploaded_text_file is not None: |
|
if is_new_txt_file_upload(uploaded_text_file): |
|
|
|
lines = uploaded_text_file.readlines() |
|
|
|
|
|
line_list = [] |
|
|
|
|
|
for line_number, line in enumerate(lines, start=1): |
|
|
|
decoded_line = line.decode("utf-8").strip() |
|
line_list.append((line_number, decoded_line)) |
|
|
|
|
|
|
|
for item in line_list: |
|
|
|
sentences = sent_tokenize(item[1]) |
|
for sentence in sentences: |
|
|
|
contradiction_detection_for_sentence(sentence, item[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toggle_single_sentence_input = st.checkbox("optional, if you want to enter a single sentence instead of an entire text file", value=False) |
|
|
|
if toggle_single_sentence_input: |
|
st.markdown( |
|
"sample queries to invoke contradiction: <br/> A Member shall be deemed inactive at Work if he or she is capable and available to perform all of his or her regular responsibilities.", |
|
unsafe_allow_html=True) |
|
query = st.text_input("Enter your query") |
|
if query: |
|
if 'prev_query' not in st.session_state or st.session_state.prev_query != query: |
|
|
|
st.session_state.prev_query = query |
|
st.session_state.premise = query |
|
contradiction_detection_for_sentence(query, 1) |
|
|