Skip to content

Commit

Permalink
feat: smart pivot - use cardinality from stats to inform pivot with `…
Browse files Browse the repository at this point in the history
…--validate` option; make the default agg count
  • Loading branch information
jqnatividad committed Dec 22, 2024
1 parent 138a501 commit ab9029a
Showing 1 changed file with 119 additions and 6 deletions.
125 changes: 119 additions & 6 deletions src/cmd/pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Usage:
qsv pivotp --help
pivotp arguments:
<on-cols> The column(s) to pivot on (creates new columns).
<input> is the input CSV file. The file must have headers.
Stdin is not supported.
<on-cols> The column(s) to pivot (creates new columns).
pivotp options:
Expand All @@ -36,9 +36,11 @@ pivotp options:
median - Median value
count - Count of values
last - Last value encountered
[default: count]
--sort-columns Sort the transposed columns by name. Default is by order of discovery.
--col-separator <arg> The separator in generated column names in case of multiple --values columns.
[default: _]
--validate Validate a pivot by checking the pivot column(s)' cardinality.
--try-parsedates When set, will attempt to parse columns as dates.
--infer-len <arg> Number of rows to scan when inferring schema.
Set to 0 to scan entire file. [default: 10000]
Expand All @@ -54,21 +56,28 @@ Common options:

use std::{fs::File, io, io::Write, path::Path};

use indicatif::HumanCount;
use polars::prelude::*;
use polars_ops::pivot::{pivot_stable, PivotAgg};
use serde::Deserialize;

use crate::{config::Delimiter, util, CliResult};
use crate::{
config::Delimiter,
util,
util::{get_stats_records, StatsMode},
CliResult,
};

#[derive(Deserialize)]
struct Args {
arg_input: String,
arg_on_cols: String,
arg_input: String,
flag_index: Option<String>,
flag_values: Option<String>,
flag_agg: Option<String>,
flag_sort_columns: bool,
flag_col_separator: String,
flag_validate: bool,
flag_try_parsedates: bool,
flag_infer_len: usize,
flag_decimal_comma: bool,
Expand All @@ -77,6 +86,103 @@ struct Args {
flag_delimiter: Option<Delimiter>,
}

/// Structure to hold pivot operation metadata
struct PivotMetadata {
estimated_columns: u64,
on_col_cardinalities: Vec<(String, u64)>,
}

/// Calculate pivot operation metadata using stats information
fn calculate_pivot_metadata(
args: &Args,
on_cols: &[String],
value_cols: Option<&Vec<String>>,
) -> CliResult<Option<PivotMetadata>> {
// Get stats records
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};

let Ok((csv_fields, csv_stats)) =
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
else {
return Ok(None);
};

if csv_stats.is_empty() {
return Ok(None);
}

// Get cardinalities for pivot columns
let mut on_col_cardinalities = Vec::with_capacity(on_cols.len());
let mut total_new_columns: u64 = 1;

for on_col in on_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == on_col)
{
let cardinality = csv_stats[pos].cardinality;
total_new_columns = total_new_columns.saturating_mul(cardinality);
on_col_cardinalities.push((on_col.clone(), cardinality));
}
}

// Calculate total columns in result
let value_cols_count = match value_cols {
Some(cols) => cols.len() as u64,
None => 1,
};
let total_columns = total_new_columns.saturating_mul(value_cols_count);

Ok(Some(PivotMetadata {
estimated_columns: total_columns,
on_col_cardinalities,
}))
}

/// Validate pivot operation using metadata
fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {
const COLUMN_WARNING_THRESHOLD: u64 = 1000;

// Print cardinality information
eprintln!("Pivot column cardinalities:");
for (col, card) in &metadata.on_col_cardinalities {
eprintln!(" {col}: {}", HumanCount(*card));
}

// Warn about large number of columns
if metadata.estimated_columns > COLUMN_WARNING_THRESHOLD {
eprintln!(
"Warning: Pivot will create {} columns. This might impact performance.",
HumanCount(metadata.estimated_columns)
);
}

// Error if operation would create an unreasonable number of columns
if metadata.estimated_columns > 100_000 {
return fail_clierror!(
"Pivot would create too many columns ({}). Consider reducing the number of pivot \
columns or using a different approach.",
HumanCount(metadata.estimated_columns)
);
}

Ok(())
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

Expand All @@ -89,7 +195,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
.collect();

// Parse index column(s)
let index_cols = if let Some(flag_index) = args.flag_index {
let index_cols = if let Some(ref flag_index) = args.flag_index {
let idx_cols: Vec<String> = flag_index
.as_str()
.split(',')
Expand All @@ -101,7 +207,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
};

// Parse values column(s)
let value_cols = if let Some(flag_values) = args.flag_values {
let value_cols = if let Some(ref flag_values) = args.flag_values {
let val_cols: Vec<String> = flag_values
.as_str()
.split(',')
Expand All @@ -119,7 +225,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
}

// Get aggregation function
let agg_fn = if let Some(agg) = args.flag_agg {
let agg_fn = if let Some(ref agg) = args.flag_agg {
Some(match agg.to_lowercase().as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
Expand Down Expand Up @@ -154,6 +260,13 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
// Read the CSV into a DataFrame
let df = csv_reader.finish()?.collect()?;

if args.flag_validate {
// Validate the operation
if let Some(metadata) = calculate_pivot_metadata(&args, &on_cols, value_cols.as_ref())? {
validate_pivot_operation(&metadata)?;
}
}

// Perform pivot operation
let mut pivot_result = pivot_stable(
&df,
Expand Down

0 comments on commit ab9029a

Please sign in to comment.