Skip to content

Commit

Permalink
Merge pull request #582 from tlm-adjoint/Tidying
Browse files Browse the repository at this point in the history
Tidying
  • Loading branch information
jrmaddison committed Jun 24, 2024
2 parents ccb5d9b + e93db9e commit d76b2fa
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/8_hessian_uq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
"source": [
"Next we solve the optimization problem. The considered problem seeks to infer the transport in the advection-diffusion equation, in terms of a stream function. Initially the tracer is concentrated on the left, and the observation taken a time $T$ later has the tracer moved to the right. The variable `m_0`, which will define our initial guess for the optimization, is set so that the velocity at the center has approximately the correct magnitude for this transport.\n",
"\n",
"As usual we need to define an appropriate inner product associated with derivatives. Here the prior defines a natural inner product – specifically we can use the prior covariance, $B$ to define an inner product for the dual space $V^*$."
"As usual we need to define an appropriate inner product associated with derivatives. Here the prior defines a natural inner product – specifically we can use the prior covariance, $B$, to define an inner product for the dual space $V^*$."
]
},
{
Expand Down
10 changes: 10 additions & 0 deletions tests/firedrake/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def test_5_optimization(setup_test, tmp_path):
@seed_test
def test_6_custom_operations(setup_test, tmp_path):
run_example_notebook("6_custom_operations.ipynb", tmp_path)


@pytest.mark.firedrake
@pytest.mark.example
@pytest.mark.skipif(complex_mode, reason="real only")
@pytest.mark.skipif(DEFAULT_COMM.size > 1, reason="serial only")
@pytest.mark.skip # Long example
@seed_test
def test_8_hessian_uq(setup_test, tmp_path):
run_example_notebook("8_hessian_uq.ipynb", tmp_path)
10 changes: 5 additions & 5 deletions tlm_adjoint/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
var_locked, var_zero)
from .manager import manager_disabled
from .petsc import (
PETScOptions, PETScVec, PETScVecInterface, attach_destroy_finalizer)
PETScOptions, PETScVecInterface, attach_destroy_finalizer)

from abc import ABC, abstractmethod
from collections.abc import Mapping, MutableMapping, Sequence
Expand Down Expand Up @@ -950,9 +950,9 @@ def pc_fn(u, b):
self._A.nullspace.correct_soln(u)
self._A.nullspace.correct_rhs(b_c)

u_petsc = PETScVec(self._A.arg_space)
u_petsc = self._A.arg_space.new_vec()
u_petsc.to_petsc(u)
b_petsc = PETScVec(self._A.action_space)
b_petsc = self._A.action_space.new_vec()
b_petsc.to_petsc(b_c)
del b_c

Expand Down Expand Up @@ -1310,8 +1310,8 @@ def solve(self, u, v):
u = packed(u)
v = packed(v)

u_petsc = PETScVec(self._A.arg_space)
v_petsc = PETScVec(self._A.arg_space)
u_petsc = self._A.arg_space.new_vec()
v_petsc = self._A.arg_space.new_vec()
with paused_space_type_checking():
u_petsc.to_petsc(u)
v_petsc.to_petsc(v)
Expand Down
6 changes: 3 additions & 3 deletions tlm_adjoint/checkpoint_schedules/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,11 @@ def finalize(self, n):

if n < 1:
raise ValueError("n must be positive")
if self._max_n is None:
if self._n >= n:
if self.max_n is None:
if self.n >= n:
self._n = n
self._max_n = n
else:
raise RuntimeError("Invalid checkpointing state")
elif self._n != n or self._max_n != n:
elif self.n != n or self.max_n != n:
raise RuntimeError("Invalid checkpointing state")
8 changes: 4 additions & 4 deletions tlm_adjoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _store(self, *, x_id=None, key=None, value, refs=True, copy):
return key, self._storage[key]

def _add_initial_condition(self, *, x_id, value, refs=True, copy):
if self._store_ics and x_id not in self._seen_ics:
if self.store_ics and x_id not in self._seen_ics:
key, _ = self._store(x_id=x_id, value=value, refs=refs, copy=copy)
if key not in self._refs_keys:
self._cp_keys.add(key)
Expand All @@ -264,7 +264,7 @@ def add_equation(self, n, i, eq, *, deps=None, nl_deps=None):

self.update_keys(n, i, eq)

if self._store_ics:
if self.store_ics:
for eq_x in eq.X():
self._seen_ics.add(var_id(eq_x))

Expand Down Expand Up @@ -295,7 +295,7 @@ def add_equation_data(self, n, i, eq, *, nl_deps=None):
if nl_deps is None:
nl_deps = eq_nl_deps

if self._store_ics:
if self.store_ics:
for eq_dep, dep in zip(eq_nl_deps, nl_deps):
self._add_initial_condition(
x_id=var_id(eq_dep), value=dep,
Expand All @@ -314,7 +314,7 @@ def copy(x):
def copy(x):
return _copy

if self._store_data:
if self.store_data:
if (n, i) in self._data:
raise RuntimeError("Non-linear dependency data already stored")

Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/firedrake/block_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DirichletBCNullspace(Nullspace):
bcs : :class:`firedrake.bcs.DirichletBC` or \
Sequence[:class:`firedrake.bcs.DirichletBC`]
Homogeneous Dirichlet boundary conditions
Homogeneous Dirichlet boundary conditions.
alpha : scalar
Defines the linear constraint matrix :math:`S = \\alpha M`.
"""
Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/hessian_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def spectral_approximation_solve(self, b):
----------
b : variable or Sequence[variable]
Defines :math:`b`.
The conjugate of the right-hand-side :math:`b`.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def rdtype(self):
"""The real data type associated with the space.
"""

return self._dtype(0.0).real.dtype.type
return self.dtype(0.0).real.dtype.type

@property
def comm(self):
Expand Down
10 changes: 5 additions & 5 deletions tlm_adjoint/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .hessian import GeneralHessian as Hessian
from .manager import manager as _manager
from .petsc import (
PETScOptions, PETScVec, PETScVecInterface, attach_destroy_finalizer,
PETScOptions, PETScVecInterface, attach_destroy_finalizer,
petsc_option_setdefault)
from .manager import (
compute_gradient, manager_disabled, reset_manager, restore_manager,
Expand Down Expand Up @@ -528,13 +528,13 @@ def objective_gradient(taols, x, g):
taols.setFromOptions()
taols.setUp()

x = PETScVec(vec_interface)
x = vec_interface.new_vec()
x.to_petsc(X)

g = PETScVec(vec_interface)
g = vec_interface.new_vec()
g.to_petsc(old_Fp_val)

s = PETScVec(vec_interface)
s = vec_interface.new_vec()
s.to_petsc(minus_P)
s.vec.scale(-1.0)

Expand Down Expand Up @@ -1105,7 +1105,7 @@ def solve(self, M):
"""

M = packed(M)
x = PETScVec(self._vec_interface)
x = self._vec_interface.new_vec()
x.to_petsc(M)

self._M[0] = tuple(var_new(m, static=var_is_static(m),
Expand Down
2 changes: 1 addition & 1 deletion tlm_adjoint/overloaded_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def rdtype(self):
"""The real data type associated with the space.
"""

return self._dtype(0.0).real.dtype.type
return self.dtype(0.0).real.dtype.type

@property
def comm(self):
Expand Down
26 changes: 13 additions & 13 deletions tlm_adjoint/tlm_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def new(self, cp_method=None, cp_parameters=None):
raise TypeError("cp_parameters must be supplied if cp_method is "
"supplied")

return EquationManager(comm=self._comm, cp_method=cp_method,
return EquationManager(comm=self.comm, cp_method=cp_method,
cp_parameters=cp_parameters)

@gc_disabled
Expand Down Expand Up @@ -225,7 +225,7 @@ def reset(self, cp_method=None, cp_parameters=None):
"supplied")

self.drop_references()
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)

self._annotation_state = AnnotationState.STOPPED
self._tlm_state = TangentLinearState.STOPPED
Expand Down Expand Up @@ -364,20 +364,20 @@ def configure_checkpointing(self, cp_method, cp_parameters):
cp_path = cp_parameters.get("path", "checkpoints~")
cp_format = cp_parameters.get("format", "hdf5")

self._comm.barrier()
if self._comm.rank == 0:
self.comm.barrier()
if self.comm.rank == 0:
if not os.path.exists(cp_path):
os.makedirs(cp_path)
self._comm.barrier()
self.comm.barrier()

if cp_format == "pickle":
cp_disk = PickleCheckpoints(
os.path.join(cp_path, f"checkpoint_{self._id:d}_"),
comm=self._comm)
comm=self.comm)
elif cp_format == "hdf5":
cp_disk = HDF5Checkpoints(
os.path.join(cp_path, f"checkpoint_{self._id:d}_"),
comm=self._comm)
comm=self.comm)
else:
raise ValueError(f"Unrecognized checkpointing format: "
f"{cp_format:s}")
Expand Down Expand Up @@ -932,7 +932,7 @@ def action_forward(cp_action):

storage_state = storage.popleft()
assert storage_state == (n1, i)
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)
cp_n = cp_action.n1
assert cp_n <= n + 1
assert cp_n < n + 1 or len(storage) == 0
Expand All @@ -959,7 +959,7 @@ def action_read(cp_action):

storage = ReplayStorage(self._blocks, cp_n, n + 1,
transpose_deps=transpose_deps)
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)
initialize_storage_cp = True
storage.update(self._cp.initial_conditions(cp=False,
refs=True,
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def new_block(self):
"""

self.drop_references()
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)

if self._annotation_state in {AnnotationState.STOPPED,
AnnotationState.FINAL}:
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def finalize(self):
"""

self.drop_references()
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)

if self._annotation_state == AnnotationState.FINAL:
return
Expand Down Expand Up @@ -1299,7 +1299,7 @@ def callback(J_i, n, i, eq, adj_X):
dJ[J_i] = tuple(map(var_copy, adj_X))

self.drop_references()
garbage_cleanup(self._comm)
garbage_cleanup(self.comm)

for B in Bs:
assert B.is_empty()
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def action_pass(cp_action):
pass
del action

garbage_cleanup(self._comm)
garbage_cleanup(self.comm)
return Js_packed.unpack(tuple(M_packed.unpack(dJ_) for dJ_ in dJ))


Expand Down

0 comments on commit d76b2fa

Please sign in to comment.