Skip to content

Commit

Permalink
Learner: add new unittests using Model.
Browse files Browse the repository at this point in the history
These tests are similar how trainer.py uses Learner.
  • Loading branch information
ds-hwang committed Dec 19, 2024
1 parent a15a3bc commit 314e812
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 7 deletions.
9 changes: 3 additions & 6 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,12 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]:
sub_learner_updates = sub_learner_updates.mask(
# pylint: disable-next=cell-var-from-loop
lambda _: should_apply(updates.opt_params),
fields=(
"opt_params",
"delta_updates",
),
fields=("opt_params", "delta_updates"),
)
sub_learner_updated_model_params = getattr(self, name).update(sub_learner_updates)
updated_model_params = jax.tree.map(
lambda apply, new_v, old_v: new_v if apply else old_v,
should_apply(updates.param_values()),
should_apply(updated_model_params),
sub_learner_updated_model_params,
updated_model_params,
)
Expand Down Expand Up @@ -712,7 +709,7 @@ def _value_and_grad(

split_params = split_params_fn(opt_params)
model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params)
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
(unused_loss, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
model_params_grad, inputs=(model_params_nograd, inputs)
)
return Updates(
Expand Down
210 changes: 209 additions & 1 deletion axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import axlearn.common.update_transformation_test
from axlearn.common import schedule
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.base_model import BaseModel
from axlearn.common.config import REQUIRED, Required, config_class, config_for_function
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.layers import Linear
from axlearn.common.learner import (
CompositeLearner,
Learner,
Expand All @@ -28,7 +30,7 @@
should_update_with_optimizers,
)
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import OutputCollection
from axlearn.common.module import OutputCollection, child_context
from axlearn.common.module import functional as F
from axlearn.common.module import new_output_collection
from axlearn.common.optimizer_base import OptParam, OptStateSpec
Expand All @@ -50,6 +52,7 @@
)
from axlearn.common.utils import (
Nested,
NestedTensor,
PartitionSpec,
Tensor,
VDict,
Expand All @@ -59,7 +62,113 @@
)


class TestModel(BaseModel):
"""A simple model for test."""

@config_class
class Config(BaseModel.Config):
dim: int = 4

def __init__(self, cfg, *, parent):
super().__init__(cfg, parent=parent)
enc_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=cfg.dim,
)
self._add_child("encoder", enc_cfg)

dec_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=1,
)
self._add_child("decoder", dec_cfg)

def forward(self, input_batch: NestedTensor) -> tuple[Tensor, NestedTensor]:
x = self.encoder(input_batch["x"])
y = self.decoder(x)
loss = jnp.mean(y**2)
aux = dict(discriminator_loss=jnp.mean(jnp.abs(y)))
return loss, aux


class LearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
learning_rate = config_for_function(schedule.stepwise).set(
sub=[0.1, 0.01, 0.001],
start_step=[100, 200],
)
optimizer_cfg = config_for_function(adam_optimizer).set(
learning_rate=learning_rate, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
cfg = Learner.default_config().set(name="test", optimizer=optimizer_cfg)
cfg.ema.decay = ema_decay
learner: Learner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and Adam mu states are same.
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
)

def test_prune_empty_state(self):
state = {
"state": {
Expand Down Expand Up @@ -816,6 +925,105 @@ def test__value_and_grad(self):


class CompositeLearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
encoder_lr = 0.1
opt1_cfg = config_for_function(sgd_optimizer).set(
learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0
)
opt2_cfg = config_for_function(adam_optimizer).set(
learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")]

cfg = CompositeLearner.default_config().set(
name="test",
rules=learner_rules,
learners={
"encoder": Learner.default_config().set(
optimizer=opt1_cfg, enable_per_variable_summaries=True
),
"decoder": Learner.default_config().set(
optimizer=opt2_cfg, enable_per_variable_summaries=False
),
},
)
cfg.ema.decay = ema_decay
learner: CompositeLearner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
# pylint: disable-next=too-many-statements
def test_learner(self, ema_decay: Optional[float], method: str):
Expand Down

0 comments on commit 314e812

Please sign in to comment.