Skip to content

Commit

Permalink
In progress commit on branch delay.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 24, 2022
1 parent 92ee12c commit b0758f4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
16 changes: 14 additions & 2 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -194,6 +195,7 @@ def _loop_backsolve_bwd(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -232,6 +234,7 @@ def _loop_backsolve_bwd(
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -398,12 +401,16 @@ def __init__(self, **kwargs):
)
self.kwargs = kwargs

def loop(self, *, args, terms, saveat, init_state, **kwargs):
def loop(self, *, args, terms, saveat, init_state, delays, **kwargs):
if saveat.steps or saveat.dense:
raise NotImplementedError(
"Cannot use `adjoint=BacksolveAdjoint()` with "
"`saveat=Steps(steps=True)` or `saveat=Steps(dense=True)`."
)
if delays is not None:
raise NotImplementedError(
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
)

y = init_state.y
sentinel = object()
Expand All @@ -412,7 +419,12 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
)

final_state, aux_stats = _loop_backsolve(
(y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs
(y, args, terms),
self=self,
saveat=saveat,
init_state=init_state,
delays=delays,
**kwargs,
)

# We only allow backpropagation through `ys`; in particular not through
Expand Down
65 changes: 54 additions & 11 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools as ft
import warnings
from typing import Optional
from typing import Callable, Optional, Sequence

import equinox as eqx
import jax
Expand Down Expand Up @@ -35,7 +35,7 @@
ConstantStepSize,
StepTo,
)
from .term import AbstractTerm, WrapTerm
from .term import AbstractTerm, VectorFieldWrapper, WrapTerm


class _State(eqx.Module):
Expand Down Expand Up @@ -102,6 +102,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -130,21 +131,52 @@ def body_fun(state, inplace):
# step sizes, all that jazz.
#

(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms,
state.tprev,
state.tnext,
state.y,
args,
state.solver_state,
state.made_jump,
)
if delays is None:
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms,
state.tprev,
state.tnext,
state.y,
args,
state.solver_state,
state.made_jump,
)
else:
# TODO: double-check that these are the correct `ts_size` and
# `direction`.
history = DenseInterpolation(
ts=state.dense_ts,
ts_size=state.dense_save_index + 1,
interpolation_cls=solver.interpolation_cls,
infos=state.dense_infos,
direction=1,
)
history_vals = []
for delay in delays:
delay_val = delay(state.tprev, state.y, args)
history_val = history.evaluate(delay_val)
history_val.append(history_val)
history_vals = tuple(history_vals)

is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)

def _apply_history(x):
if is_vf_wrapper(x):
vector_field = jtu.Partial(x.vector_field, history=history_vals)
return VectorFieldWrapper(vector_field)
else:
return x

terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)
# TODO: write down implicit problem wrt dense_info, using `terms_`
(y, y_error, dense_info, solver_state, solver_result) = terms_ # ...

# e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
# we get a negative value for y, and then get a NaN vector field. (And then
# everything breaks.) See #143.
y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

# TODO: handle discontinuity detection for delays
error_order = solver.error_order(terms)
(
keep_step,
Expand Down Expand Up @@ -510,6 +542,7 @@ def diffeqsolve(
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None,
delays: Optional[Sequence[Callable]] = None,
max_steps: Optional[int] = 16**3,
throw: bool = True,
solver_state: Optional[PyTree] = None,
Expand Down Expand Up @@ -563,6 +596,10 @@ def diffeqsolve(
- `discrete_terminating_event`: A discrete event at which to terminate the solve
early. See the page on [Events](./events.md) for more information.
- `delays`: A tuple of functions, which describe the delays used in a delay
differential equation. See the page on [Delays](./delays.md) for more
information.
- `max_steps`: The maximum number of steps to take before quitting the computation
unconditionally.
Expand Down Expand Up @@ -626,6 +663,9 @@ def diffeqsolve(
# Initial set-up
#

if delays is not None and not saveat.dense:
raise ValueError("Delay differential equations require saving dense output")

# Error checking
if dt0 is not None:
msg = (
Expand Down Expand Up @@ -728,6 +768,8 @@ def _promote(yi):
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm),
)
if delays is not None:
delays = [lambda t, y, args, fn=fn: fn(t, y, args) * direction for fn in delays]

# Stepsize controller gets an opportunity to modify the solver.
# Note that at this point the solver could be anything so we must check any
Expand Down Expand Up @@ -841,6 +883,7 @@ def _promote(yi):
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
saveat=saveat,
t0=t0,
t1=t1,
Expand Down
14 changes: 14 additions & 0 deletions diffrax/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def is_vf_expensive(
return False


class VectorFieldWrapper(eqx.Module):
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]

def __call__(self, t, y, args):
return self.vector_field(t, y, args)


class ODETerm(AbstractTerm):
r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term
appearing on the right hand side of an ODE, in which the control is time.
Expand All @@ -169,6 +176,9 @@ class ODETerm(AbstractTerm):
"""
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]

def __init__(self, vector_field):
self.vector_field = VectorFieldWrapper(vector_field)

def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
return self.vector_field(t, y, args)

Expand Down Expand Up @@ -200,6 +210,10 @@ class _ControlTerm(AbstractTerm):
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]
control: AbstractPath

def __init__(self, vector_field, control):
self.vector_field = VectorFieldWrapper(vector_field)
self.control = control

def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
return self.vector_field(t, y, args)

Expand Down

0 comments on commit b0758f4

Please sign in to comment.