今天小编分享的科学经验:大模型“取长补短”新思路入选NeurIPS‘24,显著优于现有路由方法,南科大港科大出品,欢迎阅读。
高效组合多个大模型 " 取长补短 " 新思路,被顶会 NeurIPS 2024 接收。
名为RouterDC,是一种基于双重对比学习的路由架构,具有参数高效性(小于 100M 的参数)和计算高效性(不需要对于 LLM 进行梯度回传)的优势。
在具有挑战性语言理解、代码生成和数学推理等推理任务实验中,RouterDC 在分布内(+2.76%)和分布外(+1.90%)设定下,都远超于现有的 routing 方法。
众所周知,LLM 通常在不同数据集上预训练和微调,导致它们在不同任务上的性能强弱不同。
LLM 路由则是一种组合多个 LLM 的新思路,它通过学习一个路由器(Router)来为每一个请求(query)选择最合适的 LLM。在推理时,LLM 路由只需要调用所选的 LLM 进行推理,使其在保持计算高效性的同时利用多个 LLM 的互补能力。
RouterDC 这种新方法,包括一个较小的语言模型作为编码器和一系列与候选 LLM 对应的可学习的LLM embeddings。
对于训练数据中的每个 query,首先将候选 LLM 的预测与真实标签进行比较获得表现最好和最差的 LLM,然后构造两个对比损失:
sample-LLM 对比损失:使得 query embedding(由编码器提取)与表现最佳的 LLM embeddings 相似,同时与表现最差的 LLM embeddings 不相似。
sample-sample 对比损失:提高训练的稳定性,将所有训练 query 聚类成多个组,最大化同组 query 之间的相似性的同时最小化不同组 query 之间的相似性。
这项研究由来自南方科技大学,香港科技大学的研究团队提出,以下是更为详细的介绍。
双对比学习实现 Router 训练
Router 架构
如图 1 所示,RouterDC 包括一个较小的语言模型(mDeBERTaV3-base)作为编码器 ε,和一系列的与候选 LLM 对应的可学习 LLM 嵌入 kT。对于每个 query xi,RouterDC 生成对于 T 个 LLMs 的选择概率如下:
其中,sim ( · , · ) 表示 cosine 相似度。
△图 1:RouterDC 方法示意图
sample-LLM 对比损失
为了训练 router,研究者将 query 的样本嵌入和在其上表现最好的 K+ 个 LLM 对应嵌入拉进,和在其上表现最差的 K- 个 LLM 对应嵌入拉远。因此,样本 -LLM 对比损失可以表示为:
sample-sample 对比损失
研究者通过实验发现,在 routing 问题中只使用样本 -LLM 对比损失并不稳定,使得相似的 query 可能具有不相似的嵌入。
为了提升训练的鲁棒性,训练样本被聚类成不同的组,从而在训练中拉近同一个组内的样本,拉远不同组的样本。和样本 -LLM 对比损失类似,样本 - 样本对比损失可以公式化为:
训练及推理
最终的优化目标为最小化样本 -LLM 对比损失和样本 - 样本对比损失的结合:
推理时,每个测试 query 只需要通过训练好的 router 选取概率最大的 LLM,并使用选择的 LLM 对 query 进行回答。
RouterDC 在训练时不需要任何经过 LLM 的梯度回传,并且在推理时只需要调用进行一次 LLM,同时具有训练和推理的高效性。
实验效果如何?
主要结果
RouterDC 在分布内数据集的测试准确率结果如表 1 所示。可以发现:
RouterDC 显著好于最优的单个模型,平均具有 3.98% 性能提升。在单个任务的层面,RouterDC 在三个任务上相比表现最优的单个模型取得了准确率的提升,其中 GSM8K 提升了 0.51%,ARC-C 提升了 0.57%,HumanEval 提升了 1.63%。
和现有路由方法 CosineClassifier 以及 ZOOTER 对比,RouterDC 在所有任务上都具有更好的表现。和 LoraRetriever 对比,RouterDC 具有平均 2.77% 的准确率提升。
△表 1:分布内任务的测试准确率(%)
为了评估 RouterDC 的泛化能力,表 2 展示了 RouterDC 在三个分布外数据集(PreAlgebra,MBPP,C-EVAL)的测试准确率。
可以看出,RouterDC 再次达到最高的测试准确率,显著超过表现最佳的单个 LLM(dolphin-2.9-llama3-8b)1.9%。
△表 2:分布外任务的测试准确率(%)
sample-sample 损失的作用
为了探究样本 - 样本损失的作用,图 3 展示了在是否有样本 - 样本损失的条件下训练和测试准确率曲线。可以看出,RouterDC(w/o Lsample-sample)有明显的震荡现象,而 RouterDC 则稳定得多。
△图 2:RouterDC 在 GSM8K 任务上的训练和测试准确率曲线
图 3(a)可视化了使用 RouterDC(w/o Lsample-sample)提取的训练样本的 TSNE 特征,可以看到,属于不同任务的训练样本粗略地混合在一起。而在结合 Lsample-sample 之后,训练样本有了清晰的聚类结构(如图 3(b)所示)。
△图 3:学习到的 router 所提取出训练样本 embedding 的 t-SNE 可视化
RouterDC 具有成本高效性
由于价格(cost)同样是一个评估 LLM 的重要指标,研究者通过 RouterBench 上的两个任务的实验来格外考虑 cost 的影响。如图 16 所示,RouterDC 相比于 CosineClassifier 和 ZOOTER 更加的成本高效。
△图 4:在 RouterBench 上使用不同的 Cost 获取的测试准确率
论文地址:https://arxiv.org/abs/2409.19886
代码地址:https://github.com/shuhao02/RouterDC
— 完 —
投稿请发邮件到:
标题注明【投稿】,告诉我们:
你是谁,从哪来,投稿内容
附上论文 / 项目主页链接,以及联系方式哦
我们会(尽量)及时回复你
点这里关注我,记得标星哦~
一键三连「分享」、「点赞」和「在看」
科技前沿进展日日相见 ~
>