-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.py
170 lines (144 loc) · 6.07 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
## 通用代码
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
## 获取mnist数据集
def get_mnist_label(labels):
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
## 加载mnist数据集
def load_data_mnist(batch_size, resize=None):
mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
# 将所有数字除以255 使所有像素值介于0和1之间,最后添加一个批处理维度
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
tf.cast(y, dtype='int32'))
resize_fn = lambda X, y: (tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
return (
tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(batch_size).map(resize_fn)
)
def use_svg_display():
backend_inline.set_matplotlib_formats('svg')
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
"""设置matplotlib的轴"""
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
axes.set_xlim(xlim)
axes.set_ylim(ylim)
if legend:
axes.legend(legend)
axes.grid()
def sgd(params, grads, lr, batch_size):
for params, grads in zip(params, grads):
params.assign_sub(lr * grads / batch_size)
def accuracy(y_hat, y):
"""计算预测正确的数量"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = tf.argmax(y_hat, axis=1)
cmp = tf.cast(y_hat, y.dtype) == y
return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))
def evaluate_accuracy(net, data_iter):
metric = Accumulator(2)
for X, y in data_iter:
metric.add(accuracy(net(X), y), tf.size(y))
return metric[0] / metric[1]
class Animator:
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1, figsize=(3.5, 2.5)):
if legend is None:
legend = []
use_svg_display()
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
if nrows * nrows == 1:
self.axes = [self.axes, ]
# 使用lambda函数捕获参数
self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
self.X, self.Y, self.fmts = None, None, fmts
def add(self, x, y):
if not hasattr(y, "__len__"):
y = [y]
n = len(y)
if not hasattr(x, "__len__"):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
self.axes[0].plot(x, y, fmt)
self.config_axes()
class Accumulator:
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, item):
return self.data[item]
class Updater():
"""小批量随机梯度下降法更新参数"""
def __init__(self, params, lr):
self.params = params
self.lr = lr
def __call__(self, batch_size, grads):
sgd(self.params, grads, self.lr, batch_size)
def train_epoch_ch3(net, train_iter, loss, updater):
# 训练损失总和、训练准确度总和、样本数
metric = Accumulator(3)
for X, y in train_iter:
# 计算梯度
with tf.GradientTape() as tape:
y_hat = net(X)
if isinstance(loss, tf.keras.losses.Loss):
l = loss(y, y_hat)
else:
l = loss(y_hat, y)
if isinstance(updater, tf.keras.optimizers.Optimizer):
params = net.trainable_variables
grads = tape.gradient(l, params)
updater.apply_gradients(zip(grads, params))
else:
updater(X.shape[0], tape.gradient(l, updater.params))
# keras的loss默认返回一个批量平均损失
l_sum = l * float(tf.size(y)) if isinstance(loss, tf.keras.losses.Loss) else tf.reduce_sum(l)
metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
# 返回训练损失和训练精度
return metric[0] / metric[2], metric[1] / metric[2]
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
"""训练模型"""
animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9], legend=['train loss', 'train acc', 'test acc'])
for epoch in range(num_epochs):
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
test_acc = evaluate_accuracy(net, test_iter)
print(f"Epoch [{epoch + 1}/{num_epochs}], "
f"Train Loss: {train_metrics[0]:.4f}, "
f"Train Acc: {train_metrics[1]:.4f}, "
f"Test Acc: {test_acc:.4f}")
animator.add(epoch + 1, train_metrics + (test_acc,))
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc
def predict_ch3(net, test_iter, n=6):
"""预测"""
for X, y in test_iter:
break
trues = get_mnist_label(y)
preds = get_mnist_label(tf.argmax(net(X), axis=1))
images = tf.reshape(X[0:n], (n, 28, 28))
fig, axes = plt.subplots(1, n, figsize=(15, 15))
for i, ax in enumerate(axes):
ax.imshow(images[i])
ax.set_title(f'true label: {trues[i]}\npredict label: {preds[i]}')
ax.axis('off')