Skip to content

Commit

Permalink
small corrrections around new model_construct behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Kabir Khan committed Jun 30, 2023
1 parent 8ee7a24 commit 3b51749
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
27 changes: 15 additions & 12 deletions confection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,15 +675,17 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo:


class EmptySchema(BaseModel):
class Config:
extra = "allow"
arbitrary_types_allowed = True
model_config = {
"extra": "allow",
"arbitrary_types_allowed": True
}


class _PromiseSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
alias_generator = alias_generator
_promise_schema_config = {
"extra": "forbid",
"arbitrary_types_allowed": True,
"alias_generator": alias_generator
}


@dataclass
Expand Down Expand Up @@ -902,13 +904,14 @@ def _fill(
config=config, errors=e.errors(), parent=parent
) from None
else:
# Same as parse_obj, but without validation
result = schema.model_construct(**validation)
# Same as model_validate, but without validation
fields_set = set(schema.model_fields.keys())
result = schema.model_construct(fields_set, **validation)
# If our schema doesn't allow extra values, we need to filter them
# manually because .construct doesn't parse anything
if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"):
fields = schema.model_fields.keys()
exclude = [k for k in result.model_fields_set if k not in fields]
fields = result.model_fields_set
exclude = [k for k in dict(result).keys() if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
# Do a shallow serialization first
# If any of the sub-objects are Pydantic models, first check if they
Expand Down Expand Up @@ -1055,7 +1058,7 @@ def make_promise_schema(
else:
name = RESERVED_FIELDS.get(param.name, param.name)
sig_args[name] = (annotation, default)
sig_args["__config__"] = _PromiseSchemaConfig
sig_args["__config__"] = _promise_schema_config
return create_model("ArgModel", **sig_args)


Expand Down
8 changes: 5 additions & 3 deletions confection/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import GeneratorType
import pickle

from pydantic import BaseModel, StrictFloat, PositiveInt, constr
from pydantic import BaseModel, StrictFloat, PositiveInt
from pydantic.fields import Field
from pydantic.types import StrictBool

Expand Down Expand Up @@ -1205,7 +1205,9 @@ class TestSchemaContent(BaseModel):
a: str
b: int

model_config = {"extra": "forbid"}
model_config = {
"extra": "forbid",
}

class TestSchema(BaseModel):
cfg: TestSchemaContent
Expand Down Expand Up @@ -1282,7 +1284,7 @@ class BaseSchema(BaseModel):
assert filled["catsie"]["cute"] is True
with pytest.raises(ConfigValidationError):
my_registry.resolve(config, schema=BaseSchema)
filled2 = my_registry.fill(config, schema=BaseSchema, validate=False)
filled2 = my_registry.fill(config, schema=BaseSchema)
assert filled2["catsie"]["cute"] is True
resolved = my_registry.resolve(filled2)
assert resolved["catsie"] == "meow"
Expand Down

0 comments on commit 3b51749

Please sign in to comment.