diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 8a94187bfced1..e73cdd7b80c3f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1047,7 +1047,6 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_grouped_map", "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", - "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", diff --git a/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py deleted file mode 100644 index 65bb4c021f4d2..0000000000000 --- a/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from pyspark.sql.tests.test_python_streaming_datasource import ( - BasePythonStreamingDataSourceTestsMixin, -) -from pyspark.testing.connectutils import ReusedConnectTestCase - - -class PythonStreamingDataSourceParityTests( - BasePythonStreamingDataSourceTestsMixin, ReusedConnectTestCase -): - pass - - -if __name__ == "__main__": - import unittest - from pyspark.sql.tests.connect.test_parity_python_streaming_datasource import * # noqa: F401 - - try: - import xmlrunner # type: ignore[import] - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index 5125e9ad6dec1..90f06223e0091 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -141,11 +141,15 @@ def test_stream_reader(self): self.spark.dataSource.register(self._get_test_data_source()) df = self.spark.readStream.format("TestDataSource").load() + current_batch_id = -1 + def check_batch(df, batch_id): + nonlocal current_batch_id + current_batch_id = batch_id assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) q = df.writeStream.foreachBatch(check_batch).start() - while len(q.recentProgress) < 10: + while current_batch_id < 10: time.sleep(0.2) q.stop() q.awaitTermination