diff --git a/snowpark_extensions/dataframe_reader_extensions.py b/snowpark_extensions/dataframe_reader_extensions.py index 994076d..b189bc4 100644 --- a/snowpark_extensions/dataframe_reader_extensions.py +++ b/snowpark_extensions/dataframe_reader_extensions.py @@ -9,6 +9,8 @@ import logging DataFrameReader.___extended = True DataFrameReader.__option = DataFrameReader.option + DataFrameReader.__csv = DataFrameReader.csv + def _option(self, key: str, value: Any) -> "DataFrameReader": key = key.upper() if key == "SEP" or key == "DELIMITER": @@ -72,7 +74,55 @@ def _format(self, file_type: str) -> "DataFrameReader": self._file_type = file_type else: raise Exception(f"Unsupported file format {file_type}") - + + def _csv(self, + path: str, + schema: Optional[Union[StructType, str]] = None, + sep: Optional[str] = None, + encoding: Optional[str] = None, + quote: Optional[str] = None, + escape: Optional[str] = None, + comment: Optional[str] = None, + header: Optional[Union[bool, str]] = None, + inferSchema: Optional[Union[bool, str]] = None, + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None, + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None, + nullValue: Optional[str] = None, + nanValue: Optional[str] = None, + positiveInf: Optional[str] = None, + negativeInf: Optional[str] = None, + dateFormat: Optional[str] = None, + timestampFormat: Optional[str] = None, + maxColumns: Optional[Union[int, str]] = None, + maxCharsPerColumn: Optional[Union[int, str]] = None, + maxMalformedLogPerPartition: Optional[Union[int, str]] = None, + mode: Optional[str] = None, + columnNameOfCorruptRecord: Optional[str] = None, + multiLine: Optional[Union[bool, str]] = None, + charToEscapeQuoteEscaping: Optional[str] = None, + samplingRatio: Optional[Union[float, str]] = None, + enforceSchema: Optional[Union[bool, str]] = None, + emptyValue: Optional[str] = None, + locale: Optional[str] = None, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + modifiedBefore: Optional[Union[bool, str]] = None, + modifiedAfter: Optional[Union[bool, str]] = None, + unescapedQuoteHandling: Optional[str] = None + ) -> "DataFrame": + params = {k: v for k, v in locals().items() if v is not None} + params.pop("self", None) + params.pop("path", None) + params.pop("schema", None) + if schema: + self.schema(schema) + for key, value in params.items(): + self = self.option(key, value) + return self.__csv(path) + DataFrameReader.format = _format DataFrameReader.load = _load - DataFrameReader.option = _option \ No newline at end of file + DataFrameReader.option = _option + DataFrameReader.csv = _csv + \ No newline at end of file diff --git a/tests/test_dataframe_reader_extensions.py b/tests/test_dataframe_reader_extensions.py index 1458a49..7a9c2bc 100644 --- a/tests/test_dataframe_reader_extensions.py +++ b/tests/test_dataframe_reader_extensions.py @@ -1,10 +1,34 @@ import pytest -from snowflake.snowpark import Session, Row +from snowflake.snowpark import Session, Row, DataFrameReader from snowflake.snowpark.types import * +from snowflake.snowpark.dataframe import _generate_prefix import snowpark_extensions def test_load(): session = Session.builder.from_snowsql().getOrCreate() + cases = session.read.load(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], + schema=get_schema(), + format="csv", + sep=",", + header="true") + assert 10 == len(cases.collect()) + +def test_csv(): + session = Session.builder.from_snowsql().getOrCreate() + stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEST_STAGE")}' + session.sql(f'CREATE TEMPORARY STAGE IF NOT EXISTS {stage}').show() + session.file.put(f"file://./tests/data/test1_0.csv", f"@{stage}") + session.file.put(f"file://./tests/data/test1_1.csv", f"@{stage}") + dfReader = session.read + csvInfo = dfReader.csv(f"@{stage}", + schema=get_schema(), + sep=",", + header="true") + assert 10 == len(csvInfo.collect()) + assert dfReader._cur_options["FIELD_DELIMITER"] == "," + assert dfReader._cur_options["SKIP_HEADER"] == 1 + +def get_schema(): schema = StructType([ \ StructField("case_id", StringType()), \ StructField("province", StringType()), \ @@ -13,11 +37,6 @@ def test_load(): StructField("infection_case",StringType()), \ StructField("confirmed", IntegerType()), \ StructField("latitude", FloatType()), \ - StructField("cilongitudety", FloatType()) \ + StructField("longitude", FloatType()) \ ]) - cases = session.read.load(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], - schema=schema, - format="csv", - sep=",", - header="true") - assert 10 == len(cases.collect()) \ No newline at end of file + return schema \ No newline at end of file