Skip to content

Commit

Permalink
[db-queries] Allow join expressions in paginated-multicolumn (#6530)
Browse files Browse the repository at this point in the history
Currently, the `paginated_multicolumn` utility in
`nexus_db_queries::pagination` only works when the select expression to
paginate is a table, and both columns to order by come from that table.
This means that it cannot easily be used to fix the bug in
`instance_and_vmm_list_by_sled_agent` that @davepacheco describes in
[this comment][1], which would require using `paginated_multicolumn` to
paginate on two columns in an inner join expression. 

This commit changes the giant wad of Diesel type ceremony on
`paginated_multicolumn` in order to ~~make it even worse~~ allow
expressions which are not tables to be paginated. I've added a test
demonstrating that this does, in fact, work.

Figuring out how to do this was...certainly an experience which I have
had. I think I need to lie down now.

[1]:
#6519 (review)
  • Loading branch information
hawkw authored Sep 5, 2024
1 parent da08190 commit 9c81109
Showing 1 changed file with 228 additions and 38 deletions.
266 changes: 228 additions & 38 deletions nexus/db-queries/src/db/pagination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use diesel::helper_types::*;
use diesel::pg::Pg;
use diesel::query_builder::AsQuery;
use diesel::query_dsl::methods as query_methods;
use diesel::query_source::QuerySource;
use diesel::sql_types::{Bool, SqlType};
use diesel::AppearsOnTable;
use diesel::Column;
Expand Down Expand Up @@ -70,7 +71,7 @@ where
}
}

/// Uses `pagparams` to list a subset of rows in `table`, ordered by `c1, and
/// Uses `pagparams` to list a subset of rows in `query`, ordered by `c1, and
/// then by `c2.
///
/// This is a two-column variation of the [`paginated`] function.
Expand All @@ -79,40 +80,56 @@ where
// columns" implement a subset of ExpressionMethods) or making a macro to generate
// all the necessary bounds we need.
pub fn paginated_multicolumn<T, C1, C2, M1, M2>(
table: T,
query: T,
(c1, c2): (C1, C2),
pagparams: &DataPageParams<'_, (M1, M2)>,
) -> BoxedQuery<T>
) -> <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output
where
// T is a table which can create a BoxedQuery.
T: diesel::Table,
T: query_methods::BoxedDsl<
'static,
Pg,
Output = diesel::internal::table_macro::BoxedSelectStatement<
'static,
TableSqlType<T>,
diesel::internal::table_macro::FromClause<T>,
Pg,
>,
>,
// T is a table^H^H^H^H^Hquery source which can create a BoxedQuery.
T: QuerySource,
T: AsQuery,
<T as QuerySource>::DefaultSelection:
Expression<SqlType = <T as AsQuery>::SqlType>,
T::Query: query_methods::BoxedDsl<'static, Pg>,
// Required for...everything.
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output: QueryDsl,
// C1 & C2 are columns which appear in T.
C1: 'static + Column + Copy + ExpressionMethods + AppearsOnTable<T>,
C2: 'static + Column + Copy + ExpressionMethods + AppearsOnTable<T>,
C1: 'static + Column + Copy + ExpressionMethods,
C2: 'static + Column + Copy + ExpressionMethods,
// Required to compare the columns with the marker types.
C1::SqlType: SqlType,
C2::SqlType: SqlType,
M1: Clone + AsExpression<C1::SqlType>,
M2: Clone + AsExpression<C2::SqlType>,
// Necessary for `query.limit(...)`
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::LimitDsl<
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(c1.desc())"
BoxedQuery<T>: query_methods::OrderDsl<Desc<C1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrderDsl<
Desc<C1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(...).then_order_by(c2.desc())"
BoxedQuery<T>:
query_methods::ThenOrderDsl<Desc<C2>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::ThenOrderDsl<
Desc<C2>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(c1.asc())"
BoxedQuery<T>: query_methods::OrderDsl<Asc<C1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrderDsl<
Asc<C1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.order(...).then_order_by(c2.asc())"
BoxedQuery<T>: query_methods::ThenOrderDsl<Asc<C2>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::ThenOrderDsl<
Asc<C2>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// We'd like to be able to call:
//
Expand All @@ -126,10 +143,11 @@ where
// The RHS (c2.gt(v2)) must be a boolean expression:
Gt<C2, M2>: Expression<SqlType = Bool>,
// Putting it together, we should be able to filter by LHS.and(RHS):
BoxedQuery<T>: query_methods::FilterDsl<
And<Eq<C1, M1>, Gt<C2, M2>>,
Output = BoxedQuery<T>,
>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::FilterDsl<
And<Eq<C1, M1>, Gt<C2, M2>>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// We'd also like to be able to call:
//
Expand All @@ -138,19 +156,30 @@ where
// We've already defined the bound on the LHS, so we add the equivalent
// bounds on the RHS for the "Less than" variant.
Lt<C2, M2>: Expression<SqlType = Bool>,
BoxedQuery<T>: query_methods::FilterDsl<
And<Eq<C1, M1>, Lt<C2, M2>>,
Output = BoxedQuery<T>,
>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::FilterDsl<
And<Eq<C1, M1>, Lt<C2, M2>>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,

// Necessary for "query.or_filter(c1.gt(v1))"
BoxedQuery<T>:
query_methods::OrFilterDsl<Gt<C1, M1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrFilterDsl<
Gt<C1, M1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
// Necessary for "query.or_filter(c1.lt(v1))"
BoxedQuery<T>:
query_methods::OrFilterDsl<Lt<C1, M1>, Output = BoxedQuery<T>>,
<T::Query as query_methods::BoxedDsl<'static, Pg>>::Output:
query_methods::OrFilterDsl<
Lt<C1, M1>,
Output = <T::Query as query_methods::BoxedDsl<'static, Pg>>::Output,
>,
{
let mut query = table.into_boxed().limit(pagparams.limit.get().into());
use query_methods::BoxedDsl;
let mut query = query
.as_query()
.internal_into_boxed()
.limit(pagparams.limit.get().into());
let marker = pagparams.marker.map(|m| m.clone());
match pagparams.direction {
dropshot::PaginationOrder::Ascending => {
Expand Down Expand Up @@ -315,6 +344,7 @@ mod test {

use crate::db;
use async_bb8_diesel::{AsyncRunQueryDsl, AsyncSimpleConnection};
use diesel::JoinOnDsl;
use diesel::SelectableHelper;
use dropshot::PaginationOrder;
use nexus_test_utils::db::test_setup_database;
Expand All @@ -333,9 +363,18 @@ mod test {
height -> Int8,
}
}

table! {
test_phone_numbers (user_id, phone_number) {
user_id -> Uuid,
phone_number -> Int8,
}
}

allow_tables_to_appear_in_same_query!(test_users, test_phone_numbers,);
}

use schema::test_users;
use schema::{test_phone_numbers, test_users};

#[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)]
#[diesel(table_name = test_users)]
Expand All @@ -345,13 +384,39 @@ mod test {
height: i64,
}

#[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)]
#[diesel(table_name = test_phone_numbers)]
struct PhoneNumber {
user_id: Uuid,
phone_number: i64,
}

#[derive(Debug)]
struct UserAndPhoneNumber {
user: User,
phone_number: PhoneNumber,
}

impl PartialEq<((i64, i64), i64)> for UserAndPhoneNumber {
fn eq(&self, &(user, phone): &((i64, i64), i64)) -> bool {
self.user == user && self.phone_number == phone
}
}

impl PartialEq<(i64, i64)> for User {
fn eq(&self, other: &(i64, i64)) -> bool {
self.age == other.0 && self.height == other.1
}
}

impl PartialEq<i64> for PhoneNumber {
fn eq(&self, &other: &i64) -> bool {
self.phone_number == other
}
}

async fn populate_users(pool: &db::Pool, values: &Vec<(i64, i64)>) {
use schema::test_phone_numbers::dsl as phone_numbers_dsl;
use schema::test_users::dsl;

let conn = pool.claim().await.unwrap();
Expand All @@ -365,8 +430,17 @@ mod test {
height INT NOT NULL
);
CREATE TABLE test_phone_numbers (
user_id UUID NOT NULL,
-- This is definitely the correct way to store a
-- phone number in the database. :)
phone_number INT NOT NULL,
PRIMARY KEY (user_id, phone_number)
);
CREATE INDEX ON test_users (age, height);
CREATE INDEX ON test_users (height, age);",
CREATE INDEX ON test_users (height, age);
CREATE INDEX ON test_phone_numbers (user_id);",
)
.await
.unwrap();
Expand All @@ -381,7 +455,22 @@ mod test {
.collect();

diesel::insert_into(dsl::test_users)
.values(users)
.values(users.clone())
.execute_async(&*conn)
.await
.unwrap();

let mut phone_numbers = Vec::new();
for (i, user) in users.iter().enumerate() {
for j in 0..3 {
phone_numbers.push(PhoneNumber {
user_id: user.id,
phone_number: (i as i64 + 1) * 10 + j,
});
}
}
diesel::insert_into(phone_numbers_dsl::test_phone_numbers)
.values(phone_numbers)
.execute_async(&*conn)
.await
.unwrap();
Expand Down Expand Up @@ -574,6 +663,107 @@ mod test {
logctx.cleanup_successful();
}

#[tokio::test]
async fn test_paginated_multicolumn_works_with_joins() {
use async_bb8_diesel::AsyncConnection;

let logctx =
dev::test_setup_log("test_paginated_multicolumn_works_with_joins");
let mut db = test_setup_database(&logctx.log).await;
let cfg = db::Config { url: db.pg_config().clone() };
let pool = db::Pool::new_single_host(&logctx.log, &cfg);

use schema::test_phone_numbers::dsl as phone_numbers_dsl;
use schema::test_users::dsl;

populate_users(&pool, &vec![(1, 1), (1, 2), (2, 1), (2, 3), (3, 1)])
.await;

async fn get_page(
pool: &db::Pool,
pagparams: &DataPageParams<'_, (i64, i64)>,
) -> Vec<UserAndPhoneNumber> {
let conn = pool.claim().await.unwrap();
conn.transaction_async(|conn| async move {
// I couldn't figure out how to make this work without requiring a full
// table scan, and I just want the test to work so that I can get on
// with my life...
conn.batch_execute_async(
crate::db::queries::ALLOW_FULL_TABLE_SCAN_SQL,
)
.await
.unwrap();

paginated_multicolumn(
dsl::test_users.inner_join(
phone_numbers_dsl::test_phone_numbers
.on(phone_numbers_dsl::user_id.eq(dsl::id)),
),
(dsl::age, phone_numbers_dsl::phone_number),
&pagparams,
)
.select((User::as_select(), PhoneNumber::as_select()))
.load_async(&conn)
.await
})
.await
.unwrap()
.into_iter()
.map(|(user, phone_number)| UserAndPhoneNumber {
user,
phone_number,
})
.collect::<Vec<_>>()
}

// Get the first paginated result.
let mut pagparams = DataPageParams::<(i64, i64)> {
marker: None,
direction: PaginationOrder::Ascending,
limit: NonZeroU32::new(1).unwrap(),
};
let observed = get_page(&pool, &pagparams).await;
assert_eq!(dbg!(&observed), &[((1, 1), 10)]);

// Get the next paginated results, check that they arrived in the order
// we expected.
let marker =
(observed[0].user.age, observed[0].phone_number.phone_number);
pagparams.marker = Some(&marker);
pagparams.limit = NonZeroU32::new(10).unwrap();
let observed = get_page(&pool, &pagparams).await;
assert_eq!(
dbg!(&observed),
&[
((1, 1), 11),
((1, 1), 12),
((1, 2), 20),
((1, 2), 21),
((1, 2), 22),
((2, 1), 30),
((2, 1), 31),
((2, 1), 32),
((2, 3), 40),
((2, 3), 41),
]
);

// Get the next paginated results, check that they arrived in the order
// we expected.
let marker =
(observed[9].user.age, observed[9].phone_number.phone_number);
pagparams.marker = Some(&marker);
pagparams.limit = NonZeroU32::new(10).unwrap();
let observed = get_page(&pool, &pagparams).await;
assert_eq!(
dbg!(&observed),
&[((2, 3), 42), ((3, 1), 50), ((3, 1), 51), ((3, 1), 52)]
);

let _ = db.cleanup().await;
logctx.cleanup_successful();
}

#[test]
fn test_paginator() {
// The doctest exercises a basic case for Paginator. Here we test some
Expand Down

0 comments on commit 9c81109

Please sign in to comment.