diff --git a/adaptive/learner/average_mixin.py b/adaptive/learner/average_mixin.py index 1fb8b608d..27e29f6bc 100644 --- a/adaptive/learner/average_mixin.py +++ b/adaptive/learner/average_mixin.py @@ -26,31 +26,18 @@ def data_sem(self): def mean_values_per_point(self): return np.mean([x.n for x in self._data.values()]) - def _next_seed(self, point): + def _next_seed(self, point, exclude=None): + exclude = set(exclude) if exclude is not None else set() _data = self._data.get(point, {}) pending_seeds = self.pending_points.get(point, set()) - seed = len(_data) + len(pending_seeds) - if seed in _data or seed in pending_seeds: + seed = len(_data) + len(pending_seeds) + len(exclude) + if seed in _data or seed in pending_seeds or seed in exclude: # Means that the seed already exists, for example # when '_data[point].keys() | pending_points[point] == {0, 2}'. # Only happens when starting the learner after cancelling/loading. - return (set(range(seed)) - pending_seeds - _data.keys()).pop() + return (set(range(seed)) - pending_seeds - _data.keys() - exclude).pop() return seed - def loss_per_existing_point(self): - scale = self.value_scale() - points = [] - loss_improvements = [] - for p, sem in self.data_sem.items(): - points.append((p, self._next_seed(p))) - N = self.n_values(p) - sem_improvement = (1 - sqrt(N - 1) / sqrt(N)) * sem - loss_improvement = self.weight * sem_improvement / scale - loss_improvements.append(loss_improvement) - loss_improvements = self._normalize_existing_points_loss_improvements( - points, loss_improvements) - return points, loss_improvements - def _add_to_pending(self, point): x, seed = self.unpack_point(point) if x not in self.pending_points: @@ -74,18 +61,35 @@ def _add_to_data(self, point, value): def ask(self, n, tell_pending=True): """Return n points that are expected to maximally reduce the loss.""" points, loss_improvements = [], [] - self._fill_seed_stack(till=n) # Take from the _seed_stack if there are any points. + self._fill_seed_stack(till=n) for i in range(n): - point, loss_improvement = self._seed_stack[i] - points.append(point) - loss_improvements.append(loss_improvement) + exclude_seeds = set() + (point, nseeds), loss_improvement = self._seed_stack[i] + for j in range(nseeds): + seed = self._next_seed(point, exclude_seeds) + exclude_seeds.add(seed) + points.append((point, seed)) + loss_improvements.append(loss_improvement / nseeds) + if len(points) >= n: + break + if len(points) >= n: + break if tell_pending: for p in points: self.tell_pending(p) - self._seed_stack = self._seed_stack[n:] + nseeds_left = nseeds - j - 1 # of self._seed_stack[i] + if nseeds_left > 0: # not all seeds have been asked + (point, nseeds), loss_improvement = self._seed_stack[i] + self._seed_stack[i] = ( + (point, nseeds_left), + loss_improvement * nseeds_left / nseeds + ) + self._seed_stack = self._seed_stack[i:] + else: + self._seed_stack = self._seed_stack[i+1:] return points, loss_improvements @@ -94,23 +98,29 @@ def _fill_seed_stack(self, till): if n < 1: return points, loss_improvements = [], [] - new_points, new_points_loss_improvements = ( - self._ask_points_without_adding(n)) - loss_improvements += self._normalize_new_points_loss_improvements( - new_points, new_points_loss_improvements) + + new_points, new_points_loss_improvements = \ + self.loss_per_new_point(n) + + loss_improvements += new_points_loss_improvements points += new_points existing_points, existing_points_loss_improvements = \ self.loss_per_existing_point() + points += existing_points loss_improvements += existing_points_loss_improvements loss_improvements, points = zip(*sorted( zip(loss_improvements, points), reverse=True)) - points = list(points)[:n] - loss_improvements = list(loss_improvements)[:n] - self._seed_stack += list(zip(points, loss_improvements)) + n_left = n + for loss_improvement, (point, nseeds) in zip( + loss_improvements, points): + self._seed_stack.append(((point, nseeds), loss_improvement)) + n_left -= nseeds + if n_left <= 0: + break def n_values(self, point): pending_points = self.pending_points.get(point, []) @@ -121,40 +131,49 @@ def _mean_values_per_neighbor(self, neighbors): return {p: sum(self.n_values(n) for n in ns) / len(ns) for p, ns in neighbors.items()} - def _normalize_new_points_loss_improvements(self, points, loss_improvements): - """If we are suggesting a new (not yet suggested) point, then its - 'loss_improvement' should be divided by the average number of values - of its neigbors. - - This is because it will take a similar amount of points to reach - that loss. """ + def loss_per_new_point(self, n): + """Add new points with at least self.min_values_per_point points + or with as many points as the neighbors have on average.""" + points, loss_improvements = self._ask_points_without_adding(n) if len(self._data) < 4: - return loss_improvements + points = [(p, self.min_values_per_point) for p, s in points] + return points, loss_improvements - only_points = [p for p, s in points] + only_points = [p for p, s in points] # points are [(x, seed), ...] neighbors = self._get_neighbor_mapping_new_points(only_points) mean_values_per_neighbor = self._mean_values_per_neighbor(neighbors) - return [loss / mean_values_per_neighbor[p] - for (p, seed), loss in zip(points, loss_improvements)] + points = [] + for p in only_points: + n_neighbors = int(mean_values_per_neighbor[p]) + nseeds = max(n_neighbors, self.min_values_per_point) + points.append((p, nseeds)) - def _normalize_existing_points_loss_improvements(self, points, loss_improvements): - """If the neighbors of 'point' have twice as much values - on average, then that 'point' should have an infinite loss. + return points, loss_improvements - We do this because this point probably has a incorrect - estimate of the sem.""" - if len(self._data) < 4: - return loss_improvements + def loss_per_existing_point(self): + """Increase the number of seeds by 10%.""" + if len(self.data) < 4: + return [], [] + scale = self.value_scale() + points = [] + loss_improvements = [] neighbors = self._get_neighbor_mapping_existing_points() mean_values_per_neighbor = self._mean_values_per_neighbor(neighbors) - def needs_more_data(p): - return mean_values_per_neighbor[p] > 1.5 * self.n_values(p) - - return [inf if needs_more_data(p) else loss - for (p, seed), loss in zip(points, loss_improvements)] + for p, sem in self.data_sem.items(): + n_neighbors = mean_values_per_neighbor[p] + N = self.n_values(p) + n_more = int(0.1 * N) # increase the amount of points by 10% + n_more = max(n_more, 1) # at least 1 point + points.append((p, n_more)) + # This is the improvement considering we will add + # n_more seeds to the stack. + sem_improvement = (1 / sqrt(N) - 1 / sqrt(N + n_more)) * sem + loss_improvement = self.weight * sem_improvement / scale # XXX: Do I need to divide by the scale? + loss_improvements.append(loss_improvement) + return points, loss_improvements def _get_data(self): # change DataPoint -> dict for saving @@ -165,9 +184,7 @@ def add_average_mixin(cls): names = ('data', 'data_sem', 'mean_values_per_point', '_next_seed', 'loss_per_existing_point', '_add_to_pending', '_remove_from_to_pending', '_add_to_data', 'ask', 'n_values', - '_normalize_new_points_loss_improvements', - '_normalize_existing_points_loss_improvements', - '_mean_values_per_neighbor', + 'loss_per_new_point', '_mean_values_per_neighbor', '_get_data', '_fill_seed_stack') for name in names: