Skip to content

Commit

Permalink
Fixed issue where config with * could not be filled (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Dec 31, 2023
1 parent 43c9281 commit 274c40a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
17 changes: 16 additions & 1 deletion confection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField:
default=field.default,
default_factory=field.default_factory,
required=field.required,
alias=field.alias,
)


Expand Down Expand Up @@ -912,6 +913,15 @@ def _fill(
# created via config blocks), only use its values
validation[v_key] = list(validation[v_key].values())
final[key] = list(final[key].values())

if ARGS_FIELD_ALIAS in schema.__fields__ and not resolve:
# If we're not resolving the config, make sure that the field
# expecting the promise is typed Any so it doesn't fail
# validation if it doesn't receive the function return value
field = schema.__fields__[ARGS_FIELD_ALIAS]
schema.__fields__[ARGS_FIELD_ALIAS] = copy_model_field(
field, Any
)
else:
filled[key] = value
# Prevent pydantic from consuming generator if part of a union
Expand All @@ -936,7 +946,12 @@ def _fill(
# manually because .construct doesn't parse anything
if schema.Config.extra in (Extra.forbid, Extra.ignore):
fields = schema.__fields__.keys()
exclude = [k for k in result.__fields_set__ if k not in fields]
# If we have a reserved field, we need to use its alias
field_set = [
k if k != ARGS_FIELD else ARGS_FIELD_ALIAS
for k in result.__fields_set__
]
exclude = [k for k in field_set if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
validation.update(result.dict(exclude=exclude_validation))
filled, final = cls._update_from_parsed(validation, filled, final)
Expand Down
22 changes: 22 additions & 0 deletions confection/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,28 @@ def catsie_567(*args: Optional[str], foo: str = "bar"):
assert my_registry.resolve(cfg)["config"] == "^_^"


def test_fill_config_positional_args_w_promise():
@my_registry.cats("catsie.v568")
def catsie_568(*args: str, foo: str = "bar"):
assert args[0] == "^(*.*)^"
assert foo == "baz"
return args[0]

@my_registry.cats("cat_promise.v568")
def cat_promise() -> str:
return "^(*.*)^"

cfg = {
"config": {
"@cats": "catsie.v568",
"*": {"promise": {"@cats": "cat_promise.v568"}},
}
}
filled = my_registry.fill(cfg, validate=True)
assert filled["config"]["foo"] == "bar"
assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}}


def test_make_config_positional_args_complex():
@my_registry.cats("catsie.v890")
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
Expand Down

0 comments on commit 274c40a

Please sign in to comment.