Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError while running policy gradient code, when run on colab #23

Open
Abhishek-Rajendra opened this issue Apr 28, 2020 · 0 comments

Comments

@Abhishek-Rajendra
Copy link

ValueError Traceback (most recent call last)
in ()
56
57 if done:
---> 58 loss = update_network(network, rewards, states, actions, num_actions)
59 tot_reward = sum(rewards)
60 print(f"Episode: {episode}, Reward: {tot_reward}, avg loss: {loss:.5f}")

10 frames
in update_network(network, rewards, states, actions, num_actions)
37 discounted_rewards /= np.std(discounted_rewards)
38 states = np.vstack(states)
---> 39 loss = network.train_on_batch(states, discounted_rewards)
40 return loss
41

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics, return_dict)
1349 class_weight)
1350 train_function = self.make_train_function()
-> 1351 logs = train_function(iterator)
1352
1353 if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in call(self, *args, **kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args, **kwds)
581
582 if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
625 # This is the first call of call, so we have to initialize.
626 initializers = []
--> 627 self._initialize(args, kwds, add_initializers_to=initializers)
628 finally:
629 # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
504 self._concrete_stateful_fn = (
505 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 506 *args, **kwds))
507
508 def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2444 args, kwargs = None, None
2445 with self._lock:
-> 2446 graph_function, _, _ = self._maybe_define_function(args, kwargs)
2447 return graph_function
2448

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)
-> 2777 graph_function = self._create_graph_function(args, kwargs)
2778 self._function_cache.primary[cache_key] = graph_function
2779 return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2665 arg_names=arg_names,
2666 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667 capture_by_value=self._capture_by_value),
2668 self._function_attributes,
2669 # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: func_outputs contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
439 # wrapped allows AutoGraph to swap in a converted function. We give
440 # the function a weak reference to itself to avoid a reference cycle.
--> 441 return weak_wrapped_fn().wrapped(*args, **kwds)
442 weak_wrapped_fn = weakref.ref(wrapped_fn)
443

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise

ValueError: in user code:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
    outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
    return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:533 train_step  **
    y, y_pred, sample_weight, regularization_losses=self.losses)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/compile_utils.py:204 __call__
    loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:143 __call__
    losses = self.call(y_true, y_pred)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:246 call
    return self.fn(y_true, y_pred, **self._fn_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py:1527 categorical_crossentropy
    return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:4561 categorical_crossentropy
    target.shape.assert_is_compatible_with(output.shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))

ValueError: Shapes (24, 1) and (24, 2) are incompatible
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant