Skip to content

Commit

Permalink
feat: update EnterpriseGroupMembershipSerializer to include enrollmen…
Browse files Browse the repository at this point in the history
…t count (#2286)

* feat: update EnterpriseGroupMembershipSerializer to include learner course enrollment count and updated query to filter by name
  • Loading branch information
katrinan029 authored Dec 2, 2024
1 parent 12f3e25 commit 0b70eea
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 82 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Unreleased
----------
* nothing unreleased

[5.1.0]
--------
* feat: update EnterpriseGroupMembershipSerializer to include learner course enrollment count
* feat: updated learner query to filter by full name

[5.0.0]
--------
* refactor: Removed `plotly_token/` API endpoint and related views from enterprise API.
Expand Down
2 changes: 1 addition & 1 deletion enterprise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Your project description goes here.
"""

__version__ = "5.0.0"
__version__ = "5.1.0"
13 changes: 13 additions & 0 deletions enterprise/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ class EnterpriseGroupMembershipSerializer(serializers.ModelSerializer):
member_details = serializers.SerializerMethodField()
recent_action = serializers.SerializerMethodField()
status = serializers.CharField(required=False)
enrollments = serializers.SerializerMethodField()

class Meta:
model = models.EnterpriseGroupMembership
Expand All @@ -676,6 +677,7 @@ class Meta:
'recent_action',
'status',
'activated_at',
'enrollments',
)

def get_member_details(self, obj):
Expand All @@ -698,6 +700,17 @@ def get_recent_action(self, obj):
return f"Accepted: {obj.activated_at.strftime('%B %d, %Y')}"
return f"Invited: {obj.created.strftime('%B %d, %Y')}"

def get_enrollments(self, obj):
"""
Fetch all of user's enterprise enrollments
"""
if user := obj.enterprise_customer_user:
enrollments = models.EnterpriseCourseEnrollment.objects.filter(
enterprise_customer_user=user.user_id,
)
return len(enrollments)
return 0


class EnterpriseCustomerUserReadOnlySerializer(serializers.ModelSerializer):
"""
Expand Down
2 changes: 1 addition & 1 deletion enterprise/api/v1/views/enterprise_customer_members.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_members(self, request, *args, **kwargs):
au.id,
au.email,
au.date_joined,
coalesce(NULLIF(aup.name, ''), concat(au.first_name, ' ', au.last_name)) as full_name
coalesce(NULLIF(aup.name, ''), (au.first_name || ' ' || au.last_name)) as full_name
FROM enterprise_enterprisecustomeruser ecu
INNER JOIN auth_user as au on ecu.user_id = au.id
LEFT JOIN auth_userprofile as aup on au.id = aup.user_id
Expand Down
14 changes: 10 additions & 4 deletions enterprise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4692,13 +4692,19 @@ def _get_filtered_ecu_ids(self, user_query):
# https://docs.djangoproject.com/en/5.0/topics/security/
var_q = f"%{user_query}%"
sql_string = """
select entcu.id from enterprise_enterprisecustomeruser entcu
join auth_user au on entcu.user_id = au.id
where entcu.enterprise_customer_id = %s and au.email like %s;
with users as (
select ecu.id,
au.email,
coalesce(NULLIF(aup.name, ''), (au.first_name || ' ' || au.last_name)) as full_name
from enterprise_enterprisecustomeruser ecu
inner join auth_user au on ecu.user_id = au.id
left join auth_userprofile aup on au.id = aup.user_id
where ecu.enterprise_customer_id = %s
) select id from users where email like %s or full_name like %s;
"""
# Raw sql is picky about uuid format
customer_id = str(self.enterprise_customer.pk).replace("-", "")
ecus = EnterpriseCustomerUser.objects.raw(sql_string, (customer_id, var_q))
ecus = EnterpriseCustomerUser.objects.raw(sql_string, (customer_id, var_q, var_q))
return [ecu.id for ecu in ecus]

def _get_explicit_group_members(self, user_query=None, fetch_removed=False, pending_users_only=False,):
Expand Down
78 changes: 2 additions & 76 deletions tests/test_enterprise/api/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8341,65 +8341,6 @@ def test_list_learners_bad_sort_by(self):
assert response.status_code == 400
assert response.data.get('user_query')

def test_list_learners_filtered(self):
"""
Test that the list learners endpoint can be filtered by user details
"""
group = EnterpriseGroupFactory(
enterprise_customer=self.enterprise_customer,
)
pending_user = PendingEnterpriseCustomerUserFactory(
user_email="[email protected]",
enterprise_customer=self.enterprise_customer,
)
pending_user_query_string = f'?user_query={pending_user.user_email}'
url = settings.TEST_SERVER + reverse(
'enterprise-group-learners',
kwargs={'group_uuid': group.uuid},
) + pending_user_query_string
response = self.client.get(url)

assert response.json().get('count') == 0

group.save()
pending_membership = EnterpriseGroupMembershipFactory(
group=group,
pending_enterprise_customer_user=pending_user,
enterprise_customer_user=None,
)
existing_membership = EnterpriseGroupMembershipFactory(
group=group,
pending_enterprise_customer_user=None,
enterprise_customer_user__enterprise_customer=self.enterprise_customer,
)
existing_user = existing_membership.enterprise_customer_user.user
# Changing email to something that we know will be unique for collision purposes
existing_user.email = "[email protected]"
existing_user.save()
existing_user_query_string = '?user_query=ayylmao'
url = settings.TEST_SERVER + reverse(
'enterprise-group-learners',
kwargs={'group_uuid': group.uuid},
) + existing_user_query_string
response = self.client.get(url)

assert response.json().get('count') == 1
assert response.json().get('results')[0].get(
'enterprise_customer_user_id'
) == existing_membership.enterprise_customer_user.id

url = settings.TEST_SERVER + reverse(
'enterprise-group-learners',
kwargs={'group_uuid': group.uuid},
) + pending_user_query_string

response = self.client.get(url)

assert response.json().get('count') == 1
assert response.json().get('results')[0].get(
'pending_enterprise_customer_user_id'
) == pending_membership.pending_enterprise_customer_user.id

def test_list_removed_learners(self):
group = EnterpriseGroupFactory(
enterprise_customer=self.enterprise_customer,
Expand Down Expand Up @@ -8500,6 +8441,7 @@ def test_successful_list_learners(self):
},
'recent_action': f'Accepted: {datetime.now().strftime("%B %d, %Y")}',
'status': 'pending',
'enrollments': 0,
},
)
expected_response = {
Expand Down Expand Up @@ -8541,6 +8483,7 @@ def test_successful_list_learners(self):
},
'recent_action': f'Accepted: {datetime.now().strftime("%B %d, %Y")}',
'status': 'pending',
'enrollments': 0,
}
],
}
Expand Down Expand Up @@ -8632,23 +8575,6 @@ def test_successful_list_with_filters(self):
assert len(enterprise_filtered_response.json().get('results')) == 1
assert learner_filtered_response.json().get('results')[0].get('uuid') == str(new_group.uuid)

def test_list_members_little_bobby_tables(self):
"""
Test that we properly sanitize member user query filters
https://xkcd.com/327/
"""
# url: 'http://testserver/enterprise/api/v1/enterprise_group/<group uuid>/learners/'
url = settings.TEST_SERVER + reverse(
'enterprise-group-learners',
kwargs={'group_uuid': self.group_1.uuid},
)
# The problematic child
filter_query_param = "?user_query=Robert`); DROP TABLE enterprise_enterprisecustomeruser;--"
sql_injection_protected_response = self.client.get(url + filter_query_param)
assert sql_injection_protected_response.status_code == 200
assert not sql_injection_protected_response.json().get('results')
assert EnterpriseCustomerUser.objects.all()

def test_successful_post_group(self):
"""
Test creating a new group record
Expand Down

0 comments on commit 0b70eea

Please sign in to comment.