dixitrivedi commited on
Commit
40183cc
1 Parent(s): f9bb2fa

Add application file

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer
3
+
4
+ st.set_page_config(layout="wide")
5
+
6
+ class TokenizationVisualizer:
7
+
8
+ def __init__(self):
9
+ self.tokenizers = {}
10
+
11
+ def add_tokenizer(self, name, model_name):
12
+ self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ def visualize_tokens(self, text, tokenizer):
15
+
16
+ tokens = tokenizer.tokenize(text)
17
+ str_tokens = []
18
+ for token in tokens:
19
+ str_tokens.append(tokenizer.convert_tokens_to_string([token]))
20
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
21
+
22
+ colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5']
23
+
24
+ html = ""
25
+ for i, token in enumerate(str_tokens):
26
+ color = colors[i % len(colors)]
27
+ html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>'
28
+
29
+ return html, token_ids
30
+
31
+
32
+ def playground_tab(visualizer):
33
+ st.title("Tokenization Visualizer for Language Models")
34
+ st.markdown("""
35
+ You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers.
36
+ """)
37
+
38
+
39
+ text_input = st.text_area("Enter text below to visualize tokens:", height=300)
40
+ if st.button("Tokenize"):
41
+ st.divider()
42
+
43
+ if text_input.strip():
44
+ llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"])
45
+ gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"])
46
+
47
+ col1, col2 = st.columns(2)
48
+ col1.title('Llama2 Tokenizer')
49
+ col1.container(height=200, border=True).markdown(llama_tokenization_results, unsafe_allow_html=True)
50
+ with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"):
51
+ st.markdown(llama_token_ids)
52
+ col2.title('Gujju Llama Tokenizer')
53
+ col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True)
54
+ with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"):
55
+ st.markdown(gujju_token_ids)
56
+ else:
57
+ st.error("Please enter some text.")
58
+
59
+
60
+ def main():
61
+
62
+ huggingface_tokenizers ={
63
+ "Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1",
64
+ "Llama2": "meta-llama/Llama-2-7b-hf",
65
+ }
66
+
67
+ visualizer = TokenizationVisualizer()
68
+
69
+ for tokenizer, src in huggingface_tokenizers.items():
70
+ visualizer.add_tokenizer(tokenizer, src)
71
+
72
+ playground_tab(visualizer)
73
+
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()