diff --git a/service/models/shop_cart.py b/service/models/shop_cart.py index cc33182..116f901 100644 --- a/service/models/shop_cart.py +++ b/service/models/shop_cart.py @@ -3,22 +3,22 @@ """ -# from enum import Enum +from enum import Enum from .persistent_base import db, logger, DataValidationError, PersistentBase from .shop_cart_item import ShopCartItem # from decimal import Decimal -# class ShopCartStatus(Enum): -# """Enumeration of different shop cart statuses""" +class ShopCartStatus(Enum): + """Enumeration of different shop cart statuses""" -# # An item has been added to the shop cart -# ACTIVE = 0 -# # User reached last step of checkout -# PENDING = 1 -# # Order was fulfilled or cart was abandoned -# INACTIVE = 3 + # An item has been added to the shop cart + ACTIVE = 0 + # User reached last step of checkout + PENDING = 1 + # Order was fulfilled or cart was abandoned + INACTIVE = 3 class ShopCart(db.Model, PersistentBase): @@ -33,11 +33,12 @@ class ShopCart(db.Model, PersistentBase): user_id = db.Column(db.Integer) name = db.Column(db.String(63)) total_price = db.Column(db.Numeric(precision=10, scale=2)) - # status = db.Column( - # db.Enum( - # ShopCartStatus, nullable=False, server_default=(ShopCartStatus.ACTIVE.name) - # ) - # ) + status = db.Column( + db.Enum(ShopCartStatus), + nullable=False, + server_default=(ShopCartStatus.ACTIVE.name), + ) + items = db.relationship("ShopCartItem", backref="shop_cart", passive_deletes=True) def __repr__(self): @@ -50,7 +51,7 @@ def serialize(self) -> dict: "user_id": self.user_id, "name": self.name, "total_price": self.total_price, - # "status": self.status.name, + "status": self.status.name, "items": [], } for item in self.items: @@ -68,7 +69,12 @@ def deserialize(self, data): self.user_id = data["user_id"] self.name = data["name"] self.total_price = data["total_price"] - # self.status = getattr(ShopCartStatus, data["status"]) + # Check if the status in data is already a ShopCartStatus instance + if isinstance(data["status"], ShopCartStatus): + self.status = data["status"] + else: + # If it's not, assume it's a string and try to convert it + self.status = ShopCartStatus[data["status"]] item_list = data.get("items") if item_list: for json_item in item_list: diff --git a/tests/factories.py b/tests/factories.py index 3886451..2427467 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,8 +1,9 @@ """Test Factory""" import factory -from factory.fuzzy import FuzzyDecimal +from factory.fuzzy import FuzzyDecimal, FuzzyChoice from service.models import ShopCart, ShopCartItem +from service.models.shop_cart import ShopCartStatus # pylint: disable=too-few-public-methods @@ -18,9 +19,9 @@ class Meta: user_id = factory.Sequence(lambda n: n) name = factory.Sequence(lambda n: f"sc-{n}") total_price = FuzzyDecimal(0.00, 200.00) - # status = FuzzyChoice( - # choices=[ShopCartStatus.ACTIVE, ShopCartStatus.PENDING, ShopCartStatus.INACTIVE] - # ) + status = FuzzyChoice( + choices=[ShopCartStatus.ACTIVE, ShopCartStatus.PENDING, ShopCartStatus.INACTIVE] + ) @factory.post_generation def items( diff --git a/tests/test_routes.py b/tests/test_routes.py index 040d0bc..cdac57c 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -181,6 +181,7 @@ def test_update_shopcart(self): "total_price": Decimal(new_shopcart["total_price"]) + 100, # Example of updating the price # Include updates to other fields here + "status": new_shopcart["status"], } # Update the shopcart @@ -210,6 +211,7 @@ def test_update_shop_cart_with_invalid_fields(self): "user_id": new_shopcart["user_id"], "name": "Updated Name", "total_price": new_shopcart["total_price"], + "status": new_shopcart["status"], "non_existent_field": "test", } diff --git a/tests/test_shop_cart.py b/tests/test_shop_cart.py index 8b8e489..ad43183 100644 --- a/tests/test_shop_cart.py +++ b/tests/test_shop_cart.py @@ -63,6 +63,7 @@ def test_create_a_shop_cart(self): "user_id": fake_cart.user_id, "name": fake_cart.name, "total_price": fake_cart.total_price, + "status": fake_cart.status, } cart = ShopCart() cart.deserialize(fake_cart_dict) @@ -73,6 +74,7 @@ def test_create_a_shop_cart(self): self.assertEqual(cart.user_id, fake_cart.user_id) self.assertEqual(cart.name, fake_cart.name) self.assertEqual(cart.total_price, fake_cart.total_price) + self.assertEqual(cart.status, fake_cart.status) def test_add_a_shop_cart(self): """It should Create a shopcart and add it to the database"""