Skip to content

Commit

Permalink
Implement using generator
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Jun 3, 2024
1 parent 3523343 commit aee9e17
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import sys
from collections import defaultdict
from collections.abc import Iterable, Sequence
from collections.abc import Generator, Iterable, Sequence
from contextlib import suppress
from functools import partial
from operator import itemgetter
Expand Down Expand Up @@ -126,11 +126,10 @@ def __init__(
self._cdims_default = cdims

if len({learner.__class__ for learner in self.learners}) > 1:
raise TypeError(
"A BalacingLearner can handle only one type" " of learners."
)
raise TypeError("A BalacingLearner can handle only one type of learners.")

self.strategy: STRATEGY_TYPE = strategy
self._gen: Generator | None = None

def new(self) -> BalancingLearner:
"""Create a new `BalancingLearner` with the same parameters."""
Expand Down Expand Up @@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle(
def _ask_and_tell_based_on_sequential(
self, n: int
) -> tuple[list[tuple[Int, Any]], list[float]]:
if self._gen is None:
self._gen = _sequential_generator(self.learners)
points: list[tuple[Int, Any]] = []
loss_improvements: list[float] = []
learner_index = 0

while len(points) < n:
learner = self.learners[learner_index]
if learner.done(): # type: ignore[attr-defined]
if learner_index == len(self.learners) - 1:
break
learner_index += 1
continue

point, loss_improvement = learner.ask(n=1)
if not point: # if learner is exhausted, we don't get points
if learner_index == len(self.learners) - 1:
break
learner_index += 1
continue
points.append((learner_index, point[0]))
loss_improvements.append(loss_improvement[0])
self.tell_pending((learner_index, point[0]))
for learner_index, point, loss_improvement in self._gen:
points.append((learner_index, point))
loss_improvements.append(loss_improvement)
self.tell_pending((learner_index, point))
if len(points) >= n:
break

return points, loss_improvements

Expand Down Expand Up @@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
learners, cdims, strategy = state
self.__init__(learners, cdims=cdims, strategy=strategy) # type: ignore[misc]


def _sequential_generator(
learners: list[BaseLearner],
) -> Generator[tuple[int, Any, float], None, None]:
learner_index = 0
if not hasattr(learners[0], "done"):
msg = "All learners must have a `done` method to use the 'sequential' strategy."
raise ValueError(msg)
while True:
learner = learners[learner_index]
if learner.done(): # type: ignore[attr-defined]
if learner_index == len(learners) - 1:
return
learner_index += 1
continue

point, loss_improvement = learner.ask(n=1)
if not point: # if learner is exhausted, we don't get points
if learner_index == len(learners) - 1:
return
learner_index += 1
continue
yield learner_index, point[0], loss_improvement[0]

0 comments on commit aee9e17

Please sign in to comment.