Auto Byte

Science AI

# 通过PyTorch实现对抗自编码器

「大多数人类和动物学习是无监督学习。如果智能是一块蛋糕，无监督学习是蛋糕的坯子，有监督学习是蛋糕上的糖衣，而强化学习则是蛋糕上的樱桃。我们知道如何做糖衣和樱桃，但我们不知道如何做蛋糕。」

#Encoderclass Q_net(nn.Module):      def __init__(self):        super(Q_net, self).__init__()        self.lin1 = nn.Linear(X_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3gauss = nn.Linear(N, z_dim)    def forward(self, x):        x = F.droppout(self.lin1(x), p=0.25, training=self.training)        x = F.relu(x)        x = F.droppout(self.lin2(x), p=0.25, training=self.training)        x = F.relu(x)        xgauss = self.lin3gauss(x)        return xgauss

# Decoderclass P_net(nn.Module):      def __init__(self):        super(P_net, self).__init__()        self.lin1 = nn.Linear(z_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3 = nn.Linear(N, X_dim)    def forward(self, x):        x = self.lin1(x)        x = F.dropout(x, p=0.25, training=self.training)        x = F.relu(x)        x = self.lin2(x)        x = F.dropout(x, p=0.25, training=self.training)        x = self.lin3(x)        return F.sigmoid(x)

# Discriminatorclass D_net_gauss(nn.Module):      def __init__(self):        super(D_net_gauss, self).__init__()        self.lin1 = nn.Linear(z_dim, N)        self.lin2 = nn.Linear(N, N)        self.lin3 = nn.Linear(N, 1)    def forward(self, x):        x = F.dropout(self.lin1(x), p=0.2, training=self.training)        x = F.relu(x)        x = F.dropout(self.lin2(x), p=0.2, training=self.training)        x = F.relu(x)        return F.sigmoid(self.lin3(x))

torch.manual_seed(10)   Q, P = Q_net() = Q_net(), P_net(0)     # Encoder/Decoder  D_gauss = D_net_gauss()                # Discriminator adversarial  if torch.cuda.is_available():      Q = Q.cuda()    P = P.cuda()    D_cat = D_gauss.cuda()    D_gauss = D_net_gauss().cuda()# Set learning ratesgen_lr, reg_lr = 0.0006, 0.0008  # Set optimizatorsP_decoder = optim.Adam(P.parameters(), lr=gen_lr)   Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)   Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)   D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)

1）通过编码器/解码器部分进行前向路径（forward path）计算，计算重建损失并更新编码器 Q 和解码器 P 网络的参数。

z_sample = Q(X)    X_sample = P(z_sample)    recon_loss = F.binary_cross_entropy(X_sample + TINY,                                        X.resize(train_batch_size, X_dim) + TINY)    recon_loss.backward()    P_decoder.step()    Q_encoder.step()

2）创建潜在表征 z = Q(x)，并从先验函数的 p(z) 取样本 z'，通过判别器运行每个样本，并计算分配给每个 (D(z) 和 D(z')) 的分数。

Q.eval()        z_real_gauss = Variable(torch.randn(train_batch_size, z_dim) * 5)   # Sample from N(0,5)    if torch.cuda.is_available():        z_real_gauss = z_real_gauss.cuda()    z_fake_gauss = Q(X)

3）计算判别器的损失函数，并通过判别器网络反向传播更新其权重。在代码中，

# Compute discriminator outputs and loss    D_real_gauss, D_fake_gauss = D_gauss(z_real_gauss), D_gauss(z_fake_gauss)    D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))    D_loss.backward()       # Backpropagate loss    D_gauss_solver.step()   # Apply optimization step

4）计算生成网络的损失函数并相应地更新 Q 网络。

# GeneratorQ.train()   # Back to use dropout  z_fake_gauss = Q(X)  D_fake_gauss = D_gauss(z_fake_gauss)G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))   G_loss.backward()   Q_generator.step()

AAE 学习消纠缠表征（disentangled representation）

• What is a variational autoencoder (https://jaan.io/what-is-variational-autoencoder-vae-tutorial) (Tutorial)

• Auto-encoding Variational Bayes (https://arxiv.org/abs/1312.6114) (original paper)

• Adversarial Autoencoders (https://arxiv.org/abs/1511.05644) (original paper)

• Building Machines that Imagine and Reason: Principles and Applications of Deep Generative Models (http://videolectures.net/deeplearning2016_mohamed_generative_models/) (Video Lecture)