diff --git a/api/src/data_migration/transformation/subtask/transform_agency.py b/api/src/data_migration/transformation/subtask/transform_agency.py index 921d90169..d5dd11170 100644 --- a/api/src/data_migration/transformation/subtask/transform_agency.py +++ b/api/src/data_migration/transformation/subtask/transform_agency.py @@ -272,6 +272,25 @@ def update_agency_download_file_types( ) +class TransformAgencyHierarchy(AbstractTransformSubTask): + def __init__(self, task: Task): + super().__init__(task) + + def transform_records(self) -> None: + agencies = self.db_session.scalars(select(Agency)).all() + agency_map = {agency.agency_code: agency for agency in agencies} + + for agency in agencies: + top_level_agency_code = self.get_top_level_agency_code(agency.agency_code) + if top_level_agency_code and top_level_agency_code in agency_map: + agency.top_level_agency = agency_map[top_level_agency_code] + + def get_top_level_agency_code(self, agency_code: str) -> str | None: + if "-" not in agency_code: + return None + return agency_code.split("-")[0] + + ############################ # Transformation / utility functions ############################ diff --git a/api/src/data_migration/transformation/transform_oracle_data_task.py b/api/src/data_migration/transformation/transform_oracle_data_task.py index b7ce8e0fd..93b9bf8b4 100644 --- a/api/src/data_migration/transformation/transform_oracle_data_task.py +++ b/api/src/data_migration/transformation/transform_oracle_data_task.py @@ -5,7 +5,10 @@ import src.data_migration.transformation.transform_constants as transform_constants from src.adapters import db -from src.data_migration.transformation.subtask.transform_agency import TransformAgency +from src.data_migration.transformation.subtask.transform_agency import ( + TransformAgency, + TransformAgencyHierarchy, +) from src.data_migration.transformation.subtask.transform_applicant_type import ( TransformApplicantType, ) @@ -81,3 +84,4 @@ def run_task(self) -> None: if self.transform_config.enable_agency: TransformAgency(self).run() + TransformAgencyHierarchy(self).run() diff --git a/api/tests/src/data_migration/transformation/subtask/test_transform_agency.py b/api/tests/src/data_migration/transformation/subtask/test_transform_agency.py index 4c87247c7..cfa3c6e05 100644 --- a/api/tests/src/data_migration/transformation/subtask/test_transform_agency.py +++ b/api/tests/src/data_migration/transformation/subtask/test_transform_agency.py @@ -10,10 +10,12 @@ from src.data_migration.transformation.subtask.transform_agency import ( TgroupAgency, TransformAgency, + TransformAgencyHierarchy, apply_updates, transform_agency_download_file_types, transform_agency_notify, ) +from src.db.models.agency_models import Agency from tests.src.data_migration.transformation.conftest import ( BaseTransformTestClass, setup_agency, @@ -22,6 +24,43 @@ from tests.src.db.models.factories import AgencyFactory +class TestTransformAgencyHierarchy(BaseTransformTestClass): + @pytest.fixture() + def transform_agency_hierarchy(self, transform_oracle_data_task): + return TransformAgencyHierarchy(transform_oracle_data_task) + + def test_transform_records(self, db_session, transform_agency_hierarchy): + # Create agencies with varying top-level agency codes + [ + AgencyFactory.create(agency_code="DHS"), + AgencyFactory.create(agency_code="DHS-ICE"), + AgencyFactory.create(agency_code="DHS--ICE"), + AgencyFactory.create(agency_code="DHS-ICE-123"), + AgencyFactory.create(agency_code="ABC-ICE"), + ] + + # Run the transformation + transform_agency_hierarchy.transform_records() + + # Fetch the agencies again to verify the changes + agency1 = db_session.query(Agency).filter(Agency.agency_code == "DHS").one_or_none() + agency2 = db_session.query(Agency).filter(Agency.agency_code == "DHS-ICE").one_or_none() + agency3 = db_session.query(Agency).filter(Agency.agency_code == "DHS-ICE-123").one_or_none() + agency4 = db_session.query(Agency).filter(Agency.agency_code == "ABC-ICE").one_or_none() + agency5 = db_session.query(Agency).filter(Agency.agency_code == "DHS--ICE").one_or_none() + + # Verify that the top-level agencies are set correctly + assert agency1.top_level_agency_id is None + assert agency2.top_level_agency_id == agency1.agency_id + assert agency3.top_level_agency_id == agency1.agency_id + assert agency4.top_level_agency_id is None + assert agency5.top_level_agency_id == agency1.agency_id + + def test_get_top_level_agency_code(self, transform_agency_hierarchy): + assert transform_agency_hierarchy.get_top_level_agency_code("DHS-ICE") == "DHS" + assert transform_agency_hierarchy.get_top_level_agency_code("DHS") is None + + class TestTransformAgency(BaseTransformTestClass): @pytest.fixture() def transform_agency(self, transform_oracle_data_task):