Aratako's picture
Update README.md
92582b8 verified
|
raw
history blame
5.56 kB
metadata
base_model:
  - tokyotech-llm/Swallow-7b-hf
  - tokyotech-llm/Swallow-7b-instruct-hf
  - nitky/Superswallow-7b-v0.1
  - nitky/Superswallow-7b-v0.2
  - nitky/Superswallow-7b-v0.3
library_name: transformers
tags:
  - merge
  - moe
  - lisa
license: cc-by-nc-sa-4.0
datasets:
  - kunishou/amenokaku-code-instruct
  - llm-jp/oasst1-21k-en
  - hieunguyenminh/roleplay
  - meta-math/MetaMathQA
  - kunishou/jp-effective-instructions
language:
  - ja

Swallow-MoE-4x7B-lisa

概要

tokyotech-llm/Swallow-7b-hfをベースに、以下の4モデルをgate_mode=randomでMoEし、その後LISAという手法でインストラクションチューニングを施したモデルです。

お試しで作ってみたものなので、性能にはあまり期待しないでください。以下にベンチマーク結果も記載しております。

なお、この学習で使ったLISAの実装には不具合がある可能性が指摘されており、正常に学習できていない可能性があります。

データセット

以下の合計14327件のデータを学習に利用しました。プロンプトフォーマットはAlpacaを利用しています。

なお、ichikara-instructionの利用によりCC-BY-NC-SAを継承します。

学習の設定

主な学習パラメータは以下の通りです。なお、学習途中でのエラーのため2epochs程度しか学習できておりません。

  • lisa_activated_layers: 8
  • lisa_interval_steps: 13
  • learning_rate: 5e-5
  • num_train_epochs: 約2epochs
  • batch_size: 64
  • max_seq_length: 2048

評価

マージに利用したモデル群と本モデルのjapanese-mt-benchの結果は以下の通りです。(シングルターン)

Swallow-instructよりはスコアが高く、Superswallowよりは低いという何とも言えない結果になっております。 とはいえ、少量のデータセット・たった2epochsの学習でSwallow-instructを超えられているのは一定の成果とも言えるかもしれません。

Model Size Coding Extraction Humanities Math Reasoning Roleplay STEM Writing avg_score
Swallow-7b-instruct-hf 7B 2.0 4.6 5.4 1.7 2.8 5.0 5.9 6.9 4.2875
Superswallow-7b-v0.1 7B 2.0 5.1 7.8 2.1 3.6 6.2 7.3 7.5 5.2000
Superswallow-7b-v0.2 7B 2.2 5.8 6.7 2.5 4.3 5.5 6.6 5.8 4.9250
Superswallow-7b-v0.3 7B 2.1 4.6 8.3 2.1 5.0 6.3 7.7 8.9 5.6250
This model 4x7B 2.0 3.4 7.5 1.9 2.6 5.5 6.3 7.5 4.5875

レーダーチャート

同様に、jsquad(jsquad-1.1-0.3, 2-shots)、jcommonsenseqa(jcommonsenseqa-1.1-0.3, 3-shots)、jnli(jnli-1.3-0.3, 3-shots)、marc_ja(marc_ja-1.1-0.3, 3-shots)結果は以下の通りです。(jsquadは100で割り、それぞれ小数点以下第4位を四捨五入) ここでもSwallow-instructよりはスコアが高く、Superswallowよりは低い結果になっています。なお、こちらは参考として本モデルのインストラクションチューニング前(MoEのみ)のモデルのスコアも載せてあります。

Model Size jsquad(exact_match) jcommonsenseqa(acc) jnli(acc) marc_ja(acc) average
Swallow-7b-instruct-hf 7B 0.757 0.831 0.212 0.945 0.686
Superswallow-7b-v0.1 7B 0.441 0.846 0.374 0.966 0.657
Superswallow-7b-v0.2 7B 0.722 0.846 0.381 0.964 0.728
Superswallow-7b-v0.3 7B 0.721 0.850 0.362 0.964 0.724
This model without fine-tuning 4x7B 0.674 0.809 0.333 0.952 0.692
This model 4x7B 0.741 0.806 0.385 0.948 0.719