Skip to content

Commit

Permalink
Merge pull request #1532 from jqnatividad/sample_rng_kinds
Browse files Browse the repository at this point in the history
`sample`: replace `--faster` RNG sampling option with `--rng <kind>` option
  • Loading branch information
jqnatividad authored Jan 7, 2024
2 parents c7b704b + 87b6ffe commit ffe098c
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 69 deletions.
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ qsv-sniffer = { version = "0.10", default-features = false, features = [
"runtime-dispatch-simd",
] }
rand = "0.8"
rand_hc = "0.3"
rand_xoshiro = "0.6"
rayon = "1.8"
redis = { version = "0.24", features = [
"ahash",
Expand Down
211 changes: 150 additions & 61 deletions src/cmd/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,16 @@ sample arguments:
sample options:
--seed <number> Random Number Generator (RNG) seed.
--faster Use a faster RNG that uses the Wyrand algorithm instead
of the ChaCha algorithm used by the standard RNG.
--rng <kind> The RNG algorithm to use.
Three RNGs are supported:
- standard: Use the standard RNG.
1.5 GB/s throughput.
- faster: Use faster RNG using the Xoshiro256Plus algorithm.
8 GB/s throughput.
- cryptosecure: Use cryptographically secure HC128 algorithm.
Recommended by eSTREAM (https://www.ecrypt.eu.org/stream/).
2.1 GB/s throughput though slow initialization.
[default: standard]
--user-agent <agent> Specify custom user agent to use when the input is a URL.
It supports the following variables -
$QSV_VERSION, $QSV_TARGET, $QSV_BIN_NAME, $QSV_KIND and $QSV_COMMAND.
Expand All @@ -55,11 +63,13 @@ Common options:
Must be a single character. (default: ,)
"#;

use std::io;
use std::{io, str::FromStr};

use fastrand; //DevSkim: ignore DS148264
use rand::{self, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use rand_hc::Hc128Rng;
use rand_xoshiro::Xoshiro256Plus;
use serde::Deserialize;
use strum_macros::EnumString;
use tempfile::NamedTempFile;
use url::Url;

Expand All @@ -75,19 +85,34 @@ struct Args {
flag_output: Option<String>,
flag_no_headers: bool,
flag_delimiter: Option<Delimiter>,
flag_seed: Option<usize>,
flag_faster: bool,
flag_seed: Option<u64>,
flag_rng: String,
flag_user_agent: Option<String>,
flag_timeout: Option<u16>,
}

#[derive(Debug, EnumString, PartialEq)]
#[strum(ascii_case_insensitive)]
enum RngKind {
Standard,
Faster,
Cryptosecure,
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let mut args: Args = util::get_args(USAGE, argv)?;

if args.arg_sample_size.is_sign_negative() {
return fail_incorrectusage_clierror!("Sample size cannot be negative.");
}

let Ok(rng_kind) = RngKind::from_str(&args.flag_rng) else {
return fail_incorrectusage_clierror!(
"Invalid RNG algorithm `{}`. Supported RNGs are: standard, faster, cryptosecure.",
args.flag_rng
);
};

let temp_download = NamedTempFile::new()?;

args.arg_input = match args.arg_input {
Expand Down Expand Up @@ -135,25 +160,50 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

let mut all_indices = (0..idx.count()).collect::<Vec<_>>();

if args.flag_faster {
log::info!(
"doing --faster sample_random_access. Seed: {:?}",
args.flag_seed
);
if let Some(seed) = args.flag_seed {
fastrand::seed(seed as u64); //DevSkim: ignore DS148264
}
all_indices = fastrand::choose_multiple(all_indices.into_iter(), sample_size as usize); //DevSkim: ignore DS148264
} else {
log::info!(
"doing standard sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng: StdRng = match args.flag_seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => StdRng::seed_from_u64(seed as u64), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
match rng_kind {
RngKind::Standard => {
log::info!(
"doing standard sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng: StdRng = match args.flag_seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => StdRng::seed_from_u64(seed), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
},
RngKind::Faster => {
log::info!(
"doing --faster sample_random_access. Seed: {:?}",
args.flag_seed
);

let mut rng = match args.flag_seed {
None => Xoshiro256Plus::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
},
RngKind::Cryptosecure => {
log::info!(
"doing cryptosecure sample_random_access. Seed: {:?}",
args.flag_seed
);
let seed_32 = match args.flag_seed {
None => rand::thread_rng().gen::<[u8; 32]>(),
Some(seed) => {
let seed_u8 = seed.to_le_bytes();
let mut seed_32 = [0u8; 32];
seed_32[..8].copy_from_slice(&seed_u8);
seed_32
},
};
let mut rng: Hc128Rng = match args.flag_seed {
None => Hc128Rng::from_rng(rand::thread_rng()).unwrap(),
Some(_) => Hc128Rng::from_seed(seed_32),
};
SliceRandom::shuffle(&mut *all_indices, &mut rng);
},
}

for i in all_indices.into_iter().take(sample_size as usize) {
Expand All @@ -171,12 +221,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
}
let mut rdr = rconfig.reader()?;
rconfig.write_headers(&mut rdr, &mut wtr)?;
let sampled = sample_reservoir(
&mut rdr,
sample_size as u64,
args.flag_seed,
args.flag_faster,
)?;
let sampled = sample_reservoir(&mut rdr, sample_size as u64, args.flag_seed, &rng_kind)?;
for row in sampled {
wtr.write_byte_record(&row)?;
}
Expand All @@ -188,8 +233,8 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
fn sample_reservoir<R: io::Read>(
rdr: &mut csv::Reader<R>,
sample_size: u64,
seed: Option<usize>,
faster: bool,
seed: Option<u64>,
rng_kind: &RngKind,
) -> CliResult<Vec<csv::ByteRecord>> {
// The following algorithm has been adapted from:
// https://en.wikipedia.org/wiki/Reservoir_sampling
Expand All @@ -199,37 +244,81 @@ fn sample_reservoir<R: io::Read>(
reservoir.push(row?);
}

if faster {
log::info!("doing --faster sample_reservoir. Seed: {seed:?}");
if let Some(seed) = seed {
fastrand::seed(seed as u64); //DevSkim: ignore DS148264
}
match *rng_kind {
RngKind::Standard => {
log::info!("doing standard sample_random_access. Seed: {seed:?}",);
let mut rng: StdRng = match seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => StdRng::seed_from_u64(seed), //DevSkim: ignore DS148264
};

let mut random: usize;
for (i, row) in records {
random = fastrand::usize(0..=i); //DevSkim: ignore DS148264
if random < sample_size as usize {
reservoir[random] = row?;
let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
}
} else {
log::info!("doing standard sample_reservoir. Seed: {seed:?}");
// Seeding RNG
let mut rng: StdRng = match seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => StdRng::seed_from_u64(seed as u64), //DevSkim: ignore DS148264
};

let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
},
RngKind::Faster => {
log::info!("doing --faster sample_random_access. Seed: {seed:?}",);

let mut rng = match seed {
None => Xoshiro256Plus::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), //DevSkim: ignore DS148264
};

let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
}

// if let Some(seed) = seed {
// fastrand::seed(seed); //DevSkim: ignore DS148264
// }

// let mut random: usize;
// for (i, row) in records {
// random = fastrand::usize(0..=i); //DevSkim: ignore DS148264
// if random < sample_size as usize {
// reservoir[random] = row?;
// }
// }
},
RngKind::Cryptosecure => {
log::info!("doing cryptosecure sample_random_access. Seed: {seed:?}",);

let seed_32 = match seed {
None => rand::thread_rng().gen::<[u8; 32]>(),
Some(seed) => {
let seed_u8 = seed.to_le_bytes();
let mut seed_32 = [0u8; 32];
seed_32[..8].copy_from_slice(&seed_u8);
seed_32
},
};
let mut rng: Hc128Rng = match seed {
None => Hc128Rng::from_rng(rand::thread_rng()).unwrap(),
Some(_) => Hc128Rng::from_seed(seed_32),
};

for (i, row) in records {
let random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
},
}

Ok(reservoir)
}
Loading

0 comments on commit ffe098c

Please sign in to comment.