Skip to content

Commit

Permalink
Add type-hints to adaptive/learner/base_learner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 12, 2022
1 parent 1b7e84d commit b91ef59
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import abc
from contextlib import suppress
from typing import Any, Callable

import cloudpickle

from adaptive.utils import _RequireAttrsABCMeta, load, save


def uses_nth_neighbors(n: int):
def uses_nth_neighbors(n: int) -> Callable[[int], Callable[[BaseLearner], float]]:
"""Decorator to specify how many neighboring intervals the loss function uses.
Wraps loss functions to indicate that they expect intervals together
Expand Down Expand Up @@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
... return loss
"""

def _wrapped(loss_per_interval):
def _wrapped(
loss_per_interval: Callable[[BaseLearner], float]
) -> Callable[[BaseLearner], float]:
loss_per_interval.nth_neighbors = n
return loss_per_interval

Expand Down Expand Up @@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
"""

data: dict
npoints: int
pending_points: set
function: Callable

@property
@abc.abstractmethod
def npoints(self) -> int:
"""Number of learned points."""

def tell(self, x, y):
def tell(self, x: Any, y: Any) -> None:
"""Tell the learner about a single value.
Parameters
Expand All @@ -95,7 +105,7 @@ def tell(self, x, y):
"""
self.tell_many([x], [y])

def tell_many(self, xs, ys):
def tell_many(self, xs: Any, ys: Any) -> None:
"""Tell the learner about some values.
Parameters
Expand All @@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
self.tell(x, y)

@abc.abstractmethod
def tell_pending(self, x):
def tell_pending(self, x: Any) -> None:
"""Tell the learner that 'x' has been requested such
that it's not suggested again."""

@abc.abstractmethod
def remove_unfinished(self):
def remove_unfinished(self) -> None:
"""Remove uncomputed data from the learner."""

@abc.abstractmethod
def loss(self, real=True):
def loss(self, real: bool = True) -> float:
"""Return the loss for the current state of the learner.
Parameters
Expand All @@ -128,7 +138,7 @@ def loss(self, real=True):
"""

@abc.abstractmethod
def ask(self, n, tell_pending=True):
def ask(self, n: int, tell_pending: bool = True):
"""Choose the next 'n' points to evaluate.
Parameters
Expand All @@ -142,11 +152,11 @@ def ask(self, n, tell_pending=True):
"""

@abc.abstractmethod
def _get_data(self):
def _get_data(self) -> Any:
pass

@abc.abstractmethod
def _set_data(self):
def _set_data(self, data: Any):
pass

@abc.abstractmethod
Expand All @@ -164,7 +174,7 @@ def copy_from(self, other):
"""
self._set_data(other._get_data())

def save(self, fname, compress=True):
def save(self, fname: str, compress: bool = True) -> None:
"""Save the data of the learner into a pickle file.
Parameters
Expand All @@ -178,7 +188,7 @@ def save(self, fname, compress=True):
data = self._get_data()
save(fname, data, compress)

def load(self, fname, compress=True):
def load(self, fname: str, compress: bool = True) -> None:
"""Load the data of a learner from a pickle file.
Parameters
Expand All @@ -193,8 +203,8 @@ def load(self, fname, compress=True):
data = load(fname, compress)
self._set_data(data)

def __getstate__(self):
def __getstate__(self) -> bytes:
return cloudpickle.dumps(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: bytes) -> None:
self.__dict__ = cloudpickle.loads(state)

0 comments on commit b91ef59

Please sign in to comment.