Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid device-to-host copy in ∇getindex! #801

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

pxl-th
Copy link

@pxl-th pxl-th commented Jun 21, 2024

  • Implement kernel for accumulation in ∇getindex! for generic index types.
    Fallback to old implementation if KernelAbstractions backend does not support atomics (currently this is Metal.jl).
    This is implemented as an extension which is triggered by Atomix, KernelAbstractions, GPUArrays (which every GPU backend has as dependencies). This way we won't need to manually specify every GPU backend.

Closes #800.

Benchmarking:

using AMDGPU
using BenchmarkTools
using Zygote
using Random

function main()
    x = ROCArray(zeros(Float32, 8192, 1, 32))
    ids = randperm(8192)
    Δ = ROCArray(ones(Float32, 1))

    y, back = Zygote.pullback(x) do x
        xd = x[ids, :, :]
        sum(xd; dims=(1:ndims(xd)...,))
    end
    back(Δ)
    AMDGPU.synchronize()

    @btime AMDGPU.@sync $back($Δ)
    return
end

Before:

julia> main()
  1.528 ms (69 allocations: 2.00 MiB)

Now:

julia> main()
  102.913 μs (223 allocations: 7.84 KiB)

@oxinabox oxinabox requested a review from mcabbott June 21, 2024 12:24
@oxinabox
Copy link
Member

Can this package extension live in GPUArrays or one of the other packages this is a extension on?

It;s out general policy not to have rules that are applicable to specific packages in ChainRules.jl but rather just have the rules for Base and StdLibs.
Its beyond our capacity to maintain rules for the whole ecosystem.
Generally the authors of the packages themselves are best placed to do that, since the rules often depend on the fine details of the packages in question

This case is a little less clear since it overloading ChainRules.∇getindex! rather than ChainRulesCore.rrule(::typeof(getindex), ...)
but perhaps it can be written in as a normal rrule or two.

@pxl-th
Copy link
Author

pxl-th commented Jun 21, 2024

Would it be fine to remove ∇getindex!(dx::AbstractGPUArray, dy, inds...) from ChainRules then?
Otherwise it'd break precompilation I think, if another package overwrites it.

@pxl-th
Copy link
Author

pxl-th commented Jun 21, 2024

Is the problem that it uses other packages and adds them as weak dependencies?
Because the ∇getindex!(dx::AbstractGPUArray, dy, inds...) method already existed in ChainRules.jl, but it was not optimal w.r.t. performance.
This PR just adds a more efficient path.

@oxinabox
Copy link
Member

That is a fair point, we did apparently already make an exception for AbstractGPUArrays for some reason I can not recall

@ToucheSir
Copy link
Contributor

ToucheSir commented Jun 28, 2024

Could this Gordion Knot be sliced by creating an API in CRC which both ChainRules and GPU array packages could use and overload respectively? Maybe something adjacent to the current accumulation functionality (add!! and co). I feel like array indexing is sufficiently fundamental that it wouldn't be too outlandish of an API to expose.

@pxl-th
Copy link
Author

pxl-th commented Jul 2, 2024

Could this Gordion Knot be sliced by creating an API in CRC which both ChainRules and GPU array packages could use and overload respectively? Maybe something adjacent to the current accumulation functionality (add!! and co).

Since this is basically scatter with +, we can add something like this in CRC which is defined only for AbstractArray.
Then in an extension for GPUArrays or KernelAbstractions package (which every GPU backend has as a dependency) we'd have GPU-specific version for AbstractGPUArray. It does not make sense to define this separately in CUDA.jl, AMDGPU.jl, etc.

BTW, I'm not entirely sure if it is OK that an extension for KernelAbstractions extends methods for CRC?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Avoid device-to-host copy in ∇getindex!
3 participants