File size: 1,493 Bytes
d38229e 4e5672e 57e364c fef3565 d38229e 4e5672e fef3565 a13e156 7db7298 a13e156 fef3565 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from st_keyup import st_keyup
from utils import is_hiragana_or_katakana, search_candidates
# model_name_or_path = "tokyotech-llm/Llama-3-Swallow-8B-v0.1"
model_name_or_path = "tokyotech-llm/Swallow-7b-hf"
# model_name_or_path = "llm-jp/llm-jp-1.3b-v1.0"
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
# model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
# Show title and description.
st.title("丸点棒AI")
st.write(
""
)
query_candidates = {"": ([""], 0)}
query = st_keyup(
"お題",
placeholder="ひらがな/カタカナのみを入力",
)
if query != "" and is_hiragana_or_katakana(query):
if query in query_candidates:
top_candidates, top_losses = query_candidates[query]
else:
# top_candidates = [query]
# top_losses = [0.0]
top_candidates, top_losses = search_candidates(query, query_candidates, model=model, tokenizer=tokenizer, top_k=10)
answers = ["{}: {:.2f}".format(top_candidates[index], top_losses[index]) for index in range(min(len(top_candidates), 10))]
value = "\n".join(answers)
value += f"\n({len(top_candidates)}候補)"
st.info(value)
else:
st.info("ひらがな/カタカナのみを入力してください")
|