diff --git a/.DS_Store b/.DS_Store index 75e07537ef..e486903e52 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/src/pinns_pde_solve.jl b/src/pinns_pde_solve.jl index a7a2dfa9fb..249ae3b32a 100644 --- a/src/pinns_pde_solve.jl +++ b/src/pinns_pde_solve.jl @@ -298,6 +298,35 @@ SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every; weight_ pde_loss_weights=pde_loss_weights, bc_loss_weights=bc_loss_weights, additional_loss_weights=additional_loss_weights) end +""" +A way of adaptively reweighting the components of the loss function in the total sum such that BC_i loss weights are based on the gradient variance. In particular, the weights are chosen so that the + variances over the components of the back-propagated weighted gradients are equal across all objectives. + +* `reweight_every`: how often to reweight the BC loss functions, measured in iterations. reweighting is somewhat expensive since it involves evaluating the gradient of each component loss function, +* `weight_change_inertia`: a real number that represents the inertia of the exponential moving average of the BC weight changes, +* `pde_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of PDE equations, which describes the weight the respective PDE loss has in the full loss sum, +* `bc_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of BC equations, which describes the initial weight the respective BC loss has in the full loss sum, +* `additional_loss_weights`: a scalar which describes the weight the additional loss function has in the full loss sum, this is currently not adaptive and will be constant with this adaptive loss, + +from paper +Inverse Dirichlet weighting enables reliable training of physics informed neural networks +Suryanarayana Maddu, Dominik Sturm, Christian L Müller, and Ivo F Sbalzarini +https://iopscience.iop.org/article/10.1088/2632-2153/ac3712/pdf +with code reference +https://github.com/mosaic-group/inverse-dirichlet-pinn +""" +mutable struct InverseDirichletAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss + reweight_every::Int64 + weight_change_inertia::T + pde_loss_weights::Vector{T} + bc_loss_weights::Vector{T} + additional_loss_weights::Vector{T} + SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss{T}(reweight_every; weight_change_inertia=0.5, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) where T <: Real + new(convert(Int64, reweight_every), convert(T, weight_change_inertia), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) + end +end +# default to Float64 +InverseDirichletAdaptiveLoss(args...; kwargs...) = InverseDirichletAdaptiveLoss{Float64}(args...; kwargs...) """ A way of adaptively reweighting the components of the loss function in the total sum such that the loss weights are maximized by an internal optimiser, which leads to a behavior where loss functions that have not been satisfied get a greater weight, @@ -1406,6 +1435,31 @@ function discretize_inner_functions(pde_system::PDESystem, discretization::Physi end nothing end + elseif adaloss isa InverseDirichletAdaptiveLoss + # TODO I think the numerator and denominator are not quite right here. + weight_change_inertia = discretization.adaptive_loss.weight_change_inertia + function run_loss_inverse_dirichlet_adaptive_loss(θ) + if iteration[1] % adaloss.reweight_every == 0 + pde_grads_std = [std(Zygote.gradient(pde_loss_function, θ)[1]) for pde_loss_function in pde_loss_functions] + bc_grads_std = [std(Zygote.gradient(bc_loss_function, θ)[1]) for bc_loss_function in bc_loss_functions] + pde_grads_std_max = maximum(pde_grads_std) + bc_grads_std_max = maximum(bc_grads_std) + grads_std_max = max(pde_grads_std_max, bc_grads_std_max) + + bc_loss_weights_proposed = grads_std_max ./ (bc_grads_std) + adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+ (1 .- weight_change_inertia) .* bc_loss_weights_proposed + + pde_loss_weights_proposed = grads_std_max ./ (pde_grads_std) + adaloss.pde_loss_weights .= weight_change_inertia .* adaloss.pde_loss_weights .+ (1 .- weight_change_inertia) .* pde_loss_weights_proposed + + logscalar(logger, grads_std_max, "adaptive_loss/grads_std_max", iteration[1]) + logvector(logger, pde_grads_std, "adaptive_loss/pde_grad_std", iteration[1]) + logvector(logger, bc_grads_std, "adaptive_loss/bc_grad_std", iteration[1]) + logvector(logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1]) + logvector(logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1]) + end + nothing + end elseif adaloss isa MiniMaxAdaptiveLoss pde_max_optimiser = adaloss.pde_max_optimiser bc_max_optimiser = adaloss.bc_max_optimiser @@ -1488,6 +1542,7 @@ function discretize_inner_functions(pde_system::PDESystem, discretization::Physi inner_pde_loss_functions=_pde_loss_functions, inner_bc_loss_functions=_bc_loss_functions) end + # Convert a PDE problem into an OptimizationProblem function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) discretized_functions = discretize_inner_functions(pde_system, discretization) diff --git a/test/adaptive_loss_tests.jl b/test/adaptive_loss_tests.jl index 93ab213210..f25719d5ba 100644 --- a/test/adaptive_loss_tests.jl +++ b/test/adaptive_loss_tests.jl @@ -18,7 +18,9 @@ using Random nonadaptive_loss = NeuralPDE.NonAdaptiveLoss(pde_loss_weights=1, bc_loss_weights=1) gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weights=1e3, bc_loss_weights=1) adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights=1, bc_loss_weights=1) -adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss,adaptive_loss] +invdirichletadaptive_loss = NeuralPDE.InverseDirichletAdaptiveLoss(100, pde_loss_weights=1e3, bc_loss_weights=1) +adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss,adaptive_loss, invdirichletadaptive_loss] + maxiters=4000 seed=60 @@ -96,12 +98,16 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s @show error_results_no_logs[1][:total_diff_rel] @show error_results_no_logs[2][:total_diff_rel] @show error_results_no_logs[3][:total_diff_rel] +@show error_results_no_logs[4][:total_diff_rel] + # accuracy tests, these work for this specific seed but might not for others # note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent @test error_results_no_logs[1][:total_diff_rel] < 0.4 @test error_results_no_logs[2][:total_diff_rel] < 0.4 @test error_results_no_logs[3][:total_diff_rel] < 0.4 +@test error_results_no_logs[4][:total_diff_rel] < 0.4 #plots_diffs[1][:plot] #plots_diffs[2][:plot] #plots_diffs[3][:plot] +#lots_diffs[4][:plot]