Skip to content

Commit

Permalink
Add BracketedSort a new, faster algorithm for partialsort and fri…
Browse files Browse the repository at this point in the history
…ends (#52006)
  • Loading branch information
LilithHafner authored Nov 23, 2023
1 parent 79a845c commit 187e8c2
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 1 deletion.
199 changes: 198 additions & 1 deletion base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ issorted(itr;
issorted(itr, ord(lt,by,rev,order))

function partialsort!(v::AbstractVector, k::Union{Integer,OrdinalRange}, o::Ordering)
_sort!(v, InitialOptimizations(ScratchQuickSort(k)), o, (;))
# TODO move k from `alg` to `kw`
# Don't perform InitialOptimizations before Bracketing. The optimizations take O(n)
# time and so does the whole sort. But do perform them before recursive calls because
# that can cause significant speedups when the target range is large so the runtime is
# dominated by k log k and the optimizations runs in O(k) time.
_sort!(v, BoolOptimization(
Small{12}( # Very small inputs should go straight to insertion sort
BracketedSort(k))),
o, (;))
maybeview(v, k)
end

Expand Down Expand Up @@ -1138,6 +1146,195 @@ function _sort!(v::AbstractVector, a::ScratchQuickSort, o::Ordering, kw;
end


"""
BracketedSort(target[, next::Algorithm]) <: Algorithm
Perform a partialsort for the elements that fall into the indices specified by the `target`
using BracketedSort with the `next` algorithm for subproblems.
BracketedSort takes a random* sample of the input, estimates the quantiles of the input
using the quantiles of the sample to find signposts that almost certainly bracket the target
values, filters the value in the input that fall between the signpost values to the front of
the input, and then, if that "almost certainly" turned out to be true, finds the target
within the small chunk that are, by value, between the signposts and now by position, at the
front of the vector. On small inputs or when target is close to the size of the input,
BracketedSort falls back to the `next` algorithm directly. Otherwise, BracketedSort uses the
`next` algorithm only to compute quantiles of the sample and to find the target within the
small chunk.
## Performance
If the `next` algorithm has `O(n * log(n))` runtime and the input is not pathological then
the runtime of this algorithm is `O(n + k * log(k))` where `n` is the length of the input
and `k` is `length(target)`. On pathological inputs the asymptotic runtime is the same as
the runtime of the `next` algorithm.
BracketedSort itself does not allocate. If `next` is in-place then BracketedSort is also
in-place. If `next` is not in place, and it's space usage increases monotonically with input
length then BracketedSort's maximum space usage will never be more than the space usage
of `next` on the input BracketedSort receives. For large nonpathological inputs and targets
substantially smaller than the size of the input, BracketedSort's maximum memory usage will
be much less than `next`'s. If the maximum additional space usage of `next` scales linearly
then for small k the average* maximum additional space usage of BracketedSort will be
`O(n^(2.3/3))`.
By default, BracketedSort uses the `O(n)` space and `O(n + k log k)` runtime
`ScratchQuickSort` algorithm recursively.
*Sorting is unable to depend on Random.jl because Random.jl depends on sorting.
Consequently, we use `hash` as a source of randomness. The average runtime guarantees
assume that `hash(x::Int)` produces a random result. However, as this randomization is
deterministic, if you try hard enough you can find inputs that consistently reach the
worst case bounds. Actually constructing such inputs is an exercise left to the reader.
Have fun :).
Characteristics:
* *unstable*: does not preserve the ordering of elements that compare equal
(e.g. "a" and "A" in a sort of letters that ignores case).
* *in-place* in memory if the `next` algorithm is in-place.
* *estimate-and-filter*: strategy
* *linear runtime* if `length(target)` is constant and `next` is reasonable
* *n + k log k* worst case runtime if `next` has that runtime.
* *pathological inputs* can significantly increase constant factors.
"""
struct BracketedSort{T, F} <: Algorithm
target::T
get_next::F
end

# TODO: this composition between BracketedSort and ScratchQuickSort does not bring me joy
BracketedSort(k) = BracketedSort(k, k -> InitialOptimizations(ScratchQuickSort(k)))

function bracket_kernel!(v::AbstractVector, lo, hi, lo_signpost, hi_signpost, o)
i = 0
count_below = 0
checkbounds(v, lo:hi)
for j in lo:hi
x = @inbounds v[j]
a = lo_signpost !== nothing && lt(o, x, lo_signpost)
b = hi_signpost === nothing || !lt(o, hi_signpost, x)
count_below += a
# if a != b # This branch is almost never taken, so making it branchless is bad.
# @inbounds v[i], v[j] = v[j], v[i]
# i += 1
# end
c = a != b # JK, this is faster.
k = i * c + j
# Invariant: @assert firstindex(v) ≤ lo ≤ i + j ≤ k ≤ j ≤ hi ≤ lastindex(v)
@inbounds v[j], v[k] = v[k], v[j]
i += c - 1
end
count_below, i+hi
end

function move!(v, target, source)
# This function never dominates runtime—only add `@inbounds` if you can demonstrate a
# performance improvement. And if you do, also double check behavior when `target`
# is out of bounds.
@assert length(target) == length(source)
if length(target) == 1 || isdisjoint(target, source)
for (i, j) in zip(target, source)
v[i], v[j] = v[j], v[i]
end
else
@assert minimum(source) <= minimum(target)
reverse!(v, minimum(source), maximum(target))
reverse!(v, minimum(target), maximum(target))
end
end

function _sort!(v::AbstractVector, a::BracketedSort, o::Ordering, kw)
@getkw lo hi scratch
# TODO for further optimization: reuse scratch between trials better, from signpost
# selection to recursive calls, and from the fallback (but be aware of type stability,
# especially when sorting IEEE floats.

# We don't need to bounds check target because that is done higher up in the stack
# However, we cannot assume the target is inbounds.
lo < hi || return scratch
ln = hi - lo + 1

# This is simply a precomputed short-circuit to avoid doing scalar math for small inputs.
# It does not change dispatch at all.
ln < 260 && return _sort!(v, a.get_next(a.target), o, kw)

target = a.target
k = cbrt(ln)
k2 = round(Int, k^2)
k2ln = k2/ln
offset = .15k2*top_set_bit(k2) # TODO for further optimization: tune this
lo_signpost_i, hi_signpost_i =
(floor(Int, (tar - lo) * k2ln + lo + off) for (tar, off) in
((minimum(target), -offset), (maximum(target), offset)))
lastindex_sample = lo+k2-1
expected_middle_ln = (min(lastindex_sample, hi_signpost_i) - max(lo, lo_signpost_i) + 1) / k2ln
# This heuristic is complicated because it fairly accurately reflects the runtime of
# this algorithm which is necessary to get good dispatch when both the target is large
# and the input are large.
# expected_middle_ln is a float and k2 is significantly below typemax(Int), so this will
# not overflow:
# TODO move target from alg to kw to avoid this ickyness:
ln <= 130 + 2k2 + 2expected_middle_ln && return _sort!(v, a.get_next(a.target), o, kw)

# We store the random sample in
# sample = view(v, lo:lo+k2)
# but views are not quite as fast as using the input array directly,
# so we don't actually construct this view at runtime.

# TODO for further optimization: handle lots of duplicates better.
# Right now lots of duplicates rounds up when it could use some super fast optimizations
# in some cases.
# e.g.
#
# Target: |----|
# Sorted input: 000000000000000000011111112222223333333333
#
# Will filter all zeros and ones to the front when it could just take the first few
# it encounters. This optimization would be especially potent when `allequal(ans)` and
# equal elements are egal.

# 3 random trials should typically give us 0.99999 reliability; we can assume
# the input is pathological and abort to fallback if we fail three trials.
seed = hash(ln, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
for attempt in 1:3
seed = hash(attempt, seed)
for i in lo:lo+k2-1
j = mod(hash(i, seed), i:hi) # TODO for further optimization: be sneaky and remove this division
v[i], v[j] = v[j], v[i]
end
count_below, lastindex_middle = if lo_signpost_i <= lo && lastindex_sample <= hi_signpost_i
# The heuristics higher up in this function that dispatch to the `next`
# algorithm should prevent this from happening.
# Specifically, this means that expected_middle_ln == ln, so
# ln <= ... + 2.0expected_middle_ln && return ...
# will trigger.
@assert false
# But if it does happen, the kernel reduces to
0, hi
elseif lo_signpost_i <= lo
_sort!(v, a.get_next(hi_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, nothing, v[hi_signpost_i], o)
elseif lastindex_sample <= hi_signpost_i
_sort!(v, a.get_next(lo_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, v[lo_signpost_i], nothing, o)
else
# TODO for further optimization: don't sort the middle elements
_sort!(v, a.get_next(lo_signpost_i:hi_signpost_i), o, (;kw..., hi=lastindex_sample))
bracket_kernel!(v, lo, hi, v[lo_signpost_i], v[hi_signpost_i], o)
end
target_in_middle = target .- count_below
if lo <= minimum(target_in_middle) && maximum(target_in_middle) <= lastindex_middle
scratch = _sort!(v, a.get_next(target_in_middle), o, (;kw..., hi=lastindex_middle))
move!(v, target, target_in_middle)
return scratch
end
# This line almost never runs.
end
# This line only runs on pathological inputs. Make sure it's covered by tests :)
_sort!(v, a.get_next(target), o, kw)
end


"""
StableCheckSorted(next) <: Algorithm
Expand Down
37 changes: 37 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ end
for alg in safe_algs
@test sort(1:n, alg=alg, lt = (i,j) -> v[i]<=v[j]) == perm
end
# This could easily break with minor heuristic adjustments
# because partialsort is not even guaranteed to be stable:
@test partialsort(1:n, 172, lt = (i,j) -> v[i]<=v[j]) == perm[172]
@test partialsort(1:n, 315:415, lt = (i,j) -> v[i]<=v[j]) == perm[315:415]

Expand Down Expand Up @@ -1034,6 +1036,41 @@ end
@test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward))
end

@testset "partialsort tests added for BracketedSort #52006" begin
x = rand(Int, 1000)
@test partialsort(x, 1) == minimum(x)
@test partialsort(x, 1000) == maximum(x)
sx = sort(x)
for i in [1, 2, 4, 10, 11, 425, 500, 845, 991, 997, 999, 1000]
@test partialsort(x, i) == sx[i]
end
for i in [1:1, 1:2, 1:5, 1:8, 1:9, 1:11, 1:108, 135:812, 220:586, 363:368, 450:574, 458:597, 469:638, 487:488, 500:501, 584:594, 1000:1000]
@test partialsort(x, i) == sx[i]
end

# Semi-pathological input
seed = hash(1000, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15)
seed = hash(1, seed)
for i in 1:100
j = mod(hash(i, seed), i:1000)
x[j] = typemax(Int)
end
@test partialsort(x, 500) == sort(x)[500]

# Fully pathological input
# it would be too much trouble to actually construct a valid pathological input, so we
# construct an invalid pathological input.
# This test is kind of sketchy because it passes invalid inputs to the function
for i in [1:6, 1:483, 1:957, 77:86, 118:478, 223:227, 231:970, 317:958, 500:501, 500:501, 500:501, 614:620, 632:635, 658:665, 933:940, 937:942, 997:1000, 999:1000]
x = rand(1:5, 1000)
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
end
for i in [1, 7, 8, 490, 495, 852, 993, 996, 1000]
x = rand(1:5, 1000)
@test partialsort(x, i, lt=(<=)) == sort(x)[i]
end
end

# This testset is at the end of the file because it is slow.
@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
Expand Down

0 comments on commit 187e8c2

Please sign in to comment.