【Python】利用GAN生成MNIST数据集

2018-08-07 17:50:19 浏览数 (7238)

本文转载至知乎ID:Charles(白露未晞)知乎个人专栏
下载W3Cschool手机App,0基础随时随地学编程>>戳此了解
导语

利用Python搭建简单的GAN网络来生成MNIST数据集。其中GAN,即生成对抗网络

英文全称:

Generative Adversarial Networks

偷闲入门了一波心心念念的GAN,过来发一波文。

毕竟自从2014年被Ian Goodfellow提出后便一直是深度学习中的热门方向且十分有趣。

Let's Go~~~


相关文件

百度网盘下载链接: https://pan.baidu.com/s/1h6haOWnQojZU67igrNveeA

密码: 51pr


开发工具

相关模块:

tensorflow-gpu模块;

numpy模块;

matplotlib模块;

以及一些Python自带的模块。

其中TensorFlow-GPU版本为:

1.7.0


环境搭建

安装Python并添加到环境变量,pip安装需要的相关模块即可。

其中,TensorFlow-GPU的环境搭建请自行参考相关的网络教程,注意版本和驱动严格对应即可。


原理简介

生成对抗网络(GAN)的基本思想源自博弈论中的二人零和博弈,由一个生成器和一个判别器组成,通过对抗学习的方式来训练。

具体而言:

GAN的生成器主要用来学习真实图像的特征分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对输入的图片进行真假判别。

整个过程串起来就是,生成器努力地让生成的图像更加真实,从而让判别器认为自己生成的图像是真实的图像,而判别器则努力地去识别出图像的真假,让生成器无法骗过自己。

随着训练的进行,生成器和判别器也在不断地斗智斗勇,最后的结果当然就是:

生成器生成的图像接近于真实图像,而判别器对于生成器生成的图像判别正确的概率接近于0.5。

其过程也可以总结为下图:

就生成MNIST的GAN网络具体模型而言:

生成器结构为:

判别器结构为:

至于具体的实现细节详见相关文件中源代码。

代码里注释的还算详细吧T_T


结果展示

训练方式:

在cmd窗口运行GanMnist.py文件即可。

训练过程中的Loss走势:

生成器生成的图像演变(show.py文件):

测试方式:

在cmd窗口运行Test.py文件即可。

利用训练好的模型生成MNIST数据集结果:

更多

T_T好吧这个例子也许并不能体现出GAN多有趣。

那么立个Flag,下次有空利用GAN网络来干一些有趣的事情,非常非常有趣的事情~~~