Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optionally check if all env vars match #350

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
final payload which can be used to trigger a shutdown of
the rulebook, especially when we are using rulebooks to
forward messages to other running rulebooks.
check_env_vars dict Optionally check if all the defined env vars are set
before generating the events. If any of the env_var is missing
or the value doesn't match the source plugin will end
with an exception


"""
Expand All @@ -53,16 +57,35 @@
from __future__ import annotations

import asyncio
import os
import random
import time
from dataclasses import dataclass, fields
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, Dict, Optional

import yaml


class MissingEnvVarError(Exception):
"""Exception class for missing env var."""

def __init__(self: "MissingEnvVarError", env_var: str) -> None:
"""Class constructor with the missing env_var."""
super().__init__(f"Env Var {env_var} is required")


class EnvVarMismatchError(Exception):
"""Exception class for mismatch in the env var value."""

def __init__(
self: "EnvVarMismatchError", env_var: str, value: str, expected: str
) -> None:
"""Class constructor with mismatch in env_var value."""
super().__init__(f"Env Var {env_var} expected: {expected} passed in: {value}")


@dataclass
class Args:
"""Class to store all the passed in args."""
Expand All @@ -84,6 +107,7 @@ class ControlArgs:
loop_count: int = 1
repeat_count: int = 1
timestamp: bool = False
check_env_vars: Optional[Dict[str, str]] = None


@dataclass
Expand Down Expand Up @@ -135,6 +159,7 @@ async def __call__(self: Generic) -> None:
msg = "time_format must be one of local, iso8601, epoch"
raise ValueError(msg)

await self._check_env_vars()
await self._load_payload_from_file()

if not isinstance(self.my_args.payload, list):
Expand Down Expand Up @@ -174,6 +199,14 @@ async def _post_event(self: Generic, event: dict[str, Any], index: int) -> None:
print(data) # noqa: T201
await self.queue.put(data)

async def _check_env_vars(self: Generic) -> None:
if self.control_args.check_env_vars:
for key, value in self.control_args.check_env_vars.items():
if key not in os.environ:
raise MissingEnvVarError(key)
if os.environ[key] != value:
raise EnvVarMismatchError(key, os.environ[key], value)

async def _load_payload_from_file(self: Generic) -> None:
if not self.my_args.payload_file:
return
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/event_source/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import pytest
import yaml

from extensions.eda.plugins.event_source.generic import (
EnvVarMismatchError,
MissingEnvVarError,
)
from extensions.eda.plugins.event_source.generic import main as generic_main


Expand Down Expand Up @@ -243,3 +247,57 @@ def test_generic_parsing_payload_file() -> None:
},
)
)


def test_env_vars_missing() -> None:
"""Test missing env vars"""
myqueue = _MockQueue()
event = {"name": "fred"}

with pytest.raises(MissingEnvVarError):
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"NAME_MISSING": "Fred"},
},
)
)


def test_env_vars_mismatch() -> None:
"""Test env vars with incorrect values"""
myqueue = _MockQueue()
event = {"name": "fred"}

os.environ["TEST_ENV1"] = "Kaboom"
with pytest.raises(EnvVarMismatchError):
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"TEST_ENV1": "Fred"},
},
)
)


def test_env_vars() -> None:
"""Test env vars with correct values"""
myqueue = _MockQueue()
event = {"name": "fred"}

os.environ["TEST_ENV1"] = "Fred"
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"TEST_ENV1": "Fred"},
},
)
)
assert len(myqueue.queue) == 1
assert myqueue.queue[0] == {"name": "fred"}
Loading