diff --git a/apps/accounts/admin.py b/apps/accounts/admin.py index d266ba24..80aca7ff 100644 --- a/apps/accounts/admin.py +++ b/apps/accounts/admin.py @@ -1350,6 +1350,9 @@ class ProviderRequestIPRangeInline(AdminOnlyTabularInline): model = ProviderRequestIPRange extra = 0 + readonly_fields = ["ip_range_size"] + fields = ["start", "end", "ip_range_size"] + class ProviderRequestEvidenceInline(AdminOnlyTabularInline): model = ProviderRequestEvidence diff --git a/apps/accounts/models/provider_request.py b/apps/accounts/models/provider_request.py index 6dc3750f..ba8295f9 100644 --- a/apps/accounts/models/provider_request.py +++ b/apps/accounts/models/provider_request.py @@ -23,6 +23,8 @@ Service, ) +import ipaddress + logger = logging.getLogger(__name__) # noqa @@ -407,6 +409,23 @@ class ProviderRequestIPRange(models.Model): def __str__(self) -> str: return f"{self.start} - {self.end}" + def ip_range_size(self) -> int: + """ + Return the size of the IP range, based on the start and end ip address. + """ + if not self.start or not self.end: + return 0 + + # Convert string IP addresses to IP address objects + start_ip = ipaddress.ip_address(self.start) + end_ip = ipaddress.ip_address(self.end) + + # Calculate the difference and add 1 (if start and end IP addresses are the same, we still want it to show as 1) + return int(end_ip) - int(start_ip) + 1 + + # Add a short description for the admin + ip_range_size.short_description = "IP Range Size" + def clean(self) -> None: """ Validates an IP range. @@ -419,7 +438,12 @@ def clean(self) -> None: according to the ModelForm validation logic. """ if self.start and self.end: - validate_ip_range(self.start, self.end) + try: + validate_ip_range(self.start, self.end) + except ValueError as e: + raise ValidationError({"start": e}) + except TypeError as e: + raise ValidationError({"Mismatching IP ranges": e}) class ProviderRequestEvidence(models.Model): diff --git a/apps/accounts/tests/test_provider_ip_range.py b/apps/accounts/tests/test_provider_ip_range.py new file mode 100644 index 00000000..0406727c --- /dev/null +++ b/apps/accounts/tests/test_provider_ip_range.py @@ -0,0 +1,74 @@ +import pytest +from ..models import ProviderRequest, ProviderRequestIPRange + +# add import for django validation error +from django.core.exceptions import ValidationError + + +@pytest.fixture +def ip_range(): + return ProviderRequestIPRange() + + +@pytest.mark.parametrize( + "start_ip,end_ip,expected_size", + [ + # IPv4 test cases + ("192.168.1.1", "192.168.1.1", 1), # Single IP + ("192.168.1.1", "192.168.1.10", 10), # Small range + ("192.168.1.0", "192.168.1.255", 256), # Full subnet + ("10.0.0.0", "10.0.1.0", 257), # Across subnet boundary + # IPv6 test cases + ("2001:db8::1", "2001:db8::1", 1), # Single IP + ("2001:db8::1", "2001:db8::10", 16), # Small range + ("2001:db8::0", "2001:db8::ff", 256), # Larger range + # Edge cases + ("0.0.0.0", "0.0.0.255", 256), # Start of IPv4 range + ("255.255.255.0", "255.255.255.255", 256), # End of IPv4 range + ], +) +def test_ip_range_size_calculation(ip_range, start_ip, end_ip, expected_size): + ip_range.start = start_ip + ip_range.end = end_ip + assert ip_range.ip_range_size() == expected_size + + +def test_ip_range_size_with_empty_values(ip_range): + # Test with no values set + assert ip_range.ip_range_size() == 0 + + # Test with only start IP + ip_range.start = "192.168.1.1" + assert ip_range.ip_range_size() == 0 + + # Test with only end IP + ip_range.start = None + ip_range.end = "192.168.1.10" + assert ip_range.ip_range_size() == 0 + + +# @pytest.mark.django_db +# def test_full_model_creation(): +# """Test creating and saving a model instance""" +# provider_request = ProviderRequest.objects.create() # Add necessary fields +# ip_range = ProviderRequestIPRange.objects.create( +# start="192.168.1.1", end="192.168.1.10", request=provider_request +# ) +# assert ip_range.ip_range_size() == 10 +# assert str(ip_range) == "192.168.1.1 - 192.168.1.10" + + +@pytest.mark.parametrize( + "start_ip,end_ip", + [ + ("192.168.1.10", "192.168.1.1"), # End IP before start IP + ("invalid_ip", "192.168.1.1"), # Invalid IP format + ("192.168.1.1", "invalid_ip"), # Invalid IP format + ("2001:db8::1", "192.168.1.1"), # Mixed IPv6 and IPv4 + ], +) +def test_invalid_ip_ranges(ip_range, start_ip, end_ip): + ip_range.start = start_ip + ip_range.end = end_ip + with pytest.raises(ValidationError): + ip_range.clean()