Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make union*, difference*, and intersect* run linearly by sorting elements first #203

Closed
wants to merge 8 commits into from
192 changes: 181 additions & 11 deletions src/Data/Array.purs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ module Data.Array
, deleteBy

, (\\), difference
, differenceBy
, intersect
, intersectBy

Expand All @@ -135,12 +136,16 @@ import Control.Alternative (class Alternative)
import Control.Lazy (class Lazy, defer)
import Control.Monad.Rec.Class (class MonadRec, Step(..), tailRecM2)
import Control.Monad.ST as ST
import Control.Monad.ST.Ref as STRef
import Data.Array.NonEmpty.Internal (NonEmptyArray(..))
import Data.Array.ST as STA
import Data.Array.ST.Iterator as STAI
import Data.Foldable (class Foldable, foldl, foldr, traverse_)
import Data.Foldable (foldl, foldr, foldMap, fold, intercalate) as Exports
import Data.Function (on)
import Data.Maybe (Maybe(..), maybe, isJust, fromJust, isNothing)
import Data.Ordering (invert)
import Data.Ord (abs)
import Data.Traversable (sequence, traverse)
import Data.Tuple (Tuple(..), fst, snd)
import Data.Unfoldable (class Unfoldable, unfoldr)
Expand Down Expand Up @@ -1071,8 +1076,8 @@ nubByEq eq xs = ST.run do
-- | union [1, 2, 1, 1] [3, 3, 3, 4] = [1, 2, 1, 1, 3, 4]
-- | ```
-- |
union :: forall a. Eq a => Array a -> Array a -> Array a
union = unionBy (==)
union :: forall a. Ord a => Array a -> Array a -> Array a
union = unionBy compare

-- | Calculate the union of two arrays, using the specified function to
-- | determine equality of elements. Note that duplicates in the first array
Expand All @@ -1083,8 +1088,33 @@ union = unionBy (==)
-- | unionBy mod3eq [1, 5, 1, 2] [3, 4, 3, 3] = [1, 5, 1, 2, 3]
-- | ```
-- |
unionBy :: forall a. (a -> a -> Boolean) -> Array a -> Array a -> Array a
unionBy eq xs ys = xs <> foldl (flip (deleteBy eq)) (nubByEq eq ys) xs
unionBy :: forall a. (a -> a -> Ordering) -> Array a -> Array a -> Array a
unionBy cmp left right = map snd $ ST.run do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be more efficient to take advantage of the of the fact that the output is always the first array plus some other stuff appended to it. Here's some non-ST pseudo-ish code describing that option:

union left right = left <> other where
  other =
    combineIndex right left -- intentionally putting right first and assuming stable sort
    # sortBy valueFunc
    # groupBy valueFunc
    # map head
    # filter keepRightOnly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the runtime cost of this approach versus mine?

I made the rather stupid assumption that mine would be n+m because when Harry said the current version was n*m, I thought he was implying we should make the code linear. I didn't actually analyze my code using Big O notation to see how many steps it takes.

Copy link
Contributor

@milesfrain milesfrain Jan 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe all reasonable approaches are O(n log n) (where n is the sum of array sizes). But benchmarking could still reveal a 3x speedup with a different approach, which would still be O(n log n). I'll play around with this and report back with findings. The choice of input data could also significantly change the relative performance of different algorithms.

result <- STA.new
ST.foreach indexedAndSorted \(Tuple fromLeftArray pair@(Tuple _ x')) -> do
if fromLeftArray then do
void $ STA.push pair result
else do
maybePreviousValue <- last <$> STA.unsafeFreeze result
case maybePreviousValue of
Just (Tuple _ y)
| cmp y x' /= EQ -> void $ STA.push pair result
| otherwise -> pure unit
Nothing -> do
void $ STA.push pair result
_ <- STA.sortWith fst result
STA.unsafeFreeze result
where
-- Note: when elements are equal, left array elements
-- (i.e. `Tuple true (Tuple _ _)`) appear "before" right array elements
-- (i.e. `Tuple false (Tuple _ _)`) in the resulting array.
-- This is different from `differenceBy` and `intersectBy`
valueThenLeftFirst :: Tuple Boolean (Tuple Int a) -> Tuple Boolean (Tuple Int a) -> Ordering
valueThenLeftFirst (Tuple lb (Tuple _ lv)) (Tuple rb (Tuple _ rv)) =
cmp lv rv <> invert (compare lb rb)

indexedAndSorted = sortBy valueThenLeftFirst $ combineIndex left right


-- | Delete the first element of an array which is equal to the specified value,
-- | creating a new array.
Expand Down Expand Up @@ -1118,13 +1148,129 @@ deleteBy eq x ys = maybe ys (\i -> unsafePartial $ fromJust (deleteAt i ys)) (fi
-- | difference [2, 1] [2, 3] = [1]
-- | ```
-- |
-- | Running time: `O(n*m)`, where n is the length of the first array, and m is
-- | Running time: `O(n+m)`, where n is the length of the first array, and m is
-- | the length of the second.
difference :: forall a. Eq a => Array a -> Array a -> Array a
difference = foldr delete
difference :: forall a. Ord a => Array a -> Array a -> Array a
difference = differenceBy compare

infix 5 difference as \\

differenceBy :: forall a. (a -> a -> Ordering) -> Array a -> Array a -> Array a
differenceBy _ left [] = left
differenceBy _ left@[] _ = left
differenceBy cmp left right = ST.run do
indices <- STA.new
latestRightArrayValue <- STRef.new Nothing
ST.foreach indexedAndSorted \idx -> do
maybeValueToRemove <- STRef.read latestRightArrayValue
let
fromLeftArray = idx >= 0
x' = safeishIndex idx
case fromLeftArray, maybeValueToRemove of
true, Just (Tuple valueToRemove count) | cmp valueToRemove x' == EQ ->
-- do not add left array's element to final array; check count
if count == 1 then do
void $ STRef.write Nothing latestRightArrayValue
else do
let decrementCount = Just (Tuple valueToRemove (count - 1))
void $ STRef.write decrementCount latestRightArrayValue

true, _ -> void $ STA.push idx indices

false, Just (Tuple valueToRemove count) | cmp valueToRemove x' == EQ -> do
let next = if count == 0 then Nothing
else Just (Tuple valueToRemove (count - 1))
void $ STRef.write next latestRightArrayValue
false, _ -> do
void $ STRef.write (Just (Tuple x' 1)) latestRightArrayValue
_ <- STA.sortBy (compare `on` abs) indices
sortedIndexArray <- STA.unsafeFreeze indices
final <- STA.new
ST.foreach sortedIndexArray \sortedIdx ->
void $ STA.push (safeishIndex sortedIdx) final
STA.unsafeFreeze final

where
leftLen :: Int
leftLen = length left

rightLen :: Int
rightLen = length right

safeishIndex :: Int -> a
safeishIndex i
| i < 0 = unsafePartial (unsafeIndex right ((abs i) - leftLen))
| otherwise = unsafePartial (unsafeIndex left i)


valueThenRightFirst :: Int -> Int -> Ordering
valueThenRightFirst li ri =
((cmp `on` safeishIndex) li ri) <> compare li ri

indexedAndSorted =
sortBy valueThenRightFirst $ combineIndex' leftLen left rightLen right

-- Internal use only
-- Essentially...
-- ```
-- combineWithIndex' cmp left right =
-- let
-- leftIndexed = mapWithIndex (\idx _ -> idx) left
--
-- adjustedIdx idx _ = negate (idx + (length left))
-- rightIndexedPlus = mapWithIndex adjustedIdx right
--
-- in leftIndices <> rightIndices
-- ```
-- ... but without creating two intermediate arrays due to the `mapWithIndex`
-- on both arrays. Left array elements' indices are positive;
-- right array elements' indices are negative. By using `abs`, one
-- can still get the index of the right-array elements.
-- ```
-- combineIndex' compare [7] [9, 5]
-- == [0, (-1), (-2)]
-- ```
combineIndex' :: forall a. Int -> Array a -> Int -> Array a -> Array Int
combineIndex' leftLen left rightLen right = ST.run do
out <- STA.new
ST.for 0 leftLen \idx -> do
void $ STA.push idx out
ST.for 0 rightLen \idx -> do
void $ STA.push (negate (leftLen + idx)) out
STA.unsafeFreeze out

-- Internal use only
-- Essentially...
-- ```
-- combineWithIndex cmp left right =
-- let
-- t3 bool idx val = Tuple bool (Tuple idx val)
-- leftIndexed = mapWithIndex (\idx v -> t3 true idx v) left
--
-- tupleAdjusted idx v = t3 false (idx + (length left)) v
-- rightIndexedPlus = mapWithIndex tupleAdjusted right
--
-- in leftIndexed <> rightIndexedPlus
-- ```
-- ... but without creating two intermediate arrays due to the `mapWithIndex`
-- on both arrays.
-- ```
-- combineIndex compare [7] [9, 5]
-- == [t3 true 0 7, t3 false 1 9, t3 false 2 5]
-- ```
combineIndex :: forall a. Array a -> Array a -> Array (Tuple Boolean (Tuple Int a))
combineIndex left right = ST.run do
out <- STA.new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any benefits to preallocating the array to a known size? Tried to find a clear answer for this in JS. And if there is a benefit that we want to take advantage of, we'd need to expand the Array.ST API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've wondered about that myself...

let leftLen = length left
ST.for 0 leftLen \idx -> do
let val = unsafePartial $ unsafeIndex left idx
void $ STA.push (Tuple true (Tuple idx val)) out
let rightLen = length right
ST.for 0 rightLen \idx -> do
let val = unsafePartial $ unsafeIndex right idx
void $ STA.push (Tuple false (Tuple (leftLen + idx) val)) out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any performance penalties to using a record instead of a tuple? Luckily, this tuple happens to be type-safe because of the distinct types (Boolean, Int, a).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. I also tried it out with data Tuple3 a b c = Tuple3 a b c and that sped things up only slightly.

STA.unsafeFreeze out

-- | Calculate the intersection of two arrays, creating a new array. Note that
-- | duplicates in the first array are preserved while duplicates in the second
-- | array are removed.
Expand All @@ -1133,8 +1279,8 @@ infix 5 difference as \\
-- | intersect [1, 1, 2] [2, 2, 1] = [1, 1, 2]
-- | ```
-- |
intersect :: forall a. Eq a => Array a -> Array a -> Array a
intersect = intersectBy eq
intersect :: forall a. Ord a => Array a -> Array a -> Array a
intersect = intersectBy compare

-- | Calculate the intersection of two arrays, using the specified equivalence
-- | relation to compare elements, creating a new array. Note that duplicates
Expand All @@ -1146,8 +1292,32 @@ intersect = intersectBy eq
-- | intersectBy mod3eq [1, 2, 3] [4, 6, 7] = [1, 3]
-- | ```
-- |
intersectBy :: forall a. (a -> a -> Boolean) -> Array a -> Array a -> Array a
intersectBy eq xs ys = filter (\x -> isJust (findIndex (eq x) ys)) xs
intersectBy :: forall a. (a -> a -> Ordering) -> Array a -> Array a -> Array a
intersectBy _ left [] = left
intersectBy _ left@[] _ = left
intersectBy cmp left right = map snd $ ST.run do
result <- STA.new
latestRightArrayValue <- STRef.new Nothing
ST.foreach indexedAndSorted \(Tuple fromLeftArray pair@(Tuple i x')) ->
if fromLeftArray then do
maybeValueToRemove <- STRef.read latestRightArrayValue
case maybeValueToRemove of
Just valueToRemove | cmp valueToRemove x' == EQ -> do
void $ STA.push pair result
_ -> pure unit
else do
void $ STRef.write (Just x') latestRightArrayValue
_ <- STA.sortWith fst result
STA.unsafeFreeze result
where
-- Note: when elements are equal, right array elements
-- (i.e. `Tuple false (Tuple _ _)`) appear "before" left array elements
-- (i.e. `Tuple true (Tuple _ _)`) in the resulting array.
valueThenRightFirst :: Tuple Boolean (Tuple Int a) -> Tuple Boolean (Tuple Int a) -> Ordering
valueThenRightFirst (Tuple lb (Tuple _ lv)) (Tuple rb (Tuple _ rv)) =
cmp lv rv <> compare lb rb

indexedAndSorted = sortBy valueThenRightFirst $ combineIndex left right

-- | Apply a function to pairs of elements at the same index in two arrays,
-- | collecting the results in a new array.
Expand Down
40 changes: 24 additions & 16 deletions src/Data/Array/NonEmpty.purs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ module Data.Array.NonEmpty

, (\\), difference
, difference'
, differenceBy
, differenceBy'
, intersect
, intersect'
, intersectBy
Expand Down Expand Up @@ -408,57 +410,63 @@ nubBy f = unsafeAdapt $ A.nubBy f
nubByEq :: forall a. (a -> a -> Boolean) -> NonEmptyArray a -> NonEmptyArray a
nubByEq f = unsafeAdapt $ A.nubByEq f

union :: forall a. Eq a => NonEmptyArray a -> NonEmptyArray a -> NonEmptyArray a
union = unionBy (==)
union :: forall a. Ord a => NonEmptyArray a -> NonEmptyArray a -> NonEmptyArray a
union = unionBy compare

union' :: forall a. Eq a => NonEmptyArray a -> Array a -> NonEmptyArray a
union' = unionBy' (==)
union' :: forall a. Ord a => NonEmptyArray a -> Array a -> NonEmptyArray a
union' = unionBy' compare

unionBy
:: forall a
. (a -> a -> Boolean)
. (a -> a -> Ordering)
-> NonEmptyArray a
-> NonEmptyArray a
-> NonEmptyArray a
unionBy eq xs = unionBy' eq xs <<< toArray
unionBy cmp xs = unionBy' cmp xs <<< toArray

unionBy'
:: forall a
. (a -> a -> Boolean)
. (a -> a -> Ordering)
-> NonEmptyArray a
-> Array a
-> NonEmptyArray a
unionBy' eq xs = unsafeFromArray <<< A.unionBy eq (toArray xs)
unionBy' cmp xs = unsafeFromArray <<< A.unionBy cmp (toArray xs)

delete :: forall a. Eq a => a -> NonEmptyArray a -> Array a
delete x = adaptAny $ A.delete x

deleteBy :: forall a. (a -> a -> Boolean) -> a -> NonEmptyArray a -> Array a
deleteBy f x = adaptAny $ A.deleteBy f x

difference :: forall a. Eq a => NonEmptyArray a -> NonEmptyArray a -> Array a
difference :: forall a. Ord a => NonEmptyArray a -> NonEmptyArray a -> Array a
difference xs = adaptAny $ difference' xs

difference' :: forall a. Eq a => NonEmptyArray a -> Array a -> Array a
difference' :: forall a. Ord a => NonEmptyArray a -> Array a -> Array a
difference' xs = A.difference $ toArray xs

intersect :: forall a . Eq a => NonEmptyArray a -> NonEmptyArray a -> Array a
intersect = intersectBy eq
differenceBy :: forall a. (a -> a -> Ordering) -> NonEmptyArray a -> NonEmptyArray a -> Array a
differenceBy cmp xs = adaptAny $ differenceBy' cmp xs

intersect' :: forall a . Eq a => NonEmptyArray a -> Array a -> Array a
intersect' = intersectBy' eq
differenceBy' :: forall a. (a -> a -> Ordering) -> NonEmptyArray a -> Array a -> Array a
differenceBy' cmp xs = A.differenceBy cmp $ toArray xs

intersect :: forall a . Ord a => NonEmptyArray a -> NonEmptyArray a -> Array a
intersect = intersectBy compare

intersect' :: forall a . Ord a => NonEmptyArray a -> Array a -> Array a
intersect' = intersectBy' compare

intersectBy
:: forall a
. (a -> a -> Boolean)
. (a -> a -> Ordering)
-> NonEmptyArray a
-> NonEmptyArray a
-> Array a
intersectBy eq xs = intersectBy' eq xs <<< toArray

intersectBy'
:: forall a
. (a -> a -> Boolean)
. (a -> a -> Ordering)
-> NonEmptyArray a
-> Array a
-> Array a
Expand Down
7 changes: 5 additions & 2 deletions test/Test/Data/Array.purs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ testArray = do
assert $ A.union [1, 1, 2, 3] [2, 3, 4] == [1, 1, 2, 3, 4]

log "unionBy should produce the union of two arrays using the specified equality relation"
assert $ A.unionBy (\_ y -> y < 5) [1, 2, 3] [2, 3, 4, 5, 6] == [1, 2, 3, 5, 6]
assert $ A.unionBy (\x y -> compare (x `mod` 4) (y `mod` 4))
[1, 2, 3] [2, 3, 4, 5, 6] == [1, 2, 3, 4]

log "delete should remove the first matching item from an array"
assert $ A.delete 1 [1, 2, 1] == [2, 1]
Expand All @@ -413,12 +414,14 @@ testArray = do

log "(\\\\) should return the difference between two lists"
assert $ [1, 2, 3, 4, 3, 2, 1] \\ [1, 1, 2, 3] == [4, 3, 2]
assert $ [1, 2, 3, 4, 3, 2, 1] \\ [1, 1, 1, 1, 1, 2, 3] == [4, 3, 2]
assert $ [1, 4, 6, 8, 6, 4, 1] \\ [1, 4, 4, 1, 6] == [8, 6]

log "intersect should return the intersection of two arrays"
assert $ A.intersect [1, 2, 3, 4, 3, 2, 1] [1, 1, 2, 3] == [1, 2, 3, 3, 2, 1]

log "intersectBy should return the intersection of two arrays using the specified equivalence relation"
assert $ A.intersectBy (\x y -> (x * 2) == y) [1, 2, 3] [2, 6] == [1, 3]
assert $ A.intersectBy (\x y -> compare (x `mod` 3) (y `mod` 3)) [1, 2, 3] [2, 6] == [2, 3]

log "zipWith should use the specified function to zip two lists together"
assert $ A.zipWith (\x y -> [show x, y]) [1, 2, 3] ["a", "b", "c"] == [["1", "a"], ["2", "b"], ["3", "c"]]
Expand Down
5 changes: 3 additions & 2 deletions test/Test/Data/Array/NonEmpty.purs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ testNonEmptyArray = do
assert $ NEA.union (fromArray [1, 1, 2, 3]) (fromArray [2, 3, 4]) == fromArray [1, 1, 2, 3, 4]

log "unionBy should produce the union of two arrays using the specified equality relation"
assert $ NEA.unionBy (\_ y -> y < 5) (fromArray [1, 2, 3]) (fromArray [2, 3, 4, 5, 6]) == fromArray [1, 2, 3, 5, 6]
assert $ NEA.unionBy (\x y -> compare (x `mod` 4) (y `mod` 4))
(fromArray [1, 2, 3]) (fromArray [2, 3, 4, 5, 6]) == fromArray [1, 2, 3, 4]

log "delete should remove the first matching item from an array"
assert $ NEA.delete 1 (fromArray [1, 2, 1]) == [2, 1]
Expand All @@ -295,7 +296,7 @@ testNonEmptyArray = do
assert $ NEA.intersect (fromArray [1, 2, 3, 4, 3, 2, 1]) (fromArray [1, 1, 2, 3]) == [1, 2, 3, 3, 2, 1]

log "intersectBy should return the intersection of two arrays using the specified equivalence relation"
assert $ NEA.intersectBy (\x y -> (x * 2) == y) (fromArray [1, 2, 3]) (fromArray [2, 6]) == [1, 3]
assert $ NEA.intersectBy (\x y -> compare (x `mod` 3) (y `mod` 3)) (fromArray [1, 2, 3]) (fromArray [2, 6]) == [2, 3]

log "zipWith should use the specified function to zip two arrays together"
assert $ NEA.zipWith (\x y -> [show x, y]) (fromArray [1, 2, 3]) (fromArray ["a", "b", "c"]) == fromArray [["1", "a"], ["2", "b"], ["3", "c"]]
Expand Down