Skip to content

Commit

Permalink
adding support for map_values
Browse files Browse the repository at this point in the history
  • Loading branch information
orellabac committed Dec 27, 2022
1 parent 420765e commit ca8bb94
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
6 changes: 5 additions & 1 deletion CHANGE_LOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ Adding support for dataframe.groupBy.pivot

Version 0.0.13
--------------
Added support for sort_array, array_max, array_min
Added support for sort_array, array_max, array_min

Version 0.0.14
--------------
Added support for map_values
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()

VERSION = '0.0.13'
VERSION = '0.0.14'

setup(name='snowpark_extensions',
version=VERSION,
Expand Down
27 changes: 26 additions & 1 deletion snowpark_extensions/dataframe_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,28 @@ def specials(cls,c):
if hasattr(c,"_is_special_column"):
yield c


class MapValues(SpecialColumn):
def __init__(self,array_col):
super().__init__("values")
self.array_col = array_col
self._special_column_dependencies = [array_col]
def add_columns(self, new_cols, alias:str = None):
# add itself as column
new_cols.append(self.alias(alias) if alias else self)
def expand(self,df):
array_col = _to_col_if_str(self.array_col, "values")
df = df.with_column("__IDX",F.seq8())
flatten = table_function("flatten")
seq=_generate_prefix("SEQ")
key=_generate_prefix("KEY")
path=_generate_prefix("PATH")
index=_generate_prefix("INDEX")
value=_generate_prefix("VALUE")
this=_generate_prefix("THIS")
df_values=df.join_table_function(flatten(input=array_col,outer=lit(True)).alias(seq,key,path,index,value,this)).group_by("__IDX").agg(F.array_agg(value).alias(self.special_col_name)).distinct()
df = df.join(df_values,on="__IDX").drop("__IDX")
return df

class ArraySort(SpecialColumn):
def __init__(self,array_col):
super().__init__("sorted")
Expand Down Expand Up @@ -300,7 +321,11 @@ def _arrays_flatten(array_col,remove_arrays_when_there_is_a_null=True):
return ArrayFlatten(array_col,remove_arrays_when_there_is_a_null)
def _array_sort(array_col):
return ArraySort(array_col)
def _map_values(col:ColumnOrName):
col = _to_col_if_str(col,"map_values")
return MapValues(col)

F.map_values = _map_values
F.arrays_zip = _arrays_zip
F.flatten = _arrays_flatten
F.array_sort = _array_sort
Expand Down
22 changes: 20 additions & 2 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import snowpark_extensions
from snowflake.snowpark import Session
from snowflake.snowpark.types import *
from snowflake.snowpark.functions import col,lit, sort_array, array_max, array_min
from snowflake.snowpark.functions import col,lit, sort_array, array_max, array_min, map_values
from snowflake.snowpark import functions as F

import re

def test_asc():
session = Session.builder.from_snowsql().getOrCreate()
Expand Down Expand Up @@ -130,3 +130,21 @@ def test_array_min():
res=df.select(array_min(df.data).alias('min')).collect()
assert res[0].MIN == '1' and res[1].MIN == '-1'
#[Row(min=1), Row(min=-1)]

def test_map_values():
session = Session.builder.from_snowsql().getOrCreate()
df = session.sql("SELECT object_construct('1', 'a', '2', 'b') as data")
res = df.select(map_values("data").alias("values")).collect()
# +------+
# |values|
# +------+
# |[a, b]|
# +------+
assert len(res)==1
array=re.sub(r"\s","",res[0].VALUES)
assert array == '["a","b"]'
df = session.sql("SELECT object_construct('1', 'value1', '2', parse_json('null')) as data")
res = df.select(map_values("data").alias("values")).collect()
assert len(res)==1
array=re.sub(r"\s","",res[0].VALUES)
assert array == '["value1",null]'

0 comments on commit ca8bb94

Please sign in to comment.