-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
91 lines (72 loc) · 2.86 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from PIL import Image
from glob import glob
import tensorflow as tf
def get_loader(root, batch_size, scale_size, data_format, split=None, is_grayscale=False, seed=None):
dataset_name = os.path.basename(root)
for ext in ["jpg", "png"]:
paths = glob("{}/*.{}".format(root, ext))
if ext == "jpg":
tf_decode = tf.image.decode_jpeg
elif ext == "png":
tf_decode = tf.image.decode_png
if len(paths) != 0:
break
print "get loader paths is ", paths[0]
with Image.open(paths[0]) as img:
w, h = img.size
shape = [h, w, 3]
filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed)
reader = tf.WholeFileReader()
filename, data = reader.read(filename_queue)
image = tf_decode(data, channels=3)
if is_grayscale:
image = tf.image.rgb_to_grayscale(image)
image.set_shape(shape)
min_after_dequeue = 5000
capacity = min_after_dequeue + 3 * batch_size
queue = tf.train.shuffle_batch(
[image], batch_size=batch_size,
num_threads=4, capacity=capacity,
min_after_dequeue=min_after_dequeue, name='synthetic_inputs')
if dataset_name in ['CelebA']:
queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128)
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
elif dataset_name in ['celeba_1']:
queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128)
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
else:
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
if data_format == 'NCHW':
queue = tf.transpose(queue, [0, 3, 1, 2])
elif data_format == 'NHWC':
pass
else:
raise Exception("[!] Unkown data_format: {}".format(data_format))
return tf.to_float(queue)
if __name__ == "__main__":
from my_utils import prepare_dirs_and_logger
from my_config import get_config
import numpy as np
config, unparsed = get_config()
print prepare_dirs_and_logger(config)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
if config.is_train:
print "line 72 ..."
data_path = config.data_path
batch_size = config.batch_size
do_shuffle = True
else:
print "line 77 ..."
setattr(config, 'batch_size', 64)
if config.test_data_path is None:
data_path = config.data_path
else:
data_path = config.test_data_path
batch_size = config.sample_per_image
do_shuffle = False
data_loader = get_loader(data_path, config.batch_size, config.input_scale_size, config.data_format)
print "model_name1: ", config.model_name
print "data_path2: ", config.data_path
print "model_dir3: ", config.model_dir