Skip to content

Commit

Permalink
feat: split discount
Browse files Browse the repository at this point in the history
resolves #257
  • Loading branch information
igobranco committed Mar 22, 2024
1 parent d88421b commit d286ffa
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 69 deletions.
4 changes: 4 additions & 0 deletions apps/billing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ classDiagram
-vat_identification_country: CountryField
-total_amount_exclude_vat: DecimalField
-total_amount_include_vat: DecimalField
-total_discount_excl_tax: DecimalField
-total_discount_incl_tax: DecimalField
-currency: CharField
-document_id: CharField
-payment_type: CharField
Expand All @@ -41,6 +43,8 @@ classDiagram
-vat_tax: DecimalField
-unit_price_excl_vat: DecimalField
-unit_price_incl_vat: DecimalField
-discount_excl_tax: DecimalField
-discount_incl_tax: DecimalField
-organization: CharField
-product_id: CharField
-product_code: CharField
Expand Down
5 changes: 4 additions & 1 deletion apps/billing/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Meta:
"date_time_between", start_date="-5d", end_date="-1d", tzinfo=timezone.get_current_timezone()
)
document_id = factory.Faker("pystr_format", string_format="DCI-######{{random_int}}")
total_discount_excl_tax = 0.00
total_discount_incl_tax = 0.00

@factory.lazy_attribute
def total_amount_include_vat(self):
Expand All @@ -56,7 +58,8 @@ class Meta:
vat_tax = factory.Faker("pydecimal", min_value=1, max_value=100, left_digits=3, right_digits=2)
organization_code = factory.Sequence(lambda n: f"Org {n}")
product_code = "".join([random.choice(string.ascii_uppercase) for _ in range(5)])
discount = factory.Faker("pydecimal", min_value=0, max_value=1, left_digits=1, right_digits=2)
discount_excl_tax = 0.00
discount_incl_tax = 0.00

@factory.lazy_attribute
def product_id(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Generated by Django 4.2.8 on 2024-03-22 14:40

import django_countries.fields
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("billing", "0006_sagex3transactioninformation_series_and_more"),
]

operations = [
migrations.RemoveField(
model_name="transactionitem",
name="discount",
),
migrations.AddField(
model_name="transaction",
name="total_discount_excl_tax",
field=models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
migrations.AddField(
model_name="transaction",
name="total_discount_incl_tax",
field=models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
migrations.AddField(
model_name="transactionitem",
name="discount_excl_tax",
field=models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
migrations.AddField(
model_name="transactionitem",
name="discount_incl_tax",
field=models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
migrations.AlterField(
model_name="transaction",
name="country_code",
field=django_countries.fields.CountryField(blank=True, max_length=255, null=True),
),
]
33 changes: 21 additions & 12 deletions apps/billing/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from django.conf import settings
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from django.utils.translation import gettext_lazy as _
from django_countries.fields import CountryField
Expand Down Expand Up @@ -53,6 +52,8 @@ class Transaction(BaseModel):
vat_identification_country = CountryField(max_length=255, null=True, blank=True)
total_amount_exclude_vat = models.DecimalField(max_digits=10, decimal_places=2)
total_amount_include_vat = models.DecimalField(max_digits=10, decimal_places=2)
total_discount_excl_tax = models.DecimalField(max_digits=10, decimal_places=2, default=0.00)
total_discount_incl_tax = models.DecimalField(max_digits=10, decimal_places=2, default=0.00)
currency = models.CharField(max_length=7, default="EUR")
payment_type = models.CharField(max_length=20, default="credit_card")
transaction_type = models.CharField(max_length=15, choices=TRANSACTION_TYPE)
Expand All @@ -68,17 +69,19 @@ class TransactionItem(BaseModel):
One-to-many relationship with Transaction model (related_name='transaction_items').
The fields for this model was defined in the following documentation:
ecommerce_integration_specification
https://github.com/fccn/nau-financial-manager/blob/main/docs/integrations/ecommerce_integration_specification.md
```
docs/integrations/ecommerce_integration_specification.md
```
- Description
- Quantity
- Amount excluding VAT
- Amount including VAT
- Unit price excluding VAT
- Unit price including VAT
- Product id
- Organization
- Product code
- Discount
- Total Discount excluding TAX
- Total Discount including TAX
"""

Expand All @@ -88,15 +91,21 @@ class TransactionItem(BaseModel):
vat_tax = models.DecimalField(max_digits=5, decimal_places=2)
unit_price_excl_vat = models.DecimalField(max_digits=10, decimal_places=2)
unit_price_incl_vat = models.DecimalField(max_digits=10, decimal_places=2)
discount_excl_tax = models.DecimalField(max_digits=10, decimal_places=2, default=0.00)
discount_incl_tax = models.DecimalField(max_digits=10, decimal_places=2, default=0.00)
organization_code = models.CharField(max_length=255)
product_id = models.CharField(max_length=50)
product_code = models.CharField(max_length=50)
discount = models.DecimalField(
default=0.0,
max_digits=3,
validators=[MaxValueValidator(1), MinValueValidator(0)],
decimal_places=2,
)

@property
def discount_rate(self):
"""
The discount rate
"""
try:
round(self.discount_incl_tax / (self.unit_price_incl_vat + self.discount_incl_tax), 2)
except ZeroDivisionError:
return 0

def __str__(self):
return self.product_id
Expand Down
27 changes: 7 additions & 20 deletions apps/billing/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from django.core.validators import MaxValueValidator, MinValueValidator
from django_countries.serializers import CountryFieldMixin
from rest_framework import serializers

Expand All @@ -11,20 +10,8 @@
class TransactionItemSerializer(CountryFieldMixin, serializers.ModelSerializer):
"""
A serializer class for the `TransactionItem` model.
This serializer includes the `transaction`, `description`, `quantity`, `vat_tax`, `unit_price_excl_vat`,
`unit_price_incl_vat`, `organization_code`, `product_code`, `product_id` and `discount` fields of the `TransactionItem` model.
"""

# Redefined the discount field because for some reason it isn't using the model default value.
# So the solution was to define it again in the Serializer.
discount = serializers.DecimalField(
default=0.00,
max_digits=3,
decimal_places=2,
validators=[MaxValueValidator(1), MinValueValidator(0)],
)

class Meta:
model = TransactionItem
fields = [
Expand All @@ -35,21 +22,17 @@ class Meta:
"vat_tax",
"unit_price_excl_vat",
"unit_price_incl_vat",
"discount_excl_tax",
"discount_incl_tax",
"organization_code",
"product_id",
"product_code",
"discount",
]


class TransactionSerializer(CountryFieldMixin, serializers.ModelSerializer):
"""
A serializer class for the `Transaction` model.
This serializer includes the `id`, `client_name`, `email`, `address_line_1`, `address_line_2,` `vat_identification_country`,
`vat_identification_number`, `city`, `postal_code`, `state`, `country_code`, `total_amount_exclude_vat`, `total_amount_include_vat`, `payment_type`,
`transaction_id`, `currency`, `transaction_date`, `transaction_type`, `document_id` and `transaction_items` fields of the `Transaction` model. The `transaction_items` field is a nested
serializer that includes the `TransactionItem` model fields.
"""

class Meta:
Expand All @@ -68,6 +51,8 @@ class Meta:
"vat_identification_country",
"total_amount_exclude_vat",
"total_amount_include_vat",
"total_discount_excl_tax",
"total_discount_incl_tax",
"currency",
"payment_type",
"transaction_type",
Expand Down Expand Up @@ -95,7 +80,7 @@ class ProcessTransactionSerializer(CountryFieldMixin, serializers.ModelSerialize
to_representation(instance): Returns the given instance.
"""

items = serializers.ListField()
items = TransactionItemSerializerWithoutTransaction(many=True)

class Meta:
model = Transaction
Expand All @@ -115,6 +100,8 @@ class Meta:
"vat_identification_country",
"total_amount_exclude_vat",
"total_amount_include_vat",
"total_discount_excl_tax",
"total_discount_incl_tax",
"currency",
"payment_type",
"transaction_date",
Expand Down
2 changes: 1 addition & 1 deletion apps/billing/services/processor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __generate_items_as_xml(self, items: list[TransactionItem]) -> str:
<FLD NAME="QTY">{item.quantity}</FLD>
<FLD NAME="STU">UN</FLD>
<FLD NAME="GROPRI">{item.unit_price_excl_vat}</FLD>
<FLD NAME="DISCRGVAL1">{item.discount}</FLD>
<FLD NAME="DISCRGVAL1">{item.discount_rate}</FLD>
<FLD NAME="VACITM1">{self.__vacitm1}</FLD>
</LIN>
"""
Expand Down
71 changes: 37 additions & 34 deletions apps/billing/tests/test_process_transaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
import json
from copy import deepcopy
import logging
from unittest import mock

import factory
Expand All @@ -15,6 +15,8 @@

from .test_transaction_service import processor_success_response

log = logging.getLogger(__name__)


class ProcessTransactionTest(TestCase):
"""
Expand Down Expand Up @@ -48,8 +50,9 @@ def test_create_transaction(self, mock):
Test that a transaction can be created with a valid token.
"""
self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)

log.info(self.payload)
response = self.client.post(self.endpoint, self.payload, format="json")
log.info(response.content)
self.assertEqual(response.status_code, 201)

transaction = Transaction.objects.get(transaction_id=self.payload["transaction_id"])
Expand Down Expand Up @@ -145,46 +148,46 @@ def test_valid_transaction_item_discount(self, mock):
response = self.client.post(self.endpoint, self.payload, format="json")

for item in self.payload["items"]:
self.assertTrue(decimal.Decimal(item["discount"]) >= 0)
self.assertTrue(decimal.Decimal(item["discount"]) <= 1)
self.assertTrue(decimal.Decimal(item["discount_excl_tax"]) == 0.0)
self.assertTrue(decimal.Decimal(item["discount_incl_tax"]) == 0.0)

self.assertEqual(response.status_code, 201)

def test_invalid_transaction_item_discount_greater_than_1(self):
"""
This test ensures that is not possible to process a transaction with invalid discount value in items
"""
# def test_invalid_transaction_item_discount_greater_than_1(self):
# """
# This test ensures that is not possible to process a transaction with invalid discount value in items
# """

self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
invalid_payload = deepcopy(self.payload)
invalid_payload["items"][0]["discount"] = 1.1
response = self.client.post(self.endpoint, invalid_payload, format="json")
# self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
# invalid_payload = deepcopy(self.payload)
# invalid_payload["items"][0]["total_discount_incl_tax"] = 1.1
# response = self.client.post(self.endpoint, invalid_payload, format="json")

self.assertEqual(response.status_code, 400)
self.assertEqual(str(response.data["discount"][0]), "Ensure this value is less than or equal to 1.")
# self.assertEqual(response.status_code, 400)
# self.assertEqual(str(response.data["total_discount_incl_tax"][0]), "Ensure this value is less than or equal to 1.")

def test_invalid_transaction_item_discount_smaller_than_0(self):
"""
This test ensures that is not possible to process a transaction with invalid discount value in items
"""
# def test_invalid_transaction_item_discount_smaller_than_0(self):
# """
# This test ensures that is not possible to process a transaction with invalid discount value in items
# """

self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
invalid_payload = deepcopy(self.payload)
invalid_payload["items"][0]["discount"] = -1
response = self.client.post(self.endpoint, invalid_payload, format="json")
# self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
# invalid_payload = deepcopy(self.payload)
# invalid_payload["items"][0]["discount"] = -1
# response = self.client.post(self.endpoint, invalid_payload, format="json")

self.assertEqual(response.status_code, 400)
self.assertEqual(str(response.data["discount"][0]), "Ensure this value is greater than or equal to 0.")
# self.assertEqual(response.status_code, 400)
# self.assertEqual(str(response.data["discount"][0]), "Ensure this value is greater than or equal to 0.")

@mock.patch("requests.post", side_effect=processor_success_response)
def test_invalid_transaction_item_discount_none(self, mock):
"""
This test ensures that is not possible to process a transaction without a discount value in items
"""
# @mock.patch("requests.post", side_effect=processor_success_response)
# def test_invalid_transaction_item_discount_none(self, mock):
# """
# This test ensures that is not possible to process a transaction without a discount value in items
# """

self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
invalid_payload = deepcopy(self.payload)
invalid_payload["items"][0].pop("discount", None)
response = self.client.post(self.endpoint, invalid_payload, format="json")
# self.client.credentials(HTTP_AUTHORIZATION="Token " + self.token.key)
# invalid_payload = deepcopy(self.payload)
# invalid_payload["items"][0].pop("discount", None)
# response = self.client.post(self.endpoint, invalid_payload, format="json")

self.assertEqual(response.status_code, 201)
# self.assertEqual(response.status_code, 201)
2 changes: 1 addition & 1 deletion apps/shared_revenue/services/split_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _assembly_each_result(
"percentage_for_nau": configuration.nau_percentage,
"amount_for_nau_including_vat": (item.unit_price_incl_vat * configuration.nau_percentage) * item.quantity,
"amount_for_nau_exclude_vat": (item.unit_price_excl_vat * configuration.nau_percentage) * item.quantity,
"discount": item.discount,
"discount_rate": item.discount_rate,
}

def _calculate_nau_percentage(self, product_id: str):
Expand Down

0 comments on commit d286ffa

Please sign in to comment.