Skip to content

Commit

Permalink
stateful has an implicit sequence of 1
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Feb 7, 2024
1 parent 66ff785 commit cadf65c
Showing 1 changed file with 35 additions and 51 deletions.
86 changes: 35 additions & 51 deletions dfdx/src/nn/layers/mamba_minimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ pub mod stateful {
T: Tape<E, D>,
>
Module<(
Tensor<(Batch, C1, DModel), E, D, T>,
Tensor<(Batch, DModel), E, D, T>,
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
)> for MambaBlock<DModel, DState, DtRank, DConv, DInner, E, D>
where
Expand Down Expand Up @@ -842,35 +842,34 @@ pub mod stateful {
): dfdx_core::tensor_ops::TryConcatShapeAlong<Axis<2>, Output = (Batch, DInner, DConv)>,
{
type Output = (
Tensor<(Batch, C1, DModel), E, D, T>,
Tensor<(Batch, DModel), E, D, T>,
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
);

/// Mamba block forward.
fn try_forward(
&self,
x: (
Tensor<(Batch, C1, DModel), E, D, T>,
Tensor<(Batch, DModel), E, D, T>,
MambaStateCache<Batch, DState, DConv, DInner, E, D, T>,
),
) -> Result<Self::Output, Error> {
let (x, mut cache) = x;

// let (batch, _d_model) = *x.shape();
let (batch, d_inner, d_conv) = *cache.conv_state.shape();

// layer 1 (in_proj)
let (xs, res): (
Tensor<(Batch, C1, DInner), _, _, _>,
Tensor<(Batch, C1, DInner), _, _, _>,
Tensor<(Batch, DInner), _, _, _>,
Tensor<(Batch, DInner), _, _, _>,
) = {
// projects the input DModel into 2*DInner
let xs_and_res: Tensor<(Batch, C1, <DInner as Mul<C2>>::Output), _, _, _> =
let xs_and_res: Tensor<(Batch, <DInner as Mul<C2>>::Output), _, _, _> =
self.in_proj.try_forward(x)?;

// splits xs_and_res into (xs, res)
let (xs, res, _tape) =
xs_and_res.try_split_tensor_along(Axis::<2>, d_inner, d_inner)?;
xs_and_res.try_split_tensor_along(Axis::<1>, d_inner, d_inner)?;

(xs, res)
};
Expand All @@ -893,12 +892,11 @@ pub mod stateful {
)?;
// then concat with the xs as the last column (by the right side)
let xs: Tensor<(Batch, DInner, C1), _, _, _> =
xs.try_permute::<_, Axes3<0, 2, 1>>()?;
// let xs = xs.try_reshape_like(&(batch, d_inner, Const::<1>))?;
xs.try_reshape_like(&(batch, d_inner, Const::<1>))?;
(conv_state, xs).try_concat_tensor_along(Axis::<2>)?
};

let xs: Tensor<(Batch, C1, DInner), E, _, _> = {
let xs: Tensor<(Batch, DInner), E, _, _> = {
let conv1d = self
.conv1d
.weight
Expand All @@ -913,9 +911,7 @@ pub mod stateful {
let xs = self.conv1d_bias.try_forward(xs)?;

// activation
let xs = xs.try_silu()?;

xs.try_reshape_like(&(batch, Const::<1>, d_inner))?
xs.try_silu()?
};

let (ss, cache_ssm_state) = ss_step::<Batch, DState, DtRank, DInner, E, D, T>(
Expand All @@ -929,7 +925,7 @@ pub mod stateful {
)?;

let ys = ss.try_mul(res.try_silu()?)?;
let y: Tensor<(Batch, C1, DModel), _, _, _> = self.out_proj.try_forward(ys)?;
let y: Tensor<(Batch, DModel), _, _, _> = self.out_proj.try_forward(ys)?;

cache.ssm_state = cache_ssm_state;

Expand Down Expand Up @@ -957,13 +953,13 @@ pub mod stateful {
//
a: Tensor<(DInner, DState), E, D, T>,
d: Tensor<(DInner,), E, D, T>,
u: Tensor<(Batch, C1, DInner), E, D, T>,
u: Tensor<(Batch, DInner), E, D, T>,
x_proj: &MatMul<DInner, <DtRank as Add<<DState as Mul<C2>>::Output>>::Output, E, D>,
dt_proj: &Linear<DtRank, DInner, E, D>,
ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>,
) -> Result<
(
Tensor<(Batch, C1, DInner), E, D, T>,
Tensor<(Batch, DInner), E, D, T>,
Tensor<(Batch, DInner, DState), E, D, T>,
),
dfdx::tensor::Error,
Expand All @@ -987,25 +983,25 @@ pub mod stateful {
// this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective)
let a: Tensor<(DInner, DState), _, _, _> = a.try_exp()?.try_negate()?;

// (Batch, 1, DtRank + DState * 2)
let x_dbl: Tensor<(Batch, C1, _), _, _, _> = x_proj.try_forward(u.retaped::<T>())?;
// (Batch, DtRank + DState * 2)
let x_dbl: Tensor<(Batch, _), _, _, _> = x_proj.try_forward(u.retaped::<T>())?;

// ∆ (part 1/2)
// ∆ is input-dependent
let (delta, x_dbl_tail, _tape): (Tensor<(Batch, C1, DtRank), _, _, _>, _, _) =
x_dbl.try_split_tensor_along(Axis::<2>, dt_rank, d_state * Const::<2>)?;
let (delta, x_dbl_tail, _tape): (Tensor<(Batch, DtRank), _, _, _>, _, _) =
x_dbl.try_split_tensor_along(Axis::<1>, dt_rank, d_state * Const::<2>)?;

// B and C
// B and C are input-dependent
let (b, c, _tape): (
Tensor<(Batch, C1, DState), _, _, _>,
Tensor<(Batch, C1, DState), _, _, _>,
Tensor<(Batch, DState), _, _, _>,
Tensor<(Batch, DState), _, _, _>,
_,
) = x_dbl_tail.try_split_tensor_along(Axis::<2>, d_state, d_state)?;
) = x_dbl_tail.try_split_tensor_along(Axis::<1>, d_state, d_state)?;

// ∆ (part 2/2)
// ∆ is input-dependent
let delta: Tensor<(Batch, C1, DInner), _, _, _> = {
let delta: Tensor<(Batch, DInner), _, _, _> = {
// note: don't add dt_proj bias
let delta = delta.try_matmul(
dt_proj
Expand All @@ -1021,22 +1017,14 @@ pub mod stateful {
dt_proj
.bias
.retaped::<T>()
.try_broadcast_like(&(batch, Const::<1>, d_inner))?,
.try_broadcast_like(&(batch, d_inner))?,
)?
.try_exp()?
.try_add(one)?)
.try_ln()?
};

selective_scan_step::<Batch, DState, DInner, E, D, T>(
delta.try_permute::<_, Axes3<0, 2, 1>>()?,
a,
b,
c.try_permute::<_, Axes3<1, 0, 2>>()?,
d,
u,
ssm_state_cache,
)
selective_scan_step::<Batch, DState, DInner, E, D, T>(delta, a, b, c, d, u, ssm_state_cache)
}

// Selective Scan.
Expand All @@ -1057,16 +1045,16 @@ pub mod stateful {
D: Device<E>,
T: Tape<E, D>,
>(
delta: Tensor<(Batch, DInner, C1), E, D, T>,
delta: Tensor<(Batch, DInner), E, D, T>,
a: Tensor<(DInner, DState), E, D, T>,
b: Tensor<(Batch, C1, DState), E, D, T>,
c: Tensor<(C1, Batch, DState), E, D, T>,
b: Tensor<(Batch, DState), E, D, T>,
c: Tensor<(Batch, DState), E, D, T>,
d: Tensor<(DInner,), E, D, T>,
u: Tensor<(Batch, C1, DInner), E, D, T>,
u: Tensor<(Batch, DInner), E, D, T>,
mut ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>,
) -> Result<
(
Tensor<(Batch, C1, DInner), E, D, T>,
Tensor<(Batch, DInner), E, D, T>,
Tensor<(Batch, DInner, DState), E, D, T>,
),
dfdx::tensor::Error,
Expand All @@ -1078,15 +1066,15 @@ pub mod stateful {
// - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
// "A is the more important term and the performance doesn't change much with the simplification on B"
let (delta_a, delta_bu): (
Tensor<(Batch, DInner, C1, DState), _, _, _>,
Tensor<(Batch, DInner, C1, DState), _, _, _>,
Tensor<(Batch, DInner, DState), _, _, _>,
Tensor<(Batch, DInner, DState), _, _, _>,
) = {
let target_shape = (batch, d_inner, Const::<1>, d_state);
let target_shape = (batch, d_inner, d_state);

let delta_broadcasted = delta.try_broadcast_like(&target_shape)?;

let a = a.try_broadcast_like(&target_shape)?;
let delta_a: Tensor<(Batch, DInner, C1, DState), _, _, _> =
let delta_a: Tensor<(Batch, DInner, DState), _, _, _> =
delta_broadcasted.retaped::<T>().try_mul(a)?.try_exp()?;

let b = b.try_broadcast_like(&target_shape)?;
Expand All @@ -1106,13 +1094,9 @@ pub mod stateful {

let y = ssm_state_cache
.retaped::<T>()
.try_matmul(c.try_permute::<_, Axes3<1, 2, 0>>()?)?;
let du = d
.try_broadcast_like(&(batch, Const::<1>, d_inner))?
.try_mul(u)?;
let y = y
.try_reshape_like(&(batch, Const::<1>, d_inner))?
.try_add(du)?;
.try_matmul(c.try_reshape_like(&(batch, d_state, Const::<1>))?)?;
let du = d.try_broadcast_like(&(batch, d_inner))?.try_mul(u)?;
let y = y.try_reshape_like(&(batch, d_inner))?.try_add(du)?;

Ok((y, ssm_state_cache))
}
Expand Down

0 comments on commit cadf65c

Please sign in to comment.