Skip to content

Commit

Permalink
Check that dag_ids passed in request are consistent (apache#34366)
Browse files Browse the repository at this point in the history
There are several ways to pass dag_ids in the request - via args
via kwargs, or via form requests or via json. If you pass several
of those, they should all be the same.
  • Loading branch information
potiuk authored Sep 14, 2023
1 parent b49f4a7 commit 4f1b500
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 7 deletions.
37 changes: 30 additions & 7 deletions airflow/www/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
from functools import wraps
from typing import Callable, Sequence, TypeVar, cast

Expand All @@ -27,6 +28,8 @@

T = TypeVar("T", bound=Callable)

log = logging.getLogger(__name__)


def get_access_denied_message():
return conf.get("webserver", "access_denied_message")
Expand All @@ -42,13 +45,33 @@ def decorated(*args, **kwargs):

appbuilder = current_app.appbuilder

dag_id = (
kwargs.get("dag_id")
or request.args.get("dag_id")
or request.form.get("dag_id")
or (request.is_json and request.json.get("dag_id"))
or None
)
dag_id_kwargs = kwargs.get("dag_id")
dag_id_args = request.args.get("dag_id")
dag_id_form = request.form.get("dag_id")
dag_id_json = request.json.get("dag_id") if request.is_json else None
all_dag_ids = [dag_id_kwargs, dag_id_args, dag_id_form, dag_id_json]
unique_dag_ids = set(dag_id for dag_id in all_dag_ids if dag_id is not None)

if len(unique_dag_ids) > 1:
log.warning(
f"There are different dag_ids passed in the request: {unique_dag_ids}. Returning 403."
)
log.warning(
f"kwargs: {dag_id_kwargs}, args: {dag_id_args}, "
f"form: {dag_id_form}, json: {dag_id_json}"
)
return (
render_template(
"airflow/no_roles_permissions.html",
hostname=get_hostname()
if conf.getboolean("webserver", "EXPOSE_HOSTNAME")
else "redact",
logout_url=get_auth_manager().get_url_logout(),
),
403,
)
dag_id = unique_dag_ids.pop() if unique_dag_ids else None

if appbuilder.sm.check_authorization(permissions, dag_id):
return func(*args, **kwargs)
elif get_auth_manager().is_logged_in() and not g.user.perms:
Expand Down
93 changes: 93 additions & 0 deletions tests/www/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from unittest.mock import patch

import pytest

from airflow.security import permissions
from airflow.settings import json
from tests.test_utils.api_connexion_utils import create_user_scope
from tests.www.test_security import SomeBaseView, SomeModelView


@pytest.fixture(scope="module")
def app_builder(app):
app_builder = app.appbuilder
app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
return app.appbuilder


@pytest.mark.parametrize(
"dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail",
[
("a", None, None, None, False),
(None, "b", None, None, False),
(None, None, "c", None, False),
(None, None, None, "d", False),
("a", "a", None, None, False),
("a", "a", "a", None, False),
("a", "a", "a", "a", False),
(None, "a", "a", "a", False),
(None, None, "a", "a", False),
("a", None, None, "a", False),
("a", None, "a", None, False),
("a", None, "c", None, True),
(None, "b", "c", None, True),
(None, None, "c", "d", True),
("a", "b", "c", "d", True),
],
)
def test_dag_id_consistency(
app,
dag_id_args: str | None,
dag_id_kwargs: str | None,
dag_id_form: str | None,
dag_id_json: str | None,
fail: bool,
):
with app.test_request_context() as mock_context:
from airflow.www.auth import has_access

mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {}
kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {}
if dag_id_json:
mock_context.request._cached_data = json.dumps({"dag_id": dag_id_json})
mock_context.request._parsed_content_type = ["application/json"]

with create_user_scope(
app,
username="test-user",
role_name="limited-role",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)],
) as user:
with patch("airflow.www.security_manager.g") as mock_g:
mock_g.user = user

@has_access(permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
def test_func(**kwargs):
return True

result = test_func(**kwargs)
if fail:
assert result[1] == 403
else:
assert result is True

0 comments on commit 4f1b500

Please sign in to comment.