|
import streamlit as st |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer |
|
from st_keyup import st_keyup |
|
|
|
from utils import is_hiragana_or_katakana, search_candidates |
|
|
|
|
|
|
|
model_name_or_path = "tokyotech-llm/Swallow-7b-hf" |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16) |
|
|
|
|
|
|
|
|
|
st.title("丸点棒AI") |
|
st.write( |
|
"" |
|
) |
|
|
|
query_candidates = {"": ([""], 0)} |
|
|
|
query = st_keyup( |
|
"お題", |
|
placeholder="ひらがな/カタカナのみを入力", |
|
) |
|
|
|
if query != "" and is_hiragana_or_katakana(query): |
|
if query in query_candidates: |
|
top_candidates, top_losses = query_candidates[query] |
|
else: |
|
|
|
|
|
top_candidates, top_losses = search_candidates(query, query_candidates, model=model, tokenizer=tokenizer, top_k=10) |
|
answers = ["{}: {:.2f}".format(top_candidates[index], top_losses[index]) for index in range(min(len(top_candidates), 10))] |
|
value = "\n".join(answers) |
|
value += f"\n({len(top_candidates)}候補)" |
|
st.info(value) |
|
else: |
|
st.info("ひらがな/カタカナのみを入力してください") |
|
|