Skip to content

Commit

Permalink
[SPARK-50388][PYTHON][TESTS][FOLLOW-UP] Move have_flameprof to `pys…
Browse files Browse the repository at this point in the history
…park.testing.utils`

### What changes were proposed in this pull request?
Move `have_flameprof` to `pyspark.testing.utils`

### Why are the changes needed?
to centralize the import check

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48973 from zhengruifeng/fix_has_flameprof.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Nov 27, 2024
1 parent 5425d45 commit e55511c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/connect/test_parity_udf_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from pyspark.sql.tests.test_udf_profiler import (
UDFProfiler2TestsMixin,
_do_computation,
has_flameprof,
)
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.utils import have_flameprof


class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase):
Expand Down Expand Up @@ -65,7 +65,7 @@ def action(df):
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))


Expand Down
38 changes: 16 additions & 22 deletions python/pyspark/sql/tests/test_udf_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,15 @@
from pyspark.sql.functions import col, pandas_udf, udf
from pyspark.sql.window import Window
from pyspark.profiler import UDFBasicProfiler
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import (
have_pandas,
have_pyarrow,
have_flameprof,
pandas_requirement_message,
pyarrow_requirement_message,
)

try:
import flameprof # noqa: F401

has_flameprof = True
except ImportError:
has_flameprof = False


def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False):
@udf("long", useArrow=use_arrow)
Expand Down Expand Up @@ -208,7 +202,7 @@ def test_perf_profiler_udf(self):
)
self.assertTrue(f"udf_{id}_perf.pstats" in os.listdir(d))

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand All @@ -230,7 +224,7 @@ def test_perf_profiler_udf_with_arrow(self):
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

def test_perf_profiler_udf_multiple_actions(self):
Expand All @@ -252,7 +246,7 @@ def action(df):
io.getvalue(), f"20.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

def test_perf_profiler_udf_registered(self):
Expand All @@ -276,7 +270,7 @@ def add1(x):
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -309,7 +303,7 @@ def add2(x):
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -345,7 +339,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -395,7 +389,7 @@ def mean_udf(v: pd.Series) -> float:
io.getvalue(), f"5.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -427,7 +421,7 @@ def min_udf(v: pd.Series) -> float:
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -458,7 +452,7 @@ def normalize(pdf):
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -496,7 +490,7 @@ def asof_join(left, right):
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -530,7 +524,7 @@ def normalize(table):
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

@unittest.skipIf(
Expand Down Expand Up @@ -562,7 +556,7 @@ def summarize(left, right):
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))

def test_perf_profiler_render(self):
Expand All @@ -572,7 +566,7 @@ def test_perf_profiler_render(self):

id = list(self.profile_results.keys())[0]

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))
self.assertIn("svg", self.spark.profile.render(id, type="perf"))
self.assertIn("svg", self.spark.profile.render(id, renderer="flameprof"))
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def have_package(name: str) -> bool:
have_graphviz = have_package("graphviz")
graphviz_requirement_message = None if have_graphviz else "No module named 'graphviz'"

have_flameprof = have_package("flameprof")
flameprof_requirement_message = None if have_flameprof else "No module named 'flameprof'"

pandas_requirement_message = None
try:
Expand Down

0 comments on commit e55511c

Please sign in to comment.