diff --git a/app.py b/app.py index ef0aaa09..82f38bba 100644 --- a/app.py +++ b/app.py @@ -9,7 +9,6 @@ from resources.Callback import namespace_auth from resources.DataRequests import namespace_data_requests from resources.GithubDataRequests import namespace_github -from resources.HomepageSearchCache import namespace_homepage_search_cache from resources.LinkToGithub import namespace_link_to_github from resources.LoginWithGithub import namespace_login_with_github from resources.Map import namespace_map @@ -109,9 +108,7 @@ def default(self, o): def create_app() -> Flask: psycopg2_connection = initialize_psycopg_connection() config.connection = psycopg2_connection - api = Api() - for namespace in NAMESPACES: - api.add_namespace(namespace) + api = get_api_with_namespaces() app = Flask(__name__) app.json = UpdatedJSONProvider(app) @@ -132,6 +129,13 @@ def create_app() -> Flask: return app +def get_api_with_namespaces(): + api = Api() + for namespace in NAMESPACES: + api.add_namespace(namespace) + return api + + if __name__ == "__main__": app = create_app() app.run(host=os.getenv("FLASK_RUN_HOST", "127.0.0.1")) diff --git a/database_client/db_client_dataclasses.py b/database_client/db_client_dataclasses.py index b2004b4f..3bff15f4 100644 --- a/database_client/db_client_dataclasses.py +++ b/database_client/db_client_dataclasses.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Optional +from pydantic import BaseModel from sqlalchemy.sql.expression import UnaryExpression from sqlalchemy.schema import Column from sqlalchemy.sql.expression import asc, desc, BinaryExpression @@ -16,8 +17,7 @@ } -@dataclass -class OrderByParameters: +class OrderByParameters(BaseModel): """ Contains parameters for an order_by clause """ diff --git a/database_client/subquery_logic.py b/database_client/subquery_logic.py index c34ebc6a..365d32db 100644 --- a/database_client/subquery_logic.py +++ b/database_client/subquery_logic.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from functools import partialmethod +from typing import Optional +from pydantic import BaseModel from sqlalchemy.orm import defaultload, joinedload from sqlalchemy.sql.base import ExecutableOption @@ -8,15 +10,14 @@ from middleware.enums import Relations -@dataclass -class SubqueryParameters: +class SubqueryParameters(BaseModel): """ Contains parameters for executing a subquery """ relation_name: str linking_column: str - columns: list[str] = None + columns: Optional[list[str]] = None def set_columns(self, columns: list[str]) -> None: self.columns = columns diff --git a/manual_tests/test_api_spec.py b/manual_tests/test_api_spec.py new file mode 100644 index 00000000..cb45cdde --- /dev/null +++ b/manual_tests/test_api_spec.py @@ -0,0 +1,9 @@ +import json + +from app import get_api_with_namespaces, create_app + + +def test_api_spec(): + app = create_app() + api = get_api_with_namespaces() + print(json.dumps(api.__schema__)) diff --git a/middleware/access_logic.py b/middleware/access_logic.py index 33c95bcf..8c35105c 100644 --- a/middleware/access_logic.py +++ b/middleware/access_logic.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass from http import HTTPStatus from typing import Optional from flask import request from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request +from flask_jwt_extended.exceptions import NoAuthorizationError from flask_restx import abort from jwt import ExpiredSignatureError +from pydantic import BaseModel from database_client.database_client import DatabaseClient from middleware.SimpleJWT import SimpleJWT, JWTPurpose @@ -16,11 +17,11 @@ InvalidAPIKeyException, InvalidAuthorizationHeaderException, ) +from middleware.flask_response_manager import FlaskResponseManager from middleware.primary_resource_logic.permissions_logic import get_user_permissions -@dataclass -class AuthenticationInfo: +class AuthenticationInfo(BaseModel): """ A dataclass providing information on how the user was authenticated """ @@ -62,19 +63,16 @@ def is_access_type_allowed(self, access_type: AccessTypeEnum) -> bool: return access_type in self.allowed_access_methods -@dataclass -class AccessInfoBase: - pass +class AccessInfoBase(BaseModel): + access_type: AccessTypeEnum -@dataclass class AccessInfoPrimary(AccessInfoBase): """ A dataclass providing information on how the endpoint was accessed """ user_email: str - access_type: AccessTypeEnum user_id: Optional[int] = None permissions: list[PermissionsEnum] = None @@ -84,17 +82,15 @@ def get_user_id(self) -> Optional[int]: return self.user_id -@dataclass class PasswordResetTokenAccessInfo(AccessInfoBase): - access_type = AccessTypeEnum.RESET_PASSWORD + access_type: AccessTypeEnum = AccessTypeEnum.RESET_PASSWORD user_id: int user_email: str reset_token: str -@dataclass class ValidateEmailTokenAccessInfo(AccessInfoBase): - access_type = AccessTypeEnum.VALIDATE_EMAIL + access_type: AccessTypeEnum = AccessTypeEnum.VALIDATE_EMAIL validate_email_token: str @@ -104,7 +100,11 @@ def get_identity(): try: verify_jwt_in_request() return get_jwt_identity() - except Exception: + except NoAuthorizationError: + FlaskResponseManager.abort( + HTTPStatus.BAD_REQUEST, message="Token is missing" + ) + except Exception as e: return None @staticmethod @@ -160,7 +160,9 @@ def get_authorization_header_from_request() -> str: try: return headers["Authorization"] except (KeyError, TypeError): - raise InvalidAuthorizationHeaderException + FlaskResponseManager.abort( + code=HTTPStatus.BAD_REQUEST, message="Authorization header missing" + ) def get_key_from_authorization_header( @@ -168,6 +170,11 @@ def get_key_from_authorization_header( ) -> str: try: authorization_header_parts = authorization_header.split(" ") + if len(authorization_header_parts) != 2: + FlaskResponseManager.abort( + code=HTTPStatus.BAD_REQUEST, + message="Improperly formatted authorization header", + ) if authorization_header_parts[0] != scheme: raise InvalidAPIKeyException return authorization_header_parts[1] diff --git a/middleware/column_permission_logic.py b/middleware/column_permission_logic.py index 10627f70..abdc9152 100644 --- a/middleware/column_permission_logic.py +++ b/middleware/column_permission_logic.py @@ -1,9 +1,7 @@ -from dataclasses import dataclass -from enum import Enum from http import HTTPStatus from typing import Optional -from flask_restx import abort +from pydantic import BaseModel, ConfigDict from database_client.database_client import DatabaseClient from database_client.enums import RelationRoleEnum, ColumnPermissionEnum @@ -112,8 +110,9 @@ def get_relation_role(access_info: AccessInfoPrimary) -> RelationRoleEnum: return RelationRoleEnum.STANDARD -@dataclass -class RelationRoleParameters: +class RelationRoleParameters(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + relation_role_function_with_params: DeferredFunction = DeferredFunction( function=get_relation_role, ) diff --git a/middleware/custom_dataclasses.py b/middleware/custom_dataclasses.py index 068cd215..557246e6 100644 --- a/middleware/custom_dataclasses.py +++ b/middleware/custom_dataclasses.py @@ -1,17 +1,18 @@ from dataclasses import dataclass from typing import Optional +from pydantic import BaseModel + from database_client.enums import EntityType, EventType from middleware.enums import CallbackFunctionsEnum -@dataclass -class GithubUserInfo: +class GithubUserInfo(BaseModel): """ Information about a Github user """ - user_id: str + user_id: int user_email: str @@ -25,8 +26,7 @@ class FlaskSessionCallbackInfo: callback_params: dict -@dataclass -class OAuthCallbackInfo: +class OAuthCallbackInfo(BaseModel): """ Contains information returned by OAuth in the callback logic """ diff --git a/middleware/dynamic_request_logic/get_related_resource_logic.py b/middleware/dynamic_request_logic/get_related_resource_logic.py index b733f8d1..c5f42a00 100644 --- a/middleware/dynamic_request_logic/get_related_resource_logic.py +++ b/middleware/dynamic_request_logic/get_related_resource_logic.py @@ -1,6 +1,7 @@ from typing import Optional from flask import Response +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.db_client_dataclasses import WhereMapping diff --git a/middleware/location_logic.py b/middleware/location_logic.py index 07b2f87f..1b9105d0 100644 --- a/middleware/location_logic.py +++ b/middleware/location_logic.py @@ -1,4 +1,3 @@ -from dataclasses import asdict from typing import Union from database_client.database_client import DatabaseClient @@ -19,7 +18,7 @@ def get_location_id( location_info_dict = location_info location_info = LocationInfoDTO(**location_info) else: - location_info_dict = asdict(location_info) + location_info_dict = dict(location_info) location_info_where_mappings = WhereMapping.from_dict(location_info_dict) # Get location id diff --git a/middleware/primary_resource_logic/agencies.py b/middleware/primary_resource_logic/agencies.py index 8787a1b1..1d578386 100644 --- a/middleware/primary_resource_logic/agencies.py +++ b/middleware/primary_resource_logic/agencies.py @@ -96,7 +96,7 @@ def validate_and_add_location_info( def create_agency( db_client: DatabaseClient, dto: AgenciesPostDTO, access_info: AccessInfoPrimary ) -> Response: - entry_data = asdict(dto.agency_info) + entry_data = dict(dto.agency_info) deferred_function = optionally_get_location_info_deferred_function( db_client=db_client, jurisdiction_type=dto.agency_info.jurisdiction_type, diff --git a/middleware/primary_resource_logic/archives_queries.py b/middleware/primary_resource_logic/archives_queries.py index e2b0a625..15b5d780 100644 --- a/middleware/primary_resource_logic/archives_queries.py +++ b/middleware/primary_resource_logic/archives_queries.py @@ -18,13 +18,6 @@ ] -@dataclass -class ArchivesQueryData: - id: Optional[str] = None - broken_source_url_as_of: Optional[str] = None - last_cached: Optional[str] = None - - def archives_get_query( db_client: DatabaseClient, ) -> List[Dict[str, Any]]: diff --git a/middleware/primary_resource_logic/data_requests.py b/middleware/primary_resource_logic/data_requests.py index f26637d4..3a609aea 100644 --- a/middleware/primary_resource_logic/data_requests.py +++ b/middleware/primary_resource_logic/data_requests.py @@ -1,16 +1,13 @@ -from dataclasses import dataclass, asdict from http import HTTPStatus from typing import Optional -from flask import make_response, Response -from marshmallow import fields +from flask import Response from database_client.db_client_dataclasses import WhereMapping, OrderByParameters from database_client.database_client import DatabaseClient from database_client.enums import ( ColumnPermissionEnum, RelationRoleEnum, - RequestUrgency, RequestStatus, ) from database_client.subquery_logic import SubqueryParameterManager, SubqueryParameters @@ -22,7 +19,6 @@ from middleware.custom_dataclasses import ( DeferredFunction, ) -from middleware.dynamic_request_logic.common_functions import check_for_id from middleware.dynamic_request_logic.delete_logic import delete_entry from middleware.dynamic_request_logic.get_by_id_logic import get_by_id from middleware.dynamic_request_logic.get_many_logic import get_many @@ -39,9 +35,7 @@ from middleware.flask_response_manager import FlaskResponseManager from middleware.location_logic import InvalidLocationError from middleware.schema_and_dto_logic.common_schemas_and_dtos import ( - EntryCreateUpdateRequestDTO, GetByIDBaseDTO, - GetByIDBaseSchema, GetManyBaseDTO, ) from middleware.enums import AccessTypeEnum, PermissionsEnum, Relations @@ -51,13 +45,15 @@ created_id_response, ) from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( - DataRequestLocationInfoPostDTO, GetManyDataRequestsRequestsDTO, DataRequestsPutDTO, DataRequestsPutOuterDTO, + RelatedSourceByIDDTO, + RelatedLocationsByIDDTO, + DataRequestsPostDTO, + DataRequestLocationInfoPostDTO, ) from middleware.util import dataclass_to_filtered_dict -from utilities.enums import SourceMappingEnum RELATION = Relations.DATA_REQUESTS.value RELATED_SOURCES_RELATION = Relations.RELATED_SOURCES.value @@ -70,55 +66,8 @@ def get_data_requests_subquery_params() -> list[SubqueryParameters]: ] -class RelatedSourceByIDSchema(GetByIDBaseSchema): - data_source_id = fields.Str( - required=True, - metadata={ - "description": "The ID of the data source", - "source": SourceMappingEnum.PATH, - }, - ) - - -@dataclass -class RelatedSourceByIDDTO(GetByIDBaseDTO): - data_source_id: int - - def get_where_mapping(self): - return { - "data_source_id": int(self.data_source_id), - "request_id": int(self.resource_id), - } - - -@dataclass -class RelatedLocationsByIDDTO(GetByIDBaseDTO): - location_id: int - - def get_where_mapping(self): - return { - "location_id": int(self.location_id), - "data_request_id": int(self.resource_id), - } - - -@dataclass -class RequestInfoPostDTO: - title: str - submission_notes: str - request_urgency: RequestUrgency - coverage_range: Optional[str] = None - data_requirements: Optional[str] = None - - -@dataclass -class DataRequestsPostDTO: - request_info: RequestInfoPostDTO - location_infos: Optional[list[DataRequestLocationInfoPostDTO]] = None - - def get_location_id_for_data_requests( - db_client: DatabaseClient, location_info: dict + db_client: DatabaseClient, location_info: DataRequestLocationInfoPostDTO ) -> int: """ Get the location id for the data request @@ -128,10 +77,10 @@ def get_location_id_for_data_requests( """ # Rename keys to match where mappings revised_location_info = { - "type": location_info["type"], - "state_name": location_info["state"], - "county_name": location_info["county"], - "locality_name": location_info["locality"], + "type": location_info.type, + "state_name": location_info.state, + "county_name": location_info.county, + "locality_name": location_info.locality, } location_id = db_client.get_location_id( @@ -183,7 +132,7 @@ def create_data_request_wrapper( # Check that location ids are valid, and get location ids for linking location_ids = _get_location_ids(db_client, dto) - column_value_mappings_raw = asdict(dto.request_info) + column_value_mappings_raw = dict(dto.request_info) user_id = access_info.get_user_id() column_value_mappings_raw["creator_user_id"] = user_id diff --git a/middleware/primary_resource_logic/data_sources_logic.py b/middleware/primary_resource_logic/data_sources_logic.py index e6a959e8..75dff3bf 100644 --- a/middleware/primary_resource_logic/data_sources_logic.py +++ b/middleware/primary_resource_logic/data_sources_logic.py @@ -25,7 +25,9 @@ ) from middleware.enums import Relations -from middleware.primary_resource_logic.data_requests import RelatedSourceByIDDTO +from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( + RelatedSourceByIDDTO, +) from middleware.schema_and_dto_logic.common_schemas_and_dtos import ( GetManyBaseDTO, EntryCreateUpdateRequestDTO, @@ -49,7 +51,6 @@ class DataSourceNotFoundError(Exception): pass -@dataclass class DataSourcesGetManyRequestDTO(GetManyBaseDTO): approval_status: ApprovalStatus = ApprovalStatus.APPROVED page_number: int = 1 diff --git a/middleware/primary_resource_logic/github_oauth_logic.py b/middleware/primary_resource_logic/github_oauth_logic.py index e92682b7..04d4fd8d 100644 --- a/middleware/primary_resource_logic/github_oauth_logic.py +++ b/middleware/primary_resource_logic/github_oauth_logic.py @@ -5,6 +5,7 @@ from flask import Response, make_response from flask_restx import abort from jwt import ExpiredSignatureError +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.enums import ExternalAccountTypeEnum @@ -26,8 +27,7 @@ ) -@dataclass -class LinkToGithubRequestDTO: +class LinkToGithubRequestDTO(BaseModel): gh_access_token: str user_email: str diff --git a/middleware/primary_resource_logic/homepage_search_cache.py b/middleware/primary_resource_logic/homepage_search_cache.py index 25f4cc9d..42e21649 100644 --- a/middleware/primary_resource_logic/homepage_search_cache.py +++ b/middleware/primary_resource_logic/homepage_search_cache.py @@ -2,6 +2,7 @@ from http import HTTPStatus from flask import Response, make_response +from pydantic import BaseModel from database_client.database_client import DatabaseClient from middleware.common_response_formatting import format_list_response, message_response @@ -12,8 +13,7 @@ def get_agencies_without_homepage_urls(database_client: DatabaseClient) -> Respo return make_response(format_list_response(results), 200) -@dataclass -class SearchCacheEntry: +class SearchCacheEntry(BaseModel): agency_airtable_uid: str search_results: list[str] diff --git a/middleware/primary_resource_logic/notifications_logic.py b/middleware/primary_resource_logic/notifications_logic.py index 1cfae896..8821e975 100644 --- a/middleware/primary_resource_logic/notifications_logic.py +++ b/middleware/primary_resource_logic/notifications_logic.py @@ -3,6 +3,7 @@ from http import HTTPStatus from flask import Response +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.enums import EventType @@ -15,8 +16,7 @@ from middleware.third_party_interaction_logic.mailgun_logic import send_via_mailgun -@dataclass -class NotificationEmailContent: +class NotificationEmailContent(BaseModel): html_text: str base_text: str diff --git a/middleware/primary_resource_logic/permissions_logic.py b/middleware/primary_resource_logic/permissions_logic.py index 024c8fc4..dfdcc4a1 100644 --- a/middleware/primary_resource_logic/permissions_logic.py +++ b/middleware/primary_resource_logic/permissions_logic.py @@ -7,12 +7,14 @@ from flask_jwt_extended import get_jwt_identity from flask_restx import abort from marshmallow import Schema, fields +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.helper_functions import get_db_client from middleware.exceptions import UserNotFoundError from middleware.enums import PermissionsEnum, PermissionsActionEnum from middleware.common_response_formatting import message_response +from middleware.schema_and_dto_logic.util import get_query_metadata from utilities.common import get_valid_enum_value from utilities.enums import SourceMappingEnum, ParserLocation @@ -24,7 +26,16 @@ ) -class PermissionsRequestSchema(Schema): +class PermissionsGetRequestSchema(Schema): + user_email = fields.Str( + required=True, + metadata=get_query_metadata( + "The email of the user for which to retrieve permissions." + ), + ) + + +class PermissionsPutRequestSchema(Schema): user_email = fields.Str( required=True, metadata={ @@ -49,8 +60,7 @@ class PermissionsRequestSchema(Schema): ) -@dataclass -class PermissionsRequest: +class PermissionsRequestDTO(BaseModel): user_email: str permission: str action: str @@ -116,7 +126,7 @@ def manage_user_permissions( def update_permissions_wrapper( db_client: DatabaseClient, - dto: PermissionsRequest, + dto: PermissionsRequestDTO, ) -> Response: action = get_valid_enum_value(PermissionsActionEnum, dto.action) permission = get_valid_enum_value(PermissionsEnum, dto.permission) diff --git a/middleware/primary_resource_logic/unique_url_checker.py b/middleware/primary_resource_logic/unique_url_checker.py index c62dcc4b..89ac5f40 100644 --- a/middleware/primary_resource_logic/unique_url_checker.py +++ b/middleware/primary_resource_logic/unique_url_checker.py @@ -3,6 +3,7 @@ from flask import Response from marshmallow import Schema, fields, validate +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.enums import ApprovalStatus @@ -32,8 +33,7 @@ class UniqueURLCheckerRequestSchema(Schema): ) -@dataclass -class UniqueURLCheckerRequestDTO: +class UniqueURLCheckerRequestDTO(BaseModel): url: str diff --git a/middleware/primary_resource_logic/user_queries.py b/middleware/primary_resource_logic/user_queries.py index 49084166..cf249e91 100644 --- a/middleware/primary_resource_logic/user_queries.py +++ b/middleware/primary_resource_logic/user_queries.py @@ -1,15 +1,13 @@ -from dataclasses import dataclass from http import HTTPStatus from flask import Response from marshmallow import Schema, fields +from pydantic import BaseModel from werkzeug.security import generate_password_hash -from typing import Dict from database_client.database_client import DatabaseClient from middleware.common_response_formatting import message_response from middleware.exceptions import UserNotFoundError, DuplicateUserError -from middleware.flask_response_manager import FlaskResponseManager from utilities.enums import SourceMappingEnum @@ -30,8 +28,7 @@ class UserRequestSchema(Schema): ) -@dataclass -class UserRequestDTO: +class UserRequestDTO(BaseModel): email: str password: str diff --git a/middleware/schema_and_dto_logic/common_schemas_and_dtos.py b/middleware/schema_and_dto_logic/common_schemas_and_dtos.py index 41a6a940..78b4b383 100644 --- a/middleware/schema_and_dto_logic/common_schemas_and_dtos.py +++ b/middleware/schema_and_dto_logic/common_schemas_and_dtos.py @@ -4,10 +4,10 @@ """ import ast -from dataclasses import dataclass from typing import Optional from marshmallow import Schema, fields, validate, validates_schema, ValidationError +from pydantic import BaseModel from database_client.enums import SortOrder, LocationType from middleware.schema_and_dto_logic.custom_fields import DataField @@ -58,8 +58,7 @@ class GetManyRequestsBaseSchema(Schema): ) -@dataclass -class GetManyBaseDTO: +class GetManyBaseDTO(BaseModel): """ A base data transfer object for GET requests returning a list of objects """ @@ -86,8 +85,7 @@ class GetByIDBaseSchema(Schema): ) -@dataclass -class GetByIDBaseDTO: +class GetByIDBaseDTO(BaseModel): resource_id: str @@ -101,8 +99,7 @@ class EntryDataRequestSchema(Schema): ) -@dataclass -class EntryCreateUpdateRequestDTO: +class EntryCreateUpdateRequestDTO(BaseModel): """ Contains data for creating or updating an entry """ @@ -128,8 +125,7 @@ class TypeaheadQuerySchema(Schema): ) -@dataclass -class TypeaheadDTO: +class TypeaheadDTO(BaseModel): query: str @@ -230,8 +226,7 @@ class LocationInfoExpandedSchema(LocationInfoSchema): ) -@dataclass -class LocationInfoDTO: +class LocationInfoDTO(BaseModel): type: LocationType state_iso: str county_fips: Optional[str] = None @@ -245,6 +240,5 @@ class EmailOnlySchema(Schema): ) -@dataclass -class EmailOnlyDTO: +class EmailOnlyDTO(BaseModel): email: str diff --git a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py index 46175a20..097694d1 100644 --- a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py +++ b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py @@ -1,9 +1,8 @@ -import ast -from dataclasses import dataclass from http import HTTPStatus -from typing import Type, Optional +from typing import Type import marshmallow +from pydantic import BaseModel from middleware.flask_response_manager import FlaskResponseManager from middleware.schema_and_dto_logic.custom_types import SchemaTypes, DTOTypes @@ -14,14 +13,12 @@ from utilities.enums import SourceMappingEnum -@dataclass -class NestedDTOInfo: +class NestedDTOInfo(BaseModel): key: str class_: Type -@dataclass -class SourceDataInfo: +class SourceDataInfo(BaseModel): data: dict nested_dto_info_list: list[NestedDTOInfo] diff --git a/middleware/schema_and_dto_logic/non_dto_dataclasses.py b/middleware/schema_and_dto_logic/non_dto_dataclasses.py index 33b1ba62..e03600d9 100644 --- a/middleware/schema_and_dto_logic/non_dto_dataclasses.py +++ b/middleware/schema_and_dto_logic/non_dto_dataclasses.py @@ -1,9 +1,10 @@ -from dataclasses import dataclass from typing import Type, Optional, Callable from flask_restx import Model from flask_restx.reqparse import RequestParser from marshmallow import Schema +from pydantic import BaseModel, ConfigDict +from dataclasses import dataclass from middleware.schema_and_dto_logic.custom_types import SchemaTypes, DTOTypes from utilities.enums import SourceMappingEnum @@ -11,18 +12,18 @@ @dataclass class FlaskRestxDocInfo: - parser: RequestParser - model: Optional[Model] = None + parser: Optional[RequestParser] + model: Model = None @dataclass class SchemaPopulateParameters: + schema: SchemaTypes dto_class: Type[DTOTypes] -@dataclass -class DTOPopulateParameters: +class DTOPopulateParameters(BaseModel): """ Parameters for the dynamic DTO population function """ diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/agencies_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/agencies_dtos.py index b945fe5d..d5061ea3 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/agencies_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/agencies_dtos.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Optional from middleware.enums import JurisdictionType, AgencyType @@ -6,28 +5,27 @@ LocationInfoDTO, GetByIDBaseDTO, ) +from pydantic import BaseModel -@dataclass -class AgencyInfoPutDTO: - submitted_name: Optional[str] = None - jurisdiction_type: Optional[JurisdictionType] = None +class AgencyInfoPutDTO(BaseModel): + submitted_name: str = None + jurisdiction_type: JurisdictionType = None agency_type: AgencyType = AgencyType.NONE - multi_agency: Optional[bool] = False - no_web_presence: Optional[bool] = False - approved: Optional[bool] = False - homepage_url: Optional[str] = None - lat: Optional[float] = None - lng: Optional[float] = None - defunct_year: Optional[str] = None - zip_code: Optional[str] = None - rejection_reason: Optional[str] = None - last_approval_editor: Optional[str] = None - submitter_contact: Optional[str] = None + multi_agency: bool = False + no_web_presence: bool = False + approved: bool = False + homepage_url: str = None + lat: float = None + lng: float = None + defunct_year: str = None + zip_code: str = None + rejection_reason: str = None + last_approval_editor: str = None + submitter_contact: str = None -@dataclass -class AgencyInfoPostDTO: +class AgencyInfoPostDTO(BaseModel): submitted_name: str jurisdiction_type: JurisdictionType agency_type: AgencyType @@ -44,19 +42,16 @@ class AgencyInfoPostDTO: submitter_contact: Optional[str] = None -@dataclass -class AgenciesPostDTO: +class AgenciesPostDTO(BaseModel): agency_info: AgencyInfoPostDTO location_info: Optional[LocationInfoDTO] = None -@dataclass -class AgenciesPutDTO: +class AgenciesPutDTO(BaseModel): agency_info: AgencyInfoPutDTO location_info: Optional[LocationInfoDTO] = None -@dataclass class RelatedAgencyByIDDTO(GetByIDBaseDTO): agency_id: int diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/data_requests_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/data_requests_dtos.py index cf8d7bdf..3669aa38 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/data_requests_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/data_requests_dtos.py @@ -1,14 +1,17 @@ -from dataclasses import dataclass from typing import Optional +from pydantic import BaseModel + from database_client.db_client_dataclasses import WhereMapping from database_client.enums import LocationType, RequestStatus, RequestUrgency from middleware.enums import RecordType -from middleware.schema_and_dto_logic.common_schemas_and_dtos import GetManyBaseDTO +from middleware.schema_and_dto_logic.common_schemas_and_dtos import ( + GetManyBaseDTO, + GetByIDBaseDTO, +) -@dataclass -class DataRequestLocationInfoPostDTO: +class DataRequestLocationInfoPostDTO(BaseModel): type: LocationType state: str county: Optional[str] = None @@ -26,13 +29,11 @@ def get_where_mappings(self) -> list[WhereMapping]: return WhereMapping.from_dict(d) -@dataclass class GetManyDataRequestsRequestsDTO(GetManyBaseDTO): request_status: Optional[RequestStatus] = None -@dataclass -class DataRequestsPutDTO: +class DataRequestsPutDTO(BaseModel): title: Optional[str] = None submission_notes: Optional[str] = None request_urgency: Optional[RequestUrgency] = None @@ -47,6 +48,38 @@ class DataRequestsPutDTO: pdap_response: Optional[str] = None -@dataclass -class DataRequestsPutOuterDTO: +class DataRequestsPutOuterDTO(BaseModel): entry_data: DataRequestsPutDTO + + +class RelatedSourceByIDDTO(GetByIDBaseDTO): + data_source_id: int + + def get_where_mapping(self): + return { + "data_source_id": int(self.data_source_id), + "request_id": int(self.resource_id), + } + + +class RelatedLocationsByIDDTO(GetByIDBaseDTO): + location_id: int + + def get_where_mapping(self): + return { + "location_id": int(self.location_id), + "data_request_id": int(self.resource_id), + } + + +class RequestInfoPostDTO(BaseModel): + title: str + submission_notes: str + request_urgency: RequestUrgency + coverage_range: Optional[str] = None + data_requirements: Optional[str] = None + + +class DataRequestsPostDTO(BaseModel): + request_info: RequestInfoPostDTO + location_infos: Optional[list[DataRequestLocationInfoPostDTO]] = None diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py index 42a97d58..f3955f92 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py @@ -2,6 +2,8 @@ from datetime import date from typing import Optional, List +from pydantic import BaseModel + from database_client.enums import ( ApprovalStatus, AgencyAggregation, @@ -13,8 +15,7 @@ from middleware.enums import RecordType -@dataclass -class DataSourceEntryDataPostDTO: +class DataSourceEntryDataPostDTO(BaseModel): submitted_name: str description: Optional[str] = None approval_status: Optional[ApprovalStatus] = None @@ -49,7 +50,6 @@ class DataSourceEntryDataPostDTO: record_type_name: Optional[RecordType] = None -@dataclass -class DataSourcesPostDTO: +class DataSourcesPostDTO(BaseModel): entry_data: DataSourceEntryDataPostDTO linked_agency_ids: Optional[List[int]] = None diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/request_reset_password_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/request_reset_password_dtos.py index d9652273..2a521d27 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/request_reset_password_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/request_reset_password_dtos.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class RequestResetPasswordRequestDTO: + +class RequestResetPasswordRequestDTO(BaseModel): email: str diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/reset_token_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/reset_token_dtos.py index a8377b41..9417b5a7 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/reset_token_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/reset_token_dtos.py @@ -1,12 +1,12 @@ from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class RequestResetPasswordRequestDTO: + +class RequestResetPasswordRequestDTO(BaseModel): email: str token: str -@dataclass -class ResetPasswordDTO: +class ResetPasswordDTO(BaseModel): password: str diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/user_profile_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/user_profile_dtos.py index 2f125a31..bdaf8a6e 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/user_profile_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/user_profile_dtos.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class UserPutDTO: + +class UserPutDTO(BaseModel): old_password: str new_password: str diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/archives_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/archives_schemas.py new file mode 100644 index 00000000..5c2649c5 --- /dev/null +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/archives_schemas.py @@ -0,0 +1,33 @@ +from marshmallow import Schema, fields + +from middleware.schema_and_dto_logic.util import get_json_metadata + + +class ArchivesGetResponseSchema(Schema): + id = fields.String( + required=True, metadata=get_json_metadata("The ID of the archive") + ) + last_cached = fields.DateTime( + required=True, + metadata=get_json_metadata("The last date the archive was cached"), + ) + source_url = fields.String( + required=True, metadata=get_json_metadata("The URL of the archive") + ) + update_frequency = fields.String( + required=True, metadata=get_json_metadata("The archive update frequency") + ) + + +class ArchivesPutRequestSchema(Schema): + id = fields.String( + required=True, metadata=get_json_metadata("The ID of the archive") + ) + last_cached = fields.DateTime( + required=True, + metadata=get_json_metadata("The last date the archive was cached"), + ) + broken_source_url_as_of = fields.Date( + required=True, + metadata=get_json_metadata("The date the source was marked as broken"), + ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/auth_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/auth_schemas.py index 7f32c54b..dd0a3727 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/auth_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/auth_schemas.py @@ -2,6 +2,7 @@ from typing import Optional from marshmallow import Schema, fields +from pydantic import BaseModel from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema from middleware.schema_and_dto_logic.util import get_json_metadata, get_query_metadata @@ -22,8 +23,7 @@ class GithubRequestSchema(Schema): ) -@dataclass -class LoginWithGithubRequestDTO: +class LoginWithGithubRequestDTO(BaseModel): gh_access_token: str @@ -44,6 +44,5 @@ class GithubOAuthRequestSchema(Schema): ) -@dataclass -class GithubOAuthRequestDTO: +class GithubOAuthRequestDTO(BaseModel): redirect_url: Optional[str] = None diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py index 66da7ac7..9b9e44b8 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/data_requests_advanced_schemas.py @@ -1,7 +1,6 @@ from marshmallow import fields, Schema, post_load from database_client.enums import RequestStatus -from middleware.primary_resource_logic.data_requests import RequestInfoPostDTO from middleware.schema_and_dto_logic.primary_resource_schemas.data_requests_base_schema import ( DataRequestsSchema, ) @@ -17,6 +16,7 @@ from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( DataRequestLocationInfoPostDTO, DataRequestsPutDTO, + RequestInfoPostDTO, ) from middleware.schema_and_dto_logic.primary_resource_schemas.data_sources_base_schemas import ( DataSourceExpandedSchema, @@ -30,6 +30,7 @@ get_query_metadata, get_path_metadata, ) +from utilities.enums import SourceMappingEnum class DataRequestsGetSchemaBase(DataRequestsSchema): @@ -168,3 +169,13 @@ class DataRequestsRelatedLocationAddRemoveSchema(GetByIDBaseSchema): "The ID of the location to add or remove from the data request." ), ) + + +class RelatedSourceByIDSchema(GetByIDBaseSchema): + data_source_id = fields.Str( + required=True, + metadata={ + "description": "The ID of the data source", + "source": SourceMappingEnum.PATH, + }, + ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/github_issue_app_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/github_issue_app_schemas.py index c67b251d..b3bd201b 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/github_issue_app_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/github_issue_app_schemas.py @@ -1,4 +1,5 @@ from marshmallow import Schema, fields +from pydantic import BaseModel from utilities.enums import SourceMappingEnum from dataclasses import dataclass @@ -31,6 +32,5 @@ class GithubDataRequestsIssuesPostResponseSchema(Schema): ) -@dataclass -class GithubDataRequestsIssuesPostDTO: +class GithubDataRequestsIssuesPostDTO(BaseModel): data_request_id: int diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/refresh_session_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/refresh_session_schemas.py index 99603b34..aba72c0e 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/refresh_session_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/refresh_session_schemas.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from marshmallow import fields, Schema +from pydantic import BaseModel from middleware.schema_and_dto_logic.util import get_json_metadata @@ -12,6 +13,5 @@ class RefreshSessionRequestSchema(Schema): ) -@dataclass -class RefreshSessionRequestDTO: +class RefreshSessionRequestDTO(BaseModel): refresh_token: str diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py index c8c22337..da34f316 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py @@ -2,6 +2,7 @@ from typing import Optional from marshmallow import Schema, fields, validates_schema, ValidationError +from pydantic import BaseModel from middleware.schema_and_dto_logic.schema_helpers import create_get_many_schema from middleware.schema_and_dto_logic.util import get_json_metadata @@ -197,8 +198,7 @@ class FollowSearchResponseSchema(Schema): ) -@dataclass -class SearchRequests: +class SearchRequests(BaseModel): state: str record_categories: Optional[list[RecordCategories]] = None county: Optional[str] = None diff --git a/middleware/third_party_interaction_logic/github_issue_api_logic.py b/middleware/third_party_interaction_logic/github_issue_api_logic.py index f7dd3cc2..a035de5b 100644 --- a/middleware/third_party_interaction_logic/github_issue_api_logic.py +++ b/middleware/third_party_interaction_logic/github_issue_api_logic.py @@ -3,14 +3,14 @@ from github import Github from github import Auth import requests +from pydantic import BaseModel from database_client.enums import RequestStatus from middleware.util import get_env_variable from dataclasses import dataclass -@dataclass -class GithubIssueInfo: +class GithubIssueInfo(BaseModel): url: str number: int diff --git a/middleware/util.py b/middleware/util.py index 7ef616f7..60bf34ca 100644 --- a/middleware/util.py +++ b/middleware/util.py @@ -4,6 +4,7 @@ from typing import Any, Dict from dotenv import dotenv_values, find_dotenv +from pydantic import BaseModel def get_env_variable(name: str) -> str: @@ -37,12 +38,16 @@ def dataclass_to_filtered_dict(instance: Any) -> Dict[str, Any]: :param instance: :return: """ - if not is_dataclass(instance): + if is_dataclass(instance): + d = asdict(instance) + if isinstance(instance, BaseModel): + d = dict(instance) + else: raise TypeError( - f"Expected a dataclass instance, but got {type(instance).__name__}" + f"Expected a dataclass or basemodel instance, but got {type(instance).__name__}" ) results = {} - for key, value in asdict(instance).items(): + for key, value in d.items(): # Special case for Enum if isinstance(value, Enum): results[key] = value.value diff --git a/requirements.txt b/requirements.txt index b16c7f19..b063c953 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aniso8601==9.0.1 -annotated-types==0.5.0 +annotated-types==0.7.0 anyio==3.7.1 attrs==23.1.0 blinker==1.6.2 @@ -40,8 +40,8 @@ preshed==3.0.8 psycopg~=3.2 py==1.11.0 pycparser==2.21 -pydantic==2.2.1 -pydantic_core==2.6.1 +pydantic==2.10.2 +pydantic_core==2.27.1 pytest==8.3.3 python-dateutil==2.8.2 python-dotenv==1.0.0 @@ -58,7 +58,7 @@ storage3==0.5.4 StrEnum==0.4.15 toml==0.10.2 typer==0.9.0 -typing_extensions==4.7.1 +typing_extensions==4.12.2 urllib3==1.26.18 wasabi==1.1.2 websockets==10.4 diff --git a/resources/ApiKeyResource.py b/resources/ApiKeyResource.py index 0bd92d9b..80413c48 100644 --- a/resources/ApiKeyResource.py +++ b/resources/ApiKeyResource.py @@ -1,3 +1,5 @@ +from http import HTTPStatus + from flask import Response from middleware.access_logic import NO_AUTH_INFO, AccessInfoPrimary @@ -22,7 +24,13 @@ class ApiKeyResource(PsycopgResource): namespace=namespace_api_key, auth_info=NO_AUTH_INFO, description="Generates an API key for authenticated users.", - response_info=ResponseInfo(success_message="API Key generated."), + response_info=ResponseInfo( + response_dictionary={ + HTTPStatus.OK.value: "OK. API key generated.", + HTTPStatus.UNAUTHORIZED.value: "Unauthorized. Forbidden or invalid authentication.", + HTTPStatus.INTERNAL_SERVER_ERROR.value: "Internal server error.", + } + ), schema_config=SchemaConfigs.API_KEY_POST, ) def post(self, access_info: AccessInfoPrimary) -> Response: diff --git a/resources/Archives.py b/resources/Archives.py index dc18fefb..cefc1064 100644 --- a/resources/Archives.py +++ b/resources/Archives.py @@ -1,7 +1,12 @@ from flask import Response, request from flask_restx import fields -from middleware.decorators import api_key_required, permissions_required +from middleware.access_logic import ( + GET_AUTH_INFO, + AccessInfoPrimary, + WRITE_ONLY_AUTH_INFO, +) +from middleware.decorators import api_key_required, permissions_required, endpoint_info from middleware.primary_resource_logic.archives_queries import ( archives_get_query, update_archives_data, @@ -11,24 +16,17 @@ from typing import Any from middleware.enums import PermissionsEnum -from resources.resource_helpers import add_api_key_header_arg, add_jwt_header_arg +from resources.endpoint_schema_config import SchemaConfigs +from resources.resource_helpers import ( + add_api_key_header_arg, + add_jwt_header_arg, + ResponseInfo, +) from utilities.namespace import create_namespace from resources.PsycopgResource import PsycopgResource, handle_exceptions namespace_archives = create_namespace() -archives_get_model = namespace_archives.model( - "ArchivesResponse", - { - "id": fields.String(description="The ID of the data source"), - "last_cached": fields.DateTime(description="The last date the data was cached"), - "source_url": fields.String(description="The URL of the data source"), - "update_frequency": fields.String( - description="The archive update frequency of the data source" - ), - }, -) - archives_post_model = namespace_archives.model( "ArchivesPost", { @@ -53,20 +51,16 @@ class Archives(PsycopgResource): A resource for managing archive data, allowing retrieval and update of archived data sources. """ - @handle_exceptions - @api_key_required - @namespace_archives.response( - 200, "Success: Returns a list of archived data sources", archives_get_model - ) - @namespace_archives.response(400, "Error: Bad request missing or bad API key") - @namespace_archives.response( - 403, "Error: Unauthorized. Forbidden or an invalid API key" - ) - @namespace_archives.doc( + @endpoint_info( + namespace=namespace_archives, + auth_info=GET_AUTH_INFO, + schema_config=SchemaConfigs.ARCHIVES_GET, + response_info=ResponseInfo( + success_message="Returns a list of archived data sources.", + ), description="Retrieves archived data sources from the database.", ) - @namespace_archives.expect(archives_header_get_parser) - def get(self) -> Any: + def get(self, access_info: AccessInfoPrimary) -> Any: """ Retrieves archived data sources from the database. @@ -82,19 +76,21 @@ def get(self) -> Any: return archives_combined_results_clean - @handle_exceptions - @permissions_required(PermissionsEnum.DB_WRITE) - @namespace_archives.doc( - description="Updates the archive data based on the provided JSON payload.", - responses={ - 200: "Success: Returns a status message indicating success or an error message if an exception occurs.", - 400: "Error: Bad request missing or bad API key", - 403: "Error: Unauthorized. Forbidden or an invalid API key", - 500: "Error: Internal server error", - }, + @endpoint_info( + namespace=namespace_archives, + auth_info=WRITE_ONLY_AUTH_INFO, + schema_config=SchemaConfigs.ARCHIVES_PUT, + response_info=ResponseInfo( + success_message="Successfully updated the archive data.", + ), + description=""" + Updates the archive data based on the provided JSON payload. + Note that, for this endpoint only, the schema must be provided first as a json string, + rather than as a typical JSON object. + This will be changed in a later update to conform to the standard JSON schema. + """, ) - @namespace_archives.expect(archives_header_post_parser, archives_post_model) - def put(self) -> Response: + def put(self, access_info: AccessInfoPrimary) -> Response: """ Updates the archive data based on the provided JSON payload. diff --git a/resources/LinkToGithub.py b/resources/LinkToGithub.py index 8640d244..18e0cd73 100644 --- a/resources/LinkToGithub.py +++ b/resources/LinkToGithub.py @@ -4,25 +4,14 @@ from database_client.database_client import DatabaseClient from middleware.access_logic import NO_AUTH_INFO, AccessInfoPrimary from middleware.decorators import endpoint_info -from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_request_content_population import ( - populate_schema_with_request_content, -) from middleware.primary_resource_logic.github_oauth_logic import ( - LinkToGithubRequestDTO, link_github_account_request_wrapper, ) from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_request_content_population import ( populate_schema_with_request_content, ) -from middleware.third_party_interaction_logic.callback_flask_sessions_logic import ( - setup_callback_session, -) -from middleware.enums import CallbackFunctionsEnum -from middleware.third_party_interaction_logic.callback_oauth_logic import ( - redirect_to_github_authorization, -) from resources.PsycopgResource import PsycopgResource from resources.endpoint_schema_config import SchemaConfigs from resources.resource_helpers import ResponseInfo @@ -41,7 +30,6 @@ class LinkToGithub(PsycopgResource): response_info=ResponseInfo( response_dictionary={ HTTPStatus.OK.value: "Accounts linked.", - HTTPStatus.BAD_REQUEST.value: "Bad request. Provided email doesn't have associated PDAP account or GitHub acccount doesn't match associated PDAP account.", HTTPStatus.UNAUTHORIZED.value: "Unauthorized. Forbidden or invalid authentication.", HTTPStatus.INTERNAL_SERVER_ERROR.value: "Internal Server Error.", }, diff --git a/resources/LoginWithGithub.py b/resources/LoginWithGithub.py index 8e5cd7aa..defd8404 100644 --- a/resources/LoginWithGithub.py +++ b/resources/LoginWithGithub.py @@ -25,7 +25,6 @@ class LoginWithGithub(PsycopgResource): response_info=ResponseInfo( response_dictionary={ HTTPStatus.OK.value: "User logged in.", - HTTPStatus.BAD_REQUEST.value: "Bad request.", HTTPStatus.UNAUTHORIZED.value: "Unauthorized. Forbidden or invalid authentication.", HTTPStatus.INTERNAL_SERVER_ERROR: "Internal Server Error.", }, diff --git a/resources/Permissions.py b/resources/Permissions.py index db559f78..37bf32b6 100644 --- a/resources/Permissions.py +++ b/resources/Permissions.py @@ -2,16 +2,22 @@ from flask import Response, request -from middleware.decorators import permissions_required -from middleware.enums import PermissionsEnum +from middleware.access_logic import ( + AuthenticationInfo, + AccessInfoPrimary, + WRITE_ONLY_AUTH_INFO, +) +from middleware.decorators import permissions_required, endpoint_info +from middleware.enums import PermissionsEnum, AccessTypeEnum from middleware.primary_resource_logic.permissions_logic import ( manage_user_permissions, update_permissions_wrapper, - PermissionsRequest, - PermissionsRequestSchema, + PermissionsRequestDTO, + PermissionsPutRequestSchema, ) from resources.PsycopgResource import handle_exceptions, PsycopgResource -from resources.resource_helpers import add_jwt_header_arg +from resources.endpoint_schema_config import SchemaConfigs +from resources.resource_helpers import add_jwt_header_arg, ResponseInfo from utilities.namespace import AppNamespaces, create_namespace from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_documentation_construction import ( get_restx_param_documentation, @@ -22,7 +28,7 @@ doc_info = get_restx_param_documentation( namespace=namespace_permissions, - schema=PermissionsRequestSchema, + schema=PermissionsPutRequestSchema, model_name="PermissionRequest", ) @@ -37,16 +43,17 @@ class Permissions(PsycopgResource): Provides a resource for retrieving permissions for a user. """ - @namespace_permissions.expect(all_routes_parser) - @namespace_permissions.response(HTTPStatus.OK.value, "OK; Permissions retrieved") - @namespace_permissions.response( - HTTPStatus.INTERNAL_SERVER_ERROR.value, "Internal server error" - ) - @namespace_permissions.doc( + @endpoint_info( + namespace=namespace_permissions, + auth_info=AuthenticationInfo( + allowed_access_methods=[AccessTypeEnum.JWT], + restrict_to_permissions=[PermissionsEnum.READ_ALL_USER_INFO], + ), + response_info=ResponseInfo(success_message="Returns a user's permissions."), description="Retrieves a user's permissions.", + schema_config=SchemaConfigs.PERMISSIONS_GET, ) - @permissions_required(PermissionsEnum.READ_ALL_USER_INFO) - def get(self) -> Response: + def get(self, access_info: AccessInfoPrimary) -> Response: """ Retrieves a user's permissions. :return: @@ -57,24 +64,21 @@ def get(self) -> Response: method="get_user_permissions", ) - @handle_exceptions - @namespace_permissions.expect(all_routes_parser, permission_model) - @namespace_permissions.response(HTTPStatus.OK.value, "OK; Permissions added") - @namespace_permissions.response( - HTTPStatus.INTERNAL_SERVER_ERROR.value, "Internal server error" - ) - @namespace_permissions.response( - HTTPStatus.CONFLICT.value, - "Permission already exists for user (if adding) or does not exist for user (if removing)", - ) - @namespace_permissions.response( - HTTPStatus.BAD_REQUEST.value, "Missing or bad query parameter" - ) - @namespace_permissions.doc( + @endpoint_info( + namespace=namespace_permissions, + auth_info=WRITE_ONLY_AUTH_INFO, + schema_config=SchemaConfigs.PERMISSIONS_PUT, + response_info=ResponseInfo( + response_dictionary={ + HTTPStatus.OK.value: "OK; Permissions added", + HTTPStatus.INTERNAL_SERVER_ERROR.value: "Internal server error", + HTTPStatus.CONFLICT.value: "Permission already exists for user (if adding) or does not exist for user (if removing)", + HTTPStatus.BAD_REQUEST.value: "Missing or bad query parameter", + } + ), description="Adds or removes a permission for a user.", ) - @permissions_required(PermissionsEnum.DB_WRITE) - def put(self) -> Response: + def put(self, access_info: AccessInfoPrimary) -> Response: """ Adds or removes a permission for a user. :return: @@ -82,7 +86,7 @@ def put(self) -> Response: return self.run_endpoint( wrapper_function=update_permissions_wrapper, schema_populate_parameters=SchemaPopulateParameters( - schema=PermissionsRequestSchema(), - dto_class=PermissionsRequest, + schema=PermissionsPutRequestSchema(), + dto_class=PermissionsRequestDTO, ), ) diff --git a/resources/endpoint_schema_config.py b/resources/endpoint_schema_config.py index d9335122..e6932460 100644 --- a/resources/endpoint_schema_config.py +++ b/resources/endpoint_schema_config.py @@ -7,15 +7,18 @@ from middleware.primary_resource_logic.github_oauth_logic import ( LinkToGithubRequestDTO, ) -from middleware.primary_resource_logic.data_requests import ( - RelatedSourceByIDSchema, - RelatedSourceByIDDTO, - DataRequestsPostDTO, - RelatedLocationsByIDDTO, +from middleware.primary_resource_logic.permissions_logic import ( + PermissionsPutRequestSchema, + PermissionsRequestDTO, + PermissionsGetRequestSchema, ) from middleware.schema_and_dto_logic.primary_resource_dtos.reset_token_dtos import ( ResetPasswordDTO, ) +from middleware.schema_and_dto_logic.primary_resource_schemas.archives_schemas import ( + ArchivesGetResponseSchema, + ArchivesPutRequestSchema, +) from middleware.schema_and_dto_logic.primary_resource_schemas.reset_token_schemas import ( ResetPasswordSchema, ) @@ -37,6 +40,9 @@ ) from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( DataRequestsPutOuterDTO, + RelatedSourceByIDDTO, + RelatedLocationsByIDDTO, + DataRequestsPostDTO, ) from middleware.schema_and_dto_logic.primary_resource_schemas.typeahead_suggestion_schemas import ( @@ -105,6 +111,7 @@ DataRequestsRelatedLocationAddRemoveSchema, GetManyDataRequestsRelatedLocationsSchema, DataRequestsPutSchema, + RelatedSourceByIDSchema, ) from middleware.schema_and_dto_logic.primary_resource_schemas.data_sources_advanced_schemas import ( @@ -452,3 +459,20 @@ class SchemaConfigs(Enum): primary_output_schema=APIKeyResponseSchema(), ) # endregion + + # region Archives + ARCHIVES_GET = EndpointSchemaConfig( + primary_output_schema=ArchivesGetResponseSchema(), + ) + ARCHIVES_PUT = EndpointSchemaConfig( + input_schema=ArchivesPutRequestSchema(), + ) + # endregion + + # region Permission + PERMISSIONS_GET = EndpointSchemaConfig(input_schema=PermissionsGetRequestSchema()) + PERMISSIONS_PUT = EndpointSchemaConfig( + input_schema=PermissionsPutRequestSchema(), + input_dto_class=PermissionsRequestDTO, + ) + # endregion diff --git a/tests/helper_scripts/helper_classes/EndpointCaller.py b/tests/helper_scripts/helper_classes/EndpointCaller.py deleted file mode 100644 index 6f31308c..00000000 --- a/tests/helper_scripts/helper_classes/EndpointCaller.py +++ /dev/null @@ -1,41 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from tests.helper_scripts.run_and_validate_request import run_and_validate_request - - -@dataclass -class EndpointCallInfo: - expected_json_content: Optional[dict] = None - authorization_header: Optional[dict] = None - query_parameters: Optional[dict] = None - - -class EndpointCaller: - """ - Provides a simplified interface for calling endpoints - """ - - def __init__(self, flask_client): - self.flask_client = flask_client - - def notifications_get_endpoint(self, eci: EndpointCallInfo): - return run_and_validate_request( - flask_client=self.flask_client, - http_method="get", - endpoint=NOTIFICATIONS_ENDPOINT, - headers=eci.authorization_header, - expected_json_content=eci.expected_json_content, - expected_schema=SchemaConfigs.NOTIFICATIONS_GET.value.output_schema, - query_params=eci.query_parameters, - ) - - def follow_location(self, eci: EndpointCallInfo): - return run_and_validate_request( - flask_client=self.flask_client, - http_method="post", - endpoint=USER_FOLLOW_LOCATION_ENDPOINT, - headers=eci.authorization_header, - expected_json_content=eci.expected_json_content, - expected_schema=SchemaConfigs.USER_FOLLOW_LOCATION.value.output_schema, - ) diff --git a/tests/helper_scripts/helper_classes/RequestValidator.py b/tests/helper_scripts/helper_classes/RequestValidator.py index 19c87750..739069ef 100644 --- a/tests/helper_scripts/helper_classes/RequestValidator.py +++ b/tests/helper_scripts/helper_classes/RequestValidator.py @@ -452,3 +452,10 @@ def update_password( json={"old_password": old_password, "new_password": new_password}, expected_response_status=expected_response_status, ) + + def get_api_spec( + self, + ): + return self.get( + endpoint="/api/swagger.json", + ) diff --git a/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py b/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py index a9a83536..f25edda0 100644 --- a/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py +++ b/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py @@ -20,7 +20,6 @@ DATA_SOURCES_POST_DELETE_RELATED_AGENCY_ENDPOINT, DATA_REQUESTS_POST_DELETE_RELATED_SOURCE_ENDPOINT, ) -from tests.helper_scripts.helper_classes.EndpointCaller import EndpointCaller from tests.helper_scripts.helper_classes.RequestValidator import RequestValidator from tests.helper_scripts.helper_classes.TestDataCreatorDBClient import ( TestDataCreatorDBClient, @@ -43,7 +42,6 @@ def __init__(self, flask_client: FlaskClient): self.flask_client = flask_client self.tdcdb = TestDataCreatorDBClient() self.request_validator = RequestValidator(flask_client) - self.endpoint_caller = EndpointCaller(flask_client) self.db_client = DatabaseClient() self.admin_tus: Optional[TestUserSetup] = None diff --git a/tests/helper_scripts/helper_classes/UserInfo.py b/tests/helper_scripts/helper_classes/UserInfo.py index 02d44274..adb497ec 100644 --- a/tests/helper_scripts/helper_classes/UserInfo.py +++ b/tests/helper_scripts/helper_classes/UserInfo.py @@ -1,9 +1,7 @@ -from dataclasses import dataclass -from typing import Optional +from pydantic import BaseModel -@dataclass -class UserInfo: +class UserInfo(BaseModel): email: str password: str - user_id: Optional[str] = None + user_id: int = None diff --git a/tests/helper_scripts/helper_functions.py b/tests/helper_scripts/helper_functions.py index d4bcf4ce..562b0f1c 100644 --- a/tests/helper_scripts/helper_functions.py +++ b/tests/helper_scripts/helper_functions.py @@ -47,7 +47,7 @@ def create_test_user_db_client(db_client: DatabaseClient) -> UserInfo: password = get_test_name() password_digest = generate_password_hash(password) user_id = db_client.create_new_user(email=email, password_digest=password_digest) - return UserInfo(email, password, user_id) + return UserInfo(email=email, password=password, user_id=user_id) JWTTokens = namedtuple("JWTTokens", ["access_token", "refresh_token"]) diff --git a/tests/integration/test_check_api_key.py b/tests/integration/test_check_api_key.py index 211fda89..2a791075 100644 --- a/tests/integration/test_check_api_key.py +++ b/tests/integration/test_check_api_key.py @@ -64,38 +64,6 @@ def test_check_api_key_happy_path( check_api_key() -@pytest.mark.parametrize( - "request_headers", - [ - None, - {"Authorization": "Basic"}, - {"Authorization": "BasicAPIKEY"}, - {"Authorization": "Bearer api_key"}, - {"Authrztn": "Basic api_key"}, - ], -) -def test_check_api_key_valid_authorization_header( - monkeypatch, mock_abort, request_headers -): - """ - Test various scenarios where an Invalid API Key response is expected - :param monkeypatch: - :param mock_abort: - :param request_headers: - :return: - """ - patch_request_headers( - monkeypatch, - path=PATCH_REQUESTS_ROOT, - request_headers=request_headers, - ) - - check_api_key() - mock_abort.assert_called_once_with( - code=HTTPStatus.UNAUTHORIZED, message=INVALID_API_KEY_MESSAGE - ) - - # def test_check_api_key_api_key_not_associated_with_user(monkeypatch, mock_abort): diff --git a/tests/integration/test_data_requests.py b/tests/integration/test_data_requests.py index 5cd010ac..fc5a031e 100644 --- a/tests/integration/test_data_requests.py +++ b/tests/integration/test_data_requests.py @@ -215,7 +215,7 @@ def get_data_request(data_request_id: int): post_data_request( json_request, use_authorization_header=False, - expected_response_status=HTTPStatus.UNAUTHORIZED, + expected_response_status=HTTPStatus.BAD_REQUEST, ) # Check that response fails if using invalid columns diff --git a/tests/middleware/test_access_logic.py b/tests/middleware/test_access_logic.py index 182b28bf..f3602beb 100644 --- a/tests/middleware/test_access_logic.py +++ b/tests/middleware/test_access_logic.py @@ -31,44 +31,10 @@ def test_get_authorization_header_from_request_happy_path(monkeypatch): assert "Basic api_key" == get_authorization_header_from_request() -@pytest.mark.parametrize( - "request_headers", - [ - {}, - {"Authrztn": "Basic api_key"}, - ], -) -def test_get_authorization_header_from_request_invalid_authorization_header( - monkeypatch, request_headers -): - patch_request_headers( - monkeypatch, - path="middleware.access_logic", - request_headers=request_headers, - ) - with pytest.raises(InvalidAuthorizationHeaderException): - get_authorization_header_from_request() - - def test_get_api_key_from_authorization_header_happy_path(monkeypatch): assert "api_key" == get_key_from_authorization_header("Basic api_key") -@pytest.mark.parametrize( - "authorization_header", - [ - None, - "Basic", - "Bearer api_key", - ], -) -def test_get_api_key_from_authorization_header_invalid_authorization_header( - monkeypatch, authorization_header -): - with pytest.raises(InvalidAPIKeyException): - get_key_from_authorization_header(authorization_header) - - class GetAPIKeyFromRequestHeaderMock(DynamicMagicMock): get_authorization_header_from_request: MagicMock get_api_key_from_authorization_header: MagicMock diff --git a/tests/middleware/test_dynamic_request_logic.py b/tests/middleware/test_dynamic_request_logic.py index e875c833..9831bd78 100644 --- a/tests/middleware/test_dynamic_request_logic.py +++ b/tests/middleware/test_dynamic_request_logic.py @@ -24,7 +24,9 @@ from middleware.dynamic_request_logic.post_logic import post_entry from middleware.dynamic_request_logic.put_logic import put_entry from middleware.dynamic_request_logic.supporting_classes import IDInfo -from middleware.primary_resource_logic.data_requests import RelatedSourceByIDDTO +from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( + RelatedSourceByIDDTO, +) from middleware.schema_and_dto_logic.common_response_schemas import ( EntryDataResponseSchema, @@ -301,35 +303,6 @@ def test_delete_entry(monkeypatch): assert result == mock.message_response.return_value -def test_delete_id_not_found(monkeypatch, mock_flask_response_manager): - mock = MagicMock() - mock.flask_response_manager = mock_flask_response_manager - monkeypatch.setattr(mock, f"{PATCH_ROOT}.DatabaseClient", mock.DatabaseClient) - mock.dto = RelatedSourceByIDDTO( - resource_id=mock.resource_id, - data_source_id=mock.data_source_id, - ) - mock.access_info = MagicMock() - mock.db_client._select_from_relation.return_value = [] - with pytest.raises(FakeAbort): - delete_entry( - middleware_parameters=mock.mp, - id_info=mock.id_info, - permission_checking_function=mock.permission_checking_function, - ) - mock.mp.db_client._select_from_relation.assert_called_once_with( - relation_name=mock.mp.relation, - where_mappings=mock.id_info.where_mappings, - columns=[mock.id_info.id_column_name], - ) - mock.db_client.get_user_id.assert_not_called() - mock.user_is_creator_of_data_request.assert_not_called() - mock.flask_response_manager.abort.assert_called_once_with( - message=f"Entry for {mock.id_info.where_mappings} not found.", - code=HTTPStatus.NOT_FOUND, - ) - - def test_check_requested_columns_happy_path(mock_flask_response_manager, monkeypatch): mock = MagicMock() mock.get_invalid_columns.return_value = [] diff --git a/tests/resources/conftest.py b/tests/resources/conftest.py index 500fef58..acd35a45 100644 --- a/tests/resources/conftest.py +++ b/tests/resources/conftest.py @@ -3,6 +3,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch +from pydantic import BaseModel from tests.conftest import ( ClientWithMockDB, diff --git a/tests/resources/test_DataRequests.py b/tests/resources/test_DataRequests.py index a0d65add..784039e3 100644 --- a/tests/resources/test_DataRequests.py +++ b/tests/resources/test_DataRequests.py @@ -3,9 +3,9 @@ import pytest from database_client.enums import RequestUrgency -from middleware.primary_resource_logic.data_requests import ( - DataRequestsPostDTO, +from middleware.schema_and_dto_logic.primary_resource_dtos.data_requests_dtos import ( RequestInfoPostDTO, + DataRequestsPostDTO, ) from tests.conftest import client_with_mock_db, bypass_authentication_required from tests.helper_scripts.common_mocks_and_patches import ( diff --git a/tests/test_database.py b/tests/test_database.py index 8683569a..2153df59 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -403,6 +403,8 @@ def update_data_source(column_edit_mappings): def test_qualifying_notifications_view( test_data_creator_db_client: TestDataCreatorDBClient, ): + # Note: Based on testing on 11/30/2024, this test may be wonky on the last day of the month + tdc = test_data_creator_db_client tdc.clear_test_data() old_date = datetime.now() - timedelta(days=60) @@ -600,6 +602,8 @@ def is_dependent_location(dependent_location_id: int, parent_location_id: int): def test_user_pending_notifications_view( test_data_creator_db_client: TestDataCreatorDBClient, ): + # Note: Based on testing on 11/30/2024, this test may be wonky on the last day of the month + notification_valid_date = get_notification_valid_date() tdc = test_data_creator_db_client diff --git a/tests/test_database_client.py b/tests/test_database_client.py index e8384c0e..1f4e4d57 100644 --- a/tests/test_database_client.py +++ b/tests/test_database_client.py @@ -1174,6 +1174,8 @@ def get_user_notification_queue(db_client: DatabaseClient): def test_optionally_update_user_notification_queue(test_data_creator_db_client): + # Note: Based on testing on 11/30/2024, this test may be wonky on the last day of the month + tdc = test_data_creator_db_client tdc.clear_test_data() diff --git a/tests/utilities/test_populate_dto_with_request_content.py b/tests/utilities/test_populate_dto_with_request_content.py index b2ad7ebc..e1fc82b5 100644 --- a/tests/utilities/test_populate_dto_with_request_content.py +++ b/tests/utilities/test_populate_dto_with_request_content.py @@ -7,6 +7,7 @@ from flask_restx import Namespace, Model from flask_restx.reqparse import RequestParser from marshmallow import Schema, fields, validate, ValidationError +from pydantic import BaseModel from tests.helper_scripts.common_mocks_and_patches import patch_request_args_get from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_documentation_construction import ( @@ -211,8 +212,7 @@ class ExampleSchemaWithEnum(Schema): ) -@dataclass -class ExampleDTOWithEnum: +class ExampleDTOWithEnum(BaseModel): example_enum: str @@ -253,8 +253,7 @@ class ExampleNestedSchema(Schema): ) -@dataclass -class ExampleNestedDTO: +class ExampleNestedDTO(BaseModel): example_dto_with_enum: ExampleDTOWithEnum example_dto: ExampleDTO diff --git a/tests_comprehensive/__init__.py b/tests_comprehensive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests_comprehensive/helper_scripts/SpecManager.py b/tests_comprehensive/helper_scripts/SpecManager.py new file mode 100644 index 00000000..3386b328 --- /dev/null +++ b/tests_comprehensive/helper_scripts/SpecManager.py @@ -0,0 +1,102 @@ +from enum import Enum +from http import HTTPStatus +from typing import Optional + +from pydantic import BaseModel + + +class HTTPMethod(Enum): + GET = "get" + POST = "post" + PUT = "put" + PATCH = "patch" + DELETE = "delete" + + +ALL_METHODS_STR = ["get", "post", "put", "delete", "patch"] +ALL_METHODS_ENUM = [ + HTTPMethod.GET, + HTTPMethod.POST, + HTTPMethod.PUT, + HTTPMethod.DELETE, + HTTPMethod.PATCH, +] + + +class MethodInfo(BaseModel): + parent_path: "PathInfo" + method: HTTPMethod + content: dict + + def get_response_codes(self): + for key, value in self.content["responses"].items(): + yield HTTPStatus(int(key)) + + def has_response(self, response_code: HTTPStatus): + response_codes = list(self.get_response_codes()) + return response_code in response_codes + + def has_any_authorization_header(self): + if "parameters" not in self.content: + return False + parameters = self.content["parameters"] + for parameter in parameters: + if parameter["in"] == "header" and parameter["name"] == "Authorization": + return True + return False + + def pathname(self) -> str: + return self.parent_path.route_name + + def __str__(self): + return f"{self.method.name} {self.pathname()}" + + +class PathInfo(BaseModel): + route_name: str + content: dict + + def get_allowed_method_info(self) -> list[MethodInfo]: + for key, value in self.content.items(): + if key not in ALL_METHODS_STR: + continue + method = HTTPMethod(key) + content = value + yield MethodInfo(method=method, content=content, parent_path=self) + + def get_allowed_methods(self) -> list[HTTPMethod]: + for key, value in self.content.items(): + if key not in ALL_METHODS_STR: + continue + method = HTTPMethod(key) + yield method + + def get_disallowed_methods(self) -> list[HTTPMethod]: + allowed_methods = list(self.get_allowed_methods()) + for method in ALL_METHODS_ENUM: + if method not in allowed_methods: + yield method + + +class SpecManager: + + def __init__(self, spec: dict): + self.spec = spec + + def get_paths(self): + for pathname, pathdict in self.spec["paths"].items(): + yield PathInfo(route_name=pathname, content=pathdict) + + def get_methods_with_response( + self, response_status: HTTPStatus + ) -> list[MethodInfo]: + for path_info in self.get_paths(): + for method_info in path_info.get_allowed_method_info(): + if method_info.has_response(response_status): + yield method_info + + def get_methods_with_header(self) -> list[MethodInfo]: + for path_info in self.get_paths(): + for method_info in path_info.get_allowed_method_info(): + if method_info.has_any_authorization_header(): + yield method_info diff --git a/tests_comprehensive/test_bad_request.py b/tests_comprehensive/test_bad_request.py new file mode 100644 index 00000000..7b0fd523 --- /dev/null +++ b/tests_comprehensive/test_bad_request.py @@ -0,0 +1,123 @@ +""" +All the tests in this file are expected to produce a BAD_REQUEST response +""" + +import pytest +from http import HTTPStatus + +from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( + TestDataCreatorFlask, +) +from conftest import monkeysession, test_data_creator_flask +from tests.helper_scripts.run_and_validate_request import run_and_validate_request +from tests_comprehensive.helper_scripts.SpecManager import SpecManager + + +def test_bad_request_authorization(test_data_creator_flask: TestDataCreatorFlask): + """ + All endpoints which expect a header should produce a BAD REQUEST result when + the header is malformed + """ + tdc = test_data_creator_flask + spec: dict = tdc.request_validator.get_api_spec() + + sm = SpecManager(spec) + for method_info in sm.get_methods_with_header(): + print(f"Testing method {method_info}.") + try: + run_and_validate_request( + flask_client=tdc.flask_client, + http_method=method_info.method.value, + endpoint=method_info.pathname(), + headers={"Authorization": "bad"}, + expected_response_status=HTTPStatus.BAD_REQUEST, + ) + except AssertionError as e: + raise AssertionError(f"Error in method {method_info}: {e}") + + +def test_bad_request_missing_header(test_data_creator_flask: TestDataCreatorFlask): + """ + All endpoints which expect a header should produce a BAD REQUEST result when + the header is missing + """ + tdc = test_data_creator_flask + spec: dict = tdc.request_validator.get_api_spec() + sm = SpecManager(spec) + for method_info in sm.get_methods_with_response( + response_status=HTTPStatus.BAD_REQUEST + ): + try: + if not method_info.has_response(HTTPStatus.BAD_REQUEST): + raise AssertionError("""Method should have a BAD REQUEST response.""") + print(f"Testing method {method_info}.") + run_and_validate_request( + flask_client=tdc.flask_client, + http_method=method_info.method.value, + endpoint=method_info.pathname(), + headers={}, + expected_response_status=HTTPStatus.BAD_REQUEST, + ) + except Exception as e: + raise AssertionError(f"Error in method {method_info}: {e}") + + +def test_bad_request_jwt_not_allowed(): + """ + All endpoints where a JWT (aka, "Bearer" auth) is not listed as an option + in the documentation should produce a BAD REQUEST result when a jwt is included + """ + pytest.fail("Not implemented") + + +def test_bad_request_invalid_types(): + """ + Test that fields which expect a certain type produce an expected + BAD_REQUEST validation error when an incorrect type is produced + For example: + - integer expected, but string given + - datetime expected, but date given + - datetime expected, but string given + - date expected, but datetime given + - enum expected, but invalid string given + """ + pytest.fail("Not implemented") + + +def test_bad_request_missing_required_fields(): + """ + Endpoints should return a BAD_REQUEST if a required field is missing + If an endpoint has multiple required fields, each should be individually tested + """ + pytest.fail("Not implemented") + + +def test_bad_request_negative_numbers(): + """ + For all endpoints which accept a number in the request body, + negative numbers should be rejected with a BAD_REQUEST error + """ + pytest.fail("Not implemented") + + +def test_bad_request_string_length(): + """ + Test that a bad request is returned when the string is too long + """ + pytest.fail("Not implemented") + + +def test_bad_request_extra_fields(): + """ + For json requests, extra fields should either be ignored or, ideally, produce + a validation error + """ + pytest.fail("Not implemented") + + +def test_bad_request_endpoints_disallow_out_of_spec_json_inputs(): + """ + Endpoints which accept json inputs should not allow inputs which are + not defined in the schema + """ + pytest.fail("Not implemented") diff --git a/tests_comprehensive/test_forbidden.py b/tests_comprehensive/test_forbidden.py new file mode 100644 index 00000000..63e4f872 --- /dev/null +++ b/tests_comprehensive/test_forbidden.py @@ -0,0 +1,10 @@ +import pytest + + +def test_jwt_insufficient_permissions(): + """ + All endpoints where a JWT (aka, "Bearer" auth) is listed as an option + in the documentation should produce a FORBIDDEN result where the JWT + in question does not have the correct permissions + """ + pytest.fail("Not implemented") diff --git a/tests_comprehensive/test_get_spec.py b/tests_comprehensive/test_get_spec.py new file mode 100644 index 00000000..0881ed95 --- /dev/null +++ b/tests_comprehensive/test_get_spec.py @@ -0,0 +1,10 @@ +from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( + TestDataCreatorFlask, +) +from conftest import monkeysession, test_data_creator_flask + + +def test_get_spec(test_data_creator_flask: TestDataCreatorFlask): + tdc = test_data_creator_flask + + tdc.request_validator.get_api_spec() diff --git a/tests_comprehensive/test_http_not_allowed.py b/tests_comprehensive/test_http_not_allowed.py new file mode 100644 index 00000000..92c6ea32 --- /dev/null +++ b/tests_comprehensive/test_http_not_allowed.py @@ -0,0 +1,26 @@ +from http import HTTPStatus + +from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( + TestDataCreatorFlask, +) +from conftest import monkeysession, test_data_creator_flask +from tests.helper_scripts.run_and_validate_request import run_and_validate_request +from tests_comprehensive.helper_scripts.SpecManager import SpecManager + +ALL_METHODS = ["get", "post", "put", "delete", "patch"] + + +def test_http_not_allowed(test_data_creator_flask: TestDataCreatorFlask): + tdc = test_data_creator_flask + + spec: dict = tdc.request_validator.get_api_spec() + sm = SpecManager(spec) + for path_info in sm.get_paths(): + for method in path_info.get_disallowed_methods(): + print(f"Testing method {method} on path {path_info.route_name}") + run_and_validate_request( + flask_client=tdc.flask_client, + http_method=method.value, + endpoint=path_info.route_name, + expected_response_status=HTTPStatus.METHOD_NOT_ALLOWED, + ) diff --git a/tests_comprehensive/test_inputs_match_specifications.py b/tests_comprehensive/test_inputs_match_specifications.py new file mode 100644 index 00000000..187033ff --- /dev/null +++ b/tests_comprehensive/test_inputs_match_specifications.py @@ -0,0 +1,10 @@ +import pytest + + +def test_inputs_match_specifications(): + """ + Given the specifications, randomly generated inputs + matching the types provided in the specifications + should not produce an internal error, or an error of validation + """ + pytest.fail("Not implemented") diff --git a/tests_comprehensive/test_unauthorized.py b/tests_comprehensive/test_unauthorized.py new file mode 100644 index 00000000..5b346427 --- /dev/null +++ b/tests_comprehensive/test_unauthorized.py @@ -0,0 +1,31 @@ +import pytest + +""" +All the tests in this file are expected to produce an UNAUTHORIZED response +""" + + +def test_unauthorized_jwt_expired(): + """ + All endpoints where a JWT (aka, "Bearer" auth) is listed as an option + in the documentation should produce a BAD REQUEST result when a jwt is expired + """ + pytest.fail("Not implemented") + + +def test_unauthorized_jwt_csrf(): + """ + All endpoints where a JWT (aka, "Bearer" auth) is listed as an option + in the documentation should produce an UNAUTHORIZED result where the JWT + in question is encrypted differently than the one present on the app + """ + pytest.fail("Not implemented") + + +def test_unauthorized_api_key_not_allowed(): + """ + All endpoints where an API key is listed as an option + in the documentation should produce an UNAUTHORIZED result an API key + is not allowed + """ + pytest.fail("Not implemented")