Skip to content

Commit

Permalink
Merge pull request #215 from mkanoor/payload_from_file
Browse files Browse the repository at this point in the history
feat: generic source supports events in yaml file
  • Loading branch information
mkanoor authored Jul 18, 2024
2 parents db79d23 + fe56732 commit 76c044d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
25 changes: 25 additions & 0 deletions extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
parameter payload and is an array of events.
Optional Parameters:
payload_file A yaml with an array of events can be used instead of payload
randomize True|False Randomize the events in the payload, default False
display True|False Display the event data in stdout, default False
timestamp True|False Add an event timestamp, default False
Expand Down Expand Up @@ -56,8 +57,11 @@
import time
from dataclasses import dataclass, fields
from datetime import datetime
from pathlib import Path
from typing import Any

import yaml


@dataclass
class Args:
Expand All @@ -67,6 +71,7 @@ class Args:
final_payload: Any = None
display: bool = False
create_index: str = ""
payload_file: str = ""


@dataclass
Expand Down Expand Up @@ -99,6 +104,10 @@ def __init__(self: Generic, queue: asyncio.Queue, args: dict[str, Any]) -> None:
"""Insert event data into the queue."""
self.queue = queue
field_names = [f.name for f in fields(Args)]

if "payload_file" in args:
args["payload"] = ""

self.my_args = Args(**{k: v for k, v in args.items() if k in field_names})
field_names = [f.name for f in fields(ControlArgs)]
self.control_args = ControlArgs(
Expand All @@ -124,6 +133,8 @@ async def __call__(self: Generic) -> None:
msg = "time_format must be one of local, iso8601, epoch"
raise ValueError(msg)

await self._load_payload_from_file()

if not isinstance(self.my_args.payload, list):
self.my_args.payload = [self.my_args.payload]

Expand Down Expand Up @@ -161,6 +172,20 @@ async def _post_event(self: Generic, event: dict, index: int) -> None:
print(data) # noqa: T201
await self.queue.put(data)

async def _load_payload_from_file(self: Generic) -> None:
if not self.my_args.payload_file:
return
path = Path(self.my_args.payload_file)
if not path.is_file():
msg = f"File {self.my_args.payload_file} not found"
raise ValueError(msg)
with path.open(mode="r", encoding="utf-8") as file:
try:
self.my_args.payload = yaml.safe_load(file)
except yaml.YAMLError as exc:
msg = f"File {self.my_args.payload_file} parsing error {exc}"
raise ValueError(msg) from exc

def _create_data(
self: Generic,
index: int,
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/event_source/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" Tests for generic source plugin """

import asyncio
import os
import tempfile

import pytest
import yaml

from extensions.eda.plugins.event_source.generic import main as generic_main

Expand Down Expand Up @@ -176,3 +179,66 @@ def test_generic_bad_time_format():
},
)
)


def test_generic_payload_file():
"""Test reading events from file."""
myqueue = _MockQueue()
event = {"name": "fred"}
loop_count = 2

with tempfile.NamedTemporaryFile() as tmpfile:
with open(tmpfile.name, "w") as f:
yaml.dump(event, f)
asyncio.run(
generic_main(
myqueue,
{
"payload_file": tmpfile.name,
"loop_count": loop_count,
"create_index": "sequence",
},
)
)

assert len(myqueue.queue) == loop_count
index = 0
for i in range(loop_count):
expected_event = {"name": "fred", "sequence": i}
assert myqueue.queue[index] == expected_event
index += 1


def test_generic_missing_payload_file():
"""Test reading events from missing file."""
myqueue = _MockQueue()
with tempfile.TemporaryDirectory() as tmpdir:
fname = os.path.join(tmpdir, "missing.yaml")
with pytest.raises(ValueError):
asyncio.run(
generic_main(
myqueue,
{
"payload_file": fname,
},
)
)


def test_generic_parsing_payload_file():
"""Test parsing failure events from file."""
myqueue = _MockQueue()
with tempfile.TemporaryDirectory() as tmpdir:
fname = os.path.join(tmpdir, "bogus.yaml")
with open(fname, "w") as f:
f.write("fail_text: 'Hello, I'm testing!'")

with pytest.raises(ValueError):
asyncio.run(
generic_main(
myqueue,
{
"payload_file": fname,
},
)
)

0 comments on commit 76c044d

Please sign in to comment.