简介

生成对抗网络(Generative Adversarial Network, 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

原始GAN与其他生成式模型相比,GAN这种竞争的方式不再要求一个假设的数据分布,即不需要formulate p(x),而是使用一种分布直接进行采样sampling,从而真正达到理论上可以完全逼近真实数据,这也是GAN最大的优势。然而,这种不需要预先建模的方法缺点是太过自由了,对于较大的图片,较多的 pixel的情形,基于简单 GAN 的方式就不太可控了。为了解决GAN太过自由这个问题,一个很自然的想法是给GAN加一些约束,于是便有了Conditional Generative Adversarial Nets(CGAN)

CGAN,条件生成对抗网络,一种带条件约束的GAN,在生成模型(D)和判别模型(G)的建模中均引入条件变量y(conditional variable y),使用额外信息y对模型增加条件,可以指导数据生成过程。这些条件变量y可以基于多种信息,例如类别标签,用于图像修复的部分数据,来自不同模态(modality)的数据。如果条件变量y是类别标签,可以看做CGAN 是把纯无监督的 GAN 变成有监督的模型的一种改进。这个简单直接的改进被证明非常有效,并广泛用于后续的相关工作中。网络结构如下图所示:

基于PaddlePaddle的强化学习算法Conditional GAN-LMLPHP

conditionalGAN训练19轮的模型预测效果如下图所示:

基于PaddlePaddle的强化学习算法Conditional GAN-LMLPHP

阅读本项目之前建议阅读原版论文https://arxiv.org/pdf/1411.1784.pdf 优秀解读博客https://blog.csdn.net/stalbo/article/details/79359380

 

本项目采用minist数据集

In[1]
# 代码结构
# ├── network.py   # 定义基础生成网络和判别网络。
# ├── utility.py   # 定义通用工具方法。
# └── c_gan.py     # conditionalGAN训练脚本.
# └── infer.py     # 预测脚本
# └── reader.py     # 数据读取.
In[2]
# 训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至--output选项指定的路径。
# 执行python c_gan.py --help可查看更多使用方式和参数详细说明。
#在GPU上训练CGAN,测试结果以图片的形式保存至--output选项指定的路径。
!python c_gan/c_gan.py --epoch=5 --output="./C_result" --use_gpu=True
2020-02-18 11:03:15,594-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2020-02-18 11:03:16,011-INFO: generated new fontManager
-----------  Configuration Arguments -----------
batch_size: 128
epoch: 5
output: ./C_result
run_ce: False
use_gpu: 1
------------------------------------------------
W0218 11:03:17.667068    93 device_context.cc:236] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0218 11:03:17.672333    93 device_context.cc:244] device: 0, cuDNN Version: 7.3.
Epoch ID=0
 Batch ID=0
 D-Loss=1.4754687547683716
 DG-Loss=0.48697683215141296
 gen=[0.020384092, -0.628501, 0.8264506, -0.059517004, 0.08393577]
 Batch_time_cost=0.07
Epoch ID=0
 Batch ID=50
 D-Loss=1.3325045108795166
 DG-Loss=0.5968529582023621
 gen=[-0.32408723, -0.99999803, 0.9999368, -0.8411597, 0.089144364]
 Batch_time_cost=0.07
Epoch ID=0
 Batch ID=100
 D-Loss=1.3283110857009888
 DG-Loss=0.6360062956809998
 gen=[-0.48128778, -1.0, 0.9999977, -0.9618741, -0.15321578]
 Batch_time_cost=0.06
Epoch ID=0
 Batch ID=150
 D-Loss=1.3380171060562134
 DG-Loss=0.6329773664474487
 gen=[-0.50825375, -0.9999999, 0.99999666, -0.96668917, -0.24650867]
 Batch_time_cost=0.06
Epoch ID=0
 Batch ID=200
 D-Loss=1.3771660327911377
 DG-Loss=0.6287590861320496
 gen=[-0.48687252, -0.99999964, 0.99999994, -0.96308786, -0.18785018]
 Batch_time_cost=0.06
Epoch ID=0
 Batch ID=250
 D-Loss=1.3531957864761353
 DG-Loss=0.6571791768074036
 gen=[-0.39359128, -0.9999998, 1.0, -0.9617758, 0.14621706]
 Batch_time_cost=0.07
Epoch ID=0
 Batch ID=300
 D-Loss=1.363043189048767
 DG-Loss=0.6445130109786987
 gen=[-0.39853948, -0.99999535, 0.9999999, -0.9558876, 0.10362165]
 Batch_time_cost=0.07
Epoch ID=0
 Batch ID=350
 D-Loss=1.3479816913604736
 DG-Loss=0.6430233716964722
 gen=[-0.6183855, -1.0, 0.9999999, -0.98799354, -0.5378779]
 Batch_time_cost=0.06
Epoch ID=0
 Batch ID=400
 D-Loss=1.345306634902954
 DG-Loss=0.6640985012054443
 gen=[-0.39159825, -0.99999994, 1.0, -0.96443295, 0.1392249]
 Batch_time_cost=0.06
Epoch ID=0
 Batch ID=450
 D-Loss=1.3394019603729248
 DG-Loss=0.6445193886756897
 gen=[-0.4453725, -1.0, 1.0, -0.97341233, -0.037463646]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=0
 D-Loss=1.3356835842132568
 DG-Loss=0.6343369483947754
 gen=[-0.73340905, -1.0, 0.9999999, -0.9945303, -0.78381747]
 Batch_time_cost=0.07
Epoch ID=1
 Batch ID=50
 D-Loss=1.3558433055877686
 DG-Loss=0.6355264186859131
 gen=[-0.5320427, -1.0, 1.0, -0.9876131, -0.31378862]
 Batch_time_cost=0.07
Epoch ID=1
 Batch ID=100
 D-Loss=1.3402173519134521
 DG-Loss=0.6352111101150513
 gen=[-0.55672026, -1.0, 1.0, -0.9907406, -0.41422915]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=150
 D-Loss=1.3480899333953857
 DG-Loss=0.6375709772109985
 gen=[-0.5827232, -1.0, 1.0, -0.9946148, -0.47664627]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=200
 D-Loss=1.3259913921356201
 DG-Loss=0.641555905342102
 gen=[-0.73367566, -1.0, 0.99999887, -0.99813175, -0.7840915]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=250
 D-Loss=1.3230714797973633
 DG-Loss=0.6369626522064209
 gen=[-0.49091578, -1.0, 1.0, -0.98903924, -0.18111996]
 Batch_time_cost=0.07
Epoch ID=1
 Batch ID=300
 D-Loss=1.3248735666275024
 DG-Loss=0.627005398273468
 gen=[-0.48370016, -1.0, 1.0, -0.98776495, -0.15307435]
 Batch_time_cost=0.07
Epoch ID=1
 Batch ID=350
 D-Loss=1.3656104803085327
 DG-Loss=0.6305217742919922
 gen=[-0.7182529, -1.0, 1.0, -0.9988163, -0.80599004]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=400
 D-Loss=1.3329648971557617
 DG-Loss=0.6695771217346191
 gen=[-0.4926114, -1.0, 1.0, -0.9927192, -0.18072756]
 Batch_time_cost=0.06
Epoch ID=1
 Batch ID=450
 D-Loss=1.3494174480438232
 DG-Loss=0.6599210500717163
 gen=[-0.5725556, -1.0, 1.0, -0.99533355, -0.46978894]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=0
 D-Loss=1.3190510272979736
 DG-Loss=0.6534481048583984
 gen=[-0.6627138, -1.0, 1.0, -0.9986317, -0.75535697]
 Batch_time_cost=0.07
Epoch ID=2
 Batch ID=50
 D-Loss=1.3075931072235107
 DG-Loss=0.662345826625824
 gen=[-0.5798778, -1.0, 1.0, -0.9970621, -0.46988058]
 Batch_time_cost=0.07
Epoch ID=2
 Batch ID=100
 D-Loss=1.2966625690460205
 DG-Loss=0.6551341414451599
 gen=[-0.64662725, -1.0, 1.0, -0.9989106, -0.69297016]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=150
 D-Loss=1.344698429107666
 DG-Loss=0.642096221446991
 gen=[-0.6564131, -1.0, 1.0, -0.9989565, -0.7414377]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=200
 D-Loss=1.2718031406402588
 DG-Loss=0.6544758081436157
 gen=[-0.7044133, -1.0, 1.0, -0.99934673, -0.805151]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=250
 D-Loss=1.2980257272720337
 DG-Loss=0.6246281862258911
 gen=[-0.80004555, -1.0, 1.0, -0.9998431, -0.9489559]
 Batch_time_cost=0.07
Epoch ID=2
 Batch ID=300
 D-Loss=1.3250139951705933
 DG-Loss=0.6461377143859863
 gen=[-0.5818067, -1.0, 1.0, -0.9984216, -0.5537344]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=350
 D-Loss=1.3336296081542969
 DG-Loss=0.6458700895309448
 gen=[-0.6561196, -1.0, 1.0, -0.9994945, -0.8055736]
 Batch_time_cost=0.07
Epoch ID=2
 Batch ID=400
 D-Loss=1.332587480545044
 DG-Loss=0.6567844152450562
 gen=[-0.6696944, -1.0, 1.0, -0.99964476, -0.8398167]
 Batch_time_cost=0.06
Epoch ID=2
 Batch ID=450
 D-Loss=1.3682767152786255
 DG-Loss=0.649559736251831
 gen=[-0.7100771, -1.0, 1.0, -0.99959046, -0.85634416]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=0
 D-Loss=1.3280253410339355
 DG-Loss=0.6489391326904297
 gen=[-0.7317555, -1.0, 1.0, -0.9997766, -0.89817333]
 Batch_time_cost=0.07
Epoch ID=3
 Batch ID=50
 D-Loss=1.3578753471374512
 DG-Loss=0.6147778034210205
 gen=[-0.72396195, -1.0, 1.0, -0.9998474, -0.90283734]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=100
 D-Loss=1.3406777381896973
 DG-Loss=0.6485224962234497
 gen=[-0.718904, -1.0, 1.0, -0.99990094, -0.9153805]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=150
 D-Loss=1.312596082687378
 DG-Loss=0.6498677730560303
 gen=[-0.67478544, -1.0, 1.0, -0.9998216, -0.8468633]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=200
 D-Loss=1.3386080265045166
 DG-Loss=0.6432048082351685
 gen=[-0.6547149, -1.0, 1.0, -0.99988073, -0.83443314]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=250
 D-Loss=1.326845645904541
 DG-Loss=0.6399492025375366
 gen=[-0.6931246, -1.0, 1.0, -0.99992764, -0.881379]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=300
 D-Loss=1.3271305561065674
 DG-Loss=0.6585754156112671
 gen=[-0.6919609, -1.0, 1.0, -0.999896, -0.8784841]
 Batch_time_cost=0.07
Epoch ID=3
 Batch ID=350
 D-Loss=1.3486320972442627
 DG-Loss=0.6269784569740295
 gen=[-0.6996336, -1.0, 1.0, -0.99992526, -0.8959064]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=400
 D-Loss=1.342246651649475
 DG-Loss=0.6233903169631958
 gen=[-0.77165926, -1.0, 1.0, -0.99998367, -0.9691971]
 Batch_time_cost=0.06
Epoch ID=3
 Batch ID=450
 D-Loss=1.3426618576049805
 DG-Loss=0.6478078365325928
 gen=[-0.6975712, -1.0, 1.0, -0.9999601, -0.9236312]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=0
 D-Loss=1.3098371028900146
 DG-Loss=0.6660643815994263
 gen=[-0.7207169, -1.0, 1.0, -0.9999414, -0.92297864]
 Batch_time_cost=0.07
Epoch ID=4
 Batch ID=50
 D-Loss=1.3217291831970215
 DG-Loss=0.6328650712966919
 gen=[-0.71822643, -1.0, 1.0, -0.99995583, -0.9326522]
 Batch_time_cost=0.07
Epoch ID=4
 Batch ID=100
 D-Loss=1.3669278621673584
 DG-Loss=0.6450607776641846
 gen=[-0.6679931, -1.0, 1.0, -0.99994075, -0.8517525]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=150
 D-Loss=1.333882451057434
 DG-Loss=0.6543253660202026
 gen=[-0.6812095, -1.0, 1.0, -0.99996954, -0.91442144]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=200
 D-Loss=1.3355824947357178
 DG-Loss=0.6455910205841064
 gen=[-0.7346105, -1.0, 1.0, -0.9999781, -0.9356306]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=250
 D-Loss=1.3421473503112793
 DG-Loss=0.6508429050445557
 gen=[-0.7158679, -1.0, 1.0, -0.9999834, -0.9500096]
 Batch_time_cost=0.07
Epoch ID=4
 Batch ID=300
 D-Loss=1.3312251567840576
 DG-Loss=0.6514350771903992
 gen=[-0.6991075, -1.0, 1.0, -0.99997765, -0.9242321]
 Batch_time_cost=0.07
Epoch ID=4
 Batch ID=350
 D-Loss=1.351043701171875
 DG-Loss=0.6400759220123291
 gen=[-0.7066042, -1.0, 1.0, -0.99998295, -0.9380742]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=400
 D-Loss=1.3609528541564941
 DG-Loss=0.6274992823600769
 gen=[-0.7641589, -1.0, 1.0, -0.99999577, -0.97579837]
 Batch_time_cost=0.06
Epoch ID=4
 Batch ID=450
 D-Loss=1.3415582180023193
 DG-Loss=0.652595043182373
 gen=[-0.7490909, -1.0, 1.0, -0.9999947, -0.97528887]
 Batch_time_cost=0.06
In[3]
#使用固化后的模型(训练了5轮)进行预测,预测结果保存在output中,batch_size为生成图像个数
!python c_gan/infer.py --output="./infer_result" --use_gpu=True --batch_size=4

# 可视化预测效果
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import cv2

img= cv2.imread('infer_result/generated_image.png')
plt.imshow(img)
plt.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:795: UserWarning: The current program is empty.
  warnings.warn(error_info)
W0218 11:06:43.294374   178 device_context.cc:236] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0218 11:06:43.299666   178 device_context.cc:244] device: 0, cuDNN Version: 7.3.
condition: [[5.]
 [0.]
 [4.]
 [1.]]
基于PaddlePaddle的强化学习算法Conditional GAN-LMLPHP

点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/169443 

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

>> 访问 PaddlePaddle 官网,了解更多相关内容

09-04 15:40