Skip to content

Commit

Permalink
Cumulative Update for Keras 2.2.0 and some fixes and Explanations in …
Browse files Browse the repository at this point in the history
…the README.md (#103)

* attempt 1

* Updated engine.py

Fixed the bug!

* Update Readme.md to correct some mistakes.

* Update README.md

* Fix Keras 2.2?

* Update README.md

* Update requirements.txt

* Put back Fariz's URL into the installation URL

* Update engine.py

* more fix

* Spaces around = for prettyness
  • Loading branch information
Pyrestone authored and farizrahman4u committed Jul 1, 2018
1 parent b24f453 commit 43635f3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
38 changes: 28 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ from keras.layers import *
from keras.models import *
from recurrentshop import *

x_t = Input(5,)) # The input to the RNN at time t
h_tm1 = Input((10,)) # Previous hidden state
x_t = Input(shape=(5,)) # The input to the RNN at time t
h_tm1 = Input(shape=(10,)) # Previous hidden state

# Compute new hidden state
h_t = add([Dense(10)(x_t), Dense(10, use_bias=False)(h_tm1)])
Expand All @@ -34,17 +34,35 @@ h_t = add([Dense(10)(x_t), Dense(10, use_bias=False)(h_tm1)])
h_t = Activation('tanh')(h_t)

# Build the RNN
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, output_states=[h_t])

# rnn is a standard Keras `Recurrent` instance. RecuurentModel also accepts arguments such as unroll, return_sequences etc
# RecurrentModel is a standard Keras `Recurrent` layer.
# RecurrentModel also accepts arguments such as unroll, return_sequences etc
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, final_states=[h_t])
# return_sequences is False by default
# so it only returns the last h_t state

# Build a Keras Model using our RNN layer
# input dimensions are (Time_steps, Depth)
x = Input(shape=(7,5))
y = rnn(x)
model = Model(x, y)

# Run the RNN over a random sequence
# Don't forget the batch shape when calling the model!
out = model.predict(np.random.random((1, 7, 5)))
print(out.shape)#->(1,10)

x = Input((7,5))
y = rnn(x)

model = Model(x, y)
model.predict(np.random.random((7, 5)))
# to get one output per input sequence element, set return_sequences=True
rnn2 = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, final_states=[h_t],return_sequences=True)

# Time_steps can also be None to allow variable Sequence Length
# Note that this is not compatible with unroll=True
x = Input(shape=(None ,5))
y = rnn2(x)
model2 = Model(x, y)

out2 = model2.predict(np.random.random((1, 7, 5)))
print(out2.shape)#->(1,7,10)

```

Expand Down Expand Up @@ -129,7 +147,7 @@ See docs/ directory for more features.
# Installation

```shell
git clone https://www.github.com/datalogai/recurrentshop.git
git clone https://www.github.com/farizrahman4u/recurrentshop.git
cd recurrentshop
python setup.py install
```
Expand Down
27 changes: 19 additions & 8 deletions recurrentshop/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from keras import initializers
from .backend import rnn, learning_phase_scope
from .generic_utils import serialize_function, deserialize_function
from keras.engine.topology import Node, _collect_previous_mask, _collect_input_shape
from keras.engine.base_layer import Node,_collect_previous_mask, _collect_input_shape
import inspect


if K.backend() == 'tensorflow':
import tensorflow as tf

def _to_list(x):
if type(x) is not list:
x = [x]
Expand Down Expand Up @@ -549,9 +552,12 @@ def call(self, inputs, initial_state=None, initial_readout=None, ground_truth=No
if self.teacher_force:
if ground_truth is None or self._is_optional_input_placeholder(ground_truth):
raise Exception('ground_truth must be provided for RecurrentModel with teacher_force=True.')
# counter = K.zeros((1,), dtype='int32')
counter = K.zeros((1,))
counter = K.cast(counter, 'int32')
if K.backend() == 'tensorflow':
with tf.control_dependencies(None):
counter = K.zeros((1,))
else:
counter = K.zeros((1,))
counter = K.cast(counter, 'int32')
initial_states.insert(-1, counter)
initial_states[-2]
initial_states.insert(-1, ground_truth)
Expand Down Expand Up @@ -654,8 +660,13 @@ def step(self, inputs, states):
ground_truth = states.pop()
assert K.ndim(ground_truth) == 3, K.ndim(ground_truth)
counter = states.pop()
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
if K.backend() == 'tensorflow':
with tf.control_dependencies(None):
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
else:
zero = K.cast(K.zeros((1,))[0], 'int32')
one = K.cast(K.zeros((1,))[0], 'int32')
slices = [slice(None), counter[0] - K.switch(counter[0], one, zero)] + [slice(None)] * (K.ndim(ground_truth) - 2)
ground_truth_slice = ground_truth[slices]
readout = K.in_train_phase(K.switch(counter[0], ground_truth_slice, readout), readout)
Expand Down Expand Up @@ -839,13 +850,13 @@ def _get_optional_input_placeholder(self, name=None, num=1):
self._optional_input_placeholders[name] = self._get_optional_input_placeholder()
return self._optional_input_placeholders[name]
if num == 1:
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0]
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
assert self._is_optional_input_placeholder(optional_input_placeholder)
return optional_input_placeholder
else:
y = []
for _ in range(num):
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0]
optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0]
assert self._is_optional_input_placeholder(optional_input_placeholder)
y.append(optional_input_placeholder)
return y
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
keras>=2.0.0
keras>=2.2.0
numpy>=1.8.0
theano>=0.8.0
tensorflow>=1.9

0 comments on commit 43635f3

Please sign in to comment.