From aafd968f7b2868f4a6dc1cb2212072b8e56404b0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 11 Jun 2024 11:18:02 +0100 Subject: [PATCH 1/2] update rs data --- src/py/flwr/common/constant.py | 23 ++++ src/py/flwr/common/record/recordset.py | 145 ++++++++++++++++++------- 2 files changed, 130 insertions(+), 38 deletions(-) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index b6d39b6e893..b5625ea2e50 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -88,3 +88,26 @@ class ErrorCode: def __new__(cls) -> ErrorCode: """Prevent instantiation.""" raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Record: + """Keys for records in a RecordSet.""" + + PARAMS = "p" + METRICS = "m" + CONFIGS = "c" + + def __new__(cls) -> Record: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Metric: + """Keys for metrics in a MetricsRecord.""" + + NUM_EXAMPLES = "num_examples" + LOSS = "loss" + + def __new__(cls) -> Metric: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index 74eed46ad86..46158b3f923 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -15,14 +15,61 @@ """RecordSet.""" +from __future__ import annotations + from dataclasses import dataclass -from typing import Dict, Optional, cast +from typing import Literal, TypeVar, cast, overload +from ..constant import Record from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord from .typeddict import TypedDict +T = TypeVar("T", ParametersRecord, MetricsRecord, ConfigsRecord) + + +class _Checker: + def __init__( + self, + data: RecordSetData, + type: type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord], + ) -> None: + self.data = data + self._type = type + + def check_key(self, key: str) -> None: + data = self.data + if not isinstance(key, str): + raise TypeError( + f"Expected `{str.__name__}`, but " + f"received `{type(key).__name__}` for the key." + ) + + orig_value: ParametersRecord | MetricsRecord | ConfigsRecord | None = None + if key in data.parameters_records: + orig_value = data.parameters_records[key] + elif key in data.metrics_records: + orig_value = data.metrics_records[key] + elif key in data.configs_records: + orig_value = data.configs_records[key] + + if orig_value is not None and not isinstance(orig_value, self._type): + raise KeyError( + f"Key '{key}' is already associated with " + f"a '{type(orig_value).__name__}', but a value " + f"of type '{self._type.__name__}' was provided." + ) + + def check_value( + self, value: ParametersRecord | MetricsRecord | ConfigsRecord + ) -> None: + if not isinstance(value, self._type): + raise TypeError( + f"Expected `{self._type.__name__}`, but received " + f"`{type(value).__name__}` for the value." + ) + @dataclass class RecordSetData: @@ -34,18 +81,21 @@ class RecordSetData: def __init__( self, - parameters_records: Optional[Dict[str, ParametersRecord]] = None, - metrics_records: Optional[Dict[str, MetricsRecord]] = None, - configs_records: Optional[Dict[str, ConfigsRecord]] = None, + parameters_records: dict[str, ParametersRecord] | None = None, + metrics_records: dict[str, MetricsRecord] | None = None, + configs_records: dict[str, ConfigsRecord] | None = None, ) -> None: + params_checker = _Checker(self, ParametersRecord) + metrics_checker = _Checker(self, MetricsRecord) + configs_checker = _Checker(self, ConfigsRecord) self.parameters_records = TypedDict[str, ParametersRecord]( - self._check_fn_str, self._check_fn_params + params_checker.check_key, params_checker.check_value ) self.metrics_records = TypedDict[str, MetricsRecord]( - self._check_fn_str, self._check_fn_metrics + metrics_checker.check_key, metrics_checker.check_value ) self.configs_records = TypedDict[str, ConfigsRecord]( - self._check_fn_str, self._check_fn_configs + configs_checker.check_key, configs_checker.check_value ) if parameters_records is not None: self.parameters_records.update(parameters_records) @@ -54,43 +104,15 @@ def __init__( if configs_records is not None: self.configs_records.update(configs_records) - def _check_fn_str(self, key: str) -> None: - if not isinstance(key, str): - raise TypeError( - f"Expected `{str.__name__}`, but " - f"received `{type(key).__name__}` for the key." - ) - - def _check_fn_params(self, record: ParametersRecord) -> None: - if not isinstance(record, ParametersRecord): - raise TypeError( - f"Expected `{ParametersRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - - def _check_fn_metrics(self, record: MetricsRecord) -> None: - if not isinstance(record, MetricsRecord): - raise TypeError( - f"Expected `{MetricsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - - def _check_fn_configs(self, record: ConfigsRecord) -> None: - if not isinstance(record, ConfigsRecord): - raise TypeError( - f"Expected `{ConfigsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - class RecordSet: """RecordSet stores groups of parameters, metrics and configs.""" def __init__( self, - parameters_records: Optional[Dict[str, ParametersRecord]] = None, - metrics_records: Optional[Dict[str, MetricsRecord]] = None, - configs_records: Optional[Dict[str, ConfigsRecord]] = None, + parameters_records: dict[str, ParametersRecord] | None = None, + metrics_records: dict[str, MetricsRecord] | None = None, + configs_records: dict[str, ConfigsRecord] | None = None, ) -> None: data = RecordSetData( parameters_records=parameters_records, @@ -117,6 +139,53 @@ def configs_records(self) -> TypedDict[str, ConfigsRecord]: data = cast(RecordSetData, self.__dict__["_data"]) return data.configs_records + @overload + def __getitem__(self, key: Literal["p"]) -> ParametersRecord: ... + + @overload + def __getitem__(self, key: Literal["m"]) -> MetricsRecord: ... + + @overload + def __getitem__(self, key: Literal["c"]) -> ConfigsRecord: ... + + def __getitem__(self, key: str) -> ParametersRecord | MetricsRecord | ConfigsRecord: + """Return the record for the specified key.""" + # Initialize default *Record + data = cast(RecordSetData, self.__dict__["_data"]) + if key == Record.PARAMS and key not in data.parameters_records: + data.parameters_records[key] = ParametersRecord() + elif key == Record.METRICS and key not in data.metrics_records: + data.metrics_records[key] = MetricsRecord() + elif key == Record.CONFIGS and key not in data.configs_records: + data.configs_records[key] = ConfigsRecord() + + # Return the record + if key in data.parameters_records: + return data.parameters_records[key] + if key in data.metrics_records: + return data.metrics_records[key] + if key in data.configs_records: + return data.configs_records[key] + raise KeyError(key) + + def __setitem__( + self, key: str, value: ParametersRecord | MetricsRecord | ConfigsRecord + ) -> None: + """Set the record for the specified key.""" + data = cast(RecordSetData, self.__dict__["_data"]) + if isinstance(value, ParametersRecord): + data.parameters_records[key] = value + elif isinstance(value, MetricsRecord): + data.metrics_records[key] = value + elif isinstance(value, ConfigsRecord): + data.configs_records[key] = value + else: + raise TypeError( + f"Expected {{`{ParametersRecord.__name__}`, " + f"`{MetricsRecord.__name__}`, `{ConfigsRecord.__name__}`}}, " + f"but received `{type(value).__name__}` for the value." + ) + def __repr__(self) -> str: """Return a string representation of this instance.""" flds = ("parameters_records", "metrics_records", "configs_records") From 5cef2df061b6131ab2c495cb551ad3dd2e151dd5 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 17 Jun 2024 16:58:39 +0100 Subject: [PATCH 2/2] add unittest --- src/py/flwr/common/record/recordset.py | 30 ++++++++++++++++----- src/py/flwr/common/record/recordset_test.py | 29 +++++++++++++++++++- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index 46158b3f923..9c17152c81a 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -33,12 +33,13 @@ class _Checker: def __init__( self, data: RecordSetData, - type: type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord], + r_type: type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord], ) -> None: self.data = data - self._type = type + self._type = r_type def check_key(self, key: str) -> None: + """Check the validity of the key.""" data = self.data if not isinstance(key, str): raise TypeError( @@ -55,7 +56,7 @@ def check_key(self, key: str) -> None: orig_value = data.configs_records[key] if orig_value is not None and not isinstance(orig_value, self._type): - raise KeyError( + raise TypeError( f"Key '{key}' is already associated with " f"a '{type(orig_value).__name__}', but a value " f"of type '{self._type.__name__}' was provided." @@ -64,6 +65,7 @@ def check_key(self, key: str) -> None: def check_value( self, value: ParametersRecord | MetricsRecord | ConfigsRecord ) -> None: + """Check the validity of the value.""" if not isinstance(value, self._type): raise TypeError( f"Expected `{self._type.__name__}`, but received " @@ -140,13 +142,13 @@ def configs_records(self) -> TypedDict[str, ConfigsRecord]: return data.configs_records @overload - def __getitem__(self, key: Literal["p"]) -> ParametersRecord: ... + def __getitem__(self, key: Literal["p"]) -> ParametersRecord: ... # noqa: E704 @overload - def __getitem__(self, key: Literal["m"]) -> MetricsRecord: ... + def __getitem__(self, key: Literal["m"]) -> MetricsRecord: ... # noqa: E704 @overload - def __getitem__(self, key: Literal["c"]) -> ConfigsRecord: ... + def __getitem__(self, key: Literal["c"]) -> ConfigsRecord: ... # noqa: E704 def __getitem__(self, key: str) -> ParametersRecord | MetricsRecord | ConfigsRecord: """Return the record for the specified key.""" @@ -173,6 +175,22 @@ def __setitem__( ) -> None: """Set the record for the specified key.""" data = cast(RecordSetData, self.__dict__["_data"]) + builtin_key_name: str | None = None + record_type: ( + type[ParametersRecord] | type[MetricsRecord] | type[ConfigsRecord] | None + ) = None + if key == Record.PARAMS: + builtin_key_name, record_type = "Record.PARAMS", ParametersRecord + elif key == Record.METRICS: + builtin_key_name, record_type = "Record.METRICS", MetricsRecord + elif key == Record.CONFIGS: + builtin_key_name, record_type = "Record.CONFIGS", ConfigsRecord + if record_type is not None and not isinstance(value, record_type): + raise TypeError( + f"Expected value of type `{record_type.__name__}` for built-in key " + f"`{builtin_key_name}`, but received type `{type(value).__name__}`." + ) + if isinstance(value, ParametersRecord): data.parameters_records[key] = value elif isinstance(value, MetricsRecord): diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 01260793cb4..91507ceaa2f 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -17,11 +17,12 @@ import pickle from collections import namedtuple from copy import deepcopy -from typing import Callable, Dict, List, OrderedDict, Type, Union +from typing import Any, Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest +from flwr.common.constant import Record from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.recordset_compat import ( parameters_to_parametersrecord, @@ -430,3 +431,29 @@ def test_recordset_repr() -> None: # Assert assert str(rs) == str(expected) + + +@pytest.mark.parametrize( + "rs, key, value, exc_type", + [ + # Invalid key type + (RecordSet(), 12, ConfigsRecord(), TypeError), + # Invalid value type + (RecordSet(), "resnet", 0.9, TypeError), + # Invalid record type for built-in keys + (RecordSet(), Record.PARAMS, MetricsRecord(), TypeError), + (RecordSet(), Record.PARAMS, ConfigsRecord(), TypeError), + (RecordSet(), Record.METRICS, ParametersRecord(), TypeError), + (RecordSet(), Record.METRICS, ConfigsRecord(), TypeError), + (RecordSet(), Record.CONFIGS, ParametersRecord(), TypeError), + (RecordSet(), Record.CONFIGS, MetricsRecord(), TypeError), + # Invalid reassignment (different record types) + (RecordSet({"model": ParametersRecord()}), "model", ConfigsRecord(), TypeError), + ], +) +def test_invalid_assignment( + rs: RecordSet, key: str, value: Any, exc_type: Type[Exception] +) -> None: + """Test invalid assignment to RecordSet.""" + with pytest.raises(exc_type): + rs[key] = value