Pixtral-12B / app.py
aixsatoshi's picture
Update app.py
45159cb verified
raw
history blame contribute delete
No virus
4.77 kB
import gradio as gr
import spaces
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from huggingface_hub import snapshot_download
from pathlib import Path
# モデルのダウンロードと準備
mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=mistral_models_path)
# トークナイザーとモデルのロード
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)
# 推論処理
@spaces.GPU
def mistral_inference(prompt, image_url):
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
# 言語によるUIラベルの設定
def get_labels(language):
labels = {
'en': {
'title': "Pixtral Model Image Description",
'text_prompt': "Text Prompt",
'image_url': "Image URL",
'output': "Model Output",
'image_display': "Input Image",
'submit': "Run Inference"
},
'zh': {
'title': "Pixtral模型图像描述",
'text_prompt': "文本提示",
'image_url': "图片网址",
'output': "模型输出",
'image_display': "输入图片",
'submit': "运行推理"
},
'jp': {
'title': "Pixtralモデルによる画像説明生成",
'text_prompt': "テキストプロンプト",
'image_url': "画像URL",
'output': "モデルの出力結果",
'image_display': "入力された画像",
'submit': "推論を実行"
}
}
return labels[language]
# Gradioインターフェース
def process_input(text, image_url):
result = mistral_inference(text, image_url)
return result, f'<img src="{image_url}" alt="Input Image" width="300">'
def update_ui(language):
labels = get_labels(language)
return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']
# 初期URL
initial_url = "https://huggingface.co/spaces/aixsatoshi/Pixtral-12B/resolve/main/llamagiant.jpg"
with gr.Blocks() as demo:
language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
title = gr.Markdown("## Pixtral Model Image Description")
with gr.Row():
text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
image_input = gr.Textbox(label="Image URL", value=initial_url) # 初期URLを設定
# 初期画像を表示
result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
image_output = gr.HTML(f'<img src="{initial_url}" alt="Input Image" width="300">') # 入力された画像を最初から表示
submit_button = gr.Button("Run Inference")
submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])
# 言語変更時にUIラベルを更新
language_choice.change(
fn=update_ui,
inputs=[language_choice],
outputs=[title, text_input, image_input, result_output, image_output, submit_button]
)
# 例の設定
examples = [
["Describe the scene.", "https://assets.st-note.com/production/uploads/images/138094970/rectangle_large_type_2_bc1a73623dc0e9bf8799832ddb4cd53e.png"],
["Describe the image.", "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"],
["Describe the random generated image.", "https://picsum.photos/seed/picsum/200/300"],
["Describe the image.", "https://picsum.photos/id/32/512/512"]
]
gr.Examples(examples=examples, inputs=[text_input, image_input], label="Example Inputs")
demo.launch()