File size: 2,106 Bytes
05edb6c
 
 
 
 
 
 
 
 
 
 
 
091aa4f
05edb6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6defcba
 
91d64f8
 
6defcba
05edb6c
 
6defcba
05edb6c
6defcba
05edb6c
6defcba
05edb6c
 
6defcba
 
05edb6c
 
6defcba
 
05edb6c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os

import streamlit as st
from yasem import SpladeEmbedder

if os.getenv("SPACE_ID"):
    USE_HF_SPACE = True
    os.environ["HF_HOME"] = "/data/.huggingface"
    os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
    USE_HF_SPACE = False

MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/japanese-splade-v2")


@st.cache_resource
def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder:
    embedder = SpladeEmbedder(
        model_name,
    )
    return embedder


def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]:
    embedder = get_embedder()
    embeddings = embedder.encode([input_text])
    token_values = embedder.get_token_values(embeddings[0])
    sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True)  # type: ignore
    return [(value, key) for key, value in sorted_tokens]


def main():
    st.set_page_config(
        page_title="SPLADE 日本語 demo",
        layout="centered",
        initial_sidebar_state="auto",
    )

    st.title("SPLADE 日本語 demo")

    get_embedder()

    st.markdown(f"""
    [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})を使って、テキストからSPLADEのスパースベクトルに変換するデモです。

    """)

    input_text = st.text_area("テキスト", height=200)

    if st.button("変換"):
        if input_text.strip():
            with st.spinner("変換中..."):
                sorted_tokens = get_token_values_sorted(input_text)

            total_tokens = len(sorted_tokens)
            st.markdown(f"### 結果 (トークン数: {total_tokens})")
            if sorted_tokens:
                formatted_data = [
                    {"スコア": freq, "単語(vocab)": word}
                    for freq, word in sorted_tokens
                ]
                st.table(formatted_data)
            else:
                st.warning("入力テキストから有効な単語が見つかりませんでした。")
        else:
            st.warning("テキストを入力してください。")


if __name__ == "__main__":
    main()