Skip to content

Commit

Permalink
add comments and make methods private
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Mar 7, 2019
1 parent d0dab50 commit 51f4292
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions adaptive/learner/average_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ def data_sem(self):
def mean_values_per_point(self):
return np.mean([x.n for x in self._data.values()])

def get_seed(self, point):
def _next_seed(self, point):
_data = self._data.get(point, {})
pending_seeds = self.pending_points.get(point, set())
seed = len(_data) + len(pending_seeds)
if seed in _data or seed in pending_seeds:
# means that the seed already exists, for example
# Means that the seed already exists, for example
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
# Only happens when starting the learner after cancelling/loading.
return (set(range(seed)) - pending_seeds - _data.keys()).pop()
return seed

def loss_per_existing_point(self):
scale = self.value_scale()

points = []
loss_improvements = []
for p, sem in self.data_sem.items():
points.append((p, self.get_seed(p)))
points.append((p, self._next_seed(p)))
N = self.n_values(p)
sem_improvement = (1 - sqrt(N - 1) / sqrt(N)) * sem
loss_improvement = self.weight * sem_improvement / scale
Expand Down Expand Up @@ -102,8 +102,12 @@ def _mean_values_per_neighbor(self, neighbors):
for p, ns in neighbors.items()}

def _normalize_new_points_loss_improvements(self, points, loss_improvements):
"""If we are suggesting a new point, then its 'loss_improvement' should
be divided by the average number of values of its neigbors."""
"""If we are suggesting a new (not yet suggested) point, then its
'loss_improvement' should be divided by the average number of values
of its neigbors.
This is because it will take a similar amount of points to reach
that loss. """
if len(self._data) < 4:
return loss_improvements

Expand All @@ -116,7 +120,10 @@ def _normalize_new_points_loss_improvements(self, points, loss_improvements):

def _normalize_existing_points_loss_improvements(self, points, loss_improvements):
"""If the neighbors of 'point' have twice as much values
on average, then that 'point' should have an infinite loss."""
on average, then that 'point' should have an infinite loss.
We do this because this point probably has a incorrect
estimate of the sem."""
if len(self._data) < 4:
return loss_improvements

Expand All @@ -136,7 +143,7 @@ def _get_data(self):

def add_average_mixin(cls):
names = ('data', 'data_sem', 'mean_values_per_point',
'get_seed', 'loss_per_existing_point', '_add_to_pending',
'_next_seed', 'loss_per_existing_point', '_add_to_pending',
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
'_normalize_new_points_loss_improvements',
'_normalize_existing_points_loss_improvements',
Expand Down

0 comments on commit 51f4292

Please sign in to comment.