diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index 3e4c68cd..d49c4ded 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -526,15 +526,12 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]: sub_learner_updates = sub_learner_updates.mask( # pylint: disable-next=cell-var-from-loop lambda _: should_apply(updates.opt_params), - fields=( - "opt_params", - "delta_updates", - ), + fields=("opt_params", "delta_updates"), ) sub_learner_updated_model_params = getattr(self, name).update(sub_learner_updates) updated_model_params = jax.tree.map( lambda apply, new_v, old_v: new_v if apply else old_v, - should_apply(updates.param_values()), + should_apply(updated_model_params), sub_learner_updated_model_params, updated_model_params, ) @@ -712,7 +709,7 @@ def _value_and_grad( split_params = split_params_fn(opt_params) model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params) - (_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)( + (unused_loss, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)( model_params_grad, inputs=(model_params_nograd, inputs) ) return Updates( diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index 1a11a479..d769a05c 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -15,8 +15,10 @@ import axlearn.common.update_transformation_test from axlearn.common import schedule from axlearn.common.base_layer import FactorizationSpec, ParameterSpec +from axlearn.common.base_model import BaseModel from axlearn.common.config import REQUIRED, Required, config_class, config_for_function from axlearn.common.gradient_accumulation import with_minibatch_steps +from axlearn.common.layers import Linear from axlearn.common.learner import ( CompositeLearner, Learner, @@ -28,7 +30,7 @@ should_update_with_optimizers, ) from axlearn.common.metrics import MetricAccumulator, WeightedScalar -from axlearn.common.module import OutputCollection +from axlearn.common.module import OutputCollection, child_context from axlearn.common.module import functional as F from axlearn.common.module import new_output_collection from axlearn.common.optimizer_base import OptParam, OptStateSpec @@ -50,6 +52,7 @@ ) from axlearn.common.utils import ( Nested, + NestedTensor, PartitionSpec, Tensor, VDict, @@ -59,7 +62,113 @@ ) +class TestModel(BaseModel): + """A simple model for test.""" + + @config_class + class Config(BaseModel.Config): + dim: int = 4 + + def __init__(self, cfg, *, parent): + super().__init__(cfg, parent=parent) + enc_cfg = Linear.default_config().set( + input_dim=cfg.dim, + output_dim=cfg.dim, + ) + self._add_child("encoder", enc_cfg) + + dec_cfg = Linear.default_config().set( + input_dim=cfg.dim, + output_dim=1, + ) + self._add_child("decoder", dec_cfg) + + def forward(self, input_batch: NestedTensor) -> tuple[Tensor, NestedTensor]: + x = self.encoder(input_batch["x"]) + y = self.decoder(x) + loss = jnp.mean(y**2) + aux = dict(discriminator_loss=jnp.mean(jnp.abs(y))) + return loss, aux + + class LearnerTest(TestCase): + @parameterized.parameters(None, 0.999) + def test_forward_and_backward(self, ema_decay): + """Demonstrates how API users should use the API while ensuring that it works correctly.""" + # Init a learner. + learning_rate = config_for_function(schedule.stepwise).set( + sub=[0.1, 0.01, 0.001], + start_step=[100, 200], + ) + optimizer_cfg = config_for_function(adam_optimizer).set( + learning_rate=learning_rate, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0 + ) + cfg = Learner.default_config().set(name="test", optimizer=optimizer_cfg) + cfg.ema.decay = ema_decay + learner: Learner = cfg.instantiate(parent=None) + + # Init a model. + input_dim = 4 + model_cfg = TestModel.default_config().set(name="test", dim=input_dim) + model = model_cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5) + params = model.initialize_parameters_recursively(init_key) + + # Create model and learner states. + model_param_specs = model.create_parameter_specs_recursively() + opt_params = jax.tree.map( + lambda param, spec: OptParam( + value=param, + factorization_spec=spec.factorization if spec else None, + weight_decay_scale=spec.weight_decay_scale if spec else 1.0, + ), + params, + model_param_specs, + ) + learner_state = learner.init(model_params=opt_params) + + # Forward and backward. + def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs: + model_output_collection = new_output_collection() + with child_context( + "model", + module=model, + state=model_params, + prng_key=inputs["forward_key"], + output_collection=model_output_collection, + ): + loss, aux = model(input_batch=inputs["input_batch"]) + return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection) + + batch = 2 + input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim))) + fwd_bwd_outputs, learner_output_collection = F( + learner, + method="forward_and_backward", + state=learner_state, + is_training=True, + prng_key=learner_key, + inputs=dict( + fn=_forward, + opt_params=opt_params, + inputs=dict( + input_batch=input_batch, + forward_key=fwd_key, + ), + ), + ) + forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs + updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params + learner_state = learner_output_collection.state_updates + self.assertGreater(forward_outputs.loss, 0.0) + self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0) + # The structure of updated params and Adam mu states are same. + self.assertNestedEqual( + jax.tree_util.tree_structure(updated_model_params), + jax.tree_util.tree_structure(learner_state["optimizer"][1].mu), + ) + def test_prune_empty_state(self): state = { "state": { @@ -816,6 +925,105 @@ def test__value_and_grad(self): class CompositeLearnerTest(TestCase): + @parameterized.parameters(None, 0.999) + def test_forward_and_backward(self, ema_decay): + """Demonstrates how API users should use the API while ensuring that it works correctly.""" + # Init a learner. + encoder_lr = 0.1 + opt1_cfg = config_for_function(sgd_optimizer).set( + learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0 + ) + opt2_cfg = config_for_function(adam_optimizer).set( + learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0 + ) + learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")] + + cfg = CompositeLearner.default_config().set( + name="test", + rules=learner_rules, + learners={ + "encoder": Learner.default_config().set( + optimizer=opt1_cfg, enable_per_variable_summaries=True + ), + "decoder": Learner.default_config().set( + optimizer=opt2_cfg, enable_per_variable_summaries=False + ), + }, + ) + cfg.ema.decay = ema_decay + learner: CompositeLearner = cfg.instantiate(parent=None) + + # Init a model. + input_dim = 4 + model_cfg = TestModel.default_config().set(name="test", dim=input_dim) + model = model_cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5) + params = model.initialize_parameters_recursively(init_key) + + # Create model and learner states. + model_param_specs = model.create_parameter_specs_recursively() + opt_params = jax.tree.map( + lambda param, spec: OptParam( + value=param, + factorization_spec=spec.factorization if spec else None, + weight_decay_scale=spec.weight_decay_scale if spec else 1.0, + ), + params, + model_param_specs, + ) + learner_state = learner.init(model_params=opt_params) + + # Forward and backward. + def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs: + model_output_collection = new_output_collection() + with child_context( + "model", + module=model, + state=model_params, + prng_key=inputs["forward_key"], + output_collection=model_output_collection, + ): + loss, aux = model(input_batch=inputs["input_batch"]) + return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection) + + batch = 2 + input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim))) + fwd_bwd_outputs, learner_output_collection = F( + learner, + method="forward_and_backward", + state=learner_state, + is_training=True, + prng_key=learner_key, + inputs=dict( + fn=_forward, + opt_params=opt_params, + inputs=dict( + input_batch=input_batch, + forward_key=fwd_key, + ), + ), + ) + forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs + updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params + learner_state = learner_output_collection.state_updates + self.assertGreater(forward_outputs.loss, 0.0) + self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0) + # The structure of updated params and optimizer states are same. + opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode)) + self.assertNestedEqual( + jax.tree_util.tree_structure(updated_model_params), + jax.tree_util.tree_structure( + learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn + ), + ) + self.assertNestedEqual( + jax.tree_util.tree_structure(updated_model_params), + jax.tree_util.tree_structure( + learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn + ), + ) + @parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward")) # pylint: disable-next=too-many-statements def test_learner(self, ema_decay: Optional[float], method: str):