Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume commutative multiplication exactly when necessary #540

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

sethaxen
Copy link
Member

As noted in #504, there are a number of cases where types of rules were constrained to CommutativeMulNumber where commutation of multiplication did not need to be assumed. Likewise, there were places where commutativity was assumed but not enforced by a type constraint.

This PR fixes #504 by removing constraints where un-needed and adding others where needed. Because the trigonometric, hyperbolic, logarithmic, and exponential function rules all assume cummutativity, this puts constraints on a _large_number of rules. It's possible we don't want to do this, because there are certainly numeric types out there that are real (and therefore commutative) but do not subtype Real, and in this case these would not directly hit our rules. However, the approach taken here is much safer.

@@ -10,7 +10,7 @@ end
function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
function inv_pullback(ΔΩ)
return NoTangent(), -Ω' * ΔΩ * Ω'
return NoTangent(), Ω' * -ΔΩ * Ω'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask why you moved the minus?

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that if ΔΩ is an AbstractZero or a UniformScaling, then the negation is cheaper.

If it was -true * Ω' * ΔΩ * Ω' then I think you'd save a copy (since this gets fused into mul!).

I didn't follow this. How is this fused into the mul!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this was not an important change, and I'm happy to remove)

Copy link
Member

@mcabbott mcabbott Oct 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't think about those. For dense matrices there's a 4-arg method which fuses this:

julia> f1(Ω, ΔΩ) = Ω' * -ΔΩ * Ω';

julia> f2(Ω, ΔΩ) = -true * Ω' * ΔΩ * Ω';

julia> @btime f1(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 74.708 μs, mean 101.133 μs (6 allocations, 234.52 KiB. GC mean 6.51%)

julia> @btime f2(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
  min 73.125 μs, mean 92.756 μs (4 allocations, 156.34 KiB. GC mean 4.82%)

julia> @which -1 * ones(2,2) * ones(2,2) * ones(2,2)
*(α::Union{Real, Complex}, B::AbstractMatrix{<:Union{Real, Complex}}, C::AbstractMatrix{<:Union{Real, Complex}}, D::AbstractMatrix{<:Union{Real, Complex}}) in LinearAlgebra at /Users/me/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:1134

But with I, no fusion, hence f2 is slower. Maybe * should have some extra methods for cases with I.

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Feb 12, 2022
@oxinabox
Copy link
Member

I will leave this to @mcabbott to review.
and unsubscribe.
ping me if i am needed

@sethaxen
Copy link
Member Author

I propose we explicitly test rules with ::Number types using a non-commutative number. We could use Quaternions.Quaternion, but it lacks features like rand(::Quaternion) (see JuliaGeometry/Quaternions.jl#42). We could pirate in the test suite to add this, or we could roll our own minimal quaternion implementation, like Julia base does (https://github.com/JuliaLang/julia/blob/bb5b98e72a151c41471d8cc14cacb495d647fb7f/test/testhelpers/Quaternions.jl). I think this would be the cleanest option, as if Quaternions adds explicit ChainRules support, that could mask potential issues.

Thoughts, @mcabbott?

@mcabbott
Copy link
Member

Implementing them here (like Base) sounds fine, but is testing with Quaternions going to quadruple the time for tests to run? Would be nice to avoid that if possible.

This closes #275 presume?

Not deep, but I have notation comments:

  • Can I vote for using LinearAlgebra's RealOrComplex instead of the more verbose CommutativeMulNumber everywhere? They should be exactly equivalent, but the short one nobody will have to look up to check.
  • If touching so many lines, I wonder if we ought to standardise letters for sensitivities a little, if not perfectly. I'd like to suggest that Δx be reserved for forward mode, and maybe we should start writing ∇x for reverse, at least for arrays?

(To my eye, vs are usually too small to mark what's a pretty important difference, at least when muddled up in things with NoTangent and so on, big bold words.)

Comment on lines +106 to +110
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z)
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ)
muladd(x, y, z), muladd_pullback
end
Copy link
Member

@mcabbott mcabbott Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. here I think this is very clear, the pattern of where the Δs go is important:

function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
    return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end

but I'd like to make the corresponding rrule more distinct:

Suggested change
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z)
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ)
muladd(x, y, z), muladd_pullback
end
function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
muladd_pullback(∇Ω) = NoTangent(), ProjectTo(x)(∇Ω * y'), ProjectTo(y)(x' * ∇Ω), ProjectTo(z)(∇Ω)
return muladd(x, y, z), muladd_pullback
end

or perhaps ∂Ω is a sort-of lower-case ∇Ω for scalars?

And since it closes over x,y,z already, there is nothing gained by constructing projectors outside.

@sethaxen
Copy link
Member Author

Implementing them here (like Base) sounds fine, but is testing with Quaternions going to quadruple the time for tests to run? Would be nice to avoid that if possible.

I think a bunch of scalar rules should have Quaternion tests, but probably only a few array rules, so I don't expect testing time to increase by that much (but we can check).

This closes #275 presume?

It supersedes #275 (forgot about that one; it's pretty stale now) and closes #504.

* Can I vote for using LinearAlgebra's `RealOrComplex` instead of the more verbose `CommutativeMulNumber` everywhere? They should be exactly equivalent, but the short one nobody will have to look up to check.

I think though that CommutativeMulNumber carries with it the semantic reason why we limit to Real and Complex, which is a safeguard against someone in the future accidentally changing it to be more or less restrictive.

* If touching so many lines, I wonder if we ought to standardise letters for sensitivities a little, if not perfectly. I'd like to suggest that `Δx` be reserved for forward mode, and maybe we should start writing `∇x` for reverse, at least for arrays?

(To my eye, vs are usually too small to mark what's a pretty important difference, at least when muddled up in things with NoTangent and so on, big bold words.)

I agree, the dots and bars are not great (and they're unicode characters often missing from people's devices). I seem to recall at some point we expressly advised people to use Δ for both input tangents and cotangents and for output tangents and cotangents. I'm not opposed to for cotangents, but I'd prefer we discuss that and agree on it outside of this PR. IMO this PR shouldn't be changing notation at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs version bump Version needs to be incremented or set to -DEV in Project.toml
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use CommutativeMulNumber everywhere it is needed (and no where it isn't)
3 participants