|
--- |
|
license: mit |
|
datasets: |
|
- hpprc/emb |
|
- hotchpotch/hpprc_emb-scores |
|
- microsoft/ms_marco |
|
language: |
|
- ja |
|
base_model: |
|
- tohoku-nlp/bert-base-japanese-v3 |
|
new_version: hotchpotch/japanese-splade-v2 |
|
--- |
|
|
|
高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。 |
|
|
|
- [高性能な日本語SPLADE(スパース検索)モデルを公開しました](https://secon.dev/entry/2024/10/07/100000/) |
|
- [SPLADE モデルの作り方・日本語SPLADEテクニカルレポート](https://secon.dev/entry/2024/10/23/080000-japanese-splade-tech-report/) |
|
|
|
また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。 |
|
|
|
|
|
# 利用方法 |
|
|
|
## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem) |
|
|
|
```bash |
|
pip install yasem |
|
``` |
|
|
|
```python |
|
from yasem import SpladeEmbedder |
|
|
|
model_name = "hotchpotch/japanese-splade-base-v1" |
|
embedder = SpladeEmbedder(model_name) |
|
|
|
sentences = [ |
|
"車の燃費を向上させる方法は?", |
|
"急発進や急ブレーキを避け、一定速度で走行することで燃費が向上します。", |
|
"車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。", |
|
] |
|
|
|
embeddings = embedder.encode(sentences) |
|
similarity = embedder.similarity(embeddings, embeddings) |
|
|
|
print(similarity) |
|
# [[21.49299249 10.48868281 6.25582337] |
|
# [10.48868281 12.90587398 3.19429791] |
|
# [ 6.25582337 3.19429791 12.89678271]] |
|
``` |
|
|
|
```python |
|
token_values = embedder.get_token_values(embeddings[0]) |
|
|
|
print(token_values) |
|
|
|
#{ |
|
# '車': 2.1796875, |
|
# '燃費': 2.146484375, |
|
# '向上': 1.7353515625, |
|
# '方法': 1.55859375, |
|
# '燃料': 1.3291015625, |
|
# '効果': 1.1376953125, |
|
# '良い': 0.873046875, |
|
# '改善': 0.8466796875, |
|
# 'アップ': 0.833984375, |
|
# 'いう': 0.70849609375, |
|
# '理由': 0.64453125, |
|
# ... |
|
``` |
|
|
|
## transformers |
|
|
|
```python |
|
|
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
import torch |
|
|
|
model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
def splade_max_pooling(logits, attention_mask): |
|
relu_log = torch.log(1 + torch.relu(logits)) |
|
weighted_log = relu_log * attention_mask.unsqueeze(-1) |
|
max_val, _ = torch.max(weighted_log, dim=1) |
|
return max_val |
|
|
|
tokens = tokenizer( |
|
sentences, return_tensors="pt", padding=True, truncation=True, max_length=512 |
|
) |
|
tokens = {k: v.to(model.device) for k, v in tokens.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model(**tokens) |
|
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"]) |
|
|
|
similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0) |
|
print(similarity) |
|
|
|
# tensor([[21.4943, 10.4816, 6.2540], |
|
# [10.4816, 12.9024, 3.1939], |
|
# [ 6.2540, 3.1939, 12.8919]]) |
|
``` |
|
|
|
# ベンチマークスコア |
|
|
|
## retrieval (JMTEB) |
|
|
|
[JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。japanese-splade-base-v1 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。 |
|
なお、japanese-splade-base-v1 は jaqket, mrtydi のドメインを学習(testのデータ以外)しています。 |
|
|
|
| model_name | Avg. | jagovfaqs_22k | jaqket | mrtydi | nlp_journal_abs_intro | nlp_journal_title_abs | nlp_journal_title_intro | |
|
| :------------------------------------------------------------------------- | ------: | ------------: | -----: | -----: | ---------------------: | ---------------------: | -----------------------: | |
|
| [japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1) | **0.7465** | 0.6499 | **0.6992** | **0.4365** | 0.8967 | **0.9766** | 0.8203 | |
|
| [text-embedding-3-large](https://huggingface.co/OpenAI/text-embedding-3-large) | 0.7448 | 0.7241 | 0.4821 | 0.3488 | **0.9933** | 0.9655 | **0.9547** | |
|
| [GLuCoSE-base-ja-v2](https://huggingface.co/pkshatech/GLuCoSE-base-ja-v2) | 0.7336 | 0.6979 | 0.6729 | 0.4186 | 0.9029 | 0.9511 | 0.7580 | |
|
| [multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 0.7098 | 0.7030 | 0.5878 | 0.4363 | 0.8600 | 0.9470 | 0.7248 | |
|
| [multilingual-e5-small](https://huggingface.co/intfloat/multilingual-e5-small) | 0.6727 | 0.6411 | 0.4997 | 0.3605 | 0.8521 | 0.9526 | 0.7299 | |
|
| [ruri-large](https://huggingface.co/cl-nagoya/ruri-large) | 0.7302 | **0.7668** | 0.6174 | 0.3803 | 0.8712 | 0.9658 | 0.7797 | |
|
|
|
|
|
## reranking |
|
|
|
### [JaCWIR](https://huggingface.co/datasets/hotchpotch/JaCWIR) |
|
|
|
なお、japanese-splade-base-v1 は **JaCWIR のドメインを学習していません**。 |
|
|
|
| model_names | map@10 | hit_rate@10 | |
|
| :------------------------------------------------------------------------------ | -----: | ----------: | |
|
| [japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1) | **0.9122** | **0.9854** | |
|
| [text-embedding-3-small](https://platform.openai.com/docs/guides/embeddings) | 0.8168 | 0.9506 | |
|
| [GLuCoSE-base-ja-v2](https://huggingface.co/pkshatech/GLuCoSE-base-ja-v2) | 0.8567 | 0.9676 | |
|
| [bge-m3+dense](https://huggingface.co/BAAI/bge-m3) | 0.8642 | 0.9684 | |
|
| [multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 0.8759 | 0.9726 | |
|
| [multilingual-e5-small](https://huggingface.co/intfloat/multilingual-e5-small) | 0.869 | 0.97 | |
|
| [ruri-large](https://huggingface.co/cl-nagoya/ruri-large) | 0.8291 | 0.9594 | |
|
|
|
### [JQaRA](https://github.com/hotchpotch/JQaRA) |
|
なお、japanese-splade-base-v1 は JQaRA のドメイン(test以外)を学習したものとなっています。 |
|
|
|
| model_names | ndcg@10 | mrr@10 | |
|
| :------------------------------------------------------------------------------ | ------: | -----: | |
|
| [japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1) | **0.6441** | **0.8616** | |
|
| [text-embedding-3-small](https://platform.openai.com/docs/guides/embeddings) | 0.3881 | 0.6107 | |
|
| [bge-m3+dense](https://huggingface.co/BAAI/bge-m3) | 0.539 | 0.7854 | |
|
| [multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 0.554 | 0.7988 | |
|
| [multilingual-e5-small](https://huggingface.co/intfloat/multilingual-e5-small) | 0.4917 | 0.7291 | |
|
| [GLuCoSE-base-ja-v2](https://huggingface.co/pkshatech/GLuCoSE-base-ja-v2) | 0.606 | 0.8359 | |
|
| [ruri-large](https://huggingface.co/cl-nagoya/ruri-large) | 0.6287 | 0.8418 | |
|
|
|
## 学習元データセット |
|
|
|
[hpprc/emb](https://huggingface.co/datasets/hpprc/emb) から、auto-wiki-qa, mmarco, jsquad jaquad, auto-wiki-qa-nemotron, quiz-works quiz-no-mori, miracl, jqara mr-tydi, baobab-wiki-retrieval, mkqa データセットを利用しています。 |
|
また英語データセットとして、MS Marcoを利用しています。 |
|
|
|
## 注意事項 |
|
|
|
text-embeddings-inference で動かす場合、[hotchpotch/japanese-splade-base-v1-dummy-fast-tokenizer-for-tei](https://huggingface.co/hotchpotch/japanese-splade-base-v1-dummy-fast-tokenizer-for-tei)をご利用ください。 |