Skip to content

Commit

Permalink
Different supports in component distributions for mixture models (#1791)
Browse files Browse the repository at this point in the history
* different support of component distributions in mixture distributions

* Refactor MixtureGeneral class to support different component distribution supports + unit tests

* Refactor MixtureGeneral class to support different component distribution supports and changes in unit tests

* CI fail fix

* masking probs in mixture tesst
  • Loading branch information
Qazalbash committed May 12, 2024
1 parent 2b85765 commit 1d0cedb
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 11 deletions.
63 changes: 52 additions & 11 deletions numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class MixtureGeneral(_MixtureBase):
``mixture_size``.
:param component_distributions: A list of ``mixture_size``
:class:`~numpyro.distributions.Distribution` objects.
:param support: A :class:`~numpyro.distributions.constraints.Constraint`
object specifying the support of the mixture distribution. If not
provided, the support will be inferred from the component distributions.
**Example**
Expand All @@ -288,13 +291,36 @@ class MixtureGeneral(_MixtureBase):
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()
.. doctest::
>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(2) / 2.)
>>> component_dists = [
... dist.Normal(loc=0.0, scale=1.0),
... dist.HalfNormal(scale=0.3),
... ]
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists, support=dist.constraints.real)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()
"""

pytree_data_fields = ("_mixing_distribution", "_component_distributions")
pytree_data_fields = (
"_mixing_distribution",
"_component_distributions",
"_support",
)
pytree_aux_fields = ("_mixture_size",)

def __init__(
self, mixing_distribution, component_distributions, *, validate_args=None
self,
mixing_distribution,
component_distributions,
*,
support=None,
validate_args=None,
):
_check_mixing_distribution(mixing_distribution)

Expand All @@ -308,7 +334,7 @@ def __init__(
for d in component_distributions:
if not isinstance(d, Distribution):
raise ValueError(
"All elements of 'component_distributions' must be instaces of "
"All elements of 'component_distributions' must be instances of "
"numpyro.distributions.Distribution subclasses"
)
if len(component_distributions) != self.mixture_size:
Expand All @@ -320,11 +346,19 @@ def __init__(
# TODO: It would be good to check that the support of all the component
# distributions match, but for now we just check the type, since __eq__
# isn't consistently implemented for all support types.
support_type = type(component_distributions[0].support)
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError("All component distributions must have the same support.")
self._support = support
if support is None:
support_type = type(component_distributions[0].support)
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError(
"All component distributions must have the same support."
)
else:
assert isinstance(
support, constraints.Constraint
), "support must be a Constraint object"

self._mixing_distribution = mixing_distribution
self._component_distributions = component_distributions
Expand Down Expand Up @@ -357,6 +391,8 @@ def component_distributions(self):

@constraints.dependent_property
def support(self):
if self._support is not None:
return self._support
return self.component_distributions[0].support

@property
Expand Down Expand Up @@ -389,9 +425,14 @@ def component_sample(self, key, sample_shape=()):
return jnp.stack(samples, axis=self.mixture_dim)

def component_log_probs(self, value):
component_log_probs = jnp.stack(
[d.log_prob(value) for d in self.component_distributions], axis=-1
)
component_log_probs = []
for d in self.component_distributions:
log_prob = d.log_prob(value)
if (self._support is not None) and (not d._validate_args):
mask = d.support(value)
log_prob = jnp.where(mask, log_prob, -jnp.inf)
component_log_probs.append(log_prob)
component_log_probs = jnp.stack(component_log_probs, axis=-1)
return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs


Expand Down
45 changes: 45 additions & 0 deletions test/test_distributions_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def get_normal(batch_shape):
return normal


def get_half_normal(batch_shape):
"""Get parameterized HalfNormal with given batch shape."""
scale = jnp.ones(batch_shape)
half_normal = dist.HalfNormal(scale=scale)
return half_normal


def get_mvn(batch_shape):
"""Get parameterized MultivariateNormal with given batch shape."""
dimensions = 2
Expand Down Expand Up @@ -78,6 +85,44 @@ def test_mixture_broadcast_batch_shape(
_test_mixture(mixing_distribution, component_distribution)


@pytest.mark.parametrize("batch_shape", [(), (1,), (7,), (2, 5)])
@pytest.mark.filterwarnings(
"ignore:Out-of-support values provided to log prob method."
" The value argument should be within the support.:UserWarning"
)
def test_mixture_with_different_support(batch_shape):
mixing_probabilities = jnp.ones(2) / 2
mixing_distribution = dist.Categorical(probs=mixing_probabilities)
component_distribution = [
get_normal(batch_shape),
get_half_normal(batch_shape),
]
mixture = dist.MixtureGeneral(
mixing_distribution=mixing_distribution,
component_distributions=component_distribution,
support=dist.constraints.real,
)
assert mixture.batch_shape == batch_shape
sample_shape = (11,)
component_distribution[0]._validate_args = True
component_distribution[1]._validate_args = True
xx = component_distribution[0].sample(rng_key, sample_shape)
log_prob_0 = component_distribution[0].log_prob(xx)
log_prob_1 = component_distribution[1].log_prob(xx)
expected_log_prob = jax.scipy.special.logsumexp(
jnp.stack(
[
log_prob_0 + jnp.log(mixing_probabilities[0]),
log_prob_1 + jnp.log(mixing_probabilities[1]),
],
axis=-1,
),
axis=-1,
)
result = mixture.log_prob(xx)
assert jnp.allclose(result, expected_log_prob)


def _test_mixture(mixing_distribution, component_distribution):
# Create mixture
mixture = dist.Mixture(
Expand Down

0 comments on commit 1d0cedb

Please sign in to comment.