Skip to content

Commit

Permalink
start compat
Browse files Browse the repository at this point in the history
  • Loading branch information
Kabir Khan committed Jun 30, 2023
1 parent a822d3c commit 59f3f55
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 25 deletions.
77 changes: 55 additions & 22 deletions confection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping
from typing import Iterable, Sequence, Set, cast
from typing import Iterable, Sequence, Set, TypeVar, cast
from types import GeneratorType
from inspect import isclass
from dataclasses import dataclass, is_dataclass
from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH
from configparser import InterpolationMissingOptionError, InterpolationSyntaxError
Expand All @@ -9,6 +10,7 @@
from pathlib import Path
from pydantic import BaseModel, create_model, ValidationError, Extra
from pydantic.fields import FieldInfo
from pydantic.version import VERSION as PYDANTIC_VERSION
import srsly
import catalogue
import inspect
Expand All @@ -34,6 +36,8 @@
# Regex to detect whether a value contains a variable
VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}")

PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


class CustomInterpolation(ExtendedInterpolation):
def before_read(self, parser, section, option, value):
Expand Down Expand Up @@ -667,15 +671,43 @@ def alias_generator(name: str) -> str:
return name


def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo:
def copy_model_field(field: Union["FieldInfo", "ModelField"], type_: Type) -> Union["FieldInfo", "ModelField"]:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
"""
field_info = copy.deepcopy(field)
field_info.annotation = type_
if PYDANTIC_V2:
field_info.annotation = type_
else:
field_info.type_ = type_
return field_info


def get_model_config_extra(model: Type[BaseModel]) -> str:
if PYDANTIC_V2:
extra = model.model_config.get("extra", "forbid")
else:
extra = str(model.Config.extra) or "forbid"
assert isinstance(extra, str)
return extra



_ModelT = TypeVar("_ModelT", bound=BaseModel)


def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT:
return Schema.model_validate(**data) if PYDANTIC_V2 else Schema(**data)


def model_construct(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT:
return Schema.model_construct(**data) if PYDANTIC_V2 else Schema.construct(**data)


def model_dump(instance: BaseModel) -> Dict[str, Any]:
return instance.model_dump() if PYDANTIC_V2 else instance.dict()


class EmptySchema(BaseModel):
class Config:
extra = "allow"
Expand Down Expand Up @@ -860,17 +892,18 @@ def _fill(
)
validation[v_key] = getter_result
final[key] = getter_result
if isinstance(validation[v_key], GeneratorType):
# If value is a generator we can't validate type without
# consuming it (which doesn't work if it's infinite – see
# schedule for examples). So we skip it.
validation[v_key] = []
# if isinstance(validation[v_key], GeneratorType):
# # If value is a generator we can't validate type without
# # consuming it (which doesn't work if it's infinite – see
# # schedule for examples). So we skip it.
# validation[v_key] = []
elif hasattr(value, "items"):
field_type = EmptySchema
if key in schema.model_fields:
field = schema.model_fields[key]
field_type = field.annotation
if field_type is None or not issubclass(field_type, BaseModel):
fields = schema.model_fields if PYDANTIC_V2 else schema.__fields__
if key in fields:
field = fields[key]
field_type = field.annotation if PYDANTIC_V2 else field.type_
if field_type is None or not (isclass(field_type) and issubclass(field_type, BaseModel)):
# If we don't have a pydantic schema and just a type
field_type = EmptySchema
filled[key], validation[v_key], final[key] = cls._fill(
Expand All @@ -888,29 +921,29 @@ def _fill(
final[key] = list(final[key].values())
else:
filled[key] = value
# Prevent pydantic from consuming generator if part of a union
validation[v_key] = (
value if not isinstance(value, GeneratorType) else []
)
validation[v_key] = value
final[key] = value
# Now that we've filled in all of the promises, update with defaults
# from schema, and validate if validation is enabled
exclude = []
if validate:
try:
result = schema.model_validate(validation)
result = schema.model_validate(validation) if PYDANTIC_V2 else schema(**validation)
except ValidationError as e:

raise ConfigValidationError(
config=config, errors=e.errors(), parent=parent
) from None
else:
# Same as parse_obj, but without validation
result = schema.model_construct(**validation)
result = schema.model_construct(**validation) if PYDANTIC_V2 else schema.construct(**validation)
# If our schema doesn't allow extra values, we need to filter them
# manually because .construct doesn't parse anything
if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"):
fields = schema.model_fields.keys()
exclude = [k for k in result.model_fields_set if k not in fields]
extra_attr = get_model_config_extra(schema)
if extra_attr in ("forbid", "ignore"):
fields = schema.model_fields.keys() if PYDANTIC_V2 else schema.__fields__.keys()
result_fields_set = result.model_fields_set if PYDANTIC_V2 else result.__fields_set__
exclude = [k for k in result_fields_set if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
# Do a shallow serialization first
# If any of the sub-objects are Pydantic models, first check if they
Expand All @@ -927,7 +960,7 @@ def _fill(
continue
result_dict[k] = v
if isinstance(v, BaseModel) and k not in resolved_object_keys:
result_dict[k] = v.model_dump()
result_dict[k] = model_dump(v)
validation.update(result_dict)
filled, final = cls._update_from_parsed(validation, filled, final)
if exclude:
Expand Down
36 changes: 36 additions & 0 deletions confection/tests/test_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import dataclasses
from typing import Union, Iterable
import catalogue
from confection import registry, Config
from pydantic import BaseModel

# Create a new registry.
registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False)


# Define a dummy optimizer class.

@dataclasses.dataclass
class MyCoolOptimizer:
learn_rate: float
gamma: float


@registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float):
return MyCoolOptimizer(learn_rate=learn_rate, gamma=gamma)


if __name__ == "__main__":
# Load the config file from disk, resolve it and fetch the instantiated optimizer object.
cfg_str = """
[optimizer]
@optimizers = "my_cool_optimizer.v1"
learn_rate = 0.001
gamma = 1e-8
"""
config = Config().from_str(cfg_str)
resolved = registry.resolve(config)
optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08)

print(config, resolved, optimizer)
6 changes: 3 additions & 3 deletions confection/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_schedule():
assert isinstance(result, GeneratorType)

@my_registry.optimizers("test_optimizer.v2")
def test_optimizer2(rate: Generator) -> Generator:
def test_optimizer2(rate: Iterable[float]) -> Iterable[float]:
return rate

cfg = {
Expand All @@ -595,7 +595,7 @@ def test_optimizer2(rate: Generator) -> Generator:
assert isinstance(result, GeneratorType)

@my_registry.optimizers("test_optimizer.v3")
def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
def test_optimizer3(schedules: Dict[str, Iterable[float]]) -> Iterable[float]:
return schedules["rate"]

cfg = {
Expand All @@ -606,7 +606,7 @@ def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
assert isinstance(result, GeneratorType)

@my_registry.optimizers("test_optimizer.v4")
def test_optimizer4(*schedules: Generator) -> Generator:
def test_optimizer4(*schedules: Iterable[float]) -> Iterable[float]:
return schedules[0]


Expand Down

0 comments on commit 59f3f55

Please sign in to comment.