Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float64 of argument 'x'. #592

Open
FederCO23 opened this issue Sep 28, 2024 · 2 comments

Comments

@FederCO23
Copy link

Hi Folks,
I'm working through the example notebook binary segmentation (camvid).ipynb, and I've made good progress so far. However, I'm encountering an issue when trying to fit the model.

The error seems to be related to a type mismatch between the ground truth (gt) and the predictions (pr), where one is of type float32 and the other is float64. I was wondering if this is related to a version mismatch of the libraries, or if I might be using the wrong versions of the dependencies.

Any insights or guidance? Thanks in advance!

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[77], line 13
      1 # # train model
      2 # history = model.fit(
      3 # #history = model.fit_generator(
   (...)
     11 
     12 # Train the model
---> 13 history = model.fit(
     14     train_dataloader,
     15     steps_per_epoch=len(train_dataloader),
     16     epochs=EPOCHS,
     17     callbacks=callbacks,
     18     validation_data=valid_dataloader,
     19     validation_steps=len(valid_dataloader),
     20 )

File E:\Devs\pyEnvs_experiments\segmentations\segmentation-Env\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File E:\Devs\pyEnvs_experiments\segmentations\segmentation-Env\Lib\site-packages\segmentation_models\metrics.py:54, in IOUScore.__call__(self, gt, pr)
     53 def __call__(self, gt, pr):
---> 54     return F.iou_score(
     55         gt,
     56         pr,
     57         class_weights=self.class_weights,
     58         class_indexes=self.class_indexes,
     59         smooth=self.smooth,
     60         per_image=self.per_image,
     61         threshold=self.threshold,
     62         **self.submodules
     63     )

File E:\Devs\pyEnvs_experiments\segmentations\segmentation-Env\Lib\site-packages\segmentation_models\base\functional.py:93, in iou_score(gt, pr, class_weights, class_indexes, smooth, per_image, threshold, **kwargs)
     90 axes = get_reduce_axes(per_image, **kwargs)
     92 # score calculation
---> 93 intersection = backend.sum(gt * pr, axis=axes)
     94 union = backend.sum(gt + pr, axis=axes) - intersection
     96 score = (intersection + smooth) / (union + smooth)

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float64 of argument 'x'.

FYI, this my virtual environment 'pip list':

Package                 Version
----------------------- -----------
absl-py                 2.1.0
albucore                0.0.17
albumentations          1.4.16
annotated-types         0.7.0
asttokens               2.4.1
astunparse              1.6.3
certifi                 2024.8.30
charset-normalizer      3.3.2
colorama                0.4.6
comm                    0.2.2
contourpy               1.3.0
cycler                  0.12.1
debugpy                 1.8.6
decorator               5.1.1
efficientnet            1.1.1
eval_type_backport      0.2.0
executing               2.1.0
flatbuffers             24.3.25
fonttools               4.54.1
gast                    0.6.0
google-pasta            0.2.0
grpcio                  1.66.2
h5py                    3.12.1
idna                    3.10
image-classifiers       1.0.0
imageio                 2.35.1
imgaug                  0.2.6
ipykernel               6.29.5
ipython                 8.27.0
jedi                    0.19.1
jupyter_client          8.6.3
jupyter_core            5.7.2
keras                   3.5.0
Keras-Applications      1.0.8
kiwisolver              1.4.7
lazy_loader             0.4
libclang                18.1.1
Markdown                3.7
markdown-it-py          3.0.0
MarkupSafe              2.1.5
matplotlib              3.9.2
matplotlib-inline       0.1.7
mdurl                   0.1.2
ml-dtypes               0.4.1
namex                   0.0.8
nest-asyncio            1.6.0
networkx                3.3
numpy                   1.26.4
opencv-python-headless  4.10.0.84
opt_einsum              3.4.0
optree                  0.12.1
packaging               24.1
parso                   0.8.4
pillow                  10.4.0
pip                     24.2
platformdirs            4.3.6
prompt_toolkit          3.0.48
protobuf                4.25.5
psutil                  6.0.0
pure_eval               0.2.3
pydantic                2.9.2
pydantic_core           2.23.4
Pygments                2.18.0
pyparsing               3.1.4
python-dateutil         2.9.0.post0
pywin32                 306
PyYAML                  6.0.2
pyzmq                   26.2.0
requests                2.32.3
rich                    13.8.1
scikit-image            0.24.0
scipy                   1.14.1
segmentation-models     1.0.1
setuptools              75.1.0
six                     1.16.0
stack-data              0.6.3
tensorboard             2.17.1
tensorboard-data-server 0.7.2
tensorflow              2.17.0
tensorflow-intel        2.17.0
termcolor               2.4.0
tifffile                2024.9.20
tornado                 6.4.1
traitlets               5.14.3
typing_extensions       4.12.2
urllib3                 2.2.3
wcwidth                 0.2.13
Werkzeug                3.0.4
wheel                   0.44.0
wrapt                   1.16.0
@naragana
Copy link

naragana commented Nov 9, 2024

You need to cast data type of 'gt'

import tensorflow as tf
gt_cast = tf.cast(gt, pr.dtype)
intersection = backend.sum(gt_cast * pr, axis=axes)
union = backend.sum(gt_cast + pr, axis=axes) - intersection

@FederCO23
Copy link
Author

Hi Naragana,
I updated the functional.py file in the library, and it worked! Thank you for your advice—very kind of you!
Now, I’m debugging an issue with the logs object in the Keras trainer...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants