From 6dbe917a5a2d37c6696db057342d0b44fe7c9fc5 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 17 May 2024 12:23:18 -0500 Subject: [PATCH] switch to ruff --- .github/workflows/pr_qa.yml | 4 +- Makefile | 7 +- app.py | 12 +- requirements-dev.txt | 3 +- src/blueprints/guards.py | 10 +- src/blueprints/root.py | 2 +- src/classes/data_type_struct.py | 46 ++-- src/classes/guard_struct.py | 10 +- src/classes/rail_spec_struct.py | 17 +- src/classes/schema_element_struct.py | 16 +- src/classes/schema_struct.py | 16 +- src/classes/validation_output.py | 5 +- src/clients/guard_client.py | 8 +- src/clients/postgres_client.py | 4 +- src/models/guard_item.py | 2 +- src/utils/get_llm_callable.py | 4 +- src/utils/handle_error.py | 4 +- src/utils/payload_validator.py | 4 +- src/utils/remove_nones.py | 5 +- tests/blueprints/test_guards.py | 256 ++++++++++++--------- tests/blueprints/test_root.py | 20 +- tests/clients/test_guard_client.py | 320 +++++++++++++++------------ tests/mocks/mock_blueprint.py | 11 +- tests/mocks/mock_guard_client.py | 18 +- tests/mocks/mock_postgres_client.py | 14 +- tests/mocks/mock_request.py | 5 +- tests/mocks/mock_trace.py | 8 +- tests/utils/test_pluck.py | 7 +- tests/utils/test_remove_nones.py | 58 ++--- 29 files changed, 453 insertions(+), 443 deletions(-) diff --git a/.github/workflows/pr_qa.yml b/.github/workflows/pr_qa.yml index 0e059be..4a32324 100644 --- a/.github/workflows/pr_qa.yml +++ b/.github/workflows/pr_qa.yml @@ -18,8 +18,8 @@ jobs: - name: Quality Checks run: | python -m pip install --upgrade pip; - make install-lock - make build; + make install-lock; + opentelemetry-bootstrap -a install; make qa; ## diff --git a/Makefile b/Makefile index 0b1b285..4f0149f 100644 --- a/Makefile +++ b/Makefile @@ -36,10 +36,13 @@ refresh: format: - black -l 80 ./src app.py wsgi.py + ruff check app.py wsgi.py src/ tests/ --fix + ruff format app.py wsgi.py src/ tests/ + lint: - flake8 --count ./src app.py wsgi.py + ruff check app.py wsgi.py src/ tests/ + ruff format app.py wsgi.py src/ tests/ qa: make lint diff --git a/app.py b/app.py index 3f00f9c..aad8666 100644 --- a/app.py +++ b/app.py @@ -14,25 +14,19 @@ def __init__(self, app): def __call__(self, environ, start_response): self_endpoint = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") url = urlparse(self_endpoint) - environ['wsgi.url_scheme'] = url.scheme + environ["wsgi.url_scheme"] = url.scheme return self.app(environ, start_response) def create_app(): app = Flask(__name__) - app.config['APPLICATION_ROOT'] = '/' + app.config["APPLICATION_ROOT"] = "/" app.config["PREFERRED_URL_SCHEME"] = "https" app.wsgi_app = ReverseProxied(app.wsgi_app) CORS(app) - app.wsgi_app = ProxyFix( - app.wsgi_app, - x_for=1, - x_proto=1, - x_host=1, - x_port=1 - ) + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO") configure_logging(log_level=guardrails_log_level) diff --git a/requirements-dev.txt b/requirements-dev.txt index 359040a..bb740cd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,4 @@ -black -flake8 +ruff pytest coverage pytest-mock \ No newline at end of file diff --git a/src/blueprints/guards.py b/src/blueprints/guards.py index bccb5be..fef6bb1 100644 --- a/src/blueprints/guards.py +++ b/src/blueprints/guards.py @@ -63,9 +63,7 @@ def guard(guard_name: str): 405, "Method Not Allowed", "/guard/ only supports the GET, PUT, and DELETE methods." - " You specified {request_method}".format( - request_method=request.method - ), + " You specified {request_method}".format(request_method=request.method), ) @@ -151,7 +149,7 @@ def validate(guard_name: str): result.validation_passed, result.validated_output, guard.history, - result.raw_llm_output + result.raw_llm_output, ) prompt = guard.history.last.inputs.prompt @@ -175,7 +173,9 @@ def validate(guard_name: str): ) validate_span.set_attribute("validated_output", valid_output) - validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) + validate_span.set_attribute( + "tokens_consumed", guard.history.last.tokens_consumed + ) num_of_reasks = ( guard.history.last.iterations.length - 1 diff --git a/src/blueprints/root.py b/src/blueprints/root.py index 9d74d76..759a584 100644 --- a/src/blueprints/root.py +++ b/src/blueprints/root.py @@ -46,7 +46,7 @@ def health_check(): def api_docs(): global cached_api_spec if not cached_api_spec: - with open('./open-api-spec.json') as api_spec_file: + with open("./open-api-spec.json") as api_spec_file: cached_api_spec = json.loads(api_spec_file.read()) return json.dumps(cached_api_spec) diff --git a/src/classes/data_type_struct.py b/src/classes/data_type_struct.py index 4af44b7..f687e02 100644 --- a/src/classes/data_type_struct.py +++ b/src/classes/data_type_struct.py @@ -125,13 +125,9 @@ def from_dict(cls, data_type: dict): class_children = {} elem_type = element["type"] if element is not None else None elem_is_list = elem_type == "list" - child_entries = ( - children.get("item", {}) if elem_is_list else children - ) + child_entries = children.get("item", {}) if elem_is_list else children for child_key in child_entries: - class_children[child_key] = cls.from_dict( - child_entries[child_key] - ) + class_children[child_key] = cls.from_dict(child_entries[child_key]) children_data_types = ( {"item": class_children} if elem_is_list else class_children ) @@ -155,13 +151,9 @@ def to_dict(self): self.children.get("item", {}) if elem_is_list else self.children ) for child_key in child_entries: - serialized_children[child_key] = child_entries[ - child_key - ].to_dict() + serialized_children[child_key] = child_entries[child_key].to_dict() response["children"] = ( - {"item": serialized_children} - if elem_is_list - else serialized_children + {"item": serialized_children} if elem_is_list else serialized_children ) if self.plugins is not None: @@ -174,14 +166,16 @@ def from_request(cls, data_type: dict): if data_type: children, formatters, element, plugins = pluck( data_type, ["children", "formatters", "element", "plugins"] - ) + ) children_data_types = None if children: class_children = {} elem_type = element.get("type") if element is not None else None elem_is_list = elem_type == "list" child_entries = ( - children.get("item", {}).get("children", {}) if elem_is_list else children + children.get("item", {}).get("children", {}) + if elem_is_list + else children ) for child_key in child_entries: class_children[child_key] = cls.from_request( @@ -211,13 +205,9 @@ def to_response(self): self.children.get("item", {}) if elem_is_list else self.children ) for child_key in child_entries: - serialized_children[child_key] = child_entries[ - child_key - ].to_response() + serialized_children[child_key] = child_entries[child_key].to_response() response["children"] = ( - {"item": serialized_children} - if elem_is_list - else serialized_children + {"item": serialized_children} if elem_is_list else serialized_children ) if self.plugins is not None: @@ -255,9 +245,7 @@ def from_xml(cls, elem: _Element): return cls(children, formatters, element, plugins) - def to_xml( - self, parent: _Element, as_parent: Optional[bool] = False - ) -> _Element: + def to_xml(self, parent: _Element, as_parent: Optional[bool] = False) -> _Element: element = None if as_parent: element = parent @@ -267,9 +255,7 @@ def to_xml( else: element = self.element.to_element() - format = ( - "; ".join(self.formatters) if len(self.formatters) > 0 else None - ) + format = "; ".join(self.formatters) if len(self.formatters) > 0 else None if format is not None: element.attrib["format"] = format @@ -278,7 +264,6 @@ def to_xml( if plugins is not None: element.attrib["plugins"] = plugins - stringified_attribs = {} for k, v in element.attrib.items(): stringified_attribs[k] = str(v) @@ -295,8 +280,7 @@ def to_xml( ) _parent = xml_data_type if self_is_list and ( - len(child_entries) > 0 - or child_entries[0].element.name is not None + len(child_entries) > 0 or child_entries[0].element.name is not None ): _parent = SubElement(xml_data_type, "object") for child_key in child_entries: @@ -309,9 +293,7 @@ def get_all_plugins(self) -> List[str]: self_is_list = self.element.type == "list" if self.children is not None: - children = ( - self.children.get("item", {}) if self_is_list else self.children - ) + children = self.children.get("item", {}) if self_is_list else self.children for child_key in children: plugins.extend(children[child_key].get_all_plugins()) return plugins diff --git a/src/classes/guard_struct.py b/src/classes/guard_struct.py index 596916d..ea0dc4e 100644 --- a/src/classes/guard_struct.py +++ b/src/classes/guard_struct.py @@ -14,7 +14,7 @@ def __init__( name: str, railspec: RailSpecStruct, num_reasks: int = None, - description: str = None + description: str = None, # base_model: dict = None, ): self.name = name @@ -52,9 +52,7 @@ def from_dict(cls, guard: dict): name, railspec, num_reasks, description = pluck( guard, ["name", "railspec", "num_reasks", "description"] ) - return cls( - name, RailSpecStruct.from_dict(railspec), num_reasks, description - ) + return cls(name, RailSpecStruct.from_dict(railspec), num_reasks, description) def to_dict(self) -> dict: return { @@ -100,6 +98,4 @@ def from_railspec( num_reasks: int = None, description: str = None, ): - return cls( - name, RailSpecStruct.from_xml(railspec), num_reasks, description - ) + return cls(name, RailSpecStruct.from_xml(railspec), num_reasks, description) diff --git a/src/classes/rail_spec_struct.py b/src/classes/rail_spec_struct.py index 61cece1..2658bda 100644 --- a/src/classes/rail_spec_struct.py +++ b/src/classes/rail_spec_struct.py @@ -33,12 +33,8 @@ def from_rail(cls, rail: Rail): ) def to_rail(self) -> Rail: - input_schema = ( - self.input_schema.to_schema() if self.input_schema else None - ) - output_schema = ( - self.output_schema.to_schema() if self.output_schema else None - ) + input_schema = self.input_schema.to_schema() if self.input_schema else None + output_schema = self.output_schema.to_schema() if self.output_schema else None # TODO: This might not be necessary anymore since we stopped # BasePrompt from formatting on init escaped_instructions = escape_curlys(self.instructions) @@ -51,9 +47,7 @@ def to_rail(self) -> Rail: # TODO: This might not be necessary anymore since we stopped # BasePrompt from formatting on init escaped_prompt = escape_curlys(self.prompt) - prompt = ( - Prompt(escaped_prompt, output_schema) if escaped_prompt else None - ) + prompt = Prompt(escaped_prompt, output_schema) if escaped_prompt else None prompt.source = descape_curlys(prompt.source) return Rail( input_schema, @@ -148,10 +142,7 @@ def from_xml(cls, railspec: str): xml_parser = etree.XMLParser(encoding="utf-8") elem_tree = etree.fromstring(railspec, parser=xml_parser) - if ( - "version" not in elem_tree.attrib - or elem_tree.attrib["version"] != "0.1" - ): + if "version" not in elem_tree.attrib or elem_tree.attrib["version"] != "0.1": raise ValueError( "RAIL file must have a version attribute set to 0.1." "Change the opening element to: ." diff --git a/src/classes/schema_element_struct.py b/src/classes/schema_element_struct.py index de5e20b..07e9e6f 100644 --- a/src/classes/schema_element_struct.py +++ b/src/classes/schema_element_struct.py @@ -41,9 +41,9 @@ def to_element(self) -> ElementStub: for validator_on_fail in self.on_fails: validator_tag = validator_on_fail.get("validatorTag", "") escaped_validator_tag = validator_tag.replace("/", "_") - elem_dict[ - f"on-fail-{escaped_validator_tag}" - ] = validator_on_fail.get("method") + elem_dict[f"on-fail-{escaped_validator_tag}"] = validator_on_fail.get( + "method" + ) elem_dict.pop("date_format", None) elem_dict.pop("time_format", None) elem_dict.pop("on_fails", None) @@ -184,11 +184,7 @@ def from_xml(cls, xml: _Element): strict = None strict_tag = xml.get("strict", "False") if strict_tag: - strict = ( - True - if strict_tag.lower() == "true" - else False - ) + strict = True if strict_tag.lower() == "true" else False date_format = xml.get("date-format") time_format = xml.get("time-format") on_fail = xml.get("on-fail") @@ -209,9 +205,7 @@ def from_xml(cls, xml: _Element): if attr_key.startswith("on-fail") and attr_key != "on-fail": on_fail_method = xml.get(attr_key) on_fail_tag = attr_key - on_fails.append( - {"validatorTag": on_fail_tag, "method": on_fail_method} - ) + on_fails.append({"validatorTag": on_fail_tag, "method": on_fail_method}) elif attr_key not in handled_keys: kwargs[attr_key] = xml.get(attr_key) model = xml.get("model") diff --git a/src/classes/schema_struct.py b/src/classes/schema_struct.py index 837e55e..cc2f339 100644 --- a/src/classes/schema_struct.py +++ b/src/classes/schema_struct.py @@ -22,9 +22,7 @@ def from_schema(cls, schema: Schema): else: for key in schema.root_datatype: schema_element = schema.root_datatype[key] - serialized_schema[key] = DataTypeStruct.from_data_type( - schema_element - ) + serialized_schema[key] = DataTypeStruct.from_data_type(schema_element) return cls({"schema": serialized_schema}) def to_schema(self) -> Schema: @@ -34,9 +32,7 @@ def to_schema(self) -> Schema: if isinstance(inner_schema, DataTypeStruct): string_schema = StringSchema() string_schema.string_key = inner_schema.element.name - string_schema[ - string_schema.string_key - ] = inner_schema.to_data_type() + string_schema[string_schema.string_key] = inner_schema.to_data_type() return string_schema for key in inner_schema: @@ -55,9 +51,7 @@ def from_dict(cls, schema: dict): else: for key in inner_schema: schema_element = inner_schema[key] - serialized_schema[key] = DataTypeStruct.from_dict( - schema_element - ) + serialized_schema[key] = DataTypeStruct.from_dict(schema_element) return cls({"schema": serialized_schema}) def to_dict(self): @@ -84,9 +78,7 @@ def from_request(cls, schema: dict): # JsonSchema for key in inner_schema: schema_element = inner_schema[key] - serialized_schema[key] = DataTypeStruct.from_request( - schema_element - ) + serialized_schema[key] = DataTypeStruct.from_request(schema_element) return cls({"schema": serialized_schema}) def to_response(self): diff --git a/src/classes/validation_output.py b/src/classes/validation_output.py index f114d34..827a3f1 100644 --- a/src/classes/validation_output.py +++ b/src/classes/validation_output.py @@ -21,7 +21,8 @@ def __init__( "instructions": i.inputs.instructions.source if i.inputs.instructions is not None else None, - "output": i.outputs.raw_output or i.outputs.llm_response_info.output, + "output": i.outputs.raw_output + or i.outputs.llm_response_info.output, "parsedOutput": i.parsed_output, "prompt": { "source": i.inputs.prompt.source @@ -49,7 +50,7 @@ def __init__( "propertyPath": fv.property_path, } for fv in i.failed_validations - ) + ), } for i in c.iterations ] diff --git a/src/clients/guard_client.py b/src/clients/guard_client.py index 7178909..13815cd 100644 --- a/src/clients/guard_client.py +++ b/src/clients/guard_client.py @@ -13,9 +13,7 @@ def __init__(self): def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: latest_guard_item = ( - self.pgClient.db.session.query(GuardItem) - .filter_by(name=guard_name) - .first() + self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() ) audit_item = None if as_of_date is not None: @@ -37,9 +35,7 @@ def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: def get_guard_item(self, guard_name: str) -> GuardItem: return ( - self.pgClient.db.session.query(GuardItem) - .filter_by(name=guard_name) - .first() + self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() ) def get_guards(self) -> List[GuardStruct]: diff --git a/src/clients/postgres_client.py b/src/clients/postgres_client.py index 13a64d9..6a868c1 100644 --- a/src/clients/postgres_client.py +++ b/src/clients/postgres_client.py @@ -52,9 +52,7 @@ def initialize(self, app: Flask): else f"{pg_host}:{pg_port}" ) - conf = ( - f"postgresql://{pg_user}:{pg_password}@{pg_endpoint}/{pg_database}" - ) + conf = f"postgresql://{pg_user}:{pg_password}@{pg_endpoint}/{pg_database}" if os.environ.get("NODE_ENV") == "production": conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" diff --git a/src/models/guard_item.py b/src/models/guard_item.py index dc41d61..663603f 100644 --- a/src/models/guard_item.py +++ b/src/models/guard_item.py @@ -17,7 +17,7 @@ def __init__( name, railspec, num_reasks, - description + description, # owner = None ): self.name = name diff --git a/src/utils/get_llm_callable.py b/src/utils/get_llm_callable.py index ba3ad28..bc68728 100644 --- a/src/utils/get_llm_callable.py +++ b/src/utils/get_llm_callable.py @@ -3,7 +3,7 @@ get_static_openai_create_func, get_static_openai_chat_create_func, get_static_openai_acreate_func, - get_static_openai_chat_acreate_func + get_static_openai_chat_acreate_func, ) from guardrails_api_client.models.validate_payload_llm_api import ( ValidatePayloadLlmApi, @@ -38,4 +38,4 @@ def get_llm_callable( else: pass except Exception: - pass \ No newline at end of file + pass diff --git a/src/utils/handle_error.py b/src/utils/handle_error.py index 8da18af..b21796c 100644 --- a/src/utils/handle_error.py +++ b/src/utils/handle_error.py @@ -15,9 +15,7 @@ def decorator(*args, **kwargs): traceback.print_exception(http_error) return http_error.to_dict(), http_error.status except HTTPException as http_exception: - http_error = HttpError( - http_exception.code, http_exception.description - ) + http_error = HttpError(http_exception.code, http_exception.description) return http_error.to_dict(), http_error.status except Exception as e: logger.error(e) diff --git a/src/utils/payload_validator.py b/src/utils/payload_validator.py index b3ba0cb..514af3d 100644 --- a/src/utils/payload_validator.py +++ b/src/utils/payload_validator.py @@ -4,9 +4,9 @@ from src.classes.http_error import HttpError from src.utils.remove_nones import remove_nones -with open('./open-api-spec.json') as api_spec_file: +with open("./open-api-spec.json") as api_spec_file: api_spec = json.loads(api_spec_file.read()) - + registry = Registry().with_resources( [ ( diff --git a/src/utils/remove_nones.py b/src/utils/remove_nones.py index 9bfb7bd..c85a879 100644 --- a/src/utils/remove_nones.py +++ b/src/utils/remove_nones.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List -def remove_nones (dictionary: Dict[str, Any]) -> Dict[str, Any]: + +def remove_nones(dictionary: Dict[str, Any]) -> Dict[str, Any]: filtered = {} for key, value in list(dictionary.items()): if isinstance(value, Dict): @@ -16,4 +17,4 @@ def remove_nones (dictionary: Dict[str, Any]) -> Dict[str, Any]: filtered[key] = filtered_list elif value is not None: filtered[key] = dictionary[key] - return filtered \ No newline at end of file + return filtered diff --git a/tests/blueprints/test_guards.py b/tests/blueprints/test_guards.py index 0845241..154fa2b 100644 --- a/tests/blueprints/test_guards.py +++ b/tests/blueprints/test_guards.py @@ -11,184 +11,207 @@ def test_route_setup(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) - + from src.blueprints.guards import guards_bp - - + assert guards_bp.route_call_count == 3 assert guards_bp.routes == ["/", "/", "//validate"] - def test_guards__get(mocker): mock_guard = MockGuardStruct() mock_request = MockRequest("GET") - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guards = mocker.patch("src.blueprints.guards.GuardClient.get_guards", return_value=[mock_guard]) + mock_get_guards = mocker.patch( + "src.blueprints.guards.GuardClient.get_guards", return_value=[mock_guard] + ) mocker.patch("src.blueprints.guards.get_tracer") from src.blueprints.guards import guards response = guards() - + assert mock_get_guards.call_count == 1 - assert response == [{ "name": "mock-guard" }] + assert response == [{"name": "mock-guard"}] def test_guards__post(mocker): mock_guard = MockGuardStruct() mock_request = MockRequest("POST", mock_guard.to_response()) - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_from_request = mocker.patch("src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard) - mock_create_guard = mocker.patch("src.blueprints.guards.GuardClient.create_guard", return_value=mock_guard) + mock_from_request = mocker.patch( + "src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard + ) + mock_create_guard = mocker.patch( + "src.blueprints.guards.GuardClient.create_guard", return_value=mock_guard + ) mocker.patch("src.blueprints.guards.get_tracer") from src.blueprints.guards import guards - response = guards() - + assert mock_from_request.called_once_with(mock_guard) assert mock_create_guard.called_once_with(mock_guard) - assert response == { "name": "mock-guard" } + assert response == {"name": "mock-guard"} def test_guards__raises(mocker): mock_request = MockRequest("PUT") - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import guards - + response = guards() - + assert isinstance(response, Tuple) error, status = response assert isinstance(error, Dict) assert error.get("status") == 405 assert error.get("message") == "Method Not Allowed" - assert error.get("cause") == "/guards only supports the GET and POST methods. You specified PUT" + assert ( + error.get("cause") + == "/guards only supports the GET and POST methods. You specified PUT" + ) assert status == 405 + def test_guard__get(mocker): mock_guard = MockGuardStruct() timestamp = "2024-03-04T14:11:42-06:00" - mock_request = MockRequest("GET", args={ "asOf": timestamp }) - + mock_request = MockRequest("GET", args={"asOf": timestamp}) + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch("src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard) + mock_get_guard = mocker.patch( + "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + ) mocker.patch("src.blueprints.guards.get_tracer") from src.blueprints.guards import guard - + response = guard("My%20Guard's%20Name") - + assert mock_get_guard.called_once_with("My Guard's Name", timestamp) - assert response == { "name": "mock-guard" } + assert response == {"name": "mock-guard"} def test_guard__put(mocker): mock_guard = MockGuardStruct() - mock_request = MockRequest("PUT", json={ "name": "mock-guard" }) - + mock_request = MockRequest("PUT", json={"name": "mock-guard"}) + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_from_request = mocker.patch("src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard) - mock_upsert_guard = mocker.patch("src.blueprints.guards.GuardClient.upsert_guard", return_value=mock_guard) + mock_from_request = mocker.patch( + "src.blueprints.guards.GuardStruct.from_request", return_value=mock_guard + ) + mock_upsert_guard = mocker.patch( + "src.blueprints.guards.GuardClient.upsert_guard", return_value=mock_guard + ) mocker.patch("src.blueprints.guards.get_tracer") from src.blueprints.guards import guard - + response = guard("My%20Guard's%20Name") - + assert mock_from_request.called_once_with(mock_guard) assert mock_upsert_guard.called_once_with("My Guard's Name", mock_guard) - assert response == { "name": "mock-guard" } + assert response == {"name": "mock-guard"} def test_guard__delete(mocker): mock_guard = MockGuardStruct() mock_request = MockRequest("DELETE") - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_delete_guard = mocker.patch("src.blueprints.guards.GuardClient.delete_guard", return_value=mock_guard) + mock_delete_guard = mocker.patch( + "src.blueprints.guards.GuardClient.delete_guard", return_value=mock_guard + ) mocker.patch("src.blueprints.guards.get_tracer") from src.blueprints.guards import guard - + response = guard("my-guard-name") - + assert mock_delete_guard.called_once_with("my-guard-name") - assert response == { "name": "mock-guard" } + assert response == {"name": "mock-guard"} def test_guard__raises(mocker): mock_request = MockRequest("POST") - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import guard - + response = guard("guard") - + assert isinstance(response, Tuple) error, status = response assert isinstance(error, Dict) assert error.get("status") == 405 assert error.get("message") == "Method Not Allowed" - assert error.get("cause") == "/guard/ only supports the GET, PUT, and DELETE methods. You specified POST" + assert ( + error.get("cause") + == "/guard/ only supports the GET, PUT, and DELETE methods. You specified POST" + ) assert status == 405 def test_validate__raises_method_not_allowed(mocker): mock_request = MockRequest("PUT") - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) mocker.patch("src.blueprints.guards.get_tracer") mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate - + response = validate("guard") - + assert isinstance(response, Tuple) error, status = response assert isinstance(error, Dict) assert error.get("status") == 405 assert error.get("message") == "Method Not Allowed" - assert error.get("cause") == "/guards//validate only supports the POST method. You specified PUT" + assert ( + error.get("cause") + == "/guards//validate only supports the POST method. You specified PUT" + ) assert status == 405 - - + + def test_validate__raises_bad_request__openai_api_key(mocker): mock_guard = MockGuardStruct() mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={ "llmApi": "bar" }) - + mock_request = MockRequest("POST", json={"llmApi": "bar"}) + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch("src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard) + mock_get_guard = mocker.patch( + "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate - + response = validate("My%20Guard's%20Name") - + assert mock_prep_environment.call_count == 1 assert mock_get_guard.called_once_with("My Guard's Name") - + assert isinstance(response, Tuple) error, status = response assert isinstance(error, Dict) @@ -205,22 +228,24 @@ def test_validate__raises_bad_request__openai_api_key(mocker): def test_validate__raises_bad_request__num_reasks(mocker): mock_guard = MockGuardStruct() mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={ "numReasks": 3 }) - + mock_request = MockRequest("POST", json={"numReasks": 3}) + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch("src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard) + mock_get_guard = mocker.patch( + "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) mocker.patch("src.utils.handle_error.logger.error") mocker.patch("src.utils.handle_error.traceback.print_exception") from src.blueprints.guards import validate - + response = validate("My%20Guard's%20Name") - + assert mock_prep_environment.call_count == 1 assert mock_get_guard.called_once_with("My Guard's Name") - + assert isinstance(response, Tuple) error, status = response assert isinstance(error, Dict) @@ -238,47 +263,51 @@ def test_validate__parse(mocker): mock_parse.return_value = ValidationOutcome( raw_llm_output="Hello world!", validated_output="Hello world!", - validation_passed=True + validation_passed=True, ) mock_guard = MockGuardStruct() mock_tracer = MockTracer() mock_request = MockRequest( "POST", - json={ "llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo" } + json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, ) - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch("src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard) + mock_get_guard = mocker.patch( + "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mock_cleanup_environment = mocker.patch("src.blueprints.guards.cleanup_environment") mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) - + set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch("guardrails.classes.history.call.Call.status", new_callable=PropertyMock) - mock_status.return_value = "pass" - mock_guard.history = Stack( - Call(inputs=CallInputs(prompt="Hello world prompt!")) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock ) + mock_status.return_value = "pass" + mock_guard.history = Stack(Call(inputs=CallInputs(prompt="Hello world prompt!"))) from src.blueprints.guards import validate - + response = validate("My%20Guard's%20Name") - + assert mock_prep_environment.call_count == 1 assert mock_get_guard.called_once_with("My Guard's Name") - + assert mock_parse.call_count == 1 - + assert mock_parse.called_once_with( - 1, 2, 3, + 1, + 2, + 3, llm_output="Hello world!", num_reasks=0, prompt_params=None, llm_api=None, - some_kwarg="foo" + some_kwarg="foo", ) - + assert set_attribute_spy.call_count == 7 expected_calls = [ call("guardName", "My Guard's Name"), @@ -290,66 +319,76 @@ def test_validate__parse(mocker): call("num_of_reasks", 0), ] set_attribute_spy.assert_has_calls(expected_calls) - + assert mock_cleanup_environment.call_count == 1 - - assert response == {"result": True, "validatedOutput": "Hello world!", "sessionHistory": [{"history": []}], "rawLlmResponse": "Hello world!"} - - + + assert response == { + "result": True, + "validatedOutput": "Hello world!", + "sessionHistory": [{"history": []}], + "rawLlmResponse": "Hello world!", + } + + def test_validate__call(mocker): mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") mock___call__.return_value = ValidationOutcome( - raw_llm_output="Hello world!", - validated_output=None, - validation_passed=False + raw_llm_output="Hello world!", validated_output=None, validation_passed=False ) mock_guard = MockGuardStruct() mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={ + mock_request = MockRequest( + "POST", + json={ "llmApi": "openai.Completion.create", - "promptParams": { "p1": "bar" }, + "promptParams": {"p1": "bar"}, "args": [1, 2, 3], - "some_kwarg": "foo" + "some_kwarg": "foo", }, - headers={ "x-openai-api-key": "mock-key" } + headers={"x-openai-api-key": "mock-key"}, ) - + mocker.patch("flask.Blueprint", new=MockBlueprint) mocker.patch("src.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch("src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard) + mock_get_guard = mocker.patch( + "src.blueprints.guards.GuardClient.get_guard", return_value=mock_guard + ) mock_prep_environment = mocker.patch("src.blueprints.guards.prep_environment") mock_cleanup_environment = mocker.patch("src.blueprints.guards.cleanup_environment") mocker.patch("src.blueprints.guards.get_tracer", return_value=mock_tracer) - + set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch("guardrails.classes.history.call.Call.status", new_callable=PropertyMock) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) mock_status.return_value = "fail" mock_guard.history = Stack( Call( inputs=CallInputs( - prompt="Hello world prompt!", - instructions="Hello world instructions!" + prompt="Hello world prompt!", instructions="Hello world instructions!" ) ) ) from src.blueprints.guards import validate - + response = validate("My%20Guard's%20Name") - + assert mock_prep_environment.call_count == 1 assert mock_get_guard.called_once_with("My Guard's Name") - + assert mock___call__.call_count == 1 - + assert mock___call__.called_once_with( - 1, 2, 3, + 1, + 2, + 3, llm_api="openai.Completion.create", - prompt_params={ "p1": "bar" }, + prompt_params={"p1": "bar"}, num_reasks=0, - some_kwarg="foo" + some_kwarg="foo", ) - + assert set_attribute_spy.call_count == 8 expected_calls = [ call("guardName", "My Guard's Name"), @@ -362,9 +401,12 @@ def test_validate__call(mocker): call("num_of_reasks", 0), ] set_attribute_spy.assert_has_calls(expected_calls) - + assert mock_cleanup_environment.call_count == 1 - - assert response == {"result": False, "validatedOutput": None, "sessionHistory": [{"history": []}], "rawLlmResponse": "Hello world!"} - - \ No newline at end of file + + assert response == { + "result": False, + "validatedOutput": None, + "sessionHistory": [{"history": []}], + "rawLlmResponse": "Hello world!", + } diff --git a/tests/blueprints/test_root.py b/tests/blueprints/test_root.py index add92b1..e0a85fe 100644 --- a/tests/blueprints/test_root.py +++ b/tests/blueprints/test_root.py @@ -15,29 +15,29 @@ def test_home(mocker): mocker.resetall() + def test_health_check(mocker): mocker.patch("flask.Blueprint", new=MockBlueprint) mock_pg = MockPostgresClient() mock_pg.db.session._set_rows([(1,)]) mocker.patch("src.blueprints.root.PostgresClient", return_value=mock_pg) - + def text_side_effect(query: str): return query + mock_text = mocker.patch("src.blueprints.root.text", side_effect=text_side_effect) - - from src.blueprints.root import ( - health_check - ) + + from src.blueprints.root import health_check + info_spy = mocker.spy(logger, "info") - + response = health_check() - + assert mock_text.called_once_with("SELECT count(datid) FROM pg_stat_activity;") assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] - + info_spy.assert_called_once_with("response: %s", [(1,)]) - assert response == { "status": 200, "message": "Ok" } + assert response == {"status": 200, "message": "Ok"} mocker.resetall() - diff --git a/tests/clients/test_guard_client.py b/tests/clients/test_guard_client.py index e746d22..660027f 100644 --- a/tests/clients/test_guard_client.py +++ b/tests/clients/test_guard_client.py @@ -12,42 +12,49 @@ def test_init(mocker): mock_pg_client = MockPostgresClient() mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + guard_client = GuardClient() - + assert guard_client.initialized is True assert isinstance(guard_client.pgClient, MockPostgresClient) assert guard_client.pgClient == mock_pg_client - + + class TestGetGuard: def test_get_latest(self, mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + query_spy = mocker.spy(mock_pg_client.db.session, "query") filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") mock_first = mocker.patch.object(mock_pg_client.db.session, "first") latest_guard = MockGuardStruct("latest") mock_first.return_value = latest_guard - - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = latest_guard - + guard_client = GuardClient() - + result = guard_client.get_guard("guard") - + query_spy.assert_called_once_with(GuardItem) filter_by_spy.assert_called_once_with(name="guard") assert mock_first.call_count == 1 assert mock_from_guard_item.called_once_with(latest_guard) - + assert result == latest_guard - + def test_with_as_of_date(self, mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + query_spy = mocker.spy(mock_pg_client.db.session, "query") filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") filter_spy = mocker.spy(mock_pg_client.db.session, "filter") @@ -56,141 +63,141 @@ def test_with_as_of_date(self, mocker): latest_guard = MockGuardStruct("latest") previous_guard = MockGuardStruct("previous") mock_first.side_effect = [latest_guard, previous_guard] - - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = previous_guard - + guard_client = GuardClient() - + result = guard_client.get_guard("guard", as_of_date="2024-03-06") - + assert query_spy.call_count == 2 - query_calls = [ - call(GuardItem), - call(GuardItemAudit) - ] + query_calls = [call(GuardItem), call(GuardItemAudit)] query_spy.assert_has_calls(query_calls) - - filter_by_calls = [ - call(name="guard"), - call(name="guard") - ] + + filter_by_calls = [call(name="guard"), call(name="guard")] assert filter_by_spy.call_count == 2 filter_by_spy.assert_has_calls(filter_by_calls) - + replaced_on_exp = GuardItemAudit.replaced_on > "2024-03-06" filter_spy_call = filter_spy.call_args[0][0] assert replaced_on_exp.compare(filter_spy_call) - + replaced_on_order_exp = GuardItemAudit.replaced_on.asc() order_by_spy_call = order_by_spy.call_args[0][0] assert replaced_on_order_exp.compare(order_by_spy_call) - + assert mock_first.call_count == 2 assert mock_from_guard_item.called_once_with(previous_guard) - + assert result == previous_guard - - + def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_first = mocker.patch.object(mock_pg_client.db.session, "first") mock_first.return_value = None - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") - + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) + guard_client = GuardClient() - + with pytest.raises(HttpError) as exc_info: guard_client.get_guard("guard") - + assert mock_first.call_count == 1 assert mock_from_guard_item.call_count == 0 - + assert isinstance(exc_info.value, HttpError) assert exc_info.value.status == 404 assert exc_info.value.message == "NotFound" assert exc_info.value.cause == "A Guard with the name guard does not exist!" - - + + def test_get_guard_item(mocker): mock_pg_client = MockPostgresClient() mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + query_spy = mocker.spy(mock_pg_client.db.session, "query") filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") mock_first = mocker.patch.object(mock_pg_client.db.session, "first") latest_guard = MockGuardStruct("latest") mock_first.return_value = latest_guard - + guard_client = GuardClient() - + result = guard_client.get_guard_item("guard") - + query_spy.assert_called_once_with(GuardItem) filter_by_spy.assert_called_once_with(name="guard") assert mock_first.call_count == 1 - + assert result == latest_guard def test_get_guards(mocker): mock_pg_client = MockPostgresClient() mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - + query_spy = mocker.spy(mock_pg_client.db.session, "query") mock_all = mocker.patch.object(mock_pg_client.db.session, "all") guard_one = MockGuardStruct("guard one") guard_two = MockGuardStruct("guard two") guards = [guard_one, guard_two] mock_all.return_value = guards - - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.side_effect = [guard_one, guard_two] - + guard_client = GuardClient() - + result = guard_client.get_guards() - + query_spy.assert_called_once_with(GuardItem) assert mock_all.call_count == 1 - + assert mock_from_guard_item.call_count == 2 - from_guard_item_calls = [ - call(guard_one), - call(guard_two) - ] + from_guard_item_calls = [call(guard_one), call(guard_two)] mock_from_guard_item.assert_has_calls(from_guard_item_calls) - + assert result == [guard_one, guard_two] - + + def test_create_guard(mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mock_guard_struct_init_spy = mocker.spy(MockGuardStruct, "__init__") mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) mocker.patch("src.clients.guard_client.GuardItem", new=MockGuardStruct) - + add_spy = mocker.spy(mock_pg_client.db.session, "add") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - - - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = mock_guard - + guard_client = GuardClient() - + result = guard_client.create_guard(mock_guard) - + mock_guard_struct_init_spy.assert_called_once_with( AnyMatcher, name="mock-guard", railspec={}, num_reasks=0, - description="mock guard description" + description="mock guard description", ) - + assert add_spy.call_count == 1 mock_guard_item = add_spy.call_args[0][0] assert isinstance(mock_guard_item, MockGuardStruct) @@ -199,62 +206,74 @@ def test_create_guard(mocker): assert mock_guard_item.num_reasks == 0 assert mock_guard_item.description == "mock guard description" assert commit_spy.call_count == 1 - + assert result == mock_guard - - + + class TestUpdateGuard: def test_raises_not_found(self, mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = None - + commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") - + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) + guard_client = GuardClient() - + with pytest.raises(HttpError) as exc_info: guard_client.update_guard("mock-guard", mock_guard) - + assert isinstance(exc_info.value, HttpError) assert exc_info.value.status == 404 assert exc_info.value.message == "NotFound" - assert exc_info.value.cause == "A Guard with the name mock-guard does not exist!" - + assert ( + exc_info.value.cause == "A Guard with the name mock-guard does not exist!" + ) + assert commit_spy.call_count == 0 assert mock_from_guard_item.call_count == 0 - + def test_updates_guard_item(self, mocker): old_guard = MockGuardStruct() - updated_guard = MockGuardStruct( - num_reasks=2 - ) + updated_guard = MockGuardStruct(num_reasks=2) mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = old_guard - - to_dict_spy = mocker.spy(updated_guard.railspec, 'to_dict') + + to_dict_spy = mocker.spy(updated_guard.railspec, "to_dict") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = updated_guard - + guard_client = GuardClient() - + result = guard_client.update_guard("mock-guard", updated_guard) - + mock_get_guard_item.assert_called_once_with("mock-guard") assert to_dict_spy.call_count == 1 assert commit_spy.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) - + # These would have been updated by reference assert old_guard.railspec == updated_guard.railspec.to_dict() assert old_guard.num_reasks == 2 - + assert result == updated_guard @@ -263,100 +282,125 @@ def test_guard_doesnt_exist_yet(self, mocker): input_guard = MockGuardStruct() new_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = None - + commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") - mock_create_guard = mocker.patch("src.clients.guard_client.GuardClient.create_guard") + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) + mock_create_guard = mocker.patch( + "src.clients.guard_client.GuardClient.create_guard" + ) mock_create_guard.return_value = new_guard - + guard_client = GuardClient() - + result = guard_client.upsert_guard("mock-guard", input_guard) - + assert mock_get_guard_item.called_once_with("mock-guard") assert commit_spy.call_count == 0 assert mock_from_guard_item.call_count == 0 assert mock_create_guard.called_once_with(input_guard) - + assert result == new_guard - + def test_guard_already_exists(self, mocker): old_guard = MockGuardStruct() - updated_guard = MockGuardStruct( - num_reasks=2, - description="updated description" - ) + updated_guard = MockGuardStruct(num_reasks=2, description="updated description") mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = old_guard - - to_dict_spy = mocker.spy(updated_guard.railspec, 'to_dict') + + to_dict_spy = mocker.spy(updated_guard.railspec, "to_dict") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = updated_guard - + guard_client = GuardClient() - + result = guard_client.upsert_guard("mock-guard", updated_guard) - + mock_get_guard_item.assert_called_once_with("mock-guard") assert to_dict_spy.call_count == 1 assert commit_spy.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) - + # These would have been updated by reference assert old_guard.railspec == updated_guard.railspec.to_dict() assert old_guard.num_reasks == 2 - + assert result == updated_guard class TestDeleteGuard: def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = None - + commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") - + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) + guard_client = GuardClient() - + with pytest.raises(HttpError) as exc_info: guard_client.delete_guard("mock-guard") - + assert isinstance(exc_info.value, HttpError) assert exc_info.value.status == 404 assert exc_info.value.message == "NotFound" - assert exc_info.value.cause == "A Guard with the name mock-guard does not exist!" - + assert ( + exc_info.value.cause == "A Guard with the name mock-guard does not exist!" + ) + assert commit_spy.call_count == 0 assert mock_from_guard_item.call_count == 0 - + def test_deletes_guard_item(self, mocker): old_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() - mocker.patch("src.clients.guard_client.PostgresClient", return_value=mock_pg_client) - mock_get_guard_item = mocker.patch("src.clients.guard_client.GuardClient.get_guard_item") + mocker.patch( + "src.clients.guard_client.PostgresClient", return_value=mock_pg_client + ) + mock_get_guard_item = mocker.patch( + "src.clients.guard_client.GuardClient.get_guard_item" + ) mock_get_guard_item.return_value = old_guard - + delete_spy = mocker.spy(mock_pg_client.db.session, "delete") commit_spy = mocker.spy(mock_pg_client.db.session, "commit") - mock_from_guard_item = mocker.patch("src.clients.guard_client.GuardStruct.from_guard_item") + mock_from_guard_item = mocker.patch( + "src.clients.guard_client.GuardStruct.from_guard_item" + ) mock_from_guard_item.return_value = old_guard - + guard_client = GuardClient() - + result = guard_client.delete_guard("mock-guard") - + mock_get_guard_item.assert_called_once_with("mock-guard") assert delete_spy.call_count == 1 assert commit_spy.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) - + assert result == old_guard diff --git a/tests/mocks/mock_blueprint.py b/tests/mocks/mock_blueprint.py index eca4641..2cca30c 100644 --- a/tests/mocks/mock_blueprint.py +++ b/tests/mocks/mock_blueprint.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List class MockBlueprint: @@ -7,6 +7,7 @@ class MockBlueprint: routes: List[str] methods: List[str] route_call_count: int + def __init__(self, name: str, module_name: str, **kwargs): self.name = name self.module_name = module_name @@ -19,6 +20,7 @@ def __init__(self, name: str, module_name: str, **kwargs): def route(self, route_name: str, methods: List[str] = []): def no_op(fn, *args): return fn + self.routes.append(route_name) self.methods.extend(methods) unique_methods = list(set(self.methods)) @@ -26,14 +28,11 @@ def no_op(fn, *args): self.route_call_count = self.route_call_count + 1 return no_op - def __getitem__(self, key): return getattr(self, key) def __setitem__(self, key, value): - setattr(self, key, value) - + setattr(self, key, value) + def __delitem__(self, key): delattr(self, key) - - \ No newline at end of file diff --git a/tests/mocks/mock_guard_client.py b/tests/mocks/mock_guard_client.py index d3d3840..5ffe424 100644 --- a/tests/mocks/mock_guard_client.py +++ b/tests/mocks/mock_guard_client.py @@ -2,39 +2,41 @@ class MockRailspec: def to_dict(self, *args, **kwargs): return {} + class MockGuardStruct: name: str description: str num_reasks: int history = [] - + def __init__( self, name: str = "mock-guard", num_reasks: str = 0, description: str = "mock guard description", - railspec = {} + railspec={}, ): self.name = name self.description = description self.num_reasks = num_reasks self.railspec = MockRailspec() - + def to_response(self): - return { "name": "mock-guard" } + return {"name": "mock-guard"} def to_guard(self, *args): return self - + def parse(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): - pass + pass + class MockGuardClient: def get_guards(self): return [MockGuardStruct()] - + def create_guard(self, guard: MockGuardStruct): - return MockGuardStruct() \ No newline at end of file + return MockGuardStruct() diff --git a/tests/mocks/mock_postgres_client.py b/tests/mocks/mock_postgres_client.py index 3066024..4197882 100644 --- a/tests/mocks/mock_postgres_client.py +++ b/tests/mocks/mock_postgres_client.py @@ -1,5 +1,6 @@ from typing import Any, List + class MockSession: rows: List[Any] queries: List[str] @@ -8,17 +9,17 @@ def __init__(self) -> None: self.rows = [] self.queries = [] self.execute_calls = [] - + def execute(self, query): self.queries.append(query) return self def all(self): return self.rows - + def _set_rows(self, rows: List[Any]): self.rows = rows - + def query(self, *args, **kwargs): return self @@ -33,7 +34,7 @@ def order_by(self, *args, **kwargs): def first(self, *args, **kwargs): return self - + def add(self, *args, **kwargs): return self @@ -43,11 +44,12 @@ def delete(self, *args, **kwargs): def commit(self, *args, **kwargs): return self + class MockDb: def __init__(self) -> None: self.session = MockSession() -class MockPostgresClient: + +class MockPostgresClient: def __init__(self): self.db = MockDb() - \ No newline at end of file diff --git a/tests/mocks/mock_request.py b/tests/mocks/mock_request.py index d15313d..0285fd8 100644 --- a/tests/mocks/mock_request.py +++ b/tests/mocks/mock_request.py @@ -1,5 +1,6 @@ from typing import Dict, Optional + class MockRequest: method: str json: Optional[Dict] @@ -11,9 +12,9 @@ def __init__( method: str, json: Optional[Dict] = {}, args: Optional[Dict] = {}, - headers: Optional[Dict] = {} + headers: Optional[Dict] = {}, ): self.method = method self.json = json self.args = args - self.headers = headers \ No newline at end of file + self.headers = headers diff --git a/tests/mocks/mock_trace.py b/tests/mocks/mock_trace.py index 9406c31..5f398db 100644 --- a/tests/mocks/mock_trace.py +++ b/tests/mocks/mock_trace.py @@ -39,7 +39,9 @@ def __init__(self, span: Optional[MockSpan] = None): def start_as_current_span(self, *args, **kwargs): return self.span + class MockContext: - _id: str - def __init__(self): - self._id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) \ No newline at end of file + _id: str + + def __init__(self): + self._id = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) diff --git a/tests/utils/test_pluck.py b/tests/utils/test_pluck.py index 6a4fa26..c757bf2 100644 --- a/tests/utils/test_pluck.py +++ b/tests/utils/test_pluck.py @@ -1,6 +1,7 @@ from src.utils.pluck import pluck + def test_pluck(): - input = { "a": 1, "b": 2, "c": 3 } - response = pluck(input, ["a", "c"]) - assert response == [1, 3] \ No newline at end of file + input = {"a": 1, "b": 2, "c": 3} + response = pluck(input, ["a", "c"]) + assert response == [1, 3] diff --git a/tests/utils/test_remove_nones.py b/tests/utils/test_remove_nones.py index ec1d9c6..9d5fb5b 100644 --- a/tests/utils/test_remove_nones.py +++ b/tests/utils/test_remove_nones.py @@ -3,52 +3,24 @@ def test_remove_nones(): dictionary = { - "complete_dictionary": { - "a": 1, - "b": 2 - }, - "partial_dictionary": { - "a": 1, - "b": None - }, - "empty_dictionary": { - "a": None, - "b": None - }, - "complete_list": [ - 1, - 2 - ], - "partial_list": [ - 1, - None - ], - "empty_list": [ - None, - None - ], + "complete_dictionary": {"a": 1, "b": 2}, + "partial_dictionary": {"a": 1, "b": None}, + "empty_dictionary": {"a": None, "b": None}, + "complete_list": [1, 2], + "partial_list": [1, None], + "empty_list": [None, None], "complete_primitive": 1, - "empty_primitive": None + "empty_primitive": None, } - + filtered = remove_nones(dictionary) - + assert filtered == { - "complete_dictionary": { - "a": 1, - "b": 2 - }, - "partial_dictionary": { - "a": 1 - }, + "complete_dictionary": {"a": 1, "b": 2}, + "partial_dictionary": {"a": 1}, "empty_dictionary": {}, - "complete_list": [ - 1, - 2 - ], - "partial_list": [ - 1 - ], + "complete_list": [1, 2], + "partial_list": [1], "empty_list": [], - "complete_primitive": 1 - } \ No newline at end of file + "complete_primitive": 1, + }