Skip to content

Commit

Permalink
Merge pull request #2847 from SciML/nlls_modelingtoolkitize
Browse files Browse the repository at this point in the history
Handle modelingtoolkitize for nonlinearleastsquaresproblem
  • Loading branch information
ChrisRackauckas committed Jul 6, 2024
2 parents 622408b + 13c427a commit 4aeaaea
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/systems/nonlinear/modelingtoolkitize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ $(TYPEDSIGNATURES)
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
"""
function modelingtoolkitize(
prob::NonlinearProblem; u_names = nothing, p_names = nothing, kwargs...)
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem};
u_names = nothing, p_names = nothing, kwargs...)
p = prob.p
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})

Expand Down Expand Up @@ -37,13 +38,22 @@ function modelingtoolkitize(
end

if DiffEqBase.isinplace(prob)
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
prob.f(rhs, vars, params)
if prob isa NonlinearLeastSquaresProblem
rhs = ArrayInterface.restructure(
prob.f.resid_prototype, similar(prob.f.resid_prototype, Num))
prob.f(rhs, vars, params)
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(prob.f.resid_prototype)]...)
else
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
prob.f(rhs, vars, params)
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(rhs)]...)
end

else
rhs = prob.f(vars, params)
out_def = prob.f(prob.u0, prob.p)
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
end
out_def = prob.f(prob.u0, prob.p)
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)

sts = vec(collect(vars))
_params = params
Expand Down
14 changes: 14 additions & 0 deletions test/modelingtoolkitize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,17 @@ sys = modelingtoolkitize(prob)
end
end
end

## NonlinearLeastSquaresProblem

function nlls!(du, u, p)
du[1] = 2u[1] - 2
du[2] = u[1] - 4u[2]
du[3] = 0
end
u0 = [0.0, 0.0]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(nlls!, resid_prototype = zeros(3)), u0)
sys = modelingtoolkitize(prob)
@test length(equations(sys)) == 3
@test length(equations(structural_simplify(sys; fully_determined = false))) == 0

0 comments on commit 4aeaaea

Please sign in to comment.