From 805f3be8e1f28135b015ddebbe6c8ef3a8c53d13 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 28 Apr 2024 08:18:04 +0200 Subject: [PATCH] Add a sort function. (#2134) --- candle-core/src/sort.rs | 17 +++++++++++++++++ candle-core/tests/tensor_tests.rs | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index bcd098e3d1..6bfa3ca7b1 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -219,4 +219,21 @@ impl Tensor { // No need for a backward pass for arg sort. self.apply_op1_no_bwd(&ArgSort { asc, last_dim }) } + + /// Sorts the tensor along the last dimension, returns the sorted tensor together with the + /// sorted indexes. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> { + if !self.is_contiguous() { + return Err(crate::Error::RequiresContiguous { + op: "sort_last_dim", + }); + } + let asort = self.arg_sort_last_dim(asc)?; + let sorted = self.gather(&asort, crate::D::Minus1)?; + Ok((sorted, asort)) + } } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 4971f3372b..e57e5a30ad 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -109,6 +109,24 @@ fn asort(device: &Device) -> Result<()> { indexes.to_vec2::()?, [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]], ); + let (sorted, indexes) = tensor.sort_last_dim(true)?; + assert_eq!( + indexes.to_vec2::()?, + [[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]], + ); + assert_eq!( + sorted.to_vec2::()?, + [[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]] + ); + let (sorted, indexes) = tensor.sort_last_dim(false)?; + assert_eq!( + indexes.to_vec2::()?, + [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]], + ); + assert_eq!( + sorted.to_vec2::()?, + [[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]] + ); Ok(()) }