Skip to content

Commit

Permalink
feat: an even smarter pivotp, featuring auto-aggregation selection …
Browse files Browse the repository at this point in the history
…based on type/stats
  • Loading branch information
jqnatividad committed Dec 22, 2024
1 parent 3f66c50 commit 01c9d24
Showing 1 changed file with 169 additions and 20 deletions.
189 changes: 169 additions & 20 deletions src/cmd/pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The pivot operation consists of:
- One or more index columns (these will be the new rows)
- A column that will be pivoted (this will create the new columns)
- A values column that will be aggregated
- An aggregation function to apply
- An aggregation function to apply. Features "smart" aggregation auto-selection.
For examples, see https://github.com/dathere/qsv/blob/master/tests/test_pivotp.rs.
Usage:
qsv pivotp [options] <on-cols> <input>
Expand Down Expand Up @@ -36,15 +38,21 @@ pivotp options:
median - Median value
count - Count of values
last - Last value encountered
[default: count]
none - No aggregation is done. Raises error if multiple values are in group.
smart - use value column statistics to pick an aggregation.
Will only work if there is one value column, otherwise
it falls back to `first`
smartq - same as smart, but no messages.
[default: smart]
--sort-columns Sort the transposed columns by name. Default is by order of discovery.
--col-separator <arg> The separator in generated column names in case of multiple --values columns.
[default: _]
--validate Validate a pivot by checking the pivot column(s)' cardinality.
--try-parsedates When set, will attempt to parse columns as dates.
--infer-len <arg> Number of rows to scan when inferring schema.
Set to 0 to scan entire file. [default: 10000]
--decimal-comma Use comma as decimal separator.
--decimal-comma Use comma as decimal separator when READING the input.
Note that you will need to specify an alternate --delimiter.
--ignore-errors Skip rows that can't be parsed.
Common options:
Expand All @@ -54,20 +62,24 @@ Common options:
Must be a single character. (default: ,)
"#;

use std::{fs::File, io, io::Write, path::Path};
use std::{fs::File, io, io::Write, path::Path, sync::OnceLock};

use csv::ByteRecord;
use indicatif::HumanCount;
use polars::prelude::*;
use polars_ops::pivot::{pivot_stable, PivotAgg};
use serde::Deserialize;

use crate::{
config::Delimiter,
cmd::stats::StatsData,
config::{Config, Delimiter},
util,
util::{get_stats_records, StatsMode},
CliResult,
};

static STATS_RECORDS: OnceLock<(ByteRecord, Vec<StatsData>)> = OnceLock::new();

#[derive(Deserialize)]
struct Args {
arg_on_cols: String,
Expand Down Expand Up @@ -115,11 +127,10 @@ fn calculate_pivot_metadata(
flag_memcheck: false,
};

let Ok((csv_fields, csv_stats)) =
let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
else {
return Ok(None);
};
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

if csv_stats.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -183,6 +194,113 @@ fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {
Ok(())
}

/// Suggest an appropriate aggregation function based on column statistics
#[allow(clippy::cast_precision_loss)]
fn suggest_agg_function(
args: &Args,
value_cols: &[String],
quiet: bool,
) -> CliResult<Option<PivotAgg>> {
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};

let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}

// Get stats for the value column
let value_col = &value_cols[0];
let field_pos = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == value_col);

if let Some(pos) = field_pos {
let stats = &csv_stats[pos];
let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;

// Suggest aggregation based on field type and statistics
let suggested_agg = match stats.r#type.as_str() {
"NULL" => {
if !quiet {
eprintln!("Info: \"{value_col}\" contains only NULL values");
}
PivotAgg::Count
},
"Integer" | "Float" => {
if stats.nullcount as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count");
}
PivotAgg::Count
} else {
PivotAgg::Sum
}
},
"Date" | "DateTime" => {
if stats.cardinality as f64 / row_count as f64 > 0.9 {
if !quiet {
eprintln!(
"Info: {} column \"{value_col}\" has high cardinality, using First",
stats.r#type
);
}
PivotAgg::First
} else {
if !quiet {
eprintln!(
"Info: \"{value_col}\" is a {} column, using Count",
stats.r#type
);
}
PivotAgg::Count
}
},
_ => {
if stats.cardinality == row_count {
if !quiet {
eprintln!("Info: \"{value_col}\" contains all unique values, using First");
}
PivotAgg::First
} else if stats.cardinality as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" has high cardinality, using Count");
}
PivotAgg::Count
} else {
if !quiet {
eprintln!("Info: \"{value_col}\" is a String column, using Count");
}
PivotAgg::Count
}
},
};

Ok(Some(suggested_agg))
} else {
Ok(None)
}
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

Expand Down Expand Up @@ -226,17 +344,42 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

// Get aggregation function
let agg_fn = if let Some(ref agg) = args.flag_agg {
Some(match agg.to_lowercase().as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
_ => return fail_clierror!("Invalid pivot aggregation function: {agg}"),
})
let lower_agg = agg.to_lowercase();
if lower_agg == "none" {
None
} else {
Some(match lower_agg.as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
"smart" | "smartq" => {
if let Some(value_cols) = &value_cols {
// Try to suggest an appropriate aggregation function
if let Some(suggested_agg) =
suggest_agg_function(&args, value_cols, lower_agg == "smartq")?
{
suggested_agg
} else {
// fallback to first, which always works
PivotAgg::First
}
} else {
// Default to Count if no value columns specified
PivotAgg::Count
}
},
_ => {
return fail_incorrectusage_clierror!(
"Invalid pivot aggregation function: {agg}"
)
},
})
}
} else {
None
};
Expand All @@ -248,6 +391,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
b','
};

if args.flag_decimal_comma && delim == b',' {
return fail_incorrectusage_clierror!(
"You need to specify an alternate --delimiter when using --decimal-comma."
);
}

// Create CSV reader config
let csv_reader = LazyCsvReader::new(&args.arg_input)
.with_has_header(true)
Expand Down

0 comments on commit 01c9d24

Please sign in to comment.