From f7ad685f8cb98c5f9137ab5569d34ddca1b68c4e Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Mon, 2 Dec 2024 00:12:36 +0100 Subject: [PATCH] 3.3: support for `null` and lenient equal (both have been implemented at the same time because the `null` sample depended on the lenient equal comparisons) --- biscuit/src/Auth/Biscuit/Datalog/AST.hs | 49 +++++++++++++++- biscuit/src/Auth/Biscuit/Datalog/Executor.hs | 5 ++ biscuit/src/Auth/Biscuit/Datalog/Parser.hs | 3 + biscuit/src/Auth/Biscuit/Proto.hs | 8 +++ biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs | 59 ++++++++++++-------- biscuit/test/Spec/Parser.hs | 18 ++++++ biscuit/test/Spec/SampleReader.hs | 2 +- 7 files changed, 119 insertions(+), 25 deletions(-) diff --git a/biscuit/src/Auth/Biscuit/Datalog/AST.hs b/biscuit/src/Auth/Biscuit/Datalog/AST.hs index 561d98a..50a0691 100644 --- a/biscuit/src/Auth/Biscuit/Datalog/AST.hs +++ b/biscuit/src/Auth/Biscuit/Datalog/AST.hs @@ -89,6 +89,9 @@ module Auth.Biscuit.Datalog.AST , queryHasNoV4Operators , ruleHasNoScope , ruleHasNoV4Operators + , ruleHasNoV6Values + , predicateHasNoV6Values + , checkHasNoV6Values , isCheckOne , isReject , renderBlock @@ -213,6 +216,8 @@ data Term' (inSet :: IsWithinSet) (pof :: PredicateOrFact) (ctx :: DatalogContex -- ^ A slice (eg. @{name}@) | TermSet (SetType inSet ctx) -- ^ A set (eg. @[true, false]@) + | LNull + -- ^ @null@ deriving instance ( Eq (VariableType inSet pof) , Eq (SliceType ctx) @@ -250,6 +255,7 @@ instance ( Lift (VariableType inSet pof) lift (LBool b) = [| LBool b |] lift (TermSet terms) = [| TermSet terms |] lift (LDate t) = [| LDate (read $(lift $ show t)) |] + lift LNull = [| LNull |] lift (Antiquote s) = [| s |] #if MIN_VERSION_template_haskell(2,17,0) @@ -324,6 +330,7 @@ valueToSetTerm = \case LDate i -> Just $ LDate i LBytes i -> Just $ LBytes i LBool i -> Just $ LBool i + LNull -> Just LNull TermSet _ -> Nothing Variable v -> absurd v Antiquote v -> absurd v @@ -335,6 +342,7 @@ valueToTerm = \case LDate i -> LDate i LBytes i -> LBytes i LBool i -> LBool i + LNull -> LNull TermSet i -> TermSet i Variable v -> absurd v Antiquote v -> absurd v @@ -351,6 +359,7 @@ renderId' var set slice = \case LBytes bs -> "hex:" <> encodeHex bs LBool True -> "true" LBool False -> "false" + LNull -> "null" TermSet terms -> set terms Antiquote v -> slice v @@ -609,6 +618,35 @@ ruleHasNoV4Operators :: Rule -> Bool ruleHasNoV4Operators Rule{expressions} = all expressionHasNoV4Operators expressions +expressionHasNoV6ValuesOrOperators :: Expression -> Bool +expressionHasNoV6ValuesOrOperators = \case + EBinary HeterogeneousEqual _ _ -> False + EBinary HeterogeneousNotEqual _ _ -> False + EBinary _ l r -> expressionHasNoV6ValuesOrOperators l && expressionHasNoV6ValuesOrOperators r + EUnary _ l -> expressionHasNoV6ValuesOrOperators l + EValue LNull -> False + EValue _ -> True + +ruleHasNoV6Values :: Rule -> Bool +ruleHasNoV6Values Rule{rhead, body, expressions} = + predicateHasNoV6Values rhead + && all predicateHasNoV6Values body + && all expressionHasNoV6ValuesOrOperators expressions + +predicateHasNoV6Values :: Predicate' a b -> Bool +predicateHasNoV6Values Predicate{terms} = + let hasV6 = \case + LNull -> True + _ -> False + in all (not . hasV6) terms + +checkHasNoV6Values :: Check -> Bool +checkHasNoV6Values Check{cQueries} = + let hasNoV6 QueryItem{qBody, qExpressions} = + all predicateHasNoV6Values qBody + && all expressionHasNoV6ValuesOrOperators qExpressions + in all hasNoV6 cQueries + renderRule :: Rule -> Text renderRule Rule{rhead,body,expressions,scope} = renderPredicate rhead <> " <- " @@ -685,6 +723,8 @@ data Binary = | BitwiseOr | BitwiseXor | NotEqual + | HeterogeneousEqual + | HeterogeneousNotEqual deriving (Eq, Ord, Show, Lift) data Expression' (ctx :: DatalogContext) = @@ -750,7 +790,7 @@ renderExpression = EBinary GreaterThan e e' -> rOp ">" e e' EBinary LessOrEqual e e' -> rOp "<=" e e' EBinary GreaterOrEqual e e' -> rOp ">=" e e' - EBinary Equal e e' -> rOp "==" e e' + EBinary Equal e e' -> rOp "===" e e' EBinary Contains e e' -> rm "contains" e e' EBinary Prefix e e' -> rm "starts_with" e e' EBinary Suffix e e' -> rm "ends_with" e e' @@ -766,7 +806,9 @@ renderExpression = EBinary BitwiseAnd e e' -> rOp "&" e e' EBinary BitwiseOr e e' -> rOp "|" e e' EBinary BitwiseXor e e' -> rOp "^" e e' - EBinary NotEqual e e' -> rOp "!=" e e' + EBinary NotEqual e e' -> rOp "!==" e e' + EBinary HeterogeneousEqual e e' -> rOp "==" e e' + EBinary HeterogeneousNotEqual e e' -> rOp "!=" e e' -- | A biscuit block, containing facts, rules and checks. -- @@ -1086,6 +1128,7 @@ substitutePTerm termMapping = \case LDate i -> pure $ LDate i LBytes i -> pure $ LBytes i LBool i -> pure $ LBool i + LNull -> pure $ LNull TermSet i -> TermSet . Set.fromList <$> traverse (substituteSetTerm termMapping) (Set.toList i) Variable i -> pure $ Variable i @@ -1100,6 +1143,7 @@ substituteTerm termMapping = \case LDate i -> pure $ LDate i LBytes i -> pure $ LBytes i LBool i -> pure $ LBool i + LNull -> pure $ LNull TermSet i -> TermSet . Set.fromList <$> traverse (substituteSetTerm termMapping) (Set.toList i) Variable v -> absurd v @@ -1114,6 +1158,7 @@ substituteSetTerm termMapping = \case LDate i -> pure $ LDate i LBytes i -> pure $ LBytes i LBool i -> pure $ LBool i + LNull -> pure $ LNull TermSet v -> absurd v Variable v -> absurd v Antiquote (Slice v) -> diff --git a/biscuit/src/Auth/Biscuit/Datalog/Executor.hs b/biscuit/src/Auth/Biscuit/Datalog/Executor.hs index 2a03ff6..c752055 100644 --- a/biscuit/src/Auth/Biscuit/Datalog/Executor.hs +++ b/biscuit/src/Auth/Biscuit/Datalog/Executor.hs @@ -286,6 +286,7 @@ applyBindings p@Predicate{terms} (origins, bindings) = replaceTerm (LDate t) = Just $ LDate t replaceTerm (LBytes t) = Just $ LBytes t replaceTerm (LBool t) = Just $ LBool t + replaceTerm LNull = Just LNull replaceTerm (TermSet t) = Just $ TermSet t replaceTerm (Antiquote t) = absurd t in (\nt -> (origins, p { terms = nt})) <$> newTerms @@ -341,6 +342,7 @@ isSame (LDate t) (LDate t') = t == t' isSame (LBytes t) (LBytes t') = t == t' isSame (LBool t) (LBool t') = t == t' isSame (TermSet t) (TermSet t') = t == t' +isSame LNull LNull = True isSame _ _ = False -- | Given a predicate and a fact, try to match the fact to the predicate, @@ -378,6 +380,7 @@ applyVariable bindings = \case LDate t -> Right $ LDate t LBytes t -> Right $ LBytes t LBool t -> Right $ LBool t + LNull -> Right LNull TermSet t -> Right $ TermSet t Antiquote v -> absurd v @@ -406,6 +409,8 @@ evalBinary _ NotEqual (LBytes t) (LBytes t') = pure $ LBool (t /= t') evalBinary _ NotEqual (LBool t) (LBool t') = pure $ LBool (t /= t') evalBinary _ NotEqual (TermSet t) (TermSet t') = pure $ LBool (t /= t') evalBinary _ NotEqual _ _ = Left "Inequity mismatch" +evalBinary _ HeterogeneousEqual t t' = pure $ LBool (t == t') +evalBinary _ HeterogeneousNotEqual t t' = pure $ LBool (t /= t') evalBinary _ LessThan (LInteger i) (LInteger i') = pure $ LBool (i < i') evalBinary _ LessThan (LDate t) (LDate t') = pure $ LBool (t < t') evalBinary _ LessThan _ _ = Left "< mismatch" diff --git a/biscuit/src/Auth/Biscuit/Datalog/Parser.hs b/biscuit/src/Auth/Biscuit/Datalog/Parser.hs index 2e366ae..74aeec5 100644 --- a/biscuit/src/Auth/Biscuit/Datalog/Parser.hs +++ b/biscuit/src/Auth/Biscuit/Datalog/Parser.hs @@ -136,6 +136,7 @@ termParser parseVar parseSet = l $ choice , False <$ chunk "false" ] "boolean value (eg. true or false)" + , LNull <$ chunk "null" "null value" ] intParser :: Parser Int64 @@ -256,6 +257,8 @@ table = , infixN ">" GreaterThan , infixN "===" Equal , infixN "!==" NotEqual + , infixN "==" HeterogeneousEqual + , infixN "!=" HeterogeneousNotEqual ] , [ infixL "&&" And ] , [ infixL "||" Or ] diff --git a/biscuit/src/Auth/Biscuit/Proto.hs b/biscuit/src/Auth/Biscuit/Proto.hs index 8c4b477..fc716ed 100644 --- a/biscuit/src/Auth/Biscuit/Proto.hs +++ b/biscuit/src/Auth/Biscuit/Proto.hs @@ -29,6 +29,7 @@ module Auth.Biscuit.Proto , TermV2 (..) , ExpressionV2 (..) , TermSet (..) + , Empty (..) , Op (..) , OpUnary (..) , UnaryKind (..) @@ -160,6 +161,11 @@ data TermV2 = | TermBytes (Required 5 (Value ByteString)) | TermBool (Required 6 (Value Bool)) | TermTermSet (Required 7 (Message TermSet)) + | TermNull (Required 8 (Message Empty)) + deriving stock (Generic, Show) + deriving anyclass (Decode, Encode) + +data Empty = Empty {} deriving stock (Generic, Show) deriving anyclass (Decode, Encode) @@ -211,6 +217,8 @@ data BinaryKind = | BitwiseOr | BitwiseXor | NotEqual + | HeterogeneousEqual + | HeterogeneousNotEqual deriving stock (Show, Enum, Bounded) newtype OpBinary = OpBinary diff --git a/biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs b/biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs index d738400..7493539 100644 --- a/biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs +++ b/biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs @@ -132,7 +132,12 @@ pbToBlock ePk PB.Block{..} = do bRules <- traverse (pbToRule s) $ PB.getField rules_v2 bChecks <- traverse (pbToCheck s) $ PB.getField checks_v2 bScope <- Set.fromList <$> traverse (pbToScope s) (PB.getField scope) - let v6Plus = any isReject bChecks + let v6Plus = or + [ any isReject bChecks + , any (not . predicateHasNoV6Values) bFacts + , any (not . ruleHasNoV6Values) bRules + , any (not . checkHasNoV6Values) bChecks + ] v5Plus = isJust ePk v4Plus = not $ and [ Set.null bScope @@ -299,6 +304,7 @@ pbToTerm s = \case PB.TermBool f -> pure $ LBool $ PB.getField f PB.TermVariable f -> Variable <$> getSymbol s (SymbolRef $ PB.getField f) PB.TermTermSet f -> TermSet . Set.fromList <$> traverse (pbToSetValue s) (PB.getField . PB.set $ PB.getField f) + PB.TermNull _ -> pure LNull termToPb :: ReverseSymbols -> Term -> PB.TermV2 termToPb s = \case @@ -309,6 +315,7 @@ termToPb s = \case LBytes v -> PB.TermBytes $ PB.putField v LBool v -> PB.TermBool $ PB.putField v TermSet vs -> PB.TermTermSet $ PB.putField $ PB.TermSet $ PB.putField $ setValueToPb s <$> Set.toList vs + LNull -> PB.TermNull $ PB.putField $ PB.Empty {} Antiquote v -> absurd v @@ -321,6 +328,7 @@ pbToValue s = \case PB.TermBool f -> pure $ LBool $ PB.getField f PB.TermVariable _ -> Left "Variables can't appear in facts" PB.TermTermSet f -> TermSet . Set.fromList <$> traverse (pbToSetValue s) (PB.getField . PB.set $ PB.getField f) + PB.TermNull _ -> pure LNull valueToPb :: ReverseSymbols -> Value -> PB.TermV2 valueToPb s = \case @@ -330,6 +338,7 @@ valueToPb s = \case LBytes v -> PB.TermBytes $ PB.putField v LBool v -> PB.TermBool $ PB.putField v TermSet vs -> PB.TermTermSet $ PB.putField $ PB.TermSet $ PB.putField $ setValueToPb s <$> Set.toList vs + LNull -> PB.TermNull $ PB.putField PB.Empty Variable v -> absurd v Antiquote v -> absurd v @@ -341,6 +350,7 @@ pbToSetValue s = \case PB.TermDate f -> pure $ LDate $ pbTimeToUtcTime $ PB.getField f PB.TermBytes f -> pure $ LBytes $ PB.getField f PB.TermBool f -> pure $ LBool $ PB.getField f + PB.TermNull _ -> pure $ LNull PB.TermVariable _ -> Left "Variables can't appear in facts or sets" PB.TermTermSet _ -> Left "Sets can't be nested" @@ -351,6 +361,7 @@ setValueToPb s = \case LDate v -> PB.TermDate $ PB.putField $ round $ utcTimeToPOSIXSeconds v LBytes v -> PB.TermBytes $ PB.putField v LBool v -> PB.TermBool $ PB.putField v + LNull -> PB.TermNull $ PB.putField $ PB.Empty {} TermSet v -> absurd v Variable v -> absurd v @@ -392,27 +403,29 @@ unaryToPb = PB.OpUnary . PB.putField . \case pbToBinary :: PB.OpBinary -> Binary pbToBinary PB.OpBinary{kind} = case PB.getField kind of - PB.LessThan -> LessThan - PB.GreaterThan -> GreaterThan - PB.LessOrEqual -> LessOrEqual - PB.GreaterOrEqual -> GreaterOrEqual - PB.Equal -> Equal - PB.Contains -> Contains - PB.Prefix -> Prefix - PB.Suffix -> Suffix - PB.Regex -> Regex - PB.Add -> Add - PB.Sub -> Sub - PB.Mul -> Mul - PB.Div -> Div - PB.And -> And - PB.Or -> Or - PB.Intersection -> Intersection - PB.Union -> Union - PB.BitwiseAnd -> BitwiseAnd - PB.BitwiseOr -> BitwiseOr - PB.BitwiseXor -> BitwiseXor - PB.NotEqual -> NotEqual + PB.LessThan -> LessThan + PB.GreaterThan -> GreaterThan + PB.LessOrEqual -> LessOrEqual + PB.GreaterOrEqual -> GreaterOrEqual + PB.Equal -> Equal + PB.Contains -> Contains + PB.Prefix -> Prefix + PB.Suffix -> Suffix + PB.Regex -> Regex + PB.Add -> Add + PB.Sub -> Sub + PB.Mul -> Mul + PB.Div -> Div + PB.And -> And + PB.Or -> Or + PB.Intersection -> Intersection + PB.Union -> Union + PB.BitwiseAnd -> BitwiseAnd + PB.BitwiseOr -> BitwiseOr + PB.BitwiseXor -> BitwiseXor + PB.NotEqual -> NotEqual + PB.HeterogeneousEqual -> HeterogeneousEqual + PB.HeterogeneousNotEqual -> HeterogeneousNotEqual binaryToPb :: Binary -> PB.OpBinary binaryToPb = PB.OpBinary . PB.putField . \case @@ -437,6 +450,8 @@ binaryToPb = PB.OpBinary . PB.putField . \case BitwiseOr -> PB.BitwiseOr BitwiseXor -> PB.BitwiseXor NotEqual -> PB.NotEqual + HeterogeneousEqual -> PB.HeterogeneousEqual + HeterogeneousNotEqual -> PB.HeterogeneousNotEqual pbToThirdPartyBlockRequest :: PB.ThirdPartyBlockRequest -> Either String Crypto.Signature pbToThirdPartyBlockRequest PB.ThirdPartyBlockRequest{legacyPk, pkTable, prevSig} = do diff --git a/biscuit/test/Spec/Parser.hs b/biscuit/test/Spec/Parser.hs index fde936c..73fa711 100644 --- a/biscuit/test/Spec/Parser.hs +++ b/biscuit/test/Spec/Parser.hs @@ -67,6 +67,7 @@ parseBlock = runWithNoParams blockParser $ substituteBlock mempty mempty specs :: TestTree specs = testGroup "datalog parser" [ factWithDate + , factWithNull , simpleFact , oneLetterFact , simpleRule @@ -134,6 +135,11 @@ factWithDate = testCase "Parse fact containing a date" $ do parsePredicate "date(2019-12-02T13:49:53+00:00)" @?= Right (Predicate "date" [LDate $ read "2019-12-02 13:49:53 UTC"]) +factWithNull :: TestTree +factWithNull = testCase "Parse fact containing a null value" $ do + parsePredicate "date(null)" @?= + Right (Predicate "date" [LNull]) + simpleRule :: TestTree simpleRule = testCase "Parse simple rule" $ parseRule "right($0, \"read\") <- resource( $0), operation(\"read\")" @?= @@ -218,6 +224,18 @@ constraints = testGroup "Parse expressions" (EValue (Variable "0")) (EValue (LInteger 1)) ) + , testCase "int comparison (HEQ)" $ + parseExpression "$0 == 1" @?= + Right (EBinary HeterogeneousEqual + (EValue (Variable "0")) + (EValue (LInteger 1)) + ) + , testCase "int comparison (HNEQ)" $ + parseExpression "$0 != 1" @?= + Right (EBinary HeterogeneousNotEqual + (EValue (Variable "0")) + (EValue (LInteger 1)) + ) , testCase "negative int comparison (GTE)" $ parseExpression "$0 >= -1234" @?= Right (EBinary GreaterOrEqual diff --git a/biscuit/test/Spec/SampleReader.hs b/biscuit/test/Spec/SampleReader.hs index 5034dca..882d0a4 100644 --- a/biscuit/test/Spec/SampleReader.hs +++ b/biscuit/test/Spec/SampleReader.hs @@ -227,7 +227,7 @@ processTestCase step rootPk TestCase{..} = if fst filename == "test018_unbound_variables_in_rule.bc" then step "Skipping for now (unbound variables are now caught before evaluation)" - else if fst filename `elem` ["test030_null.bc", "test031_heterogeneous_equal.bc", "test032_laziness_closures.bc", "test033_typeof.bc", "test034_array_map.bc", "test035_ffi.bc", "test036_secp256r1.bc"] + else if fst filename `elem` ["test032_laziness_closures.bc", "test033_typeof.bc", "test034_array_map.bc", "test035_ffi.bc", "test036_secp256r1.bc"] then step "Skipping for now (not supported yet)" else do