Skip to content

Commit

Permalink
Fix bug in write_to_socket function & increase test coverage (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
JanusAsmussen authored Mar 21, 2023
1 parent 1f570c6 commit d22e286
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 8 deletions.
1 change: 1 addition & 0 deletions spark_utils/common/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def write_to_socket(
found here: https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#data-source-option)
"""
write_options = write_options or {}
partition_by = partition_by or []
if partition_count:
data = data.repartition(partition_count, *partition_by)

Expand Down
24 changes: 17 additions & 7 deletions test/test_common_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from typing import List, Union

import pytest
from collections import namedtuple

Expand All @@ -11,7 +13,8 @@

def are_dfs_equal(df1: DataFrame, df2: DataFrame) -> bool:
"""Asserts if dataframes are equal"""
return (df1.schema == df2.schema) and (df1.collect() == df2.collect())
sort_col = df1.columns[0]
return (df1.schema == df2.schema) and (df1.sort(sort_col).collect() == df2.sort(sort_col).collect())


Format = namedtuple("format", ["format", "read_options"])
Expand All @@ -26,7 +29,6 @@ def are_dfs_equal(df1: DataFrame, df2: DataFrame) -> bool:
],
)
def test_read_from_socket(format_: Format, spark_session: SparkSession, test_base_path: str):

test_data_path = os.path.join(test_base_path, "test_common_functions")

socket = JobSocket(
Expand All @@ -45,13 +47,20 @@ def test_read_from_socket(format_: Format, spark_session: SparkSession, test_bas
@pytest.mark.parametrize(
"format_",
[
Format(format="parquet", read_options={}),
Format(format="csv", read_options={"header": True}),
Format(format="json", read_options={}),
Format(format="parquet", read_options={}),
],
)
def test_write_to_socket(format_: Format, spark_session: SparkSession, test_base_path: str):

@pytest.mark.parametrize("partition_by", [["strings"], None, []])
@pytest.mark.parametrize("partition_count", [None, 1, 2])
def test_write_to_socket(
format_: Format,
spark_session: SparkSession,
test_base_path: str,
partition_by: Union[None, List[str]],
partition_count: Union[None, int],
):
test_data_path = os.path.join(test_base_path, "test_common_functions")
socket = JobSocket(
alias="test",
Expand All @@ -69,16 +78,17 @@ def test_write_to_socket(format_: Format, spark_session: SparkSession, test_base
data=df,
socket=output_socket,
write_options=format_.read_options,
partition_by=partition_by,
partition_count=partition_count,
)

df_read = read_from_socket(socket=output_socket, spark_session=spark_session, read_options=format_.read_options)

assert are_dfs_equal(df, df_read)
assert are_dfs_equal(df, df_read.select(df.columns))


@pytest.mark.parametrize("sep", ["|", ";"])
def test_job_socket_serialize(sep: str, test_base_path: str):

test_data_path = os.path.join(test_base_path, "test_common_functions/data.parquet")
socket = JobSocket(
alias="test",
Expand Down
1 change: 0 additions & 1 deletion test/test_common_functions/data.json

This file was deleted.

Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"strings":"abc","ints":1,"floats":1.0}
{"strings":"def","ints":2,"floats":2.0}
{"strings":"ghe","ints":3,"floats":3.0}

0 comments on commit d22e286

Please sign in to comment.