diff --git a/hstream-kafka/HStream/Kafka/Network.hs b/hstream-kafka/HStream/Kafka/Network.hs index b9f1542e4..7f6098322 100644 --- a/hstream-kafka/HStream/Kafka/Network.hs +++ b/hstream-kafka/HStream/Kafka/Network.hs @@ -55,6 +55,7 @@ import HStream.Kafka.Server.Types (ServerContext (..), initConnectionContext) import qualified HStream.Logger as Log import Kafka.Protocol.Encoding +import Kafka.Protocol.Error import Kafka.Protocol.Message import Kafka.Protocol.Service @@ -387,7 +388,7 @@ runParseIO more parser = more >>= go Nothing Done l r -> pure (r, l) More f -> do msg <- more go (Just f) msg - Fail _ err -> E.throwIO $ DecodeError $ "Fail, " <> err + Fail _ err -> E.throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Fail, " <> err) showSockAddrHost :: N.SockAddr -> String showSockAddrHost (N.SockAddrUnix str) = str diff --git a/hstream-kafka/HStream/Kafka/Network/IO.hs b/hstream-kafka/HStream/Kafka/Network/IO.hs index 902bfb216..ec8ca846f 100644 --- a/hstream-kafka/HStream/Kafka/Network/IO.hs +++ b/hstream-kafka/HStream/Kafka/Network/IO.hs @@ -22,6 +22,7 @@ import qualified Network.Socket.ByteString as N import qualified Network.Socket.ByteString.Lazy as NL import Kafka.Protocol.Encoding +import Kafka.Protocol.Error import Kafka.Protocol.Message -- | Receive a kafka message with its request header from socket. @@ -54,20 +55,20 @@ recvKafkaMsgBS peer m_more s = do headerResult <- liftIO $ runParser @RequestHeader get reqBs case headerResult of Done l h -> return $ Just (h, l) - Fail _ err -> E.throw $ DecodeError $ "Fail, " <> err - More _ -> E.throw $ DecodeError $ "More" + Fail _ err -> E.throw $ DecodeError $ (CORRUPT_MESSAGE, "Fail, " <> err) + More _ -> E.throw $ DecodeError $ (CORRUPT_MESSAGE, "More") Done l reqBs -> do State.put l headerResult <- liftIO $ runParser @RequestHeader get reqBs case headerResult of Done l' h -> return $ Just (h, l') - Fail _ err -> E.throw $ DecodeError $ "Fail, " <> err - More _ -> E.throw $ DecodeError $ "More" + Fail _ err -> E.throw $ DecodeError $ (CORRUPT_MESSAGE, "Fail, " <> err) + More _ -> E.throw $ DecodeError $ (CORRUPT_MESSAGE, "More") More f -> do i_new <- liftIO $ N.recv s 1024 State.put i_new recvKafkaMsgBS peer (Just f) s - Fail _ err -> liftIO . E.throwIO $ DecodeError $ "Fail, " <> err + Fail _ err -> liftIO . E.throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Fail, " <> err) -- | Send a kafka message to socket. Note the message should be packed -- with its response header. diff --git a/hstream-kafka/hstream-kafka.cabal b/hstream-kafka/hstream-kafka.cabal index cb0f18330..21596a8b0 100644 --- a/hstream-kafka/hstream-kafka.cabal +++ b/hstream-kafka/hstream-kafka.cabal @@ -57,6 +57,7 @@ library kafka-protocol Kafka.Protocol.Encoding.Encode Kafka.Protocol.Encoding.Internal Kafka.Protocol.Encoding.Parser + Kafka.Protocol.Encoding.Types Kafka.Protocol.Message.Struct Kafka.Protocol.Message.Total diff --git a/hstream-kafka/protocol/Kafka/Protocol/Encoding.hs b/hstream-kafka/protocol/Kafka/Protocol/Encoding.hs index 4119e5d08..c0c429702 100644 --- a/hstream-kafka/protocol/Kafka/Protocol/Encoding.hs +++ b/hstream-kafka/protocol/Kafka/Protocol/Encoding.hs @@ -1,35 +1,53 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} --- As of GHC 8.8.1, GHC started complaining about -optP--cpp when profling --- is enabled. See https://gitlab.haskell.org/ghc/ghc/issues/17185. -{-# OPTIONS_GHC -pgmP "hpp --cpp -P" #-} module Kafka.Protocol.Encoding - ( Serializable (..) - , putEither, getEither - , putMaybe, getMaybe + ( -- * Parser + Parser, DecodeError (..) , runGet , runGet' , runPut , runPutLazy - , DecodeError (..) - -- * Defined types - , VarInt32 (..) - , VarInt64 (..) - , NullableString - , CompactString (..) - , CompactNullableString (..) - , NullableBytes - , CompactBytes (..) - , CompactNullableBytes (..) - , TaggedFields (EmptyTaggedFields) -- TODO - , KaArray (..) - , CompactKaArray (..) + -- * Message Format + , RecordBatch (..) + , decodeRecordBatch + , updateRecordBatchBaseOffset + , unsafeUpdateRecordBatchBaseOffset + -- ** Attributes + , Attributes + , CompressionType (..) + , TimestampType (..) + , compressionType + , timestampType + -- ** Record + -- * Old Message Format() + -- TODO + -- * Internals + -- ** Parser + , runParser + , runParser' + , Result (..) + , takeBytes + -- ** Builder + , Builder + , builderLength + , toLazyByteString + -- * Misc + , pattern NonNullKaArray + , unNonNullKaArray + , kaArrayToCompact + , kaArrayFromCompact + + , module Kafka.Protocol.Encoding.Types + + -- TODO: The following are outdated, and need to be updated. + + -- * Exports for Produce + , decodeBatchRecordsForProduce + , unsafeAlterBatchRecordsBsForProduce + -- * Records , BatchRecord (..) , decodeBatchRecords @@ -41,43 +59,20 @@ module Kafka.Protocol.Encoding , RecordV1 (..) , RecordV2 (..) , RecordV2_ (..) - , RecordBatch (..) - , RecordBatch_ (..) - , unsafeAlterRecordBatchBs - , unsafeAlterMessageSetBs - , RecordKey (..) - , RecordValue (..) - , RecordArray (..) + , RecordBatch1 (..) , RecordHeader - , RecordHeaderKey (..) - , RecordHeaderValue (..) -- ** Helpers , decodeLegacyRecordBatch , decodeRecordMagic , decodeNextRecordOffset - -- ** Misc - , pattern NonNullKaArray - , unNonNullKaArray - , kaArrayToCompact - , kaArrayFromCompact - , decodeBatchRecordsForProduce - , unsafeAlterBatchRecordsBsForProduce - -- * Internals - -- ** Parser - , Parser - , runParser - , runParser' - , Result (..) - , takeBytes - -- ** Builder - , Builder - , builderLength - , toLazyByteString + , unsafeAlterRecordBatchBs + , unsafeAlterMessageSetBs ) where import Control.DeepSeq (NFData) import Control.Exception import Control.Monad +import Data.Bits import Data.ByteString (ByteString) import qualified Data.ByteString as BS import Data.ByteString.Internal (ByteString (BS), w2c) @@ -87,13 +82,10 @@ import Data.Digest.CRC32 (crc32) import Data.Digest.CRC32C (crc32c) import Data.Int import Data.Maybe -import Data.String (IsString) -import Data.Text (Text) import Data.Typeable (Typeable, showsTypeRep, typeOf) import Data.Vector (Vector) import qualified Data.Vector as V -import Data.Word import Foreign.Ptr (plusPtr) import GHC.ForeignPtr (unsafeWithForeignPtr) import GHC.Generics @@ -102,81 +94,13 @@ import qualified HStream.Base.Growing as Growing import Kafka.Protocol.Encoding.Encode import Kafka.Protocol.Encoding.Internal import Kafka.Protocol.Encoding.Parser +import Kafka.Protocol.Encoding.Types +import Kafka.Protocol.Error ------------------------------------------------------------------------------- +-- Parser -class Typeable a => Serializable a where - get :: Parser a - - default get :: (Generic a, GSerializable (Rep a)) => Parser a - get = to <$> gget - - put :: a -> Builder - - default put :: (Generic a, GSerializable (Rep a)) => a -> Builder - put a = gput (from a) - -class GSerializable f where - gget :: Parser (f a) - gput :: f a -> Builder - --- | Unit: used for constructors without arguments -instance GSerializable U1 where - gget = pure U1 - gput U1 = mempty - --- | Products: encode multiple arguments to constructors -instance (GSerializable a, GSerializable b) => GSerializable (a :*: b) where - gget = do - !a <- gget - !b <- gget - pure $ a :*: b - gput (a :*: b) = gput a <> gput b - --- | Meta-information -instance (GSerializable a) => GSerializable (M1 i c a) where - gget = M1 <$> gget - gput (M1 x) = gput x - -instance (Serializable a) => GSerializable (K1 i a) where - gget = K1 <$> get - gput (K1 x) = put x - --- There is no easy way to support Sum types for Generic instance. --- --- So here we give a special case for Either -putEither :: (Serializable a, Serializable b) => Either a b -> Builder -putEither (Left x) = put x -putEither (Right x) = put x -{-# INLINE putEither #-} - --- There is no way to support Sum types for Generic instance. --- --- So here we give a special case for Either -getEither - :: (Serializable a, Serializable b) - => Bool -- ^ True for Right, False for Left - -> Parser (Either a b) -getEither True = Right <$> get -getEither False = Left <$> get -{-# INLINE getEither #-} - -putMaybe :: (Serializable a) => Maybe a -> Builder -putMaybe (Just x) = put x -putMaybe Nothing = mempty -{-# INLINE putMaybe #-} - -getMaybe - :: (Serializable a) - => Bool -- ^ True for Just, False for Nothing - -> Parser (Maybe a) -getMaybe True = Just <$> get -getMaybe False = pure Nothing -{-# INLINE getMaybe #-} - -------------------------------------------------------------------------------- - -newtype DecodeError = DecodeError String +newtype DecodeError = DecodeError (ErrorCode, String) deriving (Show) instance Exception DecodeError @@ -186,8 +110,9 @@ runParser' parser bs = do result <- runParser parser bs case result of Done l r -> pure (r, l) - Fail _ err -> throwIO $ DecodeError $ "Fail, " <> err - More _ -> throwIO $ DecodeError $ showsTypeRep (typeOf parser) ", need more" + Fail _ err -> throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Fail, " <> err) + More _ -> throwIO $ DecodeError $ + (CORRUPT_MESSAGE, showsTypeRep (typeOf parser) ", need more") {-# INLINE runParser' #-} runGet :: Serializable a => ByteString -> IO a @@ -195,7 +120,8 @@ runGet bs = do (r, l) <- runParser' get bs if BS.null l then pure r else throwIO $ DecodeError $ - "runGet done, but left " <> map w2c (BS.unpack l) + let msg = "runGet done, but left " <> map w2c (BS.unpack l) + in (CORRUPT_MESSAGE, msg) {-# INLINE runGet #-} runGet' :: Serializable a => ByteString -> IO (a, ByteString) @@ -211,158 +137,186 @@ runPut = BL.toStrict . toLazyByteString . put {-# INLINE runPut #-} ------------------------------------------------------------------------------- --- Extra Primitive Types -newtype VarInt32 = VarInt32 { unVarInt32 :: Int32 } - deriving newtype (Show, Num, Integral, Real, Enum, Ord, Eq, Bounded, NFData) +-- | Common Record base for all versions. +-- +-- To help parse all Record version. +data RecordBase = RecordBase + { baseOffset :: {-# UNPACK #-} !Int64 + , batchLength :: {-# UNPACK #-} !Int32 + -- ^ The total size of the record batch in bytes + -- (from partitionLeaderEpochOrCrc to the end) + , partitionLeaderEpochOrCrc :: {-# UNPACK #-} !Int32 + -- ^ For version 0-1, this is the CRC32 of the remainder of the record. + -- For version 2, this is the partition leader epoch. + , magic :: {-# UNPACK #-} !Int8 + } deriving (Generic, Show, Eq) -newtype VarInt64 = VarInt64 { unVarInt64 :: Int64 } - deriving newtype (Show, Num, Integral, Real, Enum, Ord, Eq, Bounded, NFData) +instance Serializable RecordBase -type NullableString = Maybe Text +------------------------------------------------------------------------------- +-- RecordBatch: v2 -newtype CompactString = CompactString { unCompactString :: Text } - deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup) +-- Ref: https://kafka.apache.org/documentation/#recordbatch +-- +-- Introduced in Kafka 0.11.0 +data RecordBatch = RecordBatch + { baseOffset :: {-# UNPACK #-} !Int64 + , batchLength :: {-# UNPACK #-} !Int32 + -- ^ The total size of the record batch in bytes, + -- from partitionLeaderEpoch(included) to the end. + , partitionLeaderEpoch :: {-# UNPACK #-} !Int32 + , magic :: {-# UNPACK #-} !Int8 + , crc :: {-# UNPACK #-} !Int32 + , attributes :: {-# UNPACK #-} !Attributes + , lastOffsetDelta :: {-# UNPACK #-} !Int32 + , baseTimestamp :: {-# UNPACK #-} !Int64 + , maxTimestamp :: {-# UNPACK #-} !Int64 + , producerId :: {-# UNPACK #-} !Int64 + , producerEpoch :: {-# UNPACK #-} !Int16 + , baseSequence :: {-# UNPACK #-} !Int32 + , recordsCount :: {-# UNPACK #-} !Int32 + , recordsData :: !ByteString + -- ^ Note that when compression is enabled, this is the compressed data + } deriving (Generic, Show, Eq) + +instance NFData RecordBatch -newtype CompactNullableString = CompactNullableString - { unCompactNullableString :: Maybe Text } - deriving newtype (Show, Eq, Ord) +-- 49: partitionLeaderEpoch(4) + magic(1) +-- + crc(4) + attributes(2) + lastOffsetDelta(4) +-- + baseTimestamp(8) + maxTimestamp(8) +-- + producerId(8) + producerEpoch(2) +-- + baseSequence(4) +-- + recordsCount(4) +sizeOfPartitionLeaderEpochToRecordCount :: Int +sizeOfPartitionLeaderEpochToRecordCount = 49 + +decodeRecordBatch :: Bool -> ByteString -> IO RecordBatch +decodeRecordBatch shouldValidateCrc bs = do + (RecordBase{..}, bs') <- runGet' @RecordBase bs + case magic of + 2 -> do + let partitionLeaderEpoch = partitionLeaderEpochOrCrc + -- The CRC covers the data from the attributes to the end of + -- the batch (i.e. all the bytes that follow the CRC). + -- + -- The CRC-32C (Castagnoli) polynomial is used for the + -- computation. + (crc, bs'') <- runGet' @Int32 bs' + when shouldValidateCrc $ do + let crcPayload = BS.take (fromIntegral batchLength - 9) bs'' + when (fromIntegral (crc32c crcPayload) /= crc) $ + throwIO $ DecodeError (CORRUPT_MESSAGE, "Invalid CRC32") + (RecordBatchHead{..}, recordBs) <- runGet' @RecordBatchHead bs'' + (recordsCount, recordsData) <- runGet' @Int32 recordBs + if BS.length recordsData == fromIntegral batchLength - sizeOfPartitionLeaderEpochToRecordCount + then pure RecordBatch{..} + else throwIO $ DecodeError (INVALID_RECORD, "There are some bytes left") + _ -> throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid magic " <> show magic) + +-- Be sure to use this function after the calling of 'decodeRecordBatch', +-- since we do not check the bounds. +updateRecordBatchBaseOffset :: ByteString -> (Int64 -> Int64) -> IO ByteString +updateRecordBatchBaseOffset bs f = do + let (BS ofp _) = BSU.unsafeTake 8 bs + offset <- unsafeWithForeignPtr ofp $ fmap (f . fromIntegral) . peek64BE + pure $ runPut offset <> (BSU.unsafeDrop 8 bs) + +-- FIXME: I don't know how the ghc sharing does, so be careful. +-- +-- https://hackage.haskell.org/package/bytestring-0.12.1.0/docs/Data-ByteString-Unsafe.html#v:unsafeUseAsCString +-- https://hackage.haskell.org/package/vector-0.13.1.0/docs/Data-Vector.html#v:unsafeThaw +unsafeUpdateRecordBatchBaseOffset :: ByteString -> (Int64 -> Int64) -> IO () +unsafeUpdateRecordBatchBaseOffset (BS fp len) f = + unsafeWithForeignPtr fp $ \p -> do + -- FIXME improvement: does ghc provide a modifyPtr function? + origin <- fromIntegral <$> peek64BE p + poke64BE p (fromIntegral $ f origin) -type NullableBytes = Maybe ByteString +-- | Internal type to help parse RecordBatch +-- +-- RecordBatch = RecordBase + CRC32 + RecordBatchHead + recordsCount + recordsData +data RecordBatchHead = RecordBatchHead + { attributes :: {-# UNPACK #-} !Attributes + , lastOffsetDelta :: {-# UNPACK #-} !Int32 + , baseTimestamp :: {-# UNPACK #-} !Int64 + , maxTimestamp :: {-# UNPACK #-} !Int64 + , producerId :: {-# UNPACK #-} !Int64 + , producerEpoch :: {-# UNPACK #-} !Int16 + , baseSequence :: {-# UNPACK #-} !Int32 + } deriving (Generic, Show, Eq) -newtype CompactBytes = CompactBytes { unCompactBytes :: ByteString } - deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup) +instance NFData RecordBatchHead +instance Serializable RecordBatchHead -newtype CompactNullableBytes = CompactNullableBytes - { unCompactNullableBytes :: Maybe ByteString } - deriving newtype (Show, Eq, Ord) +newtype Attributes = Attributes Int16 + deriving newtype (Show, Eq, NFData, Serializable) --- TODO: Currently we just ignore the tagged fields -data TaggedFields = EmptyTaggedFields - deriving (Show, Eq) +compressionCodecMask :: Int16 +compressionCodecMask = 0x07 -newtype KaArray a = KaArray - { unKaArray :: Maybe (Vector a) } - deriving newtype (Show, Eq, Ord, NFData) +timestampTypeMask :: Int16 +timestampTypeMask = 0x08 -instance Functor KaArray where - fmap f (KaArray xs) = KaArray $ fmap f <$> xs +transactionalFlagMask :: Int16 +transactionalFlagMask = 0x10 -newtype CompactKaArray a = CompactKaArray - { unCompactKaArray :: Maybe (Vector a) } - deriving newtype (Show, Eq, Ord) +controlFlagMask :: Int16 +controlFlagMask = 0x20 -instance Functor CompactKaArray where - fmap f (CompactKaArray xs) = CompactKaArray $ fmap f <$> xs +deleteHorizonFlagMask :: Int16 +deleteHorizonFlagMask = 0x40 -newtype RecordKey = RecordKey { unRecordKey :: Maybe ByteString } - deriving newtype (Show, Eq, Ord, NFData) +data CompressionType + = CompressionTypeNone + | CompressionTypeGzip + | CompressionTypeSnappy + | CompressionTypeLz4 + | CompressionTypeZstd + deriving (Show, Eq) -newtype RecordValue = RecordValue { unRecordValue :: Maybe ByteString } - deriving newtype (Show, Eq, Ord, NFData) +instance Enum CompressionType where + toEnum 0 = CompressionTypeNone + toEnum 1 = CompressionTypeGzip + toEnum 2 = CompressionTypeSnappy + toEnum 3 = CompressionTypeLz4 + toEnum 4 = CompressionTypeZstd + toEnum x = error $ "Unknown compression type id: " <> show x + {-# INLINE toEnum #-} + + fromEnum CompressionTypeNone = 0 + fromEnum CompressionTypeGzip = 1 + fromEnum CompressionTypeSnappy = 2 + fromEnum CompressionTypeLz4 = 3 + fromEnum CompressionTypeZstd = 4 + {-# INLINE fromEnum #-} + +data TimestampType + = TimestampTypeCreateTime + | TimestampTypeLogAppendTime + | TimestampTypeNone + deriving (Show, Eq) -newtype RecordArray a = RecordArray { unRecordArray :: Vector a } - deriving newtype (Show, Eq, Ord, NFData) +instance Enum TimestampType where + toEnum 0 = TimestampTypeCreateTime + toEnum 1 = TimestampTypeLogAppendTime + toEnum (-1) = TimestampTypeNone + toEnum x = error $ "Invalid timestamp type: " <> show x + {-# INLINE toEnum #-} -newtype RecordHeaderKey = RecordHeaderKey { unRecordHeaderKey :: Text } - deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup, NFData) + fromEnum TimestampTypeCreateTime = 0 + fromEnum TimestampTypeLogAppendTime = 1 + fromEnum TimestampTypeNone = (-1) + {-# INLINE fromEnum #-} -newtype RecordHeaderValue = RecordHeaderValue - { unRecordHeaderValue :: Maybe ByteString } - deriving newtype (Show, Eq, Ord, NFData) +compressionType :: Attributes -> CompressionType +compressionType (Attributes attr) = + toEnum $ fromIntegral (attr .&. compressionCodecMask) +{-# INLINE compressionType #-} -------------------------------------------------------------------------------- --- Instances - -#define INSTANCE(ty, n, getfun, patmt, pat) \ -instance Serializable ty where \ - get = getfun get##n; \ - {-# INLINE get #-}; \ - put patmt = put##n pat; \ - {-# INLINE put #-} - -#define INSTANCE_BUILTIN(t) INSTANCE(t, t, , , ) -#define INSTANCE_BUILTIN_1(t, n) INSTANCE(t, n, , , ) -#define INSTANCE_NEWTYPE(t) INSTANCE(t, t, t <$>, (t x), x) -#define INSTANCE_NEWTYPE_1(t, n) INSTANCE(t, n, t <$>, (t x), x) - -INSTANCE_BUILTIN(Bool) -INSTANCE_BUILTIN(Int8) -INSTANCE_BUILTIN(Int16) -INSTANCE_BUILTIN(Int32) -INSTANCE_BUILTIN(Int64) -INSTANCE_BUILTIN(Word32) -INSTANCE_BUILTIN(Double) -INSTANCE_BUILTIN(NullableString) -INSTANCE_BUILTIN(NullableBytes) -INSTANCE_BUILTIN_1(Text, String) -INSTANCE_BUILTIN_1(ByteString, Bytes) - -INSTANCE_NEWTYPE(VarInt32) -INSTANCE_NEWTYPE(VarInt64) -INSTANCE_NEWTYPE(CompactString) -INSTANCE_NEWTYPE(CompactNullableString) -INSTANCE_NEWTYPE(CompactBytes) -INSTANCE_NEWTYPE(CompactNullableBytes) - -INSTANCE_NEWTYPE_1(RecordKey, RecordNullableBytes) -INSTANCE_NEWTYPE_1(RecordValue, RecordNullableBytes) -INSTANCE_NEWTYPE_1(RecordHeaderKey, RecordString) -INSTANCE_NEWTYPE_1(RecordHeaderValue, RecordNullableBytes) - -instance Serializable TaggedFields where - get = do !n <- fromIntegral <$> getVarWord32 - replicateM_ n $ do - tag <- getVarWord32 - dataLen <- getVarWord32 - val <- takeBytes (fromIntegral dataLen) - pure (tag, val) - pure EmptyTaggedFields - {-# INLINE get #-} - - put _ = putVarWord32 0 - {-# INLINE put #-} - -instance Serializable a => Serializable (KaArray a) where - get = getArray - {-# INLINE get #-} - put = putArray - {-# INLINE put #-} - -instance Serializable a => Serializable (CompactKaArray a) where - get = CompactKaArray <$> getCompactArray - {-# INLINE get #-} - put (CompactKaArray xs) = putCompactArray xs - {-# INLINE put #-} - -instance Serializable a => Serializable (RecordArray a) where - get = RecordArray <$> getRecordArray - {-# INLINE get #-} - put (RecordArray xs) = putRecordArray xs - {-# INLINE put #-} - -instance - ( Serializable a - , Serializable b - ) => Serializable (a, b) -instance - ( Serializable a - , Serializable b - , Serializable c - ) => Serializable (a, b, c) -instance - ( Serializable a - , Serializable b - , Serializable c - , Serializable d - ) => Serializable (a, b, c, d) -instance - ( Serializable a - , Serializable b - , Serializable c - , Serializable d - , Serializable e - ) => Serializable (a, b, c, d, e) +timestampType :: Attributes -> TimestampType +timestampType (Attributes attr) = toEnum $ fromIntegral (attr .&. timestampTypeMask) +{-# INLINE timestampType #-} ------------------------------------------------------------------------------- -- Records @@ -370,7 +324,7 @@ instance data BatchRecord = BatchRecordV0 RecordV0 | BatchRecordV1 RecordV1 - | BatchRecordV2 RecordBatch + | BatchRecordV2 RecordBatch1 deriving (Show, Eq, Generic) instance NFData BatchRecord @@ -396,12 +350,12 @@ decodeBatchRecords' shouldValidateCrc batchBs = Growing.new >>= decode 0 batchBs 0 -> do let crc = partitionLeaderEpochOrCrc messageSize = batchLength when (messageSize < fromIntegral minRecordSizeV0) $ - throwIO $ DecodeError $ "Invalid messageSize" + throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid messageSize") when shouldValidateCrc $ do - -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' - -- -- The crc field contains the CRC32 (and not CRC-32C) of the -- subsequent message bytes (i.e. from magic byte to the value). + -- + -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' validLegacyCrc (fromIntegral batchLength) crc bs (RecordBodyV0{..}, remainder) <- runGet' @RecordBodyV0 bs' !v' <- Growing.append v (BatchRecordV0 RecordV0{..}) @@ -409,12 +363,12 @@ decodeBatchRecords' shouldValidateCrc batchBs = Growing.new >>= decode 0 batchBs 1 -> do let crc = partitionLeaderEpochOrCrc messageSize = batchLength when (messageSize < fromIntegral minRecordSizeV1) $ - throwIO $ DecodeError $ "Invalid messageSize" + throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid messageSize") when shouldValidateCrc $ do - -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' - -- -- The crc field contains the CRC32 (and not CRC-32C) of the -- subsequent message bytes (i.e. from magic byte to the value). + -- + -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' validLegacyCrc (fromIntegral batchLength) crc bs (RecordBodyV1{..}, remainder) <- runGet' @RecordBodyV1 bs' !v' <- Growing.append v (BatchRecordV1 RecordV1{..}) @@ -429,14 +383,14 @@ decodeBatchRecords' shouldValidateCrc batchBs = Growing.new >>= decode 0 batchBs when shouldValidateCrc $ do let crcPayload = BS.take (fromIntegral batchLength - 9) bs'' when (fromIntegral (crc32c crcPayload) /= crc) $ - throwIO $ DecodeError "Invalid CRC32" + throwIO $ DecodeError (CORRUPT_MESSAGE, "Invalid CRC32") (RecordBodyV2{..}, remainder) <- runGet' @RecordBodyV2 bs'' - !v' <- Growing.append v (BatchRecordV2 RecordBatch{..}) + !v' <- Growing.append v (BatchRecordV2 RecordBatch1{..}) let !batchLen = maybe 0 V.length (unKaArray records) -- Actually, there should be only one batch record here, but -- we don't require it. decode (len + batchLen) remainder v' - _ -> throwIO $ DecodeError $ "Invalid magic " <> show magic + _ -> throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid magic " <> show magic) {-# INLINABLE decodeBatchRecords' #-} encodeBatchRecordsLazy :: Vector BatchRecord -> BL.ByteString @@ -515,66 +469,11 @@ decodeNextRecordOffset bs = fst <$> runParser' parser' bs _ -> fail $ "Invalid magic " <> show magic {-# INLINE decodeNextRecordOffset #-} --- Internal type to help parse all Record version. --- --- Common Record base for all versions. -data RecordBase = RecordBase - { baseOffset :: {-# UNPACK #-} !Int64 - , batchLength :: {-# UNPACK #-} !Int32 - -- ^ The total size of the record batch in bytes - -- (from partitionLeaderEpochOrCrc to the end) - , partitionLeaderEpochOrCrc :: {-# UNPACK #-} !Int32 - -- ^ For version 0-1, this is the CRC32 of the remainder of the record. - -- For version 2, this is the partition leader epoch. - , magic :: {-# UNPACK #-} !Int8 - } deriving (Generic, Show, Eq) - -instance Serializable RecordBase - --- Internal type to help parse all Record version. --- --- RecordV0 = RecordBase + RecordBodyV0 -data RecordBodyV0 = RecordBodyV0 - { attributes :: {-# UNPACK #-} !Int8 - , key :: !NullableBytes - , value :: !NullableBytes - } deriving (Generic, Show, Eq) - -instance Serializable RecordBodyV0 - --- Internal type to help parse all Record version. --- --- RecordV1 = RecordBase + RecordBodyV1 -data RecordBodyV1 = RecordBodyV1 - { attributes :: {-# UNPACK #-} !Int8 - , timestamp :: {-# UNPACK #-} !Int64 - , key :: !NullableBytes - , value :: !NullableBytes - } deriving (Generic, Show, Eq) - -instance Serializable RecordBodyV1 - --- Internal type to help parse all Record version. --- --- RecordBatch = RecordBase + CRC32 + RecordBodyV2 -data RecordBodyV2 = RecordBodyV2 - { attributes :: {-# UNPACK #-} !Int16 - , lastOffsetDelta :: {-# UNPACK #-} !Int32 - , baseTimestamp :: {-# UNPACK #-} !Int64 - , maxTimestamp :: {-# UNPACK #-} !Int64 - , producerId :: {-# UNPACK #-} !Int64 - , producerEpoch :: {-# UNPACK #-} !Int16 - , baseSequence :: {-# UNPACK #-} !Int32 - , records :: !(KaArray RecordV2) - } deriving (Generic, Show, Eq) - -instance Serializable RecordBodyV2 - validLegacyCrc :: Int -> Int32 -> ByteString -> IO () validLegacyCrc batchLength crc bs = do crcPayload <- getLegacyCrcPayload batchLength bs when (fromIntegral (crc32 crcPayload) /= crc) $ - throwIO $ DecodeError "Invalid CRC32" + throwIO $ DecodeError (CORRUPT_MESSAGE, "Invalid CRC32") {-# INLINE validLegacyCrc #-} getLegacyCrcPayload :: Int -> ByteString -> IO ByteString @@ -584,6 +483,89 @@ getLegacyCrcPayload msgSize bs = in fst <$> runParser' parser bs {-# INLINE getLegacyCrcPayload #-} +------------------------------------------------------------------------------- +-- For Handler + +-- FIXME: support magic 0 and 1 are incomplete, donot use it. +decodeBatchRecordsForProduce :: Bool -> ByteString -> IO (Int, [Int]) +decodeBatchRecordsForProduce shouldValidateCrc = decode 0 0 [] + where + decode len _consumed offsetOffsets "" = pure (len, offsetOffsets) + decode !len !consumed !offsetOffsets !bs = do + (RecordBase{..}, bs') <- runGet' @RecordBase bs + case magic of + 2 -> do let partitionLeaderEpoch = partitionLeaderEpochOrCrc + -- The CRC covers the data from the attributes to the end of + -- the batch (i.e. all the bytes that follow the CRC). + -- + -- The CRC-32C (Castagnoli) polynomial is used for the + -- computation. + (crc, bs'') <- runGet' @Int32 bs' + when shouldValidateCrc $ do + let crcPayload = BS.take (fromIntegral batchLength - 9) bs'' + when (fromIntegral (crc32c crcPayload) /= crc) $ + throwIO $ DecodeError (CORRUPT_MESSAGE, "Invalid CRC32") + (batchRecordsLen, remainder) <- runParser' + (do batchRecordsLen <- unsafePeekInt32At 36 + -- 36: attributes(2) + lastOffsetDelta(4) + -- + baseTimestamp(8) + maxTimestamp(8) + -- + producerId(8) + producerEpoch(2) + -- + baseSequence(4) + dropBytes (fromIntegral batchLength - 9) + pure batchRecordsLen + ) bs'' + let batchRecordsLen' = if batchRecordsLen >= 0 + then fromIntegral batchRecordsLen + else 0 + -- Actually, there should be only one batch record here, but + -- we don't require it. + decode (len + batchRecordsLen') + (consumed + fromIntegral batchLength + 12) + (consumed:offsetOffsets) + remainder + 0 -> do let crc = partitionLeaderEpochOrCrc + messageSize = batchLength + when (messageSize < fromIntegral minRecordSizeV0) $ + throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid messageSize") + when shouldValidateCrc $ do + -- The crc field contains the CRC32 (and not CRC-32C) of the + -- subsequent message bytes (i.e. from magic byte to the value). + -- + -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' + validLegacyCrc (fromIntegral batchLength) crc bs + let totalSize = fromIntegral $ messageSize + 12 + remainder <- snd <$> runParser' (dropBytes totalSize) bs + decode (len + 1) (consumed + totalSize) + (consumed:offsetOffsets) remainder + 1 -> do let crc = partitionLeaderEpochOrCrc + messageSize = batchLength + when (messageSize < fromIntegral minRecordSizeV1) $ + throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid messageSize") + when shouldValidateCrc $ do + -- The crc field contains the CRC32 (and not CRC-32C) of the + -- subsequent message bytes (i.e. from magic byte to the value). + -- + -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' + validLegacyCrc (fromIntegral batchLength) crc bs + let totalSize = fromIntegral $ messageSize + 12 + remainder <- snd <$> runParser' (dropBytes totalSize) bs + decode (len + 1) (consumed + totalSize) + (consumed:offsetOffsets) remainder + _ -> throwIO $ DecodeError $ (CORRUPT_MESSAGE, "Invalid magic " <> show magic) +{-# INLINABLE decodeBatchRecordsForProduce #-} + +unsafeAlterBatchRecordsBsForProduce + :: (Int64 -> Int64) -- Update baseOffsets + -> [Int] -- All bytes offsets of baseOffset + -> ByteString + -> IO () +unsafeAlterBatchRecordsBsForProduce boof boos bs@(BS fp len) = do + unsafeWithForeignPtr fp $ \p -> do + forM_ boos $ \boo -> do + -- FIXME improvement: does ghc provide a modifyPtr function? + origin <- fromIntegral <$> peek64BE (p `plusPtr` boo) + poke64BE (p `plusPtr` boo) (fromIntegral $ boof origin) + ------------------------------------------------------------------------------- -- LegacyRecord(MessageSet): v0-1 -- @@ -593,6 +575,17 @@ getLegacyCrcPayload msgSize bs = -- (which is indicated in the magic value) was 0. Message format version 1 was -- introduced with timestamp support in version 0.10. +-- Internal type to help parse RecordV0 +-- +-- RecordV0 = RecordBase + RecordBodyV0 +data RecordBodyV0 = RecordBodyV0 + { attributes :: {-# UNPACK #-} !Int8 + , key :: !NullableBytes + , value :: !NullableBytes + } deriving (Generic, Show, Eq) + +instance Serializable RecordBodyV0 + data RecordV0 = RecordV0 { baseOffset :: {-# UNPACK #-} !Int64 , messageSize :: {-# UNPACK #-} !Int32 @@ -610,6 +603,18 @@ minRecordSizeV0 :: Int minRecordSizeV0 = 4{- crc -} + 1{- magic -} + 1{- attributes -} + 4{- key -} + 4{- value -} +-- Internal type to help parse all RecordV1 +-- +-- RecordV1 = RecordBase + RecordBodyV1 +data RecordBodyV1 = RecordBodyV1 + { attributes :: {-# UNPACK #-} !Int8 + , timestamp :: {-# UNPACK #-} !Int64 + , key :: !NullableBytes + , value :: !NullableBytes + } deriving (Generic, Show, Eq) + +instance Serializable RecordBodyV1 + data RecordV1 = RecordV1 { baseOffset :: {-# UNPACK #-} !Int64 , messageSize :: {-# UNPACK #-} !Int32 @@ -656,11 +661,43 @@ unsafeAlterMessageSetBs (BS fp len) = unsafeWithForeignPtr fp $ \p -> do poke32BE (p `plusPtr` 12) crc ------------------------------------------------------------------------------- --- RecordBatch: v2 --- --- Ref: https://kafka.apache.org/documentation/#recordbatch +-- RecordBatch: v2 (TODO: Outdated) + +-- TODO: Outdated type, use RecordBatch instead. +data RecordBatch1 = RecordBatch1 + { baseOffset :: {-# UNPACK #-} !Int64 + , batchLength :: {-# UNPACK #-} !Int32 + , partitionLeaderEpoch :: {-# UNPACK #-} !Int32 + , magic :: {-# UNPACK #-} !Int8 + , crc :: {-# UNPACK #-} !Int32 + , attributes :: {-# UNPACK #-} !Int16 + , lastOffsetDelta :: {-# UNPACK #-} !Int32 + , baseTimestamp :: {-# UNPACK #-} !Int64 + , maxTimestamp :: {-# UNPACK #-} !Int64 + , producerId :: {-# UNPACK #-} !Int64 + , producerEpoch :: {-# UNPACK #-} !Int16 + , baseSequence :: {-# UNPACK #-} !Int32 + , records :: !(KaArray RecordV2) + } deriving (Generic, Show, Eq) + +instance Serializable RecordBatch1 +instance NFData RecordBatch1 + +-- Internal type to help parse all RecordBatch -- --- Introduced in Kafka 0.11.0 +-- RecordBatch = RecordBase + CRC32 + RecordBodyV2 +data RecordBodyV2 = RecordBodyV2 + { attributes :: {-# UNPACK #-} !Int16 + , lastOffsetDelta :: {-# UNPACK #-} !Int32 + , baseTimestamp :: {-# UNPACK #-} !Int64 + , maxTimestamp :: {-# UNPACK #-} !Int64 + , producerId :: {-# UNPACK #-} !Int64 + , producerEpoch :: {-# UNPACK #-} !Int16 + , baseSequence :: {-# UNPACK #-} !Int32 + , records :: !(KaArray RecordV2) + } deriving (Generic, Show, Eq) + +instance Serializable RecordBodyV2 type RecordHeader = (RecordHeaderKey, RecordHeaderValue) @@ -700,25 +737,6 @@ data RecordV2_ = RecordV2_ instance Serializable RecordV2_ instance NFData RecordV2_ -data RecordBatch = RecordBatch - { baseOffset :: {-# UNPACK #-} !Int64 - , batchLength :: {-# UNPACK #-} !Int32 - , partitionLeaderEpoch :: {-# UNPACK #-} !Int32 - , magic :: {-# UNPACK #-} !Int8 - , crc :: {-# UNPACK #-} !Int32 - , attributes :: {-# UNPACK #-} !Int16 - , lastOffsetDelta :: {-# UNPACK #-} !Int32 - , baseTimestamp :: {-# UNPACK #-} !Int64 - , maxTimestamp :: {-# UNPACK #-} !Int64 - , producerId :: {-# UNPACK #-} !Int64 - , producerEpoch :: {-# UNPACK #-} !Int16 - , baseSequence :: {-# UNPACK #-} !Int32 - , records :: !(KaArray RecordV2) - } deriving (Generic, Show, Eq) - -instance Serializable RecordBatch -instance NFData RecordBatch - -- | The same as 'RecordBatch' but without records. -- -- This may useful for constructing 'RecordBatch' bytes. @@ -770,146 +788,3 @@ kaArrayToCompact = CompactKaArray . unKaArray kaArrayFromCompact :: CompactKaArray a -> KaArray a kaArrayFromCompact = KaArray . unCompactKaArray {-# INLINE kaArrayFromCompact #-} - --- Currently this is used for Produce handler. -decodeBatchRecordsForProduce :: Bool -> ByteString -> IO (Int, [Int]) -decodeBatchRecordsForProduce shouldValidateCrc = decode 0 0 [] - where - decode len _consumed offsetOffsets "" = pure (len, offsetOffsets) - decode !len !consumed !offsetOffsets !bs = do - (RecordBase{..}, bs') <- runGet' @RecordBase bs - case magic of - 0 -> do let crc = partitionLeaderEpochOrCrc - messageSize = batchLength - when (messageSize < fromIntegral minRecordSizeV0) $ - throwIO $ DecodeError $ "Invalid messageSize" - when shouldValidateCrc $ do - -- The crc field contains the CRC32 (and not CRC-32C) of the - -- subsequent message bytes (i.e. from magic byte to the value). - -- - -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' - validLegacyCrc (fromIntegral batchLength) crc bs - let totalSize = fromIntegral $ messageSize + 12 - remainder <- snd <$> runParser' (dropBytes totalSize) bs - decode (len + 1) (consumed + totalSize) - (consumed:offsetOffsets) remainder - 1 -> do let crc = partitionLeaderEpochOrCrc - messageSize = batchLength - when (messageSize < fromIntegral minRecordSizeV1) $ - throwIO $ DecodeError $ "Invalid messageSize" - when shouldValidateCrc $ do - -- The crc field contains the CRC32 (and not CRC-32C) of the - -- subsequent message bytes (i.e. from magic byte to the value). - -- - -- NOTE: pass the origin inputs to validLegacyCrc, not the bs' - validLegacyCrc (fromIntegral batchLength) crc bs - let totalSize = fromIntegral $ messageSize + 12 - remainder <- snd <$> runParser' (dropBytes totalSize) bs - decode (len + 1) (consumed + totalSize) - (consumed:offsetOffsets) remainder - 2 -> do let partitionLeaderEpoch = partitionLeaderEpochOrCrc - -- The CRC covers the data from the attributes to the end of - -- the batch (i.e. all the bytes that follow the CRC). - -- - -- The CRC-32C (Castagnoli) polynomial is used for the - -- computation. - (crc, bs'') <- runGet' @Int32 bs' - when shouldValidateCrc $ do - let crcPayload = BS.take (fromIntegral batchLength - 9) bs'' - when (fromIntegral (crc32c crcPayload) /= crc) $ - throwIO $ DecodeError "Invalid CRC32" - (batchRecordsLen, remainder) <- runParser' - (do batchRecordsLen <- unsafePeekInt32At 36 - -- 36: attributes(2) + lastOffsetDelta(4) - -- + baseTimestamp(8) + maxTimestamp(8) - -- + producerId(8) + producerEpoch(2) - -- + baseSequence(4) - dropBytes (fromIntegral batchLength - 9) - pure batchRecordsLen - ) bs'' - let batchRecordsLen' = if batchRecordsLen >= 0 - then fromIntegral batchRecordsLen - else 0 - -- Actually, there should be only one batch record here, but - -- we don't require it. - decode (len + batchRecordsLen') - (consumed + fromIntegral batchLength + 12) - (consumed:offsetOffsets) - remainder - _ -> throwIO $ DecodeError $ "Invalid magic " <> show magic -{-# INLINABLE decodeBatchRecordsForProduce #-} - -unsafeAlterBatchRecordsBsForProduce - :: (Int64 -> Int64) -- Update baseOffsets - -> [Int] -- All bytes offsets of baseOffset - -> ByteString - -> IO () -unsafeAlterBatchRecordsBsForProduce boof boos bs@(BS fp len) = do - unsafeWithForeignPtr fp $ \p -> do - forM_ boos $ \boo -> do - -- FIXME improvement: does ghc provide a modifyPtr function? - origin <- fromIntegral <$> peek64BE (p `plusPtr` boo) - poke64BE (p `plusPtr` boo) (fromIntegral $ boof origin) - -------------------------------------------------------------------------------- --- Internals - --- | Represents a sequence of objects of a given type T. --- --- Type T can be either a primitive type (e.g. STRING) or a structure. --- First, the length N is given as an INT32. Then N instances of type T follow. --- A null array is represented with a length of -1. In protocol documentation --- an array of T instances is referred to as [T]. -getArray :: Serializable a => Parser (KaArray a) -getArray = do - !n <- getInt32 - if n >= 0 - then KaArray . Just <$!> V.replicateM (fromIntegral n) get - else do - if n == (-1) - then pure $ KaArray Nothing - else fail $! "Length of null array must be -1 " <> show n - --- | Represents a sequence of objects of a given type T. --- --- Type T can be either a primitive type (e.g. STRING) or a structure. First, --- the length N + 1 is given as an UNSIGNED_VARINT. Then N instances of type T --- follow. A null array is represented with a length of 0. In protocol --- documentation an array of T instances is referred to as [T]. -getCompactArray :: Serializable a => Parser (Maybe (Vector a)) -getCompactArray = do - !n_1 <- fromIntegral <$> getVarWord32 - let !n = n_1 - 1 - if n >= 0 - then Just <$!> V.replicateM n get - else do - if n == (-1) - then pure Nothing - else fail $! "Length of null compact array must be -1 " <> show n - -putArray :: Serializable a => KaArray a -> Builder -putArray (KaArray (Just xs)) = - let !len = V.length xs - put_len = putInt32 (fromIntegral len) - in put_len <> V.foldl' (\s x -> s <> put x) mempty xs -putArray (KaArray Nothing) = putInt32 (-1) - -putCompactArray :: Serializable a => Maybe (Vector a) -> Builder -putCompactArray (Just xs) = - let !len = V.length xs - put_len = putVarWord32 (fromIntegral len + 1) - in put_len <> V.foldl' (\s x -> s <> put x) mempty xs -putCompactArray Nothing = putVarWord32 0 - -getRecordArray :: Serializable a => Parser (Vector a) -getRecordArray = do - !n <- fromIntegral <$> getVarInt32 - if | n > 0 -> V.replicateM n get - | n == 0 -> pure V.empty - | otherwise -> fail $! "Length of RecordArray must not be negative " <> show n - -putRecordArray :: Serializable a => Vector a -> Builder -putRecordArray xs = - let !len = V.length xs - put_len = putVarInt32 (fromIntegral len) - in put_len <> V.foldl' (\s x -> s <> put x) mempty xs diff --git a/hstream-kafka/protocol/Kafka/Protocol/Encoding/Types.hs b/hstream-kafka/protocol/Kafka/Protocol/Encoding/Types.hs new file mode 100644 index 000000000..0e9fcfe6b --- /dev/null +++ b/hstream-kafka/protocol/Kafka/Protocol/Encoding/Types.hs @@ -0,0 +1,332 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE MultiWayIf #-} +-- As of GHC 8.8.1, GHC started complaining about -optP--cpp when profling +-- is enabled. See https://gitlab.haskell.org/ghc/ghc/issues/17185. +{-# OPTIONS_GHC -pgmP "hpp --cpp -P" #-} + +module Kafka.Protocol.Encoding.Types + ( Serializable (..) + , putEither, getEither + , putMaybe, getMaybe + -- * Defined types + , VarInt32 (..) + , VarInt64 (..) + , NullableString + , CompactString (..) + , CompactNullableString (..) + , NullableBytes + , CompactBytes (..) + , CompactNullableBytes (..) + , TaggedFields (EmptyTaggedFields) -- TODO + , KaArray (..) + , CompactKaArray (..) + , RecordKey (..) + , RecordValue (..) + , RecordHeaderKey (..) + , RecordArray (..) + , RecordHeaderValue (..) + ) where + +import Control.DeepSeq (NFData) +import Control.Monad +import Data.ByteString (ByteString) +import Data.Int +import Data.String (IsString) +import Data.Text (Text) +import Data.Typeable (Typeable) +import Data.Vector (Vector) +import qualified Data.Vector as V +import Data.Word +import GHC.Generics + +import Kafka.Protocol.Encoding.Encode +import Kafka.Protocol.Encoding.Parser + +------------------------------------------------------------------------------- + +class Typeable a => Serializable a where + get :: Parser a + + default get :: (Generic a, GSerializable (Rep a)) => Parser a + get = to <$> gget + + put :: a -> Builder + + default put :: (Generic a, GSerializable (Rep a)) => a -> Builder + put a = gput (from a) + +class GSerializable f where + gget :: Parser (f a) + gput :: f a -> Builder + +-- | Unit: used for constructors without arguments +instance GSerializable U1 where + gget = pure U1 + gput U1 = mempty + +-- | Products: encode multiple arguments to constructors +instance (GSerializable a, GSerializable b) => GSerializable (a :*: b) where + gget = do + !a <- gget + !b <- gget + pure $ a :*: b + gput (a :*: b) = gput a <> gput b + +-- | Meta-information +instance (GSerializable a) => GSerializable (M1 i c a) where + gget = M1 <$> gget + gput (M1 x) = gput x + +instance (Serializable a) => GSerializable (K1 i a) where + gget = K1 <$> get + gput (K1 x) = put x + +-- There is no easy way to support Sum types for Generic instance. +-- +-- So here we give a special case for Either +putEither :: (Serializable a, Serializable b) => Either a b -> Builder +putEither (Left x) = put x +putEither (Right x) = put x +{-# INLINE putEither #-} + +-- There is no way to support Sum types for Generic instance. +-- +-- So here we give a special case for Either +getEither + :: (Serializable a, Serializable b) + => Bool -- ^ True for Right, False for Left + -> Parser (Either a b) +getEither True = Right <$> get +getEither False = Left <$> get +{-# INLINE getEither #-} + +putMaybe :: (Serializable a) => Maybe a -> Builder +putMaybe (Just x) = put x +putMaybe Nothing = mempty +{-# INLINE putMaybe #-} + +getMaybe + :: (Serializable a) + => Bool -- ^ True for Just, False for Nothing + -> Parser (Maybe a) +getMaybe True = Just <$> get +getMaybe False = pure Nothing +{-# INLINE getMaybe #-} + +------------------------------------------------------------------------------- +-- Extra Primitive Types + +newtype VarInt32 = VarInt32 { unVarInt32 :: Int32 } + deriving newtype (Show, Num, Integral, Real, Enum, Ord, Eq, Bounded, NFData) + +newtype VarInt64 = VarInt64 { unVarInt64 :: Int64 } + deriving newtype (Show, Num, Integral, Real, Enum, Ord, Eq, Bounded, NFData) + +type NullableString = Maybe Text + +newtype CompactString = CompactString { unCompactString :: Text } + deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup) + +newtype CompactNullableString = CompactNullableString + { unCompactNullableString :: Maybe Text } + deriving newtype (Show, Eq, Ord) + +type NullableBytes = Maybe ByteString + +newtype CompactBytes = CompactBytes { unCompactBytes :: ByteString } + deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup) + +newtype CompactNullableBytes = CompactNullableBytes + { unCompactNullableBytes :: Maybe ByteString } + deriving newtype (Show, Eq, Ord) + +-- TODO: Currently we just ignore the tagged fields +data TaggedFields = EmptyTaggedFields + deriving (Show, Eq) + +newtype KaArray a = KaArray + { unKaArray :: Maybe (Vector a) } + deriving newtype (Show, Eq, Ord, NFData) + +instance Functor KaArray where + fmap f (KaArray xs) = KaArray $ fmap f <$> xs + +newtype CompactKaArray a = CompactKaArray + { unCompactKaArray :: Maybe (Vector a) } + deriving newtype (Show, Eq, Ord) + +instance Functor CompactKaArray where + fmap f (CompactKaArray xs) = CompactKaArray $ fmap f <$> xs + +newtype RecordKey = RecordKey { unRecordKey :: Maybe ByteString } + deriving newtype (Show, Eq, Ord, NFData) + +newtype RecordValue = RecordValue { unRecordValue :: Maybe ByteString } + deriving newtype (Show, Eq, Ord, NFData) + +newtype RecordArray a = RecordArray { unRecordArray :: Vector a } + deriving newtype (Show, Eq, Ord, NFData) + +newtype RecordHeaderKey = RecordHeaderKey { unRecordHeaderKey :: Text } + deriving newtype (Show, Eq, Ord, IsString, Monoid, Semigroup, NFData) + +newtype RecordHeaderValue = RecordHeaderValue + { unRecordHeaderValue :: Maybe ByteString } + deriving newtype (Show, Eq, Ord, NFData) + +------------------------------------------------------------------------------- +-- Instances + +#define INSTANCE(ty, n, getfun, patmt, pat) \ +instance Serializable ty where \ + get = getfun get##n; \ + {-# INLINE get #-}; \ + put patmt = put##n pat; \ + {-# INLINE put #-} + +#define INSTANCE_BUILTIN(t) INSTANCE(t, t, , , ) +#define INSTANCE_BUILTIN_1(t, n) INSTANCE(t, n, , , ) +#define INSTANCE_NEWTYPE(t) INSTANCE(t, t, t <$>, (t x), x) +#define INSTANCE_NEWTYPE_1(t, n) INSTANCE(t, n, t <$>, (t x), x) + +INSTANCE_BUILTIN(Bool) +INSTANCE_BUILTIN(Int8) +INSTANCE_BUILTIN(Int16) +INSTANCE_BUILTIN(Int32) +INSTANCE_BUILTIN(Int64) +INSTANCE_BUILTIN(Word32) +INSTANCE_BUILTIN(Double) +INSTANCE_BUILTIN(NullableString) +INSTANCE_BUILTIN(NullableBytes) +INSTANCE_BUILTIN_1(Text, String) +INSTANCE_BUILTIN_1(ByteString, Bytes) + +INSTANCE_NEWTYPE(VarInt32) +INSTANCE_NEWTYPE(VarInt64) +INSTANCE_NEWTYPE(CompactString) +INSTANCE_NEWTYPE(CompactNullableString) +INSTANCE_NEWTYPE(CompactBytes) +INSTANCE_NEWTYPE(CompactNullableBytes) + +INSTANCE_NEWTYPE_1(RecordKey, RecordNullableBytes) +INSTANCE_NEWTYPE_1(RecordValue, RecordNullableBytes) +INSTANCE_NEWTYPE_1(RecordHeaderKey, RecordString) +INSTANCE_NEWTYPE_1(RecordHeaderValue, RecordNullableBytes) + +instance Serializable TaggedFields where + get = do !n <- fromIntegral <$> getVarWord32 + replicateM_ n $ do + tag <- getVarWord32 + dataLen <- getVarWord32 + val <- takeBytes (fromIntegral dataLen) + pure (tag, val) + pure EmptyTaggedFields + {-# INLINE get #-} + + put _ = putVarWord32 0 + {-# INLINE put #-} + +instance Serializable a => Serializable (KaArray a) where + get = getArray + {-# INLINE get #-} + put = putArray + {-# INLINE put #-} + +instance Serializable a => Serializable (CompactKaArray a) where + get = CompactKaArray <$> getCompactArray + {-# INLINE get #-} + put (CompactKaArray xs) = putCompactArray xs + {-# INLINE put #-} + +instance Serializable a => Serializable (RecordArray a) where + get = RecordArray <$> getRecordArray + {-# INLINE get #-} + put (RecordArray xs) = putRecordArray xs + {-# INLINE put #-} + +instance + ( Serializable a + , Serializable b + ) => Serializable (a, b) +instance + ( Serializable a + , Serializable b + , Serializable c + ) => Serializable (a, b, c) +instance + ( Serializable a + , Serializable b + , Serializable c + , Serializable d + ) => Serializable (a, b, c, d) +instance + ( Serializable a + , Serializable b + , Serializable c + , Serializable d + , Serializable e + ) => Serializable (a, b, c, d, e) + +------------------------------------------------------------------------------- +-- Internals + +-- | Represents a sequence of objects of a given type T. +-- +-- Type T can be either a primitive type (e.g. STRING) or a structure. +-- First, the length N is given as an INT32. Then N instances of type T follow. +-- A null array is represented with a length of -1. In protocol documentation +-- an array of T instances is referred to as [T]. +getArray :: Serializable a => Parser (KaArray a) +getArray = do + !n <- getInt32 + if n >= 0 + then KaArray . Just <$!> V.replicateM (fromIntegral n) get + else do + if n == (-1) + then pure $ KaArray Nothing + else fail $! "Length of null array must be -1 " <> show n + +-- | Represents a sequence of objects of a given type T. +-- +-- Type T can be either a primitive type (e.g. STRING) or a structure. First, +-- the length N + 1 is given as an UNSIGNED_VARINT. Then N instances of type T +-- follow. A null array is represented with a length of 0. In protocol +-- documentation an array of T instances is referred to as [T]. +getCompactArray :: Serializable a => Parser (Maybe (Vector a)) +getCompactArray = do + !n_1 <- fromIntegral <$> getVarWord32 + let !n = n_1 - 1 + if n >= 0 + then Just <$!> V.replicateM n get + else do + if n == (-1) + then pure Nothing + else fail $! "Length of null compact array must be -1 " <> show n + +putArray :: Serializable a => KaArray a -> Builder +putArray (KaArray (Just xs)) = + let !len = V.length xs + put_len = putInt32 (fromIntegral len) + in put_len <> V.foldl' (\s x -> s <> put x) mempty xs +putArray (KaArray Nothing) = putInt32 (-1) + +putCompactArray :: Serializable a => Maybe (Vector a) -> Builder +putCompactArray (Just xs) = + let !len = V.length xs + put_len = putVarWord32 (fromIntegral len + 1) + in put_len <> V.foldl' (\s x -> s <> put x) mempty xs +putCompactArray Nothing = putVarWord32 0 + +getRecordArray :: Serializable a => Parser (Vector a) +getRecordArray = do + !n <- fromIntegral <$> getVarInt32 + if | n > 0 -> V.replicateM n get + | n == 0 -> pure V.empty + | otherwise -> fail $! "Length of RecordArray must not be negative " <> show n + +putRecordArray :: Serializable a => Vector a -> Builder +putRecordArray xs = + let !len = V.length xs + put_len = putVarInt32 (fromIntegral len) + in put_len <> V.foldl' (\s x -> s <> put x) mempty xs diff --git a/hstream-kafka/protocol/Kafka/Protocol/Error.hs b/hstream-kafka/protocol/Kafka/Protocol/Error.hs index 86beb1dd8..0a0a1d280 100644 --- a/hstream-kafka/protocol/Kafka/Protocol/Error.hs +++ b/hstream-kafka/protocol/Kafka/Protocol/Error.hs @@ -121,9 +121,9 @@ module Kafka.Protocol.Error , pattern UNSUPPORTED_ASSIGNOR ) where -import Data.Int (Int16) +import Data.Int (Int16) -import Kafka.Protocol.Encoding (Serializable) +import Kafka.Protocol.Encoding.Types (Serializable) -------------------------------------------------------------------------------