25
25
# Some important consts
26
26
num_examples = 669
27
27
batch_size = 30
28
- n_epochs = 500
28
+ n_epochs = 1000
29
29
save_steps = 500 # Number of training batches between checkpoint saves
30
30
31
31
checkpoint_dir = "./ckpt/"
32
32
model_name = "ConvAutoEnc.model"
33
33
logs_dir = "./logs/run1/"
34
34
35
35
# Fetch input data (faces/trees/imgs)
36
- data_dir = "./data/celebF /"
36
+ data_dir = "./data/celebG /"
37
37
data_path = os .path .join (data_dir , '*.jpg' )
38
38
data = glob (data_path )
39
39
40
40
if len (data ) == 0 :
41
41
raise Exception ("[!] No data found in '" + data_path + "'" )
42
42
43
+
44
+ '''
45
+ Some util functions from https://github.com/carpedm20/DCGAN-tensorflow
46
+ '''
47
+
43
48
def path_to_img (path , grayscale = False ):
44
49
if (grayscale ):
45
50
return scipy .misc .imread (path , flatten = True ).astype (np .float )
46
51
else :
47
52
return scipy .misc .imread (path ).astype (np .float )
48
53
54
+ def center_crop (x , crop_h , crop_w ,
55
+ resize_h = 64 , resize_w = 64 ):
56
+ if crop_w is None :
57
+ crop_w = crop_h
58
+ h , w = x .shape [:2 ]
59
+ j = int (round ((h - crop_h )/ 2. ))
60
+ i = int (round ((w - crop_w )/ 2. ))
61
+ return scipy .misc .imresize (
62
+ x [j :j + crop_h , i :i + crop_w ], [resize_h , resize_w ])
63
+
64
+ def transform (image , input_height , input_width ,
65
+ resize_height = 48 , resize_width = 48 , crop = True ):
66
+ if crop :
67
+ cropped_image = center_crop (
68
+ image , input_height , input_width ,
69
+ resize_height , resize_width )
70
+ else :
71
+ cropped_image = scipy .misc .imresize (image , [resize_height , resize_width ])
72
+ return np .array (cropped_image )/ 127.5 - 1.
73
+
74
+ def autoresize (image_path , input_height , input_width ,
75
+ resize_height = 48 , resize_width = 48 ,
76
+ crop = True , grayscale = False ):
77
+ image = path_to_img (image_path , grayscale )
78
+ return transform (image , input_height , input_width ,
79
+ resize_height , resize_width , crop )
80
+
49
81
np .random .shuffle (data )
50
82
imread_img = path_to_img (data [0 ]) # test read an image
51
83
@@ -56,9 +88,10 @@ def path_to_img(path, grayscale = False):
56
88
57
89
is_grayscale = (c_dim == 1 )
58
90
59
- # tf Graph Input
60
- # face data image of shape 84*84=7056 N.B. originally without the depth 3
61
- x = tf .placeholder (tf .float32 , [None , 42 , 42 , 3 ], name = 'InputData' )
91
+ '''
92
+ tf Graph Input
93
+ '''
94
+ x = tf .placeholder (tf .float32 , [None , 48 , 48 , 3 ], name = 'InputData' )
62
95
63
96
if __debug__ :
64
97
print ("Reading input from:" + data_dir )
@@ -67,11 +100,6 @@ def path_to_img(path, grayscale = False):
67
100
print ("Writing checkpoints to:" + checkpoint_dir )
68
101
print ("Writing TensorBoard logs to:" + logs_dir )
69
102
70
- """
71
- We start by creating the layers with name scopes so that the graph in
72
- the tensorboard looks meaningful
73
- """
74
-
75
103
76
104
# strides = [Batch, Height, Width, Channels] in default NHWC data_format. Batch and Channels
77
105
# must always be set to 1. If channels is set to 3, then we would increment the index for the
@@ -109,17 +137,17 @@ def deconv2d(input, name, kshape, n_outputs, strides=[1, 1]):
109
137
activation_fn = tf .nn .leaky_relu )
110
138
return out
111
139
140
+
112
141
# Input to maxpool: [BatchSize, Width1, Height1, Channels]
113
142
# Output of maxpool: [BatchSize, Width2, Height2, Channels]
114
143
#
115
144
# To calculate the size of the output of maxpool layer:
116
145
# OutWidth = (InWidth - FilterWidth)/Stride + 1
117
146
#
118
147
# The kernel kshape will typically be [1,2,2,1] for a general
119
- # RGB image input of [batch_size,64,64 ,3]
148
+ # RGB image input of [batch_size,48,48 ,3]
120
149
# kshape is 1 for batch and channels because we don't want to take
121
150
# the maximum over multiple examples of channels.
122
-
123
151
def maxpool2d (x ,name ,kshape = [1 , 2 , 2 , 1 ], strides = [1 , 2 , 2 , 1 ]):
124
152
with tf .variable_scope (name ):
125
153
out = tf .nn .max_pool (x ,
@@ -161,53 +189,40 @@ def ConvAutoEncoder(x, name, reuse=False):
161
189
with tf .variable_scope (name ) as scope :
162
190
if reuse :
163
191
scope .reuse_variables ()
164
- """
165
- We want to get dimensionality reduction of 11664 to 44656
166
- Layers:
167
- input --> 84, 84 (7056)
168
- conv1 --> kernel size: (5,5), n_filters:25 ???make it small so that it runs fast
169
- pool1 --> 42, 42, 25
170
- dropout1 --> keeprate 0.8
171
- reshape --> 42*42*25
172
- FC1 --> 42*42*25, 42*42*5
173
- dropout2 --> keeprate 0.8
174
- FC2 --> 42*42*5, 8820 --> output is the encoder vars
175
- FC3 --> 8820, 42*42*5
176
- dropout3 --> keeprate 0.8
177
- FC4 --> 42*42*5,42*42*25
178
- dropout4 --> keeprate 0.8
179
- reshape --> 42, 42, 25
180
- deconv1 --> kernel size:(5,5,25), n_filters: 25
181
- upsample1 --> 84, 84, 25
182
- FullyConnected (outputlayer) --> 84* 84* 25, 84 * 84 * 1
183
- reshape --> 84 * 84
184
- """
185
- input = tf .reshape (x , shape = [- 1 , 42 , 42 , 3 ])
186
-
187
- # coding part
188
- c1 = conv2d (input , name = 'c1' , kshape = [7 , 7 , 3 , 25 ]) # kshape = [k_h, k_w, in_channels, out_chnnels]
189
- p1 = maxpool2d (c1 , name = 'p1' )
192
+
193
+ input = tf .reshape (x , shape = [- 1 , 48 , 48 , 3 ])
194
+
195
+ # kshape = [k_h, k_w, in_channels, out_chnnels]
196
+ c1 = conv2d (input , name = 'c1' , kshape = [7 , 7 , 3 , 15 ]) # Input: [48,48,3]; Output: [48,48,15]
197
+ p1 = maxpool2d (c1 , name = 'p1' ) # Input: [48,48,15]; Output: [24,24,15]
190
198
do1 = dropout (p1 , name = 'do1' , keep_rate = 0.75 )
191
- do1 = tf .reshape (do1 , shape = [- 1 , 21 * 21 * 25 ]) # reshape to 1 dimensional (-1 is batch size)
192
- fc1 = fullyConnected (do1 , name = 'fc1' , output_size = 21 * 21 * 5 )
199
+ c2 = conv2d (do1 , name = 'c2' , kshape = [5 , 5 , 15 , 25 ]) # Input: [24,24,15]; Output: [24,24,25]
200
+ p2 = maxpool2d (c2 , name = 'p2' ) # Input: [24,24,25]; Output: [12,12,25]
201
+ p2 = tf .reshape (p2 , shape = [- 1 , 12 * 12 * 25 ]) # Input: [12,12,25]; Output: [12*12*25]
202
+ fc1 = fullyConnected (p2 , name = 'fc1' , output_size = 12 * 12 * 5 ) # Input: [12*12*25]; Output: [12*12*5]
193
203
do2 = dropout (fc1 , name = 'do2' , keep_rate = 0.75 )
194
- fc2 = fullyConnected (do2 , name = 'fc2' , output_size = 21 * 21 * 3 )
204
+ fc2 = fullyConnected (do2 , name = 'fc2' , output_size = 12 * 12 * 3 ) # Input: [12*12*5]; Output: [12*12*3]
205
+ do3 = dropout (fc2 , name = 'do3' , keep_rate = 0.75 )
206
+ fc3 = fullyConnected (do3 , name = 'fc3' , output_size = 64 ) # Input: [12*12*3]; Output: [64] --> bottleneck layer
195
207
# Decoding part
196
- fc3 = fullyConnected (fc2 , name = 'fc3' , output_size = 21 * 21 * 5 )
197
- do3 = dropout (fc3 , name = 'do3' , keep_rate = 0.75 )
198
- fc4 = fullyConnected (do3 , name = 'fc4' , output_size = 21 * 21 * 25 )
199
- do4 = dropout (fc4 , name = 'do3' , keep_rate = 0.75 )
200
- do4 = tf .reshape (do4 , shape = [- 1 , 21 , 21 , 25 ])
201
- dc1 = deconv2d (do4 , name = 'dc1' , kshape = [7 ,7 ],n_outputs = 25 )
202
- up1 = upsample (dc1 , name = 'up1' , factor = [2 , 2 ])
203
- output = fullyConnected (up1 , name = 'output' , output_size = 42 * 42 * 3 )
204
- # print("output.shape"+str(output.shape))
205
- # print("x.shape"+str(x.shape))
208
+ fc4 = fullyConnected (fc3 , name = 'fc4' , output_size = 12 * 12 * 3 ) # Input: [64]; Output: [12*12*3]
209
+ do4 = dropout (fc4 , name = 'do4' , keep_rate = 0.75 )
210
+ fc5 = fullyConnected (do4 , name = 'fc5' , output_size = 12 * 12 * 5 ) # Input: [12*12*3]; Output: [12*12*5]
211
+ do5 = dropout (fc5 , name = 'do5' , keep_rate = 0.75 )
212
+ fc6 = fullyConnected (do5 , name = 'fc6' , output_size = 21 * 21 * 25 ) # Input: [12*12*5]; Output: [12*12*25]
213
+ do6 = dropout (fc6 , name = 'do6' , keep_rate = 0.75 )
214
+ do6 = tf .reshape (do6 , shape = [- 1 , 21 , 21 , 25 ]) # Input: [12*12*25]; Output: [12,12,25]
215
+ dc1 = deconv2d (do6 , name = 'dc1' , kshape = [5 , 5 ],n_outputs = 15 ) # Input: [12,12,25]; Output: [12,12,15]
216
+ up1 = upsample (dc1 , name = 'up1' , factor = [2 , 2 ]) # Input: [12,12,15]; Output: [24,24,15]
217
+ dc2 = deconv2d (up1 , name = 'dc2' , kshape = [7 , 7 ],n_outputs = 3 ) # Input: [24,24,15]; Output: [24,24,3]
218
+ up2 = upsample (dc2 , name = 'up2' , factor = [2 , 2 ]) # Input: [24,24,3]; Output: [48,48,3]
219
+ output = fullyConnected (up2 , name = 'output' , output_size = 48 * 48 * 3 )
220
+
206
221
with tf .variable_scope ('cost' ):
207
222
# N.B. reduce_mean is a batch operation! finds the mean across the batch
208
- cost = tf .reduce_mean (tf .square (tf .subtract (output , tf .reshape (x ,shape = [- 1 ,42 * 42 * 3 ]))))
209
- return x , tf .reshape (output ,shape = [- 1 ,42 , 42 ,3 ]), cost # returning, input, output and cost
210
- # ---------------------------------
223
+ cost = tf .reduce_mean (tf .square (tf .subtract (output , tf .reshape (x ,shape = [- 1 ,48 * 48 * 3 ]))))
224
+ return x , tf .reshape (output ,shape = [- 1 ,48 , 48 ,3 ]), cost # returning, input, output and cost
225
+
211
226
212
227
def train_network (x ):
213
228
@@ -248,16 +263,15 @@ def train_network(x):
248
263
print ("epoch " + str (epoch ) + " batch " + str (i ))
249
264
250
265
batch_files = data [i * batch_size :(i + 1 )* batch_size ] # get the current batch of files
251
- # TODO: add functionality to autorescale
252
- # batch = [
253
- # get_image(batch_file,
254
- # input_height=42,
255
- # input_width=42,
256
- # resize_height=42,
257
- # resize_width=42,
258
- # crop=True,
259
- # grayscale=False) for batch_file in batch_files] # get_image will get image from file dir after applying resize operation.
260
- batch = [path_to_img (batch_file ) for batch_file in batch_files ]
266
+
267
+ batch = [autoresize (batch_file ,
268
+ input_height = 48 ,
269
+ input_width = 48 ,
270
+ resize_height = 48 ,
271
+ resize_width = 48 ,
272
+ crop = True ,
273
+ grayscale = False ) for batch_file in batch_files ]
274
+
261
275
batch_images = np .array (batch ).astype (np .float32 )
262
276
263
277
# Get cost function from running optimizer
@@ -287,6 +301,7 @@ def save(saver, step, session):
287
301
os .path .join (checkpoint_dir , model_name ),
288
302
global_step = step )
289
303
304
+
290
305
# Restore from checkpoint
291
306
def restore (saver , session ):
292
307
print (">>> Restoring from checkpoints..." )
@@ -300,4 +315,4 @@ def restore(saver, session):
300
315
else :
301
316
return False , 0
302
317
303
- train_network (x )
318
+ train_network (x )
0 commit comments