Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Pydantic to v2 #31

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
42696e4
convert library to compatibility w pydantic v2, start fixing tests
Jun 6, 2023
ccedeb0
fix constr and model config access
Jun 6, 2023
05ce65b
add support for Pydantic models and dataclasses out of registered fun…
Jun 8, 2023
a822d3c
update reqs
Jun 30, 2023
59f3f55
start compat
Jun 30, 2023
8ee7a24
Merge branch 'main' of ssh://github.com/explosion/confection into kab…
Jun 30, 2023
3b51749
small corrrections around new model_construct behavior
Jun 30, 2023
2df560f
use Iterator instead of Generator and GeneratorType
Jun 30, 2023
5213981
don't validate in fill_without_resolve test
Jun 30, 2023
ff3b55f
bump reqs
Jun 30, 2023
1999477
refactor and fix for mypy
Jun 30, 2023
45f12ba
disable python 3.6
Jun 30, 2023
efd1737
rm extra python 3.6 ref
Jun 30, 2023
ca99729
check that pydantic and dataclass versions of Optimizer both work
Jun 30, 2023
caf94a2
Merge branch 'kab/pydantic-v2' of ssh://github.com/explosion/confecti…
Jul 1, 2023
beb567e
fix conflict
Jul 1, 2023
07870e7
move back to old Config nested class
Jul 1, 2023
36bc368
fix tests
Jul 1, 2023
0b31287
update from model_extra
Jul 7, 2023
2247265
fix pydantic generator equals
Jul 7, 2023
712e0ed
fixes for organization
Jul 7, 2023
ee6b10c
allow pydantic v1/v2 in reqs/setup and test both in CI
Jul 7, 2023
81fa915
only run CI push to main, not other branches
Jul 7, 2023
0fb6858
fix issue with model_validate
Jul 7, 2023
04354f2
fix filter warnings for tests
Jul 7, 2023
a5f2d5a
try run ci
Jul 7, 2023
21196e5
try run ci
Jul 7, 2023
20974fd
smaller test matrix
Jul 7, 2023
995b9af
print pydantic version before tests
Jul 7, 2023
3acfe90
fixes for mypy
Jul 7, 2023
0bac230
test fixes
Jul 7, 2023
6d95b50
re-enable mypy
Jul 7, 2023
7aa2207
Merge pull request #36 from explosion/kab/pydantic-v2-compat
Jul 7, 2023
aa7d13b
Undo unrelated changes to CI tests
adrianeboyd Aug 3, 2023
fc29ccd
Ignore tests for mypy
adrianeboyd Aug 3, 2023
65e69c1
Add mypy for pydantic v1
adrianeboyd Aug 3, 2023
a9dd2a3
Format
adrianeboyd Aug 3, 2023
36d1d1c
Lower typing_extensions pin for python 3.6
adrianeboyd Aug 3, 2023
de14431
Undo changes to typing_extensions
adrianeboyd Aug 3, 2023
3283e4a
Allow older pydantic v1 for tests for python 3.6
adrianeboyd Aug 3, 2023
cbada4b
Add CI test for spacy init config
adrianeboyd Aug 3, 2023
bf6624f
Merge branch 'main' of ssh://github.com/explosion/confection into kab…
Aug 3, 2023
680e224
black formatting
Aug 3, 2023
4f7d5b3
Fix spacy init issue (#37)
Aug 4, 2023
c4f78e8
Simplify pydantic requirements
adrianeboyd Aug 4, 2023
10b45d5
Merge remote-tracking branch 'upstream/main' into kab/pydantic-v2
adrianeboyd Aug 4, 2023
e6ac8ec
Merge branch 'kab/pydantic-v2' of ssh://github.com/explosion/confecti…
Aug 31, 2023
8ae9252
add a spacy init config regression test if spacy is installed
Aug 31, 2023
883d76b
rm unused import
Aug 31, 2023
d25d81a
rm spacy init config step in gha in favor of test in pytest
Aug 31, 2023
ba068d6
fix checks
Sep 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,22 @@ jobs:
- name: Install test requirements with Pydantic v1
run: |
python -m pip install -U -r requirements.txt
python -m pip install -U "pydantic<2.0"
python -m pip install -U "pydantic<2.0" "spacy"

- name: Run tests for Pydantic v1
run: |
python -c "import pydantic; print(pydantic.VERSION)"
python -m pytest --pyargs confection -Werror
python -m pytest --pyargs confection

- name: Install test requirements with Pydantic v2
run: |
python -m pip install -U -r requirements.txt
python -m pip install -U -r requirements.txt spacy
python -m pip install -U pydantic

- name: Run tests for Pydantic v2
run: |
python -c "import pydantic; print(pydantic.VERSION)"
python -m pytest --pyargs confection -Werror
python -m pytest --pyargs confection

- name: Test for import conflicts with hypothesis
run: |
Expand Down
188 changes: 160 additions & 28 deletions confection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NoSectionError,
ParsingError,
)
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from pathlib import Path
from types import GeneratorType
from typing import (
Expand All @@ -26,21 +26,24 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import catalogue
import srsly
from pydantic import BaseModel, ValidationError, create_model
from pydantic.fields import FieldInfo

try:
from pydantic.v1 import BaseModel, Extra, ValidationError, create_model
from pydantic.v1.fields import ModelField
from pydantic.v1.main import ModelMetaclass
except ImportError:
from pydantic import BaseModel, create_model, ValidationError, Extra # type: ignore
from pydantic.main import ModelMetaclass # type: ignore
from .util import PYDANTIC_V2, Decorator, SimpleFrozenDict, SimpleFrozenList

if PYDANTIC_V2:
from pydantic.v1.fields import ModelField # type: ignore
else:
from pydantic.fields import ModelField # type: ignore

from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401
Expand Down Expand Up @@ -689,10 +692,7 @@ def alias_generator(name: str) -> str:
return name


def copy_model_field(field: ModelField, type_: Any) -> ModelField:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
"""
def _copy_model_field_v1(field: ModelField, type_: Any) -> ModelField:
return ModelField(
name=field.name,
type_=type_,
Expand All @@ -704,6 +704,107 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField:
)


def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
"""
if PYDANTIC_V2:
field_info = copy.deepcopy(field)
field_info.annotation = type_ # type: ignore
return field_info
else:
return _copy_model_field_v1(field, type_) # type: ignore


def get_model_config_extra(model: Type[BaseModel]) -> str:
if PYDANTIC_V2:
extra = str(model.model_config.get("extra", "forbid")) # type: ignore
else:
extra = str(model.Config.extra) or "forbid" # type: ignore
assert isinstance(extra, str)
return extra


_ModelT = TypeVar("_ModelT", bound=BaseModel)


def _schema_is_pydantic_v2(Schema: Union[Type[BaseModel], BaseModel]) -> bool:
"""If `model_fields` attr is present, it means we have a schema or instance
of a pydantic v2 BaseModel. Even if we're using Pydantic V2, users could still
import from `pydantic.v1` and that would break our compat checks.
Schema (Union[Type[BaseModel], BaseModel]): Input schema or instance.
RETURNS (bool): True if the pydantic model is a v2 model or not
"""
return hasattr(Schema, "model_fields")


def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT:
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema):
return Schema.model_validate(data) # type: ignore
else:
return Schema.validate(data) # type: ignore


def model_construct(
Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any]
) -> _ModelT:
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema):
return Schema.model_construct(fields_set, **data) # type: ignore
else:
return Schema.construct(fields_set, **data) # type: ignore


def model_dump(instance: BaseModel) -> Dict[str, Any]:
if PYDANTIC_V2 and _schema_is_pydantic_v2(instance):
return instance.model_dump() # type: ignore
else:
return instance.dict()


def get_field_annotation(field: FieldInfo) -> Type:
return field.annotation if PYDANTIC_V2 else field.type_ # type: ignore


def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]:
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema):
return Schema.model_fields # type: ignore
else:
return Schema.__fields__ # type: ignore


def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]:
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema):
return Schema.model_fields_set # type: ignore
else:
return Schema.__fields_set__ # type: ignore


def get_model_extra(instance: BaseModel) -> Dict[str, FieldInfo]:
if PYDANTIC_V2 and _schema_is_pydantic_v2(instance):
return instance.model_extra # type: ignore
else:
return {}


def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo):
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema):
Schema.model_fields[key] = field # type: ignore
else:
Schema.__fields__[key] = field # type: ignore


def update_from_model_extra(
shallow_result_dict: Dict[str, Any], result: BaseModel
) -> None:
if PYDANTIC_V2 and _schema_is_pydantic_v2(result):
if result.model_extra is not None: # type: ignore
shallow_result_dict.update(result.model_extra) # type: ignore


def _safe_is_subclass(cls: type, expected: type) -> bool:
return inspect.isclass(cls) and issubclass(cls, BaseModel)


class EmptySchema(BaseModel):
class Config:
extra = "allow"
Expand Down Expand Up @@ -829,6 +930,7 @@ def _fill(
resolve: bool = True,
parent: str = "",
overrides: Dict[str, Dict[str, Any]] = {},
resolved_object_keys: Set[str] = set(),
) -> Tuple[
Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any]
]:
Expand All @@ -850,12 +952,14 @@ def _fill(
value = overrides[key_parent]
config[key] = value
if cls.is_promise(value):
if key in schema.__fields__ and not resolve:
model_fields = get_model_fields(schema)
if key in model_fields and not resolve:
# If we're not resolving the config, make sure that the field
# expecting the promise is typed Any so it doesn't fail
# validation if it doesn't receive the function return value
field = schema.__fields__[key]
schema.__fields__[key] = copy_model_field(field, Any)
field = model_fields[key]
new_field = copy_model_field(field, Any)
set_model_field(schema, key, new_field)
promise_schema = cls.make_promise_schema(value, resolve=resolve)
filled[key], validation[v_key], final[key] = cls._fill(
value,
Expand All @@ -864,6 +968,7 @@ def _fill(
resolve=resolve,
parent=key_parent,
overrides=overrides,
resolved_object_keys=resolved_object_keys,
)
reg_name, func_name = cls.get_constructor(final[key])
args, kwargs = cls.parse_args(final[key])
Expand All @@ -875,6 +980,11 @@ def _fill(
# We don't want to try/except this and raise our own error
# here, because we want the traceback if the function fails.
getter_result = getter(*args, **kwargs)

if isinstance(getter_result, BaseModel) or is_dataclass(
getter_result
):
resolved_object_keys.add(key)
else:
# We're not resolving and calling the function, so replace
# the getter_result with a Promise class
Expand All @@ -890,12 +1000,14 @@ def _fill(
validation[v_key] = []
elif hasattr(value, "items"):
field_type = EmptySchema
if key in schema.__fields__:
field = schema.__fields__[key]
field_type = field.type_
if not isinstance(field.type_, ModelMetaclass):
# If we don't have a pydantic schema and just a type
field_type = EmptySchema
fields = get_model_fields(schema)
if key in fields:
field = fields[key]
annotation = get_field_annotation(field)
if annotation is not None and _safe_is_subclass(
annotation, BaseModel
):
field_type = annotation
filled[key], validation[v_key], final[key] = cls._fill(
value,
field_type,
Expand All @@ -921,21 +1033,39 @@ def _fill(
exclude = []
if validate:
try:
result = schema.parse_obj(validation)
result = model_validate(schema, validation)
except ValidationError as e:
raise ConfigValidationError(
config=config, errors=e.errors(), parent=parent
) from None
else:
# Same as parse_obj, but without validation
result = schema.construct(**validation)
# Same as model_validate, but without validation
fields_set = set(get_model_fields(schema).keys())
result = model_construct(schema, fields_set, validation)
# If our schema doesn't allow extra values, we need to filter them
# manually because .construct doesn't parse anything
if schema.Config.extra in (Extra.forbid, Extra.ignore):
fields = schema.__fields__.keys()
exclude = [k for k in result.__fields_set__ if k not in fields]
if get_model_config_extra(schema) in ("forbid", "extra"):
result_field_names = get_model_fields_set(result)
exclude = [
k for k in dict(result).keys() if k not in result_field_names
]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
validation.update(result.dict(exclude=exclude_validation))
# Do a shallow serialization first
# If any of the sub-objects are Pydantic models, first check if they
# were resolved earlier from a registry. If they weren't resolved
# they are part of a nested schema and need to be serialized with
# model.dict()
# Allows for returning Pydantic models from a registered function
shallow_result_dict = dict(result)
update_from_model_extra(shallow_result_dict, result)
result_dict = {}
for k, v in shallow_result_dict.items():
if k in exclude_validation:
continue
result_dict[k] = v
if isinstance(v, BaseModel) and k not in resolved_object_keys:
result_dict[k] = model_dump(v)
validation.update(result_dict)
Copy link
Author

@kabirkhan kabirkhan Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change might seem a bit weird but it came up because Pydantic now treats Pydantic dataclasses and BaseModels the same way during serialization.

In Pydantic v1, calling model.dict() with an instance of a dataclass would not JSON serialize that dataclass, it would just return that dataclass.

This allowed for our Optimizer README example to work and resolve that dataclass instance.

e.g. with pydantic v1 this worked still:

import dataclasses
from typing import Union, Iterable
import catalogue
from confection import registry, Config

# Create a new registry.
registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False)


# Define a dummy optimizer class.
@dataclasses.dataclass
class MyCoolOptimizer:
    learn_rate: float
    gamma: float


@registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float):
    return MyCoolOptimizer(learn_rate, gamma)


# Load the config file from disk, resolve it and fetch the instantiated optimizer object.
config = Config().from_disk("./config.cfg")
resolved = registry.resolve(config)
optimizer = resolved["optimizer"]  # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08)

However, when swapped to pydantic v2, the final line here was resolving the optimzer to a dict.

...
optimizer = resolved["optimizer"]  # {"learn_rate": 0.001, "gamma": 1e-08}

This has actually always been a bug, because previously we could not make MyCoolOptimizer into a Pydantic model, it had to be a dataclass (or any other class that wasn't a Pydantic model).

With this code above, we could make MyCoolOptimzer into a Pydantic model and we'd get a Pydantic model back.

filled, final = cls._update_from_parsed(validation, filled, final)
if exclude:
filled = {k: v for k, v in filled.items() if k not in exclude}
Expand Down Expand Up @@ -969,6 +1099,8 @@ def _update_from_parsed(
# Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched.
elif str(type(value)) == "<class 'numpy.ndarray'>":
final[key] = value
elif isinstance(value, BaseModel) and isinstance(final[key], BaseModel):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can occur from the above change to pass through resolved pydantic models

final[key] = value
elif (
value != final[key] or not isinstance(type(value), type(final[key]))
) and not isinstance(final[key], GeneratorType):
Expand Down
Loading
Loading