diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index 5e9d60e9e5..398dea4a62 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -7,6 +7,7 @@ import collections.abc as abc import logging import operator +from typing import Callable import torch import torch.nn as nn @@ -183,8 +184,12 @@ def flops(self, x): elif layer_type in ["AdaptiveAvgPool2d"]: in_h = x.size()[2] in_w = x.size()[3] - out_h = layer.output_size[0] - out_w = layer.output_size[1] + if isinstance(layer.output_size, int): + out_h, out_w = layer.output_size, layer.output_size + elif len(layer.output_size) == 1: + out_h, out_w = layer.output_size[0], layer.output_size[0] + else: + out_h, out_w = layer.output_size if out_h > in_h or out_w > in_w: raise NotImplementedError() batchsize_per_replica = x.size()[0] @@ -295,6 +300,10 @@ def flops(self, x): for dim_size in x.size(): flops *= dim_size return flops + + elif layer_type == "Identity": + return 0 + elif hasattr(layer, "flops"): # If the module already defines a method to compute flops with the signature # below, we use it to compute flops @@ -312,8 +321,16 @@ def _layer_activations(layer, x, out): """ Computes the number of activations produced by a single layer. - Activations are counted only for convolutional layers. + Activations are counted only for convolutional layers. To override this behavior, a + layer can define a method to compute activations with the signature below, which + will be used to compute the activations instead. + + Class MyModule(nn.Module): + def activations(self, x, out): + ... """ + if hasattr(layer, "activations"): + return layer.activations(x, out) return out.numel() if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) else 0 @@ -338,11 +355,25 @@ def summarize_profiler_info(prof): return str -def _patched_computation_module(module, compute_list, compute_fn): +class _ComplexityComputer: + def __init__(self, compute_fn: Callable, count_unique: bool): + self.compute_fn = compute_fn + self.count_unique = count_unique + self.count = 0 + self.seen_modules = set() + + def compute(self, layer, x, out, module_name): + if self.count_unique and module_name in self.seen_modules: + return + self.count += self.compute_fn(layer, x, out) + self.seen_modules.add(module_name) + + +def _patched_computation_module(module, complexity_computer, module_name): """ Patch the module to compute a module's parameters, like FLOPs. - Calls compute_fn and appends the results to compute_list. + Calls compute_fn and passes the results to the complexity computer. """ ty = type(module) typestring = module.__repr__() @@ -355,7 +386,7 @@ def _original_forward(self, *args, **kwargs): def forward(self, *args, **kwargs): out = self._original_forward(*args, **kwargs) - compute_list.append(compute_fn(self, args[0], out)) + complexity_computer.compute(self, args[0], out, module_name) return out def __repr__(self): @@ -364,37 +395,58 @@ def __repr__(self): return ComputeModule -def modify_forward(model, compute_list, compute_fn): +def modify_forward(model, complexity_computer, prefix="", patch_attr=None): """ Modify forward pass to measure a module's parameters, like FLOPs. """ - if is_leaf(model) or hasattr(model, "flops"): - model.__class__ = _patched_computation_module(model, compute_list, compute_fn) + if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)): + model.__class__ = _patched_computation_module( + model, complexity_computer, prefix + ) else: - for child in model.children(): - modify_forward(child, compute_list, compute_fn) + for name, child in model.named_children(): + modify_forward( + child, + complexity_computer, + prefix=f"{prefix}.{name}", + patch_attr=patch_attr, + ) return model -def restore_forward(model): +def restore_forward(model, patch_attr=None): """ - Restore original forward in model: + Restore original forward in model. """ - if is_leaf(model) or hasattr(model, "flops"): + if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)): model.__class__ = model.orig_type else: for child in model.children(): - restore_forward(child) + restore_forward(child, patch_attr=patch_attr) return model -def compute_complexity(model, compute_fn, input_shape, input_key=None): +def compute_complexity( + model, + compute_fn, + input_shape, + input_key=None, + patch_attr=None, + compute_unique=False, +): """ Compute the complexity of a forward pass. + + Args: + compute_unique: If True, the compexity for a given module is only calculated + once. Otherwise, it is counted every time the module is called. + + TODO(@mannatsingh): We have some assumptions about only modules which are leaves + or have patch_attr defined. This should be fixed and generalized if possible. """ # assertions, input, and upvalue in which we will perform the count: assert isinstance(model, nn.Module) @@ -404,10 +456,10 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None): else: input = get_model_dummy_input(model, input_shape, input_key) - compute_list = [] + complexity_computer = _ComplexityComputer(compute_fn, compute_unique) # measure FLOPs: - modify_forward(model, compute_list, compute_fn) + modify_forward(model, complexity_computer, patch_attr=patch_attr) try: # compute complexity in eval mode with eval_model(model), torch.no_grad(): @@ -415,23 +467,27 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None): except NotImplementedError as err: raise err finally: - restore_forward(model) + restore_forward(model, patch_attr=patch_attr) - return sum(compute_list) + return complexity_computer.count def compute_flops(model, input_shape=(3, 224, 224), input_key=None): """ Compute the number of FLOPs needed for a forward pass. """ - return compute_complexity(model, _layer_flops, input_shape, input_key) + return compute_complexity( + model, _layer_flops, input_shape, input_key, patch_attr="flops" + ) def compute_activations(model, input_shape=(3, 224, 224), input_key=None): """ Compute the number of activations created in a forward pass. """ - return compute_complexity(model, _layer_activations, input_shape, input_key) + return compute_complexity( + model, _layer_activations, input_shape, input_key, patch_attr="activations" + ) def count_params(model): @@ -439,15 +495,4 @@ def count_params(model): Count the number of parameters in a model. """ assert isinstance(model, nn.Module) - count = 0 - for child in model.children(): - if is_leaf(child): - if hasattr(child, "_mask"): # for masked modules (like LGC) - count += child._mask.long().sum().item() - # FIXME: BatchNorm parameters in LGC are not counted. - else: # for regular modules - for p in child.parameters(): - count += p.nelement() - else: - count += count_params(child) - return count + return sum((parameter.nelement() for parameter in model.parameters())) diff --git a/test/generic_profiler_test.py b/test/generic_profiler_test.py index 7b39955bb7..cac1c9673c 100644 --- a/test/generic_profiler_test.py +++ b/test/generic_profiler_test.py @@ -7,6 +7,8 @@ import unittest from test.generic.config_utils import get_test_model_configs +import torch +import torch.nn as nn from classy_vision.generic.profiler import ( compute_activations, compute_flops, @@ -15,8 +17,61 @@ from classy_vision.models import build_model +class TestModule(nn.Module): + def __init__(self): + super().__init__() + # add parameters to the module to affect the parameter count + self.linear = nn.Linear(2, 3, bias=False) + + def forward(self, x): + return x + 1 + + def flops(self, x): + # TODO: this should raise an exception if this function is not defined + # since the FLOPs are indeterminable + + # need to define flops since this is an unknown class + return x.numel() + + +class TestConvModule(nn.Conv2d): + def __init__(self): + super().__init__(2, 3, (4, 4), bias=False) + # add another (unused) layer for added complexity and to test parameters + self.linear = nn.Linear(4, 5, bias=False) + + def forward(self, x): + return x + + def activations(self, x, out): + # TODO: this should ideally work without this function being defined + return out.numel() + + def flops(self, x): + # need to define flops since this is an unknown class + return 0 + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(300, 300, bias=False) + self.mod = TestModule() + self.conv = TestConvModule() + # we should be able to pick up user defined parameters as well + self.extra_params = nn.Parameter(torch.randn(10, 10)) + # we shouldn't count flops for an unused layer + self.unused_linear = nn.Linear(2, 2, bias=False) + + def forward(self, x): + out = self.conv(x) + out = out.view(out.shape[0], -1) + out = self.mod(out) + return self.linear(out) + + class TestProfilerFunctions(unittest.TestCase): - def test_complexity_calculation(self) -> None: + def test_complexity_calculation_resnext(self) -> None: model_configs = get_test_model_configs() # make sure there are three configs returned self.assertEqual(len(model_configs), 3) @@ -34,3 +89,21 @@ def test_complexity_calculation(self) -> None: self.assertEqual(compute_activations(model) // 10 ** 6, m_activations) self.assertEqual(compute_flops(model) // 10 ** 6, m_flops) self.assertEqual(count_params(model) // 10 ** 6, m_params) + + def test_complexity_calculation(self) -> None: + model = TestModel() + input_shape = (3, 10, 10) + num_elems = 3 * 10 * 10 + self.assertEqual(compute_activations(model, input_shape=input_shape), num_elems) + self.assertEqual( + compute_flops(model, input_shape=input_shape), + num_elems + + 0 + + (300 * 300), # TestModule + TestConvModule + TestModel.linear; + # TestModel.unused_linear is unused and shouldn't be counted + ) + self.assertEqual( + count_params(model), + (2 * 3) + (2 * 3 * 4 * 4) + (4 * 5) + (300 * 300) + (10 * 10) + (2 * 2), + ) # TestModule.linear + TestConvModule + TestConvModule.linear + + # TestModel.linear + TestModel.extra_params + TestModel.unused_linear