今天小編分享的科學經驗:大模型“取長補短”新思路入選NeurIPS‘24,顯著優于現有路由方法,南科大港科大出品,歡迎閱讀。
高效組合多個大模型 " 取長補短 " 新思路,被頂會 NeurIPS 2024 接收。
名為RouterDC,是一種基于雙重對比學習的路由架構,具有參數高效性(小于 100M 的參數)和計算高效性(不需要對于 LLM 進行梯度回傳)的優勢。
在具有挑戰性語言理解、代碼生成和數學推理等推理任務實驗中,RouterDC 在分布内(+2.76%)和分布外(+1.90%)設定下,都遠超于現有的 routing 方法。
眾所周知,LLM 通常在不同數據集上預訓練和微調,導致它們在不同任務上的性能強弱不同。
LLM 路由則是一種組合多個 LLM 的新思路,它通過學習一個路由器(Router)來為每一個請求(query)選擇最合适的 LLM。在推理時,LLM 路由只需要調用所選的 LLM 進行推理,使其在保持計算高效性的同時利用多個 LLM 的互補能力。
RouterDC 這種新方法,包括一個較小的語言模型作為編碼器和一系列與候選 LLM 對應的可學習的LLM embeddings。
對于訓練數據中的每個 query,首先将候選 LLM 的預測與真實标籤進行比較獲得表現最好和最差的 LLM,然後構造兩個對比損失:
sample-LLM 對比損失:使得 query embedding(由編碼器提取)與表現最佳的 LLM embeddings 相似,同時與表現最差的 LLM embeddings 不相似。
sample-sample 對比損失:提高訓練的穩定性,将所有訓練 query 聚類成多個組,最大化同組 query 之間的相似性的同時最小化不同組 query 之間的相似性。
這項研究由來自南方科技大學,香港科技大學的研究團隊提出,以下是更為詳細的介紹。
雙對比學習實現 Router 訓練
Router 架構
如圖 1 所示,RouterDC 包括一個較小的語言模型(mDeBERTaV3-base)作為編碼器 ε,和一系列的與候選 LLM 對應的可學習 LLM 嵌入 kT。對于每個 query xi,RouterDC 生成對于 T 個 LLMs 的選擇概率如下:
其中,sim ( · , · ) 表示 cosine 相似度。
△圖 1:RouterDC 方法示意圖
sample-LLM 對比損失
為了訓練 router,研究者将 query 的樣本嵌入和在其上表現最好的 K+ 個 LLM 對應嵌入拉進,和在其上表現最差的 K- 個 LLM 對應嵌入拉遠。因此,樣本 -LLM 對比損失可以表示為:
sample-sample 對比損失
研究者通過實驗發現,在 routing 問題中只使用樣本 -LLM 對比損失并不穩定,使得相似的 query 可能具有不相似的嵌入。
為了提升訓練的魯棒性,訓練樣本被聚類成不同的組,從而在訓練中拉近同一個組内的樣本,拉遠不同組的樣本。和樣本 -LLM 對比損失類似,樣本 - 樣本對比損失可以公式化為:
訓練及推理
最終的優化目标為最小化樣本 -LLM 對比損失和樣本 - 樣本對比損失的結合:
推理時,每個測試 query 只需要通過訓練好的 router 選取概率最大的 LLM,并使用選擇的 LLM 對 query 進行回答。
RouterDC 在訓練時不需要任何經過 LLM 的梯度回傳,并且在推理時只需要調用進行一次 LLM,同時具有訓練和推理的高效性。
實驗效果如何?
主要結果
RouterDC 在分布内數據集的測試準确率結果如表 1 所示。可以發現:
RouterDC 顯著好于最優的單個模型,平均具有 3.98% 性能提升。在單個任務的層面,RouterDC 在三個任務上相比表現最優的單個模型取得了準确率的提升,其中 GSM8K 提升了 0.51%,ARC-C 提升了 0.57%,HumanEval 提升了 1.63%。
和現有路由方法 CosineClassifier 以及 ZOOTER 對比,RouterDC 在所有任務上都具有更好的表現。和 LoraRetriever 對比,RouterDC 具有平均 2.77% 的準确率提升。
△表 1:分布内任務的測試準确率(%)
為了評估 RouterDC 的泛化能力,表 2 展示了 RouterDC 在三個分布外數據集(PreAlgebra,MBPP,C-EVAL)的測試準确率。
可以看出,RouterDC 再次達到最高的測試準确率,顯著超過表現最佳的單個 LLM(dolphin-2.9-llama3-8b)1.9%。
△表 2:分布外任務的測試準确率(%)
sample-sample 損失的作用
為了探究樣本 - 樣本損失的作用,圖 3 展示了在是否有樣本 - 樣本損失的條件下訓練和測試準确率曲線。可以看出,RouterDC(w/o Lsample-sample)有明顯的震蕩現象,而 RouterDC 則穩定得多。
△圖 2:RouterDC 在 GSM8K 任務上的訓練和測試準确率曲線
圖 3(a)可視化了使用 RouterDC(w/o Lsample-sample)提取的訓練樣本的 TSNE 特征,可以看到,屬于不同任務的訓練樣本粗略地混合在一起。而在結合 Lsample-sample 之後,訓練樣本有了清晰的聚類結構(如圖 3(b)所示)。
△圖 3:學習到的 router 所提取出訓練樣本 embedding 的 t-SNE 可視化
RouterDC 具有成本高效性
由于價格(cost)同樣是一個評估 LLM 的重要指标,研究者通過 RouterBench 上的兩個任務的實驗來格外考慮 cost 的影響。如圖 16 所示,RouterDC 相比于 CosineClassifier 和 ZOOTER 更加的成本高效。
△圖 4:在 RouterBench 上使用不同的 Cost 獲取的測試準确率
論文地址:https://arxiv.org/abs/2409.19886
代碼地址:https://github.com/shuhao02/RouterDC
— 完 —
投稿請發郵件到:
标題注明【投稿】,告訴我們:
你是誰,從哪來,投稿内容
附上論文 / 項目主頁鏈接,以及聯系方式哦
我們會(盡量)及時回復你
點這裡關注我,記得标星哦~
一鍵三連「分享」、「點贊」和「在看」
科技前沿進展日日相見 ~
>