Skip to content

Commit

Permalink
Merge pull request #733 from openedx/ammar/process-expired-licenses
Browse files Browse the repository at this point in the history
feat: process expired licenses and unlink enterprise learners
  • Loading branch information
muhammad-ammar authored Nov 6, 2024
2 parents f3f6fe0 + 00485c3 commit 2d55d76
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 0 deletions.
8 changes: 8 additions & 0 deletions license_manager/apps/api_client/enterprise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class EnterpriseApiClient(BaseOAuthClient):
course_enrollments_revoke_endpoint = api_base_url + 'licensed-enterprise-course-enrollment/license_revoke/'
bulk_licensed_enrollments_expiration_endpoint = api_base_url \
+ 'licensed-enterprise-course-enrollment/bulk_licensed_enrollments_expiration/'
unlink_users_endpoint = api_base_url + 'enterprise-customer/'

def get_enterprise_customer_data(self, enterprise_customer_uuid):
"""
Expand Down Expand Up @@ -189,3 +190,10 @@ def bulk_enroll_enterprise_learners(self, enterprise_id, options):
"""
enrollment_url = '{}{}/enroll_learners_in_courses/'.format(self.enterprise_customer_endpoint, enterprise_id)
return self.client.post(enrollment_url, json=options, timeout=settings.BULK_ENROLL_REQUEST_TIMEOUT_SECONDS)

def bulk_unlink_enterprise_users(self, enterprise_uuid, options):
"""
Calls the Enterprise `unlink_users` API to unlink learners for an enterprise.
"""
enrollment_url = '{}{}/unlink_users/'.format(self.unlink_users_endpoint, enterprise_uuid)
return self.client.post(enrollment_url, json=options, timeout=settings.BULK_UNLINK_REQUEST_TIMEOUT_SECONDS)
2 changes: 2 additions & 0 deletions license_manager/apps/subscriptions/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,5 @@ class SegmentEvents:
}

ENTERPRISE_BRAZE_ALIAS_LABEL = 'Enterprise' # Do Not change this, this is consistent with other uses across edX repos.

EXPIRED_LICENSE_UNLINKED = 'edx.server.license-manager.expired.license.unlinked'
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import timedelta
from unittest import mock

import pytest
from django.core.management import call_command
from django.test import TestCase
from django.test.utils import override_settings

from license_manager.apps.subscriptions.constants import (
ACTIVATED,
ASSIGNED,
EXPIRED_LICENSE_UNLINKED,
REVOKED,
UNASSIGNED,
)
from license_manager.apps.subscriptions.models import LicenseEvent
from license_manager.apps.subscriptions.tests.factories import (
CustomerAgreementFactory,
LicenseFactory,
SubscriptionPlanFactory,
)
from license_manager.apps.subscriptions.utils import localized_utcnow


@pytest.mark.django_db
class UnlinkExpiredLicensesCommandTests(TestCase):
command_name = 'unlink_expired_licenses'
today = localized_utcnow()
customer_uuid = '76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae'

def _create_expired_plan_with_licenses(
self,
unassigned_licenses_count=1,
assigned_licenses_count=2,
activated_licenses_count=3,
revoked_licenses_count=4,
start_date=today - timedelta(days=7),
expiration_date=today,
expiration_processed=False
):
"""
Creates a plan with licenses. The plan is expired by default.
"""
customer_agreement = CustomerAgreementFactory(enterprise_customer_uuid=self.customer_uuid)
expired_plan = SubscriptionPlanFactory.create(
customer_agreement=customer_agreement,
start_date=start_date,
expiration_date=expiration_date,
expiration_processed=expiration_processed
)

LicenseFactory.create_batch(unassigned_licenses_count, status=UNASSIGNED, subscription_plan=expired_plan)
LicenseFactory.create_batch(assigned_licenses_count, status=ASSIGNED, subscription_plan=expired_plan)
LicenseFactory.create_batch(activated_licenses_count, status=ACTIVATED, subscription_plan=expired_plan)
LicenseFactory.create_batch(revoked_licenses_count, status=REVOKED, subscription_plan=expired_plan)

return expired_plan

def _get_allocated_license_uuids(self, subscription_plan):
return [str(license.uuid) for license in subscription_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED])]

@override_settings(
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae']
)
@mock.patch(
'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient',
return_value=mock.MagicMock()
)
def test_expired_licenses_unlinking(self, mock_enterprise_client):
"""
Verify that expired licenses unlinking working as expected.
"""
today = localized_utcnow()

# create a plan that is expired but difference between expiration_date and today is less than 90
self._create_expired_plan_with_licenses()
# create a plan that is expired 90 days ago
plan_expired_90_days_ago = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today - timedelta(days=90)
)

call_command(self.command_name)

# verify that correct licenses from desired subscription plan were recorded in database
for license_event in LicenseEvent.objects.all():
assert license_event.license.subscription_plan.uuid == plan_expired_90_days_ago.uuid
assert license_event.event_name == EXPIRED_LICENSE_UNLINKED

# verify that call to unlink_users endpoint has correct user emails
mock_client_call_args = mock_enterprise_client().bulk_unlink_enterprise_users.call_args_list[0]
assert mock_client_call_args.args[0] == self.customer_uuid
assert sorted(mock_client_call_args.args[1]['user_emails']) == sorted([
license.user_email for license in plan_expired_90_days_ago.licenses.filter(
status__in=[ASSIGNED, ACTIVATED]
)
])

@override_settings(
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae']
)
@mock.patch(
'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient',
return_value=mock.MagicMock()
)
def test_expired_licenses_other_active_licenses(self, mock_enterprise_client):
"""
Verify that no unlinking happens when all expired licenses has other active licenses.
"""
assert LicenseEvent.objects.count() == 0
today = localized_utcnow()

# create a plan that is expired 90 days ago
plan_expired_90_days_ago = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today - timedelta(days=90)
)
# just another plan
another_plan = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today + timedelta(days=10)
)

# fetch user emails from the expired plan
user_emails = list(plan_expired_90_days_ago.licenses.filter(
status__in=[ASSIGNED, ACTIVATED]
).values_list('user_email', flat=True))

# assigned the above emails to licenses to create the test scenario where a learner has other active licenses
for license in another_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED]):
license.user_email = user_emails.pop()
license.save()

call_command(self.command_name)

# verify that no records were created in database for LicenseEvent
assert LicenseEvent.objects.count() == 0

# verify that no calls have been made to the unlink_users endpoint.
assert mock_enterprise_client().bulk_unlink_enterprise_users.call_count == 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@

import logging
from datetime import timedelta

from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.paginator import Paginator
from django.db.models import Exists, OuterRef

from license_manager.apps.api_client.enterprise import EnterpriseApiClient
from license_manager.apps.subscriptions.constants import (
ACTIVATED,
ASSIGNED,
EXPIRED_LICENSE_UNLINKED,
)
from license_manager.apps.subscriptions.models import (
CustomerAgreement,
License,
LicenseEvent,
SubscriptionPlan,
)
from license_manager.apps.subscriptions.utils import localized_utcnow


logger = logging.getLogger(__name__)


class Command(BaseCommand):
help = (
'Unlink expired licenses.'
)

def add_arguments(self, parser):
"""
Entry point to add arguments.
"""
parser.add_argument(
'--dry-run',
action='store_true',
dest='dry_run',
default=False,
help='Dry Run, print log messages without unlinking the learners.',
)

def expired_licenses(self, log_prefix, enterprise_customer_uuid):
"""
Get expired licenses.
"""
now = localized_utcnow()
expired_subscription_plan_uuids = []

customer_agreement = CustomerAgreement.objects.get(enterprise_customer_uuid=enterprise_customer_uuid)

# fetch expired subscription plans where the expiration date is older than 90 days.
expired_subscription_plans = SubscriptionPlan.objects.filter(
customer_agreement=customer_agreement,
expiration_date__lt=now - timedelta(days=90),
).prefetch_related(
'licenses'
).values('uuid', 'expiration_date')

# log expired plan uuids and their expiration dates
for plan in expired_subscription_plans:
logger.info(
'%s Expired plan. UUID: [%s], ExpirationDate: [%s]',
log_prefix,
plan.get('uuid'),
plan.get('expiration_date')
)

expired_subscription_plan_uuids = [
plan.get('uuid') for plan in expired_subscription_plans
]

queryset = License.objects.filter(
status__in=[ASSIGNED, ACTIVATED],
renewed_to=None,
subscription_plan__uuid__in=expired_subscription_plan_uuids,
).select_related(
'subscription_plan',
).values('uuid', 'lms_user_id', 'user_email')

# subquery to check for the existence of `EXPIRED_LICENSE_UNLINKED`
event_exists_subquery = LicenseEvent.objects.filter(
license=OuterRef('pk'),
event_name=EXPIRED_LICENSE_UNLINKED
).values('pk')

# exclude previously processed licenses.
queryset = queryset.exclude(Exists(event_exists_subquery))

return queryset

def handle(self, *args, **options):
"""
Unlink expired licenses.
"""
unlink = not options['dry_run']

log_prefix = '[UNLINK_EXPIRED_LICENSES]'
if not unlink:
log_prefix = '[DRY RUN]'

logger.info('%s Command started.', log_prefix)

enterprise_customer_uuids = settings.CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED
for enterprise_customer_uuid in enterprise_customer_uuids:
logger.info('%s Unlinking started for licenses. Enterprise: [%s]', log_prefix, enterprise_customer_uuid)
self.unlink_expired_licenses(log_prefix, enterprise_customer_uuid, unlink)
logger.info('%s Unlinking completed for licenses. Enterprise: [%s]', log_prefix, enterprise_customer_uuid)

logger.info('%s Command completed.', log_prefix)

def unlink_expired_licenses(self, log_prefix, enterprise_customer_uuid, unlink):
"""
Unlink expired licenses.
"""
expired_licenses = self.expired_licenses(log_prefix, enterprise_customer_uuid)

if not expired_licenses:
logger.info(
'%s No expired licenses were found for enterprise: [%s].',
log_prefix, enterprise_customer_uuid
)
return

paginator = Paginator(expired_licenses, 100)
for page_number in paginator.page_range:
licenses = paginator.page(page_number)

license_uuids = []
user_emails = []

for license in licenses:
# check if the user associated with the expired license
# has any other active licenses with the same customer
other_active_licenses = License.for_user_and_customer(
user_email=license.get('user_email'),
lms_user_id=license.get('lms_user_id'),
enterprise_customer_uuid=enterprise_customer_uuid,
active_plans_only=True,
current_plans_only=True,
).exists()
if other_active_licenses:
continue

license_uuids.append(license.get('uuid'))
user_emails.append(license.get('user_email'))

if unlink and user_emails:
EnterpriseApiClient().bulk_unlink_enterprise_users(
enterprise_customer_uuid,
{
'user_emails': user_emails,
'is_relinkable': True
},
)

# Create license events for unlinked licenses to avoid processing them again.
unlinked_license_events = [
LicenseEvent(license_id=license_uuid, event_name=EXPIRED_LICENSE_UNLINKED)
for license_uuid in license_uuids
]
LicenseEvent.objects.bulk_create(unlinked_license_events, batch_size=100)

logger.info(
"%s learners unlinked for licenses. Enterprise: [%s], LicenseUUIDs: [%s].",
log_prefix,
enterprise_customer_uuid,
license_uuids
)
2 changes: 2 additions & 0 deletions license_manager/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,5 @@
]

CUSTOMERS_WITH_CUSTOM_LICENSE_EVENTS = ['00000000-1111-2222-3333-444444444444']
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED = []
BULK_UNLINK_REQUEST_TIMEOUT_SECONDS = os.environ.get('BULK_UNLINK_REQUEST_TIMEOUT_SECONDS', 120)

0 comments on commit 2d55d76

Please sign in to comment.