diff --git a/app.py b/app.py index 54f400b8..48382791 100644 --- a/app.py +++ b/app.py @@ -6,6 +6,7 @@ from flask_cors import CORS from middleware.util import get_env_variable +from resources.Batch import namespace_batch from resources.Callback import namespace_auth from resources.DataRequests import namespace_data_requests from resources.GithubDataRequests import namespace_github @@ -64,6 +65,7 @@ namespace_notifications, namespace_map, namespace_signup, + namespace_batch, ] MY_PREFIX = "/api" diff --git a/middleware/access_logic.py b/middleware/access_logic.py index 85d5420c..a1c2335e 100644 --- a/middleware/access_logic.py +++ b/middleware/access_logic.py @@ -139,13 +139,15 @@ def get_identity(): @staticmethod def get_access_info(token: str): - simple_jwt = SimpleJWT.decode(token, purpose=JWTPurpose.STANDARD_ACCESS_TOKEN) - identity = JWTService.get_identity() - if identity: - return get_jwt_access_info_with_permissions( - user_email=identity["user_email"], user_id=identity["id"] + try: + simple_jwt = SimpleJWT.decode( + token, purpose=JWTPurpose.STANDARD_ACCESS_TOKEN ) - return None + except Exception: + return None + return get_jwt_access_info_with_permissions( + user_email=simple_jwt.sub["user_email"], user_id=simple_jwt.sub["id"] + ) def get_token_from_request_header(scheme: AuthScheme): diff --git a/middleware/dynamic_request_logic/post_logic.py b/middleware/dynamic_request_logic/post_logic.py index dd39e7d1..c1c5bbbd 100644 --- a/middleware/dynamic_request_logic/post_logic.py +++ b/middleware/dynamic_request_logic/post_logic.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Optional, Callable, Type +from typing import Optional, Callable, Type, Any import psycopg.errors import sqlalchemy @@ -13,11 +13,25 @@ from middleware.dynamic_request_logic.supporting_classes import ( MiddlewareParameters, PutPostBase, + PostPutHandler, + PutPostRequestInfo, ) from middleware.flask_response_manager import FlaskResponseManager from middleware.util_dynamic import execute_if_not_none +class PostHandler(PostPutHandler): + + def call_database_client_method(self, request: PutPostRequestInfo): + """ + Runs the database client method + and sets the request entry id in-place with the result + """ + request.entry_id = self.mp.db_client_method( + self.mp.db_client, column_value_mappings=request.entry + ) + + class PostLogic(PutPostBase): def __init__( @@ -54,6 +68,21 @@ def make_response(self) -> Response: ) +def post_entry_with_handler( + handler: PostHandler, + dto: Any, +): + request = PutPostRequestInfo(entry=dict(dto), dto=dto) + handler.mass_execute([request]) + if request.error_message is not None: + FlaskResponseManager.abort( + code=HTTPStatus.BAD_REQUEST, message=request.error_message + ) + return created_id_response( + new_id=str(request.entry_id), message=f"{handler.mp.entry_name} created." + ) + + def post_entry( middleware_parameters: MiddlewareParameters, entry: dict, @@ -61,6 +90,7 @@ def post_entry( relation_role_parameters: RelationRoleParameters = RelationRoleParameters(), post_logic_class: Optional[Type[PostLogic]] = PostLogic, check_for_permission: bool = True, + make_response: bool = True, ) -> Response: post_logic = post_logic_class( @@ -70,7 +100,7 @@ def post_entry( relation_role_parameters=relation_role_parameters, check_for_permission=check_for_permission, ) - return post_logic.execute() + return post_logic.execute(make_response=make_response) def try_to_add_entry(middleware_parameters: MiddlewareParameters, entry: dict) -> str: diff --git a/middleware/dynamic_request_logic/put_logic.py b/middleware/dynamic_request_logic/put_logic.py index f55cbbda..ed6856c9 100644 --- a/middleware/dynamic_request_logic/put_logic.py +++ b/middleware/dynamic_request_logic/put_logic.py @@ -11,9 +11,21 @@ from middleware.dynamic_request_logic.supporting_classes import ( MiddlewareParameters, PutPostBase, + PostPutHandler, + PutPostRequestInfo, ) +class PutHandler(PostPutHandler): + + def call_database_client_method(self, request: PutPostRequestInfo): + self.mp.db_client_method( + self.mp.db_client, + column_edit_mappings=request.entry, + entry_id=request.entry_id, + ) + + class PutLogic(PutPostBase): def __init__( diff --git a/middleware/dynamic_request_logic/supporting_classes.py b/middleware/dynamic_request_logic/supporting_classes.py index 6bb3c42e..0251031c 100644 --- a/middleware/dynamic_request_logic/supporting_classes.py +++ b/middleware/dynamic_request_logic/supporting_classes.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Union +from enum import Enum, auto +from typing import Optional, Union, Any from flask import Response +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.enums import RelationRoleEnum @@ -22,9 +24,9 @@ class MiddlewareParameters: Contains parameters for the middleware functions """ - access_info: AccessInfoPrimary relation: str db_client_method: callable + access_info: Optional[AccessInfoPrimary] = None db_client: DatabaseClient = DatabaseClient() # Additional arguments for the Database Client method beyond those provided in the given method db_client_additional_args: dict = field(default_factory=dict) @@ -46,6 +48,58 @@ def __init__( self.where_mappings[id_column_name] = id_column_value +class PutPostRequestInfo(BaseModel): + """ + A DTO for the post/put request + + :param request_id: the id of the request + :param entry: the entry information for the request + :param entry_id: the id of the entry to be updated, OR the id of the entry created + :param error_message: the error message, if any + """ + + request_id: int = 1 + entry: dict + dto: Optional[Any] = None + entry_id: Optional[int] = None + error_message: Optional[str] = None + + +class PostPutHandler(ABC): + + def __init__( + self, + middleware_parameters: MiddlewareParameters, + ): + self.mp = middleware_parameters + + def pre_execute(self, request: PutPostRequestInfo): + return + + @abstractmethod + def call_database_client_method(self, request: PutPostRequestInfo): + raise NotImplementedError + + def execute( + self, + request: PutPostRequestInfo, + ): + self.pre_execute(request=request) + result = self.call_database_client_method(request=request) + self.post_execute(request=request) + return result + + def post_execute(self, request: PutPostRequestInfo): + return + + def mass_execute(self, requests: list[PutPostRequestInfo]): + for request in requests: + try: + self.execute(request=request) + except Exception as e: + request.error_message = str(e) + + class PutPostBase(ABC): def __init__( @@ -83,7 +137,7 @@ def check_can_edit_columns(self, relation_role: RelationRoleEnum): def make_response(self) -> Response: raise NotImplementedError - def execute(self) -> Response: + def execute(self, make_response: bool = True) -> Response: if self.check_for_permission: relation_role = ( self.relation_role_parameters.get_relation_role_from_parameters( @@ -94,7 +148,8 @@ def execute(self) -> Response: self.pre_database_client_method_logic() self.call_database_client_method() self.post_database_client_method_logic() - return self.make_response() + if make_response: + return self.make_response() def post_database_client_method_logic(self): return diff --git a/middleware/exceptions.py b/middleware/exceptions.py index 504165db..7e13e7ae 100644 --- a/middleware/exceptions.py +++ b/middleware/exceptions.py @@ -15,16 +15,6 @@ class DuplicateUserError(Exception): pass -class TokenNotFoundError(Exception): - """Raised when the token is not found in the database.""" - - pass - - -class AccessTokenNotFoundError(Exception): - pass - - class InvalidAPIKeyException(Exception): pass diff --git a/middleware/primary_resource_logic/agencies.py b/middleware/primary_resource_logic/agencies.py index 1d578386..5d65d7e4 100644 --- a/middleware/primary_resource_logic/agencies.py +++ b/middleware/primary_resource_logic/agencies.py @@ -10,11 +10,12 @@ from middleware.dynamic_request_logic.delete_logic import delete_entry from middleware.dynamic_request_logic.get_by_id_logic import get_by_id from middleware.dynamic_request_logic.get_many_logic import get_many -from middleware.dynamic_request_logic.post_logic import post_entry -from middleware.dynamic_request_logic.put_logic import put_entry +from middleware.dynamic_request_logic.post_logic import post_entry, PostHandler +from middleware.dynamic_request_logic.put_logic import put_entry, PutHandler from middleware.dynamic_request_logic.supporting_classes import ( MiddlewareParameters, IDInfo, + PutPostRequestInfo, ) from middleware.location_logic import get_location_id from middleware.schema_and_dto_logic.primary_resource_schemas.agencies_advanced_schemas import ( @@ -29,6 +30,7 @@ LocationInfoDTO, ) from middleware.enums import Relations, JurisdictionType +from middleware.util import dict_enums_to_values SUBQUERY_PARAMS = [SubqueryParameterManager.data_sources()] @@ -93,11 +95,55 @@ def validate_and_add_location_info( entry_data["location_id"] = location_id +class AgencyPostRequestInfo(PutPostRequestInfo): + dto: AgenciesPostDTO + + +class AgencyPostHandler(PostHandler): + + def __init__(self): + super().__init__(middleware_parameters=AGENCY_POST_MIDDLEWARE_PARAMETERS) + + def pre_execute(self, request: AgencyPostRequestInfo): + validate_and_add_location_info( + db_client=DatabaseClient(), + entry_data=request.entry, + location_info=request.dto.location_info, + ) + + +class AgencyPutHandler(PutHandler): + + def __init__(self): + super().__init__(middleware_parameters=AGENCY_PUT_MIDDLEWARE_PARAMETERS) + + def pre_execute(self, request: PutPostRequestInfo): + # The below values are probably incorrect, but serve as a placeholder + entry_data = dict(request.dto.agency_info) + request.entry = dict_enums_to_values(entry_data) + location_info = request.dto.location_info + if location_info is not None: + validate_and_add_location_info( + db_client=DatabaseClient(), + entry_data=entry_data, + location_info=location_info, + ) + + +AGENCY_POST_MIDDLEWARE_PARAMETERS = MiddlewareParameters( + entry_name="agency", + relation=Relations.AGENCIES.value, + db_client_method=DatabaseClient.create_agency, +) + + def create_agency( - db_client: DatabaseClient, dto: AgenciesPostDTO, access_info: AccessInfoPrimary + db_client: DatabaseClient, + dto: AgenciesPostDTO, + make_response: bool = True, ) -> Response: entry_data = dict(dto.agency_info) - deferred_function = optionally_get_location_info_deferred_function( + pre_insertion_function = optionally_get_location_info_deferred_function( db_client=db_client, jurisdiction_type=dto.agency_info.jurisdiction_type, entry_data=entry_data, @@ -106,14 +152,14 @@ def create_agency( return post_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, - access_info=access_info, entry_name="agency", relation=Relations.AGENCIES.value, db_client_method=DatabaseClient.create_agency, ), entry=entry_data, - pre_insertion_function_with_parameters=deferred_function, + pre_insertion_function_with_parameters=pre_insertion_function, + check_for_permission=False, + make_response=make_response, ) @@ -135,6 +181,13 @@ def optionally_get_location_info_deferred_function( return deferred_function +AGENCY_PUT_MIDDLEWARE_PARAMETERS = MiddlewareParameters( + entry_name="agency", + relation=Relations.AGENCIES.value, + db_client_method=DatabaseClient.update_agency, +) + + def update_agency( db_client: DatabaseClient, access_info: AccessInfoPrimary, diff --git a/middleware/primary_resource_logic/batch_logic.py b/middleware/primary_resource_logic/batch_logic.py new file mode 100644 index 00000000..4fb23029 --- /dev/null +++ b/middleware/primary_resource_logic/batch_logic.py @@ -0,0 +1,270 @@ +from dataclasses import dataclass +from http import HTTPStatus +from io import BytesIO + +from marshmallow import Schema, ValidationError + +from database_client.database_client import DatabaseClient +from middleware.dynamic_request_logic.supporting_classes import ( + PutPostRequestInfo, + PostPutHandler, +) +from middleware.flask_response_manager import FlaskResponseManager +from middleware.primary_resource_logic.agencies import ( + AgencyPostRequestInfo, + AgencyPostHandler, + AgencyPutHandler, +) +from middleware.primary_resource_logic.data_sources_logic import ( + DataSourcesPostHandler, + DataSourcesPutHandler, +) +from middleware.schema_and_dto_logic.dynamic_logic.dynamic_csv_to_schema_conversion_logic import ( + SchemaUnflattener, +) +from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_request_content_population import ( + setup_dto_class, +) +from middleware.schema_and_dto_logic.primary_resource_dtos.agencies_dtos import ( + AgenciesPostDTO, +) +from middleware.schema_and_dto_logic.primary_resource_dtos.batch_dtos import ( + BatchRequestDTO, +) +from csv import DictReader + + +def replace_empty_strings_with_none(row: dict): + for key, value in row.items(): + if value == "": + row[key] = None + + +def _get_raw_rows_from_csv( + file: BytesIO, +): + _abort_if_csv(file) + raw_rows = _get_raw_rows(file) + return raw_rows + + +def _get_raw_rows(file: BytesIO): + try: + text_file = (line.decode("utf-8") for line in file) + reader = DictReader(text_file) + rows = list(reader) + return rows + except Exception as e: + FlaskResponseManager.abort( + code=HTTPStatus.BAD_REQUEST, message=f"Error reading csv file: {e}" + ) + + +def _abort_if_csv(file): + if file.filename.split(".")[-1] != "csv": + FlaskResponseManager.abort( + code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, message="File must be of type csv" + ) + + +class BatchRequestManager: + + def __init__(self): + self.requests = [] + + def add_request(self, request: PutPostRequestInfo): + self.requests.append(request) + + def get_requests_with_error(self): + return [ + request for request in self.requests if request.error_message is not None + ] + + def get_requests_without_error(self): + return [request for request in self.requests if request.error_message is None] + + def get_error_dict(self): + d = {} + for request in self.requests: + if request.error_message is not None: + d[request.request_id] = request.error_message + return d + + def all_requests_errored_out(self): + return len(self.get_requests_without_error()) == 0 + + +class BatchRowProcessor: + + def __init__(self, raw_row: dict, request_id: int): + self.raw_row = raw_row + self.request_id = request_id + self.request = None + if "id" in raw_row: + self.entry_id = raw_row["id"] + else: + self.entry_id = None + + def create_request_info_with_error(self, error_message: str): + request = PutPostRequestInfo( + request_id=self.request_id, + entry=self.raw_row, + error_message=error_message, + dto=None, + ) + return request + + def process( + self, unflattener: SchemaUnflattener, inner_dto_class: type, schema: Schema + ): + try: + replace_empty_strings_with_none(row=self.raw_row) + loaded_row = schema.load(self.raw_row) + except ValidationError as e: + # Handle validation error + self.request = self.create_request_info_with_error(error_message=str(e)) + return + unflattened_row = unflattener.unflatten(flat_data=loaded_row) + inner_dto = setup_dto_class( + data=unflattened_row, + dto_class=inner_dto_class, + nested_dto_info_list=unflattener.nested_dto_info_list, + ) + self.request = self.create_completed_request(inner_dto) + + def create_completed_request(self, inner_dto): + return PutPostRequestInfo( + request_id=self.request_id, + entry=dict(inner_dto), + dto=inner_dto, + entry_id=self.entry_id, + ) + + +class AgenciesPostBRP(BatchRowProcessor): + + def create_completed_request(self, inner_dto): + return AgencyPostRequestInfo( + request_id=self.request_id, entry=dict(inner_dto.agency_info), dto=inner_dto + ) + + +@dataclass +class BatchConfig: + dto: BatchRequestDTO + handler: PostPutHandler + brp_class: type[BatchRowProcessor] + schema: Schema + + +def listify_strings(raw_rows: list[dict]): + for raw_row in raw_rows: + for k, v in raw_row.items(): + if isinstance(v, str) and "," in v: + raw_row[k] = v.split(",") + + +def run_batch( + batch_config: BatchConfig, +): + unflattener = SchemaUnflattener( + flat_schema_class=batch_config.dto.csv_schema.__class__ + ) + raw_rows = _get_raw_rows_from_csv(file=batch_config.dto.file) + listify_strings(raw_rows) + schema = batch_config.schema + brm = BatchRequestManager() + for idx, raw_row in enumerate(raw_rows): + brp = batch_config.brp_class(raw_row=raw_row, request_id=idx) + brp.process( + unflattener=unflattener, + inner_dto_class=batch_config.dto.inner_dto_class, + schema=schema, + ) + brm.add_request(request=brp.request) + + handler = batch_config.handler + handler.mass_execute(requests=brm.get_requests_without_error()) + return brm + + +def manage_response( + brm: BatchRequestManager, resource_name: str, verb: str, include_ids: bool = True +): + errors = brm.get_error_dict() + if brm.all_requests_errored_out(): + return FlaskResponseManager.make_response( + status_code=HTTPStatus.OK, + data={ + "message": f"No {resource_name} were {verb} from the provided csv file.", + "errors": errors, + }, + ) + + if include_ids: + kwargs = { + "ids": [request.entry_id for request in brm.get_requests_without_error()] + } + else: + kwargs = {} + + return FlaskResponseManager.make_response( + status_code=HTTPStatus.OK, + data={ + "message": f"At least some {resource_name} created successfully.", + "errors": errors, + **kwargs, + }, + ) + + +def batch_post_agencies(db_client: DatabaseClient, dto: BatchRequestDTO): + brm = run_batch( + batch_config=BatchConfig( + dto=dto, + handler=AgencyPostHandler(), + brp_class=AgenciesPostBRP, + schema=dto.csv_schema.__class__(exclude=["file"]), + ) + ) + return manage_response(brm=brm, resource_name="agencies", verb="created") + + +def batch_put_agencies(db_client: DatabaseClient, dto: BatchRequestDTO): + brm = run_batch( + batch_config=BatchConfig( + dto=dto, + handler=AgencyPutHandler(), + brp_class=BatchRowProcessor, + schema=dto.csv_schema.__class__(exclude=["file"]), + ) + ) + return manage_response( + brm=brm, resource_name="agencies", verb="updated", include_ids=False + ) + + +def batch_post_data_sources(db_client: DatabaseClient, dto: BatchRequestDTO): + brm = run_batch( + batch_config=BatchConfig( + dto=dto, + handler=DataSourcesPostHandler(), + brp_class=BatchRowProcessor, + schema=dto.csv_schema.__class__(exclude=["file"]), + ) + ) + return manage_response(brm=brm, resource_name="data sources", verb="created") + + +def batch_put_data_sources(db_client: DatabaseClient, dto: BatchRequestDTO): + brm = run_batch( + batch_config=BatchConfig( + dto=dto, + handler=DataSourcesPutHandler(), + brp_class=BatchRowProcessor, + schema=dto.csv_schema.__class__(exclude=["file"]), + ) + ) + return manage_response( + brm=brm, resource_name="data_sources", verb="updated", include_ids=False + ) diff --git a/middleware/primary_resource_logic/data_sources_logic.py b/middleware/primary_resource_logic/data_sources_logic.py index 75dff3bf..014f0e4c 100644 --- a/middleware/primary_resource_logic/data_sources_logic.py +++ b/middleware/primary_resource_logic/data_sources_logic.py @@ -17,11 +17,17 @@ GetRelatedResourcesParameters, get_related_resource, ) -from middleware.dynamic_request_logic.post_logic import post_entry, PostLogic -from middleware.dynamic_request_logic.put_logic import put_entry +from middleware.dynamic_request_logic.post_logic import ( + post_entry, + PostLogic, + PostHandler, + post_entry_with_handler, +) +from middleware.dynamic_request_logic.put_logic import put_entry, PutHandler from middleware.dynamic_request_logic.supporting_classes import ( MiddlewareParameters, IDInfo, + PutPostRequestInfo, ) from middleware.enums import Relations @@ -37,6 +43,7 @@ from middleware.schema_and_dto_logic.primary_resource_dtos.data_sources_dtos import ( DataSourceEntryDataPostDTO, DataSourcesPostDTO, + DataSourcesPutDTO, ) from middleware.util import dataclass_to_filtered_dict @@ -199,23 +206,70 @@ def post_database_client_method_logic(self): ) +DATA_SOURCES_POST_MIDDLEWARE_PARAMETERS = MiddlewareParameters( + entry_name="Data source", + relation=RELATION, + db_client_method=DatabaseClient.add_new_data_source, +) + + +class DataSourcesPostRequestInfo(PutPostRequestInfo): + dto: DataSourcesPostDTO + + +class DataSourcesPutRequestInfo(PutPostRequestInfo): + dto: DataSourcesPutDTO + + +class DataSourcesPostHandler(PostHandler): + + def __init__(self): + super().__init__(middleware_parameters=DATA_SOURCES_POST_MIDDLEWARE_PARAMETERS) + + def pre_execute(self, request: DataSourcesPostRequestInfo): + request.entry = dataclass_to_filtered_dict(request.dto.entry_data) + optionally_swap_record_type_name_with_id( + db_client=self.mp.db_client, entry_data=request.entry + ) + + def post_execute(self, request: DataSourcesPostRequestInfo): + if request.dto.linked_agency_ids is None: + return + for agency_id in request.dto.linked_agency_ids: + self.mp.db_client.create_data_source_agency_relation( + column_value_mappings={ + "data_source_id": request.entry_id, + "agency_id": agency_id, + } + ) + + +DATA_SOURCES_PUT_MIDDLEWARE_PARAMETERS = MiddlewareParameters( + entry_name="Data source", + relation=RELATION, + db_client_method=DatabaseClient.update_data_source, +) + + +class DataSourcesPutHandler(PutHandler): + + def __init__(self): + super().__init__(middleware_parameters=DATA_SOURCES_PUT_MIDDLEWARE_PARAMETERS) + + def pre_execute(self, request: DataSourcesPutRequestInfo): + request.entry = dataclass_to_filtered_dict(request.dto.entry_data) + optionally_swap_record_type_name_with_id( + db_client=self.mp.db_client, entry_data=request.entry + ) + + def add_new_data_source_wrapper( db_client: DatabaseClient, dto: DataSourcesPostDTO, access_info: AccessInfoPrimary ) -> Response: - entry_data = dataclass_to_filtered_dict(dto.entry_data) - optionally_swap_record_type_name_with_id(db_client, entry_data) - post_logic = DataSourcesPostLogic( - middleware_parameters=MiddlewareParameters( - db_client=db_client, - access_info=access_info, - entry_name="Data source", - relation=RELATION, - db_client_method=DatabaseClient.add_new_data_source, - ), - entry=entry_data, - agency_ids=dto.linked_agency_ids, + return post_entry_with_handler( + handler=DataSourcesPostHandler(), + dto=dto, ) - return post_logic.execute() # region Related Resources diff --git a/middleware/primary_resource_logic/login_queries.py b/middleware/primary_resource_logic/login_queries.py index 8d7e1f00..14a6c103 100644 --- a/middleware/primary_resource_logic/login_queries.py +++ b/middleware/primary_resource_logic/login_queries.py @@ -1,3 +1,4 @@ +from datetime import timedelta, timezone, datetime from http import HTTPStatus from flask import Response, make_response, jsonify @@ -9,7 +10,7 @@ from werkzeug.security import check_password_hash from database_client.database_client import DatabaseClient -from middleware.SimpleJWT import JWTPurpose +from middleware.SimpleJWT import JWTPurpose, SimpleJWT from middleware.access_logic import AccessInfoPrimary from middleware.exceptions import UserNotFoundError from middleware.primary_resource_logic.user_queries import UserRequestDTO @@ -25,10 +26,16 @@ def __init__(self, email: str): "user_email": email, "id": DatabaseClient().get_user_id(email), } - self.access_token = create_access_token( - identity=identity, - additional_claims={"purpose": JWTPurpose.STANDARD_ACCESS_TOKEN.value}, + simple_jwt = SimpleJWT( + sub=identity, + exp=(datetime.now(tz=timezone.utc) + timedelta(minutes=15)).timestamp(), + purpose=JWTPurpose.STANDARD_ACCESS_TOKEN, ) + self.access_token = simple_jwt.encode() + # self.access_token = create_access_token( + # identity=identity, + # additional_claims={"purpose": JWTPurpose.STANDARD_ACCESS_TOKEN.value}, + # ) self.refresh_token = create_refresh_token(identity=identity) diff --git a/middleware/schema_and_dto_logic/common_schemas_and_dtos.py b/middleware/schema_and_dto_logic/common_schemas_and_dtos.py index 78b4b383..e1cbb62c 100644 --- a/middleware/schema_and_dto_logic/common_schemas_and_dtos.py +++ b/middleware/schema_and_dto_logic/common_schemas_and_dtos.py @@ -11,6 +11,7 @@ from database_client.enums import SortOrder, LocationType from middleware.schema_and_dto_logic.custom_fields import DataField +from middleware.schema_and_dto_logic.enums import CSVColumnCondition from middleware.schema_and_dto_logic.non_dto_dataclasses import ( SchemaPopulateParameters, DTOPopulateParameters, @@ -135,6 +136,7 @@ class TypeaheadDTO(BaseModel): metadata={ "description": "The 2 letter ISO code of the state.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, validate=validate.Length(2), ) @@ -145,6 +147,7 @@ class TypeaheadDTO(BaseModel): "description": "The unique 5-digit FIPS code of the county." "Does not apply to state or federal agencies.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, validate=validate.Length(5), ) @@ -154,6 +157,7 @@ class TypeaheadDTO(BaseModel): metadata={ "description": "The name of the locality.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) @@ -166,6 +170,7 @@ class LocationInfoSchema(Schema): metadata={ "description": "The type of location. ", "source": SourceMappingEnum.JSON, + "csv_column_name": "location_type", }, ) state_iso = STATE_ISO_FIELD diff --git a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_csv_to_schema_conversion_logic.py b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_csv_to_schema_conversion_logic.py new file mode 100644 index 00000000..95fda4fd --- /dev/null +++ b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_csv_to_schema_conversion_logic.py @@ -0,0 +1,202 @@ +import copy +from dataclasses import dataclass +from typing import Optional, Type, Dict, Any + +from marshmallow.fields import Field + +from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_request_content_population import ( + get_nested_dto_info_list, + SourceDataInfo, +) +from middleware.schema_and_dto_logic.enums import CSVColumnCondition +from middleware.schema_and_dto_logic.primary_resource_schemas.data_sources_advanced_schemas import ( + DataSourcesPostSchema, +) + + +from marshmallow import Schema, fields, ValidationError + + +class FlatSchema(Schema): + """ + A custom schema class that enforces flatness. + """ + + origin_schema: Type[Schema] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ensure_no_nested_fields() + + def _ensure_no_nested_fields(self): + for field_name, field_obj in self.fields.items(): + if isinstance(field_obj, fields.Nested): + raise ValidationError( + f"Field '{field_name}' is a nested schema, which is not permitted in FlatSchema." + ) + + # Override the add_field method to enforce flatness + def add_field(self, name, field_obj): + if isinstance(field_obj, fields.Nested): + raise ValidationError( + f"Cannot add nested schema field '{name}' to FlatSchema." + ) + super().add_field(name, field_obj) + + +@dataclass +class FieldPathValueMapping: + + def __init__( + self, + field_path, + field, + name: str | CSVColumnCondition = CSVColumnCondition.SAME_AS_FIELD, + ): + self.field_path = field_path + self.field = field + if name == CSVColumnCondition.SAME_AS_FIELD: + self.name = field_path[-1] + else: + self.name = name + + def get_schema_notation(self): + return ".".join(self.field_path) + + def get_field_name(self): + return self.field_path[-1] + + +def extract_field_path_value_mapping( + schema: Schema, key: str, parent_path: Optional[list[str]] = None +) -> list[FieldPathValueMapping]: + """ + Recursive function to traverse schema for specific metadata key + + :param schema: The schema to extract metadata from + :param key: The key to extract from the metadata + :param parent_path: The path to the parent field + :return: A list of field path value mappings + """ + if parent_path is None: + parent_path = [] + + metadata = [] + for field_name, field in schema.fields.items(): + # Construct the path to the field + current_path = parent_path + [field_name] + + # Check if the key exists in the field's metadata + if key in field.metadata: + metadata.append( + FieldPathValueMapping( + field_path=current_path, name=field.metadata[key], field=field + ), + ) + + # Check if the field is a nested schema + if isinstance(field, fields.Nested) and isinstance(field.schema, Schema): + # Recursively extract nested schema metadata + nested_metadata = extract_field_path_value_mapping( + field.schema, key, current_path + ) + metadata.extend(nested_metadata) + + # Check if the field is a list with a nested schema + elif isinstance(field, fields.List) and isinstance(field.inner, fields.Nested): + if isinstance(field.inner.schema, Schema): + nested_metadata = extract_field_path_value_mapping( + field.inner.schema, key, current_path + ["[]"] + ) + metadata.extend(nested_metadata) + + return metadata + + +def get_csv_columns_from_schema(schema: Schema) -> list[FieldPathValueMapping]: + """ + Go through all fields which have `csv_column_name` metadata + and add their full path to the mapping + """ + fieldpath_value_mapping = extract_field_path_value_mapping( + schema, "csv_column_name" + ) + return fieldpath_value_mapping + + +def generate_flat_csv_schema( + schema: Schema, additional_fields: Optional[Dict[str, Field]] = None +) -> Type[FlatSchema]: + """ + Dynamically generate flat csv schema from a nested schema + """ + fpvms = get_csv_columns_from_schema(schema) + new_fields = {} + schema_class = schema.__class__ + new_fields["origin_schema"] = schema_class + for fpvm in fpvms: + new_fields[fpvm.name] = copy.copy(fpvm.field) + if additional_fields is not None: + new_fields.update(additional_fields) + NewSchema = type(f"Flat{schema_class.__name__}", (FlatSchema,), new_fields) + + return NewSchema + + +def create_csv_column_to_field_map( + fpvms: list[FieldPathValueMapping], +) -> dict[str, list[str]]: + csv_column_to_field_map = {} + for fpvm in fpvms: + if fpvm.name is CSVColumnCondition.SAME_AS_FIELD: + field_name = fpvm.get_field_name() + else: + field_name = fpvm.name + csv_column_to_field_map[field_name] = fpvm.field_path + + return csv_column_to_field_map + + +def set_nested_value(d: dict, path: list[str], value: Any): + """ + Sets a value in a nested dictionary based on a given path. + Creates nested dictionaries for intermediate keys if they don't exist. + + :param d: The dictionary to modify. + :param path: A list of strings representing the path to the desired key. + :param value: The value to set for the final key. + """ + current = d + for key in path[:-1]: # Iterate over all keys except the last one + if key not in current: + current[key] = {} # Create a new dictionary if key is missing + current = current[key] + current[path[-1]] = value + + +class SchemaUnflattener: + + def __init__(self, flat_schema_class: Type[FlatSchema]): + self.flat_schema_class = flat_schema_class + self.origin_schema_class = flat_schema_class.origin_schema + self.origin_schema = self.origin_schema_class() + self.fpvms = get_csv_columns_from_schema(schema=self.origin_schema) + self.nested_dto_info_list = get_nested_dto_info_list(schema=self.origin_schema) + + def unflatten(self, flat_data: dict) -> dict: + data = {} + for fpvm in self.fpvms: + if fpvm.name not in flat_data: + continue + set_nested_value( + d=data, + path=fpvm.field_path, + value=flat_data[fpvm.name], + ) + return data + + +if __name__ == "__main__": + mapping = get_csv_columns_from_schema(schema=DataSourcesPostSchema()) + new_schema = generate_flat_csv_schema(schema=DataSourcesPostSchema()) + print(new_schema) diff --git a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py index 7587bad1..8f7d94bb 100644 --- a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py +++ b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_documentation_construction.py @@ -45,13 +45,22 @@ from resources.resource_helpers import create_variable_columns_model from utilities.enums import SourceMappingEnum +PARSER_FIELDS = [ + SourceMappingEnum.QUERY_ARGS, + SourceMappingEnum.PATH, + SourceMappingEnum.FILE, +] + +PARSER_SOURCE_LOCATION_MAP = { + SourceMappingEnum.QUERY_ARGS: "query", + SourceMappingEnum.PATH: "path", + SourceMappingEnum.FILE: "file", +} + # region Supporting Functions def get_location(source: SourceMappingEnum) -> str: - if source == SourceMappingEnum.PATH: - return "path" - if source == SourceMappingEnum.QUERY_ARGS: - return "query" + return PARSER_SOURCE_LOCATION_MAP[source] def add_description_info_from_validators( @@ -290,7 +299,7 @@ def sort_fields(self, schema: SchemaTypes): self._sort_field(fi) def _sort_field(self, fi: FieldInfo): - if fi.source in (SourceMappingEnum.QUERY_ARGS, SourceMappingEnum.PATH): + if fi.source in PARSER_FIELDS: self.parser_fields.append(fi) elif fi.source == SourceMappingEnum.JSON: self.model_fields.append(fi) diff --git a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py index 73426410..11aecc7c 100644 --- a/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py +++ b/middleware/schema_and_dto_logic/dynamic_logic/dynamic_schema_request_content_population.py @@ -1,11 +1,16 @@ +from dataclasses import dataclass from http import HTTPStatus from typing import Type import marshmallow +from flask import request from pydantic import BaseModel from middleware.flask_response_manager import FlaskResponseManager from middleware.schema_and_dto_logic.custom_types import SchemaTypes, DTOTypes +from middleware.schema_and_dto_logic.primary_resource_dtos.batch_dtos import ( + BatchRequestDTO, +) from middleware.schema_and_dto_logic.util import ( _get_required_argument, _get_source_getting_function, @@ -24,7 +29,7 @@ class SourceDataInfo(BaseModel): def populate_schema_with_request_content( - schema: SchemaTypes, dto_class: Type[DTOTypes] + schema: SchemaTypes, dto_class: Type[DTOTypes], load_file: bool = False ) -> DTOTypes: """ Populates a marshmallow schema with request content, given custom arguments in the schema fields @@ -36,15 +41,19 @@ def populate_schema_with_request_content( :return: """ # Get all declared fields from the schema + if load_file: + return BatchRequestDTO( + file=request.files.get("file"), csv_schema=schema, inner_dto_class=dto_class + ) fields = schema.fields - source_data_info = _get_data_from_sources(fields, schema) - + source_data_info = get_source_data_info_from_sources(schema) intermediate_data = validate_data(source_data_info.data, schema) - _apply_transformation_functions_to_dict(fields, intermediate_data) return setup_dto_class( - intermediate_data, dto_class, source_data_info.nested_dto_info_list + data=intermediate_data, + dto_class=dto_class, + nested_dto_info_list=source_data_info.nested_dto_info_list, ) @@ -72,29 +81,46 @@ class InvalidSourceMappingError(Exception): pass -def _get_data_from_sources(fields: dict, schema: SchemaTypes) -> SourceDataInfo: - """ - Get data from sources specified in the schema and field metadata - :param fields: - :param schema: - :return: - """ - data = {} +def get_nested_dto_info_list(schema: SchemaTypes) -> list[NestedDTOInfo]: nested_dto_info_list = [] - for field_name, field_value in fields.items(): + for field_name, field_value in schema.fields.items(): + if not isinstance(field_value, marshmallow.fields.Nested): + continue metadata = field_value.metadata source: SourceMappingEnum = _get_required_argument("source", metadata, schema) - if isinstance(field_value, marshmallow.fields.Nested): - _check_for_errors(metadata, source) - nested_dto_class = metadata["nested_dto_class"] - nested_dto_info_list.append( - NestedDTOInfo(key=field_name, class_=nested_dto_class) - ) + _check_for_errors(metadata=metadata, source=source) + nested_dto_class = metadata["nested_dto_class"] + nested_dto_info_list.append( + NestedDTOInfo(key=field_name, class_=nested_dto_class) + ) + return nested_dto_info_list + + +def _get_data_from_sources(schema: SchemaTypes) -> dict: + data = {} + for field_name, field_value in schema.fields.items(): + metadata = field_value.metadata + source: SourceMappingEnum = _get_required_argument( + argument_name="source", metadata=metadata, schema_class=schema + ) source_getting_function = _get_source_getting_function(source) val = source_getting_function(field_name) if val is not None: data[field_name] = val + + return data + + +def get_source_data_info_from_sources(schema: SchemaTypes) -> SourceDataInfo: + """ + Get data from sources specified in the schema and field metadata + :param fields: + :param schema: + :return: + """ + data = _get_data_from_sources(schema=schema) + nested_dto_info_list = get_nested_dto_info_list(schema=schema) return SourceDataInfo(data=data, nested_dto_info_list=nested_dto_info_list) diff --git a/middleware/schema_and_dto_logic/enums.py b/middleware/schema_and_dto_logic/enums.py index a276fcfe..ccb8b824 100644 --- a/middleware/schema_and_dto_logic/enums.py +++ b/middleware/schema_and_dto_logic/enums.py @@ -5,3 +5,12 @@ class RestxModelPlaceholder(Enum): # Placeholders for nested models VARIABLE_COLUMNS = "variable_columns" LIST_VARIABLE_COLUMNS = "list_variable_columns" + + +class CSVColumnCondition(Enum): + """ + SAME_AS_FIELD: Indicates that the csv column + should be the same as the field in the schema + """ + + SAME_AS_FIELD = "SAME_AS_FIELD" diff --git a/middleware/schema_and_dto_logic/non_dto_dataclasses.py b/middleware/schema_and_dto_logic/non_dto_dataclasses.py index e03600d9..c3c96628 100644 --- a/middleware/schema_and_dto_logic/non_dto_dataclasses.py +++ b/middleware/schema_and_dto_logic/non_dto_dataclasses.py @@ -21,6 +21,7 @@ class SchemaPopulateParameters: schema: SchemaTypes dto_class: Type[DTOTypes] + load_file: bool = False class DTOPopulateParameters(BaseModel): diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/batch_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/batch_dtos.py new file mode 100644 index 00000000..29e367cb --- /dev/null +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/batch_dtos.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Any + +from marshmallow import Schema +from pydantic import BaseModel +from werkzeug.datastructures import FileStorage + + +@dataclass +class BatchRequestDTO: + file: FileStorage + csv_schema: Schema + inner_dto_class: Any diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py index f3955f92..da970269 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/data_sources_dtos.py @@ -11,6 +11,7 @@ AccessType, RetentionSchedule, URLStatus, + UpdateMethod, ) from middleware.enums import RecordType @@ -30,7 +31,7 @@ class DataSourceEntryDataPostDTO(BaseModel): access_types: Optional[List[AccessType]] = None data_portal_type: Optional[str] = None record_formats: Optional[List[str]] = None - update_method: Optional[str] = None + update_method: Optional[UpdateMethod] = None tags: Optional[List[str]] = None readme_url: Optional[str] = None originating_entity: Optional[str] = None @@ -50,6 +51,41 @@ class DataSourceEntryDataPostDTO(BaseModel): record_type_name: Optional[RecordType] = None +class DataSourceEntryDataPutDTO(BaseModel): + submitted_name: str + description: Optional[str] = None + approval_status: Optional[ApprovalStatus] = None + source_url: Optional[str] = None + agency_supplied: Optional[bool] = None + supplying_entity: Optional[str] = None + agency_originated: Optional[bool] = None + agency_aggregation: Optional[AgencyAggregation] = None + coverage_start: Optional[date] = None + coverage_end: Optional[date] = None + detail_level: Optional[DetailLevel] = None + access_types: Optional[List[AccessType]] = None + data_portal_type: Optional[str] = None + record_formats: Optional[List[str]] = None + update_method: Optional[UpdateMethod] = None + tags: Optional[List[str]] = None + readme_url: Optional[str] = None + originating_entity: Optional[str] = None + retention_schedule: Optional[RetentionSchedule] = None + scraper_url: Optional[str] = None + submission_notes: Optional[str] = None + submitter_contact_info: Optional[str] = None + agency_described_submitted: Optional[str] = None + agency_described_not_in_database: Optional[str] = None + data_portal_type_other: Optional[str] = None + access_notes: Optional[str] = None + url_status: Optional[URLStatus] = None + record_type_name: Optional[RecordType] = None + + +class DataSourcesPutDTO(BaseModel): + entry_data: DataSourceEntryDataPutDTO + + class DataSourcesPostDTO(BaseModel): entry_data: DataSourceEntryDataPostDTO linked_agency_ids: Optional[List[int]] = None diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/agencies_base_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/agencies_base_schemas.py index 8695d3d0..e130fe50 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/agencies_base_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/agencies_base_schemas.py @@ -6,6 +6,7 @@ COUNTY_FIPS_FIELD, LOCALITY_NAME_FIELD, ) +from middleware.schema_and_dto_logic.enums import CSVColumnCondition from utilities.enums import SourceMappingEnum @@ -15,6 +16,7 @@ def get_submitted_name_field(required: bool) -> fields.Str: metadata={ "description": "The name of the agency.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) @@ -27,6 +29,7 @@ def get_jurisdiction_type_field(required: bool) -> fields.Enum: metadata={ "description": "The highest level of jurisdiction of the agency.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) @@ -37,6 +40,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The URL of the agency's homepage.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) lat = fields.Float( @@ -45,6 +49,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The latitude of the agency's location.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) lng = fields.Float( @@ -53,6 +58,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The longitude of the agency's location.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) defunct_year = fields.Str( @@ -61,6 +67,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "If present, denotes an agency which has defunct but may still have relevant records.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) agency_type = fields.Enum( @@ -72,6 +79,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The type of agency.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) multi_agency = fields.Bool( @@ -80,6 +88,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "Whether the agency is a multi-agency.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) zip_code = fields.Str( @@ -88,6 +97,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The zip code of the agency's location.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, # TODO: Re-enable when all zip codes are of expected length # validate=validate.Length(min=5), @@ -98,6 +108,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "True when an agency does not have a dedicated website.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) approved = fields.Bool( @@ -130,6 +141,7 @@ class AgencyInfoBaseSchema(Schema): metadata={ "description": "The contact information of the user who submitted the agency.", "source": SourceMappingEnum.JSON, + "csv_column_name": CSVColumnCondition.SAME_AS_FIELD, }, ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/batch_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/batch_schemas.py new file mode 100644 index 00000000..97bbc53e --- /dev/null +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/batch_schemas.py @@ -0,0 +1,94 @@ +from marshmallow import Schema, fields + +from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema +from middleware.schema_and_dto_logic.dynamic_logic.dynamic_csv_to_schema_conversion_logic import ( + generate_flat_csv_schema, +) +from middleware.schema_and_dto_logic.primary_resource_schemas.agencies_advanced_schemas import ( + AgenciesPostSchema, + AgenciesPutSchema, +) +from middleware.schema_and_dto_logic.primary_resource_schemas.data_sources_advanced_schemas import ( + DataSourcesPostSchema, + DataSourcesPutSchema, +) +from middleware.schema_and_dto_logic.util import get_json_metadata +from utilities.enums import SourceMappingEnum + + +class BatchRequestSchema(Schema): + + file = fields.String( + required=True, + metadata={ + "source": SourceMappingEnum.FILE, + "description": "The file to upload", + }, + ) + + +class BatchPutRequestSchema(BatchRequestSchema): + id = fields.Integer( + required=True, metadata=get_json_metadata("The id of the resource to update") + ) + + +# region Base Schemas +DataSourcesPostRequestFlatBaseSchema = generate_flat_csv_schema( + schema=DataSourcesPostSchema() +) +DataSourcesPutRequestFlatBaseSchema = generate_flat_csv_schema( + schema=DataSourcesPutSchema(), +) +AgenciesPostRequestFlatBaseSchema = generate_flat_csv_schema( + schema=AgenciesPostSchema() +) +AgenciesPutRequestFlatBaseSchema = generate_flat_csv_schema(schema=AgenciesPutSchema()) +# endregion + + +# region RequestSchemas +class DataSourcesPostBatchRequestSchema( + BatchRequestSchema, DataSourcesPostRequestFlatBaseSchema +): + pass + + +class DataSourcesPutBatchRequestSchema( + BatchPutRequestSchema, DataSourcesPutRequestFlatBaseSchema +): + pass + + +class AgenciesPostBatchRequestSchema( + BatchRequestSchema, AgenciesPostRequestFlatBaseSchema +): + pass + + +class AgenciesPutBatchRequestSchema( + BatchPutRequestSchema, AgenciesPutRequestFlatBaseSchema +): + pass + + +# endregion + + +class BatchPostResponseSchema(MessageSchema): + ids = fields.List( + fields.Integer(metadata=get_json_metadata("The ids of the resources created")), + required=True, + metadata=get_json_metadata("The ids of the resources created"), + ) + errors = fields.Dict( + required=True, + metadata=get_json_metadata("The errors associated with resources not created"), + ) + + +class BatchPutResponseSchema(MessageSchema): + errors = fields.Dict( + required=True, + metadata=get_json_metadata("The errors associated with resources not updated"), + ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_advanced_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_advanced_schemas.py index 4a7822e8..f9ed65ca 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_advanced_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_advanced_schemas.py @@ -6,8 +6,10 @@ from middleware.schema_and_dto_logic.common_schemas_and_dtos import ( GetManyRequestsBaseSchema, ) +from middleware.schema_and_dto_logic.enums import CSVColumnCondition from middleware.schema_and_dto_logic.primary_resource_dtos.data_sources_dtos import ( DataSourceEntryDataPostDTO, + DataSourceEntryDataPutDTO, ) from middleware.schema_and_dto_logic.primary_resource_schemas.agencies_base_schemas import ( AgenciesExpandedSchema, @@ -120,10 +122,14 @@ class DataSourcesPostSchema(Schema): fields.Integer( allow_none=True, metadata=get_json_metadata( - "The agency ids associated with the data source." + "The agency ids associated with the data source.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ), - metadata=get_json_metadata("The agency ids associated with the data source."), + metadata=get_json_metadata( + "The agency ids associated with the data source.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) @@ -145,7 +151,10 @@ class DataSourcesPutSchema(Schema): ] ), required=True, - metadata=get_json_metadata("The data source to be updated"), + metadata=get_json_metadata( + "The data source to be updated", + nested_dto_class=DataSourceEntryDataPutDTO, + ), ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_base_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_base_schemas.py index b08b691e..572f1446 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_base_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/data_sources_base_schemas.py @@ -10,6 +10,7 @@ ApprovalStatus, ) from middleware.enums import RecordType +from middleware.schema_and_dto_logic.enums import CSVColumnCondition from middleware.schema_and_dto_logic.util import get_json_metadata @@ -27,7 +28,8 @@ class DataSourceBaseSchema(Schema): required=True, allow_none=True, metadata=get_json_metadata( - "The name of the data source as originally submitted." + "The name of the data source as originally submitted.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) description = fields.String( @@ -36,7 +38,8 @@ class DataSourceBaseSchema(Schema): metadata=get_json_metadata( description="Information to give clarity and confidence about what this source is, how it was " "processed, and whether the person reading the description might want to use it. " - "Especially important if the source is difficult to preview or categorize." + "Especially important if the source is difficult to preview or categorize.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) @@ -44,7 +47,8 @@ class DataSourceBaseSchema(Schema): required=True, allow_none=True, metadata=get_json_metadata( - "A URL where these records can be found or are referenced." + "A URL where these records can be found or are referenced.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) agency_supplied = fields.Boolean( @@ -52,39 +56,47 @@ class DataSourceBaseSchema(Schema): metadata=get_json_metadata( 'Is the relevant Agency also the entity supplying the data? This may be "no" if the Agency or local ' "government contracted with a third party to publish this data, or if a third party was the original " - "record-keeper." + "record-keeper.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) supplying_entity = fields.String( allow_none=True, - metadata=get_json_metadata("If the Agency didn't publish this, who did?"), + metadata=get_json_metadata( + "If the Agency didn't publish this, who did?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) agency_originated = fields.Boolean( allow_none=True, metadata=get_json_metadata( 'Is the relevant Agency also the original record-keeper? This is usually "yes", unless a third party ' - "collected data about a police Agency." + "collected data about a police Agency.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) agency_aggregation = fields.Enum( enum=AgencyAggregation, by_value=fields.Str, metadata=get_json_metadata( - "If present, the Data Source describes multiple agencies." + "If present, the Data Source describes multiple agencies.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), allow_none=True, ) coverage_start = fields.Date( allow_none=True, metadata=get_json_metadata( - "The start date of the data source's coverage, in the format YYYY-MM-DD." + "The start date of the data source's coverage, in the format YYYY-MM-DD.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), format="iso", ) coverage_end = fields.Date( allow_none=True, metadata=get_json_metadata( - "The end date of the data source's coverage, in the format YYYY-MM-DD." + "The end date of the data source's coverage, in the format YYYY-MM-DD.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), format="iso", ) @@ -100,7 +112,8 @@ class DataSourceBaseSchema(Schema): enum=DetailLevel, by_value=fields.Str, metadata=get_json_metadata( - "Is this an individual record, an aggregated set of records, or a summary without underlying data?" + "Is this an individual record, an aggregated set of records, or a summary without underlying data?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) access_types = fields.List( @@ -108,32 +121,43 @@ class DataSourceBaseSchema(Schema): enum=AccessType, by_value=fields.Str, metadata=get_json_metadata( - "The ways the data source can be accessed. Editable only by admins." + "The ways the data source can be accessed. Editable only by admins.", ), ), allow_none=True, metadata=get_json_metadata( - "The ways the data source can be accessed. Editable only by admins." + "The ways the data source can be accessed. Editable only by admins.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) data_portal_type = fields.String( allow_none=True, - metadata=get_json_metadata("The data portal type of the data source."), + metadata=get_json_metadata( + "The data portal type of the data source.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) record_formats = fields.List( fields.String( metadata=get_json_metadata( - "What formats the data source can be obtained in." + "What formats the data source can be obtained in.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ), allow_none=True, - metadata=get_json_metadata("What formats the data source can be obtained in."), + metadata=get_json_metadata( + "What formats the data source can be obtained in.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) update_method = fields.Enum( enum=UpdateMethod, by_value=fields.Str, allow_none=True, - metadata=get_json_metadata("How is the data source updated?"), + metadata=get_json_metadata( + "How is the data source updated?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) tags = fields.List( fields.String( @@ -143,25 +167,31 @@ class DataSourceBaseSchema(Schema): ), ), metadata=get_json_metadata( - "Are there any keyword descriptors which might help people find this in a search? Try to limit tags to information which can't be contained in other properties." + "Are there any keyword descriptors which might help people find this in a search? Try to limit tags to information which can't be contained in other properties.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), allow_none=True, ) readme_url = fields.String( metadata=get_json_metadata( - "A URL where supplementary information about the source is published." + "A URL where supplementary information about the source is published.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), allow_none=True, ) originating_entity = fields.String( allow_none=True, - metadata=get_json_metadata("Who is the originator of the data source?"), + metadata=get_json_metadata( + "Who is the originator of the data source?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) retention_schedule = fields.Enum( enum=RetentionSchedule, by_value=fields.Str, metadata=get_json_metadata( - "How long are records kept? Are there published guidelines regarding how long important information must remain accessible for future use? Editable only by admins." + "How long are records kept? Are there published guidelines regarding how long important information must remain accessible for future use? Editable only by admins.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), allow_none=True, ) @@ -171,7 +201,10 @@ class DataSourceBaseSchema(Schema): ) scraper_url = fields.String( allow_none=True, - metadata=get_json_metadata("URL for the webscraper that produces this source"), + metadata=get_json_metadata( + "URL for the webscraper that produces this source", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) created_at = fields.DateTime( metadata=get_json_metadata("The date and time the data source was created."), @@ -179,7 +212,8 @@ class DataSourceBaseSchema(Schema): submission_notes = fields.String( allow_none=True, metadata=get_json_metadata( - "What are you trying to learn? Are you trying to answer a specific question, or complete a specific project? Is there anything you've already tried?" + "What are you trying to learn? Are you trying to answer a specific question, or complete a specific project? Is there anything you've already tried?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) rejection_note = fields.String( @@ -194,25 +228,29 @@ class DataSourceBaseSchema(Schema): submitter_contact_info = fields.String( allow_none=True, metadata=get_json_metadata( - "Contact information for the individual who provided the data source" + "Contact information for the individual who provided the data source", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) agency_described_submitted = fields.String( allow_none=True, metadata=get_json_metadata( - "To which criminal legal systems agency or agencies does this Data Source refer?" + "To which criminal legal systems agency or agencies does this Data Source refer?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) agency_described_not_in_database = fields.String( allow_none=True, metadata=get_json_metadata( - "If the agency associated is not in the database, why?" + "If the agency associated is not in the database, why?", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) data_portal_type_other = fields.String( allow_none=True, metadata=get_json_metadata( - "What unconventional data portal this data source is derived from" + "What unconventional data portal this data source is derived from", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, ), ) data_source_request = fields.String( @@ -227,7 +265,10 @@ class DataSourceBaseSchema(Schema): metadata=get_json_metadata("When the url was marked as broken."), ) access_notes = fields.String( - metadata=get_json_metadata("How the source can be accessed,"), + metadata=get_json_metadata( + "How the source can be accessed,", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), allow_none=True, ) url_status = fields.Enum( @@ -266,7 +307,10 @@ class DataSourceExpandedSchema(DataSourceBaseSchema): enum=RecordType, by_value=fields.Str, allow_none=True, - metadata=get_json_metadata("The type of data source. Editable only by admins."), + metadata=get_json_metadata( + "The record type of the data source.", + csv_column_name=CSVColumnCondition.SAME_AS_FIELD, + ), ) diff --git a/middleware/schema_and_dto_logic/schema_helpers.py b/middleware/schema_and_dto_logic/schema_helpers.py index b033d5c3..03afb940 100644 --- a/middleware/schema_and_dto_logic/schema_helpers.py +++ b/middleware/schema_and_dto_logic/schema_helpers.py @@ -1,24 +1,33 @@ -from importlib.metadata import metadata -from typing import Type +from typing import Type, Annotated from marshmallow import Schema, fields +from pydantic import BaseModel from middleware.schema_and_dto_logic.common_response_schemas import ( GetManyResponseSchemaBase, MessageSchema, ) +from middleware.schema_and_dto_logic.enums import CSVColumnCondition from middleware.schema_and_dto_logic.util import get_json_metadata - - -def create_post_schema(entry_data_schema: Schema, description: str) -> Type[Schema]: - class PostSchema(Schema): - entry_data = fields.Nested( - entry_data_schema, - required=True, - metadata=get_json_metadata(description), - ) - - return PostSchema +from utilities.enums import SourceMappingEnum + + +class SchemaMetadata(BaseModel): + description: Annotated[ + str, "The description of the field to be displayed in the API docs" + ] + source: Annotated[ + SourceMappingEnum, + "The source of the field when dynamically populated from requests", + ] + csv_column_name: Annotated[ + str | CSVColumnCondition | None, + ( + "The name of the field when represented as a column in a CSV file." + "Also exists as a flag to indicate that this field should be included " + "in a CSV representation of the data" + ), + ] = None def create_get_many_schema(data_list_schema: Schema, description: str) -> Type[Schema]: diff --git a/middleware/schema_and_dto_logic/util.py b/middleware/schema_and_dto_logic/util.py index 01900478..687451ed 100644 --- a/middleware/schema_and_dto_logic/util.py +++ b/middleware/schema_and_dto_logic/util.py @@ -1,4 +1,4 @@ -from typing import Type, Any, Callable, Optional +from typing import Any, Callable, Optional from flask import request diff --git a/middleware/util.py b/middleware/util.py index 60bf34ca..052e1ba7 100644 --- a/middleware/util.py +++ b/middleware/util.py @@ -32,6 +32,16 @@ def get_enum_values(en: type[Enum]) -> list[str]: return [e.value for e in en] +def dict_enums_to_values(d: dict[str, Any]) -> dict[str, Any]: + """ + Convert enums within a dictionary to their values. + """ + for key, value in d.items(): + if isinstance(value, Enum): + d[key] = value.value + return d + + def dataclass_to_filtered_dict(instance: Any) -> Dict[str, Any]: """ Convert a dataclass instance to a dictionary, filtering out any None values. diff --git a/middleware/webhook_logic.py b/middleware/webhook_logic.py index bcaf2015..007acd71 100644 --- a/middleware/webhook_logic.py +++ b/middleware/webhook_logic.py @@ -1,6 +1,3 @@ -import json -import os - import requests from middleware.third_party_interaction_logic.mailgun_logic import send_via_mailgun diff --git a/resources/Agencies.py b/resources/Agencies.py index 90555015..d56dfc32 100644 --- a/resources/Agencies.py +++ b/resources/Agencies.py @@ -76,7 +76,6 @@ def post(self, access_info: AccessInfoPrimary): return self.run_endpoint( wrapper_function=create_agency, schema_populate_parameters=SchemaConfigs.AGENCIES_POST.value.get_schema_populate_parameters(), - access_info=access_info, ) diff --git a/resources/Batch.py b/resources/Batch.py new file mode 100644 index 00000000..35deedd7 --- /dev/null +++ b/resources/Batch.py @@ -0,0 +1,105 @@ +from flask import Response + +from middleware.access_logic import ( + WRITE_ONLY_AUTH_INFO, + STANDARD_JWT_AUTH_INFO, + AccessInfoPrimary, +) +from middleware.decorators import endpoint_info +from middleware.primary_resource_logic.batch_logic import ( + batch_post_agencies, + batch_post_data_sources, + batch_put_agencies, + batch_put_data_sources, +) +from resources.PsycopgResource import PsycopgResource +from resources.endpoint_schema_config import SchemaConfigs +from resources.resource_helpers import ResponseInfo +from utilities.namespace import create_namespace, AppNamespaces + +namespace_batch = create_namespace(AppNamespaces.BATCH) + + +def add_csv_description(initial_description: str) -> str: + return ( + f"{initial_description}\n\n" + f"Note: Only file upload should be provided.\n" + f"The json arguments simply denote the columns of the csv" + ) + + +@namespace_batch.route("/agencies") +class AgenciesBatch(PsycopgResource): + + @endpoint_info( + namespace=namespace_batch, + auth_info=STANDARD_JWT_AUTH_INFO, + schema_config=SchemaConfigs.BATCH_AGENCIES_POST, + description=add_csv_description( + initial_description="Adds multiple agencies from a CSV file." + ), + response_info=ResponseInfo( + success_message="At least some resources created successfully." + ), + ) + def post(self, access_info: AccessInfoPrimary) -> Response: + return self.run_endpoint( + wrapper_function=batch_post_agencies, + schema_populate_parameters=SchemaConfigs.BATCH_AGENCIES_POST.value.get_schema_populate_parameters(), + ) + + @endpoint_info( + namespace=namespace_batch, + auth_info=WRITE_ONLY_AUTH_INFO, + schema_config=SchemaConfigs.BATCH_AGENCIES_PUT, + description=add_csv_description( + initial_description="Updates multiple agencies from a CSV file." + ), + response_info=ResponseInfo( + success_message="At least some resources updated successfully." + ), + ) + def put(self, access_info: AccessInfoPrimary): + return self.run_endpoint( + wrapper_function=batch_put_agencies, + schema_populate_parameters=SchemaConfigs.BATCH_AGENCIES_PUT.value.get_schema_populate_parameters(), + ) + + +@namespace_batch.route("/data-sources") +class DataSourcesBatch(PsycopgResource): + + @endpoint_info( + namespace=namespace_batch, + auth_info=STANDARD_JWT_AUTH_INFO, + schema_config=SchemaConfigs.BATCH_DATA_SOURCES_POST, + description=add_csv_description( + initial_description="Adds multiple data sources from a CSV file." + ), + response_info=ResponseInfo( + success_message="At least some resources created successfully." + ), + ) + def post(self, access_info: AccessInfoPrimary): + return self.run_endpoint( + wrapper_function=batch_post_data_sources, + schema_populate_parameters=SchemaConfigs.BATCH_DATA_SOURCES_POST.value.get_schema_populate_parameters(), + ) + + @endpoint_info( + namespace=namespace_batch, + auth_info=WRITE_ONLY_AUTH_INFO, + schema_config=SchemaConfigs.BATCH_DATA_SOURCES_PUT, + description=add_csv_description( + initial_description="Updates multiple data sources from a CSV file." + ), + response_info=ResponseInfo( + success_message="At least some resources updated successfully." + ), + ) + def put(self, access_info: AccessInfoPrimary): + + return self.run_endpoint( + wrapper_function=batch_put_data_sources, + schema_populate_parameters=SchemaConfigs.BATCH_DATA_SOURCES_PUT.value.get_schema_populate_parameters(), + ) diff --git a/resources/CreateTestUserWithElevatedPermissions.py b/resources/CreateTestUserWithElevatedPermissions.py index d945286b..a1bc189b 100644 --- a/resources/CreateTestUserWithElevatedPermissions.py +++ b/resources/CreateTestUserWithElevatedPermissions.py @@ -17,7 +17,8 @@ from flask import request from flask_restx import fields, abort -from middleware.enums import PermissionsEnum +from middleware.access_logic import AccessInfoPrimary +from middleware.enums import PermissionsEnum, AccessTypeEnum from middleware.primary_resource_logic.api_key_logic import create_api_key_for_user from middleware.primary_resource_logic.user_queries import ( user_post_results, @@ -98,12 +99,16 @@ def post(self): PermissionsEnum.DB_WRITE, ]: db_client.add_user_permission( - user_email=auto_user_email, + user_id=db_client.get_user_id(email=auto_user_email), permission=permission, ) - api_key = create_api_key_for_user(db_client=db_client, dto=dto).json[ - "api_key" - ] + access_info = AccessInfoPrimary( + access_type=AccessTypeEnum.JWT, + user_email=auto_user_email, + ) + api_key = create_api_key_for_user( + db_client=db_client, access_info=access_info + ).json["api_key"] return { "email": auto_user_email, "password": auto_user_password, diff --git a/resources/Login.py b/resources/Login.py index a5cb9506..3a1c4799 100644 --- a/resources/Login.py +++ b/resources/Login.py @@ -6,10 +6,9 @@ from middleware.primary_resource_logic.login_queries import try_logging_in from resources.endpoint_schema_config import SchemaConfigs from resources.resource_helpers import ResponseInfo - from utilities.namespace import create_namespace -from resources.PsycopgResource import PsycopgResource, handle_exceptions +from resources.PsycopgResource import PsycopgResource namespace_login = create_namespace() diff --git a/resources/PsycopgResource.py b/resources/PsycopgResource.py index 0cc093e2..6ac49fc9 100644 --- a/resources/PsycopgResource.py +++ b/resources/PsycopgResource.py @@ -139,6 +139,7 @@ def run_endpoint( dto = populate_schema_with_request_content( schema=schema_populate_parameters.schema, dto_class=schema_populate_parameters.dto_class, + load_file=schema_populate_parameters.load_file, ) with self.setup_database_client() as db_client: response = wrapper_function(db_client, dto=dto, **wrapper_kwargs) diff --git a/resources/endpoint_schema_config.py b/resources/endpoint_schema_config.py index ac4c71f4..929cc2ef 100644 --- a/resources/endpoint_schema_config.py +++ b/resources/endpoint_schema_config.py @@ -12,6 +12,9 @@ PermissionsRequestDTO, PermissionsGetRequestSchema, ) +from middleware.schema_and_dto_logic.primary_resource_dtos.batch_dtos import ( + BatchRequestDTO, +) from middleware.schema_and_dto_logic.primary_resource_dtos.reset_token_dtos import ( ResetPasswordDTO, ) @@ -19,6 +22,15 @@ ArchivesGetResponseSchema, ArchivesPutRequestSchema, ) +from middleware.schema_and_dto_logic.primary_resource_schemas.batch_schemas import ( + BatchRequestSchema, + BatchPostResponseSchema, + BatchPutResponseSchema, + AgenciesPutBatchRequestSchema, + AgenciesPostBatchRequestSchema, + DataSourcesPostBatchRequestSchema, + DataSourcesPutBatchRequestSchema, +) from middleware.schema_and_dto_logic.primary_resource_schemas.reset_token_schemas import ( ResetPasswordSchema, ) @@ -89,6 +101,7 @@ TypeaheadQuerySchema, EmailOnlyDTO, EmailOnlySchema, + EntryCreateUpdateRequestDTO, ) from middleware.schema_and_dto_logic.custom_types import DTOTypes from middleware.schema_and_dto_logic.non_dto_dataclasses import SchemaPopulateParameters @@ -102,6 +115,7 @@ from middleware.schema_and_dto_logic.primary_resource_dtos.agencies_dtos import ( AgenciesPostDTO, RelatedAgencyByIDDTO, + AgenciesPutDTO, ) from middleware.schema_and_dto_logic.primary_resource_schemas.data_requests_advanced_schemas import ( GetManyDataRequestsResponseSchema, @@ -124,6 +138,7 @@ ) from middleware.schema_and_dto_logic.primary_resource_dtos.data_sources_dtos import ( DataSourcesPostDTO, + DataSourcesPutDTO, ) from middleware.schema_and_dto_logic.common_response_schemas import ( IDAndMessageSchema, @@ -179,8 +194,14 @@ def __init__( ) def get_schema_populate_parameters(self) -> SchemaPopulateParameters: + if "file" in self.input_schema.fields: + load_file = True + else: + load_file = False return SchemaPopulateParameters( - schema=self.input_schema, dto_class=self.input_dto_class + schema=self.input_schema, + dto_class=self.input_dto_class, + load_file=load_file, ) @@ -476,3 +497,27 @@ class SchemaConfigs(Enum): input_dto_class=PermissionsRequestDTO, ) # endregion + + # region Batch + BATCH_DATA_SOURCES_POST = EndpointSchemaConfig( + input_schema=DataSourcesPostBatchRequestSchema(), + input_dto_class=DataSourcesPostDTO, + primary_output_schema=BatchPostResponseSchema(), + ) + BATCH_DATA_SOURCES_PUT = EndpointSchemaConfig( + input_schema=DataSourcesPutBatchRequestSchema(), + input_dto_class=DataSourcesPutDTO, + primary_output_schema=BatchPutResponseSchema(), + ) + BATCH_AGENCIES_POST = EndpointSchemaConfig( + input_schema=AgenciesPostBatchRequestSchema(), + input_dto_class=AgenciesPostDTO, + primary_output_schema=BatchPostResponseSchema(), + ) + BATCH_AGENCIES_PUT = EndpointSchemaConfig( + input_schema=AgenciesPutBatchRequestSchema(), + input_dto_class=AgenciesPutDTO, + primary_output_schema=BatchPutResponseSchema(), + ) + + # endregion diff --git a/tests/helper_scripts/common_test_functions.py b/tests/helper_scripts/common_test_functions.py index d09c4fa6..fe018e66 100644 --- a/tests/helper_scripts/common_test_functions.py +++ b/tests/helper_scripts/common_test_functions.py @@ -47,6 +47,6 @@ def assert_contains_key_value_pairs( key_value_pairs: dict, ): for key, value in key_value_pairs.items(): - assert key in dict_to_check + assert key in dict_to_check, f"Expected {key} to be in {dict_to_check}" dict_value = dict_to_check[key] assert dict_value == value, f"Expected {key} to be {value}, was {dict_value}" diff --git a/tests/helper_scripts/helper_classes/MultipleTemporaryFiles.py b/tests/helper_scripts/helper_classes/MultipleTemporaryFiles.py new file mode 100644 index 00000000..61e65284 --- /dev/null +++ b/tests/helper_scripts/helper_classes/MultipleTemporaryFiles.py @@ -0,0 +1,35 @@ +import os +import tempfile + + +class MultipleTemporaryFiles: + def __init__(self, suffixes: list[str], mode="w+", encoding="utf-8", delete=False): + if not isinstance(suffixes, list) or len(suffixes) == 0: + raise ValueError("Suffixes must be a non-empty list.") + + self.suffixes = suffixes + self.mode = mode + self.encoding = encoding + self.delete = delete + self.temp_files = [] + + def __enter__(self): + # Create temporary files with unique suffixes + for suffix in self.suffixes: + temp_file = tempfile.NamedTemporaryFile( + mode=self.mode, + encoding=self.encoding, + suffix=suffix, + delete=self.delete, + ) + self.temp_files.append(temp_file) + return self.temp_files + + def __exit__(self, exc_type, exc_value, traceback): + for temp_file in self.temp_files: + try: + temp_file.close() + if self.delete and os.path.exists(temp_file.name): + os.unlink(temp_file.name) + except Exception as e: + print(f"Error cleaning up temporary file {temp_file.name}: {e}") diff --git a/tests/helper_scripts/helper_classes/RequestValidator.py b/tests/helper_scripts/helper_classes/RequestValidator.py index 739069ef..4c7682e1 100644 --- a/tests/helper_scripts/helper_classes/RequestValidator.py +++ b/tests/helper_scripts/helper_classes/RequestValidator.py @@ -2,7 +2,9 @@ Class based means to run and validate requests """ +from dataclasses import dataclass from http import HTTPStatus +from io import BytesIO from typing import Optional, Type, Union from flask.testing import FlaskClient @@ -459,3 +461,75 @@ def get_api_spec( return self.get( endpoint="/api/swagger.json", ) + + @dataclass + class BatchOperationParams: + file: BytesIO + headers: dict + expected_response_status: HTTPStatus = HTTPStatus.OK + + def insert_agencies_batch( + self, + bop: BatchOperationParams, + expected_schema=SchemaConfigs.BATCH_AGENCIES_POST.value.primary_output_schema, + ): + return self.post( + endpoint="/api/batch/agencies", + headers=bop.headers, + file=bop.file, + expected_schema=expected_schema, + expected_response_status=bop.expected_response_status, + ) + + def update_agencies_batch( + self, + bop: BatchOperationParams, + expected_schema=SchemaConfigs.BATCH_AGENCIES_PUT.value.primary_output_schema, + ): + return self.put( + endpoint="/api/batch/agencies", + headers=bop.headers, + file=bop.file, + expected_schema=expected_schema, + expected_response_status=bop.expected_response_status, + ) + + def insert_data_sources_batch( + self, + bop: BatchOperationParams, + expected_schema=SchemaConfigs.BATCH_DATA_SOURCES_POST.value.primary_output_schema, + ): + return self.post( + endpoint="/api/batch/data-sources", + headers=bop.headers, + file=bop.file, + expected_schema=expected_schema, + expected_response_status=bop.expected_response_status, + ) + + def update_data_sources_batch( + self, + bop: BatchOperationParams, + expected_schema=SchemaConfigs.BATCH_DATA_SOURCES_PUT.value.primary_output_schema, + ): + return self.put( + endpoint="/api/batch/data-sources", + headers=bop.headers, + file=bop.file, + expected_schema=expected_schema, + expected_response_status=bop.expected_response_status, + ) + + def get_agency_by_id(self, headers: dict, id: int): + return self.get( + endpoint=f"/api/agencies/{id}", + headers=headers, + expected_schema=SchemaConfigs.AGENCIES_BY_ID_GET.value.primary_output_schema, + ) + + def get_data_source_by_id(self, headers: dict, id: int): + return self.get( + endpoint=f"/api/data-sources/{id}", + headers=headers, + expected_schema=SchemaConfigs.DATA_SOURCES_GET_BY_ID.value.primary_output_schema, + ) diff --git a/tests/helper_scripts/run_and_validate_request.py b/tests/helper_scripts/run_and_validate_request.py index 54573882..d75d6a04 100644 --- a/tests/helper_scripts/run_and_validate_request.py +++ b/tests/helper_scripts/run_and_validate_request.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Optional, Type, Union, Literal +from typing import Optional, Type, Union, Literal, TextIO from flask.testing import FlaskClient from marshmallow import Schema @@ -18,6 +18,7 @@ def run_and_validate_request( expected_json_content: Optional[dict] = None, expected_schema: Optional[Union[Type[Schema], Schema]] = None, query_parameters: Optional[dict] = None, + file: Optional[TextIO] = None, **request_kwargs, ): """ @@ -35,7 +36,21 @@ def run_and_validate_request( if query_parameters is not None: endpoint = add_query_params(endpoint, query_parameters) - response = flask_client.open(endpoint, method=http_method, **request_kwargs) + if file is not None: + # Check file is open + if file.closed is True: + # Open file + file = open(file.name, "rb") + assert file.closed is False, "File must be opened before sending request." + response = flask_client.open( + endpoint, + method=http_method, + data={"file": file}, + content_type="multipart/form-data", + **request_kwargs, + ) + else: + response = flask_client.open(endpoint, method=http_method, **request_kwargs) check_response_status(response, expected_response_status.value) # All of our requests should return some json message providing information. diff --git a/tests/integration/test_batch.py b/tests/integration/test_batch.py new file mode 100644 index 00000000..46b05ad1 --- /dev/null +++ b/tests/integration/test_batch.py @@ -0,0 +1,471 @@ +import os +from dataclasses import dataclass +from enum import Enum +from http import HTTPStatus +from typing import TextIO, Optional, Annotated, Callable + +import pytest + +from marshmallow import Schema +from csv import DictWriter +import tempfile + +from conftest import test_data_creator_flask, monkeysession +from database_client.enums import LocationType +from middleware.primary_resource_logic.batch_logic import listify_strings +from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema +from middleware.schema_and_dto_logic.dynamic_logic.dynamic_csv_to_schema_conversion_logic import ( + SchemaUnflattener, +) +from middleware.schema_and_dto_logic.primary_resource_schemas.batch_schemas import ( + AgenciesPostRequestFlatBaseSchema, + DataSourcesPostRequestFlatBaseSchema, + AgenciesPutRequestFlatBaseSchema, + DataSourcesPutRequestFlatBaseSchema, + AgenciesPutBatchRequestSchema, + DataSourcesPutBatchRequestSchema, +) +from tests.helper_scripts.common_test_data import get_test_name +from tests.helper_scripts.common_test_functions import assert_contains_key_value_pairs +from tests.helper_scripts.helper_classes.RequestValidator import RequestValidator +from tests.helper_scripts.helper_classes.SchemaTestDataGenerator import ( + generate_test_data_from_schema, +) +from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( + TestDataCreatorFlask, +) + +SUFFIX_ARRAY = [".csv", ".csv", ".inv", ".csv"] + + +# Helper scripts +# +class ResourceType(Enum): + AGENCY = "agencies" + SOURCE = "data_sources" + REQUEST = "data_requests" + + +def get_endpoint_nomenclature( + resource_type: ResourceType, +) -> str: + """ + Get a string version of the resource, appropriate for endpoints. + """ + return resource_type.value.replace("_", "-") + + +class TestCSVCreator: + + def __init__(self, schema: Schema): + self.fields = schema.fields + + def create_csv(self, file: TextIO, rows: list[dict]): + header = list(self.fields.keys()) + writer = DictWriter(file, fieldnames=header) + # Write header row using header + writer.writerow({field: field for field in header}) + + for row in rows: + writer.writerow(row) + file.close() + + +def stringify_list_of_ints(l: list): + for i in range(len(l)): + l[i] = str(l[i]) + return l + + +def stringify_lists(d: dict): + for k, v in d.items(): + if isinstance(v, dict): + stringify_lists(v) + if isinstance(v, list): + v = stringify_list_of_ints(v) + d[k] = ",".join(v) + return d + + +@dataclass +class BatchTestRunner: + tdc: TestDataCreatorFlask + csv_creator: TestCSVCreator + schema: Optional[Annotated[Schema, "The schema to use in the test"]] = None + + def generate_test_data(self, override: Optional[dict] = None): + test_data = generate_test_data_from_schema( + schema=self.schema, override=override + ) + stringify_lists(d=test_data) + return test_data + + +@pytest.fixture +def runner( + test_data_creator_flask: TestDataCreatorFlask, +): + return BatchTestRunner( + tdc=test_data_creator_flask, + csv_creator=TestCSVCreator(AgenciesPostRequestFlatBaseSchema()), + ) + + +@pytest.fixture +def agencies_post_runner( + test_data_creator_flask: TestDataCreatorFlask, +): + return BatchTestRunner( + tdc=test_data_creator_flask, + csv_creator=TestCSVCreator(AgenciesPostRequestFlatBaseSchema()), + schema=AgenciesPostRequestFlatBaseSchema(), + ) + + +@pytest.fixture +def data_sources_post_runner(test_data_creator_flask: TestDataCreatorFlask): + return BatchTestRunner( + tdc=test_data_creator_flask, + csv_creator=TestCSVCreator(DataSourcesPostRequestFlatBaseSchema()), + schema=DataSourcesPostRequestFlatBaseSchema(), + ) + + +@pytest.fixture +def agencies_put_runner(test_data_creator_flask): + return BatchTestRunner( + tdc=test_data_creator_flask, + csv_creator=TestCSVCreator(AgenciesPutBatchRequestSchema(exclude=["file"])), + schema=AgenciesPutRequestFlatBaseSchema(), + ) + + +@pytest.fixture +def data_sources_put_runner(runner: BatchTestRunner): + return BatchTestRunner( + tdc=runner.tdc, + csv_creator=TestCSVCreator(DataSourcesPutBatchRequestSchema(exclude=["file"])), + schema=DataSourcesPutRequestFlatBaseSchema(), + ) + + +class SimpleTempFile: + + def __init__(self, suffix: str = ".csv"): + self.suffix = suffix + self.temp_file = None + + def __enter__(self): + self.temp_file = tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=self.suffix, delete=False + ) + return self.temp_file + + def __exit__(self, exc_type, exc_value, traceback): + try: + self.temp_file.close() + os.unlink(self.temp_file.name) + except Exception as e: + print(f"Error cleaning up temporary file {self.temp_file.name}: {e}") + + +def generate_agencies_locality_data(): + locality_name = get_test_name() + return { + "location_type": LocationType.LOCALITY.value, + "locality_name": locality_name, + "county_fips": "42003", + "state_iso": "PA", + } + + +def create_csv_and_run( + runner: BatchTestRunner, + rows: list[dict], + request_validator_method: Callable, + suffix: str = ".csv", + expected_response_status: HTTPStatus = HTTPStatus.OK, + expected_schema: Optional[Schema] = None, +): + if expected_schema is not None: + kwargs = { + "expected_schema": expected_schema, + } + else: + kwargs = {} + + with SimpleTempFile(suffix=suffix) as temp_file: + runner.csv_creator.create_csv(file=temp_file, rows=rows) + return request_validator_method( + bop=RequestValidator.BatchOperationParams( + file=temp_file, + headers=runner.tdc.get_admin_tus().jwt_authorization_header, + expected_response_status=expected_response_status, + ), + **kwargs, + ) + + +def check_for_errors(data: dict, check_ids: bool = True): + if check_ids: + ids = data["ids"] + assert len(ids) == 1 + assert len(data["errors"]) == 2, f"Incorrect number of errors: {data['errors']}" + assert "0", "2" in data["errors"].keys() + + +def test_batch_agencies_insert_happy_path( + agencies_post_runner: BatchTestRunner, +): + runner = agencies_post_runner + + locality_info = generate_agencies_locality_data() + rows = [runner.generate_test_data(override=locality_info) for _ in range(3)] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.insert_agencies_batch, + ) + + ids = data["ids"] + + unflattener = SchemaUnflattener(flat_schema_class=AgenciesPostRequestFlatBaseSchema) + + for row, id in zip(rows, ids): + unflattened_row = unflattener.unflatten(flat_data=row) + data = runner.tdc.request_validator.get_agency_by_id( + id=id, headers=runner.tdc.get_admin_tus().jwt_authorization_header + ) + assert_contains_key_value_pairs( + dict_to_check=data["data"], key_value_pairs=unflattened_row["agency_info"] + ) + assert data["data"]["locality_name"] == row["locality_name"] + + +def test_batch_agencies_insert_some_errors( + agencies_post_runner: BatchTestRunner, +): + runner = agencies_post_runner + + rows = [ + runner.generate_test_data(override={"lat": "not a number"}), + runner.generate_test_data(override=generate_agencies_locality_data()), + runner.generate_test_data(override={"location_type": ""}), + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.insert_agencies_batch, + ) + + check_for_errors(data) + + +def test_batch_agencies_insert_wrong_file_type( + agencies_post_runner: BatchTestRunner, +): + runner = agencies_post_runner + create_csv_and_run( + runner=agencies_post_runner, + rows=[], + suffix=".json", + request_validator_method=runner.tdc.request_validator.insert_agencies_batch, + expected_response_status=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + expected_schema=MessageSchema(), + ) + + +def test_batch_agencies_update_happy_path( + agencies_put_runner: BatchTestRunner, +): + runner = agencies_put_runner + agencies = [runner.tdc.agency() for _ in range(3)] + locality_info = generate_agencies_locality_data() + rows = [ + runner.generate_test_data(override={**locality_info, "id": agencies[i].id}) + for i in range(3) + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.update_agencies_batch, + ) + + ids = [agencies[i].id for i in range(3)] + + unflattener = SchemaUnflattener(flat_schema_class=AgenciesPutRequestFlatBaseSchema) + + for row, id in zip(rows, ids): + unflattened_row = unflattener.unflatten(flat_data=row) + data = runner.tdc.request_validator.get_agency_by_id( + id=id, headers=runner.tdc.get_admin_tus().jwt_authorization_header + ) + assert_contains_key_value_pairs( + dict_to_check=data["data"], key_value_pairs=unflattened_row["agency_info"] + ) + assert data["data"]["locality_name"] == row["locality_name"] + + +def test_batch_agencies_update_some_errors( + agencies_put_runner: BatchTestRunner, +): + runner = agencies_put_runner + locality_info = generate_agencies_locality_data() + agencies = [runner.tdc.agency() for _ in range(3)] + rows = [ + runner.generate_test_data( + override={"lat": "not a number", "id": agencies[0].id} + ), + runner.generate_test_data(override={**locality_info, "id": agencies[1].id}), + runner.generate_test_data(override={"location_type": "", "id": agencies[2].id}), + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.update_agencies_batch, + ) + check_for_errors(data, check_ids=False) + + +def test_batch_agencies_update_wrong_file_type( + agencies_put_runner: BatchTestRunner, +): + runner = agencies_put_runner + create_csv_and_run( + runner=agencies_put_runner, + rows=[], + suffix=".json", + request_validator_method=runner.tdc.request_validator.update_agencies_batch, + expected_response_status=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + expected_schema=MessageSchema(), + ) + + +def test_batch_data_sources_insert_happy_path( + data_sources_post_runner: BatchTestRunner, +): + runner = data_sources_post_runner + rows = [ + runner.generate_test_data(override={"linked_agency_ids": [1, 2, 3]}) + for _ in range(3) + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.insert_data_sources_batch, + ) + ids = data["ids"] + unflattener = SchemaUnflattener( + flat_schema_class=DataSourcesPostRequestFlatBaseSchema + ) + + for row, id in zip(rows, ids): + unflattened_row = unflattener.unflatten(flat_data=row) + data = runner.tdc.request_validator.get_data_source_by_id( + id=id, headers=runner.tdc.get_admin_tus().jwt_authorization_header + ) + listify_strings([unflattened_row["entry_data"]]) + assert_contains_key_value_pairs( + dict_to_check=data["data"], + key_value_pairs=unflattened_row["entry_data"], + ) + + +def test_batch_data_sources_insert_some_errors( + data_sources_post_runner: BatchTestRunner, +): + runner = data_sources_post_runner + rows = [ + runner.generate_test_data(override={"linked_agency_ids": "not a list"}), + runner.generate_test_data(override={"linked_agency_ids": [1, 2, 3]}), + runner.generate_test_data(override={"coverage_start": "not a date"}), + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.insert_data_sources_batch, + ) + check_for_errors(data) + + +def test_batch_data_sources_insert_wrong_file_type( + data_sources_post_runner: BatchTestRunner, +): + runner = data_sources_post_runner + create_csv_and_run( + runner=data_sources_post_runner, + rows=[], + suffix=".json", + request_validator_method=runner.tdc.request_validator.insert_data_sources_batch, + expected_response_status=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + expected_schema=MessageSchema(), + ) + + +def test_batch_data_sources_update_happy_path( + data_sources_put_runner: BatchTestRunner, +): + runner = data_sources_put_runner + data_sources = [runner.tdc.data_source() for i in range(3)] + rows = [ + runner.generate_test_data(override={"id": data_sources[i].id}) for i in range(3) + ] + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.update_data_sources_batch, + ) + + ids = [data_source.id for data_source in data_sources] + + unflattener = SchemaUnflattener( + flat_schema_class=DataSourcesPutRequestFlatBaseSchema + ) + for row, id in zip(rows, ids): + unflattened_row = unflattener.unflatten(flat_data=row) + data = runner.tdc.request_validator.get_data_source_by_id( + id=id, headers=runner.tdc.get_admin_tus().jwt_authorization_header + ) + listify_strings([unflattened_row["entry_data"]]) + assert_contains_key_value_pairs( + dict_to_check=data["data"], + key_value_pairs=unflattened_row["entry_data"], + ) + + +def test_batch_data_sources_update_some_errors( + data_sources_put_runner: BatchTestRunner, +): + runner = data_sources_put_runner + + data_sources = [runner.tdc.data_source() for i in range(3)] + rows = [ + runner.generate_test_data( + override={"id": data_sources[0].id, "access_types": "A String"} + ), + runner.generate_test_data(override={"id": data_sources[1].id}), + runner.generate_test_data( + override={"id": data_sources[2].id, "record_type_name": "Not a record type"} + ), + ] + + data = create_csv_and_run( + runner=runner, + rows=rows, + request_validator_method=runner.tdc.request_validator.update_data_sources_batch, + ) + check_for_errors(data, check_ids=False) + + +def test_batch_data_sources_update_wrong_file_type( + data_sources_put_runner: BatchTestRunner, +): + runner = data_sources_put_runner + create_csv_and_run( + runner=data_sources_put_runner, + rows=[], + suffix=".json", + request_validator_method=runner.tdc.request_validator.update_data_sources_batch, + expected_response_status=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + expected_schema=MessageSchema(), + ) diff --git a/tests/integration/test_reset_password.py b/tests/integration/test_reset_password.py index c9c903d2..a1cd5543 100644 --- a/tests/integration/test_reset_password.py +++ b/tests/integration/test_reset_password.py @@ -55,7 +55,7 @@ def login(password: str, expected_response_status: HTTPStatus = HTTPStatus.OK): tdc.request_validator.get( endpoint=DATA_SOURCES_BASE_ENDPOINT, headers=get_authorization_header(scheme="Bearer", token=token), - expected_response_status=HTTPStatus.UNAUTHORIZED, + expected_response_status=HTTPStatus.BAD_REQUEST, ) new_password = str(uuid.uuid4()) diff --git a/tests/integration/test_signup.py b/tests/integration/test_signup.py index 4e00ed89..8760420b 100644 --- a/tests/integration/test_signup.py +++ b/tests/integration/test_signup.py @@ -1,8 +1,6 @@ from datetime import datetime, timezone, timedelta from http import HTTPStatus -import pytest - from conftest import test_data_creator_flask, monkeysession from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema from tests.helper_scripts.common_test_data import get_test_email diff --git a/utilities/namespace.py b/utilities/namespace.py index 45ced1fb..f6eba507 100644 --- a/utilities/namespace.py +++ b/utilities/namespace.py @@ -25,6 +25,7 @@ class AppNamespaces(Enum): path="notifications", description="Notifications Namespace" ) MAP = NamespaceAttributes(path="map", description="Map Namespace") + BATCH = NamespaceAttributes(path="batch", description="Batch Namespace") def create_namespace(