diff --git a/Cargo.lock b/Cargo.lock index 00ff4c2..39e0f87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -394,6 +394,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -1831,8 +1837,7 @@ checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "postgres" version = "0.19.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7915b33ed60abc46040cbcaa25ffa1c7ec240668e0477c4f3070786f5916d451" +source = "git+https://github.com/splitgraph/rust-postgres?rev=88c2c7714a4558aed6a63e2e2b140a8359568858#88c2c7714a4558aed6a63e2e2b140a8359568858" dependencies = [ "bytes", "fallible-iterator", @@ -1845,10 +1850,8 @@ dependencies = [ [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d442770e2b1e244bb5eb03b31c79b65bb2568f413b899eaba850fa945a65954" +source = "git+https://github.com/splitgraph/rust-postgres?rev=88c2c7714a4558aed6a63e2e2b140a8359568858#88c2c7714a4558aed6a63e2e2b140a8359568858" dependencies = [ - "futures", "native-tls", "tokio", "tokio-native-tls", @@ -1858,10 +1861,9 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b6c5ef183cd3ab4ba005f1ca64c21e8bd97ce4699cfea9e8d9a2c4958ca520" +source = "git+https://github.com/splitgraph/rust-postgres?rev=88c2c7714a4558aed6a63e2e2b140a8359568858#88c2c7714a4558aed6a63e2e2b140a8359568858" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "byteorder", "bytes", "fallible-iterator", @@ -1876,8 +1878,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c" +source = "git+https://github.com/splitgraph/rust-postgres?rev=88c2c7714a4558aed6a63e2e2b140a8359568858#88c2c7714a4558aed6a63e2e2b140a8359568858" dependencies = [ "bytes", "fallible-iterator", @@ -2543,8 +2544,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d340244b32d920260ae7448cb72b6e238bddc3d4f7603394e7dd46ed8e48f5b8" +source = "git+https://github.com/splitgraph/rust-postgres?rev=88c2c7714a4558aed6a63e2e2b140a8359568858#88c2c7714a4558aed6a63e2e2b140a8359568858" dependencies = [ "async-trait", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index a5c64ae..3489365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,12 @@ log = "0.4" native-tls = "0.2.11" object_store = { version = "0.9", features = ["aws"] } parquet = "50.0.0" -postgres = "0.19.7" -postgres-native-tls = "0.5.0" +postgres = { version = "0.19.7", git = "https://github.com/splitgraph/rust-postgres", rev = "88c2c7714a4558aed6a63e2e2b140a8359568858" } +postgres-native-tls = { version = "0.5.0", git = "https://github.com/splitgraph/rust-postgres", rev = "88c2c7714a4558aed6a63e2e2b140a8359568858" } tempfile = "3" thiserror = "1" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "signal", "process"] } -tokio-postgres = "0.7.10" +tokio-postgres = { version = "0.7.10", git = "https://github.com/splitgraph/rust-postgres", rev = "88c2c7714a4558aed6a63e2e2b140a8359568858" } url = "2.5.0" uuid = "1.2.1" diff --git a/src/lib.rs b/src/lib.rs index 77c1557..2369c64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,7 @@ use tokio::sync::Semaphore; use uuid::Uuid; pub mod pg_arrow_source; +mod pg_numeric; use pg_arrow_source::{ArrowBuilder, PgArrowSource}; #[derive(Debug, Parser)] diff --git a/src/pg_arrow_source.rs b/src/pg_arrow_source.rs index 90c6b6b..c4b7322 100644 --- a/src/pg_arrow_source.rs +++ b/src/pg_arrow_source.rs @@ -24,9 +24,16 @@ pub enum ArrowBuilder { Float64Builder(array::Float64Builder), TimestampMicrosecondBuilder(array::TimestampMicrosecondBuilder), DateBuilder(array::Date32Builder), + DecimalBuilder { + builder: array::Decimal128Builder, + scale: i16, + }, StringBuilder(array::StringBuilder), BinaryBuilder(array::BinaryBuilder), } +use crate::pg_numeric::{ + numeric_typmod_precision, numeric_typmod_scale, pg_numeric_to_arrow_decimal, +}; use crate::{ArrowBuilder::*, DataLoadingError}; // tokio-postgres provides awkward Rust type conversions for Postgres TIMESTAMP and TIMESTAMPTZ values @@ -73,9 +80,30 @@ impl From for i64 { } } +struct RawPgBinary(Vec); +impl FromSql<'_> for RawPgBinary { + fn from_sql(_ty: &Type, buf: &[u8]) -> Result> { + Ok(RawPgBinary(buf.to_vec())) + } + + fn accepts(_ty: &Type) -> bool { + true + } +} +impl From for Vec { + fn from(val: RawPgBinary) -> Self { + val.0 + } +} + +pub struct PgTypeInfo { + pg_type: Type, + type_modifier: i32, +} + impl ArrowBuilder { - pub fn from_pg_type(pg_type: &Type) -> Self { - match *pg_type { + pub fn from_pg_type(type_info: &PgTypeInfo) -> Self { + match type_info.pg_type { Type::BOOL => BooleanBuilder(array::BooleanBuilder::new()), Type::CHAR => Int8Builder(array::Int8Builder::new()), Type::INT2 => Int16Builder(array::Int16Builder::new()), @@ -93,9 +121,23 @@ impl ArrowBuilder { )), ), Type::DATE => DateBuilder(array::Date32Builder::new()), + Type::NUMERIC => { + let precision: u8 = numeric_typmod_precision(type_info.type_modifier) + .try_into() + .expect("Unsupported precision"); + let scale: i8 = numeric_typmod_scale(type_info.type_modifier) + .try_into() + .expect("Unsupported scale"); + DecimalBuilder { + builder: array::Decimal128Builder::new() + .with_precision_and_scale(precision, scale) + .expect("Could not create Decimal128Builder"), + scale: scale.into(), + } + } Type::TEXT => StringBuilder(array::StringBuilder::new()), Type::BYTEA => BinaryBuilder(array::BinaryBuilder::new()), - _ => panic!("Unsupported type: {}", pg_type), + _ => panic!("Unsupported type: {}", type_info.pg_type), } } // Append a value from a tokio-postgres row to the ArrowBuilder @@ -130,6 +172,16 @@ impl ArrowBuilder { row.get::>(column_idx) .map(UnixEpochDayOffset::into), ), + DecimalBuilder { + ref mut builder, + scale, + } => { + let maybe_raw_binary = row.get::>(column_idx); + builder.append_option(maybe_raw_binary.map(|raw_binary| { + let buf: Vec = raw_binary.into(); + pg_numeric_to_arrow_decimal(&buf, *scale) + })) + } StringBuilder(ref mut builder) => { builder.append_option(row.get::>(column_idx)) } @@ -149,14 +201,15 @@ impl ArrowBuilder { Float64Builder(builder) => Arc::new(builder.finish()), TimestampMicrosecondBuilder(builder) => Arc::new(builder.finish()), DateBuilder(builder) => Arc::new(builder.finish()), + DecimalBuilder { builder, scale: _ } => Arc::new(builder.finish()), StringBuilder(builder) => Arc::new(builder.finish()), BinaryBuilder(builder) => Arc::new(builder.finish()), } } } -fn pg_type_to_arrow_type(pg_type: &Type) -> DataType { - match *pg_type { +fn pg_type_to_arrow_type(type_info: &PgTypeInfo) -> DataType { + match type_info.pg_type { Type::BOOL => DataType::Boolean, Type::CHAR => DataType::Int8, Type::INT2 => DataType::Int16, @@ -167,16 +220,20 @@ fn pg_type_to_arrow_type(pg_type: &Type) -> DataType { Type::TIMESTAMP => DataType::Timestamp(TimeUnit::Microsecond, None), Type::TIMESTAMPTZ => DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), Type::DATE => DataType::Date32, + Type::NUMERIC => DataType::Decimal128( + numeric_typmod_precision(type_info.type_modifier).try_into().expect("Unsupported precision"), + numeric_typmod_scale(type_info.type_modifier).try_into().expect("Unsupported scale"), + ), Type::TEXT => DataType::Utf8, Type::BYTEA => DataType::Binary, - _ => panic!("Unsupported type: {}. Explicitly cast the relevant columns to text in order to store them as strings.", pg_type), + _ => panic!("Unsupported type: {}. Explicitly cast the relevant columns to text in order to store them as strings.", type_info.pg_type), } } pub struct PgArrowSource { batch_size: usize, pg_row_stream: Pin>, - pg_types: Vec, + pg_types: Vec, arrow_schema: Arc, } @@ -211,7 +268,10 @@ impl PgArrowSource { let (pg_types, arrow_fields): (Vec<_>, Vec<_>) = postgres_columns .iter() .map(|c| { - let pg_type = c.type_().clone(); + let pg_type = PgTypeInfo { + pg_type: c.type_().clone(), + type_modifier: c.type_modifier(), + }; let arrow_type = pg_type_to_arrow_type(&pg_type); (pg_type, Field::new(c.name(), arrow_type, true)) }) diff --git a/src/pg_numeric.rs b/src/pg_numeric.rs new file mode 100644 index 0000000..c111f42 --- /dev/null +++ b/src/pg_numeric.rs @@ -0,0 +1,67 @@ +const VARHDRSZ: i32 = 4; +const NUMERIC_POS: u16 = 0x0000; +const NUMERIC_NEG: u16 = 0x4000; +const NUMERIC_NAN: u16 = 0xC000; +const NUMERIC_PINF: u16 = 0xD000; +const NUMERIC_NINF: u16 = 0xF000; + +// Follows the Postgres implementation in src/backend/utils/adt/numeric.c +pub fn numeric_typmod_precision(typmod: i32) -> u16 { + (((typmod - VARHDRSZ) >> 16) & 0xffff) as u16 +} + +// Follows the Postgres implementation in src/backend/utils/adt/numeric.c +pub fn numeric_typmod_scale(typmod: i32) -> i16 { + ((((typmod - VARHDRSZ) & 0x7ff) ^ 1024) - 1024) as i16 +} + +pub fn pg_numeric_to_arrow_decimal(buf: &[u8], result_scale: i16) -> i128 { + assert!(buf.len() >= 8, "Numeric buffer not long enough"); + // Bytes 0 and 1 encode ndigits, the number of base-10000 digits + let ndigits = u16::from_be_bytes(buf[0..2].try_into().unwrap()); + // Bytes 2 and 3 encode weight, the base-10000 weight of first digit + let weight = i16::from_be_bytes(buf[2..4].try_into().unwrap()); + // Bytes 4 and 5 encode the sign + let sign_word = u16::from_be_bytes(buf[4..6].try_into().unwrap()); + let sign_multiplier: i128 = match sign_word { + NUMERIC_POS => 1, + NUMERIC_NEG => -1, + NUMERIC_NAN => panic!("Cannot convert numeric NaN"), + NUMERIC_PINF => panic!("Cannot convert numeric +Inf"), + NUMERIC_NINF => panic!("Cannot convert numeric 'Inf"), + _ => panic!("Unexpected numeric sign: {}", sign_word), + }; + // Bytes 6 and 7 encode dscale. We ignore them + // The remaining bytes contain the digits. Every two bytes encode + // a base-10000 digit + let digits_bytes = &buf[8..]; + assert!( + digits_bytes.len() >= (2 * ndigits) as usize, + "Not enough digits in numeric buffer" + ); + let mut abs_result: i128 = 0; + for i in 0..ndigits { + let digit_bytes: [u8; 2] = digits_bytes[(2 * i as usize)..(2 * i as usize + 2)] + .try_into() + .unwrap(); + // The value of the current base-10000 digit + let digit = u16::from_be_bytes(digit_bytes); + // The base-10 weight of the current base-10000 digit in abs_result + let digit_multiplier_dweight: i16 = 4 * (weight - i as i16) + result_scale; + if digit_multiplier_dweight <= -4 { + // The weight of this base-10000 digit is too small to contribute to abs_result + } else if digit_multiplier_dweight == -3 { + abs_result += (digit / 1000) as i128; + } else if digit_multiplier_dweight == -2 { + abs_result += (digit / 100) as i128; + } else if digit_multiplier_dweight == -1 { + abs_result += (digit / 10) as i128; + } else { + // digit_multiplier_dweight > 0 + let digit_multplier: i128 = 10_i128.pow(digit_multiplier_dweight as u32); + abs_result += digit_multplier * (digit as i128); + } + } + + abs_result * sign_multiplier +} diff --git a/tests/basic_integration.rs b/tests/basic_integration.rs index 59688cf..8099a12 100644 --- a/tests/basic_integration.rs +++ b/tests/basic_integration.rs @@ -1,7 +1,8 @@ use arrow::array::{ - Array, BinaryArray, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, + Array, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, }; +use arrow::datatypes::DataType; use clap::Parser; use futures::{StreamExt, TryStreamExt}; use lakehouse_loader::pg_arrow_source::PgArrowSource; @@ -18,7 +19,7 @@ async fn test_pg_to_delta_e2e() { "pg-to-delta", "postgres://test-user:test-password@localhost:5432/test-db", "-q", - "select * from t1", + "select * from t1 order by id", target_url, ]); // THEN the command runs successfully @@ -58,7 +59,7 @@ async fn test_pg_arrow_source() { // WHEN 25001 rows are split into batches of 10000 let record_batches: Vec<_> = PgArrowSource::new( "postgres://test-user:test-password@localhost:5432/test-db", - "select * from t1", + "select * from t1 order by id", 10000, ) .await @@ -207,9 +208,39 @@ async fn test_pg_arrow_source() { assert!(!cdate_array.is_null(2)); assert_eq!(cdate_array.value(2), elapsed_days as i32 + 2); + // THEN the numeric field data type should be as expected + assert_eq!( + *rb1.schema().field(11).data_type(), + DataType::Decimal128(8, 3) + ); + + // THEN the first few numeric values should be as expected + let cnumeric_array = rb1 + .column(11) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(cnumeric_array.is_null(0)); + assert!(!cnumeric_array.is_null(1)); + assert_eq!(cnumeric_array.value(1), 0_i128); + assert!(!cnumeric_array.is_null(2)); + assert_eq!(cnumeric_array.value(2), 1_i128); + assert!(!cnumeric_array.is_null(3)); + assert_eq!(cnumeric_array.value(3), -2_i128); + assert!(!cnumeric_array.is_null(4)); + assert_eq!(cnumeric_array.value(4), 3000_i128); + assert!(!cnumeric_array.is_null(5)); + assert_eq!(cnumeric_array.value(5), -4000_i128); + assert!(!cnumeric_array.is_null(6)); + assert_eq!(cnumeric_array.value(6), 50001_i128); + assert!(!cnumeric_array.is_null(7)); + assert_eq!(cnumeric_array.value(7), 99999999_i128); + assert!(!cnumeric_array.is_null(8)); + assert_eq!(cnumeric_array.value(8), -99999999_i128); + // THEN the first 3 text values should be as expected let ctext_array = rb1 - .column(11) + .column(12) .as_any() .downcast_ref::() .unwrap(); @@ -221,7 +252,7 @@ async fn test_pg_arrow_source() { // THEN the first 3 bytea values should be as expected let cbytea_array = rb1 - .column(12) + .column(13) .as_any() .downcast_ref::() .unwrap(); diff --git a/tests/postgres-init-scripts/init-pg-data.sql b/tests/postgres-init-scripts/init-pg-data.sql index af69a2e..b53c2b4 100755 --- a/tests/postgres-init-scripts/init-pg-data.sql +++ b/tests/postgres-init-scripts/init-pg-data.sql @@ -10,6 +10,7 @@ CREATE TABLE t1( ctimestamp TIMESTAMP, ctimestamptz TIMESTAMPTZ, cdate DATE, + cnumeric NUMERIC(8, 3), ctext TEXT, cbytea BYTEA ); @@ -27,6 +28,7 @@ INSERT INTO t1( ctimestamp, ctimestamptz, cdate, + cnumeric, ctext, cbytea ) SELECT @@ -40,6 +42,17 @@ INSERT INTO t1( '2024-01-01'::TIMESTAMP + s * INTERVAL '1 second', '2024-01-01 00:00:00+00'::TIMESTAMPTZ + s * INTERVAL '1 second', '2024-01-01'::DATE + s, + s::NUMERIC / 1000, s::TEXT, int4send(s::INT) FROM generate_series(1, 25000) AS s; + +-- Set various cnumeric values +UPDATE t1 SET cnumeric = 0 WHERE id = 2; +UPDATE t1 SET cnumeric = 0.001 WHERE id = 3; +UPDATE t1 SET cnumeric = -0.002 WHERE id = 4; +UPDATE t1 SET cnumeric = 3 WHERE id = 5; +UPDATE t1 SET cnumeric = -4 WHERE id = 6; +UPDATE t1 SET cnumeric = 50.001 WHERE id = 7; +UPDATE t1 SET cnumeric = 99999.999 WHERE id = 8; +UPDATE t1 SET cnumeric = -99999.999 WHERE id = 9;