diff --git a/t5x/examples/scalable_t5/layers.py b/t5x/examples/scalable_t5/layers.py index 969f158b0..91d9c98b4 100644 --- a/t5x/examples/scalable_t5/layers.py +++ b/t5x/examples/scalable_t5/layers.py @@ -52,64 +52,7 @@ 1.0, 'fan_in', 'normal', out_axis=0 ) - -# ------------------------------------------------------------------------------ -# Temporary inlined JAX N-d initializer code -# TODO(levskaya): remove once new JAX release is out. -# ------------------------------------------------------------------------------ -def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1): - """Inlined JAX `nn.initializer._compute_fans`.""" - if isinstance(in_axis, int): - in_size = shape[in_axis] - else: - in_size = int(np.prod([shape[i] for i in in_axis])) - if isinstance(out_axis, int): - out_size = shape[out_axis] - else: - out_size = int(np.prod([shape[i] for i in out_axis])) - receptive_field_size = shape.total / in_size / out_size - fan_in = in_size * receptive_field_size - fan_out = out_size * receptive_field_size - return fan_in, fan_out - - -def variance_scaling( - scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_ -): - """Inlined JAX `nn.initializer.variance_scaling`.""" - - def init(key, shape, dtype=dtype): - dtype = jax.dtypes.canonicalize_dtype(dtype) - shape = jax.core.as_named_shape(shape) - fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) - if mode == 'fan_in': - denominator = fan_in - elif mode == 'fan_out': - denominator = fan_out - elif mode == 'fan_avg': - denominator = (fan_in + fan_out) / 2 - else: - raise ValueError( - 'invalid mode for variance scaling initializer: {}'.format(mode) - ) - variance = jnp.array(scale / denominator, dtype=dtype) - - if distribution == 'truncated_normal': - # constant is stddev of standard normal truncated to (-2, 2) - stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, shape, dtype) * stddev - elif distribution == 'normal': - return random.normal(key, shape, dtype) * jnp.sqrt(variance) - elif distribution == 'uniform': - return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) - else: - raise ValueError( - 'invalid distribution for variance scaling initializer: {}'.format( - distribution - ) - ) - - return init +variance_scaling = nn.initializers.variance_scaling # ------------------------------------------------------------------------------ @@ -420,7 +363,7 @@ def __call__( return out -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int, ...]: # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple([ax if ax >= 0 else ndim + ax for ax in axes])