manquyuan00 2019-11-17
GAN,叫做生成对抗网络 (Generative Adversarial Network) 。其基本原理是生成器网络 G(Generator) 和判别器网络 D(Discriminator) 相互博弈。生成器网络 G 的主要作用是生成图片,在输入一个随机编码 (random code) z后,自动的生成假样本 G(z) 。判别器网络 D 的主要作用是判断输入是否为真实样本并提供反馈机制,真样本则输出 1 ,反之为 0 。在两个网络相互博弈的过程中,两个网络的能力都越来越高:G 生成的图片越来越像真样本,D 也越来越会判断图片的真假,然后我们在最大化 D 的前提下,最小化 D 对 G 的判断能力,这实际上就是最小最大值问题,或者说二人零和博弈,其目标函数表达式:
其中表达式中的第一项 D(G(z)) 处理的是假图像 G(z) ,我们尽量降低评分 D(G(z)) ;第二项处理的是真图像 x ,此时评分要高。但是 GAN 并不是完美的,也有自己的局限性。比如说没有用户控制的能力和低分辨率与低质量的问题。
为了提高 GAN 的用户控制能力,人类进行了一些列的探索研究。比如 Pix2Pix 模型采用有条件的使用用户输入,使用成对的数据 (paired data) 进行训练; CycleGAN 模型使用不成对的数据 (unpaired data) 的就能训练 。但无论是 Pix2Pix 还是 CycleGAN ,都是解决了从一个领域到另一个领域的图像转换问题。当有很多领域需要转换时,对于每一个领域转换,都需要重新训练一个模型去解决。目前,存在的模型处理多领域图像生成任务时,学习 k 个领域之间所有映射就必须训练 k * (k-1) 个生成器。如果训练一对一的图像多领域生成任务时,主要会导致两个问题:
上图中 (a) 模型说明如何训练 12 个不同生成器网络以达到 4 个不同领域图像之间转换任务。很明显每个生成器不能够充分利用整个训练数据,只能从 4 个领域中 2 个领域相互学习,这样就会生成图片质量不好。而上图(b)中的模型就可以解决这些问题,该模型接受多个领域训练数据,并仅使用一个生成器来学习多领域图像之间映射关系。根据模型的长相将该模型称为星形网络,外文名就是 StarGAN 。
上图是根据 StarGAN 模型训练出的效果。在同一种模型下,可以做多领域图像之间的转换,比如更换头发颜色、更换表情、更换年龄等。
上图是对 StarGAN 的简单介绍,主要包含判别器 D 和生成器 G 。
(a)D 对真假图片进行判别,真图片判真,假图片判假,真图片被分类到相应域。
(b)G 接受真图片和目标域标签并生成假图片;
(c)G 在给定原始域标签的情况下将假图片重建为原始图片(重构损失);
(d)G 尽可能生成与真实图像无法区分的图像,并且通过 D 分类到目标域。
首先描述 StarGAN 网络,在一个数据集中进行多领域的图像转换任务;然后我们讨论了如何使 StarGAN 能合并包含不同标签的数据集以及对其中任意的标签属性灵活进行图像转换。
训练一个生成器 G ,能够多领域映射。将带有领域标签 c 的输入图像 x 转换为输出图像 y,即 。随机生成目标领域标签 c 使得 G 能够灵活的转换输入图像,同时使用 D 控制多领域。这样 D 就在图像源和域标签上产生概率分布,即。
使用对抗损失函数提高生成图像质量,达到 D 无法区分出来输出图像和生成图像之间的差别:
根据输入图像 x 和目标领域标签 c ,由 G 生成输出图像,同时 D 区分出真实图像和生成图像。将
作为输入图像 x 经过 D 之后得到的可能性分布。生成器 G 使这个式子尽可能的小,而 D 则尽可能使其最大化。
对于一个输入图像 x 和目标分布标签 c ,我们的目标是将 x 转换为输出图像 y后能够被正确分类为目标分布 c 。为了实现这一目标,我们在 D 之上添加一个辅助分类器,并在优化 G 和 D 时采用目标域分类损失函数。简单来说,我们将这个式子分解为两部分:一个真实图像的分布分类损失用于约束 D ,一个假图像的分布分类损失用于约束 G 。其表达式如下所示:
其中,代表 D 计算出来的领域标签的可能性分布。一方面,通过将这个式子最小化, D 将真实图像 x 正确分类到与其相关分布 c' 。另一方面,假图像的分类分布的损失函数定义如下:
即 G 使这个式子最小化,使得生成的图像能够被 D 判别为目标领域 c。
通过最小化对抗损失和分类损失, G 训练生成的图像尽可能与真实图像一样,并且能够被分类到正确的目标领域。然而,最小化这两个损失函数不能保证 , 转换后的图像中,只改变领域差异的部分, 而保留输入图像中的其他内容 。故对 G 使用循环一致性损失函数 (cycle consistency loss) ,如下:
其中: G 以生成图像 G(x,c) 以及原始输入图像领域标签 c' 为输入,努力重构出原始图像 x 。我们选择L范数作为重构损失函数。注意到我们两次使用了同一个生成器,第一次将原始图像转换到目标领域的图像,然后将生成的图像重构回原始图像。
最终 G 和 D 的损失函数表示如下:
其中 _ 和 _ 是控制分类误差和重构误差相对于对抗误差的相对权重的超参数。在所有实验中,我们设置。
为了 GAN 训练过程稳定,生成高质量的图像,论文中采用自定义梯度惩罚来代替对抗误差损失:
其中: 表示真实和生成图像之间均匀采样的直线,试验时。
starGAN 的一个重要优势在于它能够同时合并包含不同标签的不同数据集,使得其在测试阶段能够控制所有的标签。从多个数据集学习的问题在于标签信息对每一个数据集而言只是部分已知。在 CelebA 和 RaFD 的例子中,前一个数据集包含诸如发色,性别等信息,但它不包含任何后一个数据集中包含的诸如开心生气等表情标签。这会引起问题,因为在将 G(x,c) 重构回输入图像 x 时需要完整的标签信息 c' 。
为了缓解这一问题,我们引入了向量掩码 m,使 StarGAN 模型能够忽略不确定的标签,专注于特定数据集提供的明确的已知标签。在 StarGAN 中我们使用 n 维的 one-hot 向量来代表 m ,n 表示数据集的数量。除此之外,我们将标签的同一版本定义为一个数组:
其中:[·]表示串联,其中 c表示第 i 个数据集的标签,已知标签 c 的向量能用二值标签表示二值属性或者用 one-hot 的形式表示多类属性。对于剩下的 n-1 个未 i 知标签我们简单的置为 0 。
利用多数据集训练 StarGAN 时,我们使用上面定义的 作为生成器的输入。如此,生成器学会忽略非特定的标签,而专注于指定的标签。除了输入标签 ,此处的生成器与单数据集训练的生成器网络结构一样。另一方面我们也扩展判别器的辅助分类器的分类类别到到所属聚集的所有标签。最后,我们将我们的模型按照多任务学习的方式进行训练,其中,判别器只将已知标签相关的分类误差最小化即可。
以 celebA 数据为例,下载后的数据包括 label 文件和图像。
(1, '5_o_Clock_Shadow'), (2, 'Arched_Eyebrows'), (3, 'Attractive'), (4, 'Bags_Under_Eyes'), (5, 'Bald'), (6, 'Bangs'), (7, 'Big_Lips'), (8, 'Big_Nose'), (9, 'Black_Hair'), (10, 'Blond_Hair'), (11, 'Blurry'), (12, 'Brown_Hair'), (13, 'Bushy_Eyebrows'), (14, 'Chubby'), (15, 'Double_Chin'), (16, 'Eyeglasses'), (17, 'Goatee'), (18, 'Gray_Hair'), (19, 'Heavy_Makeup'), (20, 'High_Cheekbones'), (21, 'Male'), (22, 'Mouth_Slightly_Open'), (23, 'Mustache'), (24, 'Narrow_Eyes'), (25, 'No_Beard'), (26, 'Oval_Face'), (27, 'Pale_Skin'), (28, 'Pointy_Nose'), (29, 'Receding_Hairline'), (30, 'Rosy_Cheeks'), (31, 'Sideburns'), (32, 'Smiling'), (33, 'Straight_Hair'), (34, 'Wavy_Hair'), (35, 'Wearing_Earrings'), (36, 'Wearing_Hat'), (37, 'Wearing_Lipstick'), (38, 'Wearing_Necklace'), (39, 'Wearing_Necktie'), (40, 'Young')
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
通过本文学习,您应该初步了解 StarGAN 模型的网络结构和实现原理,以及关键部分代码的初步实现。如果您对深度学习 Tensorflow 比较了解,可以参考 Tensorflow版实现starGAN;如果您对pytorch框架比较熟悉,可以参考 pytorch实现starGAN;如果您想更深入的学习了解starGAN原理,可以参考 论文。
如果想体验项目效果,您可以登陆 Mo 平台,在 应用中心 中找到 StarGAN,可以体验以下五种特征['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] 的风格变换。考虑到代码较长,我们在StarGAN 项目源码中对相关代码做了详细解释。您在学习的过程中,遇到困难或者发现我们的错误,可以随时联系我们。
1.论文:https://arxiv.org/pdf/1711.09020.pdf
2.博客:https://blog.csdn.net/stdcoutzyx/article/details/78829232
3.博客:https://www.cnblogs.com/Thinker-pcw/p/9785379.html
4.pytorch原版github地址:https://github.com/yunjey/StarGAN
5.tensorflow版github地址:https://github.com/taki0112/StarGAN-Tensorflow
6.Celeba数据集:https://www.dropbox.com/s/d1kjpkqklf0uw77/celeba.zip?dl=0
Mo(网址:momodel.cn)是一个支持 Python 的人工智能在线建模平台,能帮助你快速开发、训练并部署模型。
Mo 人工智能俱乐部 是由网站的研发与产品设计团队发起、致力于降低人工智能开发与使用门槛的俱乐部。团队具备大数据处理分析、可视化与数据建模经验,已承担多领域智能项目,具备从底层到前端的全线设计开发能力。主要研究方向为大数据管理分析与人工智能技术,并以此来促进数据驱动的科学研究。
目前俱乐部每周六在杭州举办以机器学习为主题的线下技术沙龙活动,不定期进行论文分享与学术交流。希望能汇聚来自各行各业对人工智能感兴趣的朋友,不断交流共同成长,推动人工智能民主化、应用普及化。