Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: an even smarter pivotp #2368

Merged
merged 3 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
| [lens](/src/cmd/lens.rs#L2)✨ | Interactively view, search & filter a CSV using the [csvlens](https://github.com/YS-L/csvlens#csvlens) engine.
| <a name="luau_deeplink"></a><br>[luau](/src/cmd/luau.rs#L2) 👑✨<br>📇🌐🔣📚 ![CKAN](docs/images/ckan.png) | Create multiple new computed columns, filter rows, compute aggregations and build complex data pipelines by executing a [Luau](https://luau-lang.org) [0.653](https://github.com/Roblox/luau/releases/tag/0.653) expression/script for every row of a CSV file ([sequential mode](https://github.com/dathere/qsv/blob/bb72c4ef369d192d85d8b7cc6e972c1b7df77635/tests/test_luau.rs#L254-L298)), or using [random access](https://www.webopedia.com/definitions/random-access/) with an index ([random access mode](https://github.com/dathere/qsv/blob/bb72c4ef369d192d85d8b7cc6e972c1b7df77635/tests/test_luau.rs#L367-L415)).<br>Can process a single Luau expression or [full-fledged data-wrangling scripts using lookup tables](https://github.com/dathere/qsv-lookup-tables#example) with discrete BEGIN, MAIN and END sections.<br> It is not just another qsv command, it is qsv's [Domain-specific Language](https://en.wikipedia.org/wiki/Domain-specific_language) (DSL) with [numerous qsv-specific helper functions](https://github.com/dathere/qsv/blob/113eee17b97882dc368b2e65fec52b86df09f78b/src/cmd/luau.rs#L1356-L2290) to build production data pipelines. |
| [partition](/src/cmd/partition.rs#L2)<br>👆 | Partition a CSV based on a column value. |
| [pivotp](/src/cmd/pivotp.rs#L2)✨<br>🚀🐻‍❄️🪄 | Pivot CSV data. |
| [pivotp](/src/cmd/pivotp.rs#L2)✨<br>🚀🐻‍❄️🪄 | Pivot CSV data. Features "smart" aggregation auto-selection based on data type & stats. |
| [pro](/src/cmd/pro.rs#L2) | Interact with the [qsv pro](https://qsvpro.dathere.com) API. |
| [prompt](/src/cmd/prompt.rs#L2)✨ | Open a file dialog to either pick a file as input or save output to a file. |
| [pseudo](/src/cmd/pseudo.rs#L2)<br>🔣👆 | [Pseudonymise](https://en.wikipedia.org/wiki/Pseudonymization) the value of the given column by replacing them with an incremental identifier. |
Expand Down
189 changes: 169 additions & 20 deletions src/cmd/pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The pivot operation consists of:
- One or more index columns (these will be the new rows)
- A column that will be pivoted (this will create the new columns)
- A values column that will be aggregated
- An aggregation function to apply
- An aggregation function to apply. Features "smart" aggregation auto-selection.

For examples, see https://github.com/dathere/qsv/blob/master/tests/test_pivotp.rs.

Usage:
qsv pivotp [options] <on-cols> <input>
Expand Down Expand Up @@ -36,15 +38,21 @@ pivotp options:
median - Median value
count - Count of values
last - Last value encountered
[default: count]
none - No aggregation is done. Raises error if multiple values are in group.
smart - use value column statistics to pick an aggregation.
Will only work if there is one value column, otherwise
it falls back to `first`
smartq - same as smart, but no messages.
[default: smart]
--sort-columns Sort the transposed columns by name. Default is by order of discovery.
--col-separator <arg> The separator in generated column names in case of multiple --values columns.
[default: _]
--validate Validate a pivot by checking the pivot column(s)' cardinality.
--try-parsedates When set, will attempt to parse columns as dates.
--infer-len <arg> Number of rows to scan when inferring schema.
Set to 0 to scan entire file. [default: 10000]
--decimal-comma Use comma as decimal separator.
--decimal-comma Use comma as decimal separator when READING the input.
Note that you will need to specify an alternate --delimiter.
--ignore-errors Skip rows that can't be parsed.

Common options:
Expand All @@ -54,20 +62,24 @@ Common options:
Must be a single character. (default: ,)
"#;

use std::{fs::File, io, io::Write, path::Path};
use std::{fs::File, io, io::Write, path::Path, sync::OnceLock};

use csv::ByteRecord;
use indicatif::HumanCount;
use polars::prelude::*;
use polars_ops::pivot::{pivot_stable, PivotAgg};
use serde::Deserialize;

use crate::{
config::Delimiter,
cmd::stats::StatsData,
config::{Config, Delimiter},
util,
util::{get_stats_records, StatsMode},
CliResult,
};

static STATS_RECORDS: OnceLock<(ByteRecord, Vec<StatsData>)> = OnceLock::new();

#[derive(Deserialize)]
struct Args {
arg_on_cols: String,
Expand Down Expand Up @@ -115,11 +127,10 @@ fn calculate_pivot_metadata(
flag_memcheck: false,
};

let Ok((csv_fields, csv_stats)) =
let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
else {
return Ok(None);
};
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

if csv_stats.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -183,6 +194,113 @@ fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {
Ok(())
}

/// Suggest an appropriate aggregation function based on column statistics
#[allow(clippy::cast_precision_loss)]
fn suggest_agg_function(
args: &Args,
value_cols: &[String],
quiet: bool,
) -> CliResult<Option<PivotAgg>> {
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: false,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: Some(args.arg_input.clone()),
flag_memcheck: false,
};

let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}

// Get stats for the value column
let value_col = &value_cols[0];
let field_pos = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == value_col);

if let Some(pos) = field_pos {
let stats = &csv_stats[pos];
let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;

// Suggest aggregation based on field type and statistics
let suggested_agg = match stats.r#type.as_str() {
"NULL" => {
if !quiet {
eprintln!("Info: \"{value_col}\" contains only NULL values");
}
PivotAgg::Count
},
"Integer" | "Float" => {
if stats.nullcount as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count");
}
PivotAgg::Count
} else {
PivotAgg::Sum
}
},
"Date" | "DateTime" => {
if stats.cardinality as f64 / row_count as f64 > 0.9 {
if !quiet {
eprintln!(
"Info: {} column \"{value_col}\" has high cardinality, using First",
stats.r#type
);
}
PivotAgg::First
} else {
if !quiet {
eprintln!(
"Info: \"{value_col}\" is a {} column, using Count",
stats.r#type
);
}
PivotAgg::Count
}
},
_ => {
if stats.cardinality == row_count {
if !quiet {
eprintln!("Info: \"{value_col}\" contains all unique values, using First");
}
PivotAgg::First
} else if stats.cardinality as f64 / row_count as f64 > 0.5 {
if !quiet {
eprintln!("Info: \"{value_col}\" has high cardinality, using Count");
}
PivotAgg::Count
} else {
if !quiet {
eprintln!("Info: \"{value_col}\" is a String column, using Count");
}
PivotAgg::Count
}
},
};

Ok(Some(suggested_agg))
} else {
Ok(None)
}
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

Expand Down Expand Up @@ -226,17 +344,42 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

// Get aggregation function
let agg_fn = if let Some(ref agg) = args.flag_agg {
Some(match agg.to_lowercase().as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
_ => return fail_clierror!("Invalid pivot aggregation function: {agg}"),
})
let lower_agg = agg.to_lowercase();
if lower_agg == "none" {
None
} else {
Some(match lower_agg.as_str() {
"first" => PivotAgg::First,
"sum" => PivotAgg::Sum,
"min" => PivotAgg::Min,
"max" => PivotAgg::Max,
"mean" => PivotAgg::Mean,
"median" => PivotAgg::Median,
"count" => PivotAgg::Count,
"last" => PivotAgg::Last,
"smart" | "smartq" => {
if let Some(value_cols) = &value_cols {
// Try to suggest an appropriate aggregation function
if let Some(suggested_agg) =
suggest_agg_function(&args, value_cols, lower_agg == "smartq")?
{
suggested_agg
} else {
// fallback to first, which always works
PivotAgg::First
}
} else {
// Default to Count if no value columns specified
PivotAgg::Count
}
},
_ => {
return fail_incorrectusage_clierror!(
"Invalid pivot aggregation function: {agg}"
)
},
})
}
} else {
None
};
Expand All @@ -248,6 +391,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
b','
};

if args.flag_decimal_comma && delim == b',' {
return fail_incorrectusage_clierror!(
"You need to specify an alternate --delimiter when using --decimal-comma."
);
}

// Create CSV reader config
let csv_reader = LazyCsvReader::new(&args.arg_input)
.with_has_header(true)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ pivotp_test!(
wrk.assert_success(&mut cmd);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;1;1"]];
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;100;150"]];
assert_eq!(got, expected);
}
);
Expand Down Expand Up @@ -549,7 +549,7 @@ pivotp_test!(
wrk.assert_success(&mut cmd);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;1;1"]];
let expected = vec![svec!["date;A;B"], svec!["2023-01-01;100.5;150.75"]];
assert_eq!(got, expected);
}
);
Expand Down Expand Up @@ -577,8 +577,8 @@ pivotp_test!(
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "2", "1"],
svec!["2023-01-02", "1", "2"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
];
assert_eq!(got, expected);
}
Expand All @@ -604,8 +604,8 @@ pivotp_test!(
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "2", "1"],
svec!["2023-01-02", "1", "2"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
];
assert_eq!(got, expected);
}
Expand Down
Loading