Skip to content

Commit

Permalink
change: support strict field for required_file policy and multiple re…
Browse files Browse the repository at this point in the history
…pos in repo_pattern
  • Loading branch information
netomi committed Oct 29, 2024
1 parent 4e8f38a commit a80aa5c
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 54 deletions.
21 changes: 16 additions & 5 deletions otterdog/webapp/policies/required_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,30 @@


class RepoSelector(BaseModel):
name_pattern: str | None
name_pattern: str | list[str] | None

@cached_property
def _pattern(self):
return re.compile(self.name_pattern)
def _pattern(self) -> re.Pattern | None:
if self.name_pattern is None:
return None
elif isinstance(self.name_pattern, str):
return re.compile(self.name_pattern)
else:
return re.compile("|".join(self.name_pattern))

def matches(self, repo: Repository) -> bool:
return self._pattern.match(repo.name)
pattern = self._pattern
if pattern is not None:
return bool(pattern.fullmatch(repo.name))
else:
return False


class RequiredFile(BaseModel):
path: str
repo_selector: RepoSelector
content: str
strict: bool = False


class RequiredFilePolicy(Policy):
Expand All @@ -67,7 +77,7 @@ async def evaluate(self, github_id: str) -> None:
logger.debug(f"checking for required file '{required_file.path}' in repo '{github_id}/{repo.name}'")

title = f"Adding required file {required_file.path}"
body = "This PR has been automatically created by otterdog due to a violated policy."
body = "This PR has been automatically created by otterdog due to a policy."

current_app.add_background_task(
CheckFileTask(
Expand All @@ -76,6 +86,7 @@ async def evaluate(self, github_id: str) -> None:
repo.name,
required_file.path,
required_file.content,
required_file.strict,
"policy",
title,
body,
Expand Down
36 changes: 20 additions & 16 deletions otterdog/webapp/tasks/check_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# SPDX-License-Identifier: EPL-2.0
# *******************************************************************************

import os
from dataclasses import dataclass

from otterdog.providers.github.rest import RestApi
Expand All @@ -20,6 +21,7 @@ class CheckFileTask(InstallationBasedTask, Task[None]):
repo_name: str
path: str
content: str
strict: bool
branch_prefix: str
pr_title: str
pr_body: str
Expand All @@ -42,8 +44,9 @@ async def _execute(self) -> None:
rest_api = await self.rest_api

try:
await rest_api.content.get_content(self.org_id, self.repo_name, self.path)
return
content = await rest_api.content.get_content(self.org_id, self.repo_name, self.path)
if self.strict is False or self.content == content:
return
except RuntimeError:
# file does not exist, so let's create it
pass
Expand All @@ -53,7 +56,8 @@ async def _execute(self) -> None:

async def _create_pull_request_if_necessary(self, rest_api: RestApi) -> None:
default_branch = await rest_api.repo.get_default_branch(self.org_id, self.repo_name)
branch_name = f"otterdog/{self.branch_prefix}/{self.path}"
file_name = os.path.basename(self.path)
branch_name = f"otterdog/{self.branch_prefix}/{file_name}"

try:
await rest_api.reference.get_branch_reference(self.org_id, self.repo_name, branch_name)
Expand All @@ -73,20 +77,20 @@ async def _create_pull_request_if_necessary(self, rest_api: RestApi) -> None:
default_branch_sha,
)

# FIXME: once the otterdog-app is added to the ECA allow list, this can be removed again
short_name = self.org_id if "-" not in self.org_id else self.org_id.partition("-")[2]
# FIXME: once the otterdog-app is added to the ECA allow list, this can be removed again
short_name = self.org_id if "-" not in self.org_id else self.org_id.partition("-")[2]

await rest_api.content.update_content(
self.org_id,
self.repo_name,
self.path,
self.content,
branch_name,
f"Updating file {self.path}",
f"{self.org_id}-bot",
f"{short_name}[email protected]",
author_is_committer=True,
)
await rest_api.content.update_content(
self.org_id,
self.repo_name,
self.path,
self.content,
branch_name,
f"Updating file {self.path}",
f"{self.org_id}-bot",
f"{short_name}[email protected]",
author_is_committer=True,
)

open_pull_requests = await rest_api.pull_request.get_pull_requests(
self.org_id, self.repo_name, "open", default_branch
Expand Down
60 changes: 30 additions & 30 deletions otterdog/webapp/tasks/fetch_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from dataclasses import dataclass

import yaml

from otterdog.providers.github.rest import RestApi
from otterdog.utils import print_error
from otterdog.webapp.db.models import TaskModel
from otterdog.webapp.db.service import (
cleanup_policies_of_owner,
update_or_create_policy,
)
from otterdog.webapp.policies import Policy, PolicyType
from otterdog.webapp.policies import Policy, PolicyType, read_policy
from otterdog.webapp.tasks import InstallationBasedTask, Task


Expand Down Expand Up @@ -43,40 +44,39 @@ async def _execute(self) -> None:
async with self.get_organization_config() as org_config:
rest_api = await self.rest_api

policies = await fetch_policies(rest_api, self.org_id, org_config.config_repo, self.global_policies)
policies = await self._fetch_policies(rest_api, self.org_id, org_config.config_repo, self.global_policies)

valid_types = [x.value for x in policies]
await cleanup_policies_of_owner(self.org_id, valid_types)

for policy in list(policies.values()):
await update_or_create_policy(self.org_id, policy)

async def _fetch_policies(
self,
rest_api: RestApi,
org_id: str,
repo: str,
global_policies: list[Policy],
) -> dict[PolicyType, Policy]:
config_file_path = "otterdog/policies"
policies = {p.type: p for p in global_policies}
try:
entries = await rest_api.content.get_content_object(org_id, repo, config_file_path)
except RuntimeError:
entries = []

for entry in entries:
path = entry["path"]
if path.endswith((".yml", "yaml")):
content = await rest_api.content.get_content(org_id, repo, path)
try:
policy = read_policy(yaml.safe_load(content))
policies[policy.type] = policy
except (ValueError, RuntimeError) as ex:
self.logger.error(f"failed reading policy from path '{path}'", exc_info=ex)

return policies

def __repr__(self) -> str:
return f"FetchPoliciesTask(repo='{self.org_id}/{self.repo_name}')"


async def fetch_policies(
rest_api: RestApi, org_id: str, repo: str, global_policies: list[Policy]
) -> dict[PolicyType, Policy]:
import yaml

from otterdog.webapp.policies import read_policy

config_file_path = "otterdog/policies"
policies = {p.type: p for p in global_policies}
try:
entries = await rest_api.content.get_content_object(org_id, repo, config_file_path)
except RuntimeError:
entries = []

for entry in entries:
path = entry["path"]
if path.endswith((".yml", "yaml")):
content = await rest_api.content.get_content(org_id, repo, path)
try:
policy = read_policy(yaml.safe_load(content))
policies[policy.type] = policy
except RuntimeError as ex:
print_error(f"failed reading policy from path '{path}': {ex!s}")

return policies
5 changes: 2 additions & 3 deletions otterdog/webapp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from otterdog.providers.github.cache.redis import redis_cache
from otterdog.providers.github.graphql import GraphQLClient
from otterdog.providers.github.rest import RestApi
from otterdog.utils import print_error
from otterdog.webapp.policies import Policy, read_policy

logger = getLogger(__name__)
Expand Down Expand Up @@ -190,8 +189,8 @@ async def _load_global_policies(ref: str | None = None) -> list[Policy]:
try:
policy = read_policy(yaml.safe_load(content))
policies.append(policy)
except RuntimeError as e:
print_error(f"failed reading global policy from path '{path}': {e!s}")
except (ValueError, RuntimeError) as e:
logger.error(f"failed reading global policy from path '{path}'", exc_info=e)

return policies

Expand Down
58 changes: 58 additions & 0 deletions tests/webapp/policies/test_required_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
# SPDX-License-Identifier: EPL-2.0
# *******************************************************************************

from dataclasses import fields
from unittest import mock

import yaml

from otterdog.models.repository import Repository
from otterdog.webapp.policies import PolicyType, read_policy
from otterdog.webapp.policies.required_file import RequiredFilePolicy

Expand All @@ -24,6 +28,30 @@
Please head over to....
"""

multiple_repo_selection = """
name: Require file
type: required_file
config:
files:
- path: .github/workflows/dependabot-auto-merge.yml
repo_selector:
name_pattern:
- .github
- repo-1
- repo-2
- repo-3
content: |
name: Dependabot auto-merge
on: pull_request_target
permissions: read-all
jobs:
dependabot:
permissions:
contents: write
pull-requests: write
uses: adoptium/.github/.github/workflows/dependabot-auto-merge.yml@main
"""


def test_read():
config = yaml.safe_load(yaml_content)
Expand All @@ -32,3 +60,33 @@ def test_read():
assert policy.type == PolicyType.REQUIRED_FILE

assert isinstance(policy, RequiredFilePolicy)


def test_repo_selector_multiple_repos():
config = yaml.safe_load(multiple_repo_selection)

policy = read_policy(config)
assert policy.type == PolicyType.REQUIRED_FILE

assert isinstance(policy, RequiredFilePolicy)
assert len(policy.files) == 1

required_file = policy.files[0]

assert required_file.strict is False

selector = required_file.repo_selector

assert selector.matches(create_repo_with_name(".github"))
assert selector.matches(create_repo_with_name("repo-1"))
assert selector.matches(create_repo_with_name("repo-10")) is False


def create_repo_with_name(name: str) -> Repository:
repo = create_dataclass_mock(Repository)
repo.name = name
return repo


def create_dataclass_mock(obj):
return mock.Mock(spec_set=[field.name for field in fields(obj)])

0 comments on commit a80aa5c

Please sign in to comment.