-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
109 lines (87 loc) · 4.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import math
from datasets import dataset_factory
from preprocessing import preprocessing_factory
FLAGS = tf.app.flags.FLAGS
def prepare_traindata(dataset_dir, batch_size):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset, num_readers=4, shuffle=True)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.preprocessing, is_training=True)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=4,
capacity=8 * batch_size, min_after_dequeue=4 * batch_size)
return images, labels
def prepare_testdata(dataset_dir, batch_size):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'test', dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=1, shuffle=False)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.preprocessing, is_training=False)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch([image, label], batch_size=batch_size, num_threads=1,
capacity=4 * batch_size, allow_smaller_final_batch=False)
return images, labels
def config_lr(max_steps):
if 'cifar' in FLAGS.dataset_name:
# start to decay lr at the 250th epoch
boundaries = [int(250.0 / 500.0 * max_steps)]
values = [0.1]
elif 'svhn' in FLAGS.dataset_name:
# start to decay lr at the beginning: 0th epoch
boundaries = [int(0 * max_steps)]
values = [0.02]
return boundaries, values
def linear_decay_lr(step, boundaries, values, max_steps):
# decay learning rate linearly
if 'svhn' in FLAGS.dataset_name:
decayed_lr = (float(max_steps - (step + 1)) / float(max_steps)) * values[0]
else:
if step < boundaries[0]:
decayed_lr = values[0]
else:
ratio = (float(max_steps - (step + 1)) / float(max_steps - boundaries[0]))
decayed_lr = ratio * values[0]
return decayed_lr
def cos_decay_lr(step, boundaries, values, max_steps):
# decay learning rate with a cosine function
if 'svhn' in FLAGS.dataset_name:
ratio = 1. - (float(max_steps - (step + 1)) / float(max_steps))
decayed_lr = np.cos(math.pi/2*ratio)* values[0]
else:
if step < boundaries[0]:
decayed_lr = values[0]
else:
ratio = 1. - (float(max_steps - (step + 1)) / float(max_steps - boundaries[0]))
decayed_lr = np.cos(math.pi/2*ratio)
decayed_lr = decayed_lr * values[0]
return decayed_lr
def sin_decay_lr(step, boundaries, values, max_steps):
# decay learning rate with a sine function
if 'svhn' in FLAGS.dataset_name:
ratio = 1.- (float(max_steps - (step + 1)) / float(max_steps))
decayed_lr = 1 - np.sin(math.pi/2*ratio)
decayed_lr = decayed_lr * values[0]
else:
if step < boundaries[0]:
decayed_lr = values[0]
else:
ratio = 1.- (float(max_steps - (step + 1)) / float(max_steps - boundaries[0]))
decayed_lr = 1 - np.sin(math.pi/2*ratio)
decayed_lr = decayed_lr * values[0]
return decayed_lr
def decay_lr(step, boundaries, values, max_steps):
# use cosine or sine learning rate decay schedule may further improve results
if FLAGS.decay_lr_type == 'linear':
decayed_lr = linear_decay_lr(step, boundaries, values, max_steps)
elif FLAGS.decay_lr_type == 'cosine':
decayed_lr = cos_decay_lr(step, boundaries, values, max_steps)
elif FLAGS.decay_lr_type == 'sine':
decayed_lr = sin_decay_lr(step, boundaries, values, max_steps)
else:
raise ValueError('decay_lr_type %s was not recognized.')
return decayed_lr