Skip to content

Commit

Permalink
Improve Pylint score
Browse files Browse the repository at this point in the history
  • Loading branch information
johnHostetter committed Aug 15, 2024
1 parent 8570491 commit be24cc7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/fuzzy/relations/continuous/tnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TNorm(Enum):
DIF = "dif"


class AlgebraicProduct(torch.nn.Module): # TODO: remove this class
class AlgebraicProduct(torch.nn.Module):
"""
Implementation of the Algebraic Product t-norm (Fuzzy AND).
"""
Expand Down
15 changes: 6 additions & 9 deletions src/fuzzy/sets/continuous/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import torch
import torchquad
import numpy as np
import scienceplots

# import scienceplots is used via plt.style.context(["science", "no-latex", "high-contrast"])
import scienceplots # noqa # pylint: disable=unused-import
import matplotlib as mpl
import matplotlib.pyplot as plt
from torchquad.utils.set_up_backend import set_up_backend
Expand Down Expand Up @@ -46,7 +48,6 @@ def __init__(
widths: np.ndarray,
device: torch.device,
use_sparse_tensor=False,
labels: List[str] = None,
):
super().__init__()
self.device = device
Expand Down Expand Up @@ -94,7 +95,6 @@ def __init__(
# ]
# )
self._mask = [self.make_mask(widths)]
self.labels = labels # TODO: possibly remove this attribute

def to(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -204,7 +204,7 @@ def __hash__(self):
Returns:
The hash of the fuzzy set.
"""
return hash((type(self), self.get_centers(), self.get_widths(), self.labels))
return hash((type(self), self.get_centers(), self.get_widths()))

def get_centers(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -272,7 +272,6 @@ def save(self, path: Path):
state_dict["centers"] = self.get_centers() # concatenate the centers
state_dict["widths"] = self.get_widths() # concatenate the widths
state_dict["mask"] = self.get_mask() # currently not used
state_dict["labels"] = self.labels
state_dict["class_name"] = self.__class__.__name__
if ".pt" not in path.name and ".pth" not in path.name:
raise ValueError(
Expand Down Expand Up @@ -325,12 +324,10 @@ def load(cls, path: Path, device: torch.device) -> "ContinuousFuzzySet":
state_dict: MutableMapping = torch.load(path)
centers = state_dict.pop("centers")
widths = state_dict.pop("widths")
labels = state_dict.pop("labels")
class_name = state_dict.pop("class_name")
return cls.get_subclass(class_name)(
centers=centers.cpu().detach().numpy(),
widths=widths.cpu().detach().numpy(),
labels=labels,
device=device,
)

Expand Down Expand Up @@ -362,7 +359,7 @@ def extend(self, centers: torch.Tensor, widths: torch.Tensor, mode: str):
method_of_extension([self._widths[0], widths])
)

def area_helper(self, fuzzy_sets) -> List[List[float]]:
def _area_helper(self, fuzzy_sets) -> List[List[float]]:
"""
Splits the fuzzy set (if representing a fuzzy variable) into individual fuzzy sets (the
fuzzy variable's possible fuzzy terms), and does so recursively until the base case is
Expand Down Expand Up @@ -429,7 +426,7 @@ def area(self) -> torch.Tensor:
torch.Tensor
"""
return torch.tensor(
self.area_helper(self), device=self.device, dtype=torch.float32
self._area_helper(self), device=self.device, dtype=torch.float32
)

def split_by_variables(self) -> Union[list, List[Type["ContinuousFuzzySet"]]]:
Expand Down
39 changes: 25 additions & 14 deletions src/fuzzy/sets/continuous/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, *args, modules_list=None, expandable=False, **kwargs):
modules_list = []
self.modules_list = torch.nn.ModuleList(modules_list)
self.expandable = expandable
self.pruning = False
self.epsilon = 1.5 # epsilon-completeness
# keep track of minimums and maximums if for fuzzy set width calculation
self.minimums: torch.Tensor = torch.empty(0, 0)
Expand Down Expand Up @@ -247,8 +246,12 @@ def expand(
self.minimums = minimums
self.maximums = maximums
else:
self.minimums = torch.min(minimums, self.minimums).detach()
self.maximums = torch.max(maximums, self.maximums).detach()
self.minimums = torch.min(
minimums, self.minimums
).detach()
self.maximums = torch.max(
maximums, self.maximums
).detach()

# find where the new centers should be added, if any
# LogGaussian was used, then use following to check for real membership degrees:
Expand All @@ -258,8 +261,8 @@ def expand(
# ):
# with torch.no_grad():
# assert (
# module_responses.exp() * module_masks
# ).max().item() <= 1.0, "Membership degrees are not in the range [0, 1]."
# module_responses.exp() * module_masks
# ).max().item() <= 1.0, "Memberships are not in the range [0, 1]."

exemplars: List[torch.Tensor] = []

Expand Down Expand Up @@ -295,15 +298,17 @@ def expand(
new_centers = torch.where(
self.calculate_module_responses(exemplars)
.degrees.exp()
.max(dim=-1) # TODO: assuming LogGaussian was used (exp)
.max(
dim=-1
) # TODO: assume LogGaussian is used (exp) # pylint: disable=fixme
.values
< self.epsilon,
exemplars,
new_centers,
)

if not new_centers.isnan().all(): # add new centers
# TODO: this find_centers_and_widths call is problematic
# TODO: find_centers_and_widths call is problematic # pylint: disable=fixme
new_widths: torch.Tensor = find_widths(
data_point=new_centers.nan_to_num(0.0).mean(dim=0),
minimums=self.minimums,
Expand Down Expand Up @@ -337,14 +342,20 @@ def expand(
new_widths.transpose(0, 1).max(dim=-1, keepdim=True).values
)

# TODO: this code does not work for torch.jit.script
# TODO: this code does not work for torch.jit.script # pylint: disable=fixme
# the following assumes only the first module is to be expanded
module = self.modules_list[0]
module._centers.append(module.make_parameter(parameter=new_centers))
module._widths.append(module.make_parameter(parameter=new_widths))
module._mask.append(module.make_mask(widths=new_widths))
module._centers.append( # pylint: disable=protected-access
module.make_parameter(parameter=new_centers)
)
module._widths.append( # pylint: disable=protected-access
module.make_parameter(parameter=new_widths)
)
module._mask.append( # pylint: disable=protected-access
module.make_mask(widths=new_widths)
)

# TODO: this code does not work for torch.jit.script
# TODO: this code does not work for torch.jit.script # pylint: disable=fixme
# the following assumes an entire new module is to be added
# module_type = type(self.modules_list[0]) # cannot call type
# if issubclass(module_type, ContinuousFuzzySet):
Expand Down Expand Up @@ -425,7 +436,7 @@ def prune(self, module_type: Type[ContinuousFuzzySet]) -> None:
collapsing the rest of the modules into a single module. This is done to reduce the
number of torch.nn.Modules in the list for computational efficiency.
"""
if self.pruning and len(self.modules_list) > 5:
if len(self.modules_list) > 5:
centers, widths = [], []
for module in self.modules_list[1:]:
if module.centers.shape[-1] > 1:
Expand Down Expand Up @@ -458,7 +469,7 @@ def forward(self, observations) -> Membership:
module_masks,
) = self.calculate_module_responses(observations)

# TODO: this code does not work for torch.jit.script
# TODO: this code does not work for torch.jit.script # pylint: disable=fixme
# self.expand(observations, module_responses, module_masks)

return Membership(
Expand Down
11 changes: 4 additions & 7 deletions src/fuzzy/sets/continuous/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Implements various membership functions by inheriting from ContinuousFuzzySet.
"""

from typing import List, Union
from typing import Union

import sympy
import torch
Expand All @@ -23,10 +23,9 @@ def __init__(
centers=None,
widths=None,
width_multiplier: float = 1.0, # in fuzzy logic, convention is usually 1.0, but can be 2.0
labels: List[str] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super().__init__(centers=centers, widths=widths, labels=labels, device=device)
super().__init__(centers=centers, widths=widths, device=device)
self.width_multiplier = width_multiplier
assert int(self.width_multiplier) in [1, 2]

Expand Down Expand Up @@ -220,10 +219,9 @@ def __init__(
self,
centers=None,
widths=None,
labels: List[str] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super().__init__(centers=centers, widths=widths, labels=labels, device=device)
super().__init__(centers=centers, widths=widths, device=device)

@property
@torch.jit.ignore
Expand Down Expand Up @@ -371,10 +369,9 @@ def __init__(
self,
centers=None,
widths=None,
labels: List[str] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super().__init__(centers=centers, widths=widths, labels=labels, device=device)
super().__init__(centers=centers, widths=widths, device=device)

@staticmethod
def internal_calculate_membership(
Expand Down
2 changes: 0 additions & 2 deletions tests/test_sets/continuous/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_save_and_load(self) -> None:
)
# except the saved state dict includes additional information not captured by
# the original state dict, such as the class name and the labels
assert "labels" in saved_state_dict.keys()
assert "class_name" in saved_state_dict.keys() and saved_state_dict[
"class_name"
] in (subclass.__name__ for subclass in ContinuousFuzzySet.__subclasses__())
Expand All @@ -95,7 +94,6 @@ def test_save_and_load(self) -> None:
assert torch.allclose(
membership_func.sigmas, loaded_membership_func.sigmas
)
assert membership_func.labels == loaded_membership_func.labels
# check some functionality that it is still working
assert torch.allclose(membership_func.area(), loaded_membership_func.area())
assert torch.allclose(
Expand Down

0 comments on commit be24cc7

Please sign in to comment.