Skip to content

Commit

Permalink
Apply suggestions by Linus Sellberg
Browse files Browse the repository at this point in the history
Simpler PointerPairingHeap: let the Timers object deal with knowing
whether we added to HEAD or actually deleted a timer, etc.

Fix potential stack-overflow when merging pairs after inserting _lots_
of timers (over 1 million fills a 8KB thread stack).
  • Loading branch information
ysbaddaden committed Nov 21, 2024
1 parent de224a7 commit 4a63e07
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@ private struct Node
def heap_compare(other : Pointer(self)) : Bool
key < other.value.key
end

def inspect(io : IO, indent = 0) : Nil
prv = @heap_previous
nxt = @heap_next
chd = @heap_child

indent.times { io << ' ' }
io << "Node value=" << key
io << " prv=" << prv.try(&.value.key)
io << " nxt=" << nxt.try(&.value.key)
io << " chd=" << chd.try(&.value.key)
io.puts

node = heap_child?
while node
node.value.inspect(io, indent + 2)
node = node.value.heap_next?
end
end
end

describe Crystal::PointerPairingHeap do
Expand Down Expand Up @@ -52,36 +71,30 @@ describe Crystal::PointerPairingHeap do

# removes in ascending order
10.times do |i|
heap.shift?.should eq(nodes.to_unsafe + i)
node = heap.shift?
node.should eq(nodes.to_unsafe + i)
end
end

it "#delete" do
heap = Crystal::PointerPairingHeap(Node).new
nodes = StaticArray(Node, 10).new { |i| Node.new(i) }

# noop: empty
heap.delete(nodes.to_unsafe + 0).should eq({false, false})

# insert in random order
(0..9).to_a.shuffle.each do |i|
heap.add nodes.to_unsafe + i
end

# noop: unknown node
node11 = Node.new(11)
heap.delete(pointerof(node11)).should eq({false, false})

# remove some values
heap.delete(nodes.to_unsafe + 3).should eq({true, false})
heap.delete(nodes.to_unsafe + 7).should eq({true, false})
heap.delete(nodes.to_unsafe + 1).should eq({true, false})
heap.delete(nodes.to_unsafe + 3)
heap.delete(nodes.to_unsafe + 7)
heap.delete(nodes.to_unsafe + 1)

# remove tail
heap.delete(nodes.to_unsafe + 9).should eq({true, false})
heap.delete(nodes.to_unsafe + 9)

# remove head
heap.delete(nodes.to_unsafe + 0).should eq({true, true})
heap.delete(nodes.to_unsafe + 0)

# repeatedly delete min
[2, 4, 5, 6, 8].each do |i|
Expand All @@ -106,7 +119,7 @@ describe Crystal::PointerPairingHeap do
heap.shift?.should be_nil
end

it "randomly adds while we shift nodes" do
it "randomly shift while we add nodes" do
heap = Crystal::PointerPairingHeap(Node).new

nodes = uninitialized StaticArray(Node, 1000)
Expand Down
87 changes: 42 additions & 45 deletions src/crystal/pointer_pairing_heap.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Crystal::PointerPairingHeap(T)
abstract def heap_compare(other : Pointer(self)) : Bool
end

@head : T* | Nil
@head : Pointer(T)?

private def head=(head)
@head = head
Expand All @@ -29,44 +29,26 @@ class Crystal::PointerPairingHeap(T)
@head.nil?
end

def first? : T* | Nil
def first? : Pointer(T)?
@head
end

def shift? : T* | Nil
def shift? : Pointer(T)?
if node = @head
self.head = merge_pairs(node.value.heap_child?)
node.value.heap_child = nil
node
end
end

def add(node : T*) : Bool
def add(node : Pointer(T)) : Nil
if node.value.heap_previous? || node.value.heap_next? || node.value.heap_child?
raise ArgumentError.new("The node is already in a Pairing Heap tree")
end
self.head = meld(@head, node)
node == @head
end

def delete(node : T*) : {Bool, Bool}
if node == @head
self.head = merge_pairs(node.value.heap_child?)
node.value.heap_child = nil
return {true, true}
end

if remove?(node)
subtree = merge_pairs(node.value.heap_child?)
self.head = meld(@head, subtree)
unlink(node)
return {true, false}
end

{false, false}
end

private def remove?(node)
def delete(node : Pointer(T)) : Nil
if previous_node = node.value.heap_previous?
next_sibling = node.value.heap_next?

Expand All @@ -80,9 +62,13 @@ class Crystal::PointerPairingHeap(T)
next_sibling.value.heap_previous = previous_node
end

true
subtree = merge_pairs(node.value.heap_child?)
clear(node)
self.head = meld(@head, subtree)
else
false
# removing head
self.head = merge_pairs(node.value.heap_child?)
node.value.heap_child = nil
end
end

Expand All @@ -99,29 +85,29 @@ class Crystal::PointerPairingHeap(T)
clear_recursive(child)
child = child.value.heap_next?
end
unlink(node)
clear(node)
end

private def meld(a : T*, b : T*) : T*
private def meld(a : Pointer(T), b : Pointer(T)) : Pointer(T)
if a.value.heap_compare(b)
add_child(a, b)
else
add_child(b, a)
end
end

private def meld(a : T*, b : Nil) : T*
private def meld(a : Pointer(T), b : Nil) : Pointer(T)
a
end

private def meld(a : Nil, b : T*) : T*
private def meld(a : Nil, b : Pointer(T)) : Pointer(T)
b
end

private def meld(a : Nil, b : Nil) : Nil
end

private def add_child(parent : T*, node : T*) : T*
private def add_child(parent : Pointer(T), node : Pointer(T)) : Pointer(T)
first_child = parent.value.heap_child?
parent.value.heap_child = node

Expand All @@ -132,28 +118,39 @@ class Crystal::PointerPairingHeap(T)
parent
end

# Twopass merge of the children of *node* into pairs of two.
private def merge_pairs(a : T*) : T* | Nil
a.value.heap_previous = nil
private def merge_pairs(node : Pointer(T)?) : Pointer(T)?
return unless node

if b = a.value.heap_next?
a.value.heap_next = nil
b.value.heap_previous = nil
else
return a
# 1st pass: meld children into pairs (left to right)
tail = nil

while a = node
if b = a.value.heap_next?
node = b.value.heap_next?
root = meld(a, b)
root.value.heap_previous = tail
tail = root
else
a.value.heap_previous = tail
tail = a
break
end
end

rest = merge_pairs(b.value.heap_next?)
b.value.heap_next = nil
# 2nd pass: meld the pairs back into a single tree (right to left)
root = nil

pair = meld(a, b)
meld(pair, rest)
end
while tail
node = tail.value.heap_previous?
root = meld(root, tail)
tail = node
end

private def merge_pairs(node : Nil) : Nil
root.value.heap_next = nil if root
root
end

private def unlink(node) : Nil
private def clear(node) : Nil
node.value.heap_previous = nil
node.value.heap_next = nil
node.value.heap_child = nil
Expand Down
11 changes: 10 additions & 1 deletion src/crystal/system/unix/evented/timers.cr
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,21 @@ struct Crystal::Evented::Timers
# Add a new timer into the list. Returns true if it is the next ready timer.
def add(event : Evented::Event*) : Bool
@heap.add(event)
@heap.first? == event
end

# Remove a timer from the list. Returns a tuple(dequeued, was_next_ready) of
# booleans. The first bool tells whether the event was dequeued, in which case
# the second one tells if it was the next ready event.
def delete(event : Evented::Event*) : {Bool, Bool}
@heap.delete(event)
if @heap.first? == event
@heap.shift?
{true, true}
elsif event.value.heap_previous?
@heap.delete(event)
{true, false}
else
{false, false}
end
end
end

0 comments on commit 4a63e07

Please sign in to comment.