-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.lua
161 lines (111 loc) · 4.3 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
require 'torch'
require 'nn'
require 'optim'
require 'image'
require 'nninit'
local model = require 'src/model'
local dataproc = require 'src/dataproc'
--Use FloatTensor for faster training
local dtype = 'torch.FloatTensor'
local useOpenCl = true;
--If we are using opencl, we change the tensor dtype to "ClTensor" using :cl();
if (useOpenCl) then
require 'cltorch'
require 'clnn'
dtype = torch.FloatTensor():cl():type()
end
--Create Loss Function
local criterion = nn.MSECriterion():type(dtype)
criterion.sizeAverage = false
--Create VDSR conv neural network
--http://cv.snu.ac.kr/research/VDSR/VDSR_CVPR2016.pdf
vdsrcnn = model.create(8)
--Set the network to the dtype
vdsrcnn:type(dtype)
--Create training data
function TableToTensor(table)
local tensorSize = table[1]:size()
local tensorSizeTable = {-1}
for i=1,tensorSize:size(1) do
tensorSizeTable[i+1] = tensorSize[i]
end
merge=nn.Sequential()
:add(nn.JoinTable(1))
:add(nn.View(unpack(tensorSizeTable)))
return merge:forward(table)
end
local imagesn = 12 --Number of images in the folder ./train/
local batchsize = 10 --Reduce the batch size if you have memory problems (C++ Exception or Out of memory error)
local minibatch = (imagesn*4)/batchsize --#Of iterations before going through entire batch
local hr, lr = dataproc.getImages(imagesn)
local timg = image.load("train/test.png", 3, "float")
local thr = timg:type(dtype)
local tlr = image.scale(image.scale(timg, "*1/2"), thr:size(3), thr:size(2), "bicubic"):type(dtype)
local x;
local y;
function setBatch()
ay, ax = dataproc.getBatch(hr, lr, n, w, h)
x = TableToTensor(ax):type(dtype)
y = TableToTensor(ay):type(dtype)
end
setBatch()
--Initialise training variables
params, gradParams = vdsrcnn:getParameters()
local optimState = {learningRate = 0.05, weightDecay = 0.0001, momentum = 0.9}
local cnorm = 0.001 * optimState.learningRate --Gradient Clipping (c * Initial_Learning_Rate)
local showlossevery = 100;
local loss = 1;
--Training function
function f(params)
--vdsrcnn:zeroGradParameters();
gradParams:zero()
local imagein = x:clone():csub(0.5) --Removing 0.5 to normalise the input images to [-0.5, 0.5] helps prevent gradient explosion
--if the image has values of [0, 1], all the gradients initially will be positive at the same time
--TODO: Better to substract with the mean of all images
--Forward the image values
local out = vdsrcnn:forward(imagein)
local diff = y:clone():csub(x)
--The loss is the difference between the output residual and the ground truth residual
loss = criterion:forward(out, diff)
--Compute the gradient
local lrate = optimState.learningRate
local grad_out = criterion:backward(out, diff)
--Zero the previous gradient, and backpropagate the new gradient
local grad_in = vdsrcnn:backward(imagein, grad_out):clamp(-cnorm/lrate, cnorm/lrate)
gradParams:clamp(-cnorm/lrate, cnorm/lrate) --Clip the gradients
--Return the loss and new gradient parameters to the optim.sgd() function
return loss, gradParams
end
local decreaseRate = 0.1
--Saves a ground truth residual for testing
local Truthdiff = thr:clone():csub(tlr)
image.save("test/Truth.png", Truthdiff:add(0.5))
--image.save("test/TestInput.png", x[1])
--image.save("test/TestOutput.png", y[1])
local Truthdiff2 = y[1]:clone():csub(x[1])
image.save("test/TestGT1.png", Truthdiff2:add(0.5))
local epoch = 0;
for iter = 1, 30000 do
if (iter%10000 == 0) then
optimState.learningRate = optimState.learningRate * decreaseRate
print("Reducing learning rate by a factor of " .. decreaseRate .. ". New learning rate: " .. optimState.learningRate)
end
optim.sgd(f, params, optimState)
if ((iter%showlossevery == 0) or (iter%20 == 0 and iter < 200) or (iter < 20)) then --Print the training loss and an example residual output to compare with ground truth
print("Epoch " .. epoch .. " Iteration " .. iter .. " Training Loss " .. loss)
local epochdiff = vdsrcnn:forward(tlr:clone():csub(0.5))
image.save("test/" .. iter .. "resid.png", epochdiff:add(0.5))
end
if (iter%100 == 0) then --save model each 100 iterations
vdsrcnn:clearState()
vdsrcnn:float()
torch.save("save/nn" .. iter .. ".cv", vdsrcnn)
vdsrcnn:type(dtype)
params, gradParams = vdsrcnn:getParameters()
collectgarbage()
end
if (iter%minibatch == minibatch-1) then
epoch = epoch+1
end
setBatch()
end