Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

尹相楠作者里昂中央理工博士在读学校

Self-Attention GAN 中的 self-attention 机制

Self Attention GAN 用到了很多新的技术。最大的亮点当然是 self-attention 机制,该机制是Non-local Neural Networks [1] 这篇文章提出的。其作用是能够更好地学习到全局特征之间的依赖关系。因为传统的 GAN 模型很容易学习到纹理特征:如皮毛,天空,草地等,不容易学习到特定的结构和几何特征,例如狗有四条腿,既不能多也不能少。 

除此之外,文章还用到了 Spectral Normalization for GANs [2] 提出的谱归一化。谱归一化的解释见本人这篇文章:详解GAN的谱归一化(Spectral Normalization)

但是,该文代码中的谱归一化和原始的谱归一化运用方式略有差别: 

1. 原始的谱归一化基于 W-GAN 的理论,只用在 Discriminator 中,用以约束 Discriminator 函数为 1-Lipschitz 连续。而在 Self-Attention GAN 中,Spectral Normalization 同时出现在了 Discriminator 和 Generator 中,用于使梯度更稳定。除了生成器和判别器的最后一层外,每个卷积/反卷积单元都会上一个 SpectralNorm。 

2. 当把谱归一化用在 Generator 上时,同时还保留了 BatchNorm。Discriminator 上则没有 BatchNorm,只有 SpectralNorm。 

3. 谱归一化用在 Discriminator 上时最后一层不加 Spectral Norm。 

最后,self-attention GAN 还用到了 cGANs With Projection Discriminator 提出的conditional normalizationprojection in the discriminator。这两个技术我还没有来得及看,而且 PyTorch 版本的 self-attention GAN 代码中也没有实现,就先不管它们了。

本文主要说的是 self-attention 这部分内容。

 图1. Self-Attention

Self-Attention

卷积神经网络中,每个卷积核的尺寸都是很有限的(基本上不会大于 5),因此每次卷积操作只能覆盖像素点周围很小一块邻域。

对于距离较远的特征,例如狗有四条腿这类特征,就不容易捕获到了(也不是完全捕获不到,因为多层的卷积、池化操作会把 feature map 的高和宽变得越来越小,越靠后的层,其卷积核覆盖的区域映射回原图对应的面积越大。但总而言之,毕竟还得需要经过多层映射,不够直接)。

Self-Attention 通过直接计算图像中任意两个像素点之间的关系,一步到位地获取图像的全局几何特征。 

论文中的公式不够直观,我们直接看文章的 PyTorch 的代码,核心部分为 sagan_models.py:

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention

构造函数中定义了三个 1 × 1 的卷积核,分别被命名为 query_conv , key_conv 和value_conv 。

为啥命名为这三个名字呢?这和作者给它们赋予的含义有关。query 意为查询,我们希望输入一个像素点,查询(计算)到 feature map 上所有像素点对这一点的影响。而 key 代表字典中的键,相当于所查询数据库。query 和 key 都是输入的 feature map,可以看成把 feature map 复制了两份,一份作为 query 一份作为 key。 

需要用一个什么样的函数,才能针对 query 的 feature map 中的某一个位置,计算出 key 的 feature map 中所有位置对它的影响呢?作者认为这个函数应该是可以通过“学习”得到的。那么,自然而然就想到要对这两个 feature map 分别做卷积核为 1 × 1 的卷积了,因为卷积核的权重是可以学习得到的。 

至于 value_conv ,可以看成对原 feature map 多加了一层卷积映射,这样可以学习到的参数就更多了,否则 query_conv 和 key_conv 的参数太少,按代码中只有 in_dims × in_dims//8个。 

接下来逐行研究 forward 函数:

proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)

这行代码先对输入的 feature map 卷积了一次,相当于对 query feature map 做了一次投影,所以叫做 proj_query。由于是 1 × 1 的卷积,所以不改变 feature map 的长和宽。feature map 的每个通道为如 (1) 所示的矩阵,矩阵共有 N 个元素(像素)。

然后重新改变了输出的维度,变成:

 (m_batchsize,-1,width*height) 

batch size 保持不变,width 和 height 融合到一起,把如 (1) 所示二维的 feature map 每个 channel 拉成一个长度为 N 的向量。

因此,如果 m_batchsize 取 1,即单独观察一个样本,该操作的结果是得到一个矩阵,矩阵的的行数为 query_conv 卷积输出的 channel 的数目 C( in_dim//8 ),列数为 feature map 像素数 N。

然后作者又通过 .permute(0, 2, 1) 转置了矩阵,矩阵的行数变成了 feature map 的像素数 N,列数变成了通道数 C。因此矩阵维度为 N × C 。该矩阵每行代表一个像素位置上所有通道的值,每列代表某个通道中所有的像素值。

 图2. proj_query 的维度

proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height)

这行代码和上一行类似,只不过取消了转置操作。得到的矩阵行数为通道数 C,列数为像素数 N,即矩阵维度为 C × N。该矩阵每行代表一个通道中所有的像素值,每列代表一个像素位置上所有通道的值。

 图3. proj_key的维度

energy =  torch.bmm(proj_query,proj_key)

这行代码中, torch.bmm 的意思是 batch matrix multiplication。就是说把相同 batch size 的两组 matrix 一一对应地做矩阵乘法,最后得到同样 batchsize 的新矩阵。

若 batch size=1,就是普通的矩阵乘法。已知 proj_query 维度是 N × C, proj_key 的维度是 C × N,因此 energy 的维度是 N × N:

 图4. energy的维度

energy 是 attention 的核心,其中第 i 行 j 列的元素,是由 proj_query 第 i 行,和 proj_key 第 j 列通过向量点乘得到的。而 proj_query 第 i 行表示的是 feature map 上第 i 个像素位置上所有通道的值,也就是第 i 个像素位置的所有信息,而 proj_key 第 j 列表示的是 feature map 上第 j 个像素位置上的所有通道值,也就是第 j 个像素位置的所有信息。

这俩相乘,可以看成是第 j 个像素对第 i 个像素的影响。即,energy 中第 i 行 j 列的元素值,表示第 j 个像素点对第 i 个像素点的影响。

attention = self.softmax(energy)

这里 sofmax 是构造函数中定义的,为按“行”归一化。这个操作之后的矩阵,各行元素之和为 1。这也比较好理解,因为 energy 中第 i 行元素,代表 feature map 中所有位置的像素对第 i 个像素的影响,而这个影响被解释为权重,故加起来应该是 1,故应对其按行归一化。attention 的维度也是 N × N。

proj_value = self.value_conv(x).view(m_batchsize,-1,width*height)

上面的代码中,先对原 feature map 作一次卷积映射,然后把得到的新 feature map 改变形状,维度变为 C × N ,其中 C 为通道数(注意和上面计算 proj_query   proj_key 的 C 不同,上面的 C 为 feature map 通道数的 1/8,这里的 C 与 feature map 通道数相同),N 为 feature map 的像素数。

 图5. proj_value的维度

out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)

然后,再把 proj_value (C × N)矩阵同  attention 矩阵的转置(N × N)相乘,得到 out(C × N)。之所以转置,是因为 attention 中每行的和为 1,其意义是权重,需要转置后变为每列的和为 1,施加于 proj_value 的行上,作为该行的加权平均。 proj_value 第 i 行代表第 i 个通道所有的像素值, attention 第 j 列,代表所有像素施加到第 j 个像素的影响。

因此, out 中第 i 行包含了输出的第 i 个通道中的所有像素,第 j 列表示所有像素中的第 j 个像素,合起来也就是: out 中的第 i 行第 j 列的元素,表示被 attention 加权之后的 feature map 的第 i 个通道的第 j 个像素的像素值。再改变一下形状, out 就恢复了 channel×width×height 的结构。

 图6. out的维度

out = self.gamma*out + x

最后一行代码,借鉴了残差神经网络(residual neural networks)的操作, gamma 是一个参数,表示整体施加了 attention 之后的 feature map 的权重,需要通过反向传播更新。而 x 就是输入的 feature map。

在初始阶段, gamma 为 0,该 attention 模块直接返回输入的 feature map,之后随着学习,该 attention 模块逐渐学习到了将 attention 加权过的 feature map 加在原始的 feature map 上,从而强调了需要施加注意力的部分 feature map。

总结

可以把 self attention 看成是 feature map 和它自身的转置相乘,让任意两个位置的像素直接发生关系,这样就可以学习到任意两个像素之间的依赖关系,从而得到全局特征了。看论文时会被它复杂的符号迷惑,但是一看代码就发现其实是很 naive 的操作。

参考文献

[1] Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He, Non-local Neural Networks, CVPR 2018.

[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, Spectral Normalization for Generative Adversarial Networks, ICLR 2018.

PaperWeekly
PaperWeekly

推荐、解读、讨论和报道人工智能前沿论文成果的学术平台。

理论attention 模型注意力机制GAN卷积神经网络
121
相关数据
来也科技机构

来也科技是中国乃至全球的智能自动化领军品牌,为客户提供变革性的智能自动化解决方案,提升组织生产力和办公效率,释放员工潜力,助力政企实现智能时代的人机协同。 来也科技的产品是一套智能自动化平台,包含机器人流程自动化(RPA)、智能文档处理(IDP)、对话式AI(Conversational AI)等。基于这一平台,能够根据客户需要,构造各种不同类型的数字化劳动力,实现业务流程的自动化,全面提升业务效率。

www.laiye.com/
池化技术

池化(Pooling)是卷积神经网络中的一个重要的概念,它实际上是一种形式的降采样。有多种不同形式的非线性池化函数,而其中“最大池化(Max pooling)”是最为常见的。它是将输入的图像划分为若干个矩形区域,对每个子区域输出最大值。直觉上,这种机制能够有效的原因在于,在发现一个特征之后,它的精确位置远不及它和其他特征的相对位置的关系重要。池化层会不断地减小数据的空间大小,因此参数的数量和计算量也会下降,这在一定程度上也控制了过拟合。通常来说,CNN的卷积层之间都会周期性地插入池化层。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

数据库技术

数据库,简而言之可视为电子化的文件柜——存储电子文件的处所,用户可以对文件中的数据运行新增、截取、更新、删除等操作。 所谓“数据库”系以一定方式储存在一起、能予多个用户共享、具有尽可能小的冗余度、与应用程序彼此独立的数据集合。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

测试