Skip to content

Commit

Permalink
add scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Dec 23, 2023
1 parent 675c790 commit 45508b1
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 2 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ default = ["all"]

all = [
"chunk",
"cumsum",
# "cumsum",
"equal",
"eye",
# "eye",
"logical_not",
"logical_or",
"masked_fill",
"outer",
"scatter",
"scaled_dot_product_attention",
"triangular",
"unbind",
Expand All @@ -44,6 +45,7 @@ outer = []
scaled_dot_product_attention = ["masked_fill", "logical_not", "tril"]
triangular = []
to_tuple = []
scatter = []
tril = ["triangular"]
trilu = ["triangular"]
unbind = ["to_tuple"]
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Currently provides (see also [tests](https://github.com/mokeyish/candle-ext/tree

- F::full_like / Tensor::full_like

- F::scatter / Tensor::scatter

- F::triu / Tensor::triu

- F::tril / Tensor::tril
Expand Down
2 changes: 2 additions & 0 deletions src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
#[rustfmt::skip]
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels/indexing.ptx"));
#[rustfmt::skip]
pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels/custom_binary.ptx"));
57 changes: 57 additions & 0 deletions src/kernels/indexing.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "compatibility.cuh"
#include<stdint.h>

#define SCATTER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME, OP) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { \
const size_t numel = left_size * right_size;\
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {\
const size_t pre = i / right_size;\
const size_t post = i % right_size;\
for (unsigned int j = 0; j < src_dim_size; ++j) {\
const size_t src_i = (pre * src_dim_size + j) * right_size + post;\
const size_t idx = ids[src_i];\
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;\
out[dst_i] OP inp[src_i];\
}\
}\
} \

#if __CUDA_ARCH__ >= 800
SCATTER_OP(__nv_bfloat16, int64_t, scatter_i64_bf16, =)
SCATTER_OP(__nv_bfloat16, uint32_t, scatter_u32_bf16, =)
SCATTER_OP(__nv_bfloat16, uint8_t, scatter_u8_bf16, =)
#endif

#if __CUDA_ARCH__ >= 530
SCATTER_OP(__half, uint32_t, scatter_u32_f16, =)
SCATTER_OP(__half, uint8_t, scatter_u8_f16, =)
#endif


#pragma region scatter_assign
SCATTER_OP(float, int64_t, scatter_i64_f32, =)
SCATTER_OP(double, int64_t, scatter_i64_f64, =)
SCATTER_OP(uint8_t, int64_t, scatter_i64_u8, =)
SCATTER_OP(int64_t, int64_t, scatter_i64_i64, =)
SCATTER_OP(uint32_t, int64_t, scatter_i64_u32, =)

SCATTER_OP(float, uint32_t, scatter_u32_f32, =)
SCATTER_OP(double, uint32_t, scatter_u32_f64, =)
SCATTER_OP(uint8_t, uint32_t, scatter_u32_u8, =)
SCATTER_OP(int64_t, uint32_t, scatter_u32_i64, =)
SCATTER_OP(uint32_t, uint32_t, scatter_u32_u32, =)

SCATTER_OP(float, uint8_t, scatter_u8_f32, =)
SCATTER_OP(double, uint8_t, scatter_u8_f64, =)
SCATTER_OP(uint8_t, uint8_t, scatter_u8_u8, =)
SCATTER_OP(uint32_t, uint8_t, scatter_u8_u32, =)
SCATTER_OP(int64_t, uint8_t, scatter_u8_i64, =)
#pragma endregion scatter_assign
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod logical_or;
mod masked_fill;
mod outer;
mod scaled_dot_product_attention;
mod scatter;
mod to_tuple;
mod triangular;
mod unbind;
Expand Down Expand Up @@ -100,6 +101,8 @@ pub trait TensorExt: Sized {
fn masked_fill<D: WithDType>(&self, mask: &Tensor, value: D) -> Result<Self>;
#[cfg(feature = "outer")]
fn outer(&self, vec2: &Tensor) -> Result<Self>;
#[cfg(feature = "scatter")]
fn scatter<D: Dim>(&self, indexes: &Tensor, src: &Tensor, dim: D) -> Result<Tensor>;
#[cfg(feature = "triangular")]
fn tril(&self, diagonal: isize) -> Result<Self>;
#[cfg(feature = "triangular")]
Expand Down Expand Up @@ -146,6 +149,12 @@ impl TensorExt for Tensor {
F::masked_fill(self, mask, value)
}

#[cfg(feature = "scatter")]
#[inline]
fn scatter<D: Dim>(&self, indexes: &Tensor, src: &Tensor, dim: D) -> Result<Tensor> {
F::scatter(self, indexes, src, dim)
}

#[cfg(feature = "outer")]
#[inline]
fn outer(&self, vec2: &Tensor) -> Result<Self> {
Expand Down
Loading

0 comments on commit 45508b1

Please sign in to comment.