Skip to content

Commit

Permalink
Runtime: Fix missing security policy row filters for new filter expre…
Browse files Browse the repository at this point in the history
…ssions (#3753)

* Runtime: Fix missing security policy row filters for new filter expressions

* Fix column verification

* Fix tests
  • Loading branch information
begelundmuller authored Jan 2, 2024
1 parent 711889d commit 8dbfd93
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 21 deletions.
13 changes: 8 additions & 5 deletions runtime/queries/metricsview.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,27 +203,30 @@ func dimensionSelect(mv *runtimev1.MetricsViewSpec, dim *runtimev1.MetricsViewSp
}

func buildExpression(mv *runtimev1.MetricsViewSpec, expr *runtimev1.Expression, aliases []*runtimev1.MetricsViewComparisonMeasureAlias, dialect drivers.Dialect) (string, []any, error) {
var emptyArg []any
if expr == nil {
return "", nil, nil
}

switch e := expr.Expression.(type) {
case *runtimev1.Expression_Val:
arg, err := pbutil.FromValue(e.Val)
if err != nil {
return "", emptyArg, err
return "", nil, err
}
return "?", []any{arg}, nil

case *runtimev1.Expression_Ident:
expr, isIdent := columnIdentifierExpression(mv, aliases, e.Ident, dialect)
if !isIdent {
return "", emptyArg, fmt.Errorf("unknown column filter: %s", e.Ident)
return "", nil, fmt.Errorf("unknown column filter: %s", e.Ident)
}
return expr, emptyArg, nil
return expr, nil, nil

case *runtimev1.Expression_Cond:
return buildConditionExpression(mv, e.Cond, aliases, dialect)
}

return "", emptyArg, nil
return "", nil, nil
}

func buildConditionExpression(mv *runtimev1.MetricsViewSpec, cond *runtimev1.Condition, aliases []*runtimev1.MetricsViewComparisonMeasureAlias, dialect drivers.Dialect) (string, []any, error) {
Expand Down
3 changes: 3 additions & 0 deletions runtime/queries/metricsview_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ func (q *MetricsViewAggregation) buildMetricsAggregationSQL(mv *runtimev1.Metric
whereClause += " AND " + clause
args = append(args, clauseArgs...)
}
if policy != nil && policy.RowFilter != "" {
whereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}
if len(whereClause) > 0 {
whereClause = "WHERE 1=1" + whereClause
}
Expand Down
34 changes: 19 additions & 15 deletions runtime/queries/metricsview_comparison_toplist.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ func (q *MetricsViewComparison) buildMetricsTopListSQL(mv *runtimev1.MetricsView
args = append(args, clauseArgs...)
}

if policy != nil && policy.RowFilter != "" {
baseWhereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

havingClause := ""
if q.Having != nil {
var havingClauseArgs []any
Expand Down Expand Up @@ -504,20 +508,20 @@ func (q *MetricsViewComparison) buildMetricsComparisonTopListSQL(mv *runtimev1.M

td := safeName(mv.TimeDimension)

whereClause, whereClauseArgs, err := buildExpression(mv, q.Where, nil, dialect)
if err != nil {
return "", nil, err
}

trc, err := timeRangeClause(q.TimeRange, mv, dialect, td, &args)
if err != nil {
return "", nil, err
}
baseWhereClause += trc

if q.Where != nil {
clause, clauseArgs, err := buildExpression(mv, q.Where, nil, dialect)
if err != nil {
return "", nil, err
}
baseWhereClause += " AND " + clause

args = append(args, clauseArgs...)
if whereClause != "" {
baseWhereClause += " AND " + whereClause
args = append(args, whereClauseArgs...)
}

trc, err = timeRangeClause(q.ComparisonTimeRange, mv, dialect, td, &args)
Expand All @@ -526,14 +530,14 @@ func (q *MetricsViewComparison) buildMetricsComparisonTopListSQL(mv *runtimev1.M
}
comparisonWhereClause += trc

if q.Where != nil {
clause, clauseArgs, err := buildExpression(mv, q.Where, nil, dialect)
if err != nil {
return "", nil, err
}
comparisonWhereClause += " AND " + clause
if whereClause != "" {
comparisonWhereClause += " AND " + whereClause
args = append(args, whereClauseArgs...)
}

args = append(args, clauseArgs...)
if policy != nil && policy.RowFilter != "" {
baseWhereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
comparisonWhereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

havingClause := "1=1"
Expand Down
4 changes: 4 additions & 0 deletions runtime/queries/metricsview_rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ func (q *MetricsViewRows) buildMetricsRowsSQL(mv *runtimev1.MetricsViewSpec, dia
args = append(args, clauseArgs...)
}

if policy != nil && policy.RowFilter != "" {
whereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

sortingCriteria := make([]string, 0, len(q.Sort))
for _, s := range q.Sort {
sortCriterion := safeName(s.Name)
Expand Down
4 changes: 4 additions & 0 deletions runtime/queries/metricsview_timeseries.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ func (q *MetricsViewTimeSeries) buildMetricsTimeseriesSQL(olap drivers.OLAPStore
args = append(args, clauseArgs...)
}

if policy != nil && policy.RowFilter != "" {
whereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

havingClause := ""
if q.Having != nil {
clause, clauseArgs, err := buildExpression(mv, q.Having, nil, olap.Dialect())
Expand Down
4 changes: 4 additions & 0 deletions runtime/queries/metricsview_toplist.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ func (q *MetricsViewToplist) buildMetricsTopListSQL(mv *runtimev1.MetricsViewSpe
args = append(args, clauseArgs...)
}

if policy != nil && policy.RowFilter != "" {
whereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

havingClause := ""
if q.Having != nil {
var havingClauseArgs []any
Expand Down
4 changes: 4 additions & 0 deletions runtime/queries/metricsview_totals.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ func (q *MetricsViewTotals) buildMetricsTotalsSQL(mv *runtimev1.MetricsViewSpec,
args = append(args, clauseArgs...)
}

if policy != nil && policy.RowFilter != "" {
whereClause += fmt.Sprintf(" AND (%s)", policy.RowFilter)
}

sql := fmt.Sprintf(
"SELECT %s FROM %q WHERE %s",
strings.Join(selectCols, ", "),
Expand Down
2 changes: 1 addition & 1 deletion runtime/reconcilers/metrics_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (r *MetricsViewReconciler) validate(ctx context.Context, mv *runtimev1.Metr

func validateDimension(ctx context.Context, olap drivers.OLAPStore, t *drivers.Table, d *runtimev1.MetricsViewSpec_DimensionV2, fields map[string]*runtimev1.StructType_Field) error {
if d.Column != "" {
if _, isColumn := fields[d.Column]; !isColumn {
if _, isColumn := fields[strings.ToLower(d.Column)]; !isColumn {
return fmt.Errorf("failed to validate dimension %q: column %q not found in table", d.Name, d.Column)
}
return nil
Expand Down

0 comments on commit 8dbfd93

Please sign in to comment.