diff --git a/enterprise/api_client/lms.py b/enterprise/api_client/lms.py index 47e08edb49..d677bd4276 100644 --- a/enterprise/api_client/lms.py +++ b/enterprise/api_client/lms.py @@ -6,6 +6,8 @@ import time from urllib.parse import urljoin +import requests + from opaque_keys.edx.keys import CourseKey from requests.exceptions import ( # pylint: disable=redefined-builtin ConnectionError, @@ -274,6 +276,34 @@ def get_enrolled_courses(self, username): response.raise_for_status() return response.json() + def allow_enrollment(self, email, course_id, auto_enroll=False): + """ + Call the enrollment API to allow enrollment for the given email and course_id. + + Args: + email (str): The email address of the user to be allowed to enroll in the course. + course_id (str): The string value of the course's unique identifier. + auto_enroll (bool): Whether to auto-enroll the user in the course upon registration / activation. + + Returns: + dict: A dictionary containing details of the created CourseEnrollmentAllowed object. + + """ + api_url = self.get_api_url("enrollment_allowed") + response = self.client.post( + f"{api_url}/", + json={ + 'email': email, + 'course_id': course_id, + 'auto_enroll': auto_enroll, + } + ) + if response.status_code == requests.codes.conflict: + LOGGER.info(response.json()["message"]) + else: + response.raise_for_status() + return response.json() + class CourseApiClient(NoAuthAPIClient): """ diff --git a/enterprise/utils.py b/enterprise/utils.py index 99debbc724..8b25aa1849 100644 --- a/enterprise/utils.py +++ b/enterprise/utils.py @@ -2407,19 +2407,14 @@ def truncate_string(string, max_length=MAX_ALLOWED_TEXT_LENGTH): def ensure_course_enrollment_is_allowed(course_id, email, enrollment_api_client): """ - Create a CourseEnrollmentAllowed object for invitation-only courses. + Calls the enrollment API to create a CourseEnrollmentAllowed object for + invitation-only courses. Arguments: course_id (str): ID of the course to allow enrollment email (str): email of the user whose enrollment should be allowed enrollment_api_client (:class:`enterprise.api_client.lms.EnrollmentApiClient`): Enrollment API Client """ - if not CourseEnrollmentAllowed: - raise NotConnectedToOpenEdX() - course_details = enrollment_api_client.get_course_details(course_id) if course_details["invite_only"]: - CourseEnrollmentAllowed.objects.update_or_create( - course_id=course_id, - email=email, - ) + enrollment_api_client.allow_enrollment(email, course_id) diff --git a/enterprise/views.py b/enterprise/views.py index 1086fdb4d8..8029e62501 100644 --- a/enterprise/views.py +++ b/enterprise/views.py @@ -683,6 +683,15 @@ def _enroll_learner_in_course( existing_enrollment.get('mode') == constants.CourseModes.AUDIT or existing_enrollment.get('is_active') is False ): + if enterprise_customer.allow_enrollment_in_invite_only_courses: + ensure_course_enrollment_is_allowed(course_id, request.user.email, enrollment_api_client) + LOGGER.info( + 'User {user} is allowed to enroll in Course {course_id}.'.format( + user=request.user.username, + course_id=course_id + ) + ) + course_mode = get_best_mode_from_course_key(course_id) LOGGER.info( 'Retrieved Course Mode: {course_modes} for Course {course_id}'.format( diff --git a/tests/test_enterprise/test_utils.py b/tests/test_enterprise/test_utils.py index be10f1ede4..fd01130210 100644 --- a/tests/test_enterprise/test_utils.py +++ b/tests/test_enterprise/test_utils.py @@ -539,10 +539,9 @@ def test_truncate_string(self): self.assertEqual(len(truncated_string), MAX_ALLOWED_TEXT_LENGTH) @ddt.data(True, False) - @mock.patch("enterprise.utils.CourseEnrollmentAllowed") - def test_ensure_course_enrollment_is_allowed(self, invite_only, mock_cea): + def test_ensure_course_enrollment_is_allowed(self, invite_only): """ - Test that the CourseEnrollmentAllowed is created only for the "invite_only" courses. + Test that the enrollment allow endpoint is called for the "invite_only" courses. """ self.create_user() mock_enrollment_api = mock.Mock() @@ -551,9 +550,9 @@ def test_ensure_course_enrollment_is_allowed(self, invite_only, mock_cea): ensure_course_enrollment_is_allowed("test-course-id", self.user.email, mock_enrollment_api) if invite_only: - mock_cea.objects.update_or_create.assert_called_with( + mock_enrollment_api.return_value.allow_enrollment.assert_called_with( + email=self.user.email, course_id="test-course-id", - email=self.user.email ) else: - mock_cea.objects.update_or_create.assert_not_called() + mock_enrollment_api.return_value.allow_enrollment.assert_not_called() diff --git a/tests/test_enterprise/views/test_course_enrollment_view.py b/tests/test_enterprise/views/test_course_enrollment_view.py index 8ed1819d5a..ae2881245f 100644 --- a/tests/test_enterprise/views/test_course_enrollment_view.py +++ b/tests/test_enterprise/views/test_course_enrollment_view.py @@ -1623,10 +1623,8 @@ def test_post_course_specific_enrollment_view_premium_mode( @mock.patch('enterprise.views.EnrollmentApiClient') @mock.patch('enterprise.views.get_data_sharing_consent') @mock.patch('enterprise.utils.Registry') - @mock.patch('enterprise.utils.CourseEnrollmentAllowed') def test_post_course_specific_enrollment_view_invite_only_courses( self, - mock_cea, registry_mock, get_data_sharing_consent_mock, enrollment_api_client_mock, @@ -1664,9 +1662,9 @@ def test_post_course_specific_enrollment_view_invite_only_courses( } ) - mock_cea.objects.update_or_create.assert_called_with( + enrollment_api_client_mock.return_value.allow_enrollment.assert_called_with( + email=self.user.email, course_id=course_id, - email=self.user.email ) assert response.status_code == 302