-
Notifications
You must be signed in to change notification settings - Fork 1
/
unit_tests.py
executable file
·92 lines (85 loc) · 3.05 KB
/
unit_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import ID3, parse, random
def testID3AndEvaluate():
data = [dict(a=1, b=0, Class=1), dict(a=1, b=1, Class=1)]
tree = ID3.ID3(data, 0)
if tree != None:
ans = ID3.evaluate(tree, dict(a=1, b=0))
if ans != 1:
print "ID3 test failed."
else:
print "ID3 test succeeded."
else:
print "ID3 test failed -- no tree returned"
def testPruning():
data = [dict(a=1, b=0, Class=1), dict(a=1, b=1, Class=1), dict(a=0, b=1, Class=0), dict(a=0, b=0, Class=1)]
validationData = [dict(a=1, b=0, Class=1), dict(a=1, b=1, Class=1), dict(a=0, b=0, Class=0), dict(a=0, b=0, Class=0)]
tree = ID3.ID3(data, 0)
ID3.prune(tree, validationData)
if tree != None:
ans = ID3.evaluate(tree, dict(a=0, b=0))
if ans != 0:
print "pruning test failed."
else:
print "pruning test succeeded."
else:
print "pruning test failed -- no tree returned."
def testID3AndTest():
trainData = [dict(a=1, b=0, c=0, Class=1), dict(a=1, b=1, c=0, Class=1),
dict(a=0, b=0, c=0, Class=0), dict(a=0, b=1, c=0, Class=1)]
testData = [dict(a=1, b=0, c=1, Class=1), dict(a=1, b=1, c=1, Class=1),
dict(a=0, b=0, c=1, Class=0), dict(a=0, b=1, c=1, Class=0)]
tree = ID3.ID3(trainData, 0)
fails = 0
if tree != None:
acc = ID3.test(tree, trainData)
if acc == 1.0:
print "testing on train data succeeded."
else:
print "testing on train data failed."
fails = fails + 1
acc = ID3.test(tree, testData)
if acc == 0.75:
print "testing on test data succeeded."
else:
print "testing on test data failed."
fails = fails + 1
if fails > 0:
print "Failures: ", fails
else:
print "testID3AndTest succeeded."
else:
print "testID3andTest failed -- no tree returned."
# inFile - string location of the house data file
def testPruningOnHouseData(inFile):
withPruning = []
withoutPruning = []
data = parse.parse(inFile)
for i in range(100):
random.shuffle(data)
print "length of data is:" + str(len(data)) #435
train = data[:len(data)/4] #0-50 (50%)
valid = data[len(data)/4:3*len(data)/4] #50-75 (25%)
test = data[3*len(data)/4:] #75-100 (25%)
tree = ID3.ID3(train, 'democrat')
acc = ID3.test(tree, train)
print "training accuracy: ",acc
acc = ID3.test(tree, valid)
print "validation accuracy: ",acc
acc = ID3.test(tree, test)
print "test accuracy: ",acc
ID3.prune(tree, valid)
acc = ID3.test(tree, train)
print "pruned tree train accuracy: ",acc
acc = ID3.test(tree, valid)
print "pruned tree validation accuracy: ",acc
acc = ID3.test(tree, test)
print "pruned tree test accuracy: ",acc
withPruning.append(acc)
#withoutPruning
tree = ID3.ID3(train+valid, 'democrat')
acc = ID3.test(tree, test)
print "no pruning test accuracy: ",acc
withoutPruning.append(acc)
print withPruning
print withoutPruning
print "average with pruning",sum(withPruning)/len(withPruning)," without: ",sum(withoutPruning)/len(withoutPruning)