File size: 2,850 Bytes
df4af0a 40183cc df4af0a 40183cc df4af0a 40183cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer
st.set_page_config(layout="wide")
token = os.environ.get("hf_token")
login(token=token)
class TokenizationVisualizer:
def __init__(self):
self.tokenizers = {}
def add_tokenizer(self, name, model_name):
self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name)
def visualize_tokens(self, text, tokenizer):
tokens = tokenizer.tokenize(text)
str_tokens = []
for token in tokens:
str_tokens.append(tokenizer.convert_tokens_to_string([token]))
token_ids = tokenizer.convert_tokens_to_ids(tokens)
colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5']
html = ""
for i, token in enumerate(str_tokens):
color = colors[i % len(colors)]
html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>'
return html, token_ids
def playground_tab(visualizer):
st.title("Tokenization Visualizer for Language Models")
st.markdown("""
You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers.
""")
text_input = st.text_area("Enter text below to visualize tokens:", height=300)
if st.button("Tokenize"):
st.divider()
if text_input.strip():
llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"])
gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"])
col1, col2 = st.columns(2)
col1.title('Llama2 Tokenizer')
col1.container(height=200, border=True).markdown(llama_tokenization_results, unsafe_allow_html=True)
with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"):
st.markdown(llama_token_ids)
col2.title('Gujju Llama Tokenizer')
col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True)
with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"):
st.markdown(gujju_token_ids)
else:
st.error("Please enter some text.")
def main():
huggingface_tokenizers ={
"Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1",
"Llama2": "meta-llama/Llama-2-7b-hf",
}
visualizer = TokenizationVisualizer()
for tokenizer, src in huggingface_tokenizers.items():
visualizer.add_tokenizer(tokenizer, src)
playground_tab(visualizer)
if __name__ == "__main__":
main()
|