苏剑林作者NLP、神经网络研究方向

变分自编码器系列:VAE + BN = 更好的VAE

本文我们继续之前的变分自编码器系列,分析一下如何防止 NLP 中的 VAE 模型出现“KL 散度消失(KL Vanishing)”现象。本文受到参考文献是 ACL 2020 的论文 A Batch Normalized Inference Network Keeps the KL Vanishing Away [1] 的启发,并自行做了进一步的完善。

值得一提的是,本文最后得到的方案还是颇为简洁的——只需往编码输出加入BN(Batch Normalization),然后加个简单的 scale——但确实很有效,因此值得正在研究相关问题的读者一试。同时,相关结论也适用于一般的 VAE 模型(包括 CV 的),如果按照笔者的看法,它甚至可以作为 VAE 模型的“标配”。

最后,要提醒读者这算是一篇 VAE 的进阶论文,所以请读者对 VAE 有一定了解后再来阅读本文。

VAE简单回顾

这里我们简单回顾一下 VAE 模型,并且讨论一下 VAE 在 NLP 中所遇到的困难。关于 VAE 的更详细介绍,请读者参考笔者的旧作变分自编码器 VAE:原来是这么一回事再谈变分自编码器 VAE:从贝叶斯观点出发等

1.1 VAE的训练流程

VAE 的训练流程大概可以图示为:

▲ VAE训练流程图示

写成公式就是:

其中第一项就是重构项, 是通过重参数来实现;第二项则称为 KL 散度项,这是它跟普通自编码器的显式差别,如果没有这一项,那么基本上退化为常规的 AE。更详细的符号含义可以参考再谈变分自编码器 VAE:从贝叶斯观点出发

1.2 NLP中的VAE

在 NLP 中,句子被编码为离散的整数 ID,所以 q(x|z) 是一个离散型分布,可以用万能的“条件语言模型”来实现,因此理论上 q(x|z) 可以精确地拟合生成分布,问题就出在 q(x|z) 太强了,训练时重参数操作会来噪声,噪声一大,z 的利用就变得困难起来,所以它干脆不要 z 了,退化为无条件语言模型(依然很强), 则随之下降到 0,这就出现了 KL 散度消失现象

这种情况下的 VAE 模型并没有什么价值:KL 散度为 0 说明编码器输出的是 0 向量,而解码器则是一个普通的语言模型。而我们使用 VAE 通常来说是看中了它无监督构建编码向量的能力,所以要应用 VAE 的话还是得解决 KL 散度消失问题。

事实上从 2016 开始,有不少工作在做这个问题,相应地也提出了很多方案,比如退火策略、更换先验分布等,读者 Google 一下“KL Vanishing”就可以找到很多文献了,这里不一一溯源。

1.3 BN的巧与秒

本文的方案则是直接针对 KL 散度项入手,简单有效而且没什么超参数。其思想很简单:
KL 散度消失不就是 KL 散度项变成 0 吗?我调整一下编码器输出,让 KL 散度有一个大于零的下界,这样它不就肯定不会消失了吗?
这个简单的思想的直接结果就是:在  后面加入 BN 层,如图:
▲ 往VAE里加入BN

1.4 推导过程简述

为什么会跟 BN 联系起来呢?我们来看 KL 散度项的形式:

上式是采样了 b 个样本进行计算的结果,而编码向量的维度则是 d 维。由于我们总是有 ,所以 ,因此:

留意到括号里边的量,其实它就是  在 batch 内的二阶矩,如果我们往  加入 BN 层,那么大体上可以保证  的均值为 ,方差为 ( 是 BN 里边的可训练参数),这时候:

所以只要控制好 (主要是固定  为某个常数),就可以让 KL 散度项有个正的下界,因此就不会出现 KL 散度消失现象了。这样一来,KL 散度消失现象跟 BN 就被巧妙地联系起来了,通过 BN 来“杜绝”了 KL 散度消失的可能性。

1.5 为什么不是LN?

善于推导的读者可能会想到,按照上述思路,如果只是为了让 KL 散度项有个正的下界,其实 LN(Layer Normalization)也可以,也就是在式(3)中按 j 那一维归一化。

那为什么用BN而不是LN呢?

这个问题的答案也是 BN 的巧妙之处。直观来理解,KL 散度消失是因为  的噪声比较大,解码器无法很好地辨别出 z 中的非噪声成分,所以干脆弃之不用。
而当给  加上 BN 后,相当于适当地拉开了不同样本的 z 的距离,使得哪怕 z 带了噪声,区分起来也容易一些,所以这时候解码器乐意用 z 的信息,因此能缓解这个问题;相比之下,LN 是在样本内进的行归一化,没有拉开样本间差距的作用,所以 LN 的效果不会有 BN 那么好。

进一步的结果

事实上,原论文的推导到上面基本上就结束了,剩下的都是实验部分,包括通过实验来确定 的值。然而,笔者认为目前为止的结论还有一些美中不足的地方,比如没有提供关于加入 BN 的更深刻理解,倒更像是一个工程的技巧,又比如只是  加上了 BN, 没有加上,未免有些不对称之感。

经过笔者的推导,发现上面的结论可以进一步完善。

2.1 联系到先验分布

对于 VAE 来说,它希望训练好后的模型的隐变量分布为先验分布 ,而后验分布则是 ,所以 VAE 希望下式成立:

两边乘以 z,并对 z 积分,得到:

两边乘以 ,并对 z 积分,得到:

如果往  都加入 BN,那么我们就有:

所以现在我们知道  一定是 0,而如果我们也固定 ,那么我们就有约束关系:

2.2 参考的实现方案

经过这样的推导,我们发现可以往  都加入 BN,并且可以固定 ,但此时需要满足约束(9)。
要注意的是,这部分讨论还仅仅是对 VAE 的一般分析,并没有涉及到 KL 散度消失问题,哪怕这些条件都满足了,也无法保证 KL 项不趋于 0。结合式(4)我们可以知道,保证 KL 散度不消失的关键是确保 ,所以,笔者提出的最终策略是:

其中  是一个常数,笔者在自己的实验中取了 ,而  是可训练参数,上式利用了恒等式
关键代码参考(Keras):
class Scaler(Layer):
    """特殊的scale层
    """
    def __init__(self, tau=0.5, **kwargs):
        super(Scaler, self).__init__(**kwargs)
        self.tau = tau

    def build(self, input_shape):
        super(Scaler, self).build(input_shape)
        self.scale = self.add_weight(
            name='scale', shape=(input_shape[-1],), initializer='zeros'
        )

    def call(self, inputs, mode='positive'):
        if mode == 'positive':
            scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
        else:
            scale = (1 - self.tau) * K.sigmoid(-self.scale)
        return inputs * K.sqrt(scale)

    def get_config(self):
        config = {'tau': self.tau}
        base_config = super(Scaler, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def sampling(inputs):
    """重参数采样
    """
    z_mean, z_std = inputs
    noise = K.random_normal(shape=K.shape(z_mean))
    return z_mean + z_std * noise


e_outputs  # 假设e_outputs是编码器的输出向量
scaler = Scaler()
z_mean = Dense(hidden_dims)(e_outputs)
z_mean = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_mean)
z_mean = scaler(z_mean, mode='positive')
z_std = Dense(hidden_dims)(e_outputs)
z_std = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_std)
z_std = scaler(z_std, mode='negative')
z = Lambda(sampling, name='Sampling')([z_mean, z_std])

文章内容小结

本文简单分析了 VAE 在 NLP 中的 KL 散度消失现象,并介绍了通过 BN 层来防止 KL 散度消失、稳定训练流程的方法。这是一种简洁有效的方案,不单单是原论文,笔者私下也做了简单的实验,结果确实也表明了它的有效性,值得各位读者试用。因为其推导具有一般性,所以甚至任意场景(比如 CV)中的 VAE 模型都可以尝试一下。


参考链接


[1] https://arxiv.org/abs/2004.12585

PaperWeekly
PaperWeekly

推荐、解读、讨论和报道人工智能前沿论文成果的学术平台。

理论BNVAE变分自编码器
1
相关数据
变分自编码器技术

变分自编码器可用于对先验数据分布进行建模。从名字上就可以看出,它包括两部分:编码器和解码器。编码器将数据分布的高级特征映射到数据的低级表征,低级表征叫作本征向量(latent vector)。解码器吸收数据的低级表征,然后输出同样数据的高级表征。变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。在自动编码器中,需要输入一张图片,然后将一张图片编码之后得到一个隐含向量,这比原始方法的随机取一个随机噪声更好,因为这包含着原图片的信息,然后隐含向量解码得到与原图片对应的照片。但是这样其实并不能任意生成图片,因为没有办法自己去构造隐藏向量,所以它需要通过一张图片输入编码才知道得到的隐含向量是什么,这时就可以通过变分自动编码器来解决这个问题。解决办法就是在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。这样生成一张新图片就比较容易,只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成想要的图片,而不需要给它一张原始图片先编码。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

隐变量技术

在统计学中,隐变量或潜变量指的是不可观测的随机变量。隐变量可以通过使用数学模型依据观测得的数据被推断出来。

语言模型技术

语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。由于字词与句子都是任意组合的长度,因此在训练过的语言模型中会出现未曾出现的字串(资料稀疏的问题),也使得在语料库中估算字串的机率变得很困难,这也是要使用近似的平滑n元语法(N-gram)模型之原因。

卡路里科技机构

Keep 致力于提供健身教学、跑步、骑行、交友、健康饮食指导及装备购买等一站式运动解决方案,持续打造「自由运动场」来帮助人们随时随地尽享运动。 Keep APP提供丰富的运动课程、社区交友、产品功能;Keepland线下城市运动空间,轻便的小团课精品课程使城市人群可以随时随地享受运动的乐趣;KeepKit智能硬件是部署「家庭」场景,硬件产品KeepKit连接运动与家庭场景,以内容为核心的智能运动产品平台,重塑家庭运动体验;KeepUp 是 Keep 的运动服饰品牌。年轻、酷感和运动是 KeepUp 一脉相承的品牌特点。

https://www.gotokeep.com/
推荐文章
暂无评论
暂无评论~