Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inverse Dirichlet Adaptive Loss #504

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .DS_Store
Binary file not shown.
55 changes: 55 additions & 0 deletions src/pinns_pde_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,35 @@ SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every; weight_
pde_loss_weights=pde_loss_weights, bc_loss_weights=bc_loss_weights, additional_loss_weights=additional_loss_weights)
end

"""
A way of adaptively reweighting the components of the loss function in the total sum such that BC_i loss weights are based on the gradient variance. In particular, the weights are chosen so that the
variances over the components of the back-propagated weighted gradients are equal across all objectives.

* `reweight_every`: how often to reweight the BC loss functions, measured in iterations. reweighting is somewhat expensive since it involves evaluating the gradient of each component loss function,
* `weight_change_inertia`: a real number that represents the inertia of the exponential moving average of the BC weight changes,
* `pde_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of PDE equations, which describes the weight the respective PDE loss has in the full loss sum,
* `bc_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of BC equations, which describes the initial weight the respective BC loss has in the full loss sum,
* `additional_loss_weights`: a scalar which describes the weight the additional loss function has in the full loss sum, this is currently not adaptive and will be constant with this adaptive loss,

from paper
Inverse Dirichlet weighting enables reliable training of physics informed neural networks
Suryanarayana Maddu, Dominik Sturm, Christian L Müller, and Ivo F Sbalzarini
https://iopscience.iop.org/article/10.1088/2632-2153/ac3712/pdf
with code reference
https://github.com/mosaic-group/inverse-dirichlet-pinn
"""
mutable struct InverseDirichletAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
reweight_every::Int64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this reweights_steps? That's a bit more consistent with the rest of the ecosystem

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose reweight_every to be consistent with the variable names chosen for MiniMaxAdaptiveLoss and GradientScaleAdaptiveLoss. Is the naming convention is reweight_steps across other files?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

progress_steps, timeseries_steps, p_steps, etc. It's not the biggest deal, but it's slightly more consistent with what else is out there in SciML.

weight_change_inertia::T
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss{T}(reweight_every; weight_change_inertia=0.5, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) where T <: Real
new(convert(Int64, reweight_every), convert(T, weight_change_inertia), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T))
end
end
# default to Float64
InverseDirichletAdaptiveLoss(args...; kwargs...) = InverseDirichletAdaptiveLoss{Float64}(args...; kwargs...)

"""
A way of adaptively reweighting the components of the loss function in the total sum such that the loss weights are maximized by an internal optimiser, which leads to a behavior where loss functions that have not been satisfied get a greater weight,
Expand Down Expand Up @@ -1406,6 +1435,31 @@ function discretize_inner_functions(pde_system::PDESystem, discretization::Physi
end
nothing
end
elseif adaloss isa InverseDirichletAdaptiveLoss
# TODO I think the numerator and denominator are not quite right here.
weight_change_inertia = discretization.adaptive_loss.weight_change_inertia
function run_loss_inverse_dirichlet_adaptive_loss(θ)
if iteration[1] % adaloss.reweight_every == 0
pde_grads_std = [std(Zygote.gradient(pde_loss_function, θ)[1]) for pde_loss_function in pde_loss_functions]
bc_grads_std = [std(Zygote.gradient(bc_loss_function, θ)[1]) for bc_loss_function in bc_loss_functions]
pde_grads_std_max = maximum(pde_grads_std)
bc_grads_std_max = maximum(bc_grads_std)
grads_std_max = max(pde_grads_std_max, bc_grads_std_max)

bc_loss_weights_proposed = grads_std_max ./ (bc_grads_std)
adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+ (1 .- weight_change_inertia) .* bc_loss_weights_proposed

pde_loss_weights_proposed = grads_std_max ./ (pde_grads_std)
adaloss.pde_loss_weights .= weight_change_inertia .* adaloss.pde_loss_weights .+ (1 .- weight_change_inertia) .* pde_loss_weights_proposed

logscalar(logger, grads_std_max, "adaptive_loss/grads_std_max", iteration[1])
logvector(logger, pde_grads_std, "adaptive_loss/pde_grad_std", iteration[1])
logvector(logger, bc_grads_std, "adaptive_loss/bc_grad_std", iteration[1])
logvector(logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1])
logvector(logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1])
end
nothing
end
elseif adaloss isa MiniMaxAdaptiveLoss
pde_max_optimiser = adaloss.pde_max_optimiser
bc_max_optimiser = adaloss.bc_max_optimiser
Expand Down Expand Up @@ -1488,6 +1542,7 @@ function discretize_inner_functions(pde_system::PDESystem, discretization::Physi
inner_pde_loss_functions=_pde_loss_functions, inner_bc_loss_functions=_bc_loss_functions)
end


# Convert a PDE problem into an OptimizationProblem
function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN)
discretized_functions = discretize_inner_functions(pde_system, discretization)
Expand Down
8 changes: 7 additions & 1 deletion test/adaptive_loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ using Random
nonadaptive_loss = NeuralPDE.NonAdaptiveLoss(pde_loss_weights=1, bc_loss_weights=1)
gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weights=1e3, bc_loss_weights=1)
adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights=1, bc_loss_weights=1)
adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss,adaptive_loss]
invdirichletadaptive_loss = NeuralPDE.InverseDirichletAdaptiveLoss(100, pde_loss_weights=1e3, bc_loss_weights=1)
adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss,adaptive_loss, invdirichletadaptive_loss]

maxiters=4000
seed=60

Expand Down Expand Up @@ -96,12 +98,16 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s
@show error_results_no_logs[1][:total_diff_rel]
@show error_results_no_logs[2][:total_diff_rel]
@show error_results_no_logs[3][:total_diff_rel]
@show error_results_no_logs[4][:total_diff_rel]

# accuracy tests, these work for this specific seed but might not for others
# note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent
@test error_results_no_logs[1][:total_diff_rel] < 0.4
@test error_results_no_logs[2][:total_diff_rel] < 0.4
@test error_results_no_logs[3][:total_diff_rel] < 0.4
@test error_results_no_logs[4][:total_diff_rel] < 0.4

#plots_diffs[1][:plot]
#plots_diffs[2][:plot]
#plots_diffs[3][:plot]
#lots_diffs[4][:plot]