diff --git a/py_ballisticcalc/tests.py b/py_ballisticcalc/tests.py index e16344b..6afcdb6 100644 --- a/py_ballisticcalc/tests.py +++ b/py_ballisticcalc/tests.py @@ -334,12 +334,21 @@ def test_weight(self): test_back_n_forth(self, 3, u) -class TestUnitConversionSyntax(unittest.TestCase): +class TestAbstractUnit(unittest.TestCase): def setUp(self) -> None: self.low = Distance.Yard(10) self.high = Distance.Yard(100) + def test__add__(self): + self.assertEqual(Distance.Inch(1) + Distance.Foot(1) + 2, 15) + try: + Distance.Inch(1) + 2 + Velocity.MPS(9) + raise ArithmeticError + except TypeError as expected: + print(expected) + return + def test__eq__(self): self.assertEqual(self.low, 360) self.assertEqual(360, self.low) @@ -348,6 +357,7 @@ def test__eq__(self): def test__ne__(self): self.assertNotEqual(Distance.Yard(100), Distance.Yard(90)) + self.assertNotEqual(Velocity.MPS(10), Distance.Inch(10)) def test__lt__(self): self.assertLess(self.low, self.high) diff --git a/py_ballisticcalc/unit.py b/py_ballisticcalc/unit.py index e9cc0ea..e13515e 100644 --- a/py_ballisticcalc/unit.py +++ b/py_ballisticcalc/unit.py @@ -6,6 +6,7 @@ from enum import IntEnum from math import pi, atan, tan from typing import NamedTuple, Callable +from warnings import warn __all__ = ('Unit', 'AbstractUnit', 'UnitPropsDict', 'Distance', 'Velocity', 'Angular', 'Temperature', 'Pressure', @@ -170,19 +171,55 @@ class UnitProps(NamedTuple): } -class AbstractUnit: +class UnitsComparisonWarning(UserWarning): + pass + + +class AbstractUnit(float): """Abstract class for unit of measure instance definition Stores defined unit and value, applies conversions to other units """ - __slots__ = ('_value', '_defined_units') + # __slots__ = ('_value', '_defined_units') + __slots__ = ('_defined_units',) - def __init__(self, value: [float, int], units: Unit): - """ - :param units: unit as Unit enum - :param value: numeric value of the unit - """ - self._value: float = self.to_raw(value, units) - self._defined_units: Unit = units + def __new__(cls, value: [float, int], units: Unit): + instance = super().__new__(cls, cls.to_raw(cls, value, units)) + instance._defined_units = units + return instance + + def __add__(self, other: [object, float]) -> 'AbstractUnit': + return self.__arithmetical(super().__add__, other) + + def __radd__(self, other: [object, float]) -> 'AbstractUnit': + return self.__arithmetical(super().__radd__, other) + + def __sub__(self, other: [object, float]) -> 'AbstractUnit': + return self.__arithmetical(super().__sub__, other) + + def __rsub__(self, other: [object, float]) -> 'AbstractUnit': + return self.__arithmetical(super().__rsub__, other) + + def __arithmetical(self, func, other): + self.__validate_arithmetical(func, other) + return self.__class__(func(other), self.units) + + def __validate_arithmetical(self, func, other): + print(type(other)) + if isinstance(other, self.__class__): + return True + if not isinstance(other, AbstractUnit) and isinstance(other, (float, int)): + return True + raise TypeError( + f"Operation {func.__name__} between " + f"<{self.__class__.__name__}> and <{other.__class__.__name__}> isn't support") + + # def __init__(self, value: [float, int], units: Unit): + # """ + # :param units: unit as Unit enum + # :param value: numeric value of the unit + # """ + # self._value: float = self.to_raw(value, units) + # self._defined_units: Unit = units def __str__(self) -> str: """Returns readable unit value @@ -190,14 +227,14 @@ def __str__(self) -> str: """ units = self._defined_units props = UnitPropsDict[units] - v = self.from_raw(self._value, units) + v = self.from_raw(float(self), units) return f'{round(v, props.accuracy)} {props.symbol}' def __repr__(self): """Returns instance as readable view :return: instance as readable view """ - return f'<{self.__class__.__name__}: {self << self.units} ({round(self._value, 4)})>' + return f'<{self.__class__.__name__}: {self << self.units} ({round(self, 4)})>' # def __format__(self, format_spec: str = "{v:.{a}f} {s}"): # """ @@ -206,23 +243,45 @@ def __repr__(self): # return format_spec.format(v=self._value, a=self._defined_units.key, \ # s=self._defined_units.symbol) - def __float__(self): - return float(self._value) - - def __eq__(self, other): - return float(self) == other + # def __float__(self): + # return self - def __lt__(self, other): - return float(self) < other + def __comparing(self, func, other): + print(f'ok <{self.__class__.__name__}> and <{other.__class__.__name__}>') - def __gt__(self, other): - return float(self) > other + if self.__validate_comparing(other): + return func(other) + return False - def __le__(self, other): - return float(self) <= other + def __validate_comparing(self, other): + if isinstance(other, self.__class__): + return True + if not isinstance(other, AbstractUnit) and isinstance(other, (float, int)): + return True + warn(f"\nDo not recommended" + f"\nto compare different types of measure units," + f"\na.g. <{self.__class__.__name__}> and <{other.__class__.__name__}>", + UnitsComparisonWarning) + return False - def __ge__(self, other): - return float(self) >= other + def __eq__(self, other): + return self.__comparing(super().__eq__, other) + + def __ne__(self, other): + return self.__comparing(super().__ne__, other) + + # + # def __lt__(self, other): + # return float(self) < other + # + # def __gt__(self, other): + # return float(self) > other + # + # def __le__(self, other): + # return float(self) <= other + # + # def __ge__(self, other): + # return float(self) >= other def __lshift__(self, other: Unit): return self.convert(other) @@ -233,7 +292,7 @@ def __rshift__(self, other: Unit): def __rlshift__(self, other: Unit): return self.convert(other) - def _unit_support_error(self, value: float, units: Unit): + def _unit_support_error(self, value: float, units: Unit) -> [float | None]: """Validates the units :param value: value of the unit :param units: Unit enum type @@ -245,9 +304,9 @@ def _unit_support_error(self, value: float, units: Unit): raise TypeError(err_msg) if units not in self.__dict__.values(): raise ValueError(f'{self.__class__.__name__}: unit {units} is not supported') - return 0 + return None - def to_raw(self, value: float, units: Unit) -> float: + def to_raw(self, value: float, units: Unit) -> [float, None]: """Converts value with specified units to raw value :param value: value of the unit :param units: Unit enum type @@ -255,7 +314,7 @@ def to_raw(self, value: float, units: Unit) -> float: """ return self._unit_support_error(value, units) - def from_raw(self, value: float, units: Unit) -> float: + def from_raw(self, value: float, units: Unit) -> [float, None]: """Converts raw value to specified units :param value: raw value of the unit :param units: Unit enum type @@ -271,12 +330,12 @@ def convert(self, units: Unit) -> 'AbstractUnit': value = self.get_in(units) return self.__class__(value, units) - def get_in(self, units: Unit) -> float: + def get_in(self, units: Unit) -> [float, None]: """Returns value in specified units :param units: Unit enum type :return: value in specified units """ - return self.from_raw(self._value, units) + return self.from_raw(float(self), units) @property def units(self) -> Unit: @@ -285,12 +344,12 @@ def units(self) -> Unit: """ return self._defined_units - @property - def raw_value(self) -> float: - """Raw unit value getter - :return: raw unit value - """ - return self._value + # @property + # def raw_value(self) -> float: + # """Raw unit value getter + # :return: raw unit value + # """ + # return self._value class Distance(AbstractUnit): @@ -674,3 +733,5 @@ def is_unit(obj: [AbstractUnit, float, int]): # Velocity.MPS # Temperature.Fahrenheit # Pressure.MmHg + +# print(Distance.Inch(10) == Velocity.MPS(10))