From a2cf4be4eea61731305c54247cae06b7be43367a Mon Sep 17 00:00:00 2001 From: Victor Lopez Date: Fri, 6 Feb 2026 14:38:42 +0100 Subject: [PATCH] feat: add serde support --- Cargo.toml | 13 +- README.md | 42 +- msgpacker-bench/Cargo.toml | 11 +- msgpacker-bench/benches/msgpacker.rs | 29 +- msgpacker-derive/Cargo.toml | 10 +- msgpacker-derive/src/lib.rs | 13 +- msgpacker/Cargo.toml | 33 +- msgpacker/src/error.rs | 2 + msgpacker/src/lib.rs | 7 +- msgpacker/src/pack/binary.rs | 32 +- msgpacker/src/pack/collections.rs | 56 ++- msgpacker/src/pack/common.rs | 19 +- msgpacker/src/pack/mod.rs | 2 +- msgpacker/src/serde/deserializer.rs | 710 +++++++++++++++++++++++++++ msgpacker/src/serde/mod.rs | 42 ++ msgpacker/src/serde/serializer.rs | 365 ++++++++++++++ msgpacker/src/unpack/binary.rs | 122 ++--- msgpacker/src/unpack/collections.rs | 46 +- msgpacker/src/unpack/common.rs | 51 +- msgpacker/src/unpack/mod.rs | 5 +- msgpacker/tests/serde.rs | 204 ++++++++ 21 files changed, 1645 insertions(+), 169 deletions(-) create mode 100644 msgpacker/src/serde/deserializer.rs create mode 100644 msgpacker/src/serde/mod.rs create mode 100644 msgpacker/src/serde/serializer.rs create mode 100644 msgpacker/tests/serde.rs diff --git a/Cargo.toml b/Cargo.toml index d58e65f..f9ac003 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,13 @@ [workspace] resolver = "2" -members = [ - "msgpacker", - "msgpacker-bench", - "msgpacker-derive" -] +members = ["msgpacker", "msgpacker-bench", "msgpacker-derive"] + +[workspace.package] +version = "0.5.0" +authors = ["Victor Lopez "] +edition = "2021" +license = "MIT/Apache-2.0" +repository = "https://github.com/codx-dev/msgpacker" [profile.bench] lto = true diff --git a/README.md b/README.md index 7fdf66f..5f350f3 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,11 @@ It will implement `Packable` and `Unpackable` for Rust atomic types. The traits - derive: Enables `MsgPacker` derive convenience macro. - strict: Will panic if there is a protocol violation of the size of a buffer; the maximum allowed size is `u32::MAX`. - std: Will implement the `Packable` and `Unpackable` for `std` collections. +- serde: Adds support for [serde](https://crates.io/crates/serde) + +## Known issues + +- The library, as noted in issue [#18](https://github.com/codx-dev/msgpacker/issues/18), exhibits a stricter approach when importing data generated by external tools due to the support of mixed types in languages like Python for collections. A pertinent instance from that issue involves an array serialization where the initial element is a u64, followed by a f64. While a solution can be devised in Rust using a wrapper that abstracts primitive types, it introduces an undesirable overhead that may not be suitable for typical use cases. Although this feature could potentially be added in future updates, it remains unimplemented at present. ## Example @@ -60,19 +65,38 @@ println!("deserialized {} bytes", n); assert_eq!(city, deserialized); ``` -## Benchmarks +## Serde -Results obtained with `Intel(R) Core(TM) i9-9900X CPU @ 3.50GHz`. +Version `0.5.0` introduces [serde](https://crates.io/crates/serde) support. -The simplicity of the implementation unlocks a performance more than ~10x better than [rmp-serde](https://crates.io/crates/rmp-serde). +```rust +use msgpacker::serde; +use serde_json::{json, Value}; +let val = serde_json::json!({"foo": "bar"}); +let ser = serde::to_vec(&val); +let des: Value = serde::from_slice(&ser).unwrap(); -#### Pack 1.000 elements +assert_eq!(val, des); +``` -![image](https://github.com/codx-dev/msgpacker/assets/8730839/ef69622d-0e2f-4bb1-b47c-6412d89fc19a) -![image](https://github.com/codx-dev/msgpacker/assets/8730839/ce2de037-252a-4c90-b429-430d131ccf7e) +While it's important to recognize that `serde`'s performance can be notably slower, this is primarily due to its implementation of a visitor pattern for type serialization, rather than solely relying on the static structure of declarations. However, `serde` is broadly used and having its support is helpful since a plethora of other libraries will be automatically supported just by having this feature enabled. -#### Unpack 1.000 elements +For more information, refer to `Benchmarks`. -![image](https://github.com/codx-dev/msgpacker/assets/8730839/5576f99d-6f37-4907-89db-5d666b13f9d5) -![image](https://github.com/codx-dev/msgpacker/assets/8730839/234c31d2-f319-414b-9418-4103e97d0a9c) +## Benchmarks + +Results obtained with `AMD EPYC 7402P 24-Core Processor`. + +![Image](https://github.com/user-attachments/assets/4d695e79-59bc-40c9-9e53-5a203c703462) +![Image](https://github.com/user-attachments/assets/f6a72499-9b5c-4b47-b6ea-ec4acbfea5f3) +![Image](https://github.com/user-attachments/assets/60809961-f058-4a86-952b-b6f7d7b3c9a5) +![Image](https://github.com/user-attachments/assets/de1a2be4-50e0-4dac-94c2-e4fb2ca24e2d) +![Image](https://github.com/user-attachments/assets/f88696f0-0479-43b7-a8f1-8a8ad7dab911) +![Image](https://github.com/user-attachments/assets/98277148-e2c1-4878-abd0-6b8ab5371317) + +To run the benchmarks: + +```sh +cd msgpacker-bench && cargo bench +``` diff --git a/msgpacker-bench/Cargo.toml b/msgpacker-bench/Cargo.toml index 5392c8c..a9a39ba 100644 --- a/msgpacker-bench/Cargo.toml +++ b/msgpacker-bench/Cargo.toml @@ -1,20 +1,21 @@ [package] name = "msgpacker-bench" version = "0.0.0" -authors = ["Victor Lopez "] -edition = "2021" -repository = "https://github.com/codx-dev/msgpacker" +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true description = "Benchmarks for msgpacker." publish = false [dependencies] msgpacker = { path = "../msgpacker" } -rmp-serde = "1.1" +rmp-serde = "1.3" serde = { version = "1.0", features = ["derive"] } rand = "0.8" [dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } +criterion = { version = "0.8", features = ["html_reports"] } [[bench]] name = "msgpacker" diff --git a/msgpacker-bench/benches/msgpacker.rs b/msgpacker-bench/benches/msgpacker.rs index 75afd84..b847de9 100644 --- a/msgpacker-bench/benches/msgpacker.rs +++ b/msgpacker-bench/benches/msgpacker.rs @@ -1,4 +1,6 @@ -use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use std::hint::black_box; + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use msgpacker_bench::Value; use rand::{distributions::Standard, prelude::*}; use rmp_serde::{decode::Deserializer, encode::Serializer}; @@ -14,12 +16,17 @@ pub fn pack(c: &mut Criterion) { // preallocate the required bytes let mut bufs_msgpacker = Vec::new(); + let mut bufs_msgpacker_serde = Vec::new(); let mut bufs_rmps = Vec::new(); for count in counts { let mut buf = Vec::new(); msgpacker::pack_array(&mut buf, values.iter().take(count)); bufs_msgpacker.push(buf); + let mut buf = Vec::new(); + msgpacker::serde::to_buffer(&mut buf, &values[..count]); + bufs_msgpacker_serde.push(buf); + let mut buf = Vec::new(); let mut serializer = Serializer::new(&mut buf); (&values[..count]).serialize(&mut serializer).unwrap(); @@ -41,6 +48,18 @@ pub fn pack(c: &mut Criterion) { }, ); + group.bench_with_input( + format!("msgpacker serde {count}"), + &(&values[..*count], bufs_msgpacker_serde[i].capacity()), + |b, (val, buf)| { + b.iter_batched( + || Vec::with_capacity(*buf), + |mut buf| msgpacker::serde::to_buffer(black_box(&mut buf), val), + BatchSize::LargeInput, + ); + }, + ); + group.bench_with_input( format!("rmps {count}"), &(&values[..*count], bufs_rmps[i].capacity()), @@ -71,6 +90,14 @@ pub fn pack(c: &mut Criterion) { }, ); + group.bench_with_input( + format!("msgpacker serde {count}"), + &bufs_msgpacker_serde[i], + |b, buf| { + b.iter(|| msgpacker::serde::from_slice::>(black_box(buf))); + }, + ); + group.bench_with_input(format!("rmps {count}"), &bufs_rmps[i], |b, buf| { b.iter(|| { >::deserialize(&mut Deserializer::new(black_box(&buf[..]))).unwrap() diff --git a/msgpacker-derive/Cargo.toml b/msgpacker-derive/Cargo.toml index 28ef82b..4c1ac04 100644 --- a/msgpacker-derive/Cargo.toml +++ b/msgpacker-derive/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "msgpacker-derive" -version = "0.3.2" -authors = ["Victor Lopez "] +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true categories = ["compression", "encoding", "parser-implementations"] -edition = "2021" keywords = ["messagepack", "msgpack"] -license = "MIT/Apache-2.0" readme = "README.md" -repository = "https://github.com/codx-dev/msgpacker" description = "Derive macros for the MessagePack protocol implementation for Rust." [lib] diff --git a/msgpacker-derive/src/lib.rs b/msgpacker-derive/src/lib.rs index 1d47d19..e274d90 100644 --- a/msgpacker-derive/src/lib.rs +++ b/msgpacker-derive/src/lib.rs @@ -18,16 +18,15 @@ fn contains_attribute(field: &Field, name: &str) -> bool { let name = name.to_string(); if let Some(attr) = field.attrs.first() { if let Meta::List(list) = &attr.meta { - if list.path.is_ident("msgpacker") { - if list + if list.path.is_ident("msgpacker") + && list .tokens .clone() .into_iter() .find(|a| a.to_string() == name) .is_some() - { - return true; - } + { + return true; } } } @@ -67,12 +66,12 @@ fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into { let mut is_vec_u8 = false; match &ty { - Type::Path(p) if p.path.segments.last().filter(|p| p.ident.to_string() == "Vec").is_some() => { + Type::Path(p) if p.path.segments.last().filter(|p| p.ident == "Vec").is_some() => { is_vec = true; match &p.path.segments.last().unwrap().arguments { PathArguments::AngleBracketed(a) if a.args.len() == 1 => { if let Some(GenericArgument::Type(Type::Path(p))) = a.args.first() { - if p.path.segments.last().filter(|p| p.ident.to_string() == "u8").is_some() { + if p.path.segments.last().filter(|p| p.ident == "u8").is_some() { is_vec_u8 = true; } } diff --git a/msgpacker/Cargo.toml b/msgpacker/Cargo.toml index 01ce44d..b999466 100644 --- a/msgpacker/Cargo.toml +++ b/msgpacker/Cargo.toml @@ -1,29 +1,40 @@ [package] name = "msgpacker" -version = "0.4.8" -authors = ["Victor Lopez "] +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true categories = ["compression", "encoding", "parser-implementations"] -edition = "2021" keywords = ["messagepack", "msgpack"] -license = "MIT/Apache-2.0" readme = "README.md" -repository = "https://github.com/codx-dev/msgpacker" description = "MessagePack protocol implementation for Rust." [dependencies] -msgpacker-derive = { version = "0.3", optional = true } +msgpacker-derive = { path = "../msgpacker-derive", optional = true } +serde = { version = "1.0", default-features = false, optional = true } [dev-dependencies] -proptest = "1.2" -proptest-derive = "0.3" +arbitrary = "1.4" +arbitrary-json = "0.1" +msgpacker-derive.path = "../msgpacker-derive" +proptest = "1.10" +proptest-derive = "0.8" +serde = { version = "1.0", features = ["derive"] } +serde_bytes = "0.11" +serde_json = "1.0" [features] -default = ["std", "derive"] -alloc = [] +default = ["derive", "std", "serde"] +alloc = ["serde?/alloc"] derive = ["msgpacker-derive"] strict = [] -std = ["alloc"] +std = ["alloc", "serde?/std"] [[test]] name = "collections" required-features = ["derive"] + +[[test]] +name = "serde" +required-features = ["alloc", "derive", "serde"] diff --git a/msgpacker/src/error.rs b/msgpacker/src/error.rs index 1dd3821..b53091b 100644 --- a/msgpacker/src/error.rs +++ b/msgpacker/src/error.rs @@ -15,6 +15,8 @@ pub enum Error { UnexpectedFormatTag, /// The provided bin length is not valid. UnexpectedBinLength, + /// Not yet implemented. + NotImplemented, } impl fmt::Display for Error { diff --git a/msgpacker/src/lib.rs b/msgpacker/src/lib.rs index 845444d..21a00ca 100644 --- a/msgpacker/src/lib.rs +++ b/msgpacker/src/lib.rs @@ -11,8 +11,11 @@ mod extension; mod error; mod format; mod helpers; -mod pack; -mod unpack; +pub(crate) mod pack; +pub(crate) mod unpack; + +#[cfg(feature = "serde")] +pub mod serde; pub use error::Error; use format::Format; diff --git a/msgpacker/src/pack/binary.rs b/msgpacker/src/pack/binary.rs index 01f1d2c..a9a6276 100644 --- a/msgpacker/src/pack/binary.rs +++ b/msgpacker/src/pack/binary.rs @@ -1,26 +1,30 @@ use super::{Format, Packable}; use core::iter; +pub fn pack_bytes_slice_len>(buf: &mut T, slice: &[u8]) -> usize { + if slice.len() <= u8::MAX as usize { + buf.extend(iter::once(Format::BIN8).chain(iter::once(slice.len() as u8))); + 2 + } else if slice.len() <= u16::MAX as usize { + buf.extend(iter::once(Format::BIN16).chain((slice.len() as u16).to_be_bytes())); + 3 + } else if slice.len() <= u32::MAX as usize { + buf.extend(iter::once(Format::BIN32).chain((slice.len() as u32).to_be_bytes())); + 5 + } else { + #[cfg(feature = "strict")] + panic!("strict serialization enabled; the buffer is too large"); + return 0; + } +} + impl Packable for [u8] { #[allow(unreachable_code)] fn pack(&self, buf: &mut T) -> usize where T: Extend, { - let n = if self.len() <= u8::MAX as usize { - buf.extend(iter::once(Format::BIN8).chain(iter::once(self.len() as u8))); - 2 - } else if self.len() <= u16::MAX as usize { - buf.extend(iter::once(Format::BIN16).chain((self.len() as u16).to_be_bytes())); - 3 - } else if self.len() <= u32::MAX as usize { - buf.extend(iter::once(Format::BIN32).chain((self.len() as u32).to_be_bytes())); - 5 - } else { - #[cfg(feature = "strict")] - panic!("strict serialization enabled; the buffer is too large"); - return 0; - }; + let n = pack_bytes_slice_len(buf, self); buf.extend(self.iter().copied()); n + self.len() } diff --git a/msgpacker/src/pack/collections.rs b/msgpacker/src/pack/collections.rs index bafee5f..f50ca30 100644 --- a/msgpacker/src/pack/collections.rs +++ b/msgpacker/src/pack/collections.rs @@ -1,18 +1,12 @@ use super::{Format, Packable}; use core::{borrow::Borrow, iter}; -/// Packs an array into the extendable buffer, returning the amount of written bytes. -#[allow(unreachable_code)] -pub fn pack_array(buf: &mut T, iter: A) -> usize +/// Packs the length of an array. +pub fn pack_array_len(buf: &mut T, len: usize) -> usize where T: Extend, - A: IntoIterator, - I: Iterator + ExactSizeIterator, - V: Packable, { - let values = iter.into_iter(); - let len = values.len(); - let n = if len <= 15 { + if len <= 15 { buf.extend(iter::once(((len & 0x0f) as u8) | 0x90)); 1 } else if len <= u16::MAX as usize { @@ -25,24 +19,30 @@ where #[cfg(feature = "strict")] panic!("strict serialization enabled; the buffer is too large"); return 0; - }; - n + values.map(|v| v.pack(buf)).sum::() + } } -/// Packs a map into the extendable buffer, returning the amount of written bytes. +/// Packs an array into the extendable buffer, returning the amount of written bytes. #[allow(unreachable_code)] -pub fn pack_map(buf: &mut T, iter: A) -> usize +pub fn pack_array(buf: &mut T, iter: A) -> usize where T: Extend, A: IntoIterator, - B: Borrow<(K, V)>, - I: Iterator + ExactSizeIterator, - K: Packable, + I: Iterator + ExactSizeIterator, V: Packable, { - let map = iter.into_iter(); - let len = map.len(); - let n = if len <= 15 { + let values = iter.into_iter(); + let len = values.len(); + let n = pack_array_len(buf, len); + n + values.map(|v| v.pack(buf)).sum::() +} + +/// Packs the length of a map. +pub fn pack_map_len(buf: &mut T, len: usize) -> usize +where + T: Extend, +{ + if len <= 15 { buf.extend(iter::once(((len & 0x0f) as u8) | 0x80)); 1 } else if len <= u16::MAX as usize { @@ -55,7 +55,23 @@ where #[cfg(feature = "strict")] panic!("strict serialization enabled; the buffer is too large"); return 0; - }; + } +} + +/// Packs a map into the extendable buffer, returning the amount of written bytes. +#[allow(unreachable_code)] +pub fn pack_map(buf: &mut T, iter: A) -> usize +where + T: Extend, + A: IntoIterator, + B: Borrow<(K, V)>, + I: Iterator + ExactSizeIterator, + K: Packable, + V: Packable, +{ + let map = iter.into_iter(); + let len = map.len(); + let n = pack_map_len(buf, len); n + map .map(|b| { let (k, v) = b.borrow(); diff --git a/msgpacker/src/pack/common.rs b/msgpacker/src/pack/common.rs index 9c946d3..f1599ec 100644 --- a/msgpacker/src/pack/common.rs +++ b/msgpacker/src/pack/common.rs @@ -2,20 +2,22 @@ use super::{Format, Packable}; use core::{iter, marker::PhantomData}; impl Packable for () { - fn pack(&self, _buf: &mut T) -> usize + fn pack(&self, buf: &mut T) -> usize where T: Extend, { - 0 + buf.extend(iter::once(Format::NIL)); + 1 } } impl Packable for PhantomData { - fn pack(&self, _buf: &mut T) -> usize + fn pack(&self, buf: &mut T) -> usize where T: Extend, { - 0 + buf.extend(iter::once(Format::NIL)); + 1 } } @@ -33,6 +35,15 @@ impl Packable for bool { } } +impl Packable for char { + fn pack(&self, buf: &mut T) -> usize + where + T: Extend, + { + (*self as u32).pack(buf) + } +} + impl Packable for Option where X: Packable, diff --git a/msgpacker/src/pack/mod.rs b/msgpacker/src/pack/mod.rs index d31ce85..5e51047 100644 --- a/msgpacker/src/pack/mod.rs +++ b/msgpacker/src/pack/mod.rs @@ -1,7 +1,7 @@ use super::{Format, Packable}; mod binary; -mod collections; +pub(crate) mod collections; mod common; mod float; mod int; diff --git a/msgpacker/src/serde/deserializer.rs b/msgpacker/src/serde/deserializer.rs new file mode 100644 index 0000000..f94f519 --- /dev/null +++ b/msgpacker/src/serde/deserializer.rs @@ -0,0 +1,710 @@ +use core::fmt; + +use serde::{de, Deserializer as _}; + +use crate::{ + format::Format, + unpack::{binary, collections}, + Error, Unpackable as _, +}; + +pub struct MsgpackDeserializer<'a>(pub &'a [u8]); + +impl de::Error for Error { + fn custom(_msg: T) -> Self + where + T: fmt::Display, + { + Error::NotImplemented + } +} + +impl<'de, 'a> de::Deserializer<'de> for &'a mut MsgpackDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + if self.0.is_empty() { + return Err(Error::BufferTooShort); + } + match self.0[0] { + 0x00..=Format::POSITIVE_FIXINT => self.deserialize_u8(visitor), + 0x80..=0x8f => self.deserialize_map(visitor), + 0x90..=0x9f => self.deserialize_seq(visitor), + 0xa0..=0xbf => self.deserialize_str(visitor), + 0xe0..=0xff => self.deserialize_i8(visitor), + Format::NIL => self.deserialize_option(visitor), + Format::TRUE => self.deserialize_bool(visitor), + Format::FALSE => self.deserialize_bool(visitor), + Format::UINT8 => self.deserialize_u8(visitor), + Format::UINT16 => self.deserialize_u16(visitor), + Format::UINT32 => self.deserialize_u32(visitor), + Format::UINT64 => self.deserialize_u64(visitor), + Format::INT8 => self.deserialize_i8(visitor), + Format::INT16 => self.deserialize_i16(visitor), + Format::INT32 => self.deserialize_i32(visitor), + Format::INT64 => self.deserialize_i64(visitor), + Format::FLOAT32 => self.deserialize_f32(visitor), + Format::FLOAT64 => self.deserialize_f64(visitor), + Format::BIN8 | Format::BIN16 | Format::BIN32 => self.deserialize_bytes(visitor), + Format::STR8 | Format::STR16 | Format::STR32 => self.deserialize_str(visitor), + Format::ARRAY16 | Format::ARRAY32 => self.deserialize_seq(visitor), + Format::MAP16 | Format::MAP32 => self.deserialize_map(visitor), + #[cfg(feature = "alloc")] + Format::FIXEXT1 + | Format::FIXEXT2 + | Format::FIXEXT4 + | Format::FIXEXT8 + | Format::FIXEXT16 + | Format::EXT8 + | Format::EXT16 + | Format::EXT32 => self.deserialize_byte_buf(visitor), + _ => Err(Error::UnexpectedFormatTag), + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = bool::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_bool(v) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = i8::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_i8(v) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = i16::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_i16(v) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = i32::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_i32(v) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = i64::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_i64(v) + } + + fn deserialize_i128(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = i128::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_i128(v) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = u8::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_u8(v) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = u16::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_u16(v) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = u32::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_u32(v) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = u64::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_u64(v) + } + + fn deserialize_u128(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = u128::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_u128(v) + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = f32::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_f32(v) + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = f64::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_f64(v) + } + + fn deserialize_char(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = char::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_char(v) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = binary::unpack_str(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_borrowed_str(v) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + #[cfg(not(feature = "alloc"))] + { + let _ = visitor; + return Err(Error::NotImplemented); + } + + #[cfg(feature = "alloc")] + { + let (n, v) = ::alloc::string::String::unpack(self.0)?; + self.0 = &self.0[n..]; + return visitor.visit_string(v); + } + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, v) = binary::unpack_bytes(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_borrowed_bytes(v) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + #[cfg(not(feature = "alloc"))] + { + let _ = visitor; + return Err(Error::NotImplemented); + } + + #[cfg(feature = "alloc")] + { + let (n, v) = ::alloc::vec::Vec::unpack(self.0)?; + self.0 = &self.0[n..]; + return visitor.visit_byte_buf(v); + } + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + if self.0.is_empty() { + return Err(Error::BufferTooShort); + } + if self.0[0] == Format::NIL { + self.0 = &self.0[1..]; + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, _) = <()>::unpack(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, len) = collections::unpack_array_len(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_seq(MsgpackDeserializerSeq { + m: self, + count: len, + }) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_seq(MsgpackDeserializerSeq { + m: self, + count: len, + }) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_seq(MsgpackDeserializerSeq { + m: self, + count: len, + }) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + let (n, len) = collections::unpack_map_len(self.0)?; + self.0 = &self.0[n..]; + visitor.visit_map(MsgpackDeserializerSeq { + m: self, + count: len, + }) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_seq(MsgpackDeserializerSeq { + m: self, + count: fields.len(), + }) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_enum(MsgpackEnumHandler { de: self }) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_u32(visitor) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_any(visitor) + } +} + +struct MsgpackDeserializerSeq<'a, 'de: 'a> { + m: &'a mut MsgpackDeserializer<'de>, + count: usize, +} + +impl<'de, 'a> de::SeqAccess<'de> for MsgpackDeserializerSeq<'a, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + if self.count == 0 { + return Ok(None); + } + self.count -= 1; + seed.deserialize(&mut *self.m).map(Some) + } +} + +impl<'de, 'a> de::MapAccess<'de> for MsgpackDeserializerSeq<'a, 'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + if self.count == 0 { + return Ok(None); + } + self.count -= 1; + seed.deserialize(&mut *self.m).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.m) + } +} + +struct MsgpackEnumHandler<'a, 'de: 'a> { + de: &'a mut MsgpackDeserializer<'de>, +} + +impl<'de, 'a> de::VariantAccess<'de> for MsgpackEnumHandler<'a, 'de> { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + seed.deserialize(self.de) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_tuple(len, visitor) + } + + fn struct_variant( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_struct("", fields, visitor) + } +} + +impl<'de, 'a> de::EnumAccess<'de> for MsgpackEnumHandler<'a, 'de> { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + let discriminant = seed.deserialize(MsgpackTagDeserializer { de: self.de })?; + + Ok((discriminant, self)) + } +} + +struct MsgpackTagDeserializer<'a, 'de: 'a> { + de: &'a mut MsgpackDeserializer<'de>, +} + +impl<'de, 'a> serde::Deserializer<'de> for MsgpackTagDeserializer<'a, 'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_identifier(visitor) + } + + fn deserialize_bool(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_i8(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_i16(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_i32(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_i64(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_u8(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_u16(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.de.deserialize_u32(visitor) + } + + fn deserialize_u64(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_str(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_string(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_bytes(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_byte_buf(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_option(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_unit(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_seq(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_map(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.deserialize_u32(visitor) + } + + fn deserialize_ignored_any(self, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unreachable!() + } +} diff --git a/msgpacker/src/serde/mod.rs b/msgpacker/src/serde/mod.rs new file mode 100644 index 0000000..df2d988 --- /dev/null +++ b/msgpacker/src/serde/mod.rs @@ -0,0 +1,42 @@ +//! [serde] implementations. + +use serde::{Deserialize, Serialize}; + +use crate::Error; + +mod deserializer; +mod serializer; + +/// Serializes the provided value into the extendable buffer. +/// +/// This operation is infallible as it will only allocate bytes. +pub fn to_buffer(buffer: &mut X, value: &T) +where + X: Extend, + T: Serialize + ?Sized, +{ + value + .serialize(&mut serializer::MsgpackSerializer::from(buffer)) + .ok(); +} + +/// Serializes the provided value into a [Vec]. +#[cfg(feature = "alloc")] +pub fn to_vec(value: &T) -> ::alloc::vec::Vec +where + T: Serialize + ?Sized, +{ + let mut v = Vec::new(); + + to_buffer(&mut v, value); + + v +} + +/// Deserializes the data from the given slice. +pub fn from_slice<'a, T>(s: &'a [u8]) -> Result +where + T: Deserialize<'a>, +{ + T::deserialize(&mut deserializer::MsgpackDeserializer(s)) +} diff --git a/msgpacker/src/serde/serializer.rs b/msgpacker/src/serde/serializer.rs new file mode 100644 index 0000000..685d131 --- /dev/null +++ b/msgpacker/src/serde/serializer.rs @@ -0,0 +1,365 @@ +use core::fmt; + +use serde::{ser, Serialize}; + +use crate::{pack::collections, Error, Packable}; + +pub struct MsgpackSerializer<'a, X: Extend> { + pub b: &'a mut X, +} + +impl<'a, X: Extend> From<&'a mut X> for MsgpackSerializer<'a, X> { + fn from(b: &'a mut X) -> Self { + Self { b } + } +} + +impl ser::Error for Error { + fn custom(_msg: T) -> Self + where + T: fmt::Display, + { + Error::NotImplemented + } +} + +impl<'a, X: Extend> ser::Serializer for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, v: bool) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_i8(self, v: i8) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_i16(self, v: i16) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_i32(self, v: i32) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_i64(self, v: i64) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_i128(self, v: i128) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_u16(self, v: u16) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_u32(self, v: u32) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_u64(self, v: u64) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_u128(self, v: u128) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_f32(self, v: f32) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_f64(self, v: f64) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_char(self, v: char) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_str(self, v: &str) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + v.pack(self.b); + Ok(()) + } + + fn serialize_none(self) -> Result { + Option::<()>::None.pack(self.b); + Ok(()) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn serialize_unit(self) -> Result { + ().pack(self.b); + Ok(()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + ().pack(self.b); + Ok(()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + ) -> Result { + variant_index.pack(self.b); + Ok(()) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + variant_index.pack(self.b); + super::to_buffer(self.b, value); + Ok(()) + } + + fn serialize_seq(self, len: Option) -> Result { + let len = len.ok_or(Error::NotImplemented)?; + collections::pack_array_len(self.b, len); + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Ok(self) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + variant_index.pack(self.b); + Ok(self) + } + + fn serialize_map(self, len: Option) -> Result { + let len = len.ok_or(Error::NotImplemented)?; + collections::pack_map_len(self.b, len); + Ok(self) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + variant_index.pack(self.b); + Ok(self) + } + + #[cfg(not(feature = "alloc"))] + fn collect_str(self, _value: &T) -> Result + where + T: ?Sized + fmt::Display, + { + Err(Error::NotImplemented) + } +} + +impl<'a, X: Extend> ser::SerializeSeq for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeTuple for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeTupleStruct for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeTupleVariant for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeMap for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, key); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeStruct for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, X: Extend> ser::SerializeStructVariant for &'a mut MsgpackSerializer<'a, X> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + super::to_buffer(self.b, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} diff --git a/msgpacker/src/unpack/binary.rs b/msgpacker/src/unpack/binary.rs index b49bf7a..9456223 100644 --- a/msgpacker/src/unpack/binary.rs +++ b/msgpacker/src/unpack/binary.rs @@ -7,7 +7,6 @@ use super::{ helpers::{take_byte, take_num}, Error, Format, }; -use alloc::{string::String, vec::Vec}; use core::str; pub fn unpack_bytes(mut buf: &[u8]) -> Result<(usize, &[u8]), Error> { @@ -40,70 +39,77 @@ pub fn unpack_str(mut buf: &[u8]) -> Result<(usize, &str), Error> { Ok((n + len, str)) } -impl Unpackable for Vec { - type Error = Error; +#[cfg(feature = "alloc")] +mod alloc { + use super::*; - fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { - unpack_bytes(buf).map(|(n, b)| (n, b.to_vec())) - } + use ::alloc::{string::String, vec::Vec}; + + impl Unpackable for Vec { + type Error = Error; + + fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { + unpack_bytes(buf).map(|(n, b)| (n, b.to_vec())) + } - fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> - where - I: IntoIterator, - { - let mut bytes = bytes.into_iter(); - let format = take_byte_iter(bytes.by_ref())?; - let (n, len) = match format { - Format::BIN8 => (2, take_byte_iter(bytes.by_ref())? as usize), - Format::BIN16 => ( - 3, - take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, - ), - Format::BIN32 => ( - 5, - take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, - ), - _ => return Err(Error::UnexpectedFormatTag), - }; - let v: Vec<_> = bytes.take(len).collect(); - if v.len() < len { - return Err(Error::BufferTooShort); + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> + where + I: IntoIterator, + { + let mut bytes = bytes.into_iter(); + let format = take_byte_iter(bytes.by_ref())?; + let (n, len) = match format { + Format::BIN8 => (2, take_byte_iter(bytes.by_ref())? as usize), + Format::BIN16 => ( + 3, + take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, + ), + Format::BIN32 => ( + 5, + take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, + ), + _ => return Err(Error::UnexpectedFormatTag), + }; + let v: Vec<_> = bytes.take(len).collect(); + if v.len() < len { + return Err(Error::BufferTooShort); + } + Ok((n + len, v)) } - Ok((n + len, v)) } -} -impl Unpackable for String { - type Error = Error; + impl Unpackable for String { + type Error = Error; - fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { - unpack_str(buf).map(|(n, s)| (n, s.into())) - } + fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { + unpack_str(buf).map(|(n, s)| (n, s.into())) + } - fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> - where - I: IntoIterator, - { - let mut bytes = bytes.into_iter(); - let format = take_byte_iter(bytes.by_ref())?; - let (n, len) = match format { - 0xa0..=0xbf => (1, format as usize & 0x1f), - Format::STR8 => (2, take_byte_iter(bytes.by_ref())? as usize), - Format::STR16 => ( - 3, - take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, - ), - Format::STR32 => ( - 5, - take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, - ), - _ => return Err(Error::UnexpectedFormatTag), - }; - let v: Vec<_> = bytes.take(len).collect(); - if v.len() < len { - return Err(Error::BufferTooShort); + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> + where + I: IntoIterator, + { + let mut bytes = bytes.into_iter(); + let format = take_byte_iter(bytes.by_ref())?; + let (n, len) = match format { + 0xa0..=0xbf => (1, format as usize & 0x1f), + Format::STR8 => (2, take_byte_iter(bytes.by_ref())? as usize), + Format::STR16 => ( + 3, + take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, + ), + Format::STR32 => ( + 5, + take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, + ), + _ => return Err(Error::UnexpectedFormatTag), + }; + let v: Vec<_> = bytes.take(len).collect(); + if v.len() < len { + return Err(Error::BufferTooShort); + } + let s = String::from_utf8(v).map_err(|_| Error::InvalidUtf8)?; + Ok((n + len, s)) } - let s = String::from_utf8(v).map_err(|_| Error::InvalidUtf8)?; - Ok((n + len, s)) } } diff --git a/msgpacker/src/unpack/collections.rs b/msgpacker/src/unpack/collections.rs index 4bf9ff5..bd86ff6 100644 --- a/msgpacker/src/unpack/collections.rs +++ b/msgpacker/src/unpack/collections.rs @@ -3,14 +3,10 @@ use super::{ Error, Format, Unpackable, }; -/// Unpacks an array from the buffer, returning a collectable type and the amount of read bytes. -pub fn unpack_array(mut buf: &[u8]) -> Result<(usize, C), ::Error> -where - V: Unpackable, - C: FromIterator, -{ +/// Unpacks the array length from the buffer. +pub fn unpack_array_len(mut buf: &[u8]) -> Result<(usize, usize), Error> { let format = take_byte(&mut buf)?; - let (mut n, len) = match format { + let (n, len) = match format { 0x90..=0x9f => (1, (format & 0x0f) as usize), Format::ARRAY16 => ( 3, @@ -22,6 +18,17 @@ where ), _ => return Err(Error::UnexpectedFormatTag.into()), }; + Ok((n, len)) +} + +/// Unpacks an array from the buffer, returning a collectable type and the amount of read bytes. +pub fn unpack_array(mut buf: &[u8]) -> Result<(usize, C), ::Error> +where + V: Unpackable, + C: FromIterator, +{ + let (mut n, len) = unpack_array_len(buf)?; + buf = &buf[n..]; let array: C = (0..len) .map(|_| { let (count, v) = V::unpack(buf)?; @@ -64,16 +71,10 @@ where Ok((n, array)) } -/// Unpacks a map from the buffer, returning a collectable type and the amount of read bytes. -pub fn unpack_map(mut buf: &[u8]) -> Result<(usize, C), ::Error> -where - K: Unpackable, - V: Unpackable, - ::Error: From<::Error>, - C: FromIterator<(K, V)>, -{ +/// Unpacks a map length from the buffer. +pub fn unpack_map_len(mut buf: &[u8]) -> Result<(usize, usize), Error> { let format = take_byte(&mut buf)?; - let (mut n, len) = match format { + let (n, len) = match format { 0x80..=0x8f => (1, (format & 0x0f) as usize), Format::MAP16 => ( 3, @@ -85,6 +86,19 @@ where ), _ => return Err(Error::UnexpectedFormatTag.into()), }; + Ok((n, len)) +} + +/// Unpacks a map from the buffer, returning a collectable type and the amount of read bytes. +pub fn unpack_map(mut buf: &[u8]) -> Result<(usize, C), ::Error> +where + K: Unpackable, + V: Unpackable, + ::Error: From<::Error>, + C: FromIterator<(K, V)>, +{ + let (mut n, len) = unpack_map_len(buf)?; + buf = &buf[n..]; let map: C = (0..len) .map(|_| { let (count, k) = K::unpack(buf)?; diff --git a/msgpacker/src/unpack/common.rs b/msgpacker/src/unpack/common.rs index 58821b5..f7ea64b 100644 --- a/msgpacker/src/unpack/common.rs +++ b/msgpacker/src/unpack/common.rs @@ -7,30 +7,48 @@ use core::{marker::PhantomData, mem::MaybeUninit}; impl Unpackable for () { type Error = Error; - fn unpack(_buf: &[u8]) -> Result<(usize, Self), Self::Error> { - Ok((0, ())) + fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> { + let format = take_byte(&mut buf)?; + if format != Format::NIL { + return Err(Error::UnexpectedFormatTag); + } + Ok((1, ())) } - fn unpack_iter(_buf: I) -> Result<(usize, Self), Self::Error> + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> where I: IntoIterator, { - Ok((0, ())) + let mut bytes = bytes.into_iter(); + let format = take_byte_iter(bytes.by_ref())?; + if format != Format::NIL { + return Err(Error::UnexpectedFormatTag); + } + Ok((1, ())) } } impl Unpackable for PhantomData { type Error = Error; - fn unpack(_buf: &[u8]) -> Result<(usize, Self), Self::Error> { - Ok((0, PhantomData)) + fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> { + let format = take_byte(&mut buf)?; + if format != Format::NIL { + return Err(Error::UnexpectedFormatTag); + } + Ok((1, PhantomData)) } - fn unpack_iter(_buf: I) -> Result<(usize, Self), Self::Error> + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> where I: IntoIterator, { - Ok((0, PhantomData)) + let mut bytes = bytes.into_iter(); + let format = take_byte_iter(bytes.by_ref())?; + if format != Format::NIL { + return Err(Error::UnexpectedFormatTag); + } + Ok((1, PhantomData)) } } @@ -58,6 +76,23 @@ impl Unpackable for bool { } } +impl Unpackable for char { + type Error = Error; + + fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { + u32::unpack(buf) + .and_then(|(n, v)| char::from_u32(v).ok_or(Error::InvalidUtf8).map(|c| (n, c))) + } + + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> + where + I: IntoIterator, + { + u32::unpack_iter(bytes) + .and_then(|(n, v)| char::from_u32(v).ok_or(Error::InvalidUtf8).map(|c| (n, c))) + } +} + impl Unpackable for Option where X: Unpackable, diff --git a/msgpacker/src/unpack/mod.rs b/msgpacker/src/unpack/mod.rs index 5f20331..5d8a92a 100644 --- a/msgpacker/src/unpack/mod.rs +++ b/msgpacker/src/unpack/mod.rs @@ -1,8 +1,7 @@ use super::{helpers, Error, Format, Unpackable}; -#[cfg(feature = "alloc")] -mod binary; -mod collections; +pub(crate) mod binary; +pub(crate) mod collections; mod common; mod float; mod int; diff --git a/msgpacker/tests/serde.rs b/msgpacker/tests/serde.rs new file mode 100644 index 0000000..9c6bb8f --- /dev/null +++ b/msgpacker/tests/serde.rs @@ -0,0 +1,204 @@ +use core::marker::PhantomData; + +use arbitrary::{Arbitrary as _, Unstructured}; +use arbitrary_json::ArbitraryValue; +use msgpacker::Packable; +use msgpacker_derive::MsgPacker; +use proptest::prelude::*; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, MsgPacker)] +pub struct Bar { + #[serde(with = "serde_bytes")] + pub b: Vec, + pub s: String, + pub t: (u64, u64, bool, String), + pub u: (), + pub p: PhantomData, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, MsgPacker)] +pub enum Foo { + Bar, + Baz(u32, String), + Qux { + #[serde(with = "serde_bytes")] + a: Vec, + b: u64, + }, +} + +#[test] +fn serde_works_bool() { + case(true); +} + +#[test] +fn serde_works_i8() { + case(i8::MAX - 3); +} + +#[test] +fn serde_works_i16() { + case(i16::MAX - 3); +} + +#[test] +fn serde_works_i32() { + case(i32::MAX - 3); +} + +#[test] +fn serde_works_i64() { + case(i64::MAX - 3); +} + +#[test] +fn serde_works_i128() { + case(i128::MAX - 3); +} + +#[test] +fn serde_works_isize() { + case(isize::MAX - 3); +} + +#[test] +fn serde_works_u8() { + case(u8::MAX - 3); +} + +#[test] +fn serde_works_u16() { + case(u16::MAX - 3); +} + +#[test] +fn serde_works_u32() { + case(u32::MAX - 3); +} + +#[test] +fn serde_works_u64() { + case(u64::MAX - 3); +} + +#[test] +fn serde_works_u128() { + case(u128::MAX - 3); +} + +#[test] +fn serde_works_usize() { + case(usize::MAX - 3); +} + +#[test] +fn serde_works_f32() { + case(f32::MAX - 3.0); +} + +#[test] +fn serde_works_f64() { + case(f64::MAX - 3.0); +} + +#[test] +fn serde_works_char() { + case('x'); +} + +#[test] +fn serde_works_string() { + case("foo".to_string()); +} + +#[test] +fn serde_works_bytes() { + // we need to specify to serialize this vec as bytes since serde cannot distinguish the + // concrete type of vec, but messagepack has a special treatment for bytes array + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, MsgPacker)] + struct BytesWrapper(#[serde(with = "serde_bytes")] Vec); + case(BytesWrapper(b"foo".to_vec())); +} + +#[test] +fn serde_works_option_some() { + case(Some("foo".to_string())); +} + +#[test] +fn serde_works_option_none() { + case(Option::::None); +} + +#[test] +fn serde_works_unit() { + case(()); +} + +#[test] +fn serde_works_tuple() { + case((18, u64::MAX, false, "zz".to_string())); +} + +#[test] +fn serde_works_struct() { + case(Bar { + b: b"xxxx".to_vec(), + s: "yyy".to_string(), + t: (18, u64::MAX, false, "zz".to_string()), + u: (), + p: PhantomData, + }); +} + +#[test] +fn serde_works_enum_variant() { + case(Foo::Bar); +} + +#[test] +fn serde_works_enum_variant_tuple() { + case(Foo::Baz(15, "xxx".into())); +} + +#[test] +fn serde_works_enum_struct() { + case(Foo::Qux { + a: vec![1, 2, 3], + b: 42, + }); +} + +proptest! { + #[test] + fn serde_proptest_json(seed: [u8; 32]) { + let seed = Unstructured::new(&seed); + let value = ArbitraryValue::arbitrary_take_rest(seed).unwrap().take(); + + let mut bytes = vec![]; + msgpacker::serde::to_buffer(&mut bytes, &value); + + let y: Value = msgpacker::serde::from_slice(&bytes).unwrap(); + assert_eq!(value, y); + } +} + +pub fn case(x: T) +where + T: Packable + Serialize + for<'de> Deserialize<'de> + PartialEq + core::fmt::Debug, +{ + let mut bytes = vec![]; + + msgpacker::serde::to_buffer(&mut bytes, &x); + + let mut mbt = vec![]; + x.pack(&mut mbt); + assert_eq!(bytes, mbt); + + let y: T = msgpacker::serde::from_slice(&bytes).unwrap(); + + assert_eq!(x, y); +}