(CVPR '18) Social GAN:利用GAN来帮助预测行人运动轨迹

2018/10/06 GAN

这次要阐述的工作是来自于CVPR2018中李飞飞组关于预测行人运动轨迹的工作,通过结合GAN和Sequence Prediction来帮助提高预测效果。由于受限于笔者阅读面窄,对这方面的早期工作不甚了解,就全盘接受本文对早期工作介绍的观点了。

背景

预测行人轨迹的任务主要的问题有三个

  1. 人与人之间的相互影响 某特定行人轨迹是受其他行人的位置而影响的,在早期工作中,每个行人都得走一次LSTM,计算开销大,而且不能进行真正的全局考虑所有行人因素,capacity不够大。

  2. 社交上的可接受程度 行人轨迹预测需要考虑人的社交属性。举个例子,有一对情侣拉着手往前走,理论上来说可以从下面钻过去,但是这显然是不符合人对于私密性的要求。 量化这个指标可能需要一些不够优雅的繁琐处理。

  3. 多情况 行人的轨迹预测显然不止只有一种解,早期的工作大多基于优化欧氏距离之类的方法,只能给出一个“平均”好的路径选择。

本次工作将GAN的对抗思想引入行人运动轨迹预测的任务中。第一,引入variety loss来帮助GAN展开$Distribution_{G}$在空间的分布,来覆盖所有存在可能足够好的解的样本空间。第二,在G和D的中间层引入池化模块(Pooling Module) 来帮助学习到一个全局的池化结果,这个结果携带了场景中所有人的feature,从而实现整体只需要走一次LSTM,降低计算开销的结果。第三,基于GAN的生成效果,可以给出多个解。

整体方法

先简单定义一下符号,我们定义输入为,预测输出为,真正的结果为皆为一个list的时序坐标。表达形式为

图1 上图为整体生成器和判别器的架构,注意虽然LSTM在图中分开了,但是在同一个位置的LSTM是share weights的,也等于是相同的LSTM。

在Encoder的部分中,每个人的位置用多层全连接embedding作为LSTM的输入,得到一个定长向量,t是sequence,i是人,此处有

就是embedding function with ReLU,是这里的weight,就是encoder的weight,在全场景中share weight。e可以多步地feed进去。

为了解决人与人的互动关系,我们引入Pooling Module。在最后一步输入结束之后,我们使用一个池化块(的具体算法后面给出),来获得一个对应每一个人的的pooled tensor。

在Decode的过程中,我们先注意对它的初始化,不同于传统GAN的纯噪音input,我们有

为多层全连接带ReLU,是embedding的weight,为concatenate 上去的噪音。

初始化后的decode 预测过程中,有

为embedding function with ReLU,为embedding weight,为decoder的weights,为全连接层。

注意这里的PM操作是可以在不同步的时候调整的。

在Discriminator就是一个encoder,主要就是区分,,socially acceptable 的问题就放在这里让判别器自己学习。

在Loss上,在对抗 loss的基础上加一个L2 loss,来衡量预测结果和真实结果的距离。

Pool Module

池化块引入的目的是为了帮助统一考虑所有人。这里有两问题,一个是数量可变,一个是人与人交互的信息比较稀疏,但又不能不考虑,所以提出了以下的解决方法 20190509003330.png

这里提出的Pool操作和CVPR 2016 同为李飞飞组的基于grid的social pool方法不同。如上图,而是直接通过坐标算出红人对于绿人和蓝人的相对坐标,然后独立地过MLP,然后过Elementwise Pool(原文用Max Pool),就能得到红人对应的池化后张量$P_{i}$。

鼓励多样性

文章里给出了一个Variety Loss来鼓励生成的多样性,训练的时候,一次生成k个结果,同样每个结果都计算L2距离,取最小的作为Loss值。

代码

github/agrimgupta92/sgan ,基于PyTorch。(有pretrained model)

个人觉得本文有意思的地方是其引入的池化块,在文章背后有给出一系列的直观解释,有兴趣的话可以详细看看,然后也combine了不少其他领域的流行方法,最终能够比较高效地解决早期工作遗留的各种问题。

论文地址

Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks

Search

    Table of Contents