File size: 3,774 Bytes
010edb7
 
 
 
 
 
 
 
 
 
3eacaec
010edb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eacaec
010edb7
 
 
 
 
 
 
 
 
 
3eacaec
 
010edb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eacaec
 
 
010edb7
3eacaec
010edb7
3eacaec
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
import vec2text
import torch
from umap import UMAP
import plotly.express as px
import numpy as np
from streamlit_plotly_events import plotly_events
from resources import reduce_embeddings
import utils
import pandas as pd
from scipy.spatial import distance


def diffs(embeddings: np.ndarray, corrector):
    st.text(f"Embedding shape: {embeddings.shape}")
    st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')

def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector):
    

    # Add a scatter plot using Plotly
    fig = px.scatter(
        x=vectors_2d[:, 0], 
        y=vectors_2d[:, 1], 
        opacity=0.6,
        hover_data={"Title": df["title"]}, 
        labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
        title="UMAP Scatter Plot of Reddit Titles",
        color_discrete_sequence=["#ff504c"]  # Set default blue color for points
    )

    # Customize the layout to adapt to browser settings (light/dark mode)
    fig.update_layout(
        template=None,  # Let Plotly adapt automatically based on user settings
        plot_bgcolor="rgba(0, 0, 0, 0)",
        paper_bgcolor="rgba(0, 0, 0, 0)"
    )

    x, y = 0.0, 0.0
    vec = np.array([x, y]).astype("float32")
    inferred_embedding = None
    # Add a card container to the right of the content with Streamlit columns
    col1, col2 = st.columns([0.6, 0.4])  # Adjusting ratio to allocate space for the card container

    with col1:
        # Main content stays here (scatterplot, form, etc.)
        selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
                                        )
        with st.form(key="form1_main"):
            if selected_points:
                clicked_point = selected_points[0]
                x = clicked_point['x']
                y = clicked_point['y']

            x = st.number_input("X Coordinate", value=x, format="%.10f")
            y = st.number_input("Y Coordinate", value=y, format="%.10f")
            vec = np.array([x, y]).astype("float32")


            submit_button = st.form_submit_button("Submit")

            if selected_points or submit_button:
                inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
                inferred_embedding = inferred_embedding.astype("float32")

                output = vec2text.invert_embeddings(
                    embeddings=torch.tensor(inferred_embedding).cuda(),
                    corrector=corrector,
                    num_steps=20,
                )

                st.text(str(output))
                st.text(str(inferred_embedding))
            else:
                st.text("Click on a point in the scatterplot to see its coordinates.")

    with col2:
        closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
        selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
        selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None

        st.markdown(
            f"### Selected text:\n```console\n{selected_sentence}\n```"
        )

        if inferred_embedding is not None and (closest_sentence_index != -1):
            couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
            st.markdown(f"### Inferred embedding distance:")
            st.number_input("Euclidean", value=distance.euclidean(
                *couple
            ), disabled=True)
            st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)