Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba (minimal) #918

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c532ec
remove deprecated ftz intrinsics
rainiwu Jan 26, 2024
fb91f13
suppress spurious cargo clippy warning
rainiwu Jan 26, 2024
a832f51
Update safetensors module and naming
swfsql Jan 31, 2024
901cfe4
impl core::ops::Sub for Dim types
swfsql Jan 31, 2024
a14b40b
add SiLU activation function
swfsql Jan 31, 2024
b52932c
add RMS normalization
swfsql Jan 31, 2024
693b699
Add split_tensor_along method
swfsql Feb 1, 2024
de55567
rm unrelated derive
swfsql Feb 1, 2024
3122f78
Merge branch 'silu' into mamba-minimal
swfsql Feb 2, 2024
ace3808
Merge branch 'split-tensor-along' into mamba-minimal
swfsql Feb 2, 2024
f6d06e0
Merge branch 'rms-norm' into mamba-minimal
swfsql Feb 2, 2024
ea424c3
Added `TryUnstack` for tensors.
swfsql Feb 6, 2024
5994ac5
fix wgpu signature
swfsql Feb 6, 2024
24a8593
Merge pull request #1 from rainiwu/remove-ftz
swfsql Feb 7, 2024
5ffff2d
add continuity requirement for unstack
swfsql Feb 8, 2024
e883b28
Added {load/read/save/write}_safetensor_with methods
swfsql Feb 9, 2024
c695a15
unstack fixes
swfsql Feb 9, 2024
4141e06
Merge branch 'unstack' into mamba-root
swfsql Feb 9, 2024
8202b20
Merge branch 'safetensors-change' into mamba-root
swfsql Feb 9, 2024
34234e2
Merge remote-tracking branch 'origin/avoid-ci-errors' into mamba-root
swfsql Feb 9, 2024
93202ad
silu: fix cpu df
swfsql Feb 20, 2024
eb70a88
allow to load safetensors from a byte array
swfsql Feb 20, 2024
fde7a40
avoid conv1d bound for cudnn
swfsql Feb 6, 2024
75d63cd
bump gemm
swfsql Feb 9, 2024
f0bcb9a
clippy fix
swfsql Feb 9, 2024
cac2f33
Add mamba-minimal
swfsql Feb 2, 2024
bff1b65
add nightly requirement for mamba-minimal
swfsql Feb 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{mem::MaybeUninit, vec::Vec};

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 1 in dfdx-core/src/data/collate.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

/// Collates `Self` into some other type.
/// Generally similar to an unzip method;
Expand Down Expand Up @@ -55,6 +55,7 @@
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! The following sections provide some high level core concepts & exmaples, and
//! there is more detailed documentation in each of dfdx's submodules.
//!
//! See [feature_flags] for details on feature flags.

Check warning on line 12 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `feature_flags`
//!
//! # Shapes & Tensors
//!
Expand Down Expand Up @@ -59,7 +59,7 @@
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory

Check warning on line 62 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor::Cuda`
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
Expand All @@ -85,8 +85,8 @@
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |

Check warning on line 88 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConv2D`
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |

Check warning on line 89 in dfdx-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `tensor_ops::TryConvTrans2D`
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
Expand Down Expand Up @@ -128,44 +128,6 @@
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
90 changes: 76 additions & 14 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod tuples;
mod vecs;

use std::vec::Vec;

Check warning on line 4 in dfdx-core/src/nn_traits/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 4 in dfdx-core/src/nn_traits/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 4 in dfdx-core/src/nn_traits/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 4 in dfdx-core/src/nn_traits/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

use crate::prelude::{Device, Dtype, Error, Gradients, Shape, Tensor, UniqueId};

Expand Down Expand Up @@ -116,12 +116,13 @@
#[cfg(feature = "safetensors")]
/// Something that can be saved to a .safetensors file.
pub trait SaveSafeTensors {
fn save_safetensors<P: AsRef<std::path::Path>>(
fn save_safetensors_with<P: AsRef<std::path::Path>, F: FnMut(String) -> String>(
&self,
path: P,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let mut tensors = Vec::new();
self.write_safetensors("", &mut tensors);
self.write_safetensors_with("", &mut tensors, key_map);
let data = tensors.iter().map(|(k, dtype, shape, data)| {
(
k.clone(),
Expand All @@ -131,53 +132,103 @@

safetensors::serialize_to_file(data, &None, path.as_ref())
}
fn write_safetensors(
fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
path: P,
) -> Result<(), safetensors::SafeTensorError> {
self.save_safetensors_with(path, &mut core::convert::identity)
}
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
);
fn write_safetensors(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
) {
self.write_safetensors_with(location, tensors, &mut core::convert::identity)
}
}

#[cfg(feature = "safetensors")]
/// Something that can be loaded from a .safetensors file.
pub trait LoadSafeTensors {
fn load_safetensors<P: AsRef<std::path::Path>>(
fn load_safetensors_with<P: AsRef<std::path::Path>, F: FnMut(String) -> String>(
&mut self,
path: P,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let f = std::fs::File::open(path)?;
let buffer = unsafe { memmap2::MmapOptions::new().map(&f)? };
let tensors = safetensors::SafeTensors::deserialize(&buffer)?;
self.read_safetensors("", &tensors)
self.read_safetensors_with("", &tensors, skip_missing, key_map)
}
fn load_safetensors<P: AsRef<std::path::Path>>(
&mut self,
path: P,
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensors_with(path, false, &mut core::convert::identity)
}
fn load_safetensors_from_bytes_with<F: FnMut(String) -> String>(
&mut self,
bytes: &[u8],
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let tensors = safetensors::SafeTensors::deserialize(&bytes)?;
self.read_safetensors_with("", &tensors, skip_missing, key_map)
}
fn load_safetensors_from_bytes(
&mut self,
bytes: &[u8],
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity)
}

fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError>;
fn read_safetensors(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
) -> Result<(), safetensors::SafeTensorError> {
self.read_safetensors_with(location, tensors, false, &mut core::convert::identity)
}
}

#[cfg(feature = "safetensors")]
impl<S: Shape, E: Dtype, D: Device<E>, T> LoadSafeTensors for Tensor<S, E, D, T> {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
self.load_safetensor(tensors, location)
self.load_safetensor(tensors, location, skip_missing, key_map)
}
}

#[cfg(feature = "safetensors")]
impl<S: Shape, E: Dtype, D: Device<E>, T> SaveSafeTensors for Tensor<S, E, D, T> {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
let location = key_map(location.to_string());
tensors.push((
location.to_string(),
location,
<E as crate::dtypes::SafeTensorsDtype>::DTYPE,
self.shape.concrete().into(),
self.as_vec().iter().flat_map(|e| e.to_le_bytes()).collect(),
Expand All @@ -189,15 +240,17 @@
($Ty:ty) => {
#[cfg(feature = "safetensors")]
impl SaveSafeTensors for $Ty {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
let location = key_map(location.to_string());
#[allow(unused_imports)]
use crate::dtypes::ToLeBytes;
tensors.push((
location.to_string(),
location,
<$Ty as crate::dtypes::SafeTensorsDtype>::DTYPE,
Vec::new(),
self.to_le_bytes().to_vec(),
Expand All @@ -207,14 +260,23 @@

#[cfg(feature = "safetensors")]
impl LoadSafeTensors for $Ty {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
let location = key_map(location.to_string());
#[allow(unused_imports)]
use crate::dtypes::FromLeBytes;
let view = tensors.tensor(location)?;
let view = match tensors.tensor(&location) {
Ok(ok) => ok,
Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => {
return Ok(());
}
Err(e) => return Err(e),
};
*self = Self::from_le_bytes(view.data().try_into().unwrap());
Ok(())
}
Expand Down
17 changes: 13 additions & 4 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tensor_ops::Device,
};

use std::vec::Vec;

Check warning on line 7 in dfdx-core/src/nn_traits/tuples.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/tuples.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/tuples.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/tuples.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

macro_rules! tuple_impls {
([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),*]) => {
Expand All @@ -20,23 +20,32 @@

#[cfg(feature = "safetensors")]
impl<$($name: crate::nn_traits::SaveSafeTensors, )+> crate::nn_traits::SaveSafeTensors for ($($name,)+) {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
$(self.$idx.write_safetensors(&format!("{location}{}.", $idx), tensors);)+
$(
let name = &format!("{location}.{}", $idx);
self.$idx.write_safetensors_with(name, tensors, key_map);
)+
}
}

#[cfg(feature = "safetensors")]
impl<$($name: crate::nn_traits::LoadSafeTensors, )+> crate::nn_traits::LoadSafeTensors for ($($name,)+) {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
$(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+
$(
let name = &format!("{location}.{}", $idx);
self.$idx.read_safetensors_with(name, tensors, skip_missing, key_map)?;
)+
Ok(())
}
}
Expand Down
13 changes: 9 additions & 4 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
tensor_ops::Device,
};

use std::vec::Vec;

Check warning on line 7 in dfdx-core/src/nn_traits/vecs.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/vecs.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/vecs.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

Check warning on line 7 in dfdx-core/src/nn_traits/vecs.rs

View workflow job for this annotation

GitHub Actions / cargo-test-nightly

the item `Vec` is imported redundantly

impl<E: Dtype, D: Device<E>, T: crate::nn_traits::BuildOnDevice<E, D>>
crate::nn_traits::BuildOnDevice<E, D> for Vec<T>
Expand Down Expand Up @@ -60,26 +60,31 @@

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for Vec<T> {
fn write_safetensors(
fn write_safetensors_with<F: FnMut(String) -> String>(
&self,
location: &str,
tensors: &mut Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)>,
key_map: &mut F,
) {
for (i, t) in self.iter().enumerate() {
t.write_safetensors(&format!("{location}{i}."), tensors);
let name = &format!("{location}.{i}");
t.write_safetensors_with(name, tensors, key_map);
}
}
}

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::LoadSafeTensors> crate::nn_traits::LoadSafeTensors for Vec<T> {
fn read_safetensors(
fn read_safetensors_with<F: FnMut(String) -> String>(
&mut self,
location: &str,
tensors: &safetensors::SafeTensors,
skip_missing: bool,
key_map: &mut F,
) -> Result<(), safetensors::SafeTensorError> {
for (i, t) in self.iter_mut().enumerate() {
t.read_safetensors(&format!("{location}{i}."), tensors)?;
let name = &format!("{location}.{i}");
t.read_safetensors_with(name, tensors, skip_missing, key_map)?;
}
Ok(())
}
Expand Down
39 changes: 39 additions & 0 deletions dfdx-core/src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ where
}
}

impl<const N: usize> core::ops::Sub<Const<N>> for usize {
type Output = usize;
fn sub(self, _: Const<N>) -> Self::Output {
self.size() - N
}
}
impl<const N: usize> core::ops::Sub<usize> for Const<N> {
type Output = usize;
fn sub(self, rhs: usize) -> Self::Output {
N - rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Sub<Const<N>> for Const<M>
where
Const<{ M - N }>: Sized,
{
type Output = Const<{ M - N }>;
fn sub(self, _: Const<N>) -> Self::Output {
Const
}
}

impl<const N: usize> core::ops::Mul<Const<N>> for usize {
type Output = usize;
fn mul(self, _: Const<N>) -> Self::Output {
Expand Down Expand Up @@ -121,18 +145,33 @@ where
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
fn dim(&self) -> Self::Dim;
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T;
}
impl<T, const N: usize> Array<T> for [T; N] {
type Dim = Const<N>;
fn dim(&self) -> Self::Dim {
Const
}
fn from_fn<F>(cb: F, _len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
std::array::from_fn(cb)
}
}
impl<T> Array<T> for std::vec::Vec<T> {
type Dim = usize;
fn dim(&self) -> Self::Dim {
self.len()
}
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
(0..len).map(cb).collect()
}
}

/// A collection of dimensions ([Dim]) that change how a multi-dimensional
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
ls: &[impl Tensorlike<L, E, D>],
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! The [Tensor] struct, [Cpu] & [Cuda] devices, and

Check warning on line 1 in dfdx-core/src/tensor/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `Cuda`
//! traits like [ZerosTensor], [OnesTensor], [SampleTensor].
//!
//! At a high level a tensor is made up of:
//! 1. The [crate::shapes::Shape] of the array it stores
//! 2. The [crate::shapes::Dtype] of the elements of the array
//! 3. The [Storage] (e.g. [Cpu] or [Cuda]) that it uses to store the nd array

Check warning on line 7 in dfdx-core/src/tensor/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `Cuda`
//! 4. A [Tape], which can either actually be a tape ([OwnedTape])
//! or be empty ([NoneTape]).
//!
Expand Down Expand Up @@ -151,7 +151,7 @@
pub use numpy::NumpyDtype;
mod error;
#[cfg(feature = "safetensors")]
pub mod safetensors;
mod safetensors;
mod tensorlike;
mod unique_id;

Expand Down
Loading
Loading