Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGDW doesn't work #43

Closed
MarcelSimon opened this issue May 30, 2020 · 1 comment
Closed

SGDW doesn't work #43

MarcelSimon opened this issue May 30, 2020 · 1 comment

Comments

@MarcelSimon
Copy link

Hi!
Thanks for all your effort. This code really helps when implementing a custom optimizer.

There seems to be an issue with SGDW. The sample code from the README works fine with AdamW, but crashes when using SGDW:

import os; os.environ["TF_KERAS"]='1'
import numpy as np
from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l1, l2, l1_l2
import keras_adamw

ipt   = Input(shape=(120, 4))
x     = LSTM(60, activation='relu', name='lstm_1',
             kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
out   = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x)
model = Model(ipt, out)

lr_multipliers = {'lstm_1': 0.5}

optimizer = keras_adamw.SGDW(lr=1e-4, model=model)
model.compile(optimizer, loss='binary_crossentropy')

for epoch in range(3):
    for iteration in range(24):
        x = np.random.rand(10, 120, 4) # dummy data
        y = np.random.randint(0, 2, (10, 1)) # dummy labels
        loss = model.train_on_batch(x, y)
        print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
    print("EPOCH {} COMPLETED\n".format(epoch + 1))

returns

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-8-d2bca98bfb4f> in <module>()
     21         x = np.random.rand(10, 120, 4) # dummy data
     22         y = np.random.randint(0, 2, (10, 1)) # dummy labels
---> 23         loss = model.train_on_batch(x, y)
     24         print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
     25     print("EPOCH {} COMPLETED\n".format(epoch + 1))

8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

ValueError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:541 train_step  **
        self.trainable_variables)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1814 _minimize
        optimizer.apply_gradients(zip(gradients, trainable_variables))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:508 apply_gradients
        "name": name,
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2420 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2427 _merge_call
        return merge_fn(self._strategy, *args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:592 _distributed_apply  **
        var, apply_grad_to_update_var, args=(grad,), group=False))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2013 update
        return self._update(var, fn, args, kwargs, group)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2659 _update
        return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2665 _update_non_slot
        result = fn(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:567 apply_grad_to_update_var  **
        update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
    /usr/local/lib/python3.6/dist-packages/keras_adamw/optimizers_v2.py:672 _resource_apply_dense
        m = K.zeros(K.int_shape(var))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:1333 zeros
        return variable(v, dtype=dtype, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:845 variable
        constraint=constraint)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:261 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:255 _variable_v2_call
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2562 creator
        return next_creator(**kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2562 creator
        return next_creator(**kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2562 creator
        return next_creator(**kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2562 creator
        return next_creator(**kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:511 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

I'm using tensorflow-gpu version 2.2 and tf.keras.

@OverLordGoldDragon
Copy link
Owner

@MarcelSimon Glad you found it helpful, and thanks for reporting. This is indeed a bug for the untested case of SGDW(momentum=0), and is now fixed in v1.31. Btw, for momentum=0, SGDW should be equivalent to SGD in terms of weight decays (feel free to compare); an explanation here.

For learning, I also recommend these Q&A's: _get_hyper vs. _set_hyper -- others

Feel free to reopen if problems persist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants