diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index d49c4ded..64262ec9 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -8,6 +8,7 @@ """ from __future__ import annotations +import collections import dataclasses import enum from collections.abc import Mapping, Sequence @@ -402,6 +403,15 @@ class Config(BaseLearner.Config): # It will raise an error if none of the rules matches the param path. rules: Required[Sequence[tuple[str, str]]] = REQUIRED + # Maps a sublearner name to its corresponding loss name for gradient computation. + # `model.forward` returns `tuple(loss, aux)`, and the loss name picks which loss to use. + # - If set to `None`, all sublearners use the default loss (the first value of the tuple). + # - If the value is `"loss"`, the default loss is used. + # - For a specific loss name, it retrieves the loss from the `aux` dict + # (e.g., aux["loss_name"]). The specified loss must be a scalar float Tensor. + # - Sublearner names not explicitly mapped are automatically assigned to the default loss. + sublearner_to_loss: Optional[dict[str, str]] = None + # Ema config. All parameters should share the same ema config. # See Learner ema for more details. ema: InstantiableConfig = config_for_function(param_ema) @@ -428,6 +438,30 @@ def __init__(self, cfg: Config, *, parent: Module): # Create a global model ema. self.ema: PartitionedGradientTransformation = cfg.ema.instantiate() + self.loss_to_sublearner = self._create_loss_to_sublearner() + + def _create_loss_to_sublearner(self) -> dict[str, list[str]]: + # Create a dict that maps all losses to all sublayer names without any edge cases. + cfg = self.config + sublearner_set = set(cfg.learners.keys()) + loss_to_sublearner = collections.defaultdict(list) + + if not cfg.sublearner_to_loss: + # sorted() is to make the output deterministic for unittest. + loss_to_sublearner["loss"] = sorted(list(sublearner_set)) + return loss_to_sublearner + + for sublearner_name, loss_name in cfg.sublearner_to_loss.items(): + if sublearner_name not in sublearner_set: + raise ValueError(f"'{sublearner_name}' is not found in the known learners.") + loss_to_sublearner[loss_name].append(sublearner_name) + sublearner_set.remove(sublearner_name) + + loss_to_sublearner["loss"].extend(list(sublearner_set)) + if not loss_to_sublearner["loss"]: + raise ValueError("'loss' must map to at least one sublearner.") + return loss_to_sublearner + def _learner_tree(self, params: Nested[Any]) -> Nested[str]: """Returns a tree of the same structure as params where each leaf is the name of the sublearner to apply. @@ -551,36 +585,99 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]: def forward_and_backward( self, *, fn: ForwardFn, inputs: Nested[Tensor], opt_params: Nested[OptParam] ) -> ForwardBackwardOutputs: - with child_context( - "should_compute_gradients", module=self, output_collection=new_output_collection() + loss_to_updates = {} + loss_to_should_apply = {} + # 1. The updates for each loss name are collected. + # Note: Since there are no overlapping updated parameters between loss names, this does not + # use additional memory. + for loss, sublearner_list in self.loss_to_sublearner.items(): + with child_context( + f"should_compute_gradients_{loss}", + module=self, + output_collection=new_output_collection(), + ): + should_compute_gradients = self.should_update_with_optimizers( + opt_params, sublearner_list + ) + loss_to_should_apply[loss] = should_compute_gradients + # Note: Calling jax.value_and_grad for each loss name results in an overhead where + # forward() is executed as many times as the number of loss names. + loss_to_updates[loss] = _value_and_grad( + fn, + opt_params=opt_params, + inputs=inputs, + should_compute_gradients=should_compute_gradients, + loss_name=loss, + ) + + # 2. Merge all updates for each loss name. + updates = forward_pass = forward_outputs = inplace_updates = None + for loss, this_updates in loss_to_updates.items(): + if loss == "loss": + # The default "loss" provides `forward_pass` and `inplace_updates`. + forward_pass = this_updates.forward_pass["default"] + forward_outputs = forward_pass.outputs # type: ignore + inplace_updates = this_updates.inplace_updates + + if updates is None: + updates = this_updates + continue + + should_apply = loss_to_should_apply[loss] + replacements = {} + replacements["opt_params"] = jax.tree.map( + lambda each_apply, new, old: new if each_apply else old, + should_apply, + this_updates.opt_params, + updates.opt_params, + ) + replacements["delta_updates"] = jax.tree.map( + lambda each_apply, new, old: new if each_apply else old, + should_apply, + this_updates.delta_updates, + updates.delta_updates, + ) + updates = dataclasses.replace(updates, **replacements) + + if ( + (updates is None) + or (forward_pass is None) + or (forward_outputs is None) + or (inplace_updates is None) ): - should_compute_gradients = self.should_update_with_optimizers(opt_params) - updates = _value_and_grad( - fn, - opt_params=opt_params, - inputs=inputs, - should_compute_gradients=should_compute_gradients, - ) - forward_outputs = updates.forward_pass.get("default").outputs # type: ignore + raise ValueError( + f"No updates found; {updates=}, {forward_pass=}, {forward_outputs=}, " + f"{inplace_updates=}" + ) + + # The default "loss" provides `forward_pass` and `inplace_updates`. + updates = dataclasses.replace(updates, forward_pass=dict(default=forward_pass)) + updates = dataclasses.replace(updates, inplace_updates=inplace_updates) + + # 3. Update params and optimizer states using the merged updates. updated_params = self.update(updates) return ForwardBackwardOutputs( forward_outputs=forward_outputs, backward_outputs=BackwardOutputs(updated_params=updated_params), ) - def should_update_with_optimizers(self, model_params: Nested[OptParam]) -> dict: + def should_update_with_optimizers( + self, model_params: Nested[OptParam], sublearner_list: Optional[list[str]] = None + ) -> dict: """Returns whether each parameter should be updated with the optimizers. Args: model_params: A nested structure with OptParams as leaf nodes. + sublearner_list: The list of sublearner names. Returns: A nested dict with the same structure as `model_params` with boolean leaf values. """ - cfg = self.config + if sublearner_list is None: + sublearner_list = self.config.learners.keys() learner_tree = self._learner_tree(params=model_params) should_update = jax.tree.map(lambda p: False, model_params) - for name in cfg.learners.keys(): + for name in sublearner_list: # Whether each parameter should apply the sub learner. should_apply = jax.tree.map( lambda learner_name, n=name: learner_name == n, @@ -652,7 +749,7 @@ def split_params_fn(model_params: Nested) -> tuple[Nested, Nested]: return filtered_forward, split_params_fn -def _as_loss_fn(fun: ForwardFn) -> Callable: +def _as_loss_fn(fun: ForwardFn, *, loss_name: str) -> Callable: """Convert a `ForwardFn` to a function with the same signature execpt that it outputs `loss, forward_pass`. @@ -660,6 +757,7 @@ def _as_loss_fn(fun: ForwardFn) -> Callable: Args: fun: The function to wrap. + loss_name: The loss name for which the gradient will be calculated. Returns: The wrapped function. @@ -667,7 +765,13 @@ def _as_loss_fn(fun: ForwardFn) -> Callable: def forward(model_params: Nested[Tensor], *, inputs: Any) -> tuple[Tensor, ForwardPass]: outputs = fun(model_params=model_params, inputs=inputs) # type: ignore - return outputs.loss, ForwardPass( + if loss_name == "loss": + loss = outputs.loss + else: + if loss_name not in outputs.aux: + raise ValueError(f"{loss_name=} not found in aux: {list(outputs.aux.keys())}") + loss = outputs.aux[loss_name] + return loss, ForwardPass( # We don't use `forward` here since it is not technically a `ForwardFn`. forward_fn=fun, model_params=model_params, @@ -684,6 +788,7 @@ def _value_and_grad( opt_params: Nested[OptParam], inputs: Nested[Tensor], should_compute_gradients: Optional[Nested[bool]] = None, + loss_name: str = "loss", ) -> Updates: """Computes the value and grad of `fun`. @@ -694,6 +799,7 @@ def _value_and_grad( should_compute_gradients: The model parameters to compute gradients for. Has the same tree structure as `model_params`. If None, all parameters have their gradients computed. + loss_name: The loss name for which the gradient will be calculated. Returns: The gradient `Updates` for `fun`. The returned `updates` include a "default" key for @@ -705,7 +811,7 @@ def _value_and_grad( fun, should_compute_gradients=should_compute_gradients ) - loss_fun = _as_loss_fn(fun) + loss_fun = _as_loss_fn(fun, loss_name=loss_name) split_params = split_params_fn(opt_params) model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params) diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index d769a05c..7de07f4b 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. """Tests learner.""" +import contextlib import copy import re from numbers import Number @@ -925,16 +926,18 @@ def test__value_and_grad(self): class CompositeLearnerTest(TestCase): - @parameterized.parameters(None, 0.999) - def test_forward_and_backward(self, ema_decay): + @parameterized.product( + ema_decay=[None, 0.999], + sublearner_to_loss=[None, dict(encoder="loss", decoder="discriminator_loss")], + ) + def test_forward_and_backward(self, ema_decay, sublearner_to_loss): """Demonstrates how API users should use the API while ensuring that it works correctly.""" # Init a learner. - encoder_lr = 0.1 opt1_cfg = config_for_function(sgd_optimizer).set( - learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0 + learning_rate=0.1, decouple_weight_decay=True, weight_decay=1.0 ) - opt2_cfg = config_for_function(adam_optimizer).set( - learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0 + opt2_cfg = config_for_function(sgd_optimizer).set( + learning_rate=0.2, decouple_weight_decay=True, weight_decay=1.0 ) learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")] @@ -949,6 +952,7 @@ def test_forward_and_backward(self, ema_decay): optimizer=opt2_cfg, enable_per_variable_summaries=False ), }, + sublearner_to_loss=sublearner_to_loss, ) cfg.ema.decay = ema_decay learner: CompositeLearner = cfg.instantiate(parent=None) @@ -1009,21 +1013,82 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp learner_state = learner_output_collection.state_updates self.assertGreater(forward_outputs.loss, 0.0) self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0) - # The structure of updated params and optimizer states are same. - opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode)) - self.assertNestedEqual( - jax.tree_util.tree_structure(updated_model_params), - jax.tree_util.tree_structure( - learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn - ), + # The sub structure of updated params must be same to the updated optimizer states. + # pylint: disable-next=protected-access + learner_tree = learner._learner_tree(params=updated_model_params) + for sublearner_name in ("encoder", "decoder"): + sublearner_apply = jax.tree.map( + lambda learner_name, n=sublearner_name: learner_name == n, learner_tree + ) + optimizer_apply = jax.tree.map( + lambda updated_params: isinstance(updated_params, Tensor), + learner_state[sublearner_name]["optimizer"][0].trace, + is_leaf=lambda x: isinstance(x, (Tensor, optax.MaskedNode)), + ) + self.assertNestedEqual(sublearner_apply, optimizer_apply, f"{sublearner_name=}") + + @parameterized.parameters( + dict( + sublearner_to_loss=None, + expected=dict(loss=["decoder", "encoder"]), + ), + dict( + sublearner_to_loss=dict(encoder="loss"), + expected=dict(loss=["encoder", "decoder"]), + ), + dict( + sublearner_to_loss=dict(encoder="loss", decoder="loss"), + expected=dict(loss=["encoder", "decoder"]), + ), + dict( + sublearner_to_loss=dict(decoder="discriminator_loss"), + expected=dict(loss=["encoder"], discriminator_loss=["decoder"]), + ), + dict( + sublearner_to_loss=dict(encoder="loss", decoder="discriminator_loss"), + expected=dict(loss=["encoder"], discriminator_loss=["decoder"]), + ), + dict( + sublearner_to_loss=dict(encoder="loss", decoder="discriminator_loss", what="loss"), + expected=ValueError("'what' is not found in the known learners."), + ), + dict( + sublearner_to_loss=dict(encoder="discriminator_loss", decoder="discriminator_loss"), + expected=ValueError("'loss' must map to at least one sublearner."), + ), + ) + def test_create_loss_to_sublearner(self, sublearner_to_loss, expected): + opt1_cfg = config_for_function(sgd_optimizer).set( + learning_rate=0.1, decouple_weight_decay=True, weight_decay=1.0 ) - self.assertNestedEqual( - jax.tree_util.tree_structure(updated_model_params), - jax.tree_util.tree_structure( - learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn - ), + opt2_cfg = config_for_function(adam_optimizer).set( + learning_rate=0.2, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0 + ) + learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")] + + cfg = CompositeLearner.default_config().set( + name="test", + rules=learner_rules, + learners={ + "encoder": Learner.default_config().set( + optimizer=opt1_cfg, enable_per_variable_summaries=True + ), + "decoder": Learner.default_config().set( + optimizer=opt2_cfg, enable_per_variable_summaries=False + ), + }, + sublearner_to_loss=sublearner_to_loss, ) + if isinstance(expected, Exception): + ctx = self.assertRaisesRegex(type(expected), str(expected)) + else: + ctx = contextlib.nullcontext() + + with ctx: + learner: CompositeLearner = cfg.instantiate(parent=None) + self.assertEqual(learner.loss_to_sublearner, expected) + @parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward")) # pylint: disable-next=too-many-statements def test_learner(self, ema_decay: Optional[float], method: str): diff --git a/axlearn/common/test_utils.py b/axlearn/common/test_utils.py index ed0ee77c..fe036ae6 100644 --- a/axlearn/common/test_utils.py +++ b/axlearn/common/test_utils.py @@ -234,10 +234,10 @@ def assertNestedAllClose(self, a, b, atol=1e-6, rtol=1e-3): else: self.assertAlmostEqual(a_value, b_value, msg=f"{a_name}") - def assertNestedEqual(self, a, b): + def assertNestedEqual(self, a, b, msg=None): a_kv = flatten_items(a) b_kv = flatten_items(b) - self.assertCountEqual([k for k, _ in a_kv], [k for k, _ in b_kv]) + self.assertCountEqual([k for k, _ in a_kv], [k for k, _ in b_kv], msg=msg) a_dict = dict(a_kv) b_dict = dict(b_kv) for k in a_dict: @@ -245,7 +245,7 @@ def assertNestedEqual(self, a, b): b_value = b_dict[k] np.testing.assert_array_equal(a_value, b_value, err_msg=k) if hasattr(a_value, "dtype"): - self.assertEqual(a_value.dtype, b_value.dtype) + self.assertEqual(a_value.dtype, b_value.dtype, msg=msg) # TODO(markblee): Move this to axlearn/experiments/test_utils.py, where it's used.