【Python】利用GAN生成动漫头像

2018-08-07 17:56:55 浏览数 (9545)

本文转载至知乎ID:Charles(白露未晞)知乎个人专栏

下载W3Cschool手机App,0基础随时随地学编程>>戳此了解

参考文献

用深层卷积生成对抗网络进行无监督表示学习;

GenerativeAdversarialNets。

甘论文汇总:

https://github.com/zhangqianhui/AdversarialNetsPapers


声明

本教程提供的所有源码以及素材仅供学习交流使用,禁止商用/非法使用。


导语

发现之前两篇关于GAN的文章效果都比较一般,想试着优化一下,训练个像样一些的模型,至少不能太丢GAN的脸,于是就有了这篇文章。

OK,让我们愉快地开始吧〜


相关文件

百度网盘下载链接:https//pan.baidu.com/s/1t-d5wq3TeBWcVzTOoraPtQ

密码:84ky


开发工具

蟒蛇中的版本:3.6.4

相关模块:

pytorch模块;

torchvision模块;

PIL模块;

以及一些Python中的中自带的模块。

PyTorch版本:

0.3.0


环境搭建

安装的Python的中并添加到环境变量,PIP安装需要的相关模块即可。

补充说明:

PyTorch0.3.0不支持直接的PIP安装(Windows)中中中。

有两个选择:

(1)安装anaconda3后在anaconda3的环境下安装(直接PIP安装即可);

(2)使用编译好的WHL文件安装,下载链接为:

https://pan.baidu.com/s/1dF6ayLr#list/path=%2Fpytorch

原理简介

关于生成对抗网络的核心思想,请参考之前的文章:

【Python中的中】利用甘生成MNIST数据集。

顺便补充一下甘训练目标的数学语言描述:

公式解释如下:

X:真实图片;

Z:输入ģ网络的噪声;

G(Z):G ^网络生成的图片;

d(X):真实图片是否真实的概率;

d(G(X)):G ^网络生成的图片是否真实的概率。

正如之前的文章所述,生成网络ģ的训练目标是尽可能生成真实的图片去欺骗判别网络d;而判别网络d的训练目标就是尽可能把生成网络ģ生成的图片和真实的图片区别开来,即训练过程是一个动态的“博弈过程”。

因此,公式中的ģ网络希望d(G(Z))尽可能得大; d网络希望d(x)的的的尽可能得大,d(G(X))尽可能得小故而公式的训练目标为

更多关于甘的原理介绍和应用可参考参考文献 ”部分的内容。

具体模型

【的的的Python】利用GAN神奇生成宝贝一文中使用的网络结构不同,本文使用了全卷积网络结构(即不再加入全连接层FC)。同时本文增加了训练数据量,使用了大约5万张动漫头像作为训练数据。

具体而言,生成器结构为:

判别器结构为:

具体实现详见相关文件中的源代码。


模型训练

一。训练数据集

使用了大约5万张动漫头像作为训练数据集,数据集源:

https://zhuanlan.zhihu.com/p/24767059。

二,模型训练

修改config.json文件中的训练数据集路径:

在CMD显示窗口显示运行显示train.py文件即可。

训练截图:

效果展示

Epoch0:

Epoch5:

Epoch10:

Epoch15:

Epoch20:

Epoch25:

Epoch29:

更多

代码截止2018年7月4日测试无误。

相关文件中提供了训练好的模型以及调用模型的简单脚本,直接在cmd窗口运行“ test.py ”文件即可生成动漫头像: