|
import os |
|
import streamlit as st |
|
from huggingface_hub import login |
|
from transformers import AutoTokenizer |
|
|
|
st.set_page_config(layout="wide") |
|
token = os.environ.get("hf_token") |
|
login(token=token) |
|
|
|
class TokenizationVisualizer: |
|
|
|
def __init__(self): |
|
self.tokenizers = {} |
|
|
|
def add_tokenizer(self, name, model_name): |
|
self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name) |
|
|
|
def visualize_tokens(self, text, tokenizer): |
|
|
|
tokens = tokenizer.tokenize(text) |
|
str_tokens = [] |
|
for token in tokens: |
|
str_tokens.append(tokenizer.convert_tokens_to_string([token])) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5'] |
|
|
|
html = "" |
|
for i, token in enumerate(str_tokens): |
|
color = colors[i % len(colors)] |
|
html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>' |
|
|
|
return html, token_ids |
|
|
|
|
|
def playground_tab(visualizer): |
|
st.title("Tokenization Visualizer for Language Models") |
|
st.markdown(""" |
|
You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers. |
|
""") |
|
|
|
|
|
text_input = st.text_area("Enter text below to visualize tokens:", height=300) |
|
if st.button("Tokenize"): |
|
st.divider() |
|
|
|
if text_input.strip(): |
|
llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"]) |
|
gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"]) |
|
|
|
col1, col2 = st.columns(2) |
|
col1.title('Llama2 Tokenizer') |
|
col1.container(height=200, border=True).markdown(llama_tokenization_results, unsafe_allow_html=True) |
|
with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"): |
|
st.markdown(llama_token_ids) |
|
col2.title('Gujju Llama Tokenizer') |
|
col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True) |
|
with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"): |
|
st.markdown(gujju_token_ids) |
|
else: |
|
st.error("Please enter some text.") |
|
|
|
|
|
def main(): |
|
|
|
huggingface_tokenizers ={ |
|
"Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1", |
|
"Llama2": "meta-llama/Llama-2-7b-hf", |
|
} |
|
|
|
visualizer = TokenizationVisualizer() |
|
|
|
for tokenizer, src in huggingface_tokenizers.items(): |
|
visualizer.add_tokenizer(tokenizer, src) |
|
|
|
playground_tab(visualizer) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|