Skip to content

Commit

Permalink
Less than eq constaint (#1822)
Browse files Browse the repository at this point in the history
* chore: less than eq constraint

* chore: unit tests for constraints and transforms
  • Loading branch information
Qazalbash committed Jun 25, 2024
1 parent 616a811 commit 209dad9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 0 deletions.
11 changes: 11 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,16 @@ def __eq__(self, other):
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _LessThanEq(_LessThan):
def __call__(self, x):
return x <= self.upper_bound

def __eq__(self, other):
if not isinstance(other, _LessThanEq):
return False
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _IntegerInterval(Constraint):
is_discrete = True

Expand Down Expand Up @@ -768,6 +778,7 @@ def tree_flatten(self):
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
less_than_eq = _LessThanEq
independent = _IndependentConstraint
integer_interval = _IntegerInterval
integer_greater_than = _IntegerGreaterThan
Expand Down
1 change: 1 addition & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,7 @@ def _transform_to_greater_than(constraint):


@biject_to.register(constraints.less_than)
@biject_to.register(constraints.less_than_eq)
def _transform_to_less_than(constraint):
return ComposeTransform(
[
Expand Down
1 change: 1 addition & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
"greater_than": T(constraints.greater_than, (_a(0.0),), dict()),
"greater_than_eq": T(constraints.greater_than_eq, (_a(0.0),), dict()),
"less_than": T(constraints.less_than, (_a(-1.0),), dict()),
"less_than_eq": T(constraints.less_than_eq, (_a(-1.0),), dict()),
"independent": T(
constraints.independent,
(constraints.greater_than(jnp.zeros((2,))),),
Expand Down
1 change: 1 addition & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def test_batched_recursive_linear_transform():
(constraints.interval(8, 13), (17,)),
(constraints.l1_ball, (4,)),
(constraints.less_than(-1), ()),
(constraints.less_than_eq(-1), ()),
(constraints.lower_cholesky, (15,)),
(constraints.open_interval(3, 4), ()),
(constraints.ordered_vector, (5,)),
Expand Down

0 comments on commit 209dad9

Please sign in to comment.