diff --git a/addon/decompressor.h b/addon/decompressor.h index 917a170..1e365fa 100644 --- a/addon/decompressor.h +++ b/addon/decompressor.h @@ -6,25 +6,34 @@ using namespace Napi; +using DCTX_Deleter = void (*)(ZSTD_DCtx*); + struct Decompressor { std::vector data; - size_t buffer_size; - Decompressor(std::vector data, size_t buffer_size) - : data(data), buffer_size(buffer_size) {} + Decompressor(std::vector data) : data(data) {} CompressionResult operator()() { - std::vector decompressed(buffer_size); + std::vector decompressed; - size_t _result = - ZSTD_decompress(decompressed.data(), decompressed.size(), data.data(), data.size()); + std::unique_ptr decompression_context(ZSTD_createDCtx(), + (DCTX_Deleter)ZSTD_freeDCtx); - if (ZSTD_isError(_result)) { - std::string error(ZSTD_getErrorName(_result)); - return CompressionResult::Error(error); - } + ZSTD_inBuffer input = {data.data(), data.size(), 0}; - decompressed.resize(_result); + while (input.pos < input.size) { + std::vector chunk(ZSTD_DStreamOutSize()); + ZSTD_outBuffer output = {chunk.data(), chunk.size(), 0}; + size_t const ret = ZSTD_decompressStream(decompression_context.get(), &output, &input); + if (ZSTD_isError(ret)) { + std::string error(ZSTD_getErrorName(ret)); + return CompressionResult::Error(error); + } + + for (size_t i = 0; i < output.pos; ++i) { + decompressed.push_back(chunk[i]); + } + } return CompressionResult::Ok(decompressed); } @@ -36,6 +45,6 @@ struct Decompressor { std::vector data(total); std::copy(input_data, input_data + total, data.data()); - return Decompressor(data, total * 1000); + return Decompressor(data); } }; diff --git a/test/index.test.js b/test/index.test.js index 213179e..3c53841 100644 --- a/test/index.test.js +++ b/test/index.test.js @@ -33,6 +33,10 @@ describe('compress', function () { test('compress() returns a Nodejs buffer', async function () { expect(await compress(Buffer.from([1, 2, 3]))).to.be.instanceOf(Buffer); }); + + test('decompress() with empty buffer', async function () { + expect(await decompress(Buffer.from([]))).to.deep.equal(Buffer.from([])) + }) }); /**