Skip to content

Commit

Permalink
Support aux loss gradient to CompositeLearner.
Browse files Browse the repository at this point in the history
Currently, `CompositeLearner` computes gradients for all sublayers using the
default `loss` from `forward`. However, in cases like GANs, the discriminative
network requires gradients computed using the discriminative loss for parameter
updates.

To generalize this functionality, a `sublearner_to_loss` flag is added to
`CompositeLearner`. This flag maps a sublearner name to its corresponding loss
name for gradient computation. `model.forward` returns `tuple(loss, aux)`, and
the loss name picks which loss to use.
- If set to `None`, all sublearners use the default loss (the first value of the tuple).
- If the value is `"loss"`, the default loss is used.
- For a specific loss name, it retrieves the loss from the `aux` dict
  (e.g., aux["loss_name"]). The specified loss must be a scalar float Tensor.
- Sublearner names not explicitly mapped are automatically assigned to the default loss.

Example: For a GAN, users can configure `CompositeLearner.Config` as follows:
```
cfg = CompositeLearner.default_config().set(
    rules=[(".*encoder.*", "encoder"), (".*decoder.*", "decoder")],
    learners={
        "encoder": Learner.default_config(),
        "decoder": Learner.default_config(),
    },
    sublearner_to_loss=dict(encoder="loss", decoder="discriminative_loss"),
)
```

The implementation calculates gradients in `forward_and_backward()` by invoking
`value_and_grad` for each loss name and its corresponding parameters. Updates
are then merged.
  • Loading branch information
ds-hwang committed Dec 23, 2024
1 parent f91709f commit c677fd7
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 37 deletions.
138 changes: 122 additions & 16 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
from __future__ import annotations

import collections
import dataclasses
import enum
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -402,6 +403,15 @@ class Config(BaseLearner.Config):
# It will raise an error if none of the rules matches the param path.
rules: Required[Sequence[tuple[str, str]]] = REQUIRED

# Maps a sublearner name to its corresponding loss name for gradient computation.
# `model.forward` returns `tuple(loss, aux)`, and the loss name picks which loss to use.
# - If set to `None`, all sublearners use the default loss (the first value of the tuple).
# - If the value is `"loss"`, the default loss is used.
# - For a specific loss name, it retrieves the loss from the `aux` dict
# (e.g., aux["loss_name"]). The specified loss must be a scalar float Tensor.
# - Sublearner names not explicitly mapped are automatically assigned to the default loss.
sublearner_to_loss: Optional[dict[str, str]] = None

# Ema config. All parameters should share the same ema config.
# See Learner ema for more details.
ema: InstantiableConfig = config_for_function(param_ema)
Expand All @@ -428,6 +438,30 @@ def __init__(self, cfg: Config, *, parent: Module):
# Create a global model ema.
self.ema: PartitionedGradientTransformation = cfg.ema.instantiate()

self.loss_to_sublearner = self._create_loss_to_sublearner()

def _create_loss_to_sublearner(self) -> dict[str, list[str]]:
# Create a dict that maps all losses to all sublayer names without any edge cases.
cfg = self.config
sublearner_set = set(cfg.learners.keys())
loss_to_sublearner = collections.defaultdict(list)

if not cfg.sublearner_to_loss:
# sorted() is to make the output deterministic for unittest.
loss_to_sublearner["loss"] = sorted(list(sublearner_set))
return loss_to_sublearner

for sublearner_name, loss_name in cfg.sublearner_to_loss.items():
if sublearner_name not in sublearner_set:
raise ValueError(f"'{sublearner_name}' is not found in the known learners.")
loss_to_sublearner[loss_name].append(sublearner_name)
sublearner_set.remove(sublearner_name)

loss_to_sublearner["loss"].extend(list(sublearner_set))
if not loss_to_sublearner["loss"]:
raise ValueError("'loss' must map to at least one sublearner.")
return loss_to_sublearner

def _learner_tree(self, params: Nested[Any]) -> Nested[str]:
"""Returns a tree of the same structure as params where each leaf is the name of the
sublearner to apply.
Expand Down Expand Up @@ -551,36 +585,99 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]:
def forward_and_backward(
self, *, fn: ForwardFn, inputs: Nested[Tensor], opt_params: Nested[OptParam]
) -> ForwardBackwardOutputs:
with child_context(
"should_compute_gradients", module=self, output_collection=new_output_collection()
loss_to_updates = {}
loss_to_should_apply = {}
# 1. The updates for each loss name are collected.
# Note: Since there are no overlapping updated parameters between loss names, this does not
# use additional memory.
for loss, sublearner_list in self.loss_to_sublearner.items():
with child_context(
f"should_compute_gradients_{loss}",
module=self,
output_collection=new_output_collection(),
):
should_compute_gradients = self.should_update_with_optimizers(
opt_params, sublearner_list
)
loss_to_should_apply[loss] = should_compute_gradients
# Note: Calling jax.value_and_grad for each loss name results in an overhead where
# forward() is executed as many times as the number of loss names.
loss_to_updates[loss] = _value_and_grad(
fn,
opt_params=opt_params,
inputs=inputs,
should_compute_gradients=should_compute_gradients,
loss_name=loss,
)

# 2. Merge all updates for each loss name.
updates = forward_pass = forward_outputs = inplace_updates = None
for loss, this_updates in loss_to_updates.items():
if loss == "loss":
# The default "loss" provides `forward_pass` and `inplace_updates`.
forward_pass = this_updates.forward_pass["default"]
forward_outputs = forward_pass.outputs # type: ignore
inplace_updates = this_updates.inplace_updates

if updates is None:
updates = this_updates
continue

should_apply = loss_to_should_apply[loss]
replacements = {}
replacements["opt_params"] = jax.tree.map(
lambda each_apply, new, old: new if each_apply else old,
should_apply,
this_updates.opt_params,
updates.opt_params,
)
replacements["delta_updates"] = jax.tree.map(
lambda each_apply, new, old: new if each_apply else old,
should_apply,
this_updates.delta_updates,
updates.delta_updates,
)
updates = dataclasses.replace(updates, **replacements)

if (
(updates is None)
or (forward_pass is None)
or (forward_outputs is None)
or (inplace_updates is None)
):
should_compute_gradients = self.should_update_with_optimizers(opt_params)
updates = _value_and_grad(
fn,
opt_params=opt_params,
inputs=inputs,
should_compute_gradients=should_compute_gradients,
)
forward_outputs = updates.forward_pass.get("default").outputs # type: ignore
raise ValueError(
f"No updates found; {updates=}, {forward_pass=}, {forward_outputs=}, "
f"{inplace_updates=}"
)

# The default "loss" provides `forward_pass` and `inplace_updates`.
updates = dataclasses.replace(updates, forward_pass=dict(default=forward_pass))
updates = dataclasses.replace(updates, inplace_updates=inplace_updates)

# 3. Update params and optimizer states using the merged updates.
updated_params = self.update(updates)
return ForwardBackwardOutputs(
forward_outputs=forward_outputs,
backward_outputs=BackwardOutputs(updated_params=updated_params),
)

def should_update_with_optimizers(self, model_params: Nested[OptParam]) -> dict:
def should_update_with_optimizers(
self, model_params: Nested[OptParam], sublearner_list: Optional[list[str]] = None
) -> dict:
"""Returns whether each parameter should be updated with the optimizers.
Args:
model_params: A nested structure with OptParams as leaf nodes.
sublearner_list: The list of sublearner names.
Returns:
A nested dict with the same structure as `model_params` with boolean leaf values.
"""
cfg = self.config
if sublearner_list is None:
sublearner_list = self.config.learners.keys()
learner_tree = self._learner_tree(params=model_params)
should_update = jax.tree.map(lambda p: False, model_params)
for name in cfg.learners.keys():
for name in sublearner_list:
# Whether each parameter should apply the sub learner.
should_apply = jax.tree.map(
lambda learner_name, n=name: learner_name == n,
Expand Down Expand Up @@ -652,22 +749,29 @@ def split_params_fn(model_params: Nested) -> tuple[Nested, Nested]:
return filtered_forward, split_params_fn


def _as_loss_fn(fun: ForwardFn) -> Callable:
def _as_loss_fn(fun: ForwardFn, *, loss_name: str) -> Callable:
"""Convert a `ForwardFn` to a function with the same signature execpt that it outputs
`loss, forward_pass`.
This wrapping makes it compatible with `jax.grad()`.
Args:
fun: The function to wrap.
loss_name: The loss name for which the gradient will be calculated.
Returns:
The wrapped function.
"""

def forward(model_params: Nested[Tensor], *, inputs: Any) -> tuple[Tensor, ForwardPass]:
outputs = fun(model_params=model_params, inputs=inputs) # type: ignore
return outputs.loss, ForwardPass(
if loss_name == "loss":
loss = outputs.loss
else:
if loss_name not in outputs.aux:
raise ValueError(f"{loss_name=} not found in aux: {list(outputs.aux.keys())}")
loss = outputs.aux[loss_name]
return loss, ForwardPass(
# We don't use `forward` here since it is not technically a `ForwardFn`.
forward_fn=fun,
model_params=model_params,
Expand All @@ -684,6 +788,7 @@ def _value_and_grad(
opt_params: Nested[OptParam],
inputs: Nested[Tensor],
should_compute_gradients: Optional[Nested[bool]] = None,
loss_name: str = "loss",
) -> Updates:
"""Computes the value and grad of `fun`.
Expand All @@ -694,6 +799,7 @@ def _value_and_grad(
should_compute_gradients: The model parameters to compute gradients for.
Has the same tree structure as `model_params`.
If None, all parameters have their gradients computed.
loss_name: The loss name for which the gradient will be calculated.
Returns:
The gradient `Updates` for `fun`. The returned `updates` include a "default" key for
Expand All @@ -705,7 +811,7 @@ def _value_and_grad(
fun, should_compute_gradients=should_compute_gradients
)

loss_fun = _as_loss_fn(fun)
loss_fun = _as_loss_fn(fun, loss_name=loss_name)

split_params = split_params_fn(opt_params)
model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params)
Expand Down
101 changes: 83 additions & 18 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
"""Tests learner."""
import contextlib
import copy
import re
from numbers import Number
Expand Down Expand Up @@ -925,16 +926,18 @@ def test__value_and_grad(self):


class CompositeLearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
@parameterized.product(
ema_decay=[None, 0.999],
sublearner_to_loss=[None, dict(encoder="loss", decoder="discriminator_loss")],
)
def test_forward_and_backward(self, ema_decay, sublearner_to_loss):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
encoder_lr = 0.1
opt1_cfg = config_for_function(sgd_optimizer).set(
learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0
learning_rate=0.1, decouple_weight_decay=True, weight_decay=1.0
)
opt2_cfg = config_for_function(adam_optimizer).set(
learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
opt2_cfg = config_for_function(sgd_optimizer).set(
learning_rate=0.2, decouple_weight_decay=True, weight_decay=1.0
)
learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")]

Expand All @@ -949,6 +952,7 @@ def test_forward_and_backward(self, ema_decay):
optimizer=opt2_cfg, enable_per_variable_summaries=False
),
},
sublearner_to_loss=sublearner_to_loss,
)
cfg.ema.decay = ema_decay
learner: CompositeLearner = cfg.instantiate(parent=None)
Expand Down Expand Up @@ -1009,21 +1013,82 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
# The sub structure of updated params must be same to the updated optimizer states.
# pylint: disable-next=protected-access
learner_tree = learner._learner_tree(params=updated_model_params)
for sublearner_name in ("encoder", "decoder"):
sublearner_apply = jax.tree.map(
lambda learner_name, n=sublearner_name: learner_name == n, learner_tree
)
optimizer_apply = jax.tree.map(
lambda updated_params: isinstance(updated_params, Tensor),
learner_state[sublearner_name]["optimizer"][0].trace,
is_leaf=lambda x: isinstance(x, (Tensor, optax.MaskedNode)),
)
self.assertNestedEqual(sublearner_apply, optimizer_apply, f"{sublearner_name=}")

@parameterized.parameters(
dict(
sublearner_to_loss=None,
expected=dict(loss=["decoder", "encoder"]),
),
dict(
sublearner_to_loss=dict(encoder="loss"),
expected=dict(loss=["encoder", "decoder"]),
),
dict(
sublearner_to_loss=dict(encoder="loss", decoder="loss"),
expected=dict(loss=["encoder", "decoder"]),
),
dict(
sublearner_to_loss=dict(decoder="discriminator_loss"),
expected=dict(loss=["encoder"], discriminator_loss=["decoder"]),
),
dict(
sublearner_to_loss=dict(encoder="loss", decoder="discriminator_loss"),
expected=dict(loss=["encoder"], discriminator_loss=["decoder"]),
),
dict(
sublearner_to_loss=dict(encoder="loss", decoder="discriminator_loss", what="loss"),
expected=ValueError("'what' is not found in the known learners."),
),
dict(
sublearner_to_loss=dict(encoder="discriminator_loss", decoder="discriminator_loss"),
expected=ValueError("'loss' must map to at least one sublearner."),
),
)
def test_create_loss_to_sublearner(self, sublearner_to_loss, expected):
opt1_cfg = config_for_function(sgd_optimizer).set(
learning_rate=0.1, decouple_weight_decay=True, weight_decay=1.0
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
opt2_cfg = config_for_function(adam_optimizer).set(
learning_rate=0.2, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")]

cfg = CompositeLearner.default_config().set(
name="test",
rules=learner_rules,
learners={
"encoder": Learner.default_config().set(
optimizer=opt1_cfg, enable_per_variable_summaries=True
),
"decoder": Learner.default_config().set(
optimizer=opt2_cfg, enable_per_variable_summaries=False
),
},
sublearner_to_loss=sublearner_to_loss,
)

if isinstance(expected, Exception):
ctx = self.assertRaisesRegex(type(expected), str(expected))
else:
ctx = contextlib.nullcontext()

with ctx:
learner: CompositeLearner = cfg.instantiate(parent=None)
self.assertEqual(learner.loss_to_sublearner, expected)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
# pylint: disable-next=too-many-statements
def test_learner(self, ema_decay: Optional[float], method: str):
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,18 @@ def assertNestedAllClose(self, a, b, atol=1e-6, rtol=1e-3):
else:
self.assertAlmostEqual(a_value, b_value, msg=f"{a_name}")

def assertNestedEqual(self, a, b):
def assertNestedEqual(self, a, b, msg=None):
a_kv = flatten_items(a)
b_kv = flatten_items(b)
self.assertCountEqual([k for k, _ in a_kv], [k for k, _ in b_kv])
self.assertCountEqual([k for k, _ in a_kv], [k for k, _ in b_kv], msg=msg)
a_dict = dict(a_kv)
b_dict = dict(b_kv)
for k in a_dict:
a_value = a_dict[k]
b_value = b_dict[k]
np.testing.assert_array_equal(a_value, b_value, err_msg=k)
if hasattr(a_value, "dtype"):
self.assertEqual(a_value.dtype, b_value.dtype)
self.assertEqual(a_value.dtype, b_value.dtype, msg=msg)


# TODO(markblee): Move this to axlearn/experiments/test_utils.py, where it's used.
Expand Down

0 comments on commit c677fd7

Please sign in to comment.