Skip to content

Commit

Permalink
refactor: join for performance
Browse files Browse the repository at this point in the history
1. Created a reusable output ByteRecord for building the final record
2. Instead of using chain() iterator, we now:
- Clear the output record
- Extend it with the left record
- Extend it with the right record
3. Reuse the same ByteRecord instances throughout the operation
4. Avoid creating temporary iterators and vectors

this reduces allocs and increases locality.

Also, explicitly declared key type as Vec<ByteString>
  • Loading branch information
jqnatividad committed Dec 25, 2024
1 parent a9af236 commit 94172f0
Showing 1 changed file with 48 additions and 18 deletions.
66 changes: 48 additions & 18 deletions src/cmd/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,20 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
let mut scratch = csv::ByteRecord::new();
let mut validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?;
let mut row = csv::ByteRecord::new();
let mut key;
let mut key: Vec<ByteString>;
let mut output = csv::ByteRecord::new();
while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if let Some(rows) = validx.values.get(&key) {
for &rowi in rows {
validx.idx.seek(rowi as u64)?;

validx.idx.read_byte_record(&mut scratch)?;
let combined = row.iter().chain(scratch.iter());
self.wtr.write_record(combined)?;

output.clear();
output.extend(&row);
output.extend(&scratch);
self.wtr.write_record(&output)?;
}
}
}
Expand All @@ -235,32 +239,44 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {

fn outer_join(mut self, right: bool) -> CliResult<()> {
if right {
::std::mem::swap(&mut self.rdr1, &mut self.rdr2);
::std::mem::swap(&mut self.sel1, &mut self.sel2);
swap(&mut self.rdr1, &mut self.rdr2);
swap(&mut self.sel1, &mut self.sel2);
}

let mut scratch = csv::ByteRecord::new();
let (_, pad2) = self.get_padding()?;
let mut validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?;
let mut row = csv::ByteRecord::new();
let mut key;
let mut key: Vec<ByteString>;
let mut output = csv::ByteRecord::new();

while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if let Some(rows) = validx.values.get(&key) {
for &rowi in rows {
validx.idx.seek(rowi as u64)?;
let row1 = row.iter();
let mut row1 = row.iter();
validx.idx.read_byte_record(&mut scratch)?;
output.clear();
if right {
self.wtr.write_record(scratch.iter().chain(row1))?;
output.extend(&scratch);
output.extend(&mut row1);
} else {
self.wtr.write_record(row1.chain(&scratch))?;
output.extend(&mut row1);
output.extend(&scratch);
}
self.wtr.write_record(&output)?;
}
} else if right {
self.wtr.write_record(pad2.iter().chain(&row))?;
} else {
self.wtr.write_record(row.iter().chain(&pad2))?;
output.clear();
if right {
output.extend(&pad2);
output.extend(&row);
} else {
output.extend(&row);
output.extend(&pad2);
}
self.wtr.write_record(&output)?;
}
}
Ok(self.wtr.flush()?)
Expand All @@ -269,7 +285,7 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
fn left_join(mut self, anti: bool) -> CliResult<()> {
let validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?;
let mut row = csv::ByteRecord::new();
let mut key;
let mut key: Vec<ByteString>;
while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if validx.values.get(&key).is_none() {
Expand All @@ -287,11 +303,12 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
let mut scratch = csv::ByteRecord::new();
let (pad1, pad2) = self.get_padding()?;
let mut validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?;
let mut output = csv::ByteRecord::new();

// Keep track of which rows we've written from rdr2.
let mut rdr2_written: Vec<_> = repeat(false).take(validx.num_rows).collect();
let mut row1 = csv::ByteRecord::new();
let mut key;
let mut key: Vec<ByteString>;
while self.rdr1.read_byte_record(&mut row1)? {
key = get_row_key(&self.sel1, &row1, self.casei);
if let Some(rows) = validx.values.get(&key) {
Expand All @@ -300,10 +317,16 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {

validx.idx.seek(rowi as u64)?;
validx.idx.read_byte_record(&mut scratch)?;
self.wtr.write_record(row1.iter().chain(&scratch))?;
output.clear();
output.extend(&row1);
output.extend(&scratch);
self.wtr.write_record(&output)?;
}
} else {
self.wtr.write_record(row1.iter().chain(&pad2))?;
output.clear();
output.extend(&row1);
output.extend(&pad2);
self.wtr.write_record(&output)?;
}
}

Expand All @@ -313,7 +336,10 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
if !written {
validx.idx.seek(i as u64)?;
validx.idx.read_byte_record(&mut scratch)?;
self.wtr.write_record(pad1.iter().chain(&scratch))?;
output.clear();
output.extend(&pad1);
output.extend(&scratch);
self.wtr.write_record(&output)?;
}
}
Ok(self.wtr.flush()?)
Expand All @@ -324,6 +350,7 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
pos.set_byte(0);
let mut row2 = csv::ByteRecord::new();
let mut row1 = csv::ByteRecord::new();
let mut output = csv::ByteRecord::new();
let rdr2_has_headers = self.rdr2.has_headers();
while self.rdr1.read_byte_record(&mut row1)? {
self.rdr2.seek(pos.clone())?;
Expand All @@ -333,7 +360,10 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
self.rdr2.read_byte_record(&mut row2)?;
}
while self.rdr2.read_byte_record(&mut row2)? {
self.wtr.write_record(row1.iter().chain(&row2))?;
output.clear();
output.extend(&row1);
output.extend(&row2);
self.wtr.write_record(&output)?;
}
}
Ok(self.wtr.flush()?)
Expand Down

0 comments on commit 94172f0

Please sign in to comment.