Skip to content

Commit

Permalink
Merge pull request #64 from engelhardtnick-at-TW/fix-mypy-and-linter
Browse files Browse the repository at this point in the history
Fix mypy and linter
  • Loading branch information
lauris-tw authored Nov 5, 2024
2 parents f0e3761 + 604a870 commit 6d9ad34
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 38 deletions.
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


@pytest.fixture(scope="session")
def SPARK():
def spark_session() -> SparkSession:
return SparkSession.builder.appName("IntegrationTests").getOrCreate()
22 changes: 12 additions & 10 deletions tests/integration/test_distance_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@
]


def test_should_maintain_all_data_it_reads(SPARK) -> None:
given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK)
given_dataframe = SPARK.read.parquet(given_ingest_folder)
distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder)
def test_should_maintain_all_data_it_reads(spark_session: SparkSession) -> None:
given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(
spark_session)
given_dataframe = spark_session.read.parquet(given_ingest_folder)
distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder)

actual_dataframe = SPARK.read.parquet(given_transform_folder)
actual_dataframe = spark_session.read.parquet(given_transform_folder)
actual_columns = set(actual_dataframe.columns)
actual_schema = set(actual_dataframe.schema)
expected_columns = set(given_dataframe.columns)
Expand All @@ -97,12 +98,13 @@ def test_should_maintain_all_data_it_reads(SPARK) -> None:


@pytest.mark.skip
def test_should_add_distance_column_with_calculated_distance(SPARK) -> None:
given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(SPARK)
distance_transformer.run(SPARK, given_ingest_folder, given_transform_folder)
def test_should_add_distance_column_with_calculated_distance(spark_session: SparkSession) -> None:
given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders(
spark_session)
distance_transformer.run(spark_session, given_ingest_folder, given_transform_folder)

actual_dataframe = SPARK.read.parquet(given_transform_folder)
expected_dataframe = SPARK.createDataFrame(
actual_dataframe = spark_session.read.parquet(given_transform_folder)
expected_dataframe = spark_session.createDataFrame(
[
SAMPLE_DATA[0] + [1.07],
SAMPLE_DATA[1] + [0.92],
Expand Down
10 changes: 6 additions & 4 deletions tests/integration/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import tempfile
from typing import Tuple, List

from pyspark.sql import SparkSession

from data_transformations.citibike import ingest


def test_should_sanitize_column_names(SPARK) -> None:
def test_should_sanitize_column_names(spark_session: SparkSession) -> None:
given_ingest_folder, given_transform_folder = __create_ingest_and_transform_folders()
input_csv_path = given_ingest_folder + 'input.csv'
csv_content = [
Expand All @@ -15,10 +17,10 @@ def test_should_sanitize_column_names(SPARK) -> None:
['1', '5', '2'],
]
__write_csv_file(input_csv_path, csv_content)
ingest.run(SPARK, input_csv_path, given_transform_folder)
ingest.run(spark_session, input_csv_path, given_transform_folder)

actual = SPARK.read.parquet(given_transform_folder)
expected = SPARK.createDataFrame(
actual = spark_session.read.parquet(given_transform_folder)
expected = spark_session.createDataFrame(
[
['3', '4', '1'],
['1', '5', '2']
Expand Down
43 changes: 24 additions & 19 deletions tests/integration/test_validate_spark_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,44 @@
import pytest


def test_java_home_is_set():
def test_java_home_is_set() -> None:
java_home = os.environ.get("JAVA_HOME")
assert java_home is not None, "Environment variable 'JAVA_HOME' is not set but is required by pySpark to work."
assert java_home is not None, \
"Environment variable 'JAVA_HOME' is not set but is required by pySpark to work."


def test_java_version_minimum_requirement(expected_major_version=11):
def test_java_version_minimum_requirement(expected_major_version: int = 11) -> None:
version_line = __extract_version_line(__java_version_output())
major_version = __parse_major_version(version_line)
assert major_version >= expected_major_version, (f"Major version {major_version} is not recent enough, "
f"we need at least version {expected_major_version}.")
assert major_version >= expected_major_version, (
f"Major version {major_version} is not recent enough, "
f"we need at least version {expected_major_version}.")


def __java_version_output():
java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode("utf-8")
def __java_version_output() -> str:
java_version = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode(
"utf-8")
print(f"\n`java -version` returned\n{java_version}")
return java_version


def __extract_version_line(java_version_output):
version_line = next((line for line in java_version_output.splitlines() if "version" in line), None)
def __extract_version_line(java_version_output: str) -> str:
version_line = next((line for line in java_version_output.splitlines() if "version" in line),
None)
if not version_line:
pytest.fail("Couldn't find version information in `java -version` output.")
return version_line


def __parse_major_version(version_line):
version_regex = re.compile(r'version "(?P<major>\d+)\.(?P<minor>\d+)\.\d+"')
# pylint: disable=R1710
def __parse_major_version(version_line: str) -> int:
version_regex = re.compile(r'version "(?P<major>\d+)\.(?P<minor>\d+)\.\w+"')
match = version_regex.search(version_line)
if not match:
return None
major_version = int(match.group("major"))
if major_version == 1:
major_version = int(match.group("minor"))
if major_version is None:
pytest.fail(f"Couldn't parse Java version from {version_line}.")
return major_version
if match is not None:
major_version = int(match.group("major"))
if major_version == 1:
# we need to jump this hoop due to Java version naming conventions - it's fun:
# https://softwareengineering.stackexchange.com/questions/175075/why-is-java-version-1-x-referred-to-as-java-x
major_version = int(match.group("minor"))
return major_version
pytest.fail(f"Couldn't parse Java version from {version_line}.")
9 changes: 5 additions & 4 deletions tests/integration/test_word_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List

import pytest
from pyspark.sql import SparkSession

from data_transformations.wordcount import word_count_transformer

Expand All @@ -19,7 +20,7 @@ def _get_file_paths(input_file_lines: List[str]) -> Tuple[str, str]:


@pytest.mark.skip
def test_should_tokenize_words_and_count_them(SPARK) -> None:
def test_should_tokenize_words_and_count_them(spark_session: SparkSession) -> None:
lines = [
"In my younger and more vulnerable years my father gave me some advice that I've been "
"turning over in my mind ever since. \"Whenever you feel like criticising any one,\""
Expand All @@ -46,9 +47,9 @@ def test_should_tokenize_words_and_count_them(SPARK) -> None:
]
input_file_path, output_path = _get_file_paths(lines)

word_count_transformer.run(SPARK, input_file_path, output_path)
word_count_transformer.run(spark_session, input_file_path, output_path)

actual = SPARK.read.csv(output_path, header=True, inferSchema=True)
actual = spark_session.read.csv(output_path, header=True, inferSchema=True)
expected_data = [
["a", 4],
["across", 1],
Expand Down Expand Up @@ -258,6 +259,6 @@ def test_should_tokenize_words_and_count_them(SPARK) -> None:
["you've", 1],
["younger", 1],
]
expected = SPARK.createDataFrame(expected_data, ["word", "count"])
expected = spark_session.createDataFrame(expected_data, ["word", "count"])

assert actual.collect() == expected.collect()

0 comments on commit 6d9ad34

Please sign in to comment.