今天小編分享的科學經驗:LIama 3+Mamba強強聯手!蒸餾到線性RNN,推理速度提升1.6倍,歡迎閲讀。
把 Llama 3 蒸餾到 Mamba,推理速度最高可提升 1.6 倍!
而且性能不減,甚至表現比原始模型還要優異。
這是來自 Together AI 的新作,通過蒸餾将 Transformer 和 Mamba 模型結合到了一起,同時還為混合模型涉及了推理加速算法
提出 Mamba 架構的大神、FlashAttention 作者 Tri Dao,也參與了這一項目。
Together AI 創始人兼 CEO 表示,Transformer 和 Mamba 的混合,是未來大模型的一大發展方向。
将 Transformer 蒸餾進 Mamba
在蒸餾正式開始之前,需要先進行從 Transformer 到線性 RNN 的初始化。
作者觀察到,Transformer 的注意力機制與 RNN 的計算之間存在一定的相似性。
因此可以将 Transformer 的注意力線性化,從而建立二者的聯系。
利用這種對應關系,可以将預訓練的 Transformer 模型的參數復制到 Mamba 模型中。
在完成參數初始化後,作者采用了一個三階段的蒸餾流程進一步提升 Mamba 模型的性能,使其更好地學習 Transformer 的知識。
第一階段是基于偽标籤的蒸餾——使用預訓練的 Transformer 教師模型在無标籤數據上生成偽标籤,然後讓 Mamba 學生模型在這些偽标籤上訓練。
這一過程的損失函數結合了 KL 散度損失和交叉熵損失,分别用于模仿教師模型輸出分布以及偽标籤的拟合。
第二階段是在指令數據集上進行的監督微調,使用帶标籤的指令數據集(如 OpenHermes 2.5)進行訓練。
最後一個階段,是用人類反饋數據,通過基于獎勵模型進行優化。
作者收集了人類對模型輸出的反饋數據,然後據此構建一個獎勵模型并使用 RL 算法(如 PPO)來優化模型在該獎勵模型下的表現。
在 8 塊 80G A100 GPU 上,每個混合模型的整個蒸餾過程,只需不到五天的時間。
通過以上的蒸餾過程,作者得到了 Transformer-Mamba 混合模型,之後又提出了 Speculative Decoding(推測解碼)算法來加速推理過程。
混合模型推理加速算法
推測解碼算法的基本思想是使用一個輕量級的 Draft 模型來預測多個 token,然後再用驗證模型(Verifier)來驗證這些預測。
這樣可以顯著提高解碼的并行性,加速生成過程。
Draft 模型通常是一個小的 Transformer,根據當前的上下文預測出接下來的 K 個 token。
對于預測出的 K 個 token,Transformer 層可以直接并行地處理這 K 個 token,計算它們的隐狀态;
Mamba 層則需要按照順序依次處理每個 token,首先計算當前 token 的隐狀态,并将其與之前的隐狀态進行比較。
如果當前 token 是正确的,則将其添加到已接受的序列中,并更新最新的隐狀态(但不保存中間狀态)。
如果當前 token 是錯誤的,則停止處理後續 token,并将最新的隐狀态回退到上一個已接受的 token 處。
如果序列中的所有 K 個 token 都被接受,則将它們添加到輸出序列中,并繼續預測下一組 token。
如果有 token 被拒絕,則從第一個被拒絕的 token 處截斷預測序列,并返回初始步驟從該位置開始重新預測。
Llama 3 推理速度提升 1.6 倍
測試結果表明,混合模型在單論(AlpacaEval)和多輪(MT-Bench)聊天對話任務上與 Llama-3 相當甚至更優。
并且還對不同混合比例的模型表現進行了測試,發現其中按照 1:1 比例混合的模型表現最佳。
在零樣本的通用 NLP 任務評測中,混合模型的平均成績優于同等規模的 RNN 模型。
在少樣本的 OpenLLM Leaderboard 榜單上,混合模型的表現與最好的開源 RNN 模型相當,并在 GSM8K 和 CRUX 任務上超過了對應的 Instruct 模型。
除了模型性能,作者也對推測解碼算法帶來的加速效果進行了測試。
首先測試的是純 Mamba 模型,結果在 2.8B 和 7B 的模型上,相比原來的解碼方式,推理速度提升了 1.7-2.6 倍。
進一步地,作者在蒸餾的 Zephyr 和 Llama 混合模型上進行了測試,結果 Zephyr 混合模型的推理速度提升了 1.8 倍以上,Llama 混合模型也有 1.6 倍左右的加速。
論文地址:
https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models