-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
84 lines (68 loc) · 2.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import division
import argparse
import numpy as np
import torch
from dim_red.triplet import train_triplet
from dim_red.angular import train_angular
from dim_red.support_func import sanitize
from dim_red.data import load_dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('dataset options')
aa("--database", default="sift")
aa("--method", type=str, default="triplet")
group = parser.add_argument_group('Model hyperparameters')
aa("--dout", type=int, default=16,
help="output dimension")
aa("--dint", type=int, default=1024)
group = parser.add_argument_group('Computation params')
aa("--seed", type=int, default=1234)
aa("--device", choices=["cuda", "cpu", "auto"], default="auto")
aa("--val_freq", type=int, default=10,
help="frequency of validation calls")
aa("--optim", type=str, default="sgd")
aa("--print_results", type=int, default=0)
aa("--save", type=int, default=0)
aa("--full", type=int, default=0)
aa("--val_freq_search", type=int, default=5,
help="frequency of validation calls")
aa("--save_knn_1k", type=int, default=0)
aa("--save_optimal", type=int, default=0)
aa("--batch_size", type=int, default=64)
aa("--epochs", type=int, default=40)
aa("--lr_schedule", type=str, default="0.1,0.1,0.05,0.01")
aa("--momentum", type=float, default=0.9)
args = parser.parse_args()
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
np.random.seed(args.seed)
torch.manual_seed(args.seed)
print(args)
results_file_name = "/home/shekhale/results/dim_red/" + args.database + "/train_results_" + args.method + ".txt"
if args.print_results > 0:
with open(results_file_name, "a") as rfile:
rfile.write("\n\n")
rfile.write("START TRAINING \n")
print ("load dataset %s" % args.database)
(_, xb, xq, _) = load_dataset(args.database, args.device, calc_gt=False, mnt=True)
base_size = xb.shape[0]
threshold = int(base_size * 0.01)
perm = np.random.permutation(base_size)
xv = xb[perm[:threshold]]
if args.full:
xt = xb
else:
xt = xb[perm[threshold:]]
print(xb.shape, xt.shape, xv.shape, xq.shape)
xt = sanitize(xt)
xv = sanitize(xv)
xb = sanitize(xb)
xq = sanitize(xq)
if args.method == "triplet":
train_triplet(xb, xt, xv, xq, args, results_file_name)
elif args.method == "angular":
train_angular(xb, xt, xv, xq, args, results_file_name, perm)
else:
print("Select an available method")