diff --git a/src/decoder.rs b/src/decoder.rs index 66d64b1..38b8b01 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -155,10 +155,6 @@ impl SourceBlockDecoder { ) -> SourceBlockDecoder { let source_symbols = int_div_ceil(block_length, config.symbol_size() as u64); - let mut received_esi = Set::new(); - for i in source_symbols..extended_source_block_symbols(source_symbols) { - received_esi.insert(i); - } SourceBlockDecoder { source_block_id, symbol_size: config.symbol_size(), @@ -168,7 +164,7 @@ impl SourceBlockDecoder { source_symbols: vec![None; source_symbols as usize], repair_packets: vec![], received_source_symbols: 0, - received_esi, + received_esi: Set::new(), decoded: false, sparse_threshold: SPARSE_MATRIX_THRESHOLD, } @@ -253,15 +249,12 @@ impl SourceBlockDecoder { ); let (payload_id, payload) = packet.split(); - let num_extended_symbols = extended_source_block_symbols(self.source_block_symbols); if self.received_esi.insert(payload_id.encoding_symbol_id()) { - if payload_id.encoding_symbol_id() >= num_extended_symbols { + if payload_id.encoding_symbol_id() >= self.source_block_symbols { // Repair symbol self.repair_packets .push(EncodingPacket::new(payload_id, payload)); } else { - // Check that this is not an extended symbol (which aren't explicitly sent) - assert!(payload_id.encoding_symbol_id() < self.source_block_symbols); // Source symbol self.source_symbols[payload_id.encoding_symbol_id() as usize] = Some(Symbol::new(payload)); @@ -271,6 +264,14 @@ impl SourceBlockDecoder { } let num_extended_symbols = extended_source_block_symbols(self.source_block_symbols); + let num_padding_symbols = num_extended_symbols - self.source_block_symbols; + + // Case 1: the number of received packets is insufficient for decoding + if self.received_esi.len() < self.source_block_symbols as usize { + return None; + } + + // Case 2: we have all source symbols and can return them without decoding if self.received_source_symbols == self.source_block_symbols { let mut result = vec![0; self.symbol_size as usize * self.source_block_symbols as usize]; @@ -282,46 +283,46 @@ impl SourceBlockDecoder { return Some(result); } - if self.received_esi.len() as u32 >= num_extended_symbols { - let s = num_ldpc_symbols(self.source_block_symbols) as usize; - let h = num_hdpc_symbols(self.source_block_symbols) as usize; - - let mut encoded_indices = vec![]; - // See section 5.3.3.4.2. There are S + H zero symbols to start the D vector - let mut d = vec![Symbol::zero(self.symbol_size); s + h]; - for (i, source) in self.source_symbols.iter().enumerate() { - if let Some(symbol) = source { - encoded_indices.push(i as u32); - d.push(symbol.clone()); - } + // Case 3: we may have sufficient symbols to do a standard decoding + let s = num_ldpc_symbols(self.source_block_symbols) as usize; + let h = num_hdpc_symbols(self.source_block_symbols) as usize; + + let mut encoded_isis = vec![]; + // See section 5.3.3.4.2. There are S + H zero symbols to start the D vector + let mut d = vec![Symbol::zero(self.symbol_size); s + h]; + for (i, source) in self.source_symbols.iter().enumerate() { + if let Some(symbol) = source { + encoded_isis.push(i as u32); + d.push(symbol.clone()); } + } - // Append the extended padding symbols - for i in self.source_block_symbols..num_extended_symbols { - encoded_indices.push(i); - d.push(Symbol::zero(self.symbol_size)); - } + // Append the extended padding symbols + for i in self.source_block_symbols..num_extended_symbols { + encoded_isis.push(i); + d.push(Symbol::zero(self.symbol_size)); + } - for repair_packet in self.repair_packets.iter() { - encoded_indices.push(repair_packet.payload_id.encoding_symbol_id()); - d.push(Symbol::new(repair_packet.data.clone())); - } + // Append the received repair symbols + for repair_packet in self.repair_packets.iter() { + // We need to convert from ESI to ISI + encoded_isis.push(repair_packet.payload_id.encoding_symbol_id() + num_padding_symbols); + d.push(Symbol::new(repair_packet.data.clone())); + } - if extended_source_block_symbols(self.source_block_symbols) >= self.sparse_threshold { - let (constraint_matrix, hdpc) = generate_constraint_matrix::( - self.source_block_symbols, - &encoded_indices, - ); - return self.try_pi_decode(constraint_matrix, hdpc, d); - } else { - let (constraint_matrix, hdpc) = generate_constraint_matrix::( - self.source_block_symbols, - &encoded_indices, - ); - return self.try_pi_decode(constraint_matrix, hdpc, d); - } + if num_extended_symbols >= self.sparse_threshold { + let (constraint_matrix, hdpc) = generate_constraint_matrix::( + self.source_block_symbols, + &encoded_isis, + ); + self.try_pi_decode(constraint_matrix, hdpc, d) + } else { + let (constraint_matrix, hdpc) = generate_constraint_matrix::( + self.source_block_symbols, + &encoded_isis, + ); + self.try_pi_decode(constraint_matrix, hdpc, d) } - None } fn rebuild_source_symbol( @@ -346,16 +347,6 @@ impl SourceBlockDecoder { #[cfg(feature = "std")] #[cfg(test)] mod codec_tests { - #[cfg(not(feature = "python"))] - use crate::Decoder; - use crate::SourceBlockEncoder; - use crate::SourceBlockEncodingPlan; - #[cfg(not(feature = "python"))] - use crate::{Encoder, EncoderBuilder}; - use crate::{ObjectTransmissionInformation, SourceBlockDecoder}; - #[cfg(not(feature = "python"))] - use rand::seq::SliceRandom; - use rand::Rng; use std::{ iter, sync::{ @@ -365,6 +356,19 @@ mod codec_tests { vec::Vec, }; + #[cfg(not(feature = "python"))] + use rand::seq::SliceRandom; + use rand::Rng; + + #[cfg(not(feature = "python"))] + use crate::Decoder; + #[cfg(not(feature = "python"))] + use crate::{Encoder, EncoderBuilder}; + use crate::{ + ObjectTransmissionInformation, SourceBlockDecoder, SourceBlockEncoder, + SourceBlockEncodingPlan, + }; + #[cfg(not(feature = "python"))] #[test] fn random_erasure_dense() { diff --git a/src/encoder.rs b/src/encoder.rs index 998b4de..1b68120 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -321,7 +321,10 @@ impl SourceBlockEncoder { for i in 0..packets { let tuple = intermediate_tuple(start_encoding_symbol_id + i, lt_symbols, sys_index, p1); result.push(EncodingPacket::new( - PayloadId::new(self.source_block_id, start_encoding_symbol_id + i), + PayloadId::new( + self.source_block_id, + self.source_symbols.len() as u32 + start_repair_symbol_id + i, + ), enc( self.source_symbols.len() as u32, &self.intermediate_symbols, @@ -443,15 +446,16 @@ mod tests { use rand::Rng; use std::vec::Vec; + use super::*; + use crate::base::intermediate_tuple; - use crate::encoder::enc; - use crate::encoder::gen_intermediate_symbols; use crate::symbol::Symbol; use crate::systematic_constants::num_lt_symbols; use crate::systematic_constants::num_pi_symbols; use crate::systematic_constants::{ calculate_p1, num_ldpc_symbols, systematic_index, MAX_SOURCE_SYMBOLS_PER_BLOCK, }; + use crate::PayloadId; #[cfg(not(feature = "python"))] use crate::{Encoder, EncoderBuilder, EncodingPacket, ObjectTransmissionInformation}; #[cfg(not(feature = "python"))] @@ -555,6 +559,33 @@ mod tests { } } + #[test] + fn encoding_creates_expected_packets() { + let symbol_size = 2; + let data: [u8; 6] = [0, 1, 2, 3, 4, 5]; + let encoder = SourceBlockEncoder::new2( + 0, + &ObjectTransmissionInformation::new(0, symbol_size, 1, 1, 1), + &data, + ); + assert_eq!( + encoder.source_packets(), + [[0, 1], [2, 3], [4, 5]] + .into_iter() + .enumerate() + .map(|(i, d)| EncodingPacket::new(PayloadId::new(0, i as u32), d.into())) + .collect::>() + ); + assert_eq!( + encoder + .repair_packets(2, 4) + .into_iter() + .map(|p| p.payload_id.encoding_symbol_id()) + .collect::>(), + &[5, 6, 7, 8] + ); + } + #[cfg(not(feature = "python"))] #[test] fn test_builder() {