From caef93089d81f39fe211a0ffe3dced6089a7a972 Mon Sep 17 00:00:00 2001 From: ajasnosz <139114006+ajasnosz@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:05:51 +0100 Subject: [PATCH] fix: refactor create query (#1140) * fix: refactor create query * chore: move to common --- splunk_connect_for_snmp/common/hummanbool.py | 14 +++ splunk_connect_for_snmp/inventory/tasks.py | 116 +++++++++---------- test/common/test_humanbool.py | 22 +++- 3 files changed, 89 insertions(+), 63 deletions(-) diff --git a/splunk_connect_for_snmp/common/hummanbool.py b/splunk_connect_for_snmp/common/hummanbool.py index 217302c9c..f134e5b6e 100644 --- a/splunk_connect_for_snmp/common/hummanbool.py +++ b/splunk_connect_for_snmp/common/hummanbool.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import typing from typing import Union @@ -42,3 +43,16 @@ def human_bool(flag: Union[str, bool], default: bool = False) -> bool: return False else: return default + + +class BadlyFormattedFieldError(Exception): + pass + + +def convert_to_float(value: typing.Any, ignore_error: bool = False) -> typing.Any: + try: + return float(value) + except ValueError: + if ignore_error: + return value + raise BadlyFormattedFieldError(f"Value '{value}' should be numeric") diff --git a/splunk_connect_for_snmp/inventory/tasks.py b/splunk_connect_for_snmp/inventory/tasks.py index 101690a30..edc1d913c 100644 --- a/splunk_connect_for_snmp/inventory/tasks.py +++ b/splunk_connect_for_snmp/inventory/tasks.py @@ -36,7 +36,11 @@ from celery.utils.log import get_task_logger from splunk_connect_for_snmp import customtaskmanager -from splunk_connect_for_snmp.common.hummanbool import human_bool +from splunk_connect_for_snmp.common.hummanbool import ( + BadlyFormattedFieldError, + convert_to_float, + human_bool, +) from ..poller import app @@ -51,10 +55,6 @@ POLL_BASE_PROFILES = human_bool(os.getenv("POLL_BASE_PROFILES", "true")) -class BadlyFormattedFieldError(Exception): - pass - - class InventoryTask(Task): def __init__(self): self.mongo_client = pymongo.MongoClient(MONGO_URI) @@ -305,86 +305,78 @@ def create_profile(profile_name, frequency, varbinds, records): def create_query(conditions: typing.List[dict], address: str) -> dict: - conditional_profiles_mapping = { - "equals": "$eq", - "gt": "$gt", - "lt": "$lt", - "in": "$in", - "regex": "$regex", - } - - negative_profiles_mapping = { - "equals": "$ne", - "gt": "$lte", - "lt": "$gte", - "in": "$nin", - "regex": "$regex", + # Define mappings for conditional and negative profiles + profile_mappings = { + "positive": { + "equals": "$eq", + "gt": "$gt", + "lt": "$lt", + "in": "$in", + "regex": "$regex", + }, + "negative": { + "equals": "$ne", + "gt": "$lte", + "lt": "$gte", + "in": "$nin", + "regex": "$regex", + }, } + # Helper functions def _parse_mib_component(field: str) -> str: - mib_component = field.split("|") - if len(mib_component) < 2: + components = field.split("|") + if len(components) < 2: raise BadlyFormattedFieldError(f"Field {field} is badly formatted") - return mib_component[0] - - def _convert_to_float(value: typing.Any, ignore_error=False) -> typing.Any: - try: - return float(value) - except ValueError: - if ignore_error: - return value - else: - raise BadlyFormattedFieldError(f"Value '{value}' should be numeric") + return components[0] def _prepare_regex(value: str) -> typing.Union[list, str]: pattern = value.strip("/").split("/") - if len(pattern) > 1: - return pattern - else: - return pattern[0] + return pattern if len(pattern) > 1 else pattern[0] - def _get_value_for_operation(operation: str, value: str) -> typing.Any: - if operation in ["lt", "gt"]: - return _convert_to_float(value) - elif operation == "in": - return [_convert_to_float(v, True) for v in value] - elif operation == "regex": - return _prepare_regex(value) - return value + def _get_value_for_operation(operation: str, value: typing.Any) -> typing.Any: + operation_handlers = { + "lt": lambda v: convert_to_float(v), + "gt": lambda v: convert_to_float(v), + "in": lambda v: [convert_to_float(item, True) for item in v], + "regex": lambda v: _prepare_regex(v), + } + return operation_handlers.get(operation, lambda v: v)(value) def _prepare_query_input( - operation: str, value: typing.Any, field: str, negate_operation: bool + operation: str, value: typing.Any, field: str, negate: bool, mongo_op: str ) -> dict: - if operation == "regex" and isinstance(value, list): - query = {mongo_operation: value[0], "$options": value[1]} - else: - query = {mongo_operation: value} - if operation == "regex" and negate_operation: + query = ( + {mongo_op: value} + if not (operation == "regex" and isinstance(value, list)) + else {mongo_op: value[0], "$options": value[1]} + ) + if operation == "regex" and negate: query = {"$not": query} return {f"fields.{field}.value": query} + # Main processing loop filters = [] - field = "" for condition in conditions: - field = condition["field"] - # fields in databases are written in convention "IF-MIB|ifInOctets" - field = field.replace(".", "|") + field = condition["field"].replace(".", "|") # Standardize field format value = condition["value"] - negate_operation = human_bool( - condition.get("negate_operation", False), default=False - ) + negate = human_bool(condition.get("negate_operation", False), default=False) operation = condition["operation"].lower() - value_for_querying = _get_value_for_operation(operation, value) - mongo_operation = ( - negative_profiles_mapping.get(operation) - if negate_operation - else conditional_profiles_mapping.get(operation) + + # Determine MongoDB operator and prepare query + mongo_op = profile_mappings["negative" if negate else "positive"].get( + operation, "" ) + value_for_query = _get_value_for_operation(operation, value) query = _prepare_query_input( - operation, value_for_querying, field, negate_operation + operation, value_for_query, field, negate, mongo_op ) filters.append(query) + + # Parse MIB component for address matching mib_component = _parse_mib_component(field) + + # Construct final query return { "$and": [ {"address": address}, diff --git a/test/common/test_humanbool.py b/test/common/test_humanbool.py index 8c5752b9d..73a1c0f57 100644 --- a/test/common/test_humanbool.py +++ b/test/common/test_humanbool.py @@ -1,6 +1,10 @@ from unittest import TestCase -from splunk_connect_for_snmp.common.hummanbool import human_bool +from splunk_connect_for_snmp.common.hummanbool import ( + BadlyFormattedFieldError, + convert_to_float, + human_bool, +) class TestHumanBool(TestCase): @@ -32,3 +36,19 @@ def test_human_bool_default(self): self.assertTrue(human_bool("foo", True)) self.assertFalse(human_bool("1foo", False)) self.assertFalse(human_bool("1FoO")) + + def test_convert_to_float(self): + value = 1 + result = convert_to_float(value) + self.assertIsInstance(result, float) + + def test_convert_to_float_ignore(self): + value = "up" + result = convert_to_float(value, True) + self.assertEqual(result, value) + + def test_convert_to_float_error(self): + value = "up" + with self.assertRaises(BadlyFormattedFieldError) as context: + convert_to_float(value) + self.assertEqual("Value 'up' should be numeric", context.exception.args[0])