Skip to content

Commit

Permalink
Replaced Conditions with cel::Predicates
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Snaps <[email protected]>
  • Loading branch information
alexsnaps committed Nov 27, 2024
1 parent d2084cf commit ad9c90d
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 90 deletions.
6 changes: 3 additions & 3 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@
//! "my_namespace",
//! 2,
//! 60,
//! vec!["req.method == 'GET'"],
//! vec!["req_method == 'GET'"],
//! vec!["user_id"],
//! ).unwrap();
//! rate_limiter.add_limit(limit);
//!
//! // We've defined a limit of 2. So we can report 2 times before being
//! // rate-limited
//! let mut values_to_report: HashMap<String, String> = HashMap::new();
//! values_to_report.insert("req.method".to_string(), "GET".to_string());
//! values_to_report.insert("req_method".to_string(), "GET".to_string());
//! values_to_report.insert("user_id".to_string(), "1".to_string());
//!
//! // Check if we can report
Expand Down Expand Up @@ -167,7 +167,7 @@
//! "my_namespace",
//! 10,
//! 60,
//! vec!["req.method == 'GET'"],
//! vec!["req_method == 'GET'"],
//! vec!["user_id"],
//! ).unwrap();
//!
Expand Down
33 changes: 17 additions & 16 deletions limitador/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub struct Limit {

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
conditions: BTreeSet<Condition>,
conditions: BTreeSet<cel::Predicate>,
variables: BTreeSet<String>,
}

Expand Down Expand Up @@ -273,6 +273,7 @@ pub enum Predicate {
NotEqual,
}

#[allow(dead_code)]
impl Predicate {
fn test(&self, lhs: &str, rhs: &str) -> bool {
match self {
Expand Down Expand Up @@ -305,8 +306,11 @@ impl Limit {
LimitadorError: From<<T as TryInto<Condition>>::Error>,
{
// the above where-clause is needed in order to call unwrap().
let conditions: Result<BTreeSet<_>, _> =
conditions.into_iter().map(|cond| cond.try_into()).collect();
let conditions: Result<BTreeSet<_>, _> = conditions
.into_iter()
.map(|cond| cond.try_into())
.map(|r| r.map(|c| cel::Predicate::parse::<String>(c.into()).unwrap()))
.collect();
match conditions {
Ok(conditions) => Ok(Self {
id: None,
Expand All @@ -332,7 +336,12 @@ impl Limit {
where
LimitadorError: From<<T as TryInto<Condition>>::Error>,
{
match conditions.into_iter().map(|cond| cond.try_into()).collect() {
match conditions
.into_iter()
.map(|cond| cond.try_into())
.map(|r| r.map(|c| cel::Predicate::parse::<String>(c.into()).unwrap()))
.collect()
{
Ok(conditions) => Ok(Self {
id: Some(id.into()),
namespace: namespace.into(),
Expand Down Expand Up @@ -400,25 +409,16 @@ impl Limit {
}

pub fn applies(&self, values: &HashMap<String, String>) -> bool {
let ctx = Context::new(self, String::default(), values);
let all_conditions_apply = self
.conditions
.iter()
.all(|cond| Self::condition_applies(cond, values));
.all(|predicate| predicate.test(&ctx).unwrap());

let all_vars_are_set = self.variables.iter().all(|var| values.contains_key(var));

all_conditions_apply && all_vars_are_set
}

fn condition_applies(condition: &Condition, values: &HashMap<String, String>) -> bool {
let left_operand = condition.var_name.as_str();
let right_operand = condition.operand.as_str();

match values.get(left_operand) {
Some(val) => condition.predicate.test(val, right_operand),
None => false,
}
}
}

impl Hash for Limit {
Expand Down Expand Up @@ -855,6 +855,7 @@ mod conditions {
}
}

use crate::limit::cel::Context;
pub use cel::Expression as CelExpression;
pub use cel::ParseError;
pub use cel::Predicate as CelPredicate;
Expand Down Expand Up @@ -1086,7 +1087,7 @@ mod tests {
limit1.namespace.clone(),
limit1.max_value + 10,
limit1.seconds,
limit1.conditions.clone(),
vec!["req.method == 'GET'"],
limit1.variables.clone(),
)
.expect("This must be a valid limit!");
Expand Down
111 changes: 96 additions & 15 deletions limitador/src/limit/cel.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::limit::cel::errors::EvaluationError;
use crate::limit::Limit;
use cel_interpreter::{ExecutionError, Value};
pub use errors::ParseError;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::Arc;

pub(super) mod errors {
use cel_interpreter::ExecutionError;
Expand Down Expand Up @@ -69,7 +73,40 @@ pub(super) mod errors {
}
}

pub struct Context {}
pub struct Context<'a> {
variables: HashSet<String>,
ctx: cel_interpreter::Context<'a>,
}

impl Context<'_> {
pub(crate) fn new(limit: &Limit, root: String, values: &HashMap<String, String>) -> Self {
let mut ctx = cel_interpreter::Context::default();

if root.is_empty() {
for (binding, value) in values {
ctx.add_variable_from_value(binding, value.clone())
}
} else {
let map = cel_interpreter::objects::Map::from(values.clone());
ctx.add_variable_from_value(root, Value::Map(map));
}

let limit_data = cel_interpreter::objects::Map::from(HashMap::from([(
"name",
limit
.name
.as_ref()
.map(|n| Value::String(Arc::new(n.to_string())))
.unwrap_or(Value::Null),
)]));
ctx.add_variable_from_value("limit", Value::Map(limit_data));

Self {
variables: values.keys().cloned().collect(),
ctx,
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
Expand Down Expand Up @@ -99,9 +136,8 @@ impl Expression {
}
}

pub fn resolve(&self, _ctx: &Context) -> Result<Value, ExecutionError> {
let ctx = cel_interpreter::Context::default();
Value::resolve(&self.expression, &ctx)
pub fn resolve(&self, ctx: &Context) -> Result<Value, ExecutionError> {
Value::resolve(&self.expression, &ctx.ctx)
}
}

Expand Down Expand Up @@ -158,16 +194,33 @@ impl Ord for Expression {
}
}

#[derive(Clone, Debug, Serialize)]
pub struct Predicate(Expression);
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct Predicate {
#[serde(skip_serializing, default)]
variables: HashSet<String>,
expression: Expression,
}

impl Predicate {
pub fn parse<T: ToString>(source: T) -> Result<Self, ParseError> {
Expression::parse(source).map(Self)
Expression::parse(source).map(|e| Self {
variables: e
.expression
.references()
.variables()
.into_iter()
.map(String::from)
.collect(),
expression: e,
})
}

pub fn test(&self, ctx: &Context) -> Result<bool, EvaluationError> {
match self.0.resolve(ctx)? {
if !self.variables.iter().all(|v| ctx.variables.contains(v)) {
return Ok(false);
}
match self.expression.resolve(ctx)? {
Value::Bool(b) => Ok(b),
v => Err(err_on_value(v)),
}
Expand All @@ -178,7 +231,7 @@ impl Eq for Predicate {}

impl PartialEq<Self> for Predicate {
fn eq(&self, other: &Self) -> bool {
self.0.source == other.0.source
self.expression.source == other.expression.source
}
}

Expand All @@ -190,18 +243,39 @@ impl PartialOrd<Self> for Predicate {

impl Ord for Predicate {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
self.expression.cmp(&other.expression)
}
}

impl Hash for Predicate {
fn hash<H: Hasher>(&self, state: &mut H) {
self.expression.source.hash(state);
}
}

impl TryFrom<String> for Predicate {
type Error = ParseError;

fn try_from(value: String) -> Result<Self, Self::Error> {
Self::parse(value)
}
}

impl From<Predicate> for String {
fn from(value: Predicate) -> Self {
value.expression.source
}
}

#[cfg(test)]
mod tests {
use super::{Context, Expression, Predicate};
use std::collections::HashSet;

#[test]
fn expression() {
let exp = Expression::parse("100").expect("failed to parse");
assert_eq!(exp.eval(&Context {}), Ok(String::from("100")));
assert_eq!(exp.eval(&ctx()), Ok(String::from("100")));
}

#[test]
Expand All @@ -210,30 +284,37 @@ mod tests {
let serialized = serde_json::to_string(&exp).expect("failed to serialize");
let deserialized: Expression =
serde_json::from_str(&serialized).expect("failed to deserialize");
assert_eq!(exp.eval(&Context {}), deserialized.eval(&Context {}));
assert_eq!(exp.eval(&ctx()), deserialized.eval(&ctx()));
}

#[test]
fn unexpected_value_type_expression() {
let exp = Expression::parse("['100']").expect("failed to parse");
assert_eq!(
exp.eval(&Context {}).map_err(|e| format!("{e}")),
exp.eval(&ctx()).map_err(|e| format!("{e}")),
Err("unexpected value of type list: `[String(\"100\")]`".to_string())
);
}

#[test]
fn predicate() {
let pred = Predicate::parse("42 == uint('42')").expect("failed to parse");
assert_eq!(pred.test(&Context {}), Ok(true));
assert_eq!(pred.test(&ctx()), Ok(true));
}

#[test]
fn unexpected_value_predicate() {
let pred = Predicate::parse("42").expect("failed to parse");
assert_eq!(
pred.test(&Context {}).map_err(|e| format!("{e}")),
pred.test(&ctx()).map_err(|e| format!("{e}")),
Err("unexpected value of type integer: `42`".to_string())
);
}

fn ctx<'a>() -> Context<'a> {
Context {
variables: HashSet::default(),
ctx: cel_interpreter::Context::default(),
}
}
}
Loading

0 comments on commit ad9c90d

Please sign in to comment.