diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index f1495958945..3f828098482 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -106,3 +106,26 @@ class ErrorCode: def __new__(cls) -> ErrorCode: """Prevent instantiation.""" raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Record: + """Keys for records in a RecordSet.""" + + PARAMS = "p" + METRICS = "m" + CONFIGS = "c" + + def __new__(cls) -> Record: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Metric: + """Keys for metrics in a MetricsRecord.""" + + NUM_EXAMPLES = "num_examples" + LOSS = "loss" + + def __new__(cls) -> Metric: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index 74eed46ad86..9c17152c81a 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -15,14 +15,63 @@ """RecordSet.""" +from __future__ import annotations + from dataclasses import dataclass -from typing import Dict, Optional, cast +from typing import Literal, TypeVar, cast, overload +from ..constant import Record from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord from .typeddict import TypedDict +T = TypeVar("T", ParametersRecord, MetricsRecord, ConfigsRecord) + + +class _Checker: + def __init__( + self, + data: RecordSetData, + r_type: type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord], + ) -> None: + self.data = data + self._type = r_type + + def check_key(self, key: str) -> None: + """Check the validity of the key.""" + data = self.data + if not isinstance(key, str): + raise TypeError( + f"Expected `{str.__name__}`, but " + f"received `{type(key).__name__}` for the key." + ) + + orig_value: ParametersRecord | MetricsRecord | ConfigsRecord | None = None + if key in data.parameters_records: + orig_value = data.parameters_records[key] + elif key in data.metrics_records: + orig_value = data.metrics_records[key] + elif key in data.configs_records: + orig_value = data.configs_records[key] + + if orig_value is not None and not isinstance(orig_value, self._type): + raise TypeError( + f"Key '{key}' is already associated with " + f"a '{type(orig_value).__name__}', but a value " + f"of type '{self._type.__name__}' was provided." + ) + + def check_value( + self, value: ParametersRecord | MetricsRecord | ConfigsRecord + ) -> None: + """Check the validity of the value.""" + if not isinstance(value, self._type): + raise TypeError( + f"Expected `{self._type.__name__}`, but received " + f"`{type(value).__name__}` for the value." + ) + @dataclass class RecordSetData: @@ -34,18 +83,21 @@ class RecordSetData: def __init__( self, - parameters_records: Optional[Dict[str, ParametersRecord]] = None, - metrics_records: Optional[Dict[str, MetricsRecord]] = None, - configs_records: Optional[Dict[str, ConfigsRecord]] = None, + parameters_records: dict[str, ParametersRecord] | None = None, + metrics_records: dict[str, MetricsRecord] | None = None, + configs_records: dict[str, ConfigsRecord] | None = None, ) -> None: + params_checker = _Checker(self, ParametersRecord) + metrics_checker = _Checker(self, MetricsRecord) + configs_checker = _Checker(self, ConfigsRecord) self.parameters_records = TypedDict[str, ParametersRecord]( - self._check_fn_str, self._check_fn_params + params_checker.check_key, params_checker.check_value ) self.metrics_records = TypedDict[str, MetricsRecord]( - self._check_fn_str, self._check_fn_metrics + metrics_checker.check_key, metrics_checker.check_value ) self.configs_records = TypedDict[str, ConfigsRecord]( - self._check_fn_str, self._check_fn_configs + configs_checker.check_key, configs_checker.check_value ) if parameters_records is not None: self.parameters_records.update(parameters_records) @@ -54,43 +106,15 @@ def __init__( if configs_records is not None: self.configs_records.update(configs_records) - def _check_fn_str(self, key: str) -> None: - if not isinstance(key, str): - raise TypeError( - f"Expected `{str.__name__}`, but " - f"received `{type(key).__name__}` for the key." - ) - - def _check_fn_params(self, record: ParametersRecord) -> None: - if not isinstance(record, ParametersRecord): - raise TypeError( - f"Expected `{ParametersRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - - def _check_fn_metrics(self, record: MetricsRecord) -> None: - if not isinstance(record, MetricsRecord): - raise TypeError( - f"Expected `{MetricsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - - def _check_fn_configs(self, record: ConfigsRecord) -> None: - if not isinstance(record, ConfigsRecord): - raise TypeError( - f"Expected `{ConfigsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - class RecordSet: """RecordSet stores groups of parameters, metrics and configs.""" def __init__( self, - parameters_records: Optional[Dict[str, ParametersRecord]] = None, - metrics_records: Optional[Dict[str, MetricsRecord]] = None, - configs_records: Optional[Dict[str, ConfigsRecord]] = None, + parameters_records: dict[str, ParametersRecord] | None = None, + metrics_records: dict[str, MetricsRecord] | None = None, + configs_records: dict[str, ConfigsRecord] | None = None, ) -> None: data = RecordSetData( parameters_records=parameters_records, @@ -117,6 +141,69 @@ def configs_records(self) -> TypedDict[str, ConfigsRecord]: data = cast(RecordSetData, self.__dict__["_data"]) return data.configs_records + @overload + def __getitem__(self, key: Literal["p"]) -> ParametersRecord: ... # noqa: E704 + + @overload + def __getitem__(self, key: Literal["m"]) -> MetricsRecord: ... # noqa: E704 + + @overload + def __getitem__(self, key: Literal["c"]) -> ConfigsRecord: ... # noqa: E704 + + def __getitem__(self, key: str) -> ParametersRecord | MetricsRecord | ConfigsRecord: + """Return the record for the specified key.""" + # Initialize default *Record + data = cast(RecordSetData, self.__dict__["_data"]) + if key == Record.PARAMS and key not in data.parameters_records: + data.parameters_records[key] = ParametersRecord() + elif key == Record.METRICS and key not in data.metrics_records: + data.metrics_records[key] = MetricsRecord() + elif key == Record.CONFIGS and key not in data.configs_records: + data.configs_records[key] = ConfigsRecord() + + # Return the record + if key in data.parameters_records: + return data.parameters_records[key] + if key in data.metrics_records: + return data.metrics_records[key] + if key in data.configs_records: + return data.configs_records[key] + raise KeyError(key) + + def __setitem__( + self, key: str, value: ParametersRecord | MetricsRecord | ConfigsRecord + ) -> None: + """Set the record for the specified key.""" + data = cast(RecordSetData, self.__dict__["_data"]) + builtin_key_name: str | None = None + record_type: ( + type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord] | None + ) = None + if key == Record.PARAMS: + builtin_key_name, record_type = "Record.PARAMS", ParametersRecord + elif key == Record.METRICS: + builtin_key_name, record_type = "Record.METRICS", MetricsRecord + elif key == Record.CONFIGS: + builtin_key_name, record_type = "Record.CONFIGS", ConfigsRecord + if record_type is not None and not isinstance(value, record_type): + raise TypeError( + f"Expected value of type `{record_type.__name__}` for built-in key " + f"`{builtin_key_name}`, but received type `{type(value).__name__}`." + ) + + if isinstance(value, ParametersRecord): + data.parameters_records[key] = value + elif isinstance(value, MetricsRecord): + data.metrics_records[key] = value + elif isinstance(value, ConfigsRecord): + data.configs_records[key] = value + else: + raise TypeError( + f"Expected {{`{ParametersRecord.__name__}`, " + f"`{MetricsRecord.__name__}`, `{ConfigsRecord.__name__}`}}, " + f"but received `{type(value).__name__}` for the value." + ) + def __repr__(self) -> str: """Return a string representation of this instance.""" flds = ("parameters_records", "metrics_records", "configs_records") diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 01260793cb4..91507ceaa2f 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -17,11 +17,12 @@ import pickle from collections import namedtuple from copy import deepcopy -from typing import Callable, Dict, List, OrderedDict, Type, Union +from typing import Any, Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest +from flwr.common.constant import Record from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.recordset_compat import ( parameters_to_parametersrecord, @@ -430,3 +431,29 @@ def test_recordset_repr() -> None: # Assert assert str(rs) == str(expected) + + +@pytest.mark.parametrize( + "rs, key, value, exc_type", + [ + # Invalid key type + (RecordSet(), 12, ConfigsRecord(), TypeError), + # Invalid value type + (RecordSet(), "resnet", 0.9, TypeError), + # Invalid record type for built-in keys + (RecordSet(), Record.PARAMS, MetricsRecord(), TypeError), + (RecordSet(), Record.PARAMS, ConfigsRecord(), TypeError), + (RecordSet(), Record.METRICS, ParametersRecord(), TypeError), + (RecordSet(), Record.METRICS, ConfigsRecord(), TypeError), + (RecordSet(), Record.CONFIGS, ParametersRecord(), TypeError), + (RecordSet(), Record.CONFIGS, MetricsRecord(), TypeError), + # Invalid reassignment (different record types) + (RecordSet({"model": ParametersRecord()}), "model", ConfigsRecord(), TypeError), + ], +) +def test_invalid_assignment( + rs: RecordSet, key: str, value: Any, exc_type: Type[Exception] +) -> None: + """Test invalid assignment to RecordSet.""" + with pytest.raises(exc_type): + rs[key] = value