-
Notifications
You must be signed in to change notification settings - Fork 13
/
plotting.py
57 lines (46 loc) · 1.82 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
def patch_plot(ax, x, y, yl, yu, label):
if label.startswith('random'):
h = ax.plot(x, y, '--', label=label, linewidth=1, alpha=0.6)
else:
h = ax.plot(x, y, '-', label=label, linewidth=1)
ax.fill_between(x, yl, yu, label='', alpha=0.15, facecolor=h[0].get_color())
def plot_comparison(n_evals, regrets, PLOT_TYPE='mean', pct_values=[25,50,75]):
plt.xlabel('Iterations')
ax = plt.gca()
for k in regrets.keys():
r = regrets[k]
n_test = r.shape[1]
if PLOT_TYPE == 'mean':
regret = r.mean(axis=1)
regret_sd = r.std(axis=1)
regret_low = regret - regret_sd/np.sqrt(n_test)
regret_high = regret + regret_sd/np.sqrt(n_test)
else:
pcts = np.percentile(r, pct_values, axis=1)
regret = pcts[1,:]
regret_low = pcts[0,:]
regret_high = pcts[2,:]
patch_plot(ax, range(n_evals), regret, regret_low, regret_high, label=k)
plt.legend(loc='upper right')
if PLOT_TYPE == 'mean':
plt.ylabel('Regret (mean $\pm$ SE)')
else:
plt.ylabel('Regret (quartiles)')
def compare_regrets(regrets):
n_evals = regrets[list(regrets.keys())[0]].shape[0]
plot_comparison(n_evals, regrets)
plt.show()
def compare_ranks(regrets, num_decimals=10):
all_results = np.stack([regrets[k] for k in regrets.keys()], axis=2)
all_results = np.around(all_results, num_decimals)
ranks0 = np.apply_along_axis(st.rankdata, 2, all_results)
ranks = {}
for i in range(len(regrets.keys())):
ranks[list(regrets.keys())[i]] = ranks0[:,:,i].squeeze()
n_evals = regrets[list(regrets.keys())[0]].shape[0]
plot_comparison(n_evals, ranks)
plt.ylabel('Rank (mean $\pm$ SE)')
plt.show()