diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index f49981d..9b480a8 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -74,4 +74,8 @@ Adding support for dataframe.groupBy.pivot Version 0.0.13 -------------- -Added support for sort_array, array_max, array_min \ No newline at end of file +Added support for sort_array, array_max, array_min + +Version 0.0.14 +-------------- +Added support for map_values \ No newline at end of file diff --git a/setup.py b/setup.py index 45bb081..b98c727 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -VERSION = '0.0.12' +VERSION = '0.0.14' setup(name='snowpark_extensions', version=VERSION, diff --git a/snowpark_extensions/dataframe_extensions.py b/snowpark_extensions/dataframe_extensions.py index 77c1bd2..f047b91 100644 --- a/snowpark_extensions/dataframe_extensions.py +++ b/snowpark_extensions/dataframe_extensions.py @@ -162,7 +162,28 @@ def specials(cls,c): if hasattr(c,"_is_special_column"): yield c - + class MapValues(SpecialColumn): + def __init__(self,array_col): + super().__init__("values") + self.array_col = array_col + self._special_column_dependencies = [array_col] + def add_columns(self, new_cols, alias:str = None): + # add itself as column + new_cols.append(self.alias(alias) if alias else self) + def expand(self,df): + array_col = _to_col_if_str(self.array_col, "values") + df = df.with_column("__IDX",F.seq8()) + flatten = table_function("flatten") + seq=_generate_prefix("SEQ") + key=_generate_prefix("KEY") + path=_generate_prefix("PATH") + index=_generate_prefix("INDEX") + value=_generate_prefix("VALUE") + this=_generate_prefix("THIS") + df_values=df.join_table_function(flatten(input=array_col,outer=lit(True)).alias(seq,key,path,index,value,this)).group_by("__IDX").agg(F.array_agg(value).alias(self.special_col_name)).distinct() + df = df.join(df_values,on="__IDX").drop("__IDX") + return df + class ArraySort(SpecialColumn): def __init__(self,array_col): super().__init__("sorted") @@ -300,7 +321,11 @@ def _arrays_flatten(array_col,remove_arrays_when_there_is_a_null=True): return ArrayFlatten(array_col,remove_arrays_when_there_is_a_null) def _array_sort(array_col): return ArraySort(array_col) + def _map_values(col:ColumnOrName): + col = _to_col_if_str(col,"map_values") + return MapValues(col) + F.map_values = _map_values F.arrays_zip = _arrays_zip F.flatten = _arrays_flatten F.array_sort = _array_sort diff --git a/tests/test_functions.py b/tests/test_functions.py index e057f0e..98a6a66 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -3,9 +3,9 @@ import snowpark_extensions from snowflake.snowpark import Session from snowflake.snowpark.types import * -from snowflake.snowpark.functions import col,lit, sort_array, array_max, array_min +from snowflake.snowpark.functions import col,lit, sort_array, array_max, array_min, map_values from snowflake.snowpark import functions as F - +import re def test_asc(): session = Session.builder.from_snowsql().getOrCreate() @@ -130,3 +130,21 @@ def test_array_min(): res=df.select(array_min(df.data).alias('min')).collect() assert res[0].MIN == '1' and res[1].MIN == '-1' #[Row(min=1), Row(min=-1)] + +def test_map_values(): + session = Session.builder.from_snowsql().getOrCreate() + df = session.sql("SELECT object_construct('1', 'a', '2', 'b') as data") + res = df.select(map_values("data").alias("values")).collect() + # +------+ + # |values| + # +------+ + # |[a, b]| + # +------+ + assert len(res)==1 + array=re.sub(r"\s","",res[0].VALUES) + assert array == '["a","b"]' + df = session.sql("SELECT object_construct('1', 'value1', '2', parse_json('null')) as data") + res = df.select(map_values("data").alias("values")).collect() + assert len(res)==1 + array=re.sub(r"\s","",res[0].VALUES) + assert array == '["value1",null]' \ No newline at end of file