diff --git a/middleware/primary_resource_logic/data_requests.py b/middleware/primary_resource_logic/data_requests.py index 58eba153..4a809767 100644 --- a/middleware/primary_resource_logic/data_requests.py +++ b/middleware/primary_resource_logic/data_requests.py @@ -190,7 +190,7 @@ def get_data_requests_wrapper( } if dto.request_statuses is not None: db_client_additional_args["where_mappings"] = { - "request_statuses": [rs.value for rs in dto.request_statuses] + "request_status": [rs.value for rs in dto.request_statuses] } return get_many( middleware_parameters=MiddlewareParameters( diff --git a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py index 8f7d94bb..9c08e09a 100644 --- a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py +++ b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py @@ -90,6 +90,46 @@ def add_description_info_from_enum( return description +class DescriptionBuilder: + """ + Appends information to description based on data + in field + """ + + def __init__(self, field: marshmallow_fields, description: str): + self.field = field + self.description = description + self.metadata = field.metadata + + def __str__(self): + return self.description + + def build_description(self): + self.enum_case() + self.validator_case() + self.list_query_case() + + def enum_case(self): + if isinstance(self.field, marshmallow_fields.Enum): + self.description = add_description_info_from_enum( + field_value=self.field, description=self.description + ) + + def validator_case(self): + for validator in self.field.validators: + if isinstance(validator, OneOf): + self.description = add_description_info_from_validators( + field=self.field, description=self.description + ) + + def list_query_case(self): + if not isinstance(self.field, marshmallow_fields.List): + return + if not self.metadata.get("source") == SourceMappingEnum.QUERY_ARGS: + return + self.description += " \n(Comma-delimited list)" + + class FieldInfo: def __init__( @@ -125,9 +165,9 @@ def _get_description( description = _get_required_argument( "description", metadata, schema_class, field_name ) - if isinstance(field_value, marshmallow_fields.Enum): - description = add_description_info_from_enum(field_value, description) - return add_description_info_from_validators(field_value, description) + desc_builder = DescriptionBuilder(field=field_value, description=description) + desc_builder.build_description() + return str(desc_builder) def _map_field_type(self, field_type: type[MarshmallowFields]) -> Type[RestxFields]: try: diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py index 4fc8389b..f9cd03d7 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py @@ -1,4 +1,4 @@ -from marshmallow import fields, Schema, post_load +from marshmallow import fields, Schema, post_load, pre_load from database_client.enums import RequestStatus from middleware.schema_and_dto_logic.primary_resource_schemas.data_requests_base_schema import ( @@ -139,15 +139,24 @@ def location_infos_convert_empty_list_to_none(self, in_data, **kwargs): class GetManyDataRequestsRequestsSchema(GetManyRequestsBaseSchema): request_statuses = fields.List( - fields.Enum( - enum=RequestStatus, - by_value=fields.Str, - allow_none=True, - metadata=get_query_metadata("The status of the requests to return."), - ), + fields.Enum( + enum=RequestStatus, + by_value=fields.Str, + allow_none=True, + metadata=get_query_metadata("The status of the requests to return."), + ), metadata=get_query_metadata("The status of the requests to return."), ) + @pre_load + def listify_request_statuses(self, in_data, **kwargs): + request_statuses = in_data.get("request_statuses", None) + if request_statuses is None: + return in_data + in_data["request_statuses"] = request_statuses.split(",") + + return in_data + GetManyDataRequestsResponseSchema = create_get_many_schema( data_list_schema=DataRequestsGetSchemaBase, diff --git a/middleware/util.py b/middleware/util.py index 1367a2f6..47a4314f 100644 --- a/middleware/util.py +++ b/middleware/util.py @@ -142,3 +142,19 @@ def get_temp_directory() -> str: # Go to root directory return os.path.join(os.getcwd(), "temp") + + +def stringify_list_of_ints(l: list[int]): + for i in range(len(l)): + l[i] = str(l[i]) + return l + + +def stringify_lists(d: dict): + for k, v in d.items(): + if isinstance(v, dict): + stringify_lists(v) + if isinstance(v, list): + v = stringify_list_of_ints(v) + d[k] = ",".join(v) + return d diff --git a/tests/helper_scripts/helper_classes/RequestValidator.py b/tests/helper_scripts/helper_classes/RequestValidator.py index e4dfe4d4..34147399 100644 --- a/tests/helper_scripts/helper_classes/RequestValidator.py +++ b/tests/helper_scripts/helper_classes/RequestValidator.py @@ -354,7 +354,6 @@ def update_data_request( def get_data_requests( self, headers: dict, - request_status: Optional[RequestStatus] = None, sort_by: Optional[str] = None, sort_order: Optional[SortOrder] = None, request_statuses: Optional[list[RequestStatus]] = None, @@ -366,7 +365,9 @@ def get_data_requests( "sort_by": sort_by, "sort_order": sort_order.value if sort_order is not None else None, "request_statuses": ( - [rs.value for rs in request_statuses] if request_statuses is not None else None + [rs.value for rs in request_statuses] + if request_statuses is not None + else None ), }, ) diff --git a/tests/helper_scripts/helper_functions_simple.py b/tests/helper_scripts/helper_functions_simple.py index fa8d2676..6d809230 100644 --- a/tests/helper_scripts/helper_functions_simple.py +++ b/tests/helper_scripts/helper_functions_simple.py @@ -1,6 +1,8 @@ from datetime import datetime, timezone, timedelta from urllib.parse import urlparse, parse_qs, urlencode, urlunparse +from middleware.util import stringify_lists + """ `Simple` in this case refers to helper functions which rely on no external dependencies @@ -22,6 +24,8 @@ def add_query_params(url, params: dict): :return: """ + stringify_lists(d=params) + # Parse the original URL into components url_parts = list(urlparse(url)) diff --git a/tests/integration/test_bulk.py b/tests/integration/test_bulk.py index 32e964bb..1d1ff3c0 100644 --- a/tests/integration/test_bulk.py +++ b/tests/integration/test_bulk.py @@ -21,6 +21,7 @@ AgenciesPutBatchRequestSchema, DataSourcesPutBatchRequestSchema, ) +from middleware.util import stringify_lists from tests.helper_scripts.common_test_data import get_test_name from tests.helper_scripts.common_asserts import assert_contains_key_value_pairs from tests.helper_scripts.helper_classes.RequestValidator import RequestValidator @@ -34,22 +35,6 @@ ) -def stringify_list_of_ints(l: list[int]): - for i in range(len(l)): - l[i] = str(l[i]) - return l - - -def stringify_lists(d: dict): - for k, v in d.items(): - if isinstance(v, dict): - stringify_lists(v) - if isinstance(v, list): - v = stringify_list_of_ints(v) - d[k] = ",".join(v) - return d - - @dataclass class BatchTestRunner: tdc: TestDataCreatorFlask diff --git a/tests/integration/test_data_requests.py b/tests/integration/test_data_requests.py index 619b6b2a..70efdcb9 100644 --- a/tests/integration/test_data_requests.py +++ b/tests/integration/test_data_requests.py @@ -99,15 +99,13 @@ def test_data_requests_get( headers=tus_creator.jwt_authorization_header, )[DATA_KEY] - - # Assert admin columns are greater than user columns assert len(admin_data[0]) > len(data[0]) # Run get again, this time filtering the request status to be active data = tdc.request_validator.get_data_requests( headers=tus_creator.jwt_authorization_header, - request_status=RequestStatus.ACTIVE, + request_statuses=[RequestStatus.ACTIVE], )[DATA_KEY] # The more recent data request should be returned, but the old one should be filtered out @@ -121,7 +119,7 @@ def test_data_requests_get( def get_sorted_data_requests(sort_order: SortOrder): return tdc.request_validator.get_data_requests( headers=tus_creator.jwt_authorization_header, - request_status=RequestStatus.INTAKE, + request_statuses=[RequestStatus.INTAKE], sort_by="id", sort_order=sort_order, )