File size: 2,850 Bytes
df4af0a
40183cc
df4af0a
40183cc
 
 
df4af0a
 
40183cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer

st.set_page_config(layout="wide")
token = os.environ.get("hf_token")
login(token=token)

class TokenizationVisualizer:
    
    def __init__(self):
        self.tokenizers = {}

    def add_tokenizer(self, name, model_name):
        self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name)
    
    def visualize_tokens(self, text, tokenizer):

        tokens = tokenizer.tokenize(text)
        str_tokens = []
        for token in tokens:
            str_tokens.append(tokenizer.convert_tokens_to_string([token]))
        token_ids = tokenizer.convert_tokens_to_ids(tokens)

        colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5']
        
        html = ""
        for i, token in enumerate(str_tokens):
            color = colors[i % len(colors)]
            html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>'

        return html, token_ids

        
def playground_tab(visualizer):
    st.title("Tokenization Visualizer for Language Models")
    st.markdown("""
                You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers.
        """)
    
    
    text_input = st.text_area("Enter text below to visualize tokens:", height=300)
    if st.button("Tokenize"):
        st.divider()
        
        if text_input.strip():
            llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"])
            gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"])  
            
            col1, col2 = st.columns(2)
            col1.title('Llama2 Tokenizer')   
            col1.container(height=200, border=True).markdown(llama_tokenization_results,  unsafe_allow_html=True)
            with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"):
                st.markdown(llama_token_ids)
            col2.title('Gujju Llama Tokenizer')   
            col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True)
            with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"):
                st.markdown(gujju_token_ids)     
        else:
            st.error("Please enter some text.")


def main():
    
    huggingface_tokenizers ={
                 "Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1",
                 "Llama2": "meta-llama/Llama-2-7b-hf",
    }
    
    visualizer = TokenizationVisualizer()

    for tokenizer, src in huggingface_tokenizers.items():
        visualizer.add_tokenizer(tokenizer, src)
    
    playground_tab(visualizer)

    
    
if __name__ == "__main__":
    main()