Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve get_source_inputs related errors #75

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion classification_models/keras.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras
# import keras
from tensorflow import keras
from .models_factory import ModelsFactory


Expand Down
21 changes: 16 additions & 5 deletions classification_models/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._common_blocks import ChannelSE
from .. import get_submodules_from_kwargs
from ..weights import load_model_weights
from tensorflow.keras.utils import get_source_inputs

backend = None
layers = None
Expand Down Expand Up @@ -209,10 +210,16 @@ def ResNet(model_params, input_shape=None, input_tensor=None, include_top=True,
if input_tensor is None:
img_input = layers.Input(shape=input_shape, name='data')
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
""" Commented to solve following error:
ValueError: Unexpectedly found an instance of type
`<class 'tensorflow.python.keras.engine.keras_tensor.KerasTensor'>`.
Expected a symbolic tensor instance.
"""
# if not backend.is_keras_tensor(input_tensor):
# img_input = layers.Input(tensor=input_tensor, shape=input_shape)
# else:
# img_input = input_tensor
img_input = input_tensor

# choose residual block type
ResidualBlock = model_params.residual_block
Expand Down Expand Up @@ -266,7 +273,11 @@ def ResNet(model_params, input_shape=None, input_tensor=None, include_top=True,

# Ensure that the model takes into account any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = keras_utils.get_source_inputs(input_tensor)
""" Modified to solve following error:
module 'keras.utils' has no attribute 'get_source_inputs'
"""
# inputs = keras_utils.get_source_inputs(input_tensor)
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
7 changes: 6 additions & 1 deletion classification_models/weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import get_submodules_from_kwargs
from tensorflow.keras.utils import get_file

__all__ = ['load_model_weights']

Expand All @@ -22,7 +23,11 @@ def load_model_weights(model, model_name, dataset, classes, include_top, **kwarg
raise ValueError('If using `weights` and `include_top`'
' as true, `classes` should be {}'.format(weights['classes']))

weights_path = keras_utils.get_file(
""" Modified to solve following error:
module 'keras.utils' has no attribute 'get_file'
"""
# weights_path = keras_utils.get_file(
weights_path = get_file(
weights['name'],
weights['url'],
cache_subdir='models',
Expand Down