From 3d451219c33c3fcb279c809237db9ba9622ac4a1 Mon Sep 17 00:00:00 2001 From: Orellabac Date: Wed, 8 Feb 2023 19:59:47 -0600 Subject: [PATCH] 0.0.26 (#29) * progress on table function changes * reorganize functions * save progress * upgrade version and notes in change log --- CHANGE_LOG.txt | 8 + README.md | 54 ++- extras/README.md | 6 + extras/notebooks/runner/README.md | 15 + extras/notebooks/runner/example1.ipynb | 54 +++ runner => extras/notebooks/runner/runner.py | 46 +- requirements.txt | 1 - runner.bat | 9 - setup.py | 10 +- snowpark_extensions/__init__.py | 1 + snowpark_extensions/dataframe_extensions.py | 414 ++++++------------ .../dataframe_reader_extensions.py | 59 +++ snowpark_extensions/functions_extensions.py | 200 ++++++--- .../session_builder_extensions.py | 5 +- tests/data/test1_0.csv | 6 + tests/data/test1_1.csv | 6 + tests/test_dataframe_extensions.py | 101 +++-- tests/test_dataframe_reader_extensions.py | 23 + tests/test_functions.py | 11 +- 19 files changed, 618 insertions(+), 411 deletions(-) create mode 100644 extras/README.md create mode 100644 extras/notebooks/runner/README.md create mode 100644 extras/notebooks/runner/example1.ipynb rename runner => extras/notebooks/runner/runner.py (86%) delete mode 100644 runner.bat create mode 100644 snowpark_extensions/dataframe_reader_extensions.py create mode 100644 tests/data/test1_0.csv create mode 100644 tests/data/test1_1.csv create mode 100644 tests/test_dataframe_reader_extensions.py diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index 0c6d085..7e723cd 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -129,3 +129,11 @@ Fixing an issue with the current implementation of applyInPandas Version 0.0.25 -------------- Change in implementation of regexp_split to support different regular expression cases + +Version 0.0.26 +-------------- +- Changes in the implementation for explode / explode_outer / array_zip / flatten +to take advantege of changes in snowpark lib. +- adding a stack method similar in functionality to unpivot +- removing dependency to shortuuid +- adding extensions for DataFrameReader diff --git a/README.md b/README.md index bae3530..1dcfda8 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ order by start_time desc; | DataFrame.groupby.applyInPandas| Maps each group of the current DataFrame using a pandas udf and returns the result as a DataFrame. | | DataFrame.replace | extends replace to allow using a regex | DataFrame.groupBy.pivot | extends the snowpark groupby to add a pivot operator - +| DataFrame.stack | This is an operator similar to the unpivot operator ### Examples @@ -193,6 +193,58 @@ df.group_by("ID").applyInPandas( ------------------------------ ``` +### stack + +Assuming you have a DataTable like: + +# +-------+---------+-----+---------+----+ +# | Name|Analytics| BI|Ingestion| ML| +# +-------+---------+-----+---------+----+ +# | Mickey| null|12000| null|8000| +# | Martin| null| 5000| null|null| +# | Jerry| null| null| 1000|null| +# | Riley| null| null| null|9000| +# | Donald| 1000| null| null|null| +# | John| null| null| 1000|null| +# |Patrick| null| null| null|1000| +# | Emily| 8000| null| 3000|null| +# | Arya| 10000| null| 2000|null| +# +-------+---------+-----+---------+----+ + +```python +df.select("NAME",df.stack(4,lit('Analytics'), "ANALYTICS", lit('BI'), "BI", lit('Ingestion'), "INGESTION", lit('ML'), "ML").alias("Project", "Cost_To_Project")).filter(col("Cost_To_Project").is_not_null()).orderBy("NAME","Project") +``` + +That will return: +``` +'------------------------------------------- +|"NAME" |"PROJECT" |"COST_TO_PROJECT" | +------------------------------------------- +|Arya |Analytics |10000 | +|Arya |Ingestion |2000 | +|Donald |Analytics |1000 | +|Emily |Analytics |8000 | +|Emily |Ingestion |3000 | +|Jerry |Ingestion |1000 | +|John |Ingestion |1000 | +|Martin |BI |5000 | +|Mickey |BI |12000 | +|Mickey |ML |8000 | +|Patrick |ML |1000 | +|Riley |ML |9000 | +------------------------------------------- +``` + +## DataFrameReader Extensions + +| Name | Description | +|--------------------------------|-------------------------------------------------------------------------------------| +| DataFrameReader.format | Specified the format of the file to load +| DataFrameReader.load | Loads a dataframe from a file. It will upload the files to an stage if needed + +### Example + + ## Functions Extensions | Name | Description | diff --git a/extras/README.md b/extras/README.md new file mode 100644 index 0000000..becc352 --- /dev/null +++ b/extras/README.md @@ -0,0 +1,6 @@ +# Snowpark Extensions Extras + +These "extras" are experimental extensions. These extensions are meant to test some snowpark capabilities. +We put them as experimental, as they might require some additional testing or apply only in some scenarios. + + diff --git a/extras/notebooks/runner/README.md b/extras/notebooks/runner/README.md new file mode 100644 index 0000000..1db6821 --- /dev/null +++ b/extras/notebooks/runner/README.md @@ -0,0 +1,15 @@ +# Notebook Runner + +The notebook runner is a small example that allows you to run a notebook from within snowflake. + +The runner script will: +1. connect to snowflake +2. upload the notebook, +3. publish a storeproc, +4. run the store procedure, +5. save the results of the notebook and +6. then download the results as an html + +This script call also be used to publish a permanent stored proc that can then be used to run any notebook that is already on an stage, +or to schedule a task to run a notebook. + diff --git a/extras/notebooks/runner/example1.ipynb b/extras/notebooks/runner/example1.ipynb new file mode 100644 index 0000000..0c28867 --- /dev/null +++ b/extras/notebooks/runner/example1.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from snowflake.snowpark import Session\n", + "import snowpark_extensions\n", + "# will try to setup credential from the snowsql CLI if present of from SNOW_xxx or SNOWSQL_xxx variables\n", + "# if not configuration can be retrieve you will receive an error\n", + "session = Session.builder.from_snowsql().from_env().getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])\n", + "df.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16 (default, Jan 10 2023, 15:23:34) \n[GCC 9.4.0]" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "9ac03a0a6051494cc606d484d27d20fce22fb7b4d169f583271e11d5ba46a56e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/runner b/extras/notebooks/runner/runner.py similarity index 86% rename from runner rename to extras/notebooks/runner/runner.py index cb76244..2f7b5b2 100755 --- a/runner +++ b/extras/notebooks/runner/runner.py @@ -4,33 +4,61 @@ import argparse from rich import print -import shortuuid import os from snowflake.snowpark import Session from snowflake.snowpark.functions import sproc import snowpark_extensions -print("[cyan]Snowpark Extensions Utilities") +print("[cyan]Snowpark Extensions Extras") +print("[cyan]Notebook Runner") print("[cyan]=============================") -print("This tool will connect using snowconfig file") arguments = argparse.ArgumentParser() -arguments.add_argument("--notebook",help="Jupyter Notebook to run",required=True) +arguments.add_argument("--notebook",help="Jupyter Notebook to run") +arguments.add_argument("--registerproc",default="",help="Register an stored proc that can then be used to run notebooks") arguments.add_argument("--stage",help="stage",default="NOTEBOOK_RUN") arguments.add_argument("--packages",help="packages",default="") +arguments.add_argument("--imports" ,help="imports" ,default="") +arguments.add_argument("--connection",dest="connection_args",nargs="*",required=True,help="Connect options, for example snowsql, snowsql connection,env") args = arguments.parse_args() -session = Session.builder.from_snowsql().getOrCreate() +print(args) +session = None +try: + if len(args.connection_args) >= 1: + first_arg = args.connection_args[0] + rest_args = args.connection_args[1:] + if first_arg == "snowsql": + session = Session.builder.from_snowsql(*rest_args).create() + elif first_arg == "env": + session = Session.builder.from_env().create() + else: + connection_args={} + for arg in args.connection_args: + key, value = arg.split("=") + connection_args[key] = value + session = Session.builder.configs(connection_args).create() +except Exception as e: + print(e) + print("[red] An error happened while trying to connect") + exit(1) +if not session: + print("[red] Not connected. Aborting") + exit(2) session.sql(f"CREATE STAGE IF NOT EXISTS {args.stage}").show() -session.file.put(args.notebook,f'@{args.stage}',auto_compress=False,overwrite=True) - +print(f"Uploading notebook to stage {args.stage}") +session.file.put(f"file://{args.notebook}",f'@{args.stage}',auto_compress=False,overwrite=True) +print(f"Notebook uploaded") packages=["snowflake-snowpark-python","nbconvert","nbformat","ipython","jinja2==3.0.3","plotly"] packages.extend(set(filter(None, args.packages.split(',')))) print(f"Using packages [magenta]{packages}") - -@sproc(replace=True,is_permanent=False,packages=packages,imports=["@test/snowpark_extensions.zip","@test/shortuuid.zip"]) #,"@test/IPython.zip" +imports=[] +if args.imports: + imports.extend(args.imports.split(',')) +is_permanent=False +@sproc(name=args.registerproc,replace=True,is_permanent=is_permanent,packages=packages,imports=[]) def run_notebook(session:Session,stage:str,notebook_filename:str) -> str: # (c) Matthew Wardrop 2019; Licensed under the MIT license # diff --git a/requirements.txt b/requirements.txt index 430cc3c..9a715a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ snowflake-snowpark-python[pandas] pandas -shortuuid rich nest_asyncio jinja2 diff --git a/runner.bat b/runner.bat deleted file mode 100644 index c50e991..0000000 --- a/runner.bat +++ /dev/null @@ -1,9 +0,0 @@ -@echo off -setlocal -REM Set the python io encoding to UTF-8 by default if not set. -IF "%PYTHONIOENCODING%"=="" ( - SET PYTHONIOENCODING="UTF-8" -) -python "%~dp0\runner" %* - -endlocal \ No newline at end of file diff --git a/setup.py b/setup.py index 4db518d..8f88c2b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -VERSION = '0.0.25' +VERSION = '0.0.26' setup(name='snowpark_extensions', version=VERSION, @@ -14,12 +14,8 @@ long_description_content_type='text/markdown', url='http://github.com/MobilizeNet/snowpark-extensions-py', author='mauricio.rojas', - install_requires=['snowflake-snowpark-python[pandas]>=1.1.0', - 'shortuuid', 'nest_asyncio', 'jinja2', 'rich'], + install_requires=['snowflake-snowpark-python[pandas]==1.1.0', + 'nest_asyncio', 'jinja2', 'rich'], author_email='mauricio.rojas@mobilize.net', packages=['snowpark_extensions'], - scripts=[ - 'runner', - 'runner.bat' - ], zip_safe=False) diff --git a/snowpark_extensions/__init__.py b/snowpark_extensions/__init__.py index 8b776dd..ee979c8 100644 --- a/snowpark_extensions/__init__.py +++ b/snowpark_extensions/__init__.py @@ -2,6 +2,7 @@ from .dataframe_extensions import * +from .dataframe_reader_extensions import * from .functions_extensions import * from .session_builder_extensions import * from .types_extensions import * diff --git a/snowpark_extensions/dataframe_extensions.py b/snowpark_extensions/dataframe_extensions.py index 8139fe5..75be104 100644 --- a/snowpark_extensions/dataframe_extensions.py +++ b/snowpark_extensions/dataframe_extensions.py @@ -2,17 +2,20 @@ from snowflake.snowpark.functions import col, lit, udtf, regexp_replace from snowflake.snowpark import functions as F from snowflake.snowpark.dataframe import _generate_prefix -from snowflake.snowpark.functions import table_function +from snowflake.snowpark.functions import table_function, udf from snowflake.snowpark.column import _to_col_if_str, _to_col_if_lit import pandas as pd import numpy as np from snowpark_extensions.utils import map_to_python_type, schema_str_to_schema -import shortuuid from snowflake.snowpark import context from snowflake.snowpark.types import StructType,StructField from snowflake.snowpark._internal.analyzer.expression import Expression, FunctionExpression from snowflake.snowpark._internal.analyzer.unary_expression import Alias from snowflake.snowpark._internal.analyzer.analyzer_utils import quote_name +from snowflake.snowpark import Window, Column +from snowflake.snowpark.types import * +from snowflake.snowpark.functions import udtf, col +from snowflake.snowpark.relational_grouped_dataframe import RelationalGroupedDataFrame from typing import ( TYPE_CHECKING, @@ -32,8 +35,50 @@ LiteralType, ) +from snowflake.snowpark._internal.utils import ( + parse_positional_args_to_list +) + +from snowflake.snowpark.table_function import ( + TableFunctionCall, + _create_table_function_expression, + _get_cols_after_join_table, +) + +from snowflake.snowpark._internal.analyzer.table_function import ( + TableFunctionJoin +) + +from snowflake.snowpark._internal.analyzer.select_statement import ( + SelectStatement, + SelectSnowflakePlan +) + if not hasattr(DataFrame,"___extended"): + DataFrame.___extended = True + + # we need to extend the alias function for + # table function to allow the situation where + # the function returns several columns + def adjusted_table_alias(self,*aliases) -> "TableFunctionCall": + canon_aliases = [quote_name(col) for col in aliases] + if len(set(canon_aliases)) != len(aliases): + raise ValueError("All output column names after aliasing must be unique.") + if hasattr(self, "alias_adjust"): + """ + currently tablefunctions are rendered as table(func(....)) + One option later on could be to render this is (select col1,col2,col3,col4 from table(func(...))) + aliases can the be use as (select col1 alias1,col2 alias2 from table(func(...))) + """ + self._aliases = self.alias_adjust(*canon_aliases) + else: + self._aliases = canon_aliases + return self + + TableFunctionCall.alias = adjusted_table_alias + TableFunctionCall.as_ = adjusted_table_alias + def get_dtypes(schema): data = np.array([map_to_python_type(x.datatype) for x in schema.fields]) # providing an index @@ -44,13 +89,13 @@ def get_dtypes(schema): def map(self,func,output_types,input_types=None,input_cols=None,to_row=False): - clazz="map"+shortuuid.uuid()[:8] + clazz= _generate_prefix("map") output_schema=[] if not input_types: input_types = [x.datatype for x in self.schema.fields] input_cols_len=len(input_types) if not input_cols: - input_col_names=self.columns[:input_cols_len] + input_col_names=self.columns[:input_cols_len] _input_cols = [self[x] for x in input_col_names] else: input_col_names=input_cols @@ -75,8 +120,8 @@ def simple_map(self,func): DataFrame.simple_map = simple_map - DataFrameNaFunctions.__oldreplace = DataFrameNaFunctions.replace - + DataFrameNaFunctions.__oldreplace = DataFrameNaFunctions.replace + def extended_replace( self, to_replace: Union[ @@ -92,287 +137,88 @@ def extended_replace( return self._df.select([regexp_replace(col(x.name), to_replace,value).alias(x.name) if isinstance(x.datatype,StringType) else col(x.name) for x in self._df.schema]) else: return self.__oldreplace(to_replace,value,subset) - + DataFrameNaFunctions.replace = extended_replace def has_null(col): return F.array_contains(F.sql_expr("parse_json('null')"),col) | F.coalesce(F.array_contains(lit(None) ,col),lit(False)) - - # SPECIAL COLUMN HELPERS - - class SpecialColumn(Column): - def __init__(self,special_column_base_name="special_column"): - self.special_col_name = _generate_prefix(special_column_base_name) - super().__init__(self.special_col_name) - self._is_special_column = True - self._expression._is_special_column = True - sc = self - def _expand(df): - nonlocal sc - return sc.expand(df) - self._expression.expand = _expand - self._special_column_dependencies=[] - import uuid - self._hash = hash(str(uuid.uuid4())) - self.expanded = False - def __hash__(self): - return self._hash - def __eq__(self, other): - return self._hash == other._hash - def gen_unique_value_name(self,idx,base_name): - return base_name if idx == 0 else f"{base_name}_{idx}" - def add_columns(self,new_cols, alias:str=None): - pass - def expand(self,df): - self.expanded = True - @classmethod - def extract_specials(cls,c): - if isinstance(c, Expression): - if c.children: - for child in c.children: - is_child_special = cls.has_special_column(child) - if is_child_special: - return True - if hasattr(c,"_is_special_column"): - return True - if hasattr(c,"_expression") and c._expression and c._expression.children: - for child in c._expression.children: - is_child_special = cls.has_special_column(child) - if is_child_special: - return True - @classmethod - def any_specials(cls,*cols): - for c in cols: - for special in cls.specials(c): - return True - @classmethod - def specials(cls,c): - """ Returns special columns that might add a join clause to the dataframe """ - if isinstance(c, Expression): - if c.children: - for child in c.children: - for special in cls.specials(child): - yield special - elif hasattr(c,"_expression") and c._expression.children: - for child in c._expression.children: - for special in cls.specials(child): - yield special - if hasattr(c,"_special_column_dependencies") and c._special_column_dependencies: - for dependency in c._special_column_dependencies: - for special in cls.specials(dependency): - yield special - if hasattr(c,"_is_special_column"): - yield c - - class MapValues(SpecialColumn): - def __init__(self,array_col): - super().__init__("values") - self.array_col = array_col - self._special_column_dependencies = [array_col] - def add_columns(self, new_cols, alias:str = None): - # add itself as column - new_cols.append(self.alias(alias) if alias else self) - def expand(self,df): - if not self.expanded: - array_col = _to_col_if_str(self.array_col, "values") - df = df.with_column("__IDX",F.seq8()) - flatten = table_function("flatten") - seq=_generate_prefix("SEQ") - key=_generate_prefix("KEY") - path=_generate_prefix("PATH") - index=_generate_prefix("INDEX") - value=_generate_prefix("VALUE") - this=_generate_prefix("THIS") - df_values=df.join_table_function(flatten(input=array_col,outer=lit(True)).alias(seq,key,path,index,value,this)).group_by("__IDX").agg(F.array_agg(value).alias(self.special_col_name)).distinct() - df = df.join(df_values,on="__IDX").drop("__IDX") - self.expanded = True - return df - - class ArrayFlatten(SpecialColumn): - def __init__(self,flatten_col,remove_arrays_when_there_is_a_null): - super().__init__("flatten") - self.flatten_col = flatten_col - self.remove_arrays_when_there_is_a_null = remove_arrays_when_there_is_a_null - self._special_column_dependencies = [flatten_col] - def add_columns(self, new_cols, alias:str = None): - new_cols.append(self.alias(alias) if alias else self) - def expand(self,df): - if not self.expanded: - array_col = _to_col_if_str(self.flatten_col, "flatten") - flatten = table_function("flatten") - df=df.join_table_function(flatten(array_col).alias("__SEQ_FLATTEN","KEY","PATH","__INDEX_FLATTEN","__FLATTEN_VALUE","THIS")) - df = df.drop("KEY","PATH","THIS") - if self.remove_arrays_when_there_is_a_null: - df_with_has_null=df.withColumn("__HAS_NULL",has_null(array_col)) - df_flattened= df_with_has_null.group_by(col("__SEQ_FLATTEN")).agg(F.call_builtin("BOOLOR_AGG",col("__HAS_NULL")).alias("__HAS_NULL"),F.call_builtin("ARRAY_UNION_AGG",col("__FLATTEN_VALUE")).alias("__FLATTEN_VALUE")) - df_flattened=df_flattened.with_column("__FLATTEN_VALUE",F.iff("__HAS_NULL", lit(None), col("__FLATTEN_VALUE"))).drop("__HAS_NULL") - df=df.drop("__FLATTEN_VALUE").where(col("__INDEX_FLATTEN")==0).join(df_flattened,on="__SEQ_FLATTEN").drop("__SEQ_FLATTEN","__INDEX_FLATTEN").rename("__FLATTEN_VALUE",self.special_col_name) - else: - df_flattened= df.group_by(col("__SEQ_FLATTEN")).agg(F.call_builtin("ARRAY_UNION_AGG",col("__FLATTEN_VALUE")).alias("__FLATTEN_VALUE")) - df=df.drop("__FLATTEN_VALUE").where(col("__INDEX_FLATTEN")==0).join(df_flattened,on="__SEQ_FLATTEN").drop("__SEQ_FLATTEN","__INDEX_FLATTEN").rename("__FLATTEN_VALUE",self.special_col_name) - self.expanded = True - return df - - class ArrayZip(SpecialColumn): - def __init__(self,left,*right,use_compat=False): - super().__init__("zipped") - self.left_col = left - self.right_cols = right - self._special_column_dependencies = [left,*right] - self._use_compat = use_compat - def add_columns(self,new_cols,alias:str = None): - new_cols.append(self.alias(alias) if alias else self) - def expand(self,df): - if not self.expanded: - if (not hasattr(ArrayZip,"_added_compat_func")) and self._use_compat: - context.get_active_session().sql(""" - CREATE OR REPLACE TEMPORARY FUNCTION PUBLIC.ARRAY_UNDEFINED_COMPACT(ARR VARIANT) RETURNS ARRAY - LANGUAGE JAVASCRIPT AS - $$ - if (ARR.includes(undefined)){ - filtered = ARR.filter(x => x === undefined); - if (filtered.length==0) - return filtered; - } - return ARR; - $$; - """).count() - ArrayZip._added_compat_func = True - df_with_idx = df.with_column("_IDX",F.seq8()) - flatten = table_function("flatten") - right = df_with_idx.select("_IDX",self.left_col)\ - .join_table_function(flatten(input=self.left_col,outer=lit(True))\ - .alias("SEQ","KEY","PATH","INDEX","__VALUE_0","THIS")) \ - .drop(self.left_col,"SEQ","KEY","PATH","THIS") - vals=["__VALUE_0"] - for right_col in self.right_cols: - prior=len(vals)-1 - next=len(vals) - left_col_name=f"__VALUE_{prior}" - right_col_name=f"__VALUE_{next}" - vals.append(right_col_name) - new_right=df_with_idx.select("_IDX",right_col).join_table_function(flatten(input=right_col,outer=lit(False))\ - .alias("SEQ","KEY","PATH","INDEX",right_col_name,"THIS")) \ - .drop(right_col,"SEQ","KEY","PATH","THIS") #.with_column("INDEX",F.coalesce(col("INDEX"),lit(0))) \ - if right: - right = right.join(new_right,on=["_IDX","INDEX"],how="left",lsuffix="___LEFT") - else: - right = new_right - zipped = right.select("_IDX","INDEX",F.array_construct(*vals).alias("NGROUP")) - if self._use_compat: - zipped = zipped.with_column("NGROUP",F.call_builtin("ARRAY_UNDEFINED_COMPACT",col("NGROUP"))) - zipped=zipped.group_by("_IDX").agg(F.sql_expr(f'arrayagg(ngroup) within group (order by INDEX) {self.special_col_name}')) - result = df_with_idx.join(zipped,on="_IDX").drop("_IDX") - df = result - self.expanded = True - return df - - class Explode(SpecialColumn): - def __init__(self,expr,map=False,outer=False,use_compat=False): - """ Right not it must be explictly stated if the value is a map. By default it is assumed it is not""" - super().__init__("value" if map else "col") - self.expr = expr - self.map = map - self.outer = outer - self.key_col_name = None - self.key_col = None - self.value_col_name = self.special_col_name - self._special_column_dependencies = [expr] - self.use_compat=use_compat - def add_columns(self,new_cols,alias:str = None): - if self.map: - self.key_col_name = _generate_prefix("key") - self.key_col = col(self.key_col_name) - new_cols.append(self.key_col.alias(alias + "_1") if alias else self.key_col) - new_cols.append(self.alias(alias) if alias else self) - def expand(self,df): - if not self.expanded: - if self.key_col_name is not None: - df = df.join_table_function(flatten(input=self.expr,outer=lit(self.outer)).alias("SEQ",self.key_col_name,"PATH","INDEX",self.value_col_name,"THIS")).drop(["SEQ","PATH","INDEX","THIS"]) - else: - df = df.join_table_function(flatten(input=self.expr,outer=lit(self.outer)).alias("SEQ","KEY","PATH","INDEX",self.value_col_name,"THIS")).drop(["SEQ","KEY","PATH","INDEX","THIS"]) - if self.use_compat: - df=df.with_column(self.value_col_name, - F.iff( - F.cast(self.value_col_name,ArrayType()) == F.array_construct(), - lit(None), - F.cast(self.value_col_name,ArrayType()))) - self.expanded = True - return df - - def explode(expr,outer=False,map=False,use_compat=False): - return Explode(expr,map,outer,use_compat=use_compat) - - def explode_outer(expr,map=False, use_compat=False): - return Explode(expr,map,True,use_compat=use_compat) - - F.explode = explode - F.explode_outer = explode_outer - def _arrays_zip(left,*right,use_compat=False): - """ In SF zip might return [undefined,...undefined] instead of [] """ - return ArrayZip(left,*right,use_compat=use_compat) - def _arrays_flatten(array_col,remove_arrays_when_there_is_a_null=True): - return ArrayFlatten(array_col,remove_arrays_when_there_is_a_null) - def _map_values(col:ColumnOrName): - col = _to_col_if_str(col,"map_values") - return MapValues(col) - - F.map_values = _map_values - F.arrays_zip = _arrays_zip - F.flatten = _arrays_flatten - flatten = table_function("flatten") - _oldwithColumn = DataFrame.withColumn - _oldSelect = DataFrame.select - def withColumnExtended(self,colname,expr): - if isinstance(expr, SpecialColumn): - new_cols = [] - expr.add_columns(new_cols, alias=colname) - df = self - for s in SpecialColumn.specials(expr): - df=s.expand(df) - return _oldSelect(df,*self.columns,*new_cols) - #return self.with_columns(df,new_cols,[col(x) for x in new_cols]) - else: - return _oldwithColumn(self,colname,expr) - - DataFrame.withColumn = withColumnExtended - - - def selectExtended(self,*cols): - if SpecialColumn.any_specials(*cols): - new_cols = [] - extended_cols = [] - # extend only the main cols - for x in cols: - if isinstance(x, SpecialColumn): - x.add_columns(new_cols) - extended_cols.append(x) + def selectExtended(self,*cols) -> "DataFrame": + exprs = parse_positional_args_to_list(*cols) + if not exprs: + raise ValueError("The input of select() cannot be empty") + names = [] + table_func = None + join_plan = None + for e in exprs: + if isinstance(e, Column): + names.append(e._named()) + elif isinstance(e, str): + names.append(Column(e)._named()) + elif isinstance(e, TableFunctionCall): + if table_func: + raise ValueError( + f"At most one table function can be called inside a select(). " + f"Called '{table_func.name}' and '{e.name}'." + ) + table_func = e + func_expr = _create_table_function_expression(func=table_func) + join_plan = self._session._analyzer.resolve( + TableFunctionJoin(self._plan, func_expr) + ) + _, new_cols = _get_cols_after_join_table( + func_expr, self._plan, join_plan + ) + names.extend(new_cols) else: - new_cols.append(x) - df = self - # but expand all tables, because there could be several joins - for c in cols: - for extended_col in SpecialColumn.specials(c): - df = extended_col.expand(df) - return _oldSelect(df,*[_to_col_if_str(x,"extended") for x in new_cols]) - else: - return _oldSelect(self,*cols) - - DataFrame.select = selectExtended + raise TypeError( + "The input of select() must be Column, column name, TableFunctionCall, or a list of them" + ) + if self._select_statement: + if join_plan: + result=self._with_plan( + SelectStatement( + from_=SelectSnowflakePlan( + join_plan, analyzer=self._session._analyzer + ), + analyzer=self._session._analyzer, + ).select(names) + ) + if table_func and hasattr(table_func,"post_action"): + result = table_func.post_action(result) + return result + return self._with_plan(self._select_statement.select(names)) + result = self._with_plan(Project(names, join_plan or self._plan)) + if table_func and hasattr(table_func,"post_action"): + result = table_func.post_action(result) + return result -import shortuuid -from snowflake.snowpark import Window, Column -from snowflake.snowpark.types import * -from snowflake.snowpark.functions import udtf, col -from snowflake.snowpark.relational_grouped_dataframe import RelationalGroupedDataFrame + DataFrame.select = selectExtended -def group_by_pivot(self,pivot_col): - return GroupByPivot(self, pivot_col) -RelationalGroupedDataFrame.pivot = group_by_pivot + def stack(self,rows:int,*cols): + count_cols = len(cols) + if count_cols % rows != 0: + raise Exception("Invalid parameter. The given cols cannot be arrange in the give rows") + out_count_cols = int(count_cols / rows) + # determine the input schema + input_schema = self.select(cols).limit(1).schema + from snowflake.snowpark.functions import col + input_types = [x.datatype for x in input_schema.fields] + input_cols = [x.name for x in input_schema.fields] + output_cols = [f"col{x}" for x in range(1,out_count_cols)] + clazz=_generate_prefix("stack") + def process(self, *row): + for i in range(0, len(row), out_count_cols): + yield tuple(row[i:i+out_count_cols]) + output_schema = StructType([StructField(f"col{i+1}",input_schema.fields[i].datatype) for i in range(0,out_count_cols)]) + udtf_class = type(clazz, (object, ), {"process":process}) + tfunc = udtf(udtf_class,output_schema=output_schema, input_types=input_types,name=clazz,replace=True,is_permanent=False,packages=["snowflake-snowpark-python"]) + return tfunc(*cols) + + DataFrame.stack = stack -class GroupByPivot(): + class GroupByPivot(): def __init__(self,old_groupby_col,pivot_col): self.old_groupby_col = old_groupby_col self.pivot_col=pivot_col @@ -414,12 +260,16 @@ def max(self, col: ColumnOrName) -> DataFrame: def count(self) -> DataFrame: """Return the number of rows for each group.""" return self.clean(self.prepare(col).count("__firstAggregate")) - def agg(self, aggregated_col: ColumnOrName) -> DataFrame: + def agg(self, aggregated_col: ColumnOrName) -> DataFrame: if hasattr(aggregated_col, "_expression") and isinstance(aggregated_col._expression, FunctionExpression): name = aggregated_col._expression.name return self.clean(self.prepare(aggregated_col).function(name)(col("__firstAggregate"))) else: - raise Exception("Also functions expressions are supported") + raise Exception("Also functions expressions are supported") + + def group_by_pivot(self,pivot_col): + return GroupByPivot(self, pivot_col) + RelationalGroupedDataFrame.pivot = group_by_pivot if not hasattr(RelationalGroupedDataFrame, "applyInPandas"): def applyInPandas(self,func,schema,batch_size=16000): @@ -431,7 +281,7 @@ def applyInPandas(self,func,schema,batch_size=16000): input_cols = [x.name for x in self._df.schema.fields] output_cols = [x.name for x in output_schema.fields] grouping_exprs = [Column(x) for x in self._grouping_exprs] - clazz="applyInPandas"+shortuuid.uuid()[:8] + clazz=_generate_prefix("applyInPandas") def __init__(self): self.rows = [] self.dfs = [] @@ -453,7 +303,7 @@ def end_partition(self): df = pd.DataFrame(self.rows, columns=input_cols) self.dfs.append(df) self.rows = [] - pandas_input = pd.concat(self.dfs) + pandas_input = pd.concat(self.dfs) pandas_output = func(pandas_input) for row in pandas_output.itertuples(index=False): yield tuple(row) @@ -464,4 +314,4 @@ def end_partition(self): return self._df.join_table_function(tfunc(*input_cols).over(partition_by=grouping_exprs, order_by=grouping_exprs)).select(*renamed_back) RelationalGroupedDataFrame.applyInPandas = applyInPandas - ###### HELPER END + ###### HELPER END diff --git a/snowpark_extensions/dataframe_reader_extensions.py b/snowpark_extensions/dataframe_reader_extensions.py new file mode 100644 index 0000000..d6a7149 --- /dev/null +++ b/snowpark_extensions/dataframe_reader_extensions.py @@ -0,0 +1,59 @@ +from snowflake.snowpark import DataFrame, Row, DataFrameReader +from snowflake.snowpark.types import StructType +from snowflake.snowpark import context +from typing import Any, Union, List, Optional +from snowflake.snowpark.functions import lit +from snowflake.snowpark.dataframe import _generate_prefix + +if not hasattr(DataFrameReader,"___extended"): + + DataFrameReader.___extended = True + DataFrameReader.__option = DataFrameReader.option + def _option(self, key: str, value: Any) -> "DataFrameReader": + key = key.upper() + if key == "SEP": + key = "FIELD_DELIMITER" + elif key == "HEADER": + key ="SKIP_HEADER" + value = 1 if value == True or str(value).upper() == "TRUE" else 0 + self.__option(key,value) + + def _load(self,path: Union[str, List[str], None] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None,stage=None, **options) -> "DataFrame": + self.options(dict(options)) + self.format(format) + if schema: + self.schema(schema) + files = [] + if isinstance(path,list): + files.extend(path) + else: + files.append(path) + session = context.get_active_session() + if stage is None: + stage = f'{session.get_fully_qualified_current_schema()}.{_generate_prefix("TEMP_STAGE")}' + session.sql(f'create TEMPORARY stage if not exists {stage}').show() + stage_files = [x for x in path if x.startswith("@")] + if len(stage_files) > 1: + raise Exception("Currently only one staged file can be specified. You can use a pattern if you want to specify several files") + print(f"Uploading files using stage {stage}") + for file in files: + if file.startswith("file://"): # upload local file + session.file.put(file,stage) + elif file.startswith("@"): #ignore it is on an stage + return self._read_semi_structured_file(file,format) + else: #assume it is file too + session.file.put(f"file://{file}",f"@{stage}") + if self._file_type == "csv": + return self.csv(f"@{stage}") + return self._read_semi_structured_file(f"@{stage}",format) + + def _format(self, file_type: str) -> "DataFrameReader": + file_type = str(file_type).lower() + if file_type in ["csv","json","avro","orc","parquet","xml"]: + self._file_type = file_type + else: + raise Exception(f"Unsupported file format {file_type}") + + DataFrameReader.format = _format + DataFrameReader.load = _load + DataFrameReader.option = _option \ No newline at end of file diff --git a/snowpark_extensions/functions_extensions.py b/snowpark_extensions/functions_extensions.py index 6eb02a3..8cac1b0 100644 --- a/snowpark_extensions/functions_extensions.py +++ b/snowpark_extensions/functions_extensions.py @@ -3,7 +3,7 @@ from snowflake.snowpark import functions as F from snowflake.snowpark import context -from snowflake.snowpark.functions import call_builtin, col,lit, concat, coalesce, object_construct_keep_null +from snowflake.snowpark.functions import call_builtin, col,lit, concat, coalesce, object_construct_keep_null, table_function, udf from snowflake.snowpark import DataFrame, Column from snowflake.snowpark.types import ArrayType, BooleanType from snowflake.snowpark._internal.type_utils import ( @@ -52,13 +52,6 @@ def regexp_extract(value:ColumnOrLiteralStr,regexp:ColumnOrLiteralStr,idx:int) - # we add .* to the expression if needed return coalesce(call_builtin('regexp_substr',value,regexp,lit(1),lit(1),lit('e'),idx),lit('')) - def unix_timestamp(col): - return call_builtin("DATE_PART","epoch_second",col) - - def from_unixtime(col): - col = _to_col_if_str(col,"from_unixtime") - return F.to_timestamp(col).alias('ts') - def format_number(col,d): col = _to_col_if_str(col,"format_number") return F.to_varchar(col,'999,999,999,999,999.' + '0'*d) @@ -100,7 +93,7 @@ def create_map(*col_names): col_list.append(value) return object_construct(*col_list) - def array_distinct(col): + def _array_distinct(col): col = _to_col_if_str(col,"array_distinct") return F.call_builtin('array_distinct',col) @@ -108,15 +101,15 @@ def array_distinct(col): def _array(*cols): return F.array_construct(*cols) - F._sort_array_function = None + F._sort_array_udf = None def _sort_array(col:ColumnOrName,asc:ColumnOrLiteral=True): - if not F._sort_array_function: + if not F._sort_array_udf: session = context.get_active_session() current_database = session.get_current_database() function_name =_generate_prefix("_sort_array_helper") - F._sort_array_function = f"{current_database}.public.{function_name}" + F._sort_array_udf = f"{current_database}.public.{function_name}" session.sql(f""" - create or replace temporary function {F._sort_array_function}(ARR ARRAY,ASC BOOLEAN) returns ARRAY + create or replace temporary function {F._sort_array_udf}(ARR ARRAY,ASC BOOLEAN) returns ARRAY language javascript as $$ ARRLENGTH = ARR.length; @@ -168,18 +161,18 @@ def _sort_array(col:ColumnOrName,asc:ColumnOrLiteral=True): var RES = new Array(ARRLENGTH-ARR.length).fill(null).concat(ARR); if (ASC) return RES; else return RES.reverse(); $$;""").show() - return call_builtin(F._sort_array_function,col,asc) + return call_builtin(F._sort_array_udf,col,asc) - F._array_sort_function = None + F._array_sort_udf = None def _array_sort(col:ColumnOrName): - if not F._array_sort_function: + if not F._array_sort_udf: session = context.get_active_session() current_database = session.get_current_database() function_name =_generate_prefix("_array_sort_helper") - F._array_sort_function = f"{current_database}.public.{function_name}" + F._array_sort_udf = f"{current_database}.public.{function_name}" session.sql(f""" - create or replace temporary function {F._array_sort_function}(ARR ARRAY) returns ARRAY + create or replace temporary function {F._array_sort_udf}(ARR ARRAY) returns ARRAY language javascript as $$ ARRLENGTH = ARR.length; @@ -231,37 +224,37 @@ def _array_sort(col:ColumnOrName): var RES = ARR.concat(new Array(ARRLENGTH-ARR.length).fill(null)); return RES; $$;""").show() - return call_builtin(F._array_sort_function,col) - F._array_max_function = None + return call_builtin(F._array_sort_udf,col) + F._array_max_udf = None def _array_max(col:ColumnOrName): - if not F._array_max_function: + if not F._array_max_udf: session = context.get_active_session() current_database = session.get_current_database() function_name =_generate_prefix("_array_max_function") - F._array_max_function = f"{current_database}.public.{function_name}" + F._array_max_udf = f"{current_database}.public.{function_name}" session.sql(f""" - create or replace temporary function {F._array_max_function}(ARR ARRAY) returns VARIANT + create or replace temporary function {F._array_max_udf}(ARR ARRAY) returns VARIANT language javascript as $$ return Math.max(...ARR); $$ """).show() - return call_builtin(F._array_max_function,col) - F._array_min_function = None + return call_builtin(F._array_max_udf,col) + F._array_min_udf = None def _array_min(col:ColumnOrName): - if not F._array_min_function: + if not F._array_min_udf: session = context.get_active_session() current_database = session.get_current_database() - function_name =_generate_prefix("_array_min_function") - F._array_min_function = f"{current_database}.public.{function_name}" + function_name =_generate_prefix("_array_min_udf") + F._array_min_udf = f"{current_database}.public.{function_name}" session.sql(f""" - create or replace temporary function {F._array_min_function}(ARR ARRAY) returns VARIANT + create or replace temporary function {F._array_min_udf}(ARR ARRAY) returns VARIANT language javascript as $$ return Math.min(...ARR); $$ """).show() - return call_builtin(F._array_min_function,col) + return call_builtin(F._array_min_udf,col) def _struct(*cols): new_cols = [] @@ -280,6 +273,52 @@ def _struct(*cols): new_cols.append(c) return object_construct_keep_null(*new_cols) + F._array_flatten_udf = None + def _array_flatten(array): + if not F._array_flatten_udf: + @udf + def _array_flatten(array_in:list) -> list: + flat_list = [] + for sublist in array_in: + if type(sublist) == list: + flat_list.extend(sublist) + else: + flat_list.append(sublist) + return flat_list + F._array_flatten_udf = _array_flatten + array = _to_col_if_str(array, "array_flatten") + return F._array_flatten_udf(array) + + F._array_zip_udfs = {} + + def build_array_zip_ddl(nargs:int): + function_name = _generate_prefix(f"array_zip_{nargs}") + args = ",".join([f"list{x} ARRAY" for x in range(1,nargs+1)]) + args_names = ",".join([f"list{x}" for x in range(1,nargs+1)]) + return function_name,f""" +CREATE OR REPLACE TEMPORARY FUNCTION {function_name}({args}) +returns ARRAY language python runtime_version = '3.8' +handler = 'zip_list' +as +$$ +def zip_list({args_names}): + return list(zip({args_names})) +$$;""" + + def _arrays_zip(*lists): + nargs = len(lists) + if nargs < 2: + raise Exception("At least two list are needed for array_zip") + if not str(nargs) in F._array_zip_udfs: + try: + function_name, udf_ddl = build_array_zip_ddl(nargs) + context.get_active_session().sql(udf_ddl).show() + F._array_zip_udfs[str(nargs)] = function_name + except Exception as e: + raise Exception(f"Could not register support udf for array_zip. Error: {e}") + list_cols = [_to_col_if_str(x, "array_zip") for x in lists] + return F.call_builtin(F._array_zip_udfs[str(nargs)],*list_cols) + def _bround(col: Column, scale: int = 0): power = pow(F.lit(10), F.lit(scale)) elevatedColumn = F.when(F.lit(0) == F.lit(scale), col).otherwise(col * power) @@ -297,24 +336,21 @@ def has_special_char(string): def is_not_a_regex(pattern): return not has_special_char(pattern) - F._split_regex_function = None - F.snowflake_split = F.split - def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:int = -1): - + F._split_regex_udf = None + def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:int = -1): value = _to_col_if_str(value,"split_regex") - pattern_col = pattern if isinstance(pattern, str): pattern_col = lit(pattern) if limit < 0 and isinstance(pattern, str) and is_not_a_regex(pattern): - return F.snowflake_split(value, pattern_col) + return F.split(value, pattern_col) session = context.get_active_session() current_database = session.get_current_database() function_name =_generate_prefix("_regex_split_helper") - F._split_regex_function = f"{current_database}.public.{function_name}" + F._split_regex_udf = f"{current_database}.public.{function_name}" - session.sql(f"""CREATE OR REPLACE FUNCTION {F._split_regex_function} (input String, regex String, limit INT) + session.sql(f"""CREATE OR REPLACE FUNCTION {F._split_regex_udf} (input String, regex String, limit INT) RETURNS ARRAY LANGUAGE JAVA RUNTIME_VERSION = '11' @@ -328,29 +364,73 @@ def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:int = -1 Pattern pattern = Pattern.compile(regex); return pattern.split(input, limit); }}}}$$;""").show() + return call_builtin(F._split_regex_udf, value, pattern_col, limit) + + def _explode(expr,outer=False,map=False,use_compat=False): + value_col = "explode" + if map: + key = "key" + value_col = "value" + else: + key = _generate_prefix("KEY") + seq = _generate_prefix("SEQ") + path = _generate_prefix("PATH") + index = _generate_prefix("INDEX") + this = _generate_prefix("THIS") + flatten = table_function("flatten") + explode_res = flatten(input=expr,outer=lit(outer)).alias(seq,key,path,index,value_col,this) + # we patch the alias, to simplify explode use case where only one column is used + if not map: + explode_res.alias_adjust = lambda alias1 : [seq,key,path,index,alias1,this] + # post action to execute after join + def post_action(df): + drop_columns = [seq,path,index,this] if map else [seq,key,path,index,this] + df = df.drop(drop_columns) + if use_compat: + # in case we need backwards compatibility with spark behavior + df=df.with_column(value_col, + F.iff(F.cast(value_col,ArrayType()) == F.array_construct(),lit(None),F.cast(value_col,ArrayType()))) + return df + explode_res.post_action = post_action + return explode_res + + def _explode_outer(expr,map=False, use_compat=False): + return _explode(expr,outer=True,map=map,use_compat=use_compat) + + F._map_values_udf = None + def _map_values(col:ColumnOrName): + col = _to_col_if_str(col,"map_values") + if not F._map_values_udf: + @udf(replace=True,is_permanent=False) + def map_values(obj:dict)->list: + return list(obj.values()) + F._map_values_udf = map_values + return F._map_values_udf(col) + + + - return call_builtin(F._split_regex_function, value, pattern_col, limit) - F.array = _array - F.array_max = _array_max - F.array_min = _array_min - F.array_distinct = array_distinct + F.array = _array + F.array_max = _array_max + F.array_min = _array_min + F.array_flatten = _array_flatten + F.array_distinct = _array_distinct + F.array_sort = _array_sort + F.arrays_zip = _arrays_zip + F.bround = _bround + F.create_map = create_map + F.daydiff = daydiff + F.date_add = date_add + F.date_sub = date_sub + F.explode = _explode + F.explode_outer = _explode_outer + F.format_number = format_number + F.flatten = _array_flatten + F.map_values = _map_values F.regexp_extract = regexp_extract - F.create_map = create_map - F.unix_timestamp = unix_timestamp - F.from_unixtime = from_unixtime - F.format_number = format_number - F.reverse = reverse - F.daydiff = daydiff - F.date_add = date_add - F.date_sub = date_sub - F.asc = lambda col: _to_col_if_str(col, "asc").asc() - F.desc = lambda col: _to_col_if_str(col, "desc").desc() - F.asc_nulls_first = lambda col: _to_col_if_str(col, "asc_nulls_first").asc() - F.desc_nulls_first = lambda col: _to_col_if_str(col, "desc_nulls_first").asc() - F.sort_array = _sort_array - F.array_sort = _array_sort - F.struct = _struct - F.bround = _bround - F.regexp_split = _regexp_split \ No newline at end of file + F.regexp_split = _regexp_split + F.reverse = reverse + F.sort_array = _sort_array + F.struct = _struct \ No newline at end of file diff --git a/snowpark_extensions/session_builder_extensions.py b/snowpark_extensions/session_builder_extensions.py index 971c9fe..ea1329f 100644 --- a/snowpark_extensions/session_builder_extensions.py +++ b/snowpark_extensions/session_builder_extensions.py @@ -3,12 +3,11 @@ from pathlib import Path from snowflake.snowpark import Session import os -import shortuuid import io import sys if not hasattr(Session.SessionBuilder,"___extended"): - + from snowflake.snowpark._internal.utils import generate_random_alphanumeric _logger = logging.getLogger(__name__) def console_handler(stream='stdout'): @@ -38,7 +37,7 @@ def SessionBuilder_extendedcreate(self): session = self.___create() if hasattr(self,"__appname__"): setattr(session, "__appname__", self.__appname__) - uuid = shortuuid.uuid() + uuid = generate_random_alphanumeric() session.query_tag = f"APPNAME={session.__appname__};execution_id={uuid}" return session Session.SessionBuilder.create = SessionBuilder_extendedcreate diff --git a/tests/data/test1_0.csv b/tests/data/test1_0.csv new file mode 100644 index 0000000..3793cda --- /dev/null +++ b/tests/data/test1_0.csv @@ -0,0 +1,6 @@ +case_id,province,city,group,infection_case,confirmed,latitude,longitude +1000001,Seoul,Yongsan-gu,TRUE,Itaewon Clubs,139,37.538621,126.992652 +1000002,Seoul,Gwanak-gu,TRUE,Richway,119,37.48208,126.901384 +1000003,Seoul,Guro-gu,TRUE,Guro-gu Call Center,95,37.508163,126.884387 +1000004,Seoul,Yangcheon-gu,TRUE,Yangcheon Table Tennis Club,43,37.546061,126.874209 +1000005,Seoul,Dobong-gu,TRUE,Day Care Center,43,37.679422,127.044374 \ No newline at end of file diff --git a/tests/data/test1_1.csv b/tests/data/test1_1.csv new file mode 100644 index 0000000..41e8615 --- /dev/null +++ b/tests/data/test1_1.csv @@ -0,0 +1,6 @@ +case_id,province,city,group,infection_case,confirmed,latitude,longitude +1000006,Seoul,Guro-gu,TRUE,Manmin Central Church,41,37.481059,126.894343 +1000007,Seoul,from other city,TRUE,SMR Newly Planted Churches Group,36,0,0 +1000008,Seoul,Dongdaemun-gu,TRUE,Dongan Church,17,37.592888,127.056766 +1000009,Seoul,from other city,TRUE,Coupang Logistics Center,25,0,0 +1000010,Seoul,Gwanak-gu,TRUE,Wangsung Church,30,37.481735,126.930121 \ No newline at end of file diff --git a/tests/test_dataframe_extensions.py b/tests/test_dataframe_extensions.py index 1e9ddca..3990987 100644 --- a/tests/test_dataframe_extensions.py +++ b/tests/test_dataframe_extensions.py @@ -217,17 +217,17 @@ def test_explode_outer_with_array(): assert results[2].ID == 2 and results[2].COL == None assert results[3].ID == 3 and results[3].COL == None -def test_array_zip_compat(): +def test_array_zip(): session = Session.builder.from_snowsql().getOrCreate() df = session.createDataFrame([([2, None, 3],),([1],),([],)], ['data']) # +---------------+ # | data| # +---------------+ - # |[2, 1, null, 3]| + # |[2, null, 3] | # | [1]| # | []| # +---------------+ - df = df.withColumn("FIELDS", F.arrays_zip("data","data",use_compat=True)) + df = df.withColumn("FIELDS", F.arrays_zip("data","data")) # +------------+------------------------------+ # |data |FIELDS | # +------------+------------------------------+ @@ -239,46 +239,27 @@ def test_array_zip_compat(): assert len(res)==3 res1 = eval(res[0][1].replace("null","None")) res2 = eval(res[1][1].replace("null","None")) - # NOTE: SF will not return null but undefined res3 = eval(res[2][1]) assert res1==[[2,2],[None,None],[3,3]] assert res2==[[1,1]] - assert res3==[[]] - -def test_array_zip(): - session = Session.builder.from_snowsql().getOrCreate() - df = session.createDataFrame([([2, None, 3],),([1],),([],)], ['data']) - # +---------------+ - # | data| - # +---------------+ - # |[2, 1, null, 3]| - # | [1]| - # | []| - # +---------------+ - df = df.withColumn("FIELDS", F.arrays_zip("data","data")) - # +------------+------------------------------+ - # |data |FIELDS | - # +------------+------------------------------+ - # |[2, null, 3]|[{2, 2}, {null, null}, {3, 3}]| - # |[1] |[{1, 1}] | - # |[] |[] | - # +------------+------------------------------+ + assert res3==[] + df = df.withColumn("FIELDS", F.arrays_zip("data","data","data")).orderBy("data") res = df.collect() - assert len(res)==3 res1 = eval(res[0][1].replace("null","None")) res2 = eval(res[1][1].replace("null","None")) - # NOTE: SF will not return null but undefined - res3 = eval(re.sub("undefined","None",res[2][1])) - assert res1==[[2,2],[None,None],[3,3]] - assert res2==[[1,1]] - assert res3==[[None,None]] + res3 = eval(res[2][1].replace("null","None")) + assert len(res)==3 + assert res1==[] + assert res2==[[1,1,1]] + assert res3==[[2,2,2],[None,None,None],[3,3,3]] + def test_nested_specials(): session = Session.builder.from_snowsql().getOrCreate() df = session.createDataFrame([([2, None, 3],),([1],),([],)], ['data']) - df2 = df.withColumn("FIELDS", F.arrays_zip("data","data",use_compat=True)) - df = df.withColumn("FIELDS", F.explode_outer(F.arrays_zip("data","data",use_compat=True),use_compat=True)) + #df2 = df.withColumn("FIELDS", F.arrays_zip("data","data")) + df = df.withColumn("FIELDS", F.explode_outer(F.arrays_zip("data","data"))) res = df.collect() # +------------+------------+ # | data| FIELDS| @@ -300,3 +281,59 @@ def test_nested_specials(): assert array3 == [3,3] assert array4 == [1,1] assert array5 == None + + +def test_stack(): +# +-------+---------+-----+---------+----+ +# | Name|Analytics| BI|Ingestion| ML| +# +-------+---------+-----+---------+----+ +# | Mickey| null|12000| null|8000| +# | Martin| null| 5000| null|null| +# | Jerry| null| null| 1000|null| +# | Riley| null| null| null|9000| +# | Donald| 1000| null| null|null| +# | John| null| null| 1000|null| +# |Patrick| null| null| null|1000| +# | Emily| 8000| null| 3000|null| +# | Arya| 10000| null| 2000|null| +# +-------+---------+-----+---------+----+ + session = Session.builder.from_snowsql().getOrCreate() + data0 = [ + ('Mickey' , None,12000,None,8000), + ('Martin' , None, 5000,None,None), + ('Jerry' , None, None,1000,None), + ('Riley' , None, None,None,9000), + ('Donald' , 1000, None,None,None), + ('John' , None, None,1000,None), + ('Patrick', None, None,None,1000), + ('Emily' , 8000, None,3000,None), + ('Arya' ,10000, None,2000,None)] + + schema_df = StructType([ + StructField('Name' , StringType(), True), + StructField('Analytics' , IntegerType(), True), + StructField('BI' , IntegerType(), True), + StructField('Ingestion' , IntegerType(), True), + StructField('ML' , IntegerType(), True) + ]) + + df = session.createDataFrame(data0,schema_df) + df.show() + unstacked = df.select("NAME",df.stack(4,F.lit('Analytics'), "ANALYTICS", F.lit('BI'), "BI", F.lit('Ingestion'), "INGESTION", F.lit('ML'), "ML").alias("Project", "Cost_To_Project")) + res = unstacked.collect() + assert len(res) == 36 + res = unstacked.filter(F.col("Cost_To_Project").is_not_null()).orderBy("NAME","Project").collect() + assert len(res) == 12 + assert list(res[ 0]) == ['Arya', 'Analytics', 10000] + assert list(res[ 1]) == ['Arya', 'Ingestion', 2000] + assert list(res[ 2]) == ['Donald', 'Analytics', 1000] + assert list(res[ 3]) == ['Emily', 'Analytics', 8000] + assert list(res[ 4]) == ['Emily', 'Ingestion', 3000] + assert list(res[ 5]) == ['Jerry', 'Ingestion', 1000] + assert list(res[ 6]) == ['John', 'Ingestion', 1000] + assert list(res[ 7]) == ['Martin', 'BI', 5000] + assert list(res[ 8]) == ['Mickey', 'BI', 12000] + assert list(res[ 9]) == ['Mickey', 'ML', 8000] + assert list(res[10]) == ['Patrick', 'ML', 1000] + assert list(res[11]) == ['Riley', 'ML', 9000] + diff --git a/tests/test_dataframe_reader_extensions.py b/tests/test_dataframe_reader_extensions.py new file mode 100644 index 0000000..1458a49 --- /dev/null +++ b/tests/test_dataframe_reader_extensions.py @@ -0,0 +1,23 @@ +import pytest +from snowflake.snowpark import Session, Row +from snowflake.snowpark.types import * +import snowpark_extensions + +def test_load(): + session = Session.builder.from_snowsql().getOrCreate() + schema = StructType([ \ + StructField("case_id", StringType()), \ + StructField("province", StringType()), \ + StructField("city", StringType()), \ + StructField("group", BooleanType()), \ + StructField("infection_case",StringType()), \ + StructField("confirmed", IntegerType()), \ + StructField("latitude", FloatType()), \ + StructField("cilongitudety", FloatType()) \ + ]) + cases = session.read.load(["./tests/data/test1_0.csv","./tests/data/test1_1.csv"], + schema=schema, + format="csv", + sep=",", + header="true") + assert 10 == len(cases.collect()) \ No newline at end of file diff --git a/tests/test_functions.py b/tests/test_functions.py index f771bb8..094f762 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -44,9 +44,9 @@ def test_array_flatten(): res=res.collect() assert len(res)==2 array1 = eval(res[0]['FLATTEN']) - array2 = eval(res[1]['FLATTEN'] or 'None') + array2 = eval(res[1]['FLATTEN'].replace("null",'None')) assert array1[0]==1 and array1[1]==2 and array1[2]==3 and array1[3]==4 and array1[4]==5 and array1[5]==6 - assert array2 is None + assert array2 == [1, None, 4, 5] def test_create_map(): def do_assert(res): @@ -277,7 +277,6 @@ def test_bround(): assert res0[4].ROUNDING == -2.0 assert res0[5].ROUNDING == -2.0 - res1 = df_1.withColumn("rounding",F.bround(F.col('value'),1) ).collect() assert len(res1) == 10 assert res1[0].ROUNDING == 2.2 @@ -291,7 +290,6 @@ def test_bround(): assert res1[8].ROUNDING == 1.5 assert res1[9].ROUNDING == 1.5 - resNull = df_null.withColumn("rounding",F.bround(F.col('value'),None) ).collect() assert len(resNull) == 6 assert resNull[0].ROUNDING == None @@ -304,9 +302,7 @@ def test_bround(): def test_regexp_split(): session = Session.builder.from_snowsql().config("schema","PUBLIC").getOrCreate() from snowflake.snowpark.functions import regexp_split - df = session.createDataFrame([('testAandtestBareTwoBBtests',)], ['s',]) - res = df.select(regexp_split(df.s, "test(A|BB)" , 3).alias('s')).collect() assert res[0].S == '[\n "",\n "andtestBareTwoBBtests"\n]' res = df.select(regexp_split(df.s, "test(A|BB)", 1).alias('s')).collect() @@ -364,4 +360,5 @@ def test_regexp_split(): df = session.createDataFrame([('',)], ['s',]) res = df.select(regexp_split(df.s, '".+?"', 4).alias('s')).collect() - assert res[0].S == '[\n ""\n]' \ No newline at end of file + assert res[0].S == '[\n ""\n]' +