Skip to content

Commit

Permalink
3.3: support for null and lenient equal
Browse files Browse the repository at this point in the history
(both have been implemented at the same time because the `null` sample depended on the lenient equal comparisons)
  • Loading branch information
divarvel committed Dec 1, 2024
1 parent 8ccb646 commit f7ad685
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 25 deletions.
49 changes: 47 additions & 2 deletions biscuit/src/Auth/Biscuit/Datalog/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ module Auth.Biscuit.Datalog.AST
, queryHasNoV4Operators
, ruleHasNoScope
, ruleHasNoV4Operators
, ruleHasNoV6Values
, predicateHasNoV6Values
, checkHasNoV6Values
, isCheckOne
, isReject
, renderBlock
Expand Down Expand Up @@ -213,6 +216,8 @@ data Term' (inSet :: IsWithinSet) (pof :: PredicateOrFact) (ctx :: DatalogContex
-- ^ A slice (eg. @{name}@)
| TermSet (SetType inSet ctx)
-- ^ A set (eg. @[true, false]@)
| LNull
-- ^ @null@

deriving instance ( Eq (VariableType inSet pof)
, Eq (SliceType ctx)
Expand Down Expand Up @@ -250,6 +255,7 @@ instance ( Lift (VariableType inSet pof)
lift (LBool b) = [| LBool b |]
lift (TermSet terms) = [| TermSet terms |]
lift (LDate t) = [| LDate (read $(lift $ show t)) |]
lift LNull = [| LNull |]
lift (Antiquote s) = [| s |]

#if MIN_VERSION_template_haskell(2,17,0)
Expand Down Expand Up @@ -324,6 +330,7 @@ valueToSetTerm = \case
LDate i -> Just $ LDate i
LBytes i -> Just $ LBytes i
LBool i -> Just $ LBool i
LNull -> Just LNull
TermSet _ -> Nothing
Variable v -> absurd v
Antiquote v -> absurd v
Expand All @@ -335,6 +342,7 @@ valueToTerm = \case
LDate i -> LDate i
LBytes i -> LBytes i
LBool i -> LBool i
LNull -> LNull
TermSet i -> TermSet i
Variable v -> absurd v
Antiquote v -> absurd v
Expand All @@ -351,6 +359,7 @@ renderId' var set slice = \case
LBytes bs -> "hex:" <> encodeHex bs
LBool True -> "true"
LBool False -> "false"
LNull -> "null"
TermSet terms -> set terms
Antiquote v -> slice v

Expand Down Expand Up @@ -609,6 +618,35 @@ ruleHasNoV4Operators :: Rule -> Bool
ruleHasNoV4Operators Rule{expressions} =
all expressionHasNoV4Operators expressions

expressionHasNoV6ValuesOrOperators :: Expression -> Bool
expressionHasNoV6ValuesOrOperators = \case
EBinary HeterogeneousEqual _ _ -> False
EBinary HeterogeneousNotEqual _ _ -> False
EBinary _ l r -> expressionHasNoV6ValuesOrOperators l && expressionHasNoV6ValuesOrOperators r
EUnary _ l -> expressionHasNoV6ValuesOrOperators l
EValue LNull -> False
EValue _ -> True

ruleHasNoV6Values :: Rule -> Bool
ruleHasNoV6Values Rule{rhead, body, expressions} =
predicateHasNoV6Values rhead
&& all predicateHasNoV6Values body
&& all expressionHasNoV6ValuesOrOperators expressions

predicateHasNoV6Values :: Predicate' a b -> Bool
predicateHasNoV6Values Predicate{terms} =
let hasV6 = \case
LNull -> True
_ -> False
in all (not . hasV6) terms

checkHasNoV6Values :: Check -> Bool
checkHasNoV6Values Check{cQueries} =
let hasNoV6 QueryItem{qBody, qExpressions} =
all predicateHasNoV6Values qBody
&& all expressionHasNoV6ValuesOrOperators qExpressions
in all hasNoV6 cQueries

renderRule :: Rule -> Text
renderRule Rule{rhead,body,expressions,scope} =
renderPredicate rhead <> " <- "
Expand Down Expand Up @@ -685,6 +723,8 @@ data Binary =
| BitwiseOr
| BitwiseXor
| NotEqual
| HeterogeneousEqual
| HeterogeneousNotEqual
deriving (Eq, Ord, Show, Lift)

data Expression' (ctx :: DatalogContext) =
Expand Down Expand Up @@ -750,7 +790,7 @@ renderExpression =
EBinary GreaterThan e e' -> rOp ">" e e'
EBinary LessOrEqual e e' -> rOp "<=" e e'
EBinary GreaterOrEqual e e' -> rOp ">=" e e'
EBinary Equal e e' -> rOp "==" e e'
EBinary Equal e e' -> rOp "===" e e'
EBinary Contains e e' -> rm "contains" e e'
EBinary Prefix e e' -> rm "starts_with" e e'
EBinary Suffix e e' -> rm "ends_with" e e'
Expand All @@ -766,7 +806,9 @@ renderExpression =
EBinary BitwiseAnd e e' -> rOp "&" e e'
EBinary BitwiseOr e e' -> rOp "|" e e'
EBinary BitwiseXor e e' -> rOp "^" e e'
EBinary NotEqual e e' -> rOp "!=" e e'
EBinary NotEqual e e' -> rOp "!==" e e'
EBinary HeterogeneousEqual e e' -> rOp "==" e e'
EBinary HeterogeneousNotEqual e e' -> rOp "!=" e e'

-- | A biscuit block, containing facts, rules and checks.
--
Expand Down Expand Up @@ -1086,6 +1128,7 @@ substitutePTerm termMapping = \case
LDate i -> pure $ LDate i
LBytes i -> pure $ LBytes i
LBool i -> pure $ LBool i
LNull -> pure $ LNull
TermSet i ->
TermSet . Set.fromList <$> traverse (substituteSetTerm termMapping) (Set.toList i)
Variable i -> pure $ Variable i
Expand All @@ -1100,6 +1143,7 @@ substituteTerm termMapping = \case
LDate i -> pure $ LDate i
LBytes i -> pure $ LBytes i
LBool i -> pure $ LBool i
LNull -> pure $ LNull
TermSet i ->
TermSet . Set.fromList <$> traverse (substituteSetTerm termMapping) (Set.toList i)
Variable v -> absurd v
Expand All @@ -1114,6 +1158,7 @@ substituteSetTerm termMapping = \case
LDate i -> pure $ LDate i
LBytes i -> pure $ LBytes i
LBool i -> pure $ LBool i
LNull -> pure $ LNull
TermSet v -> absurd v
Variable v -> absurd v
Antiquote (Slice v) ->
Expand Down
5 changes: 5 additions & 0 deletions biscuit/src/Auth/Biscuit/Datalog/Executor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ applyBindings p@Predicate{terms} (origins, bindings) =
replaceTerm (LDate t) = Just $ LDate t
replaceTerm (LBytes t) = Just $ LBytes t
replaceTerm (LBool t) = Just $ LBool t
replaceTerm LNull = Just LNull
replaceTerm (TermSet t) = Just $ TermSet t
replaceTerm (Antiquote t) = absurd t
in (\nt -> (origins, p { terms = nt})) <$> newTerms
Expand Down Expand Up @@ -341,6 +342,7 @@ isSame (LDate t) (LDate t') = t == t'
isSame (LBytes t) (LBytes t') = t == t'
isSame (LBool t) (LBool t') = t == t'
isSame (TermSet t) (TermSet t') = t == t'
isSame LNull LNull = True
isSame _ _ = False

-- | Given a predicate and a fact, try to match the fact to the predicate,
Expand Down Expand Up @@ -378,6 +380,7 @@ applyVariable bindings = \case
LDate t -> Right $ LDate t
LBytes t -> Right $ LBytes t
LBool t -> Right $ LBool t
LNull -> Right LNull
TermSet t -> Right $ TermSet t
Antiquote v -> absurd v

Expand Down Expand Up @@ -406,6 +409,8 @@ evalBinary _ NotEqual (LBytes t) (LBytes t') = pure $ LBool (t /= t')
evalBinary _ NotEqual (LBool t) (LBool t') = pure $ LBool (t /= t')
evalBinary _ NotEqual (TermSet t) (TermSet t') = pure $ LBool (t /= t')
evalBinary _ NotEqual _ _ = Left "Inequity mismatch"
evalBinary _ HeterogeneousEqual t t' = pure $ LBool (t == t')
evalBinary _ HeterogeneousNotEqual t t' = pure $ LBool (t /= t')
evalBinary _ LessThan (LInteger i) (LInteger i') = pure $ LBool (i < i')
evalBinary _ LessThan (LDate t) (LDate t') = pure $ LBool (t < t')
evalBinary _ LessThan _ _ = Left "< mismatch"
Expand Down
3 changes: 3 additions & 0 deletions biscuit/src/Auth/Biscuit/Datalog/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ termParser parseVar parseSet = l $ choice
, False <$ chunk "false"
]
<?> "boolean value (eg. true or false)"
, LNull <$ chunk "null" <?> "null value"
]

intParser :: Parser Int64
Expand Down Expand Up @@ -256,6 +257,8 @@ table =
, infixN ">" GreaterThan
, infixN "===" Equal
, infixN "!==" NotEqual
, infixN "==" HeterogeneousEqual
, infixN "!=" HeterogeneousNotEqual
]
, [ infixL "&&" And ]
, [ infixL "||" Or ]
Expand Down
8 changes: 8 additions & 0 deletions biscuit/src/Auth/Biscuit/Proto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module Auth.Biscuit.Proto
, TermV2 (..)
, ExpressionV2 (..)
, TermSet (..)
, Empty (..)
, Op (..)
, OpUnary (..)
, UnaryKind (..)
Expand Down Expand Up @@ -160,6 +161,11 @@ data TermV2 =
| TermBytes (Required 5 (Value ByteString))
| TermBool (Required 6 (Value Bool))
| TermTermSet (Required 7 (Message TermSet))
| TermNull (Required 8 (Message Empty))
deriving stock (Generic, Show)
deriving anyclass (Decode, Encode)

data Empty = Empty {}
deriving stock (Generic, Show)
deriving anyclass (Decode, Encode)

Expand Down Expand Up @@ -211,6 +217,8 @@ data BinaryKind =
| BitwiseOr
| BitwiseXor
| NotEqual
| HeterogeneousEqual
| HeterogeneousNotEqual
deriving stock (Show, Enum, Bounded)

newtype OpBinary = OpBinary
Expand Down
59 changes: 37 additions & 22 deletions biscuit/src/Auth/Biscuit/ProtoBufAdapter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ pbToBlock ePk PB.Block{..} = do
bRules <- traverse (pbToRule s) $ PB.getField rules_v2
bChecks <- traverse (pbToCheck s) $ PB.getField checks_v2
bScope <- Set.fromList <$> traverse (pbToScope s) (PB.getField scope)
let v6Plus = any isReject bChecks
let v6Plus = or
[ any isReject bChecks
, any (not . predicateHasNoV6Values) bFacts
, any (not . ruleHasNoV6Values) bRules
, any (not . checkHasNoV6Values) bChecks
]
v5Plus = isJust ePk
v4Plus = not $ and
[ Set.null bScope
Expand Down Expand Up @@ -299,6 +304,7 @@ pbToTerm s = \case
PB.TermBool f -> pure $ LBool $ PB.getField f
PB.TermVariable f -> Variable <$> getSymbol s (SymbolRef $ PB.getField f)
PB.TermTermSet f -> TermSet . Set.fromList <$> traverse (pbToSetValue s) (PB.getField . PB.set $ PB.getField f)
PB.TermNull _ -> pure LNull

termToPb :: ReverseSymbols -> Term -> PB.TermV2
termToPb s = \case
Expand All @@ -309,6 +315,7 @@ termToPb s = \case
LBytes v -> PB.TermBytes $ PB.putField v
LBool v -> PB.TermBool $ PB.putField v
TermSet vs -> PB.TermTermSet $ PB.putField $ PB.TermSet $ PB.putField $ setValueToPb s <$> Set.toList vs
LNull -> PB.TermNull $ PB.putField $ PB.Empty {}

Antiquote v -> absurd v

Expand All @@ -321,6 +328,7 @@ pbToValue s = \case
PB.TermBool f -> pure $ LBool $ PB.getField f
PB.TermVariable _ -> Left "Variables can't appear in facts"
PB.TermTermSet f -> TermSet . Set.fromList <$> traverse (pbToSetValue s) (PB.getField . PB.set $ PB.getField f)
PB.TermNull _ -> pure LNull

valueToPb :: ReverseSymbols -> Value -> PB.TermV2
valueToPb s = \case
Expand All @@ -330,6 +338,7 @@ valueToPb s = \case
LBytes v -> PB.TermBytes $ PB.putField v
LBool v -> PB.TermBool $ PB.putField v
TermSet vs -> PB.TermTermSet $ PB.putField $ PB.TermSet $ PB.putField $ setValueToPb s <$> Set.toList vs
LNull -> PB.TermNull $ PB.putField PB.Empty

Variable v -> absurd v
Antiquote v -> absurd v
Expand All @@ -341,6 +350,7 @@ pbToSetValue s = \case
PB.TermDate f -> pure $ LDate $ pbTimeToUtcTime $ PB.getField f
PB.TermBytes f -> pure $ LBytes $ PB.getField f
PB.TermBool f -> pure $ LBool $ PB.getField f
PB.TermNull _ -> pure $ LNull
PB.TermVariable _ -> Left "Variables can't appear in facts or sets"
PB.TermTermSet _ -> Left "Sets can't be nested"

Expand All @@ -351,6 +361,7 @@ setValueToPb s = \case
LDate v -> PB.TermDate $ PB.putField $ round $ utcTimeToPOSIXSeconds v
LBytes v -> PB.TermBytes $ PB.putField v
LBool v -> PB.TermBool $ PB.putField v
LNull -> PB.TermNull $ PB.putField $ PB.Empty {}

TermSet v -> absurd v
Variable v -> absurd v
Expand Down Expand Up @@ -392,27 +403,29 @@ unaryToPb = PB.OpUnary . PB.putField . \case

pbToBinary :: PB.OpBinary -> Binary
pbToBinary PB.OpBinary{kind} = case PB.getField kind of
PB.LessThan -> LessThan
PB.GreaterThan -> GreaterThan
PB.LessOrEqual -> LessOrEqual
PB.GreaterOrEqual -> GreaterOrEqual
PB.Equal -> Equal
PB.Contains -> Contains
PB.Prefix -> Prefix
PB.Suffix -> Suffix
PB.Regex -> Regex
PB.Add -> Add
PB.Sub -> Sub
PB.Mul -> Mul
PB.Div -> Div
PB.And -> And
PB.Or -> Or
PB.Intersection -> Intersection
PB.Union -> Union
PB.BitwiseAnd -> BitwiseAnd
PB.BitwiseOr -> BitwiseOr
PB.BitwiseXor -> BitwiseXor
PB.NotEqual -> NotEqual
PB.LessThan -> LessThan
PB.GreaterThan -> GreaterThan
PB.LessOrEqual -> LessOrEqual
PB.GreaterOrEqual -> GreaterOrEqual
PB.Equal -> Equal
PB.Contains -> Contains
PB.Prefix -> Prefix
PB.Suffix -> Suffix
PB.Regex -> Regex
PB.Add -> Add
PB.Sub -> Sub
PB.Mul -> Mul
PB.Div -> Div
PB.And -> And
PB.Or -> Or
PB.Intersection -> Intersection
PB.Union -> Union
PB.BitwiseAnd -> BitwiseAnd
PB.BitwiseOr -> BitwiseOr
PB.BitwiseXor -> BitwiseXor
PB.NotEqual -> NotEqual
PB.HeterogeneousEqual -> HeterogeneousEqual
PB.HeterogeneousNotEqual -> HeterogeneousNotEqual

binaryToPb :: Binary -> PB.OpBinary
binaryToPb = PB.OpBinary . PB.putField . \case
Expand All @@ -437,6 +450,8 @@ binaryToPb = PB.OpBinary . PB.putField . \case
BitwiseOr -> PB.BitwiseOr
BitwiseXor -> PB.BitwiseXor
NotEqual -> PB.NotEqual
HeterogeneousEqual -> PB.HeterogeneousEqual
HeterogeneousNotEqual -> PB.HeterogeneousNotEqual

pbToThirdPartyBlockRequest :: PB.ThirdPartyBlockRequest -> Either String Crypto.Signature
pbToThirdPartyBlockRequest PB.ThirdPartyBlockRequest{legacyPk, pkTable, prevSig} = do
Expand Down
18 changes: 18 additions & 0 deletions biscuit/test/Spec/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ parseBlock = runWithNoParams blockParser $ substituteBlock mempty mempty
specs :: TestTree
specs = testGroup "datalog parser"
[ factWithDate
, factWithNull
, simpleFact
, oneLetterFact
, simpleRule
Expand Down Expand Up @@ -134,6 +135,11 @@ factWithDate = testCase "Parse fact containing a date" $ do
parsePredicate "date(2019-12-02T13:49:53+00:00)" @?=
Right (Predicate "date" [LDate $ read "2019-12-02 13:49:53 UTC"])

factWithNull :: TestTree
factWithNull = testCase "Parse fact containing a null value" $ do
parsePredicate "date(null)" @?=
Right (Predicate "date" [LNull])

simpleRule :: TestTree
simpleRule = testCase "Parse simple rule" $
parseRule "right($0, \"read\") <- resource( $0), operation(\"read\")" @?=
Expand Down Expand Up @@ -218,6 +224,18 @@ constraints = testGroup "Parse expressions"
(EValue (Variable "0"))
(EValue (LInteger 1))
)
, testCase "int comparison (HEQ)" $
parseExpression "$0 == 1" @?=
Right (EBinary HeterogeneousEqual
(EValue (Variable "0"))
(EValue (LInteger 1))
)
, testCase "int comparison (HNEQ)" $
parseExpression "$0 != 1" @?=
Right (EBinary HeterogeneousNotEqual
(EValue (Variable "0"))
(EValue (LInteger 1))
)
, testCase "negative int comparison (GTE)" $
parseExpression "$0 >= -1234" @?=
Right (EBinary GreaterOrEqual
Expand Down
2 changes: 1 addition & 1 deletion biscuit/test/Spec/SampleReader.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ processTestCase step rootPk TestCase{..} =
if fst filename == "test018_unbound_variables_in_rule.bc"
then
step "Skipping for now (unbound variables are now caught before evaluation)"
else if fst filename `elem` ["test030_null.bc", "test031_heterogeneous_equal.bc", "test032_laziness_closures.bc", "test033_typeof.bc", "test034_array_map.bc", "test035_ffi.bc", "test036_secp256r1.bc"]
else if fst filename `elem` ["test032_laziness_closures.bc", "test033_typeof.bc", "test034_array_map.bc", "test035_ffi.bc", "test036_secp256r1.bc"]
then
step "Skipping for now (not supported yet)"
else do
Expand Down

0 comments on commit f7ad685

Please sign in to comment.