Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid NaN (co)tangents for sqrt(0) #599

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Avoid NaN (co)tangents for sqrt(0) #599

wants to merge 2 commits into from

Conversation

sethaxen
Copy link
Member

This PR fixes #576 by treating zero (co)tangents in sqrt as strong zeros.

It partially fixes FluxML/Zygote.jl#1101 also, but to fix it entirely, we would need to do the same thing to the rule for ^.

Benchmark

This simple benchmark indicates that the performance decrease from this modified rule in Zygote is not extreme.

julia> using Zygote, BenchmarkTools, Random

julia> x = zeros(1_000);

julia> y = rand(MersenneTwister(42), 1_000);

julia> f(x) = sum(x -> max(sqrt(x), 1), x)
f (generic function with 1 method)

julia> Zygote.gradient(f, x)
([NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN    NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN],)

julia> b1_1 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
 Range (min  max):   8.663 μs  730.817 μs  ┊ GC (min  max):  0.00%  93.96%
 Time  (median):      9.755 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   13.060 μs ±  31.542 μs  ┊ GC (mean ± σ):  15.00% ±  6.24%

  ▅██▆▇▆▅▄▃▂▁▂▂▂▂▂▁▁▁▁▁▂▂▂▁▁                       ▁▁▁         ▂
  ███████████████████████████▇█▇▇▇▆▇▇▆▆▄▆▆▆▆▇█▇▇▆▆▆████▇▆▅▄▇▆▆ █
  8.66 μs       Histogram: log(frequency) by time        26 μs <

 Memory estimate: 71.22 KiB, allocs estimate: 31.

julia> b2_1 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
 Range (min  max):   8.612 μs  532.969 μs  ┊ GC (min  max):  0.00%  93.85%
 Time  (median):      9.335 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.282 μs ±  30.078 μs  ┊ GC (mean ± σ):  16.07% ±  6.44%

  ▅██▆▆▆▅▃▂ ▁  ▂▂▂▂▁▁     ▁                                    ▂
  ██████████████████████████████▇▇▆▆▅▆▆▄▇▅▄▅▅▆▅▅▆▅▅▆▆▆▆▅▅▅▄▄▅▅ █
  8.61 μs       Histogram: log(frequency) by time      23.9 μs <

 Memory estimate: 71.22 KiB, allocs estimate: 31.

julia> function ChainRulesCore.frule((_, Δx), ::typeof(sqrt), x::Number)
           Ω = sqrt(x)
           ∂Ω = Δx / 2Ω
           return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω)
       end

julia> function ChainRulesCore.rrule(::typeof(sqrt), x::Number)
           Ω = sqrt(x)
           function sqrt_pullback(ΔΩ)
               ∂x = ΔΩ / 2conj(Ω)
               return (
                   NoTangent(),
                   ProjectTo(x)(ifelse(iszero(ΔΩ) & iszero(x), zero(∂x), ∂x))
               )
           end
           return Ω, sqrt_pullback
       end

julia> Zygote.gradient(f, x)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> b1_2 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   8.891 μs   1.357 ms  ┊ GC (min  max):  0.00%  96.56%
 Time  (median):      9.832 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.484 μs ± 46.625 μs  ┊ GC (mean ± σ):  15.40% ±  4.09%

   ▆██▆▅▃▃▃▃▂▁ ▁▁▁                                            ▂
  ▇████████████████▇▇█▇▆▄▃▄▃▃▄▄▄▅▃▄▂▄▃▂▄▄▃▄▄▄▂▄▄▃▄▅▄▅▄▆▇▇▆▇▆▆ █
  8.89 μs      Histogram: log(frequency) by time      25.3 μs <

 Memory estimate: 86.84 KiB, allocs estimate: 31.

julia> b2_2 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   9.066 μs   1.304 ms  ┊ GC (min  max):  0.00%  96.15%
 Time  (median):      9.892 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   11.970 μs ± 43.215 μs  ┊ GC (mean ± σ):  14.43% ±  3.96%

     ▁▆██▆▄▁                                                   
  ▂▃▅████████▆▅▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂ ▃
  9.07 μs         Histogram: frequency by time        15.7 μs <

 Memory estimate: 86.84 KiB, allocs estimate: 31.

julia> judge(mean(b1_2), mean(b1_1))
BenchmarkTools.TrialJudgement: 
  time:   -4.41% => invariant (5.00% tolerance)
  memory: +21.94% => regression (1.00% tolerance)


julia> judge(mean(b2_2), mean(b2_1))
BenchmarkTools.TrialJudgement: 
  time:   -2.54% => invariant (5.00% tolerance)
  memory: +21.94% => regression (1.00% tolerance)

@sethaxen sethaxen requested a review from mcabbott March 12, 2022 12:18
@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Mar 12, 2022
function frule((_, Δx), ::typeof(sqrt), x::Number)
Ω = sqrt(x)
∂Ω = Δx / 2Ω
return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason to use ifelse instead of a ternary operator (which does not require to evaluate both branches)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was that so long as the type of ∂Ω is inferrable, the two branches do no extra work, and the use of & and ifelse both could perform better if this is used in an inner loop and potentially allow Zygote to perform better for higher order AD (since Zygote tends so be slow when hitting control flow but has a special rule for ifelse). However, I was unable to devise a benchmark that showed a substantial difference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that in older Julia versions there were cases where we could improve performance in SciML by avoiding zero or moving it out of loops. But I couldn't reproduce this with a simple example immediately, maybe it's not relevant here and/or fixed in recent Julia versions.

Copy link
Member

@devmotion devmotion Mar 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example where it matters:

julia> using BenchmarkTools

julia> function f(x)    
           s = zero(x)      
           for i in 1:10    
               s += iseven(i) ? zero(x) : x
           end   
           return s
       end             
f (generic function with 1 method) 
                                                     
julia> function g(x)       
           s = zero(x)      
           for i in 1:10
               s += ifelse(iseven(i), zero(x), x)
           end
           return s
       end                 
g (generic function with 1 method)

julia> @btime f($(big"1.0"));
  45.640 μs (3002 allocations: 164.17 KiB)

julia> @btime g($(big"1.0"));
  56.341 μs (4002 allocations: 218.86 KiB)

@mcabbott
Copy link
Member

This seems fine. How many other functions will need this? cbrt is currently a scalar rule. Powers are their own messy thing.

@sethaxen
Copy link
Member Author

This seems fine. How many other functions will need this? cbrt is currently a scalar rule. Powers are their own messy thing.

At first glance, /, \, ^, inv, sqrt, cbrt, log, log2, log10, log1p, and a bunch of inverse trig/hyperbolic functions.

I'm reluctant to add custom frules/rrules for all of these without first at least checking if we see a significant performance decrease by making zero (co)tangents strong zeros in the @scalar_rule macro, so perhaps before merging this I should open a PR on ChainRulesCore with a benchmark.

@sethaxen
Copy link
Member Author

I opened a PR to ChainRulesCore that would supersede this one if merged: JuliaDiff/ChainRulesCore.jl#551

@mcabbott
Copy link
Member

Functions like inv, log etc. are a slightly different class to sqrt, since the primal is infinite.

The motivating case for sqrt is I think something like f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that for inv, less trivial than g(x) = inv(inv(x))?

@sethaxen
Copy link
Member Author

Functions like inv, log etc. are a slightly different class to sqrt, since the primal is infinite.

Is this difference important though? There are plenty of cases where in a well-behaved primal function intermediate can be non-finite, resulting in introduction of NaNs. Here's another one that hits users of lower-truncated normal distributions in Turing:

julia> using StatsFuns

julia> normcdf(0.0, 1.0, Inf)  # a constant function for all finite values of mu and sigma
1.0

julia> FiniteDifferences.grad(central_fdm(5, 1), x -> normcdf(0.0, x, Inf), 1.0)
(6.085449991639748e-14,)

julia> Zygote.gradient(x -> normcdf(0.0, x, Inf), 1.0)
(NaN,)

This happens because the gradient of erfc at Inf is 0, but when that gets pulled back through (x - mu)/sigma for x=Inf, we have an infinite partial for /, so a NaN is introduced. This case is also resolved by treating zero (co)tangents as hard zeros.

The motivating case for sqrt is I think something like f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that for inv, less trivial than g(x) = inv(inv(x))?

Perhaps, but I don't see inv(inv(x)) (or log(exp(x))) as being any more trivial than sqrt(x^2).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs version bump Version needs to be incremented or set to -DEV in Project.toml
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Current rules for sqrt produce NaN for zero primal and (co)tangents NaN gradients for sqrt
3 participants