Skip to content

Commit

Permalink
Test underlying search logic and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonyhashemi committed Oct 26, 2023
1 parent 3196280 commit c62dd78
Show file tree
Hide file tree
Showing 9 changed files with 402 additions and 107 deletions.
61 changes: 61 additions & 0 deletions app/main/aws/open_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Tuple

import boto3
from opensearchpy import AWSV4SignerAuth, OpenSearch

from app.main.aws.parameter import (
get_aws_environment_prefix,
get_parameter_store_key_value,
)


def get_open_search_index_from_aws_params() -> str:
return get_parameter_store_key_value(
get_aws_environment_prefix() + "AWS_OPEN_SEARCH_INDEX"
)


def generate_open_search_client_from_aws_params() -> OpenSearch:
host = get_parameter_store_key_value(
get_aws_environment_prefix() + "AWS_OPEN_SEARCH_HOST"
)
http_auth = _get_open_search_http_auth()

open_search_client = OpenSearch(
hosts=[{"host": host, "port": 443}],
http_auth=http_auth,
use_ssl=True,
verify_certs=True,
http_compress=True,
ssl_assert_hostname=False,
ssl_show_warn=True,
)
return open_search_client


def _get_open_search_http_auth(
auth_method: str = "username_password",
) -> Tuple[str, str] | AWSV4SignerAuth:
if auth_method == "username_password":
return _get_open_search_username_password_auth()
return _get_open_search_iam_auth()


def _get_open_search_username_password_auth() -> Tuple[str, str]:
username = get_parameter_store_key_value(
get_aws_environment_prefix() + "AWS_OPEN_SEARCH_USERNAME"
)
password = get_parameter_store_key_value(
get_aws_environment_prefix() + "AWS_OPEN_SEARCH_PASSWORD"
)
return (username, password)


def _get_open_search_iam_auth() -> AWSV4SignerAuth:
credentials = boto3.Session().get_credentials()
aws_region = get_parameter_store_key_value(
get_aws_environment_prefix() + "AWS_REGION"
)
service = "es"
aws_auth = AWSV4SignerAuth(credentials, aws_region, service)
return aws_auth
4 changes: 2 additions & 2 deletions app/main/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from flask_wtf.csrf import CSRFError

from app.main.search import open_search
from app.main.search import search_logic
from .forms import SearchForm
from werkzeug.exceptions import HTTPException
import os
Expand Down Expand Up @@ -96,7 +96,7 @@ def poc_search():

if query:
open_search_response = (
open_search.generate_open_search_client_and_make_poc_search(query)
search_logic.generate_open_search_client_and_make_poc_search(query)
)
results = open_search_response["hits"]["hits"]

Expand Down
101 changes: 0 additions & 101 deletions app/main/search/open_search.py

This file was deleted.

48 changes: 48 additions & 0 deletions app/main/search/search_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from typing import Any

from opensearchpy import ImproperlyConfigured
from app.main.aws.open_search import (
generate_open_search_client_from_aws_params,
get_open_search_index_from_aws_params,
)


def generate_open_search_client_and_make_poc_search(query: str) -> Any:
fields = [
"legal_status",
"description",
"closure_type",
"Internal-Sender_Identifier",
"id",
"Contact_Email",
"Source_Organization",
"Consignment_Series.keyword",
"Consignment_Series",
"Contact_Name",
]
open_search_client = generate_open_search_client_from_aws_params()

try:
open_search_client.ping()
except ImproperlyConfigured as e:
logging.error("OpenSearch client improperly configured: " + str(e))
raise e

logging.info("OpenSearch client has been connected successfully")

open_search_index = get_open_search_index_from_aws_params()
open_search_query = {
"query": {
"multi_match": {
"query": query,
"fields": fields,
"fuzziness": "AUTO",
"type": "best_fields",
}
}
}
search_results = open_search_client.search(
body=open_search_query, index=open_search_index
)
return search_results
71 changes: 71 additions & 0 deletions app/tests/test_aws_open_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import patch
from moto import mock_ssm
import boto3

from app.main.aws.open_search import (
get_open_search_index_from_aws_params,
generate_open_search_client_from_aws_params,
)


@mock_ssm
def test_get_open_search_index_from_aws_params():
ssm_client = boto3.client("ssm", region_name="eu-west-2")
ssm_client.put_parameter(
Name="ENVIRONMENT_NAME",
Value="test_env",
Type="String",
Overwrite=True,
)
ssm_client.put_parameter(
Name="/test_env/AWS_OPEN_SEARCH_INDEX",
Value="test_index",
Type="String",
Overwrite=True,
)

assert get_open_search_index_from_aws_params() == "test_index"


@mock_ssm
@patch("app.main.aws.open_search.OpenSearch")
def test_generate_open_search_client_from_aws_params(mock_open_search):
ssm_client = boto3.client("ssm", region_name="eu-west-2")
ssm_client.put_parameter(
Name="ENVIRONMENT_NAME",
Value="test_env",
Type="String",
Overwrite=True,
)
ssm_client.put_parameter(
Name="/test_env/AWS_OPEN_SEARCH_HOST",
Value="mock_opensearch_host",
Type="String",
Overwrite=True,
)
ssm_client.put_parameter(
Name="/test_env/AWS_OPEN_SEARCH_USERNAME",
Value="mock_username",
Type="String",
Overwrite=True,
)
ssm_client.put_parameter(
Name="/test_env/AWS_OPEN_SEARCH_PASSWORD",
Value="mock_password",
Type="String",
Overwrite=True,
)

assert (
generate_open_search_client_from_aws_params() == mock_open_search.return_value
)

mock_open_search.assert_called_once_with(
hosts=[{"host": "mock_opensearch_host", "port": 443}],
http_auth=("mock_username", "mock_password"),
use_ssl=True,
verify_certs=True,
http_compress=True,
ssl_assert_hostname=False,
ssl_show_warn=True,
)
6 changes: 3 additions & 3 deletions app/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from bs4 import BeautifulSoup
from flask.testing import FlaskClient


def test_poc_search_get(client: FlaskClient):
"""
Given a user accessing the search page
Expand Down Expand Up @@ -29,7 +30,7 @@ def test_poc_search_no_query(client: FlaskClient):
assert b"records found" not in response.data


@patch("app.main.routes.open_search.generate_open_search_client_and_make_poc_search")
@patch("app.main.routes.search_logic.generate_open_search_client_and_make_poc_search")
def test_poc_search_with_no_results(mock_open_search, client: FlaskClient):
"""
Given a user with a search query
Expand All @@ -45,7 +46,7 @@ def test_poc_search_with_no_results(mock_open_search, client: FlaskClient):
assert b"records found" not in response.data


@patch("app.main.routes.open_search.generate_open_search_client_and_make_poc_search")
@patch("app.main.routes.search_logic.generate_open_search_client_and_make_poc_search")
def test_poc_search_results_displayed(mock_open_search, client: FlaskClient):
"""
Given a user with a search query which should return n results
Expand Down Expand Up @@ -114,4 +115,3 @@ def test_poc_search_results_displayed(mock_open_search, client: FlaskClient):
assert [result.text for result in row.find_all("td")] == expected_results_table[
row_index + 1
]

Loading

0 comments on commit c62dd78

Please sign in to comment.