news 2026/1/9 6:58:11

【论文阅读笔记-meta rl】多任务批强化学习与度量学习方法 MBML (Multi-task Batch RL with Metric Learning):解决任务推断中的伪相关性

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【论文阅读笔记-meta rl】多任务批强化学习与度量学习方法 MBML (Multi-task Batch RL with Metric Learning):解决任务推断中的伪相关性

NeurIPS 2020

Li, J., Vuong, Q., Liu, S., Liu, M., Ciosek, K., Christensen, H., & Su, H. (2020). Multi-task batch reinforcement learning with metric learning.Advances in neural information processing systems,33, 6197-6210.

在多任务强化学习中,我们通常希望训练一个能够泛化到多个任务的策略。然而,当训练数据来自不同任务且它们的状态-动作分布差异较大时,策略可能会学习到错误的任务推断方式——即仅根据状态-动作对判断任务,而忽略奖励信号。这会导致在未见任务上表现不佳。本文提出了一种结合三元组损失过渡重新标记的方法,MBML(Multi-task Batch RL with Metric Learning),通过近似奖励函数对跨任务数据进行重标注,构造"难负样本",再用三元组损失迫使任务推断模块必须依赖奖励信息而非仅靠状态-动作模式,强制任务推断模块同时考虑状态、动作和奖励,从而提升泛化性能。此外,训练好的策略作为初始化可大幅提升后续训练的收敛速度。

文章目录

    • 一、直观例子:为什么只看状态-动作对会出问题?
    • 二、研究背景与问题定义
      • 强化学习基础
      • 批强化学习
      • 多任务批强化学习问题
      • 任务推断模块
    • 三、核心挑战:伪相关性与错误的任务依赖
    • 四、方法:MBML —— 通过度量学习增强任务推断
      • 1. 过渡重新标记:构建硬负例
      • 2. 三元组损失:强制模型考虑奖励
      • 3. 总损失函数
      • 4. 算法流程
    • 五、实验场景与结果
      • 实验环境
      • 基线方法
      • 主要结果
      • 分析与讨论
    • 六、总结与展望

一、直观例子:为什么只看状态-动作对会出问题?

假设我们要训练一个智能体在二维平面上导航到不同目标位置的任务。我们有两个训练任务:

  • 任务1:导航到目标位置 Goal 1
  • 任务2:导航到目标位置 Goal 2

我们收集了两个数据集:

  • 任务1的数据(红色方块)主要集中在 Goal 1 周围
  • 任务2的数据(蓝色方块)主要集中在 Goal 2 周围

由于两个目标位置相距较远,红色与蓝色方块在状态-动作空间中没有重叠。如果我们训练一个任务推断模块,它可能学会:

“红色方块 → 任务1,蓝色方块 → 任务2”

而完全忽略了奖励信号(例如距离目标越近奖励越高)。

测试时的问题:在未见任务中(例如真实目标是 Goal 1),智能体随机收集了一些过渡数据(绿色方块)。如果这些绿色方块在状态空间上与蓝色方块更接近,模型会错误地推断当前任务是任务2,导致智能体向错误的目标移动。

这个例子揭示了多任务批强化学习中的一个核心挑战:当训练任务的数据分布差异大时,模型容易学习到伪相关性,仅依赖状态-动作对进行任务推断,而忽略了奖励信号。

二、研究背景与问题定义

强化学习基础

强化学习(RL)中,智能体通过与环境交互学习一个策略,以最大化累积奖励。一个任务通常建模为一个马尔可夫决策过程(MDP):

  • 状态空间S \mathcal{S}S
  • 动作空间A \mathcal{A}A
  • 转移函数T ( s ′ ∣ s , a ) T(s' \mid s, a)T(ss,a)
  • 奖励函数R ( s , a , s ′ ) R(s, a, s')R(s,a,s)
  • 初始状态分布T 0 T_0T0

策略π ( a ∣ s ) \pi(a \mid s)π(as)是一个从状态到动作的映射。目标是最优化:
J ( π ) = E τ ∼ π [ ∑ t = 0 H − 1 R ( s t , a t , s t ′ ) ] J(\pi) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{H-1} R(s_t, a_t, s'_t) \right]J(π)=Eτπ[t=0H1R(st,at,st)]

批强化学习

批强化学习(Batch RL)指的是仅使用一个预先收集的离线数据集B = { ( s t , a t , r t , s t ′ ) } \mathcal{B} = \{(s_t, a_t, r_t, s'_t)\}B={(st,at,rt,st)}训练策略,而不允许与环境进行额外交互。典型的算法如BCQ(Batch Constrained Q-Learning),它通过生成候选动作并添加微小扰动来进行受限探索。

多任务批强化学习问题

给定K KK个批数据集{ B i } i = 1 K \{\mathcal{B}_i\}_{i=1}^K{Bi}i=1K,每个数据集来自一个不同的任务M i M_iMi。我们的目标是训练一个多任务策略π θ \pi_\thetaπθ,使其在从同一任务分布p ( M ) p(M)p(M)中采样的未见任务上表现良好。关键挑战是:

  1. 测试时任务身份未知,策略必须从收集的过渡数据中推断任务;
  2. 不同任务的数据集可能在状态-动作分布上差异很大(即分布不重叠),导致任务推断模块可能忽略奖励信号。

在在线多任务RL中,策略可以通过持续采集数据逐步消除分布差异,实现"自纠错"。但在Batch RL中,数据是静态的,任务推断模块一旦学到错误依赖关系就无法修正。这正是本文聚焦的核心难题。

任务推断模块

我们使用一个任务推断模块q ϕ q_\phiqϕ,它接收一个上下文集c i \mathbf{c}_ici(来自任务i ii的一组过渡数据),输出一个任务身份的后验分布q ϕ ( z ∣ c i ) q_\phi(z \mid \mathbf{c}_i)qϕ(zci)。策略则同时接收状态s ss和推断出的任务身份z zzπ ( s , z ) \pi(s, z)π(s,z)

三、核心挑战:伪相关性与错误的任务依赖

在训练中,我们希望通过蒸馏多个单任务策略来得到一个多任务策略。具体来说,我们:

  1. 为每个训练任务训练一个 BCQ 策略(得到Q i , G i , ξ i Q_i, G_i, \xi_iQi,Gi,ξi);
  2. 将这些策略蒸馏为一个多任务策略(Q D , G D , ξ D Q_D, G_D, \xi_DQD,GD,ξD),其输入除了状态外还包括推断的任务身份z zz

蒸馏损失函数(以值函数为例)为:
L Q = 1 K ∑ i = 1 K E ( s , a ) , c i ∼ B i [ ( Q i ( s , a ) − Q D ( s , a , z i ) ) 2 + β KL ( q ϕ ( c i ) ∥ N ( 0 , 1 ) ) ] , z i ∼ q ϕ ( c i ) \mathcal{L}_Q = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{(s,a),\mathbf{c}_i \sim \mathcal{B}_i} \left[ (Q_i(s,a) - Q_D(s,a,z_i))^2 + \beta \text{KL}(q_\phi(\mathbf{c}_i) \| \mathcal{N}(0,1)) \right], \quad z_i \sim q_\phi(\mathbf{c}_i)LQ=K1i=1KE(s,a),ciBi[(Qi(s,a)QD(s,a,zi))2+βKL(qϕ(ci)N(0,1))],ziqϕ(ci)

然而,当不同任务的数据集在状态-动作分布上不重叠时,模型可能学会:
P ( Z ∣ S , A ) 而非正确的 P ( Z ∣ S , A , R ) P(Z \mid S, A) \quad \text{而非正确的} \quad P(Z \mid S, A, R)P(ZS,A)而非正确的P(ZS,A,R)

即,模型仅依赖状态-动作对推断任务,而忽略了奖励。这会导致在测试时,如果收集的过渡数据与某个训练任务的状态-动作分布更接近,即使奖励模式不同,模型也会错误地推断任务身份。

四、方法:MBML —— 通过度量学习增强任务推断

为了解决上述问题,我们提出了MBML(Multi-task Batch RL with Metric Learning),其核心是三元组损失过渡重新标记

1. 过渡重新标记:构建硬负例

我们为每个训练任务i ii学习一个奖励函数近似器R ^ i \hat{R}_iR^i。给定一个来自任务j jj的上下文集c j \mathbf{c}_jcj,我们用R ^ i \hat{R}_iR^i重新标记其中的奖励,得到:
c j → i = { ( s j , t , a j , t , R ^ i ( s j , t , a j , t ) , s j , t ′ ) } \mathbf{c}_{j \to i} = \{(s_{j,t}, a_{j,t}, \hat{R}_i(s_{j,t}, a_{j,t}), s'_{j,t})\}cji={(sj,t,aj,t,R^i(sj,t,aj,t),sj,t)}
这相当于将任务j jj的数据“伪装”成任务i ii的数据(状态-动作对相同,但奖励不同)。

2. 三元组损失:强制模型考虑奖励

对于每个任务i ii,我们构建三元组:

  • 锚点(Anchor)c j → i \mathbf{c}_{j \to i}cji(重新标记后的数据)
  • 正例(Positive)c i \mathbf{c}_ici(原始任务i ii的数据)
  • 负例(Negative)c j \mathbf{c}_jcj(原始任务j jj的数据)

三元组损失定义为:
L triplet i = 1 K − 1 ∑ j ≠ i [ d ( q ϕ ( c j → i ) , q ϕ ( c i ) ) ⏟ 锚点-正样本距离 − d ( q ϕ ( c j → i ) , q ϕ ( c j ) ) ⏟ 锚点-负样本距离 + m ] + \mathcal{L}_{\text{triplet}}^i = \frac{1}{K-1} \sum_{j\neq i} \Big[ \underbrace{d\big(q_\phi(\mathbf{c}_{j\to i}), q_\phi(\mathbf{c}_i)\big)}_{\text{锚点-正样本距离}} - \underbrace{d\big(q_\phi(\mathbf{c}_{j\to i}), q_\phi(\mathbf{c}_j)\big)}_{\text{锚点-负样本距离}} + m \Big]_+Ltripleti=K11j=i[锚点-正样本距离d(qϕ(cji),qϕ(ci))锚点-负样本距离d(qϕ(cji),qϕ(cj))+m]+
其中:

  • d ( ⋅ , ⋅ ) d(\cdot,\cdot)d(,)是散度度量,本文使用 KL 散度
  • m > 0 m > 0m>0是 margin,确保正样本比负样本至少近m mm
  • [ ⋅ ] + = max ⁡ ( ⋅ , 0 ) [\cdot]_+ = \max(\cdot, 0)[]+=max(,0)是 ReLU

直观理解

  • 第一项:鼓励c j → i \mathbf{c}_{j \to i}cjic i \mathbf{c}_ici推断出相似的任务身份;
  • 第二项:鼓励c j → i \mathbf{c}_{j \to i}cjic j \mathbf{c}_jcj推断出不同的任务身份。

关键点:由于c j → i \mathbf{c}_{j \to i}cjic j \mathbf{c}_jcj的状态-动作对完全相同,唯一的区别是奖励。因此,为了最小化该损失,任务推断模块必须考虑奖励信息P ( Z ∣ S , A , R ) P(Z|S,A,R)P(ZS,A,R)

3. 总损失函数

最终损失为蒸馏损失与三元组损失的加权和:
L = L triplet + L Q + L G + L ξ \mathcal{L} = \mathcal{L}_{\text{triplet}} + \mathcal{L}_Q + \mathcal{L}_G + \mathcal{L}_\xiL=Ltriplet+LQ+LG+Lξ

其中,
L Q = 1 K ∑ i = 1 K E ( s , a ) , c i ∼ B i [ ( Q i ( s , a ) − Q D ( s , a , z i ) ) 2 + β KL ( q ϕ ( c i ) ∥ N ( 0 , 1 ) ) ] , z i ∼ q ϕ ( c i ) \mathcal{L}_Q = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{(s,a),\mathbf{c}_i \sim \mathcal{B}_i} \left[ (Q_i(s,a) - Q_D(s,a,z_i))^2 + \beta \text{KL}(q_\phi(\mathbf{c}_i) \| \mathcal{N}(0,1)) \right], \quad z_i \sim q_\phi(\mathbf{c}_i)LQ=K1i=1KE(s,a),ciBi[(Qi(s,a)QD(s,a,zi))2+βKL(qϕ(ci)N(0,1))],ziqϕ(ci)

L G = 1 K ∑ i = 1 K E s , c i ∼ B i ∥ G i ( s , ν ) − G D ( s , ν , z ˉ i ) ∥ 2 \mathcal{L}_G = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{s,\mathbf{c}_i \sim \mathcal{B}_i} \| G_i(s,\nu) - G_D(s,\nu,\bar{\mathbf{z}}_i) \|^2LG=K1i=1KEs,ciBiGi(s,ν)GD(s,ν,zˉi)2

L ξ = 1 K ∑ i = 1 K E s , c i ∼ B i ν ∼ N ( 0 , 1 ) ∥ ξ i ( s , a ) − ξ D ( s , a , z ˉ i ) ∥ 2 , a = G i ( s , ν ) \mathcal{L}_\xi = \frac{1}{K} \sum_{i=1}^K \mathbb{E}_{\substack{s,\mathbf{c}_i \sim \mathcal{B}_i \\ \nu \sim \mathcal{N}(0,1)}} \| \xi_i(s,a) - \xi_D(s,a,\bar{\mathbf{z}}_i) \|^2, \quad a = G_i(s,\nu)Lξ=K1i=1KEs,ciBiνN(0,1)ξi(s,a)ξD(s,a,zˉi)2,a=Gi(s,ν)

记号z ˉ i \bar{\mathbf{z}}_izˉi表示梯度停止(stop gradient),即L G \mathcal{L}_GLGL ξ \mathcal{L}_\xiLξ不用于更新q ϕ q_\phiqϕ,避免任务推断模块被生成质量差的动作误导。

4. 算法流程

MBML 分为两个阶段:

  1. 单任务策略训练:使用 BCQ 为每个训练任务训练独立的策略;
  2. 多任务策略蒸馏:结合三元组损失,将单任务策略蒸馏为多任务策略。

详细的伪代码见原文附录 E。

五、实验场景与结果

实验环境

我们在 5 个 MuJoCo 任务分布和 1 个修改后的 D4RL 任务上评估 MBML:

  • AntDir:蚂蚁朝目标方向奔跑
  • HumanoidDir-M:人形机器人朝目标方向奔跑(修改版本,避免平凡解)
  • AntGoal:蚂蚁导航至目标位置
  • UmazeGoal-M:在 U 型迷宫中导航至目标位置
  • HalfCheetahVel:猎豹维持目标速度
  • WalkerParam:通过随机物理参数改变转移函数
任务分布任务定义方式训练任务数测试任务数状态空间挑战点
AntDir目标奔跑方向(120°弧内)108proprioceptive state状态不含方向信息
HumanoidDir-M目标奔跑方向108同上奖励系数调整后任务差异显著
AntGoal目标位置(120°弧内)108同上需导航到不同位置
HalfCheetahVel目标速度108同上速度控制精度
WalkerParam物理参数(质量、摩擦等)308同上转移函数变化
UmazeGoal-M迷宫目标位置108位置+速度稀疏奖励场景

基线方法

  • PEARL(修改为批训练版本)
  • Contextual BCQ(在 BCQ 基础上增加任务推断模块)
  • MetaGenRL(基于 DDPG,在批设定中容易发散)

主要结果

  1. 在未见任务上的表现:MBML 在所有任务分布上均优于基线方法。

    • PEARL 在某些任务上表现尚可,但未针对 Batch RL 的离线特性优化;Contextual BCQ 稳定但收敛到次优解;MetaGenRL 快速发散。
  2. 消融实验

    • Full Model (MBML):完整方法
    • No Relabeling (NR):仅使用原始上下文,通过大批量采样构造难负样本(计算复杂度高)
    • No Triplet Loss (NT):仅将重标注数据加入输入,但无三元组损失
    • Neither:简单蒸馏 + 任务推断模块
    • Ground Truth (GT):使用真实奖励函数重标注(作为性能上限)

    结果

    • 在 5/6 个任务上,完整模型显著优于所有消融版本
    • NR因计算效率低且难负样本质量差,性能下降
    • NT虽利用奖励信息,但无显式约束,提升有限
    • Neither完全失效,验证了三元组损失的必要性
    • GT与 MBML 性能接近,说明学习的奖励函数足够准确
  3. 作为初始化的加速效果:将训练好的多任务策略用于初始化 SAC,在未见任务上训练时,收敛速度提升高达80%

分析与讨论

  • 为什么三元组损失有效:它强制模型区分仅奖励不同的过渡数据,从而学习到正确的任务依赖P ( Z ∣ S , A , R ) P(Z \mid S, A, R)P(ZS,A,R)
  • 计算效率:使用重新标记的三元组损失计算复杂度为O ( K 2 ) O(K^2)O(K2),而传统硬负例挖掘需要O ( K 2 N 2 ) O(K^2 N^2)O(K2N2)
  • 奖励预测的泛化性:即使状态-动作分布不重叠,学到的奖励函数也能在一定程度上泛化,足以支撑三元组损失的有效性。

六、总结与展望

本文提出了 MBML,一种通过三元组损失过渡重新标记增强任务推断能力的多任务批强化学习方法。本文的核心在于通过度量学习增强任务推断的鲁棒性,尤其是在数据分布不重叠的批强化学习设定中。该方法不仅提升了泛化性能,还提供了一种高效的策略初始化方案,为实际应用中的样本效率问题提供了有希望的解决方案。核心贡献在于:

  1. 指出了在多任务批 RL 中,由于数据集分布不重叠导致任务推断模块可能忽略奖励信号的问题;
  2. 提出了一种新颖的三元组损失设计,通过奖励重新标记构建硬负例,强制模型考虑奖励;
  3. 实验表明,MBML 在多个任务分布上优于基线方法,并且训练好的策略可作为优质初始化,大幅提升后续训练效率。

Limitation

  • 文章主要讨论了奖励函数差异的情况,其技术细节和公式也集中围绕奖励重标注展开。但对于转移函数(transition function)不同的情况,文章的讨论相对简略。

  • 在原文Algorithm 3Appendix D的重标注过程中,仅对奖励进行替换,而保留了原始的下一个状态

  • 文章明确承认这是一个局限)在Discussion (Sec. 4)的 Limitations 部分,作者写道:

    “We also assume the learnt reward function of one task can generalize to state-action pairs from the other tasks, even when their state-action visitation frequencies do not overlap significantly.”

    这表明方法的核心假设是奖励函数可跨任务泛化,但未对转移函数的可泛化性做出类似假设

  • 实验中做了转移函数差异任务)文章在Sec. 5.1提到:

    “We also consider the WalkerParam environment where random physical parameters parameterize the agent, inducingdifferent transition functionsin each task.”

    WalkerParam 在实验中确实被包含,且 MBML 表现良好(+50% 超过 Contextual BCQ)。这说明即使不直接处理转移函数差异,方法依然有效

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/2 5:22:33

uni-app跨平台开发终极指南:一次编写,多端运行

uni-app跨平台开发终极指南:一次编写,多端运行 【免费下载链接】uni-app A cross-platform framework using Vue.js 项目地址: https://gitcode.com/dcloud/uni-app 还在为不同平台重复编写代码而烦恼吗?uni-app正是你需要的解决方案&…

作者头像 李华
网站建设 2026/1/6 21:38:03

终极指南:如何在5分钟内掌握SmoothScroll平滑滚动技术

终极指南:如何在5分钟内掌握SmoothScroll平滑滚动技术 【免费下载链接】smoothscroll Scroll Behavior polyfill 项目地址: https://gitcode.com/gh_mirrors/smo/smoothscroll SmoothScroll是一个轻量级JavaScript库,专门为网页提供流畅的平滑滚动…

作者头像 李华
网站建设 2026/1/7 21:02:47

AlphaFold解码蛋白质进化足迹:从分子化石到功能重建

AlphaFold解码蛋白质进化足迹:从分子化石到功能重建 【免费下载链接】alphafold Open source code for AlphaFold. 项目地址: https://gitcode.com/GitHub_Trending/al/alphafold 在生命演化的长河中,蛋白质如同分子化石,记录着亿万年…

作者头像 李华
网站建设 2026/1/8 4:36:18

2025视频生成平民化:WanVideo_comfy如何让RTX 4060也能做电影级视频

导语 【免费下载链接】WanVideo_comfy 项目地址: https://ai.gitcode.com/hf_mirrors/Kijai/WanVideo_comfy 阿里WanVideo_comfy开源项目通过多模型融合与量化技术,将专业级视频生成硬件门槛降至消费级GPU,重构AI内容创作生态。 行业现状&#…

作者头像 李华
网站建设 2026/1/6 16:35:20

Fiddly:3分钟将Readme.md转化为精美HTML页面的神奇工具

Fiddly:3分钟将Readme.md转化为精美HTML页面的神奇工具 【免费下载链接】fiddly Create beautiful and simple HTML pages from your Readme.md files 项目地址: https://gitcode.com/gh_mirrors/fi/fiddly 还在为你的开源项目文档不够美观而烦恼吗&#xff…

作者头像 李华
网站建设 2026/1/7 3:48:54

11、管理 OpenLDAP 与配置邮件服务器指南

管理 OpenLDAP 与配置邮件服务器指南 1. OpenLDAP 信息查询 在 SUSE Linux Enterprise Server 10 系统中,可使用 ldapsearch 命令从 LDAP 目录读取数据。以下是具体操作步骤: - 读取整个树 :使用 -x 选项,该选项会强制 ldapsearch 使用简单认证方法,适用于 LDA…

作者头像 李华