zmbfeng commited on
Commit
561a0db
1 Parent(s): 897250e

embedding paragraphs

Browse files
Files changed (2) hide show
  1. app.py +47 -5
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,6 +1,16 @@
1
  import streamlit as st
2
  import os
3
  import json
 
 
 
 
 
 
 
 
 
 
4
  def is_new_file_upload(uploaded_file):
5
  if 'last_uploaded_file' in st.session_state:
6
  # Check if the newly uploaded file is different from the last one
@@ -44,18 +54,50 @@ if uploaded_json_file is not None:
44
  # print("page_count=",st.session_state.page_count)
45
  content = uploaded_json_file.read()
46
  try:
47
- data = json.loads(content)
48
  #print(data)
49
  # Check if the parsed data is a dictionary
50
- if isinstance(data, list):
51
- # Count the number of top-level elements
52
- st.session_state.list_count = len(data)
53
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
54
  else:
55
  st.write('The JSON content is not a dictionary.')
56
  except json.JSONDecodeError:
57
  st.write('Invalid JSON file.')
58
  st.rerun()
 
 
 
 
 
 
59
 
60
  if 'list_count' in st.session_state:
61
- st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import os
3
  import json
4
+
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel,T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ import torch
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import numpy as np
10
+ import nltk
11
+ from nltk.tokenize import sent_tokenize
12
+
13
+
14
  def is_new_file_upload(uploaded_file):
15
  if 'last_uploaded_file' in st.session_state:
16
  # Check if the newly uploaded file is different from the last one
 
54
  # print("page_count=",st.session_state.page_count)
55
  content = uploaded_json_file.read()
56
  try:
57
+ st.session_state.restored_paragraphs = json.loads(content)
58
  #print(data)
59
  # Check if the parsed data is a dictionary
60
+ if isinstance(st.session_state.restored_paragraphs, list):
61
+ # Count the restored_paragraphs of top-level elements
62
+ st.session_state.list_count = len(st.session_state.restored_paragraphs)
63
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
64
  else:
65
  st.write('The JSON content is not a dictionary.')
66
  except json.JSONDecodeError:
67
  st.write('Invalid JSON file.')
68
  st.rerun()
69
+ if 'is_initialized' not in st.session_state:
70
+ st.session_state['is_initialized'] = True
71
+
72
+ nltk.download('punkt')
73
+ st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
74
+ st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
75
 
76
  if 'list_count' in st.session_state:
77
+ st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
78
+ if 'paragraph_sentence_encodings' not in st.session_state:
79
+ print("start embedding paragarphs")
80
+ read_progress_bar = st.progress(0)
81
+ st.session_state.paragraph_sentence_encodings = []
82
+ for index,paragraph in enumerate(st.session_state.restored_paragraphs):
83
+ #print(paragraph)
84
+
85
+ progress_percentage = (index) / (st.session_state.list_count - 1)
86
+ print(progress_percentage)
87
+ read_progress_bar.progress(progress_percentage)
88
+
89
+ sentence_encodings = []
90
+ sentences = sent_tokenize(paragraph['text'])
91
+ for sentence in sentences:
92
+ if sentence.strip().endswith('?'):
93
+ sentence_encodings.append(None)
94
+ continue
95
+ if len(sentence.strip()) < 4:
96
+ sentence_encodings.append(None)
97
+ continue
98
+ sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
99
+ with torch.no_grad():
100
+ sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
101
+ sentence_encodings.append([sentence, sentence_encoding])
102
+ # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
103
+ st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ scikit-learn
4
+ nltk