Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Dec 19, 2024
1 parent 855b138 commit a314242
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def weights_torch(model, fmt='longform', plot='boxplot'):


class WeightsTorch:
def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='boxplot') -> None:
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
self.model = model
self.fmt = fmt
self.plot = plot
Expand All @@ -396,18 +396,18 @@ def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='b
def _find_layers(self, model, module_name):
for name, module in model.named_children():
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
self._find_layers(module, module_name+"."+name)
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
if len(list(module.named_children())) != 0:
self._find_layers(module, module_name + "." + name)
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
if len(list(module.named_children())) != 0:
# custom nn.Module, continue search
self._find_layers(module, module_name+"."+name)
self._find_layers(module, module_name + "." + name)
else:
self._register_layer(module_name+"."+name)
self._register_layer(module_name + "." + name)

def _is_registered(self, name : str) -> bool:
def _is_registered(self, name: str) -> bool:
return name in self.registered_layers

def _register_layer(self, name : str) -> None:
def _register_layer(self, name: str) -> None:
if self._is_registered(name) == False:
self.registered_layers.append(name)

Expand Down Expand Up @@ -448,8 +448,8 @@ def _get_weights(self) -> pandas.DataFrame:

def get_weights(self) -> dict:
return self._get_weights()
def _get_layer(self, layer_name : str, module : torch.nn.Module) -> torch.nn.Module:

def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
for name in layer_name.split('.')[1:]:
module = getattr(module, name)
return module
Expand Down

0 comments on commit a314242

Please sign in to comment.