Skip to content

Commit

Permalink
Fixes random_flax_module with flax.linen.BatchNorm (#1823)
Browse files Browse the repository at this point in the history
* filter oout tests waiting for next tfp release

* fix issue 1446

* add feddback (not working)

* feedbackl 2

* default handler

* rm prng_key from substitute

* remove class from function
  • Loading branch information
juanitorduz committed Jul 1, 2024
1 parent 5af9ebd commit d40f0e9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def process_message(self, msg):
return

if self.data is not None:
value = self.data.get(msg["name"])
value = self.data.get(msg.get("name"))
else:
value = self.substitute_fn(msg)

Expand Down
12 changes: 11 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.primitives import Messenger
from numpyro.util import (
_validate_model,
find_stack_level,
Expand All @@ -46,6 +47,12 @@
ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"])


class _substitute_default_key(Messenger):
def process_message(self, msg):
if msg["type"] == "prng_key" and msg["value"] is None:
msg["value"] = random.PRNGKey(0)


def log_density(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
Expand Down Expand Up @@ -660,9 +667,12 @@ def initialize_model(
data={
k: site["value"]
for k, site in model_trace.items()
if site["type"] in ["param"]
if site["type"] in ["param", "mutable"]
},
)

model = _substitute_default_key(model)

constrained_values = {
k: v["value"]
for k, v in model_trace.items()
Expand Down
13 changes: 12 additions & 1 deletion test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
random_haiku_module,
)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
Expand Down Expand Up @@ -256,6 +257,11 @@ def model():
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)


@pytest.mark.parametrize("dropout", [True, False])
@pytest.mark.parametrize("batchnorm", [True, False])
Expand Down Expand Up @@ -300,3 +306,8 @@ def model():
assert tr["nn$state"]["type"] == "mutable"
else:
assert set(tr.keys()) == {"nn$params", "x", "y"}

# test svi
guide = AutoDelta(model)
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi.run(random.PRNGKey(100), 10)

0 comments on commit d40f0e9

Please sign in to comment.