Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cumulative Update for Keras 2.2.0 and some fixes and Explanations in the README.md #103

Merged
merged 18 commits into from
Jul 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ from keras.layers import *
from keras.models import *
from recurrentshop import *

x_t = Input(5,)) # The input to the RNN at time t
h_tm1 = Input((10,)) # Previous hidden state
x_t = Input(shape=(5,)) # The input to the RNN at time t
h_tm1 = Input(shape=(10,)) # Previous hidden state

# Compute new hidden state
h_t = add([Dense(10)(x_t), Dense(10, use_bias=False)(h_tm1)])
Expand All @@ -34,17 +34,35 @@ h_t = add([Dense(10)(x_t), Dense(10, use_bias=False)(h_tm1)])
h_t = Activation('tanh')(h_t)

# Build the RNN
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, output_states=[h_t])

# rnn is a standard Keras `Recurrent` instance. RecuurentModel also accepts arguments such as unroll, return_sequences etc
# RecurrentModel is a standard Keras `Recurrent` layer.
# RecurrentModel also accepts arguments such as unroll, return_sequences etc
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, final_states=[h_t])
# return_sequences is False by default
# so it only returns the last h_t state

# Build a Keras Model using our RNN layer
# input dimensions are (Time_steps, Depth)
x = Input(shape=(7,5))
y = rnn(x)
model = Model(x, y)

# Run the RNN over a random sequence
# Don't forget the batch shape when calling the model!
out = model.predict(np.random.random((1, 7, 5)))
print(out.shape)#->(1,10)

x = Input((7,5))
y = rnn(x)

model = Model(x, y)
model.predict(np.random.random((7, 5)))
# to get one output per input sequence element, set return_sequences=True
rnn2 = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, final_states=[h_t],return_sequences=True)

# Time_steps can also be None to allow variable Sequence Length
# Note that this is not compatible with unroll=True
x = Input(shape=(None ,5))
y = rnn2(x)
model2 = Model(x, y)

out2 = model2.predict(np.random.random((1, 7, 5)))
print(out2.shape)#->(1,7,10)

```

Expand Down Expand Up @@ -129,7 +147,7 @@ See docs/ directory for more features.
# Installation

```shell
git clone https://www.github.com/datalogai/recurrentshop.git
git clone https://www.github.com/farizrahman4u/recurrentshop.git
cd recurrentshop
python setup.py install
```
Expand Down
27 changes: 19 additions & 8 deletions recurrentshop/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from keras import initializers
from .backend import rnn, learning_phase_scope
from .generic_utils import serialize_function, deserialize_function
from keras.engine.topology import Node, _collect_previous_mask, _collect_input_shape
from keras.engine.base_layer import Node,_collect_previous_mask, _collect_input_shape
import inspect


if K.backend() == 'tensorflow':
import tensorflow as tf

def _to_list(x):
if type(x) is not list:
x = [x]
Expand Down Expand Up @@ -549,9 +552,12 @@ def call(self, inputs, initial_state=None, initial_readout=None, ground_truth=No
if self.teacher_force:
if ground_truth is None or self._is_optional_input_placeholder(ground_truth):
raise Exception('ground_truth must be provided for RecurrentModel with teacher_force=True.')
# counter = K.zeros((1,), dtype='int32')
counter = K.zeros((1,))
counter = K.cast(counter, 'int32')
if K.backend() == 'tensorflow':
with tf.control_dependencies(None):
counter = K.zeros((1,))
else:
counter = K.zeros((1,))
counter = K.cast(counter, 'int32')
initial_states.insert(-1, counter)
initial_states[-2]
initial_states.insert(-1, ground_truth)
Expand Down Expand Up @@ -654,8 +660,13 @@ def step(self, inputs, states):
ground_truth = states.pop()
assert K.ndim(ground_truth) == 3, K.ndim(ground_truth)
counter = states.pop()
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
if K.backend() == 'tensorflow':
with tf.control_dependencies(None):
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
else:
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
slices = [slice(None), counter[0] - K.switch(counter[0], one, zero)] + [slice(None)] * (K.ndim(ground_truth) - 2)
ground_truth_slice = ground_truth[slices]
readout = K.in_train_phase(K.switch(counter[0], ground_truth_slice, readout), readout)
Expand Down Expand Up @@ -839,13 +850,13 @@ def _get_optional_input_placeholder(self, name=None, num=1):
self._optional_input_placeholders[name] = self._get_optional_input_placeholder()
return self._optional_input_placeholders[name]
if num == 1:
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0]
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
assert self._is_optional_input_placeholder(optional_input_placeholder)
return optional_input_placeholder
else:
y = []
for _ in range(num):
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0]
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
assert self._is_optional_input_placeholder(optional_input_placeholder)
y.append(optional_input_placeholder)
return y
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
keras>=2.0.0
keras>=2.2.0
numpy>=1.8.0
theano>=0.8.0
tensorflow>=1.9