diff --git a/src/fuzzy/relations/continuous/tnorm.py b/src/fuzzy/relations/continuous/tnorm.py index e990283..d0f92dc 100644 --- a/src/fuzzy/relations/continuous/tnorm.py +++ b/src/fuzzy/relations/continuous/tnorm.py @@ -28,7 +28,7 @@ class TNorm(Enum): DIF = "dif" -class AlgebraicProduct(torch.nn.Module): # TODO: remove this class +class AlgebraicProduct(torch.nn.Module): """ Implementation of the Algebraic Product t-norm (Fuzzy AND). """ diff --git a/src/fuzzy/sets/continuous/abstract.py b/src/fuzzy/sets/continuous/abstract.py index faafbec..33d5ea0 100644 --- a/src/fuzzy/sets/continuous/abstract.py +++ b/src/fuzzy/sets/continuous/abstract.py @@ -14,7 +14,9 @@ import torch import torchquad import numpy as np -import scienceplots + +# import scienceplots is used via plt.style.context(["science", "no-latex", "high-contrast"]) +import scienceplots # noqa # pylint: disable=unused-import import matplotlib as mpl import matplotlib.pyplot as plt from torchquad.utils.set_up_backend import set_up_backend @@ -46,7 +48,6 @@ def __init__( widths: np.ndarray, device: torch.device, use_sparse_tensor=False, - labels: List[str] = None, ): super().__init__() self.device = device @@ -94,7 +95,6 @@ def __init__( # ] # ) self._mask = [self.make_mask(widths)] - self.labels = labels # TODO: possibly remove this attribute def to(self, *args, **kwargs): """ @@ -204,7 +204,7 @@ def __hash__(self): Returns: The hash of the fuzzy set. """ - return hash((type(self), self.get_centers(), self.get_widths(), self.labels)) + return hash((type(self), self.get_centers(), self.get_widths())) def get_centers(self) -> torch.Tensor: """ @@ -272,7 +272,6 @@ def save(self, path: Path): state_dict["centers"] = self.get_centers() # concatenate the centers state_dict["widths"] = self.get_widths() # concatenate the widths state_dict["mask"] = self.get_mask() # currently not used - state_dict["labels"] = self.labels state_dict["class_name"] = self.__class__.__name__ if ".pt" not in path.name and ".pth" not in path.name: raise ValueError( @@ -325,12 +324,10 @@ def load(cls, path: Path, device: torch.device) -> "ContinuousFuzzySet": state_dict: MutableMapping = torch.load(path) centers = state_dict.pop("centers") widths = state_dict.pop("widths") - labels = state_dict.pop("labels") class_name = state_dict.pop("class_name") return cls.get_subclass(class_name)( centers=centers.cpu().detach().numpy(), widths=widths.cpu().detach().numpy(), - labels=labels, device=device, ) @@ -362,7 +359,7 @@ def extend(self, centers: torch.Tensor, widths: torch.Tensor, mode: str): method_of_extension([self._widths[0], widths]) ) - def area_helper(self, fuzzy_sets) -> List[List[float]]: + def _area_helper(self, fuzzy_sets) -> List[List[float]]: """ Splits the fuzzy set (if representing a fuzzy variable) into individual fuzzy sets (the fuzzy variable's possible fuzzy terms), and does so recursively until the base case is @@ -429,7 +426,7 @@ def area(self) -> torch.Tensor: torch.Tensor """ return torch.tensor( - self.area_helper(self), device=self.device, dtype=torch.float32 + self._area_helper(self), device=self.device, dtype=torch.float32 ) def split_by_variables(self) -> Union[list, List[Type["ContinuousFuzzySet"]]]: diff --git a/src/fuzzy/sets/continuous/group.py b/src/fuzzy/sets/continuous/group.py index 9a5a734..b0c5132 100644 --- a/src/fuzzy/sets/continuous/group.py +++ b/src/fuzzy/sets/continuous/group.py @@ -41,7 +41,6 @@ def __init__(self, *args, modules_list=None, expandable=False, **kwargs): modules_list = [] self.modules_list = torch.nn.ModuleList(modules_list) self.expandable = expandable - self.pruning = False self.epsilon = 1.5 # epsilon-completeness # keep track of minimums and maximums if for fuzzy set width calculation self.minimums: torch.Tensor = torch.empty(0, 0) @@ -247,8 +246,12 @@ def expand( self.minimums = minimums self.maximums = maximums else: - self.minimums = torch.min(minimums, self.minimums).detach() - self.maximums = torch.max(maximums, self.maximums).detach() + self.minimums = torch.min( + minimums, self.minimums + ).detach() + self.maximums = torch.max( + maximums, self.maximums + ).detach() # find where the new centers should be added, if any # LogGaussian was used, then use following to check for real membership degrees: @@ -258,8 +261,8 @@ def expand( # ): # with torch.no_grad(): # assert ( - # module_responses.exp() * module_masks - # ).max().item() <= 1.0, "Membership degrees are not in the range [0, 1]." + # module_responses.exp() * module_masks + # ).max().item() <= 1.0, "Memberships are not in the range [0, 1]." exemplars: List[torch.Tensor] = [] @@ -295,7 +298,9 @@ def expand( new_centers = torch.where( self.calculate_module_responses(exemplars) .degrees.exp() - .max(dim=-1) # TODO: assuming LogGaussian was used (exp) + .max( + dim=-1 + ) # TODO: assume LogGaussian is used (exp) # pylint: disable=fixme .values < self.epsilon, exemplars, @@ -303,7 +308,7 @@ def expand( ) if not new_centers.isnan().all(): # add new centers - # TODO: this find_centers_and_widths call is problematic + # TODO: find_centers_and_widths call is problematic # pylint: disable=fixme new_widths: torch.Tensor = find_widths( data_point=new_centers.nan_to_num(0.0).mean(dim=0), minimums=self.minimums, @@ -337,14 +342,20 @@ def expand( new_widths.transpose(0, 1).max(dim=-1, keepdim=True).values ) - # TODO: this code does not work for torch.jit.script + # TODO: this code does not work for torch.jit.script # pylint: disable=fixme # the following assumes only the first module is to be expanded module = self.modules_list[0] - module._centers.append(module.make_parameter(parameter=new_centers)) - module._widths.append(module.make_parameter(parameter=new_widths)) - module._mask.append(module.make_mask(widths=new_widths)) + module._centers.append( # pylint: disable=protected-access + module.make_parameter(parameter=new_centers) + ) + module._widths.append( # pylint: disable=protected-access + module.make_parameter(parameter=new_widths) + ) + module._mask.append( # pylint: disable=protected-access + module.make_mask(widths=new_widths) + ) - # TODO: this code does not work for torch.jit.script + # TODO: this code does not work for torch.jit.script # pylint: disable=fixme # the following assumes an entire new module is to be added # module_type = type(self.modules_list[0]) # cannot call type # if issubclass(module_type, ContinuousFuzzySet): @@ -425,7 +436,7 @@ def prune(self, module_type: Type[ContinuousFuzzySet]) -> None: collapsing the rest of the modules into a single module. This is done to reduce the number of torch.nn.Modules in the list for computational efficiency. """ - if self.pruning and len(self.modules_list) > 5: + if len(self.modules_list) > 5: centers, widths = [], [] for module in self.modules_list[1:]: if module.centers.shape[-1] > 1: @@ -458,7 +469,7 @@ def forward(self, observations) -> Membership: module_masks, ) = self.calculate_module_responses(observations) - # TODO: this code does not work for torch.jit.script + # TODO: this code does not work for torch.jit.script # pylint: disable=fixme # self.expand(observations, module_responses, module_masks) return Membership( diff --git a/src/fuzzy/sets/continuous/impl.py b/src/fuzzy/sets/continuous/impl.py index 7054adf..c244dc6 100644 --- a/src/fuzzy/sets/continuous/impl.py +++ b/src/fuzzy/sets/continuous/impl.py @@ -2,7 +2,7 @@ Implements various membership functions by inheriting from ContinuousFuzzySet. """ -from typing import List, Union +from typing import Union import sympy import torch @@ -23,10 +23,9 @@ def __init__( centers=None, widths=None, width_multiplier: float = 1.0, # in fuzzy logic, convention is usually 1.0, but can be 2.0 - labels: List[str] = None, device: Union[str, torch.device] = torch.device("cpu"), ): - super().__init__(centers=centers, widths=widths, labels=labels, device=device) + super().__init__(centers=centers, widths=widths, device=device) self.width_multiplier = width_multiplier assert int(self.width_multiplier) in [1, 2] @@ -220,10 +219,9 @@ def __init__( self, centers=None, widths=None, - labels: List[str] = None, device: Union[str, torch.device] = torch.device("cpu"), ): - super().__init__(centers=centers, widths=widths, labels=labels, device=device) + super().__init__(centers=centers, widths=widths, device=device) @property @torch.jit.ignore @@ -371,10 +369,9 @@ def __init__( self, centers=None, widths=None, - labels: List[str] = None, device: Union[str, torch.device] = torch.device("cpu"), ): - super().__init__(centers=centers, widths=widths, labels=labels, device=device) + super().__init__(centers=centers, widths=widths, device=device) @staticmethod def internal_calculate_membership( diff --git a/tests/test_sets/continuous/test_continuous.py b/tests/test_sets/continuous/test_continuous.py index 3db469e..5c71668 100644 --- a/tests/test_sets/continuous/test_continuous.py +++ b/tests/test_sets/continuous/test_continuous.py @@ -73,7 +73,6 @@ def test_save_and_load(self) -> None: ) # except the saved state dict includes additional information not captured by # the original state dict, such as the class name and the labels - assert "labels" in saved_state_dict.keys() assert "class_name" in saved_state_dict.keys() and saved_state_dict[ "class_name" ] in (subclass.__name__ for subclass in ContinuousFuzzySet.__subclasses__()) @@ -95,7 +94,6 @@ def test_save_and_load(self) -> None: assert torch.allclose( membership_func.sigmas, loaded_membership_func.sigmas ) - assert membership_func.labels == loaded_membership_func.labels # check some functionality that it is still working assert torch.allclose(membership_func.area(), loaded_membership_func.area()) assert torch.allclose(