diff --git a/Cargo.lock b/Cargo.lock index 057a36d03..c2f9a31c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,36 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode-store-tests" +version = "0.1.0" +dependencies = [ + "bincode", + "ndarray", + "rmp", + "rmp-serde", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bitflags" version = "2.9.1" @@ -460,6 +490,7 @@ name = "ndarray" version = "0.16.1" dependencies = [ "approx", + "bincode", "cblas-sys", "defmac", "itertools", @@ -1129,6 +1160,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "ureq" version = "2.10.1" @@ -1167,6 +1204,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 14226986e..3bddf2943 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +bincode = { version = "2.0", optional = true, default-features = false, features = ["alloc", "derive"] } + [dev-dependencies] defmac = "0.2" quickcheck = { workspace = true } @@ -60,6 +62,8 @@ blas = ["dep:cblas-sys", "dep:libc"] serde = ["dep:serde"] +bincode = ["dep:bincode"] + std = ["num-traits/std", "matrixmultiply/std"] rayon = ["dep:rayon", "std"] diff --git a/README.rst b/README.rst index 49558b1c1..2a33ed807 100644 --- a/README.rst +++ b/README.rst @@ -78,6 +78,10 @@ your `Cargo.toml`. - Enables serialization support for serde 1.x +- ``bincode`` + + - Enables bincode store support for bincode 2.x + - ``rayon`` - Enables parallel iterators, parallelized methods and ``par_azip!``. diff --git a/crates/bincode-store-tests/Cargo.toml b/crates/bincode-store-tests/Cargo.toml new file mode 100644 index 000000000..4f86faaf6 --- /dev/null +++ b/crates/bincode-store-tests/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "bincode-store-tests" +version = "0.1.0" +authors = ["MiyakoMeow"] +publish = false +edition = "2018" + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["bincode"] } + +bincode = { version = "2" } + +[dev-dependencies] +# >=0.8.11 to avoid rmp-serde security vulnerability +# <0.8.14 to allows MSRV 1.64.0 +rmp = { version = ">=0.8.11,<0.8.14" } +# Old version to work with Rust 1.64+ +rmp-serde = { version = ">=1.1.1" } diff --git a/crates/bincode-store-tests/src/lib.rs b/crates/bincode-store-tests/src/lib.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/crates/bincode-store-tests/src/lib.rs @@ -0,0 +1 @@ + diff --git a/crates/bincode-store-tests/tests/store_and_extract.rs b/crates/bincode-store-tests/tests/store_and_extract.rs new file mode 100644 index 000000000..325318b2c --- /dev/null +++ b/crates/bincode-store-tests/tests/store_and_extract.rs @@ -0,0 +1,98 @@ +use ndarray::{arr0, arr1, arr2, s, ArcArray, ArrayBase, Dim, IxDyn, IxDynImpl, OwnedArcRepr, OwnedRepr}; +// No test: ArcArray2, ArrayD + +#[test] +fn store_many_dim_excrate() { + { + let a = arr0::(2.72); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim<[usize; 0]>>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } + + { + let a = arr1::(&[2.72, 1., 2.]); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim<[usize; 1]>>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } + + { + let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim<[usize; 2]>>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } + + { + // Test a sliced array. + let mut a = ArcArray::from_iter(0..32) + .into_shape_with_order((2, 2, 2, 4)) + .unwrap(); + a.slice_collapse(s![..;-1, .., .., ..2]); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim<[usize; 4]>>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } +} + +#[test] +fn serial_ixdyn_serde() { + { + let a = arr0::(2.72).into_dyn(); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } + + { + let a = arr1::(&[2.72, 1., 2.]).into_dyn(); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } + + { + let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]) + .into_shape_with_order(IxDyn(&[3, 1, 1, 1, 2, 1])) + .unwrap(); + let store_bytes = bincode::encode_to_vec(&a, bincode::config::standard()).unwrap(); + println!("Bincode encode {:?} => {:?}", &a, store_bytes); + let res = bincode::decode_from_slice::, Dim>, _>( + &store_bytes, + bincode::config::standard(), + ); + println!("{:?}", res); + assert_eq!(a, res.unwrap().0); + } +} diff --git a/src/array_bincode.rs b/src/array_bincode.rs new file mode 100644 index 000000000..bf6976657 --- /dev/null +++ b/src/array_bincode.rs @@ -0,0 +1,124 @@ +// Copyright 2014-2025 MiyakoMeow and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use bincode::{ + de::{BorrowDecoder, Decoder}, + enc::Encoder, + error::{DecodeError, EncodeError}, + BorrowDecode, Decode, Encode, +}; + +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +use crate::{imp_prelude::*, IxDynImpl, ShapeError}; + +use super::arraytraits::ARRAY_FORMAT_VERSION; + +/// **Requires crate feature `"bincode"`** +impl Encode for Dim +where + I: Encode, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + Encode::encode(&self.ix(), encoder) + } +} + +/// **Requires crate feature `"bincode"`** +impl Decode for Dim +where + I: Decode, +{ + fn decode>(decoder: &mut D) -> Result { + Decode::decode(decoder).map(Dim::new) + } +} + +/// **Requires crate feature `"bincode"`** +impl<'de, Context, I> BorrowDecode<'de, Context> for Dim +where + I: BorrowDecode<'de, Context>, +{ + fn borrow_decode>(decoder: &mut D) -> Result { + BorrowDecode::borrow_decode(decoder).map(Dim::new) + } +} + +/// **Requires crate feature `"bincode"`** +impl Encode for IxDyn { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + let ix: &IxDynImpl = self.ix(); + Encode::encode(&ix.len(), encoder)?; + ix.into_iter() + .try_for_each(|ix| Encode::encode(ix, encoder)) + } +} + +/// **Requires crate feature `"bincode"`** +impl Decode for IxDynImpl { + fn decode>(decoder: &mut D) -> Result { + let len: usize = Decode::decode(decoder)?; + (0..len) + .map(|_| Decode::decode(decoder)) + .collect::, DecodeError>>() + .map(IxDynImpl::from) + } +} + +/// **Requires crate feature `"bincode"`** +impl<'de, Context> bincode::BorrowDecode<'de, Context> for IxDynImpl { + fn borrow_decode>(decoder: &mut D) -> Result { + let len: usize = BorrowDecode::borrow_decode(decoder)?; + (0..len) + .map(|_| BorrowDecode::borrow_decode(decoder)) + .collect::, DecodeError>>() + .map(IxDynImpl::from) + } +} + +/// **Requires crate feature `"serde"`** +impl Encode for ArrayBase +where + A: Encode, + D: Dimension + Encode, + S: Data, +{ + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + Encode::encode(&ARRAY_FORMAT_VERSION, encoder)?; + Encode::encode(&self.raw_dim(), encoder)?; + let iter = self.iter(); + Encode::encode(&iter.len(), encoder)?; + iter.into_iter() + .try_for_each(|elt| Encode::encode(elt, encoder)) + } +} + +/// **Requires crate feature `"bincode"`** +impl Decode for ArrayBase +where + A: Decode, + D: Dimension + Decode, + S: DataOwned, +{ + fn decode>(decoder: &mut De) -> Result { + let data_version: u8 = Decode::decode(decoder)?; + (data_version == ARRAY_FORMAT_VERSION) + .then_some(()) + .ok_or(DecodeError::Other("ARRAY_FORMAT_VERSION not match!"))?; + let dim: D = Decode::decode(decoder)?; + let data_len: usize = Decode::decode(decoder)?; + let data: Vec<_> = (0..data_len) + .map(|_| Decode::decode(decoder)) + .collect::, DecodeError>>()?; + let expected_size = dim.size(); + ArrayBase::from_shape_vec(dim, data).map_err(|_err: ShapeError| DecodeError::ArrayLengthMismatch { + required: expected_size, + found: data_len, + }) + } +} diff --git a/src/arraytraits.rs b/src/arraytraits.rs index a34b1985e..7920bd48d 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -432,8 +432,8 @@ where { } -#[cfg(feature = "serde")] -#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] +#[cfg(any(feature = "serde", feature = "bincode"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "serde", feature = "bincode"))))] // Use version number so we can add a packed format later. pub const ARRAY_FORMAT_VERSION: u8 = 1u8; diff --git a/src/lib.rs b/src/lib.rs index 3efb378ce..a1fe1fa2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,6 +175,8 @@ mod itertools; mod argument_traits; #[cfg(feature = "serde")] mod array_serde; +#[cfg(feature = "bincode")] +mod array_bincode; mod arrayformat; mod arraytraits; pub use crate::argument_traits::AssignElem;