diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 55035df02..0375cda91 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -386,7 +386,7 @@ def weights_torch(model, fmt='longform', plot='boxplot'): class WeightsTorch: - def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='boxplot') -> None: + def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None: self.model = model self.fmt = fmt self.plot = plot @@ -396,18 +396,18 @@ def __init__(self, model : torch.nn.Module, fmt : str ='longform', plot : str='b def _find_layers(self, model, module_name): for name, module in model.named_children(): if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)): - self._find_layers(module, module_name+"."+name) - elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): - if len(list(module.named_children())) != 0: + self._find_layers(module, module_name + "." + name) + elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module): + if len(list(module.named_children())) != 0: # custom nn.Module, continue search - self._find_layers(module, module_name+"."+name) + self._find_layers(module, module_name + "." + name) else: - self._register_layer(module_name+"."+name) + self._register_layer(module_name + "." + name) - def _is_registered(self, name : str) -> bool: + def _is_registered(self, name: str) -> bool: return name in self.registered_layers - def _register_layer(self, name : str) -> None: + def _register_layer(self, name: str) -> None: if self._is_registered(name) == False: self.registered_layers.append(name) @@ -448,8 +448,8 @@ def _get_weights(self) -> pandas.DataFrame: def get_weights(self) -> dict: return self._get_weights() - - def _get_layer(self, layer_name : str, module : torch.nn.Module) -> torch.nn.Module: + + def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module: for name in layer_name.split('.')[1:]: module = getattr(module, name) return module