XDRush

深度学习在CTR预估中的应用之Attention FM

1 背景说明

Attention,也叫注意力模型,关于Attention,这篇文章中有详细的描述。Attention在计算机视觉、自然语言处理中都有典型的应用。因为Attention能够学习到不同特征对于目标的重要程度,目前的研究结果表明,Attention在绝大多数场景中都能发挥积极作用。

Attention FM(后文简称AFM)正是在原有FM基础上,考虑了不同二阶交叉特征对于目标的重要程度,实验结果表明,AFM性能明显要优于FM及FM的衍生版本。下面我们就来探索下AFM原理。

2 AFM模型结构及原理

2.1 AFM模型结构

FM模型结构及原理这里就不再讲解,不了解的参考这篇,AFM结构如下图所示:
AFM模型结构

上面这张模型图省去了FM模型中的线性部分。只考虑二阶交叉部分,这部分又分为两层:Pair-wise Interaction Layer和Attention-based Pooling Layer。

2.2 Pair-wise Interaction Layer

FM中使用特征对应隐向量的内积来表示两个特征之间的关联程度,受到这个启发,作者提出了一个新的Pair-wise Interaction层(以下简称PI层),它将$m$个交叉向量扩展到$m(m-1)/2$个交叉向量,每一个交叉向量都是由原始隐向量的element-wise product(元素积)来表示,Pair-wise Interaction层的输出结果为:
pair-wise interaction层的输出

其中$\boldsymbol{v}_i\bigodot\boldsymbol{v}_j$表示两个向量的元素积,不知道元素积的可自行搜索下,因为这一步表示非常关键。其中,

如果将PI层的结果直接输出做预测,那么预测得分表达式为:
PI层的预测得分

这里提醒一点,$\boldsymbol{v}_i\bigodot\boldsymbol{v}_j$的结果是一个向量,$x_ix_j$是一个数值,因此有$\boldsymbol{p}^T\in{R^k}$,$b\in{R}$,即$\boldsymbol{p}^T$是一个向量,$b$是一个实数值。

并且,并且,并且,重要的地方说三遍,当$\boldsymbol{p}^T=[1,1,…,1],b=0$是,PI层就退化为FM!这一点需要特别注意。

2.3 Attention-based Pooling Layer

接下来就是重点啦,Attention层!Attention的思想就是不同的部分对结果的贡献程度不一样!我们知道,FM中并没有考虑这种重要程度,受此启发,基于Attention的FM应运而生。

不妨再来看下网络中的Attention这一层,
Attention层

上图其实有一定的迷惑性,PI层和Attention-based Pooling之间的直线连接表示不经过Attention直接输出,也就是2.2节中的描述。PI层经过Attention Net,最终形成系数$a{ij}$,然后再求和输出,这才是Attention部分,也就是说,Attention部分目的就是求出不同交叉项的系数$a{ij}$(重要程度),因此,经过Attention的输出为:
Attention输出表达式

接下来的问题就是,$a{ij}$如何获取?一个最简单的思路就是,直接将$a{ij}$当做超参数,参与训练,通过迭代直接求取。但是这样会有一个问题:对于没有共现的特征,$a_{ij}$就无法获取了。

其实在这篇文章中,看图说话那部分,求Attention系数有异曲同工之妙,个人感觉作者是借鉴了paper中获取系数的方式。总之,求系数很简单,一层MLP就可以!也就是PI到Attention Net其实就是一层MLP:
Attention系数获取

这样做就避免了没有共现的特征无法求系数的问题。到了这一步,Attention部分的输出就很简单了:

以上就是AFM模型结构和原理,到这里我想大家基本上就都能明白啦!

2.4 AFM学习过程

AFM通过设置不同的目标函数,能够适用于多种任务:分类问题,回归问题等。对于推荐系统或者CTR预估,一般采用square loss:
损失函数初步

同样为了避免过拟合问题,作者在MLP那一层使用了$L_2$正则,还提到了这里没有用dropout是因为模型稳定性和性能问题,那么最终的目标函数为:
最终的目标函数

最终的性能如下:
性能对比
性能对比

可以看出,对比其他的模型,AFM性能在不同程度上都有一定提升!笔者好奇的是,作者在提出AFM之前不久,提出了NFM,这里并没有对NFM和AFM性能做对比,不知道是什么原因,哈哈!另外一点比较好奇的是,作者这两篇文章都没有对比AUC这个最重要的性能,都是之比较RMSE,不知道是基于什么原因,暂且保留这个疑问吧!

2.5 开源实现

作者还是非常好的,NFM和AFM都给出了开源实现,这个就非常良心啦,不用再造轮子啦!https://github.com/hexiangnan/attentional_factorization_machine

3 总结

个人认为,这篇文章的创新点有两个:(1)PI层的提出,有了PI层,其实还可以基于PI做些其他的事情;(2)首次将Attention思想应用在CTR预估上。不得不佩服作者的嗅觉和灵活应用能力!