Skip to content

Commit

Permalink
Numeric support
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeiPatiakin committed May 31, 2024
1 parent 5fe1c74 commit 467dd48
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use tokio::sync::Semaphore;
use uuid::Uuid;

pub mod pg_arrow_source;
mod pg_numeric;
use pg_arrow_source::{ArrowBuilder, PgArrowSource};

#[derive(Debug, Parser)]
Expand Down
76 changes: 68 additions & 8 deletions src/pg_arrow_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@ pub enum ArrowBuilder {
Float64Builder(array::Float64Builder),
TimestampMicrosecondBuilder(array::TimestampMicrosecondBuilder),
DateBuilder(array::Date32Builder),
DecimalBuilder {
builder: array::Decimal128Builder,
scale: i16,
},
StringBuilder(array::StringBuilder),
BinaryBuilder(array::BinaryBuilder),
}
use crate::pg_numeric::{
numeric_typmod_precision, numeric_typmod_scale, pg_numeric_to_arrow_decimal,
};
use crate::{ArrowBuilder::*, DataLoadingError};

// tokio-postgres provides awkward Rust type conversions for Postgres TIMESTAMP and TIMESTAMPTZ values
Expand Down Expand Up @@ -73,9 +80,30 @@ impl From<UnixEpochMicrosecondOffset> for i64 {
}
}

struct RawPgBinary(Vec<u8>);
impl FromSql<'_> for RawPgBinary {
fn from_sql(_ty: &Type, buf: &[u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
Ok(RawPgBinary(buf.to_vec()))
}

fn accepts(_ty: &Type) -> bool {
true
}
}
impl From<RawPgBinary> for Vec<u8> {
fn from(val: RawPgBinary) -> Self {
val.0
}
}

pub struct PgTypeInfo {
pg_type: Type,
type_modifier: i32,
}

impl ArrowBuilder {
pub fn from_pg_type(pg_type: &Type) -> Self {
match *pg_type {
pub fn from_pg_type(type_info: &PgTypeInfo) -> Self {
match type_info.pg_type {
Type::BOOL => BooleanBuilder(array::BooleanBuilder::new()),
Type::CHAR => Int8Builder(array::Int8Builder::new()),
Type::INT2 => Int16Builder(array::Int16Builder::new()),
Expand All @@ -93,9 +121,23 @@ impl ArrowBuilder {
)),
),
Type::DATE => DateBuilder(array::Date32Builder::new()),
Type::NUMERIC => {
let precision: u8 = numeric_typmod_precision(type_info.type_modifier)
.try_into()
.expect("Unsupported precision");
let scale: i8 = numeric_typmod_scale(type_info.type_modifier)
.try_into()
.expect("Unsupported scale");
DecimalBuilder {
builder: array::Decimal128Builder::new()
.with_precision_and_scale(precision, scale)
.expect("Could not create Decimal128Builder"),
scale: scale.into(),
}
}
Type::TEXT => StringBuilder(array::StringBuilder::new()),
Type::BYTEA => BinaryBuilder(array::BinaryBuilder::new()),
_ => panic!("Unsupported type: {}", pg_type),
_ => panic!("Unsupported type: {}", type_info.pg_type),
}
}
// Append a value from a tokio-postgres row to the ArrowBuilder
Expand Down Expand Up @@ -130,6 +172,16 @@ impl ArrowBuilder {
row.get::<usize, Option<UnixEpochDayOffset>>(column_idx)
.map(UnixEpochDayOffset::into),
),
DecimalBuilder {
ref mut builder,
scale,
} => {
let maybe_raw_binary = row.get::<usize, Option<RawPgBinary>>(column_idx);
builder.append_option(maybe_raw_binary.map(|raw_binary| {
let buf: Vec<u8> = raw_binary.into();
pg_numeric_to_arrow_decimal(&buf, *scale)
}))
}
StringBuilder(ref mut builder) => {
builder.append_option(row.get::<usize, Option<&str>>(column_idx))
}
Expand All @@ -149,14 +201,15 @@ impl ArrowBuilder {
Float64Builder(builder) => Arc::new(builder.finish()),
TimestampMicrosecondBuilder(builder) => Arc::new(builder.finish()),
DateBuilder(builder) => Arc::new(builder.finish()),
DecimalBuilder { builder, scale: _ } => Arc::new(builder.finish()),
StringBuilder(builder) => Arc::new(builder.finish()),
BinaryBuilder(builder) => Arc::new(builder.finish()),
}
}
}

fn pg_type_to_arrow_type(pg_type: &Type) -> DataType {
match *pg_type {
fn pg_type_to_arrow_type(type_info: &PgTypeInfo) -> DataType {
match type_info.pg_type {
Type::BOOL => DataType::Boolean,
Type::CHAR => DataType::Int8,
Type::INT2 => DataType::Int16,
Expand All @@ -167,16 +220,20 @@ fn pg_type_to_arrow_type(pg_type: &Type) -> DataType {
Type::TIMESTAMP => DataType::Timestamp(TimeUnit::Microsecond, None),
Type::TIMESTAMPTZ => DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
Type::DATE => DataType::Date32,
Type::NUMERIC => DataType::Decimal128(
numeric_typmod_precision(type_info.type_modifier).try_into().expect("Unsupported precision"),
numeric_typmod_scale(type_info.type_modifier).try_into().expect("Unsupported scale"),
),
Type::TEXT => DataType::Utf8,
Type::BYTEA => DataType::Binary,
_ => panic!("Unsupported type: {}. Explicitly cast the relevant columns to text in order to store them as strings.", pg_type),
_ => panic!("Unsupported type: {}. Explicitly cast the relevant columns to text in order to store them as strings.", type_info.pg_type),
}
}

pub struct PgArrowSource {
batch_size: usize,
pg_row_stream: Pin<Box<RowStream>>,
pg_types: Vec<Type>,
pg_types: Vec<PgTypeInfo>,
arrow_schema: Arc<Schema>,
}

Expand Down Expand Up @@ -211,7 +268,10 @@ impl PgArrowSource {
let (pg_types, arrow_fields): (Vec<_>, Vec<_>) = postgres_columns
.iter()
.map(|c| {
let pg_type = c.type_().clone();
let pg_type = PgTypeInfo {
pg_type: c.type_().clone(),
type_modifier: c.type_modifier(),
};
let arrow_type = pg_type_to_arrow_type(&pg_type);
(pg_type, Field::new(c.name(), arrow_type, true))
})
Expand Down
67 changes: 67 additions & 0 deletions src/pg_numeric.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
const VARHDRSZ: i32 = 4;
const NUMERIC_POS: u16 = 0x0000;
const NUMERIC_NEG: u16 = 0x4000;
const NUMERIC_NAN: u16 = 0xC000;
const NUMERIC_PINF: u16 = 0xD000;
const NUMERIC_NINF: u16 = 0xF000;

// Follows the Postgres implementation in src/backend/utils/adt/numeric.c
pub fn numeric_typmod_precision(typmod: i32) -> u16 {
(((typmod - VARHDRSZ) >> 16) & 0xffff) as u16
}

// Follows the Postgres implementation in src/backend/utils/adt/numeric.c
pub fn numeric_typmod_scale(typmod: i32) -> i16 {
((((typmod - VARHDRSZ) & 0x7ff) ^ 1024) - 1024) as i16
}

pub fn pg_numeric_to_arrow_decimal(buf: &[u8], result_scale: i16) -> i128 {
assert!(buf.len() >= 8, "Numeric buffer not long enough");
// Bytes 0 and 1 encode ndigits, the number of base-10000 digits
let ndigits = u16::from_be_bytes(buf[0..2].try_into().unwrap());
// Bytes 2 and 3 encode weight, the base-10000 weight of first digit
let weight = i16::from_be_bytes(buf[2..4].try_into().unwrap());
// Bytes 4 and 5 encode the sign
let sign_word = u16::from_be_bytes(buf[4..6].try_into().unwrap());
let sign_multiplier: i128 = match sign_word {
NUMERIC_POS => 1,
NUMERIC_NEG => -1,
NUMERIC_NAN => panic!("Cannot convert numeric NaN"),
NUMERIC_PINF => panic!("Cannot convert numeric +Inf"),
NUMERIC_NINF => panic!("Cannot convert numeric 'Inf"),
_ => panic!("Unexpected numeric sign: {}", sign_word),
};
// Bytes 6 and 7 encode dscale. We ignore them
// The remaining bytes contain the digits. Every two bytes encode
// a base-10000 digit
let digits_bytes = &buf[8..];
assert!(
digits_bytes.len() >= (2 * ndigits) as usize,
"Not enough digits in numeric buffer"
);
let mut abs_result: i128 = 0;
for i in 0..ndigits {
let digit_bytes: [u8; 2] = digits_bytes[(2 * i as usize)..(2 * i as usize + 2)]
.try_into()
.unwrap();
// The value of the current base-10000 digit
let digit = u16::from_be_bytes(digit_bytes);
// The base-10 weight of the current base-10000 digit in abs_result
let digit_multiplier_dweight: i16 = 4 * (weight - i as i16) + result_scale;
if digit_multiplier_dweight <= -4 {
// The weight of this base-10000 digit is too small to contribute to abs_result
} else if digit_multiplier_dweight == -3 {
abs_result += (digit / 1000) as i128;
} else if digit_multiplier_dweight == -2 {
abs_result += (digit / 100) as i128;
} else if digit_multiplier_dweight == -1 {
abs_result += (digit / 10) as i128;
} else {
// digit_multiplier_dweight > 0
let digit_multplier: i128 = 10_i128.pow(digit_multiplier_dweight as u32);
abs_result += digit_multplier * (digit as i128);
}
}

abs_result * sign_multiplier
}
36 changes: 30 additions & 6 deletions tests/basic_integration.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow::array::{
Array, BinaryArray, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
Array, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
};
use clap::Parser;
use futures::{StreamExt, TryStreamExt};
Expand All @@ -18,7 +18,7 @@ async fn test_pg_to_delta_e2e() {
"pg-to-delta",
"postgres://test-user:test-password@localhost:5432/test-db",
"-q",
"select * from t1",
"select * from t1 order by id",
target_url,
]);
// THEN the command runs successfully
Expand Down Expand Up @@ -58,7 +58,7 @@ async fn test_pg_arrow_source() {
// WHEN 25001 rows are split into batches of 10000
let record_batches: Vec<_> = PgArrowSource::new(
"postgres://test-user:test-password@localhost:5432/test-db",
"select * from t1",
"select * from t1 order by id",
10000,
)
.await
Expand Down Expand Up @@ -207,9 +207,33 @@ async fn test_pg_arrow_source() {
assert!(!cdate_array.is_null(2));
assert_eq!(cdate_array.value(2), elapsed_days as i32 + 2);

// THEN the first few numeric values should be as expected
let cnumeric_array = rb1
.column(11)
.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap();
assert!(cnumeric_array.is_null(0));
assert!(!cnumeric_array.is_null(1));
assert_eq!(cnumeric_array.value(1), 0_i128);
assert!(!cnumeric_array.is_null(2));
assert_eq!(cnumeric_array.value(2), 1_i128);
assert!(!cnumeric_array.is_null(3));
assert_eq!(cnumeric_array.value(3), -2_i128);
assert!(!cnumeric_array.is_null(4));
assert_eq!(cnumeric_array.value(4), 3000_i128);
assert!(!cnumeric_array.is_null(5));
assert_eq!(cnumeric_array.value(5), -4000_i128);
assert!(!cnumeric_array.is_null(6));
assert_eq!(cnumeric_array.value(6), 50001_i128);
assert!(!cnumeric_array.is_null(7));
assert_eq!(cnumeric_array.value(7), 99999999_i128);
assert!(!cnumeric_array.is_null(8));
assert_eq!(cnumeric_array.value(8), -99999999_i128);

// THEN the first 3 text values should be as expected
let ctext_array = rb1
.column(11)
.column(12)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
Expand All @@ -221,7 +245,7 @@ async fn test_pg_arrow_source() {

// THEN the first 3 bytea values should be as expected
let cbytea_array = rb1
.column(12)
.column(13)
.as_any()
.downcast_ref::<BinaryArray>()
.unwrap();
Expand Down
13 changes: 13 additions & 0 deletions tests/postgres-init-scripts/init-pg-data.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CREATE TABLE t1(
ctimestamp TIMESTAMP,
ctimestamptz TIMESTAMPTZ,
cdate DATE,
cnumeric NUMERIC(8, 3),
ctext TEXT,
cbytea BYTEA
);
Expand All @@ -27,6 +28,7 @@ INSERT INTO t1(
ctimestamp,
ctimestamptz,
cdate,
cnumeric,
ctext,
cbytea
) SELECT
Expand All @@ -40,6 +42,17 @@ INSERT INTO t1(
'2024-01-01'::TIMESTAMP + s * INTERVAL '1 second',
'2024-01-01 00:00:00+00'::TIMESTAMPTZ + s * INTERVAL '1 second',
'2024-01-01'::DATE + s,
s::NUMERIC / 1000,
s::TEXT,
int4send(s::INT)
FROM generate_series(1, 25000) AS s;

-- Set various cnumeric values
UPDATE t1 SET cnumeric = 0 WHERE id = 2;
UPDATE t1 SET cnumeric = 0.001 WHERE id = 3;
UPDATE t1 SET cnumeric = -0.002 WHERE id = 4;
UPDATE t1 SET cnumeric = 3 WHERE id = 5;
UPDATE t1 SET cnumeric = -4 WHERE id = 6;
UPDATE t1 SET cnumeric = 50.001 WHERE id = 7;
UPDATE t1 SET cnumeric = 99999.999 WHERE id = 8;
UPDATE t1 SET cnumeric = -99999.999 WHERE id = 9;

0 comments on commit 467dd48

Please sign in to comment.