From 66882071ad023e3a3dd4538bedee1f18932dbb6b Mon Sep 17 00:00:00 2001 From: Douglas Dwyer Date: Fri, 22 Sep 2023 16:30:02 -0400 Subject: [PATCH] Add the ability to provide a safe context for encoders to use Remove unsafe methods and clean up MaybeOwned enums --- src/stream/raw.rs | 124 ++++++++++++++++++++++++++++++---------- src/stream/read/mod.rs | 13 +++++ src/stream/write/mod.rs | 13 +++++ 3 files changed, 121 insertions(+), 29 deletions(-) diff --git a/src/stream/raw.rs b/src/stream/raw.rs index f79241df..25b5aff1 100644 --- a/src/stream/raw.rs +++ b/src/stream/raw.rs @@ -132,7 +132,7 @@ pub struct Status { /// An in-memory decoder for streams of data. pub struct Decoder<'a> { - context: zstd_safe::DCtx<'a>, + context: MaybeOwnedDCtx<'a>, } impl Decoder<'static> { @@ -148,11 +148,20 @@ impl Decoder<'static> { context .load_dictionary(dictionary) .map_err(map_error_code)?; - Ok(Decoder { context }) + Ok(Decoder { + context: MaybeOwnedDCtx::Owned(context), + }) } } impl<'a> Decoder<'a> { + /// Creates a new decoder which employs the provided context for deserialization. + pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self { + Self { + context: MaybeOwnedDCtx::Borrowed(context), + } + } + /// Creates a new decoder, using an existing `DecoderDictionary`. pub fn with_prepared_dictionary<'b>( dictionary: &DecoderDictionary<'b>, @@ -164,7 +173,9 @@ impl<'a> Decoder<'a> { context .ref_ddict(dictionary.as_ddict()) .map_err(map_error_code)?; - Ok(Decoder { context }) + Ok(Decoder { + context: MaybeOwnedDCtx::Owned(context), + }) } /// Creates a new decoder, using a ref prefix @@ -183,9 +194,11 @@ impl<'a> Decoder<'a> { /// Sets a decompression parameter for this decoder. pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> { - self.context - .set_parameter(parameter) - .map_err(map_error_code)?; + match &mut self.context { + MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter), + MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter), + } + .map_err(map_error_code)?; Ok(()) } } @@ -196,9 +209,11 @@ impl Operation for Decoder<'_> { input: &mut InBuffer<'_>, output: &mut OutBuffer<'_, C>, ) -> io::Result { - self.context - .decompress_stream(output, input) - .map_err(map_error_code) + match &mut self.context { + MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input), + MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input), + } + .map_err(map_error_code) } fn flush( @@ -219,9 +234,15 @@ impl Operation for Decoder<'_> { } fn reinit(&mut self) -> io::Result<()> { - self.context - .reset(zstd_safe::ResetDirective::SessionOnly) - .map_err(map_error_code)?; + match &mut self.context { + MaybeOwnedDCtx::Owned(x) => { + x.reset(zstd_safe::ResetDirective::SessionOnly) + } + MaybeOwnedDCtx::Borrowed(x) => { + x.reset(zstd_safe::ResetDirective::SessionOnly) + } + } + .map_err(map_error_code)?; Ok(()) } @@ -243,7 +264,7 @@ impl Operation for Decoder<'_> { /// An in-memory encoder for streams of data. pub struct Encoder<'a> { - context: zstd_safe::CCtx<'a>, + context: MaybeOwnedCCtx<'a>, } impl Encoder<'static> { @@ -264,11 +285,20 @@ impl Encoder<'static> { .load_dictionary(dictionary) .map_err(map_error_code)?; - Ok(Encoder { context }) + Ok(Encoder { + context: MaybeOwnedCCtx::Owned(context), + }) } } impl<'a> Encoder<'a> { + /// Creates a new encoder that uses the provided context for serialization. + pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self { + Self { + context: MaybeOwnedCCtx::Borrowed(context), + } + } + /// Creates a new encoder using an existing `EncoderDictionary`. pub fn with_prepared_dictionary<'b>( dictionary: &EncoderDictionary<'b>, @@ -280,7 +310,9 @@ impl<'a> Encoder<'a> { context .ref_cdict(dictionary.as_cdict()) .map_err(map_error_code)?; - Ok(Encoder { context }) + Ok(Encoder { + context: MaybeOwnedCCtx::Owned(context), + }) } /// Creates a new encoder initialized with the given ref prefix. @@ -306,9 +338,11 @@ impl<'a> Encoder<'a> { /// Sets a compression parameter for this encoder. pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> { - self.context - .set_parameter(parameter) - .map_err(map_error_code)?; + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter), + MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter), + } + .map_err(map_error_code)?; Ok(()) } @@ -324,9 +358,15 @@ impl<'a> Encoder<'a> { &mut self, pledged_src_size: Option, ) -> io::Result<()> { - self.context - .set_pledged_src_size(pledged_src_size) - .map_err(map_error_code)?; + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => { + x.set_pledged_src_size(pledged_src_size) + } + MaybeOwnedCCtx::Borrowed(x) => { + x.set_pledged_src_size(pledged_src_size) + } + } + .map_err(map_error_code)?; Ok(()) } } @@ -337,16 +377,22 @@ impl<'a> Operation for Encoder<'a> { input: &mut InBuffer<'_>, output: &mut OutBuffer<'_, C>, ) -> io::Result { - self.context - .compress_stream(output, input) - .map_err(map_error_code) + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input), + MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input), + } + .map_err(map_error_code) } fn flush( &mut self, output: &mut OutBuffer<'_, C>, ) -> io::Result { - self.context.flush_stream(output).map_err(map_error_code) + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => x.flush_stream(output), + MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output), + } + .map_err(map_error_code) } fn finish( @@ -354,17 +400,37 @@ impl<'a> Operation for Encoder<'a> { output: &mut OutBuffer<'_, C>, _finished_frame: bool, ) -> io::Result { - self.context.end_stream(output).map_err(map_error_code) + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => x.end_stream(output), + MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output), + } + .map_err(map_error_code) } fn reinit(&mut self) -> io::Result<()> { - self.context - .reset(zstd_safe::ResetDirective::SessionOnly) - .map_err(map_error_code)?; + match &mut self.context { + MaybeOwnedCCtx::Owned(x) => { + x.reset(zstd_safe::ResetDirective::SessionOnly) + } + MaybeOwnedCCtx::Borrowed(x) => { + x.reset(zstd_safe::ResetDirective::SessionOnly) + } + } + .map_err(map_error_code)?; Ok(()) } } +enum MaybeOwnedCCtx<'a> { + Owned(zstd_safe::CCtx<'a>), + Borrowed(&'a mut zstd_safe::CCtx<'static>), +} + +enum MaybeOwnedDCtx<'a> { + Owned(zstd_safe::DCtx<'a>), + Borrowed(&'a mut zstd_safe::DCtx<'static>), +} + #[cfg(test)] mod tests { diff --git a/src/stream/read/mod.rs b/src/stream/read/mod.rs index ee5e4137..06e6022f 100644 --- a/src/stream/read/mod.rs +++ b/src/stream/read/mod.rs @@ -46,6 +46,19 @@ impl Decoder<'static, R> { } } impl<'a, R: BufRead> Decoder<'a, R> { + /// Creates a new decoder which employs the provided context for deserialization. + pub fn with_context( + reader: R, + context: &'a mut zstd_safe::DCtx<'static>, + ) -> Self { + Self { + reader: zio::Reader::new( + reader, + raw::Decoder::with_context(context), + ), + } + } + /// Sets this `Decoder` to stop after the first frame. /// /// By default, it keeps concatenating frames until EOF is reached. diff --git a/src/stream/write/mod.rs b/src/stream/write/mod.rs index b9dd895c..3bc305f4 100644 --- a/src/stream/write/mod.rs +++ b/src/stream/write/mod.rs @@ -193,6 +193,19 @@ impl Encoder<'static, W> { } impl<'a, W: Write> Encoder<'a, W> { + /// Creates an encoder that uses the provided context to compress a stream. + pub fn with_context( + writer: W, + context: &'a mut zstd_safe::CCtx<'static>, + ) -> Self { + Self { + writer: zio::Writer::new( + writer, + raw::Encoder::with_context(context), + ), + } + } + /// Creates a new encoder, using an existing prepared `EncoderDictionary`. /// /// (Provides better compression ratio for small files,