MTPは推論速度だけでなくコーディング性能も向上させる

論文紹介

MTPによる投機的デコーディングが、推論速度だけでなくコーディングなど一部のタスクの回答精度も向上させる可能性があります。

NTPとMTP

NTP(Next token prediction)とは、従来どおりひとつのヘッドが順番に次のトークンを生成する仕組みです。MTP(Multi token prediction)は、NTPヘッドに加えて複数の小さなヘッドが数トークン先読みしてトークンを生成します。それらの先読みトークンの扱いは学習時と推論時で異なります。学習時は教師による正解のトークンと比較してロスの計算に使われるのに対して、推論時はNTPが採用・棄却を決定します。

コーディング性能を向上させる

学習時からMTPヘッドを有効化しておくと、HumanEval, MBPPのベンチマークテストの成績が向上したといいます。13BパラメータのLLMでは、HumanEvalで12%、MBPPで17%の向上が見られたとのことです。
ただし、自然言語による選択式の問題を扱うベンチマークテストであるARC, HellaSwagなどを7BパラメータのLLMで試したところ、n=4(席y見するトークンの数)の場合成績が僅かに低下し、n=2ではNTPのみでの学習時と同等でした。一方で、自然言語の要約タスク(ROGUE-L)ではn=2, n=4ともに性能が改善されました。
○✕クイズではなく文章生成系のタスクで性能が向上するということかもしれません。

性能向上のメカニズム

私自身正確に理解していないので簡単な説明にとどめておきます。

Induction head(帰納ヘッド)の形成促進

帰納ヘッドとは、「以前ABというパターンが現れたため、またAが現れたらBが現れることを予測する」ということを行うヘッドです。30Mパラメータ以下の言語モデルの学習実験で、NTPのみだとこの帰納ヘッドがほとんど形成されないのに対して、MTPを有効化すると形成されやすくなることが確認されたといいます。

算術の汎化

多項式を計算するタスクにおいて、多くの演算ステップを必要とする式を解く際に、NTPではそのステップ数が増えると急激に計算精度が落ちるのに対して、MTPでは各ステップを正確にこなせました。これは局所的な計算パターンの過学習(この多項式が出たら、解き方はわからないけどこの答えになるよねという考え方?)を抑制し、より大域的な構造を学習している(式と答えのパターンではなく解き方を理解している?)ということだそうです(AI談)。疑問符つきの括弧書きは私の推測です。

但し書き

スケールしないと効かない

だいたい1Bパラメータ以下のモデルでは、性能面でMTPがNTPに負ける場合があります。

nの最適な値はモデルのアーキテクチャやタスクによって変わる

コーディングではn=4, トークンでなくバイトを予測するモデルでのコーディングはn=8が最適だといいます。

学習時点からMTPを有効化する必要がある

学習済みのモデルにMTPヘッドを後付けすることもできますが、それでは性能の向上は見られません。出力自体はNTPのみの場合と同等であり、生成速度だけが向上します。

引用文献

Better & Faster Large Language Models via Multi-token Prediction

コメント

タイトルとURLをコピーしました