Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: experiment with SMJ last buffered batch #12082

Closed
wants to merge 2 commits into from
Closed

Conversation

comphead
Copy link
Contributor

Which issue does this PR close?

Related to #11555

Closes #.

Rationale for this change

Experiment with approach how to identify a last buffered batch for the given streaming row join key

What changes are included in this PR?

Are these changes tested?

Are there any user-facing changes?

)
.run_test(&[JoinTestType::HjSmj], false)
.await
for i in 0..1000 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I run this test 1000 times there is a possibility that test gonna fail.

// Try to calculate if the buffered batch we scan is the last one for specific stream row and join key
// for Batchsize == 1 self.buffered_data.scanning_finished() works well
// For other scenarios its an attempt to figure out there is no more rows matching the same join key
let last_batch = if self.batch_size == 1 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@korowa @viirya @alamb
Hi guys appreciate if you have any other ideas how to calculate the last batch.
AntiJoin relies exactly on having the last batch, to calculate the predicate for join key correctly.
I'm trying to figure out there is no more buffered rows incoming for the given streaming join key. The approach is still not perfect as it still allows the tests to fail time to time although it becomes more stable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try and give it a look tomorrow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect the right check, based on function names, to be

let last_batch = self.buffered_data.scanning_finished()

However, I tried that and the test still fails.

@richox I wonder if you have any ideas (as it appears you are the original author of SortMergeJoin in #2242)

I am having a hard time following the logic in such a large function (looks like freeze_streamed is something like 300 lines long).

If I were debugging this issue more, what I would probably do is

  1. to break the logic down into a few more named functions so the logic boundaries were clearer and the intended action is clearer.
  2. try and document, in comments, what the intended invariants of BufferedBatch / ScanningBatch are. My hope would be that in the process of writing that documentation I would learn the code more so I could have a better idea of what invariant isn't being upheld in this ufunction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation would def help. Btw here a ticket for SMJ documentation
#10357

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to look at this PR (will take some time though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to look at this PR (will take some time though)

Thanks, its not an actual PR, it is more attempts/directions to find a solution and discuss. Im experimenting more in parallel and would love to hear your ideas as well

Copy link
Contributor

@korowa korowa Aug 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've finally got your idea (and the fact that the problem is not related to fetching buffered side, but to processing already joined buffered side data). Probably for anti join following will be helpful

  1. get_filtered_join_mask -- maybe it should only update matched_indices (in case filters are evaluated as true as least once), and data emission logic should be in some other place (currently there is a problem with streamed records without any filter matches will be duplicated for each joined buffered chunk, as "negative" filter results are not tracked across joined batches). Anyway, it s output doesn't seem to be sufficient for antijoin.
  2. filtered anti join should return only the records for which buffered-side scanning is completed (as freeze_streamed may be called in the middle of buffered-data scanning, due to output batch size), and there were no true filters for them (from p.1) -- so, maybe we should split filter evaluation and output emission in freeze_streamed (since the filters should be checked for all matched indices, but in the same time, the current streamed index can be filitered out of output because it has further buffered batches to be joined with)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @korowa for the directions. This week I will try to find if such approach works for us, and alternatively I'm planning to play with a pair scanning_batch().range.start and self.scanning_offset perhaps it can give a hint how to identify last joined buffered side batch for the for the streaming row.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option just came to my mind. We know in advance number of matched rows. We can calc it as

let matched_indices = &self.buffered_data.batches.iter().map(|b| b.range.end - b.range.start).count();

No matter how SMJ distributes data, we can just take number of buffered_indices.len() per each iteration and substract it from matched_indices. Once we have hit 0 means no more matched rows expected

@github-actions github-actions bot added physical-expr Physical Expressions core Core DataFusion crate labels Aug 20, 2024
@comphead
Copy link
Contributor Author

To create a reproduce test its needed to run a test in debug mode

async fn test_anti_join_1k_filtered() {
    // NLJ vs HJ gives wrong result
    // Tracked in https://github.com/apache/datafusion/issues/11537
    for i in 0..1000 {
        JoinFuzzTestCase::new(
            make_staggered_batches(1000),
            make_staggered_batches(1000),
            JoinType::LeftAnti,
            Some(Box::new(col_lt_col_filter)),
        )
            .run_test(&[JoinTestType::HjSmj], true)
            .await
    }
}

the test creates a dump of data locally to the disk. for example fuzz_test_debug/batch_size_7

The test below is a reproduce case(just set the paths) which step 1 outputs

#[tokio::test]
async fn test1() {
    let left: Vec<RecordBatch> = JoinFuzzTestCase::load_partitioned_batches_from_parquet(
        "fuzz_test_debug/batch_size_7/input1",
    )
    .await
    .unwrap();

    let right: Vec<RecordBatch> =
        JoinFuzzTestCase::load_partitioned_batches_from_parquet(
            "fuzz_test_debug/batch_size_7/input2",
        )
        .await
        .unwrap();

    JoinFuzzTestCase::new(
        left,
        right,
        JoinType::LeftAnti,
        Some(Box::new(col_lt_col_filter)),
    )
    .run_test(&[JoinTestType::HjSmj], false)
    .await;
}

@comphead
Copy link
Contributor Author

@korowa @viirya please help to understand scenario with ranges.

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key.

For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

What I see now is for some extreme case I can get a joined buffered data when being called freeze_streamed which doesn't match the join key.

Like [1..3] for join key 1 and then [0..1] for join key 2, which looks weird for me and it seems like unexpected? WDYT?

@viirya
Copy link
Member

viirya commented Aug 29, 2024

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key.

For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

I don't get the question clearly.

You have [0, 1, 1] as buffered indices for same streamed row? Why you have same buffered row id 1 twice?

@comphead
Copy link
Contributor Author

if there is a left streamed row with join key (1) from the right side we gonna have joined buffered batches where range shows what indices share the same join key.
For example

Streamed data        Buffered data
[1]               -> [0, 1, 1], [1, 1, 2]

Should have ranges [1..3], [0..2]

I don't get the question clearly.

You have [0, 1, 1] as buffered indices for same streamed row? Why you have same buffered row id 1 twice?

Thanks @viirya it's not indices, it is a raw data. Let me rephrase it.

If I have a left table

a b
10 20

and right table

a b
5 20
10 20
10 21
10 21
10 22
15 22

And join key is A and Filter is on column B

In freeze_streamed I can observe the right table comes as 3 batches

1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10
2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10
3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

@comphead
Copy link
Contributor Author

comphead commented Aug 29, 2024

#[tokio::test]
async fn test_ranges() {
    let left: Vec<RecordBatch> = make_staggered_batches(1);

    let left = vec![
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![1])),
                Arc::new(Int32Array::from(vec![10])),
                Arc::new(Int32Array::from(vec![10])),
                Arc::new(Int32Array::from(vec![1000])),
            ],
        ).unwrap()
    ];

    let right = vec![
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![0, 1, 1, 2])),
                Arc::new(Int32Array::from(vec![0, 10, 11, 20])),
                Arc::new(Int32Array::from(vec![0, 1100, 0, 2100])),
                Arc::new(Int32Array::from(vec![0, 11000, 0, 21000])),
            ],
        ).unwrap(),
        RecordBatch::try_new(
            left[0].schema().clone(),
            vec![
                Arc::new(Int32Array::from(vec![2, 2])),
                Arc::new(Int32Array::from(vec![20, 21])),
                Arc::new(Int32Array::from(vec![2101, 0])),
                Arc::new(Int32Array::from(vec![21001, 0])),
            ],
        ).unwrap(),

    ];

    JoinFuzzTestCase::new(
        left,
        right,
        JoinType::LeftAnti,
        Some(Box::new(col_lt_col_filter)),
    )
        .run_test(&[JoinTestType::HjSmj], false)
        .await;
}

if you debug freeze_streamed you can see of the buffered data batches has range 0 .. 1 but for another join key. Do you think it is correct? Probably we need to check join array from first batch with subsequent batches

@viirya
Copy link
Member

viirya commented Aug 29, 2024

If I have a left table

a b
10 20
and right table

a b
5 20
10 20
10 21
10 21
10 22
15 22
And join key is A and Filter is on column B

In freeze_streamed I can observe the right table comes as 3 batches

1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10 2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10 3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

Would you let me know how do you cut the 3 batches among the 6 buffered rows?

@comphead
Copy link
Contributor Author

If I have a left table
a b
10 20
and right table
a b
5 20
10 20
10 21
10 21
10 22
15 22
And join key is A and Filter is on column B
In freeze_streamed I can observe the right table comes as 3 batches
1 Batch. join_array [10] Range 1..3 - which is correct as rownumbers 1 and 2 related to join key 10 2 Batch. join_array[10] Range 0..2 - which is correct as rownumbers 0 and 1 related to join key 10 3 Batch. join_array[15] Range 0..1 - which is weird, why this batch associated ?

Would you let me know how do you cut the 3 batches among the 6 buffered rows?

I believe it depends on batch_size, output_size. What I have observed the buffered batch of 6 rows can be processed differently. 3 + 1 + 1 + 1, or 1 + 1 + 1 + 1 + 1 + 1, or 1 batch of 6 rows.

I think @korowa mentioned it here

filtered anti join should return only the records for which buffered-side scanning is completed (as freeze_streamed may be called in the middle of buffered-data scanning, due to output batch size), and there were no true filters for them (from p.1) -- so, maybe we should split filter evaluation and output emission in freeze_streamed (since the filters should be checked for all matched indices, but in the same time, the current streamed index can be filitered out of output because it has further buffered batches to be joined with)?

For the simplicity lets consider the test in #12082 (comment)

When I debug the freeze_streamed I can see the buffered data is coming as

[datafusion/physical-plan/src/joins/sort_merge_join.rs:1500:25] &self.buffered_data.batches = [
    BufferedBatch {
        batch: Some(
            RecordBatch {
                columns: [
                    PrimitiveArray<Int32>
                    [
                      0,
                      1,
                      1,
                      2,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      10,
                      11,
                      20,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      1100,
                      0,
                      2100,
                    ],
                    PrimitiveArray<Int32>
                    [
                      0,
                      11000,
                      0,
                      21000,
                    ],
                ],
                row_count: 4,
            },
        ),
        range: 3..4,
        join_arrays: [
            PrimitiveArray<Int32>
            [
              0,
              1,
              1,
              2,
            ],
            PrimitiveArray<Int32>
            [
              0,
              10,
              11,
              20,
            ],
        ],
    },
    BufferedBatch {
        batch: Some(
            RecordBatch { 
                columns: [
                    PrimitiveArray<Int32>
                    [
                      2,
                      2,
                    ],
                    PrimitiveArray<Int32>
                    [
                      20,
                      21,
                    ],
                    PrimitiveArray<Int32>
                    [
                      2101,
                      0,
                    ],
                    PrimitiveArray<Int32>
                    [
                      21001,
                      0,
                    ],
                ],
                row_count: 2,
            },
        ),
        range: 0..1,
        join_arrays: [
            PrimitiveArray<Int32>
            [
              2,
              2,
            ],
            PrimitiveArray<Int32>
            [
              20,
              21,
            ],
        ],
    },
]

What are ranges here? the doc says

    /// The range in which the rows share the same join key
    pub range: Range<usize>,

but how range: 3..4 in first batch and range: 0..1, in second matches the join key at all? it points to non matched rows

@viirya
Copy link
Member

viirya commented Aug 30, 2024

What are ranges here? the doc says

    /// The range in which the rows share the same join key
    pub range: Range<usize>,

but how range: 3..4 in first batch and range: 0..1, in second matches the join key at all? it points to non matched rows

range are the row indices of the batch in the BufferedBatch which have the same join key. Not related to match or not.

@korowa
Copy link
Contributor

korowa commented Sep 1, 2024

range are the row indices of the batch in the BufferedBatch which have the same join key. Not related to match or not.

That matches my understanding of these ranges in buffered batches.

Like [1..3] for join key 1 and then [0..1] for join key 2, which looks weird for me and it seems like unexpected? WDYT?

@comphead, I've tried your example and what I see while debugging, there are 3 "versions" of buffered data with the following ranges

0..1 // join key 0

1..3 // join key 1, first right batch
0..2 // join key 1, second right batch

2..3 // join key 2

I'm able to see them before and after join_partial call.

At what point in the code you are able to observe 0..1 for the key 2?

@comphead
Copy link
Contributor Author

comphead commented Sep 2, 2024

At what point in the code you are able to observe 0..1 for the key 2?

I'm running the test from #12082 (comment) and debugging the freeze_streamed function. For batch size 2 I'm seeing batches distribution like #12082 (comment)

You can see there that buffered batch with join array

        join_arrays: [
            PrimitiveArray<Int32>
            [
              2,
              2,
            ],
            PrimitiveArray<Int32>
            [
              20,
              21,
            ],
        ],

which confuses me, I was thinking only buffered batches that contains a streaming key should be there. But looks like its not.

I believe we can get do following:

  • get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
  • find all batches in buffered_data that contain the join key from step 1
  • if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

@viirya @korowa do you think it would be enough to identify that all rows has been processed for the given join key?

@korowa
Copy link
Contributor

korowa commented Sep 8, 2024

@comphead I've finally got it -- it's like in this case SMJ is trying to produce output for each join key pair (streamed-buffered) -- I guess it's how smj state managements works now -- streamed-side index won't move, until all buffered-side data will be processed, since it's required to identify current ordering.

- get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
- find all batches in buffered_data that contain the join key from step 1
- if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

I'd say that normally you don't need to compare join keys, and you should rely on buffered_data.scanning_finished() (or self.current_ordering == Less), but in your example both of these conditions are either not working, or not intended to work (not sure which of these two is a correct statement).

I also hope to start spending some time on SMJ due to #12359

@comphead
Copy link
Contributor Author

comphead commented Sep 8, 2024

@comphead I've finally got it -- it's like in this case SMJ is trying to produce output for each join key pair (streamed-buffered) -- I guess it's how smj state managements works now -- streamed-side index won't move, until all buffered-side data will be processed, since it's required to identify current ordering.

- get current join key from streamed_batch.join_arrays by self.streamed_batch.idx
- find all batches in buffered_data that contain the join key from step 1
- if the buffered_data.scanning_batch_idx equals to batches length from step2 and this batch range.end == num_rows that probably means SMJ already emitted all the indices from this batch and we are done for the some particular key

I'd say that normally you don't need to compare join keys, and you should rely on buffered_data.scanning_finished() (or self.current_ordering == Less), but in your example both of these conditions are either not working, or not intended to work (not sure which of these two is a correct statement).

I also hope to start spending some time on SMJ due to #12359

Thanks @korowa I have been experimenting so much with different parts of SMJ and it showed that
buffered_data.scanning_finished() is not working,
self.current_ordering == Less we cannot rely on this in freeze_streamed as it is called only if self.current_ordering == Equal. Now I'm trying to calculate if its possible to predict that ordering gonna change from Equal to Less.

And yes I was also trying to compare join arrays which potentially can give us a clue that everything is processed, but it might be very expensive

@comphead comphead closed this Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Core DataFusion crate physical-expr Physical Expressions
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants