From d6464176143df0738957473675ecca8f998751b0 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 1 Feb 2024 14:50:38 -0500 Subject: [PATCH] Fix docs (#1229) --- burn-autodiff/Cargo.toml | 9 ++++++++- burn-core/Cargo.toml | 10 ++++++++++ burn-ndarray/Cargo.toml | 16 ++++++++++------ burn-tensor/src/lib.rs | 1 + burn-tensor/src/tensor/api/check.rs | 24 ++++++++++++++---------- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/burn-autodiff/Cargo.toml b/burn-autodiff/Cargo.toml index 47d86af3fb..ba13c32025 100644 --- a/burn-autodiff/Cargo.toml +++ b/burn-autodiff/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-autodiff" version.workspace = true [features] -default = ["export_tests"] +default = [] export_tests = ["burn-tensor-testgen"] [dependencies] @@ -21,3 +21,10 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.1", opt derive-new = { workspace = true } spin = { workspace = true } + +[dev-dependencies] +burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [ + "export_tests", +] } + + diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index 88538ca029..1cd00aae04 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -40,6 +40,16 @@ std = [ ] doc = [ "std", + # Backends + "dataset", + "candle", + "fusion", + "ndarray", + "tch", + "wgpu", + "vision", + "autodiff", + # Doc features "burn-candle/doc", "burn-common/doc", "burn-dataset/doc", diff --git a/burn-ndarray/Cargo.toml b/burn-ndarray/Cargo.toml index 2fc7968b23..e2b67ccb67 100644 --- a/burn-ndarray/Cargo.toml +++ b/burn-ndarray/Cargo.toml @@ -48,13 +48,9 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", features = [ - "export_tests", -], optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", optional = true } burn-common = { path = "../burn-common", version = "0.12.1", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [ - "export_tests", -] } +burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false } matrixmultiply = { workspace = true, default-features = false } rayon = { workspace = true, optional = true } @@ -67,5 +63,13 @@ openblas-src = { workspace = true, optional = true } rand = { workspace = true } spin = { workspace = true } # using in place of use std::sync::Mutex; +[dev-dependencies] +burn-autodiff = { path = "../burn-autodiff", version = "0.12.1", default-features = false, features = [ + "export_tests", +] } +burn-tensor = { path = "../burn-tensor", version = "0.12.1", default-features = false, features = [ + "export_tests", +] } + [package.metadata.docs.rs] features = ["doc"] diff --git a/burn-tensor/src/lib.rs b/burn-tensor/src/lib.rs index ac61d8c2b2..d16a02e220 100644 --- a/burn-tensor/src/lib.rs +++ b/burn-tensor/src/lib.rs @@ -16,6 +16,7 @@ mod tensor; mod tests; pub use half::{bf16, f16}; +pub(crate) use tensor::check::macros::check; pub use tensor::*; pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index 928b3fee0c..5d55ad9b5c 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -818,21 +818,25 @@ impl TensorError { } } -/// We use a macro for all checks, since the panic message file and line number will match the -/// function that does the check instead of a the generic error.rs crate private unrelated file -/// and line number. -#[macro_export(local_inner_macros)] -macro_rules! check { - ($check:expr) => { - if let TensorCheck::Failed(check) = $check { - core::panic!("{}", check.format()); - } - }; +/// Module where we defined macros that can be used only in the project. +pub(crate) mod macros { + /// We use a macro for all checks, since the panic message file and line number will match the + /// function that does the check instead of a the generic error.rs crate private unrelated file + /// and line number. + macro_rules! check { + ($check:expr) => { + if let TensorCheck::Failed(check) = $check { + core::panic!("{}", check.format()); + } + }; + } + pub(crate) use check; } #[cfg(test)] mod tests { use super::*; + use macros::check; #[test] #[should_panic]