diff --git a/src/cmd/pivotp.rs b/src/cmd/pivotp.rs index 472fe632d..a4c048fba 100644 --- a/src/cmd/pivotp.rs +++ b/src/cmd/pivotp.rs @@ -5,7 +5,9 @@ The pivot operation consists of: - One or more index columns (these will be the new rows) - A column that will be pivoted (this will create the new columns) - A values column that will be aggregated -- An aggregation function to apply +- An aggregation function to apply. Features "smart" aggregation auto-selection. + +For examples, see https://github.com/dathere/qsv/blob/master/tests/test_pivotp.rs. Usage: qsv pivotp [options] @@ -36,7 +38,12 @@ pivotp options: median - Median value count - Count of values last - Last value encountered - [default: count] + none - No aggregation is done. Raises error if multiple values are in group. + smart - use value column statistics to pick an aggregation. + Will only work if there is one value column, otherwise + it falls back to `first` + smartq - same as smart, but no messages. + [default: smart] --sort-columns Sort the transposed columns by name. Default is by order of discovery. --col-separator The separator in generated column names in case of multiple --values columns. [default: _] @@ -44,7 +51,8 @@ pivotp options: --try-parsedates When set, will attempt to parse columns as dates. --infer-len Number of rows to scan when inferring schema. Set to 0 to scan entire file. [default: 10000] - --decimal-comma Use comma as decimal separator. + --decimal-comma Use comma as decimal separator when READING the input. + Note that you will need to specify an alternate --delimiter. --ignore-errors Skip rows that can't be parsed. Common options: @@ -54,20 +62,24 @@ Common options: Must be a single character. (default: ,) "#; -use std::{fs::File, io, io::Write, path::Path}; +use std::{fs::File, io, io::Write, path::Path, sync::OnceLock}; +use csv::ByteRecord; use indicatif::HumanCount; use polars::prelude::*; use polars_ops::pivot::{pivot_stable, PivotAgg}; use serde::Deserialize; use crate::{ - config::Delimiter, + cmd::stats::StatsData, + config::{Config, Delimiter}, util, util::{get_stats_records, StatsMode}, CliResult, }; +static STATS_RECORDS: OnceLock<(ByteRecord, Vec)> = OnceLock::new(); + #[derive(Deserialize)] struct Args { arg_on_cols: String, @@ -115,11 +127,10 @@ fn calculate_pivot_metadata( flag_memcheck: false, }; - let Ok((csv_fields, csv_stats)) = + let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| { get_stats_records(&schema_args, StatsMode::FrequencyForceStats) - else { - return Ok(None); - }; + .unwrap_or_else(|_| (ByteRecord::new(), Vec::new())) + }); if csv_stats.is_empty() { return Ok(None); @@ -183,6 +194,113 @@ fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> { Ok(()) } +/// Suggest an appropriate aggregation function based on column statistics +#[allow(clippy::cast_precision_loss)] +fn suggest_agg_function( + args: &Args, + value_cols: &[String], + quiet: bool, +) -> CliResult> { + let schema_args = util::SchemaArgs { + flag_enum_threshold: 0, + flag_ignore_case: false, + flag_strict_dates: false, + flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(), + flag_dates_whitelist: String::new(), + flag_prefer_dmy: false, + flag_force: false, + flag_stdout: false, + flag_jobs: None, + flag_no_headers: false, + flag_delimiter: args.flag_delimiter, + arg_input: Some(args.arg_input.clone()), + flag_memcheck: false, + }; + + let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| { + get_stats_records(&schema_args, StatsMode::FrequencyForceStats) + .unwrap_or_else(|_| (ByteRecord::new(), Vec::new())) + }); + + // If multiple value columns, default to First + if value_cols.len() > 1 { + return Ok(Some(PivotAgg::First)); + } + + // Get stats for the value column + let value_col = &value_cols[0]; + let field_pos = csv_fields + .iter() + .position(|f| std::str::from_utf8(f).unwrap_or("") == value_col); + + if let Some(pos) = field_pos { + let stats = &csv_stats[pos]; + let rconfig = Config::new(Some(&args.arg_input)); + let row_count = util::count_rows(&rconfig)? as u64; + + // Suggest aggregation based on field type and statistics + let suggested_agg = match stats.r#type.as_str() { + "NULL" => { + if !quiet { + eprintln!("Info: \"{value_col}\" contains only NULL values"); + } + PivotAgg::Count + }, + "Integer" | "Float" => { + if stats.nullcount as f64 / row_count as f64 > 0.5 { + if !quiet { + eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count"); + } + PivotAgg::Count + } else { + PivotAgg::Sum + } + }, + "Date" | "DateTime" => { + if stats.cardinality as f64 / row_count as f64 > 0.9 { + if !quiet { + eprintln!( + "Info: {} column \"{value_col}\" has high cardinality, using First", + stats.r#type + ); + } + PivotAgg::First + } else { + if !quiet { + eprintln!( + "Info: \"{value_col}\" is a {} column, using Count", + stats.r#type + ); + } + PivotAgg::Count + } + }, + _ => { + if stats.cardinality == row_count { + if !quiet { + eprintln!("Info: \"{value_col}\" contains all unique values, using First"); + } + PivotAgg::First + } else if stats.cardinality as f64 / row_count as f64 > 0.5 { + if !quiet { + eprintln!("Info: \"{value_col}\" has high cardinality, using Count"); + } + PivotAgg::Count + } else { + if !quiet { + eprintln!("Info: \"{value_col}\" is a String column, using Count"); + } + PivotAgg::Count + } + }, + }; + + Ok(Some(suggested_agg)) + } else { + Ok(None) + } +} + pub fn run(argv: &[&str]) -> CliResult<()> { let args: Args = util::get_args(USAGE, argv)?; @@ -226,17 +344,42 @@ pub fn run(argv: &[&str]) -> CliResult<()> { // Get aggregation function let agg_fn = if let Some(ref agg) = args.flag_agg { - Some(match agg.to_lowercase().as_str() { - "first" => PivotAgg::First, - "sum" => PivotAgg::Sum, - "min" => PivotAgg::Min, - "max" => PivotAgg::Max, - "mean" => PivotAgg::Mean, - "median" => PivotAgg::Median, - "count" => PivotAgg::Count, - "last" => PivotAgg::Last, - _ => return fail_clierror!("Invalid pivot aggregation function: {agg}"), - }) + let lower_agg = agg.to_lowercase(); + if lower_agg == "none" { + None + } else { + Some(match lower_agg.as_str() { + "first" => PivotAgg::First, + "sum" => PivotAgg::Sum, + "min" => PivotAgg::Min, + "max" => PivotAgg::Max, + "mean" => PivotAgg::Mean, + "median" => PivotAgg::Median, + "count" => PivotAgg::Count, + "last" => PivotAgg::Last, + "smart" | "smartq" => { + if let Some(value_cols) = &value_cols { + // Try to suggest an appropriate aggregation function + if let Some(suggested_agg) = + suggest_agg_function(&args, value_cols, lower_agg == "smartq")? + { + suggested_agg + } else { + // fallback to first, which always works + PivotAgg::First + } + } else { + // Default to Count if no value columns specified + PivotAgg::Count + } + }, + _ => { + return fail_incorrectusage_clierror!( + "Invalid pivot aggregation function: {agg}" + ) + }, + }) + } } else { None }; @@ -248,6 +391,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> { b',' }; + if args.flag_decimal_comma && delim == b',' { + return fail_incorrectusage_clierror!( + "You need to specify an alternate --delimiter when using --decimal-comma." + ); + } + // Create CSV reader config let csv_reader = LazyCsvReader::new(&args.arg_input) .with_has_header(true)