Skip to content

Commit

Permalink
add comments and make methods private
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Mar 7, 2019
1 parent d0dab50 commit 5f3aeba
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions adaptive/learner/average_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ def data_sem(self):
def mean_values_per_point(self):
return np.mean([x.n for x in self._data.values()])

def get_seed(self, point):
def _next_seed(self, point):
_data = self._data.get(point, {})
pending_seeds = self.pending_points.get(point, set())
seed = len(_data) + len(pending_seeds)
if seed in _data or seed in pending_seeds:
# means that the seed already exists, for example
# Means that the seed already exists, for example
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
# Only happens when starting the learner after cancelling/loading.
return (set(range(seed)) - pending_seeds - _data.keys()).pop()
return seed

Expand All @@ -42,7 +43,7 @@ def loss_per_existing_point(self):
points = []
loss_improvements = []
for p, sem in self.data_sem.items():
points.append((p, self.get_seed(p)))
points.append((p, self._next_seed(p)))
N = self.n_values(p)
sem_improvement = (1 - sqrt(N - 1) / sqrt(N)) * sem
loss_improvement = self.weight * sem_improvement / scale
Expand Down Expand Up @@ -136,7 +137,7 @@ def _get_data(self):

def add_average_mixin(cls):
names = ('data', 'data_sem', 'mean_values_per_point',
'get_seed', 'loss_per_existing_point', '_add_to_pending',
'_next_seed', 'loss_per_existing_point', '_add_to_pending',
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
'_normalize_new_points_loss_improvements',
'_normalize_existing_points_loss_improvements',
Expand Down

0 comments on commit 5f3aeba

Please sign in to comment.