3.4 多模态的Loss设计
category
type
status
slug
date
summary
tags
password
icon
如果轨迹预测不考虑多模态, 仅仅作为简单的回归问题会怎么样?
如图, 假设数据集里的车辆在这个路口, 左转和右转的数据刚好各50%. 那么不考虑多模态的轨迹预测网络, 训练后会输出右图的结果:直行.
这显然是不符合轨迹预测的要求, 我们需要去学习车辆左转, 右转, 停车等等各种行为的轨迹. 如何能够准确的让网络学习到多模态的特性, 并且不会导致每个模态轨迹受到其他数据的影响呢? Mixture Density Network 混合密度网络是一个很好的选择.
3.4.1 Mixture Density Network 混合密度网络
对于回归问题, 简单的方案是直接输出对应的结果, 比如预测直接输出预测点的坐标. 但是很多情况下很难准确预测agent行为, 更合理的方案是网络先输出轨迹的概率分布. 以高斯分布为例, 高斯分布是由均值和标准差决定的, 输出了分布的参数, 就可以代表这个分布.
但是现实世界是复杂的, 单一高斯分布很难表达大部分的不确定分布. 不过我们可以由多个高斯分布整合来拟合, 也就是混合高斯模型(GMM). 而这就是混合密度网络的目的, 利用神经网络输出多个分布的参数.
再将这些多个分布混合, 得到最终的混合分布. 下图是利用网络输出的两个分布拟合真实分布的例子.
3.4.2 多模态轨迹预测
3.4.2.1 Winner Takes All(WTA) 赢家通吃
在轨迹预测网络中, 我们可以设计个模态. 每个模态对应一种行为:左转, 右转等等, 每个模态输出混合分布的参数. 这样就确保了, 在训练时每个模态都尽量不受其他模态数据的干扰. 事实上, 为了使每个模态能够完全不受不必要的数据干扰, Winner Takes All(WTA)赢家通吃策略得到了广泛运用.
比如在训练中, 网络输出了个模态的轨迹. 在计算loss的时候, 我们会先挑选出离真值最接近的那一条模态轨迹. 如图, 左边的轨迹最接近真值绿色轨迹, 只计算这一条轨迹的loss, 其他的轨迹全部抛弃. 这也就是所谓的WTA, 这种策略非常适合线上挑战赛.
WTA的优点在于, 可以专注的去优化与真值接近的网络, 而避免了因为平均效应导致的预测失准.
3.4.2.2 模态轨迹概率
网络还需要输出每条模态轨迹的概率, 简单的方法可以由网络输出每条模态轨迹的评分, 之后经过softmax得到每条模态轨迹的概率. 这样确保了所有的模态轨迹的概率之和为1.
但是数据集无法提供轨迹的概率真值, 如何设计loss使网络能够学习呢? 可以将这个考虑成分类问题, 类似的采取WTA策略, 将最接近真值的那条轨迹的概率设置为1.0, 其他的设置为0, 用Softmax得到的概率与设置的概率进行交叉熵损失.
但是为了避免模型走向极端, 我们不一定会做这样的强分类, 可以将最接近的那一条概率设置为大值(比如0.9), 其他的设置为小值.
比如是网络输出的概率, 是我们设置的真值概率:
但是实际上, 即使是最接近真值的模态轨迹, 也不一定就是完全贴合的. 因此我们还可以根据不同模态轨迹与真值的平均偏差(比如ADE), 进行Softmax来得到相对真实的模态轨迹概率. 这样就把每条模态轨迹与真值的距离偏差映射到了概率上, 用这个映射后的概率作为真值计算交叉熵损失:
3.4.2.3 缺点
这种方式来处理多模态问题,也存在着一些缺点:
- 由于未明确定义模态,模态之间存在可交换性(exchangeability)
- 由于可交换性导致的模糊问题,仍然可能遭受模式崩溃(mode collapse)。
参考链接
上一篇
动手学控制理论
下一篇
端到端-理论与实战视频课程
Loading...