diff --git a/care/facility/api/viewsets/patient.py b/care/facility/api/viewsets/patient.py index 2cc07ac9a5..0a502dc06a 100644 --- a/care/facility/api/viewsets/patient.py +++ b/care/facility/api/viewsets/patient.py @@ -147,8 +147,24 @@ def filter_by_category(self, queryset, name, value): field_name="last_consultation__symptoms_onset_date" ) last_consultation_admitted_bed_type_list = MultiSelectFilter( - field_name="last_consultation__current_bed__bed__bed_type" + method="filter_by_bed_type", ) + + def filter_by_bed_type(self, queryset, name, value): + if not value: + return queryset + + values = value.split(",") + filter_q = Q() + + if "None" in values: + filter_q |= Q(last_consultation__current_bed__isnull=True) + values.remove("None") + if values: + filter_q |= Q(last_consultation__current_bed__bed__bed_type__in=values) + + return queryset.filter(filter_q) + last_consultation_admitted_bed_type = CareChoiceFilter( field_name="last_consultation__current_bed__bed__bed_type", choice_dict=REVERSE_BED_TYPES, diff --git a/care/facility/tests/test_patientfilterset.py b/care/facility/tests/test_patientfilterset.py new file mode 100644 index 0000000000..be962b791c --- /dev/null +++ b/care/facility/tests/test_patientfilterset.py @@ -0,0 +1,108 @@ +from django.utils import timezone + +from care.facility.api.viewsets.patient import PatientFilterSet +from care.facility.models import ( + AssetLocation, + Bed, + ConsultationBed, + PatientRegistration, +) +from care.utils.tests.test_base import TestBase + + +class PatientFilterSetTestCase(TestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_filter_by_bed_type(self): + patient1 = self.create_patient(name="patient1") + patient2 = self.create_patient(name="patient2") + patient3 = self.patient + + # create asset + asset1 = AssetLocation.objects.create( + name="asset1", location_type=1, facility=self.facility + ) + + # create beds + bed1_data = { + "name": "bed 1", + "bed_type": 1, + "location": asset1, + "facility": self.facility, + } + bed2_data = { + "name": "bed 2", + "bed_type": 2, + "location": asset1, + "facility": self.facility, + } + + bed1 = Bed.objects.create(**bed1_data) + bed2 = Bed.objects.create(**bed2_data) + + consultation1 = self.create_consultation( + patient=patient1, facility=self.facility + ) + consultation2 = self.create_consultation( + patient=patient2, facility=self.facility + ) + consultation3 = self.create_consultation( + patient=patient3, facility=self.facility + ) + + # consultation beds + consultation_bed1 = ConsultationBed.objects.create( + consultation=consultation1, bed=bed1, start_date=timezone.now() + ) + consultation_bed2 = ConsultationBed.objects.create( + consultation=consultation2, bed=bed2, start_date=timezone.now() + ) + + consultation1.current_bed = consultation_bed1 + consultation1.save(update_fields=["current_bed"]) + consultation2.current_bed = consultation_bed2 + consultation2.save(update_fields=["current_bed"]) + # None for consultation 3 + + patient1.last_consultation = consultation1 + patient1.save(update_fields=["last_consultation"]) + patient2.last_consultation = consultation2 + patient2.save(update_fields=["last_consultation"]) + patient3.last_consultation = consultation3 + patient3.save(update_fields=["last_consultation"]) + + # Create the filter set instance + filterset = PatientFilterSet(queryset=PatientRegistration.objects.all()) + + # filter + filtered_queryset = filterset.filter_by_bed_type( + name="last_consultation_admitted_bed_type_list", + value="1,None", + queryset=PatientRegistration.objects.all(), + ) + self.assertEqual(len(filtered_queryset), 2) # patient, patient1 and patient3 + self.assertTrue(patient1 in filtered_queryset) + self.assertFalse(patient2 in filtered_queryset) + self.assertTrue(patient3 in filtered_queryset) + + filtered_queryset = filterset.filter_by_bed_type( + name="last_consultation_admitted_bed_type_list", + value="None", + queryset=PatientRegistration.objects.all(), + ) + self.assertEqual(len(filtered_queryset), 1) + self.assertFalse(patient1 in filtered_queryset) + self.assertFalse(patient2 in filtered_queryset) + self.assertTrue(patient3 in filtered_queryset) + + filtered_queryset = filterset.filter_by_bed_type( + name="last_consultation_admitted_bed_type_list", + value="2", + queryset=PatientRegistration.objects.all(), + ) + self.assertEqual(len(filtered_queryset), 1) + self.assertFalse(patient1 in filtered_queryset) + self.assertTrue(patient2 in filtered_queryset) + self.assertFalse(patient3 in filtered_queryset)