Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Make RecordSet easier to use #3632

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,26 @@ class ErrorCode:
def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")


class Record:
"""Keys for records in a RecordSet."""

PARAMS = "p"
METRICS = "m"
CONFIGS = "c"

def __new__(cls) -> Record:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")


class Metric:
"""Keys for metrics in a MetricsRecord."""

NUM_EXAMPLES = "num_examples"
LOSS = "loss"

def __new__(cls) -> Metric:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")
jafermarq marked this conversation as resolved.
Show resolved Hide resolved
163 changes: 125 additions & 38 deletions src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,63 @@
"""RecordSet."""


from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional, cast
from typing import Literal, TypeVar, cast, overload

from ..constant import Record
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord
from .typeddict import TypedDict

T = TypeVar("T", ParametersRecord, MetricsRecord, ConfigsRecord)


class _Checker:
def __init__(
self,
data: RecordSetData,
r_type: type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord],
) -> None:
self.data = data
self._type = r_type

def check_key(self, key: str) -> None:
"""Check the validity of the key."""
data = self.data
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)

orig_value: ParametersRecord | MetricsRecord | ConfigsRecord | None = None
if key in data.parameters_records:
orig_value = data.parameters_records[key]
elif key in data.metrics_records:
orig_value = data.metrics_records[key]
elif key in data.configs_records:
orig_value = data.configs_records[key]

if orig_value is not None and not isinstance(orig_value, self._type):
raise TypeError(
f"Key '{key}' is already associated with "
f"a '{type(orig_value).__name__}', but a value "
f"of type '{self._type.__name__}' was provided."
)
jafermarq marked this conversation as resolved.
Show resolved Hide resolved

def check_value(
self, value: ParametersRecord | MetricsRecord | ConfigsRecord
) -> None:
"""Check the validity of the value."""
if not isinstance(value, self._type):
raise TypeError(
f"Expected `{self._type.__name__}`, but received "
f"`{type(value).__name__}` for the value."
)


@dataclass
class RecordSetData:
Expand All @@ -34,18 +83,21 @@ class RecordSetData:

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
parameters_records: dict[str, ParametersRecord] | None = None,
metrics_records: dict[str, MetricsRecord] | None = None,
configs_records: dict[str, ConfigsRecord] | None = None,
) -> None:
params_checker = _Checker(self, ParametersRecord)
metrics_checker = _Checker(self, MetricsRecord)
configs_checker = _Checker(self, ConfigsRecord)
self.parameters_records = TypedDict[str, ParametersRecord](
self._check_fn_str, self._check_fn_params
params_checker.check_key, params_checker.check_value
)
self.metrics_records = TypedDict[str, MetricsRecord](
self._check_fn_str, self._check_fn_metrics
metrics_checker.check_key, metrics_checker.check_value
)
self.configs_records = TypedDict[str, ConfigsRecord](
self._check_fn_str, self._check_fn_configs
configs_checker.check_key, configs_checker.check_value
)
if parameters_records is not None:
self.parameters_records.update(parameters_records)
Expand All @@ -54,43 +106,15 @@ def __init__(
if configs_records is not None:
self.configs_records.update(configs_records)

def _check_fn_str(self, key: str) -> None:
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)

def _check_fn_params(self, record: ParametersRecord) -> None:
if not isinstance(record, ParametersRecord):
raise TypeError(
f"Expected `{ParametersRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_metrics(self, record: MetricsRecord) -> None:
if not isinstance(record, MetricsRecord):
raise TypeError(
f"Expected `{MetricsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_configs(self, record: ConfigsRecord) -> None:
if not isinstance(record, ConfigsRecord):
raise TypeError(
f"Expected `{ConfigsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)


class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
parameters_records: dict[str, ParametersRecord] | None = None,
metrics_records: dict[str, MetricsRecord] | None = None,
configs_records: dict[str, ConfigsRecord] | None = None,
) -> None:
data = RecordSetData(
parameters_records=parameters_records,
Expand All @@ -117,6 +141,69 @@ def configs_records(self) -> TypedDict[str, ConfigsRecord]:
data = cast(RecordSetData, self.__dict__["_data"])
return data.configs_records

@overload
def __getitem__(self, key: Literal["p"]) -> ParametersRecord: ... # noqa: E704

@overload
def __getitem__(self, key: Literal["m"]) -> MetricsRecord: ... # noqa: E704

@overload
def __getitem__(self, key: Literal["c"]) -> ConfigsRecord: ... # noqa: E704

def __getitem__(self, key: str) -> ParametersRecord | MetricsRecord | ConfigsRecord:
"""Return the record for the specified key."""
# Initialize default *Record
data = cast(RecordSetData, self.__dict__["_data"])
if key == Record.PARAMS and key not in data.parameters_records:
data.parameters_records[key] = ParametersRecord()
elif key == Record.METRICS and key not in data.metrics_records:
data.metrics_records[key] = MetricsRecord()
elif key == Record.CONFIGS and key not in data.configs_records:
data.configs_records[key] = ConfigsRecord()

# Return the record
if key in data.parameters_records:
return data.parameters_records[key]
if key in data.metrics_records:
return data.metrics_records[key]
if key in data.configs_records:
return data.configs_records[key]
raise KeyError(key)

def __setitem__(
self, key: str, value: ParametersRecord | MetricsRecord | ConfigsRecord
) -> None:
"""Set the record for the specified key."""
data = cast(RecordSetData, self.__dict__["_data"])
builtin_key_name: str | None = None
record_type: (
type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord] | None
) = None
if key == Record.PARAMS:
builtin_key_name, record_type = "Record.PARAMS", ParametersRecord
elif key == Record.METRICS:
builtin_key_name, record_type = "Record.METRICS", MetricsRecord
elif key == Record.CONFIGS:
builtin_key_name, record_type = "Record.CONFIGS", ConfigsRecord
if record_type is not None and not isinstance(value, record_type):
raise TypeError(
f"Expected value of type `{record_type.__name__}` for built-in key "
f"`{builtin_key_name}`, but received type `{type(value).__name__}`."
)

if isinstance(value, ParametersRecord):
data.parameters_records[key] = value
elif isinstance(value, MetricsRecord):
data.metrics_records[key] = value
elif isinstance(value, ConfigsRecord):
data.configs_records[key] = value
else:
raise TypeError(
f"Expected {{`{ParametersRecord.__name__}`, "
f"`{MetricsRecord.__name__}`, `{ConfigsRecord.__name__}`}}, "
f"but received `{type(value).__name__}` for the value."
)

def __repr__(self) -> str:
"""Return a string representation of this instance."""
flds = ("parameters_records", "metrics_records", "configs_records")
Expand Down
29 changes: 28 additions & 1 deletion src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import pickle
from collections import namedtuple
from copy import deepcopy
from typing import Callable, Dict, List, OrderedDict, Type, Union
from typing import Any, Callable, Dict, List, OrderedDict, Type, Union

import numpy as np
import pytest

from flwr.common.constant import Record
from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common.recordset_compat import (
parameters_to_parametersrecord,
Expand Down Expand Up @@ -430,3 +431,29 @@ def test_recordset_repr() -> None:

# Assert
assert str(rs) == str(expected)


@pytest.mark.parametrize(
"rs, key, value, exc_type",
[
# Invalid key type
(RecordSet(), 12, ConfigsRecord(), TypeError),
# Invalid value type
(RecordSet(), "resnet", 0.9, TypeError),
# Invalid record type for built-in keys
(RecordSet(), Record.PARAMS, MetricsRecord(), TypeError),
(RecordSet(), Record.PARAMS, ConfigsRecord(), TypeError),
(RecordSet(), Record.METRICS, ParametersRecord(), TypeError),
(RecordSet(), Record.METRICS, ConfigsRecord(), TypeError),
(RecordSet(), Record.CONFIGS, ParametersRecord(), TypeError),
(RecordSet(), Record.CONFIGS, MetricsRecord(), TypeError),
# Invalid reassignment (different record types)
(RecordSet({"model": ParametersRecord()}), "model", ConfigsRecord(), TypeError),
],
)
def test_invalid_assignment(
rs: RecordSet, key: str, value: Any, exc_type: Type[Exception]
) -> None:
"""Test invalid assignment to RecordSet."""
with pytest.raises(exc_type):
rs[key] = value