之前在DCGAN文章简单解读里说明了DCGAN的原理。本次来实现一个DCGAN,并在数据集上实际测试它的效果。本次的代码来自github开源代码DCGAN-tensorflow,感谢carpedm20的贡献!
1. 代码结构
代码结构如下图1所示:
我们主要关注的文件为download.py,main.py,model.py,ops.py以及utils.py。其实看文件名字就大概可以猜出各个文件的作用了。
- download.py主要下载数据集到本地,这里我们需要下载三个数据集:MNIST,lsun以及celebA。
- main.py是主函数,用于配置命令行参数以及模型的训练和测试。
- model.py 是定义DCGAN模型的地方,也是我们要重点关注的代码。
- ops.py 定义了很多构造模型的重要函数,比如batch_norm(BN操作),conv2d(卷积操作),deconv2d(翻卷积操作)等。
utils.py 定义很多有用的全局辅助函数。
2. 代码简单解读
2.1 download.py
download.py代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187"""
Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py
Downloads the following:
- Celeb-A dataset
- LSUN dataset
- MNIST dataset
"""
from __future__ import print_function
import os
import sys
import gzip
import json
import shutil
import zipfile
import argparse
import requests
import subprocess
from tqdm import tqdm
from six.moves import urllib
parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
help='name of dataset to download [celebA, lsun, mnist]')
def download(url, dirpath):
filename = url.split('/')[-1]
filepath = os.path.join(dirpath, filename)
u = urllib.request.urlopen(url)
f = open(filepath, 'wb')
filesize = int(u.headers["Content-Length"])
print("Downloading: %s Bytes: %s" % (filename, filesize))
downloaded = 0
block_sz = 8192
status_width = 70
while True:
buf = u.read(block_sz)
if not buf:
print('')
break
else:
print('', end='\r')
downloaded += len(buf)
f.write(buf)
status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize))
print(status, end='')
sys.stdout.flush()
f.close()
return filepath
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={ 'id': id }, stream=True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, chunk_size=32*1024):
total_size = int(response.headers.get('content-length', 0))
with open(destination, "wb") as f:
# 显示进度条
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
unit='B', unit_scale=True, desc=destination):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def unzip(filepath):
print("Extracting: " + filepath)
dirpath = os.path.dirname(filepath)
with zipfile.ZipFile(filepath) as zf:
zf.extractall(dirpath)
os.remove(filepath)
def download_celeb_a(dirpath):
data_dir = 'celebA'
# ./data/celebA
if os.path.exists(os.path.join(dirpath, data_dir)):
print('Found Celeb-A - skip')
return
filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
# ./data/img_align_celeba.zip
save_path = os.path.join(dirpath, filename)
if os.path.exists(save_path):
print('[*] {} already exists'.format(save_path)) # 文件已经存在
else:
download_file_from_google_drive(drive_id, save_path)
zip_dir = ''
with zipfile.ZipFile(save_path) as zf:
zip_dir = zf.namelist()[0] # 解压以后默认文件夹的名字
zf.extractall(dirpath) # 提取文件到该文件夹
os.remove(save_path) # 移除压缩文件
# 重命名文件夹
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
def _list_categories(tag):
url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
f = urllib.request.urlopen(url)
return json.loads(f.read())
def _download_lsun(out_dir, category, set_name, tag):
# locals(),Return a dictionary containing the current scope's local variables
url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
'&category={category}&set={set_name}'.format(**locals())
print(url)
if set_name == 'test':
out_name = 'test_lmdb.zip'
else:
out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
# out_path:./data/lsun/xxx.zip
out_path = os.path.join(out_dir, out_name)
cmd = ['curl', url, '-o', out_path]
print('Downloading', category, set_name, 'set')
# 调用linux命令
subprocess.call(cmd)
def download_lsun(dirpath):
data_dir = os.path.join(dirpath, 'lsun')
if os.path.exists(data_dir):
print('Found LSUN - skip')
return
else:
os.mkdir(data_dir)
tag = 'latest'
#categories = _list_categories(tag)
categories = ['bedroom']
for category in categories:
_download_lsun(data_dir, category, 'train', tag)
_download_lsun(data_dir, category, 'val', tag)
_download_lsun(data_dir, '', 'test', tag)
def download_mnist(dirpath):
data_dir = os.path.join(dirpath, 'mnist')
if os.path.exists(data_dir):
print('Found MNIST - skip')
return
else:
os.mkdir(data_dir)
url_base = 'http://yann.lecun.com/exdb/mnist/'
file_names = ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']
for file_name in file_names:
url = (url_base+file_name).format(**locals())
print(url)
out_path = os.path.join(data_dir,file_name)
cmd = ['curl', url, '-o', out_path]
print('Downloading ', file_name)
subprocess.call(cmd)
cmd = ['gzip', '-d', out_path]
print('Decompressing ', file_name)
subprocess.call(cmd)
def prepare_data_dir(path = './data'):
if not os.path.exists(path):
os.mkdir(path)
if __name__ == '__main__':
args = parser.parse_args()
prepare_data_dir()
# 如果datasets参数是 ['CelebA', 'celebA', 'celebA'] 其中之一
if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']):
download_celeb_a('./data')
if 'lsun' in args.datasets:
download_lsun('./data')
if 'mnist' in args.datasets:
download_mnist('./data')首先需要导入的包中,gzip和zipfile用于文件压缩和解压缩相关;argparse用于构建命令行参数;requests用于http请求下载网络文件资源;subprocess用于运行shell命令;tqdm用于进度条显示;six包用于python2和python3的兼容,比如 from six.moves import urllib 这句就是导入python2.x的urllib库。
- 上面的代码除了原作者加的注释之外,我也已经加了一部分注释,意思应该比较好理解了。主要做的事情,就是利用requests库从网络上将mnist,lsun以及celebA这三个数据集下载下来,保存在data目录下。注意mnist和celebA数据集下载下来之后还进行了解压缩。
- 上面的三个数据集,mnist是著名的手写数字数据库,大家应该都已经很熟悉了;lsun是大型场景理解数据集(large-scale-scene-understanding);celebA是一个开源的人脸数据库。除了mnist之外,其余两个数据集体积都较大,celebA大概有20w+的图像,压缩文件体积为1.4G;而lsun有很多个场景不同的数据集,如果按照上面的脚本下载,下载的文件为bedroom数据集,压缩文件有46G之大,而且其实下载下来的文件解压后为mdb(Access数据库)格式,不是原始图片格式,不方便处理。所以我们实际会下载其他的数据集作为替代,比如这个room layout estimation(2G)数据。如果使用download.py脚本下载速度较慢的话,可以自行下载好数据集,然后放在data目录下即可。
2.2 main.py
main.py代码如下:
1 | import os |
- 这里需要注意的是 flags = tf.app.flags 用于tensorflow构建命令行参数, flags.DEFINE_xxx(param,default,description) 用于定义命令行参数及其取值,第一个参数param是具体参数值,第二个参数default是参数默认取值,第三个参数description是参数描述字符串。
- 在构建了sess之后,我们需要区分数据集是mnist还是其他数据集。因为mnist比较特殊,它有10个类别的数字图像,所以我们在构建DCGAN的时候需要额外多传递一个y_dim=10参数。 show_all_variables 函数用于显示model所有变量的具体信息。
- 接下来如果是训练状态( FLAGS.train == True ),则进行模型训练( dcgan.train(FLAGS) ;否则进行测试,即加载之前训练时候保存的checkpoint文件,然后调用 visualize 函数进行test(该函数可以生成image或者gif,可视化展示训练的效果)。
- tf.app.run() 是常用的tensorflow运行的起始命令。
2.3 model.py
model.py代码如下:
1 | from __future__ import division |
- from __future__ import division 这句话当python的版本为2.x时生效,可以让两个整数数字相除的结果返回一个浮点数(在python2中默认是整数,python3默认为浮点数)。glob可以以简单的正则表达式筛选的方式返回某个文件夹下符合要求的文件名列表。
- DCGAN的构造方法除了设置一大堆的属性之外,还要注意区分dataset是否是mnist,因为mnist是灰度图像,所以应该设置channel = 1( self.c_dim = 1 ),如果是彩色图像,则 self.c_dim = 3 or self.c_dim = 4 。然后就是build_model。
- self.generator 用于构造生成器; self.discriminator 用于构造鉴别器; self.sampler 用于随机采样(用于生成样本)。这里需要注意的是, self.y 只有当dataset是mnist的时候才不为None,不是mnist的情况下,只需要 self.z 即可生成samples。
- sigmoid_cross_entropy_with_logits 函数被重新定义了,是为了兼容不同版本的tensorflow。该函数首先使用sigmoid activation,然后计算cross-entropy loss。
- self.g_loss 是生成器损失; self.d_loss_real 是真实图片的鉴别器损失; self.d_loss_fake 是虚假图片(由生成器生成的fake images)的损失; self.d_loss 是总的鉴别器损失。
- 这里的 histogram_summary 和 scalar_summary 是为了在后续在tensorboard中对各个损失函数进行可视化。
- tf.trainable_variables() 可以获取model的全部可训练参数,由于我们在定义生成器和鉴别器变量的时候使用了不同的name,因此我们可以通过variable的name来获取得到self.d_vars(鉴别器相关变量),self.g_vars(生成器相关变量)。 self.saver = tf.train.Saver() 用于保存训练好的模型参数到checkpoint。
- train 函数是核心的训练函数。这里optimizer和DCGAN的原文保持一直,选用Adam优化函数, lr=0.0002 , beta1=0.5 。 merge_summary 函数和 SummaryWriter 用于构建summary,在tensorboard中显示。
- sample_z 是从[-1,1]的均匀分布产生的。如果dataset是mnist,则可以直接读取sample_inputs和sample_labels。否则需要手动逐个处理图像, get_image
返回的是取值为(-1,1)的,shape为(resize_height,resize_width)的ndarray。如果处理的图像是灰度图像,则需要再增加一个dim,表示图像的channel=1,对应的代码是 sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] 。 - 接下来通过 self.sess.run([d_optim,… 和 self.sess.run([g_optim,…) 来更新鉴别器和生成器。 self.writer.add_summary(summary_str, counter) 增加summary到writer。由于同样的原因,这里仍然需要区分mnist和其他的数据集,所以计算最优化函数的过程需要一个if和一个else。
- np.mod(counter, config.print_every) == 1 表示每print_every次生成一次samples; np.mod(counter, config.checkpoint_every) == 2 表示每checkpoint_every次保存一下checkpoint file。
- 下面是discriminator(鉴别器)的具体实现。首先鉴别器使用conv(卷积)操作,激活函数使用leaky-relu,每一个layer需要使用batch normalization。tensorflow的batch normalization使用 tf.contrib.layers.batch_norm 实现。如果不是mnist,则第一层使用leaky-relu+conv2d,后面三层都使用conv2d+BN+leaky-relu,最后加上一个one hidden unit的linear layer,再送入sigmoid函数即可;如果是mnist,则 yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 首先给y增加两维,以便可以和image连接起来,这里实际上是使用了conditional GAN(条件GAN)的思想。 x = conv_cond_concat(image, yb) 得到condition和image合并之后的结果,然后 h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name=’d_h0_conv’)) 进行卷积操作。第二次进行conv2d+leaky-relu+concat操作。第三次进行conv2d+BN+leaky-relu+reshape+concat操作。第四次进行linear+BN+leaky-relu+concat操作。最后同样是linear+sigmoid操作。
- 下面是generator(生成器)的具体实现。和discriminator不同的是,generator需要使用deconv(反卷积)以及relu 激活函数。generator的结构是:1.如果不是mnist:linear+reshape+BN+relu—->(deconv+BN+relu)x3 —->deconv+tanh;2.如果是mnist,则除了需要考虑输入z之外,还需要考虑label y,即需要将z和y连接起来(Conditional GAN),具体的结构是:reshape+concat—->linear+BN+relu+concat—->linear+BN+relu+reshape+concat—->deconv+BN+relu+concat—->deconv+sigmoid。注意的最后的激活函数没有采用通常的tanh,而是采用了sigmoid(其输出会直接映射到0-1之间)。
- sampler函数是采样函数,用于生成样本送入当前训练的生成器,查看训练效果。其逻辑和generator函数基本类似,也是需要区分是否是mnist,二者需要采用不同的结构。不是mnist时,y=None即可;否则mnist还需要考虑y。
- load_mnist 函数用于加载mnist数据集; save 函数用于保存checkpoint; load 函数用于加载checkpoint。
2.4 ops.py
ops.py代码如下:
1 | import math |
- 第9行到第20行的代码是为了保持tf0.x和tf1.x版本的兼容性。tf0.x版本使用tf.xxx_summary风格的函数,而tf1.x版本则使用tf.summary.xxx风格的函数。为了保持一致性,通过重命名统一成tf.xxx_summary风格了。
- 22行到27行重新定义了concat函数,也是为了兼容性考虑, if “concat_v2” in dir(tf): 这句话是说如果tf有concat_v2这个方法的话,tf0.x中使用concat_v2函数,而tf1.x版本中使用concat函数。
- 29行到44行定义了batch_norm类。需要注意的是37-44行定义了类的__call__特殊方法,这个方法的作用是可以将类像普通的函数那样直接调用,而不用先构造一个对象再调用方法,这是常用的一个技巧。tf中的batch normalization 是函数 tf.contrib.layers.batch_norm
- conv_cond_concat函数的作用是将conv(卷积)和cond(条件)concat起来。在mnist的generator和discriminator中会用到。
- 54行到65行的conv2d函数重新定义了卷积操作,主要是封装了 tf.nn.conv2d 函数。
- 68行到91行定义了deconv2d(反卷积)函数。tf0.x的反卷积函数为 tf.nn.deconv2d ,tf1.x的反卷积函数为 tf.nn.conv2d_transpose 。最后还加上了一个bias( tf.nn.bias_add )。
- 94到95行定义了leaky-relu函数lrelu。其实就一行代码: tf.maximum(x, leak*x) 。
- 97行到109行定义了linear函数,其实就是一个fully_connected layer。
2.5 utils.py
utils.py代码如下:
1 | """ |
utils.py定义了很多有用的全局工具函数,可以直接被其他的脚本调用。
- glob库用来list 某一个文件夹下的files;os库用来操作路径和文件夹等;pprint用于美观打印;gtime和strftime有用格式化日期;scipy.misc包含了很多和图像相关的有用的函数。
- 24-27行的show_all_variables函数,调用了 slim.model_analyzer.analyze_vars(vars,print_info) 函数来打印model所有variables的信息。
- 39-43行的imread函数封装了 scipy.misc.imread 函数,该函数参数 flatten = True 表示将color layer 展平成一个single gray-scale layer。
- 48-73行的merge函数用于从一系列小图产生大图,images[0]表示小图的个数,h=images[1]表示小图的高,w = images[2]表示小图的宽,x_h = size[0]表示最终大图height应该扩展的倍数,x_w = size[1]表示最终大图width应该扩展的倍数。该函数最终生成一个高为h*x_h,宽为w*x_w的大图。表示大图的高度方向包含x_h个小图,宽度方向包含x_w个小图。
- 75-85行定义了保存图像的imsave函数。注意 np.squeeze 可以去除数组中维度为1的那些维(降维),与之相反的操作是 np.expand_dims(arr,axis) 函数,可以给指定的axis维度增加一维。
- 87-104行的center_crop函数的作用是中心化剪切处理,同时对图像进行了resize操作。
- 106-126行的transform函数,也是对图像进行center_crop(可选)以及resize操作,只不过它最后将image array的每个元素的取值范围从(0,255)映射到(-1,1),(-1,1)是tanh函数的取值范围。
- 132-193行的to_json函数将各个layers结构保存到json文件,我们不用这个函数,就不细说了。
- 195-215行的make_gif函数可以将生成的序列图像转换为gif图像,这里使用moviepy库来完成这个工作,关于moviepy的介绍和使用,可以参考我之前的一篇文章。
- 217-298行的visualize用于测试阶段生成图像样本,可以是单个jpg格式的图像,也可以是gif图像,还可以是小图拼接成的大图。visualize函数通过option变量的取值(可以取0,1,2,3,4五个值)来控制以五种不同的方式保存结果。
- option=0:这种情况只适用于dataset 不等于mnist的情况,直接将samples merge成一个大图,然后保存即可,其中大图共有batch_size张小图,每行和每列各有ceil(sqrt(batch_size))个;
- option=1:这种情况和option=0类似,只是它考虑到了dataset为mnist的情况,如果是mnist,则会随机生成batch_size个digit labels,然后从generator生成相应的数字,最后拼接成一个大图,这里我自己定义了一个save_random_digits函数用于将每次随机生成的数字保存到txt文件中去,这样后续可以验证生成的数字图像是否是我们希望生成的;
- option=2:这种情况下,不会生成一张大图,而是生成含有batch_size帧的gif图,默认时间是2s,如果生成gif失败,则会生成和option=1一样的大图;
- option=3:不能是mnist数据集,生成和option=2一样的gif。
- option=4:合成一张大图的gif,一共有batch_size个大图,每个大图由z_dim(生成样本数目)个小图组成。
- 300-316行的save_random_digits函数是我自定义的函数,用于将随机数字保存到txt文件;
- 最后326-346行的resize_imgs函数是我自己添加的,作用就是将指定文件夹下的图像resize成指定的大小,这样我们就可以利用自己的数据集训练model了。
3. 代码运行结果(生成图像效果验证)
1. mnist
根据我们上面的解读,运行如下命令即可以使用mnist训练DCGAN:
python3 main.py --dataset=mnist --input_height=28 --output_height=28 --train True你需要确保main.py目录下的data/mnist文件夹下有已经解压缩的mnist数据文件。由于mnist数据规模不大,所以使用gpu训练大概只需要几十分钟。训练完成之后,训练过程中采样得到的生成图片保存在samples文件夹下,第一次采样和最后一次采样得到图片分别为下图1和图2所示:
python3 main.py --dataset=mnist --input_height=28 --output_height=28 --train False测试默认会生成100张合成的大图,我们随机抽取一张,比如第66张吧,其真实的随机数字排列和生成的手写数字如下图3和图4所示:
python3 main.py --dataset celebA --input_height=108 --crop --train True \ --epoch 2 --sample_dir ./celebA_samples --visualize True注意默认训练采样保存的文件夹是samples文件夹,由于我们已经把mnist的结果保存在那里了,如果继续使用这个文件夹,celebA的结果会把之前的文件覆盖掉。为了避免这样的情况,我们重新设定保存sample的文件夹为celebA_samples文件夹,这个文件夹会在运行过程中自动创建,不需要手动创建。由于celebA的数据集规模较大,我电脑的配置是:ubuntu 16.04,tensorflow1.4.1,cuda8+cudnn6,显卡是nvidia GTX950M,显存4G。在batch_size = 64的情况下,大概1.5s可以训练一个batch,因此如果按照默认配置epoch=25,一个epoch的batch_num = ceil(202602/64)=3166,因此全部训练完大约需要的时间为1.5\*3166\*25/3600 ≈33h。由于我没有台式机,自己的笔记本不太可能一直训练这么长时间;机房的电脑配置太渣,train不动。所以我只能随便train一下了。我甚至一轮都没有训练完就停下来了。第1个epoch第100个batch生成的图像如下图5所示:
python3 main.py --dataset celebA --input_height=108 --crop --train False \ --checkpoint_dir ./checkpoint --sample_dir ./celebA_samples当然你仍然可以通过设定option的值来控制test的输出。下面的图7和图8是生成的gif图(图8由于体积太大已经转为jpg格式),由于训练非常不充分,因此效果不佳,但是仍然有脸部的轮廓:
python3 main.py --dataset beauty_girls --input_height=108 --crop --train True \ --epoch 500 --sample_dir ./beauty_girls_samples --visualize True \ --print_every 10 --checkpoint_every 240
这一次因为图片数量只有2000,所以我设定要训练500轮,我在晚上睡觉的时候用笔记本跑了一下,这下却翻车了,训练采样得到的图片是这样的:
可以发现从第1轮到第300轮生成图片的质量是提高的,但是再往后训练,特别是到了最后500轮的时候,图像明显花了,很多小图都是相似的看不懂的模式(也就是论文里说的mode collapse),这说明最多训练到300轮左右模型就已经差不多收敛了,再往后效果可能会更差,也许会发生mode collapse这种现象。这一点和论文最后提到的是一致的。而且可以发现即使是最好的生成图片,质量也不是特别好,这可能主要是与训练样本数太少(只有2000)而且图像风格差异太大引起的。最后,不要问我要原始训练图片,是拿什么图片训练的,你看生成图片难道猜不到么?哈哈哈。
#### 5. girl_face
#### 这个数据集来自知乎网友Best July的文章:用DCGAN生成女朋友,有兴趣大家可以看看这篇文章。该数据集包含了剪切好的8000多张妹子的头像,大小都是96x96的。差不多是下面这种:
数据集大家可以去faces下载,密码:09h9。运行下面的命令即可以开始训练:
python3 main.py –dataset girl_face –input_height=96 –crop –train True \
–epoch 200 –sample_dir ./girl_face –visualize True \
–print_every 30 –checkpoint_every 300
你需要确保将包含图片数据的girl_face文件夹放在data目录下,我们设定训练200轮,全部训练完成估计要5,6个小时。下图11(从上至下)是分别训练1轮,30轮,70轮,100轮,130轮以及170轮时候产生的图像,可以发现随着训练轮数的增加,生成图像的质量是逐渐增加的,大概到100轮左右的时候,其实生成的头像质量已经很不错了(可以发现是美女了),后续个别位置的小图质量有所增加,但是始终有一些小图有一些畸变,不是特别自然。但是总体上来说,生成的图片质量很不错了。
训练完成之后,我们使用训练得到的model进行test,但是其实有一个问题我们之前没有提到,那就是如果训练轮数设定的过多,那么最新的一个checkpoint加载得到的model未必是最优的,最优的可能在中间的某一个epoch。但是原代码只能加载最新的一个checkpoint,所以我们将model.py中的 load 函数修改如下:
1 | # load checkpoints file |
主要的修改就是增加了一个checkpoint_name参数,用于指定特定的而不是最新的checkpoint file。同时我们增加了一个checkpoint_name命令行参数: flags.DEFINE_string(“checkpoint_name”,None,”the name of the loaded checkpoint file,default is the lastest checkpoint”) 用来指定checkpoint_name参数,默认值是None。
另外还有一个问题就是,在train的时候sample的样本,输入噪声z是服从(-1,1)的均匀分布,而原代码的visualize函数在option=1,2,3,4的时候,sample不是通过(-1,1)的均匀分布采样得到的,经过我的实验,如果在option=1,2,3,4的时候直接用原代码进行test,得到的生成图片几乎都是模糊的。我猜想这是因为test和train的时候的输入采样分布不一致导致的结果。因此我也对utils.py的visualize函数进行了修改如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95def visualize(sess, dcgan, config, option):
# 用于可视化
image_frame_dim = int(math.ceil(config.batch_size**.5)) # 图片尺寸
if option == -1:
# noise
z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim],
'./%s/test_%s.png' % (config.sample_dir, strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
elif option == 0:
# noise
z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
elif option == 1: # 将samples生成大图
#values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))
# for kdx, z in enumerate(z_sample):
# z[idx] = values[kdx]
if config.dataset == "mnist":
# y是batch_size个0-9之间的随机数
y = np.random.choice(10, config.batch_size)
save_random_digits(y,image_frame_dim,image_frame_dim,'./%s/test_arange_%s.txt' % (config.sample_dir,idx))
y_one_hot = np.zeros((config.batch_size, 10))
y_one_hot[np.arange(config.batch_size), y] = 1
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
else:
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_arange_%s.png' % (config.sample_dir,idx))
elif option == 2:
# values = np.arange(0, 1, 1./config.batch_size)
# idx是随机的
# for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
# z_dim:test_images_num
#z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
# np.tile:按照指定的维度将array重复
# z_sample shape:(batch_size,z_dim)
#z_sample = np.tile(z, (config.batch_size, 1))
#z_sample = np.zeros([config.batch_size, dcgan.z_dim])
# for kdx, z in enumerate(z_sample):
# z[idx] = values[kdx]
z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
if config.dataset == "mnist":
y = np.random.choice(10, config.batch_size)
#save_random_digits(y, image_frame_dim, image_frame_dim, './%s/test_%s.txt' % % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
y_one_hot = np.zeros((config.batch_size, 10))
y_one_hot[np.arange(config.batch_size), y] = 1
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
else:
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
try:
make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx),4)
except:
save_images(samples, [image_frame_dim, image_frame_dim], './%s/test_%s.png' % (config.sample_dir,strftime("%Y-%m-%d-%H-%M-%S", gmtime())))
elif option == 3: # 不能是mnist,直接生成gif
# values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
# z_sample = np.zeros([config.batch_size, dcgan.z_dim])
# for kdx, z in enumerate(z_sample):
# z[idx] = values[kdx]
z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
make_gif(samples, './%s/test_gif_%s.gif' % (config.sample_dir,idx),4)
elif option == 4:
image_set = []
# values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
print(" [*] %d" % idx)
# z_sample = np.zeros([config.batch_size, dcgan.z_dim])
# for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
make_gif(image_set[-1], './%s/test_gif_%s.gif' % (config.sample_dir,idx),12)
# 合成一张大图gif(64张大图)
new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
for idx in range(63, -1, -1)] # 63-0
make_gif(new_image_set, './%s/test_gif_merged.gif' % config.sample_dir, duration=8)
elif option == 5:
#保存单个的小图
z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
for i,sample in enumerate(samples):
scipy.misc.imsave("./%s/single_test_%s.png" %(config.sample_dir,i),sample)
主要的修改是将所有的采样方式都改为(-1,1)的均匀分布: z_sample = np.random.uniform(-1, 1, size=(config.batch_size, dcgan.z_dim)) 。实验发现,这种方式在test的时候是非常有效的。另外,我保留了option=0的情况不变,增加了option=-1的情况以及option=5的情况。option=5表示将生成的图片按小图保存。下面的几张图展示了test的结果: