From 59f3f55a3111f22076348169f95aee0c1a25a7c9 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:01:18 -0700 Subject: [PATCH] start compat --- confection/__init__.py | 77 +++++++++++++++++++++++---------- confection/tests/test_2.py | 36 +++++++++++++++ confection/tests/test_config.py | 6 +-- 3 files changed, 94 insertions(+), 25 deletions(-) create mode 100644 confection/tests/test_2.py diff --git a/confection/__init__.py b/confection/__init__.py index 3358414..a590085 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,6 +1,7 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, Set, cast +from typing import Iterable, Sequence, Set, TypeVar, cast from types import GeneratorType +from inspect import isclass from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError @@ -9,6 +10,7 @@ from pathlib import Path from pydantic import BaseModel, create_model, ValidationError, Extra from pydantic.fields import FieldInfo +from pydantic.version import VERSION as PYDANTIC_VERSION import srsly import catalogue import inspect @@ -34,6 +36,8 @@ # Regex to detect whether a value contains a variable VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + class CustomInterpolation(ExtendedInterpolation): def before_read(self, parser, section, option, value): @@ -667,15 +671,43 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: +def copy_model_field(field: Union["FieldInfo", "ModelField"], type_: Type) -> Union["FieldInfo", "ModelField"]: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ field_info = copy.deepcopy(field) - field_info.annotation = type_ + if PYDANTIC_V2: + field_info.annotation = type_ + else: + field_info.type_ = type_ return field_info +def get_model_config_extra(model: Type[BaseModel]) -> str: + if PYDANTIC_V2: + extra = model.model_config.get("extra", "forbid") + else: + extra = str(model.Config.extra) or "forbid" + assert isinstance(extra, str) + return extra + + + +_ModelT = TypeVar("_ModelT", bound=BaseModel) + + +def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: + return Schema.model_validate(**data) if PYDANTIC_V2 else Schema(**data) + + +def model_construct(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: + return Schema.model_construct(**data) if PYDANTIC_V2 else Schema.construct(**data) + + +def model_dump(instance: BaseModel) -> Dict[str, Any]: + return instance.model_dump() if PYDANTIC_V2 else instance.dict() + + class EmptySchema(BaseModel): class Config: extra = "allow" @@ -860,17 +892,18 @@ def _fill( ) validation[v_key] = getter_result final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] + # if isinstance(validation[v_key], GeneratorType): + # # If value is a generator we can't validate type without + # # consuming it (which doesn't work if it's infinite – see + # # schedule for examples). So we skip it. + # validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema - if key in schema.model_fields: - field = schema.model_fields[key] - field_type = field.annotation - if field_type is None or not issubclass(field_type, BaseModel): + fields = schema.model_fields if PYDANTIC_V2 else schema.__fields__ + if key in fields: + field = fields[key] + field_type = field.annotation if PYDANTIC_V2 else field.type_ + if field_type is None or not (isclass(field_type) and issubclass(field_type, BaseModel)): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( @@ -888,29 +921,29 @@ def _fill( final[key] = list(final[key].values()) else: filled[key] = value - # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) + validation[v_key] = value final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled exclude = [] if validate: try: - result = schema.model_validate(validation) + result = schema.model_validate(validation) if PYDANTIC_V2 else schema(**validation) except ValidationError as e: + raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None else: # Same as parse_obj, but without validation - result = schema.model_construct(**validation) + result = schema.model_construct(**validation) if PYDANTIC_V2 else schema.construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything - if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"): - fields = schema.model_fields.keys() - exclude = [k for k in result.model_fields_set if k not in fields] + extra_attr = get_model_config_extra(schema) + if extra_attr in ("forbid", "ignore"): + fields = schema.model_fields.keys() if PYDANTIC_V2 else schema.__fields__.keys() + result_fields_set = result.model_fields_set if PYDANTIC_V2 else result.__fields_set__ + exclude = [k for k in result_fields_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first # If any of the sub-objects are Pydantic models, first check if they @@ -927,7 +960,7 @@ def _fill( continue result_dict[k] = v if isinstance(v, BaseModel) and k not in resolved_object_keys: - result_dict[k] = v.model_dump() + result_dict[k] = model_dump(v) validation.update(result_dict) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: diff --git a/confection/tests/test_2.py b/confection/tests/test_2.py new file mode 100644 index 0000000..290ecd7 --- /dev/null +++ b/confection/tests/test_2.py @@ -0,0 +1,36 @@ +import dataclasses +from typing import Union, Iterable +import catalogue +from confection import registry, Config +from pydantic import BaseModel + +# Create a new registry. +registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False) + + +# Define a dummy optimizer class. + +@dataclasses.dataclass +class MyCoolOptimizer: + learn_rate: float + gamma: float + + +@registry.optimizers.register("my_cool_optimizer.v1") +def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float): + return MyCoolOptimizer(learn_rate=learn_rate, gamma=gamma) + + +if __name__ == "__main__": + # Load the config file from disk, resolve it and fetch the instantiated optimizer object. + cfg_str = """ +[optimizer] +@optimizers = "my_cool_optimizer.v1" +learn_rate = 0.001 +gamma = 1e-8 + """ + config = Config().from_str(cfg_str) + resolved = registry.resolve(config) + optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08) + + print(config, resolved, optimizer) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 2e7f458..587b759 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -584,7 +584,7 @@ def test_schedule(): assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: + def test_optimizer2(rate: Iterable[float]) -> Iterable[float]: return rate cfg = { @@ -595,7 +595,7 @@ def test_optimizer2(rate: Generator) -> Generator: assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: + def test_optimizer3(schedules: Dict[str, Iterable[float]]) -> Iterable[float]: return schedules["rate"] cfg = { @@ -606,7 +606,7 @@ def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: + def test_optimizer4(*schedules: Iterable[float]) -> Iterable[float]: return schedules[0]