Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable user accounts script #2898

Merged
7 changes: 7 additions & 0 deletions backend/data_tools/data/user_data.json5
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,13 @@
oidc_id: "00000000-0000-1111-a111-000000000022",
groups: [],
status: "ACTIVE"
},
{ // 526
first_name: "System",
last_name: "Admin",
email: "[email protected]",
oidc_id: "00000000-0000-1111-a111-000000000026",
status: "LOCKED"
}
],
notification: [
Expand Down
Empty file.
124 changes: 124 additions & 0 deletions backend/data_tools/src/disable_users/disable_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
import os

from data_tools.src.disable_users.queries import (
ALL_ACTIVE_USER_SESSIONS_QUERY,
EXCLUDED_USER_OIDC_IDS,
GET_USER_ID_BY_OIDC_QUERY,
INACTIVE_USER_QUERY,
SYSTEM_ADMIN_EMAIL,
SYSTEM_ADMIN_OIDC_ID,
)
from data_tools.src.import_static_data.import_data import get_config, init_db
from sqlalchemy import text
from sqlalchemy.orm import Mapper, Session

from models import * # noqa: F403, F401

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def get_ids_from_oidc_ids(se, oidc_ids: list):
"""Retrieve user IDs corresponding to a list of OIDC IDs."""
if not all(isinstance(oidc_id, str) for oidc_id in oidc_ids):
raise ValueError("All oidc_ids must be strings.")

ids = []
for oidc_id in oidc_ids:
user_id = se.execute(text(GET_USER_ID_BY_OIDC_QUERY), {"oidc_id": oidc_id}).scalar()

if user_id is not None:
ids.append(user_id)

return ids


def create_system_admin(se):
"""Create system user if it doesn't exist."""
system_admin = se.execute(
text(GET_USER_ID_BY_OIDC_QUERY),
{"oidc_id": SYSTEM_ADMIN_OIDC_ID}
).fetchone()

if system_admin is None:
sys_user = User(
email=SYSTEM_ADMIN_EMAIL,
oidc_id=SYSTEM_ADMIN_OIDC_ID,
status=UserStatus.LOCKED
)
se.add(sys_user)
se.commit()
return sys_user.id

return system_admin[0]


def disable_user(se, user_id, system_admin_id):
"""Deactivate a single user and log the change."""
updated_user = User(id=user_id, status=UserStatus.INACTIVE, updated_by=system_admin_id)
se.merge(updated_user)

db_audit = build_audit(updated_user, OpsDBHistoryType.UPDATED)
ops_db_history = OpsDBHistory(
event_type=OpsDBHistoryType.UPDATED,
created_by=system_admin_id,
class_name=updated_user.__class__.__name__,
row_key=db_audit.row_key,
changes=db_audit.changes,
)
se.add(ops_db_history)

ops_event = OpsEvent(
event_type=OpsEventType.UPDATE_USER,
event_status=OpsEventStatus.SUCCESS,
created_by=system_admin_id,
)
se.add(ops_event)

all_user_sessions = se.execute(text(ALL_ACTIVE_USER_SESSIONS_QUERY), {"user_id": user_id})
for session in all_user_sessions:
updated_user_session = UserSession(
id=session[0],
is_active=False,
updated_by=system_admin_id
)
se.merge(updated_user_session)


def update_disabled_users_status(conn: sqlalchemy.engine.Engine):
"""Update the status of disabled users in the database."""
with Session(conn) as se:
logger.info("Checking for System User.")
system_admin_id = create_system_admin(se)

logger.info("Fetching inactive users.")
results = se.execute(text(INACTIVE_USER_QUERY)).scalars().all()
excluded_ids = get_ids_from_oidc_ids(se, EXCLUDED_USER_OIDC_IDS)
user_ids = [uid for uid in results if uid not in excluded_ids]

if not user_ids:
logger.info("No inactive users found.")
return

logger.info("Inactive users found:", user_ids)

for user_id in user_ids:
logger.info("Deactivating user", user_id)
disable_user(se, user_id, system_admin_id)

se.commit()


if __name__ == "__main__":
logger.info("Starting Disable Inactive Users process.")

script_env = os.getenv("ENV")
script_config = get_config(script_env)
db_engine, db_metadata_obj = init_db(script_config)

event.listen(Mapper, "after_configured", setup_schema(BaseModel))

update_disabled_users_status(db_engine)

logger.info("Disable Inactive Users process complete.")
31 changes: 31 additions & 0 deletions backend/data_tools/src/disable_users/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
SYSTEM_ADMIN_OIDC_ID = "00000000-0000-1111-a111-000000000026"
SYSTEM_ADMIN_EMAIL = "[email protected]"

EXCLUDED_USER_OIDC_IDS = [
"00000000-0000-1111-a111-000000000018", # Admin Demo
"00000000-0000-1111-a111-000000000019", # User Demo
"00000000-0000-1111-a111-000000000020", # Director Dave
"00000000-0000-1111-a111-000000000021", # Budget Team
"00000000-0000-1111-a111-000000000022", # Director Derrek
SYSTEM_ADMIN_OIDC_ID # System Admin
]

INACTIVE_USER_QUERY = (
"SELECT id "
"FROM ops_user "
"WHERE id IN ( "
" SELECT ou.id "
" FROM user_session JOIN ops_user ou ON user_session.user_id = ou.id "
" WHERE ou.status = 'ACTIVE' "
" AND user_session.last_active_at < CURRENT_TIMESTAMP - INTERVAL '60 days'"
");"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there it is!


ALL_ACTIVE_USER_SESSIONS_QUERY = (
"SELECT * "
"FROM user_session "
"WHERE user_id = :user_id AND is_active = TRUE "
"ORDER BY created_on DESC"
)

GET_USER_ID_BY_OIDC_QUERY = "SELECT id FROM ops_user WHERE oidc_id = :oidc_id"
Empty file.
108 changes: 108 additions & 0 deletions backend/data_tools/tests/disable_users/test_disable_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from unittest.mock import MagicMock, patch

import pytest
from data_tools.src.disable_users.disable_users import (
create_system_admin,
disable_user,
get_ids_from_oidc_ids,
update_disabled_users_status,
)
from data_tools.src.disable_users.queries import SYSTEM_ADMIN_EMAIL, SYSTEM_ADMIN_OIDC_ID

from models import OpsDBHistoryType, OpsEventStatus, OpsEventType, UserStatus

system_admin_id = 111

@pytest.fixture
def mock_session():
"""Fixture for creating a mock SQLAlchemy session."""
session = MagicMock()
session.execute.return_value.fetchone.return_value = None
return session

def test_create_system_admin(mock_session):
create_system_admin(mock_session)

se_add = mock_session.add.call_args[0][0]
mock_session.execute.assert_called_once()
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
assert se_add.email == SYSTEM_ADMIN_EMAIL
assert se_add.oidc_id == SYSTEM_ADMIN_OIDC_ID
assert se_add.first_name is None
assert se_add.last_name is None

def test_return_existing_system_admin(mock_session):
mock_session.execute.return_value.fetchone.return_value = (system_admin_id,)

result = create_system_admin(mock_session)

assert result == system_admin_id
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()

def test_deactivate_user(mock_session):
user_id = 1
db_history_changes = {
"id": {"new": user_id, "old": None},
"status": {"new": "INACTIVE", "old": None},
"updated_by": {"new": system_admin_id, "old": None}
}

mock_session.execute.return_value = [(1,), (2,)]

disable_user(mock_session, user_id, system_admin_id)

assert mock_session.merge.call_count == 3
assert mock_session.add.call_count == 2

user_call = mock_session.merge.call_args_list[0]
assert user_call[0][0].id == user_id
assert user_call[0][0].status == UserStatus.INACTIVE
assert user_call[0][0].updated_by == system_admin_id

user_session_call_1 = mock_session.merge.call_args_list[1]
assert user_session_call_1[0][0].id == user_id
assert user_session_call_1[0][0].is_active is False
assert user_session_call_1[0][0].updated_by == system_admin_id

ops_db_history_call = mock_session.add.call_args_list[0]
assert ops_db_history_call[0][0].event_type == OpsDBHistoryType.UPDATED
assert ops_db_history_call[0][0].created_by == system_admin_id
assert ops_db_history_call[0][0].class_name == 'User'
assert ops_db_history_call[0][0].row_key == str(user_id)
assert ops_db_history_call[0][0].changes == db_history_changes

ops_events_call = mock_session.add.call_args_list[1]
assert ops_events_call[0][0].event_type == OpsEventType.UPDATE_USER
assert ops_events_call[0][0].event_status == OpsEventStatus.SUCCESS
assert ops_events_call[0][0].created_by == system_admin_id

@patch("data_tools.src.disable_users.disable_users.logger")
def test_no_inactive_users(mock_logger, mock_session):
mock_session.execute.return_value.all.return_value = None
update_disabled_users_status(mock_session)

mock_logger.info.assert_any_call("Checking for System User.")
mock_logger.info.assert_any_call("Fetching inactive users.")
mock_logger.info.assert_any_call("No inactive users found.")

def test_valid_oidc_ids(mock_session):
mock_session.execute.return_value.scalar.side_effect = [1, 2, None] # Mock responses for OIDC IDs

oidc_ids = ["oidc_1", "oidc_2", "oidc_3"]
expected_ids = [1, 2]

result = get_ids_from_oidc_ids(mock_session, oidc_ids)
assert result == expected_ids

empty_result = get_ids_from_oidc_ids(mock_session, [])
assert empty_result == []

def test_invalid_oidc_id_type(mock_session):
oidc_ids = ["valid_oidc", 123, "another_valid_oidc"]

with pytest.raises(ValueError) as context:
get_ids_from_oidc_ids(mock_session, oidc_ids)

assert str(context.value) == "All oidc_ids must be strings."
84 changes: 83 additions & 1 deletion backend/models/history.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from collections import namedtuple
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from types import NoneType

from sqlalchemy import Column, Index, String, desc
from sqlalchemy import Column, Index, String, desc, inspect
from sqlalchemy.dialects.postgresql import ENUM, JSONB
from sqlalchemy.orm.attributes import get_history

from .base import BaseModel

DbRecordAudit = namedtuple("DbRecordAudit", "row_key changes")

class OpsDBHistoryType(Enum):
NEW = 1
Expand All @@ -30,3 +36,79 @@ class OpsDBHistory(BaseModel):
OpsDBHistory.row_key,
desc(OpsDBHistory.created_on),
)


def build_audit(obj, event_type: OpsDBHistoryType) -> DbRecordAudit: # noqa: C901
row_key = "|".join([str(getattr(obj, pk.name)) for pk in inspect(obj.__table__).primary_key.columns.values()])

changes = {}

mapper = obj.__mapper__

# collect changes in column values
auditable_columns = list(filter(lambda c: c.key in obj.__dict__, mapper.columns))
for col in auditable_columns:
key = col.key
hist = get_history(obj, key)
if hist.has_changes():
# this assumes columns are primitives, not lists
old_val = convert_for_jsonb(hist.deleted[0]) if hist.deleted else None
new_val = convert_for_jsonb(hist.added[0]) if hist.added else None
# exclude Enums that didn't really change
if hist.deleted and isinstance(hist.deleted[0], Enum) and old_val == new_val:
continue
if event_type == OpsDBHistoryType.NEW:
if new_val:
changes[key] = {
"new": new_val,
}
else:
changes[key] = {
"new": new_val,
"old": old_val,
}

# collect changes in relationships, such as agreement.team_members
# limit this to relationships that aren't being logged as their own Classes
# and only include them on the editable side
auditable_relationships = list(
filter(
lambda rel: rel.secondary is not None and not rel.viewonly,
mapper.relationships,
)
)

for relationship in auditable_relationships:
key = relationship.key
hist = get_history(obj, key)
if hist.has_changes():
related_class_name = (
relationship.argument if isinstance(relationship.argument, str) else relationship.argument.__name__
)
changes[key] = {
"collection_of": related_class_name,
"added": convert_for_jsonb(hist.added),
}
if event_type != OpsDBHistoryType.NEW:
changes[key]["deleted"] = convert_for_jsonb(hist.deleted)
return DbRecordAudit(row_key, changes)


def convert_for_jsonb(value):
if isinstance(value, (str, bool, int, float, NoneType)):
return value
if isinstance(value, Enum):
return value.name
if isinstance(value, Decimal):
return float(value)
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, date):
return value.isoformat()
if isinstance(value, BaseModel):
if callable(getattr(value, "to_slim_dict", None)):
return value.to_slim_dict()
return value.to_dict()
if isinstance(value, (list, tuple)):
return [convert_for_jsonb(item) for item in value]
return str(value)
Loading