Skip to content

Commit

Permalink
fixed ML tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liferoad committed Nov 27, 2024
1 parent 5a6c902 commit b403f71
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
21 changes: 10 additions & 11 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import shutil
import tempfile
import time
import typing
import unittest
from collections.abc import Sequence
from typing import Any
Expand Down Expand Up @@ -140,8 +139,8 @@ def test_ml_transform_on_list_dict(self):
'x': int, 'y': float
},
expected_dtype={
'x': typing.Sequence[np.float32],
'y': typing.Sequence[np.float32],
'x': Sequence[np.float32],
'y': Sequence[np.float32],
},
),
param(
Expand All @@ -153,8 +152,8 @@ def test_ml_transform_on_list_dict(self):
'x': np.int32, 'y': np.float32
},
expected_dtype={
'x': typing.Sequence[np.float32],
'y': typing.Sequence[np.float32],
'x': Sequence[np.float32],
'y': Sequence[np.float32],
},
),
param(
Expand All @@ -165,21 +164,21 @@ def test_ml_transform_on_list_dict(self):
'x': list[int], 'y': list[float]
},
expected_dtype={
'x': typing.Sequence[np.float32],
'y': typing.Sequence[np.float32],
'x': Sequence[np.float32],
'y': Sequence[np.float32],
},
),
param(
input_data=[{
'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0]
}],
input_types={
'x': typing.Sequence[int],
'y': typing.Sequence[float],
'x': Sequence[int],
'y': Sequence[float],
},
expected_dtype={
'x': typing.Sequence[np.float32],
'y': typing.Sequence[np.float32],
'x': Sequence[np.float32],
'y': Sequence[np.float32],
},
),
])
Expand Down
8 changes: 4 additions & 4 deletions sdks/python/apache_beam/ml/transforms/handlers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import shutil
import sys
import tempfile
import typing
import unittest
import uuid
from collections.abc import Sequence
from typing import NamedTuple
from typing import Union

Expand Down Expand Up @@ -276,9 +276,9 @@ def test_tft_process_handler_transformed_data_schema(self):
schema_utils.schema_from_feature_spec(raw_data_feature_spec))

expected_transformed_data_schema = {
'x': typing.Sequence[np.float32],
'y': typing.Sequence[np.float32],
'z': typing.Sequence[bytes]
'x': Sequence[np.float32],
'y': Sequence[np.float32],
'z': Sequence[bytes]
}

actual_transformed_data_schema = (
Expand Down
7 changes: 4 additions & 3 deletions sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,11 @@ def get_portability_package_data():
'pillow',
# Support TF 2.16.0: https://github.com/apache/beam/issues/31294
# Once TF version is unpinned, also don't restrict Python version.
'tensorflow<2.16.0;python_version<"3.12"',
# 'tensorflow<2.16.0;python_version<"3.12"',
# limit this to 2.12.x to make tests stable
'tensorflow>=2.12rc1,<2.13',
'tensorflow-hub',
# https://github.com/tensorflow/transform/issues/313
'tensorflow-transform;python_version<"3.11"',
'tensorflow-transform',
'tf2onnx',
'torch',
'transformers',
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ commands =
setenv =
extras = test,gcp,dataframe,ml_test
commands =
# Log tensorflow version for debugging
/bin/sh -c "pip freeze | grep -E tensorflow"
bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}"

[testenv:py{39,310,311,312}-dask]
Expand Down

0 comments on commit b403f71

Please sign in to comment.