Ransaka's picture
Update app.py
c6afe0a
import streamlit as st
import pandas as pd
import numpy as np
import altair as alt
import os
from PIL import Image
from embeddings.embeddings import load_model
from sentence_transformers import util
import warnings
warnings.filterwarnings('ignore')
st.set_page_config(page_title="Sinhala Embedding Space", page_icon=":bar_chart:")
# cluster PNG file
image = Image.open('plots/clusters.png')
# Load data
# @st.cache_data
def load_data():
chart_data = pd.read_csv(r"data/top_cluster_dataset.csv",dtype={'Headline': str, 'x': np.float64, 'y': np.float64, 'labels': str})
return chart_data
chart_data = load_data()
# Create a Streamlit app
# Define tabs
tabs = ["Clustering Results","Sentences Similarity"]
selected_tab = st.sidebar.radio("Select a Tab", tabs)
def get_altair_chart():
chart = alt.Chart(chart_data).mark_circle(size=60).encode(x='x', y='y', color='labels', tooltip=['Headline']).interactive()
return chart
# Main content
if selected_tab == "Sentences Similarity":
sample_sentences = chart_data['Headline'].sample(10, random_state=1).tolist()
st.title("Calculate Sentences Similarity")
# select model to use dropdown
st.subheader("Select a model to use")
model_list = ["Ransaka/SinhalaRoberta","keshan/SinhalaBERTo"]
selected_model = st.selectbox("Select Model", model_list)
model = load_model(selected_model)
sentence1 = st.text_input("Enter Sentence 1", "")
sentence2 = st.text_input("Enter Sentence 2", "")
if sentence1 and sentence2:
# add button to calculate similarity
if st.button("Calculate Similarity"):
with st.spinner('Calculating Similarity...'):
# Calculate similarity
similarity = util.pytorch_cos_sim(model.encode(sentence1), model.encode(sentence2))[0][0]
if similarity > 0.7:
st.success(f"Sentences are similar (Score: {similarity:.3f})")
elif similarity > 0.5:
st.warning(f"Sentences are somewhat similar (Score: {similarity:.3f})")
else:
st.error(f"Sentences are not similar (Score: {similarity:.3f})")
else:
st.write("Enter two sentences to calculate similarity. Or start with sample sentences below.")
# change radio button to randomize sentences and show sample sentences
if st.button("Randomize Sentences"):
sample_sentences = chart_data['Headline'].sample(10).tolist()
for sentence in sample_sentences:
# show sample sentences in small font
st.write(sentence)
elif selected_tab == "Clustering Results":
st.title("Clustering Results")
# Display PNG image
st.subheader("Full Clustering Results")
st.image(image, use_column_width=False, caption='Static PNG File',width=750)
# with st.spinner('Loading Interactive Results...'):
# Display Altair chart
st.subheader("Interactive Chart")
chart = get_altair_chart()
st.altair_chart(chart, use_container_width=True)
# Dropdown functionality to update DataFrame
st.subheader("Select a cluster")
unique_clusters = chart_data['labels'].unique().tolist()
selected_value = st.selectbox("Select Value", unique_clusters)
# Filter and display results based on selected cluster
if selected_value:
filtered_data = chart_data[chart_data['labels'].str.contains(selected_value, case=False)].sample(10)[['Headline']].reset_index(drop=True)
st.dataframe(filtered_data,width=750)
else:
st.write("Select a cluster to display results.")