diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index 3dfa65802..dacc04084 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -38,6 +38,7 @@ common deps , base >=4.9.1 && <5 , deepseq >=1.2 && <1.5 , template-haskell + , vector common test-deps import: deps diff --git a/containers/containers.cabal b/containers/containers.cabal index 30ac269e6..0223a7f44 100644 --- a/containers/containers.cabal +++ b/containers/containers.cabal @@ -33,7 +33,7 @@ source-repository head Library default-language: Haskell2010 - build-depends: base >= 4.9.1 && < 5, array >= 0.4.0.0, deepseq >= 1.2 && < 1.5, template-haskell + build-depends: base >= 4.9.1 && < 5, array >= 0.4.0.0, deepseq >= 1.2 && < 1.5, template-haskell, vector hs-source-dirs: src ghc-options: -O2 -Wall -fwarn-incomplete-uni-patterns -fwarn-incomplete-record-updates diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index d5d7e29d9..6163cac57 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE PatternGuards #-} +{-# LANGUAGE ScopedTypeVariables #-} #if !defined(TESTING) && defined(__GLASGOW_HASKELL__) {-# LANGUAGE Trustworthy #-} #endif @@ -247,6 +248,12 @@ import Data.Functor.Identity (Identity) import qualified Data.Foldable as Foldable import Control.DeepSeq (NFData(rnf)) +import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as VU + +-- import Data.Bits ((.&.),(.|.),xor,countTrailingZeros,popCount,complement, bit) +import Data.Bits ((.&.),(.|.),xor,countTrailingZeros,popCount) + import Utils.Containers.Internal.StrictPair import Utils.Containers.Internal.PtrEquality @@ -1824,6 +1831,90 @@ splitRoot orig = -- -- @since 0.5.11 +powerSet :: forall a . Set a -> Set (Set a) +powerSet xs = + let !w = length xs + !u = V.fromListN w $ toList xs + -- v ! m is the set with bit pattern m, + -- e.g., for xs = [1,2,3], + -- we have fmap Foldable.toList v = array (0,7) + -- [(0,[]),(1,[3]),(2,[2]),(3,[2,3]),(4,[1]),(5,[1,3]),(6,[1,2]),(7,[1,2,3])] + !v = V.generate (2^w) $ \ m -> + if m == 0 + then Tip + else let ST up med lo = splitBits m + in bin (u V.! (w - 1 - med)) + (v V.! up) (v V.! lo) + + full = 2^(w+1)-1 :: Int + stp = VU.iterateN (2^w) (next_pattern full) 0 + make :: Int -> Int -> Set (Set a) + make !begin !s = + if s == 0 + then Tip + else + let !sl = shiftR (s-1) 1; !sr = s - 1 - sl + in bin (v V.! (stp VU.! (begin + sl))) + (make begin sl) + (make (begin + sl+1) sr) + + in make 0 (2^w) + +{- + +-- | @bit_pattern w i@ is the bit pattern at position i +-- in the lexicographic enumeration of their meanings as sets. +-- map (bit_pattern 3) [0..7] +-- = [0,4,6,7,5,2,3,1] +-- = [000,100,110,111,101,010,011,001] +-- This function is called often. It takes 1/3 of run-time, +-- but it does not allocate. +bit_pattern :: Int -> Int -> Int +bit_pattern 0 _ = 0 +bit_pattern !width !i = + let go :: Int -> Int -> Int -> Int + go !topmask !n !set = + if n == 0 then set + else if 0 == ((n-1) .&. topmask) + then go (shiftR topmask 1) + (n-1) (set .|. topmask) + else go (shiftR topmask 1) + (n .&. complement topmask) set + in go (bit $ width-1) i 0 +-} + +-- | next bitpattern, first arg. is 2^(w+1)-1 +next_pattern :: Int -> Int -> Int +{-# inline next_pattern #-} +next_pattern full m = + if even m + then -- switch highest trailing zero bit to one + -- ex.: m = 10100 000000 + let lo = full .&. xor m (m-1) -- 00111 000111 + b = xor lo (shiftR lo 1) -- 00100 100 + in m .|. shiftR b 1 -- 10110 + else -- remove lowest one bit (at index 0) + -- then move now-lowest one bit on place to the right + -- ex.: m = 1101111 111 101 + let mm = m - 1 -- 1101110 110 100 + lo = xor mm (mm-1) -- 0000011 011 111 + b = xor lo (shiftR lo 1) -- 0000010 010 100 + in xor mm (b .|. shiftR b 1) -- 1101101 101 010 + +data StrictTriple = ST !Int !Int !Int + +-- | return bitmask for upper half, +-- index of middle bit, bitmask for lower half +splitBits :: Int -> StrictTriple +splitBits m = + let clearLowest !x = x .&. (x-1) + go 0 !x = x; go k !x = go (k-1) (clearLowest x) + up_med = go (div (popCount m) 2) m + lo = xor m up_med + up = clearLowest up_med + med = xor up_med up + in ST up (countTrailingZeros med) lo + -- Proof of complexity: step executes n times. At the ith step, -- "insertMin x `mapMonotonic` pxs" takes O(2^i log i) time since pxs has size -- 2^i - 1 and we insertMin into its elements which are sets of size <= i. @@ -1834,9 +1925,9 @@ splitRoot orig = -- = O(log n * \sum_{i=1}^{n-1} 2^i) -- = O(2^n log n) -powerSet :: Set a -> Set (Set a) -powerSet xs0 = insertMin empty (foldr' step Tip xs0) where - step x pxs = insertMin (singleton x) (insertMin x `mapMonotonic` pxs) `glue` pxs +-- powerSet_orig :: Set a -> Set (Set a) +-- powerSet_orig xs0 = insertMin empty (foldr' step Tip xs0) where +-- step x pxs = insertMin (singleton x) (insertMin x `mapMonotonic` pxs) `glue` pxs -- | \(O(nm)\). Calculate the Cartesian product of two sets. --