diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index 0ad027b..ac9b8f2 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -117,3 +117,7 @@ Adjust to use snowpark-python>=1.1.0 Version 0.0.22 -------------- Adjustment for notebook integration in the case that only one row is returned. Thanks to @naga + + Version 0.0.23 +-------------- +Adding function extension regexp_split \ No newline at end of file diff --git a/README.md b/README.md index 4963744..7518aaa 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ df.group_by("ID").applyInPandas( | functions.date_add | returns the date that is n days days after | | functions.date_sub | returns the date that is n days before | | functions.regexp_extract | extract a specific group matched by a regex, from the specified string column. | +| functions.regexp_split | splits a specific group matched by a regex, it is an extension of split wich supports a limit parameter. | | ~~functions.asc~~ | ~~returns a sort expression based on the ascending order of the given column name.~~ **Available in snowpark-python >=1.1.0** | | ~~functions.desc~~ | ~~returns a sort expression based on the descending order of the given column name.~~ **Available in snowpark-python >=1.1.0** | | functions.flatten | creates a single array from an array of arrays @@ -357,6 +358,17 @@ df.select(F.regexp_extract('id', r'(\d+)_(\d+)', 2)).show() # ------------------------------------------------------ ``` +### regexp_split + +```python +session = Session.builder.from_snowsql().create() + +df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',]) +res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect() +print(str(res)) +# ['one', 'twoBthreeC'] +``` + # utilities | Name | Description | diff --git a/setup.py b/setup.py index 9f4dc22..16408c5 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -VERSION = '0.0.22' +VERSION = '0.0.23' setup(name='snowpark_extensions', version=VERSION, diff --git a/snowpark_extensions/functions_extensions.py b/snowpark_extensions/functions_extensions.py index 3064a55..c4b9958 100644 --- a/snowpark_extensions/functions_extensions.py +++ b/snowpark_extensions/functions_extensions.py @@ -16,6 +16,7 @@ from snowflake.snowpark.column import _to_col_if_str, _to_col_if_lit from snowflake.snowpark.dataframe import _generate_prefix from snowflake.snowpark._internal.analyzer.unary_expression import Alias +import re if not hasattr(F,"___extended"): F.___extended = True @@ -51,6 +52,13 @@ def regexp_extract(value:ColumnOrLiteralStr,regexp:ColumnOrLiteralStr,idx:int) - # we add .* to the expression if needed return coalesce(call_builtin('regexp_substr',value,regexp,lit(1),lit(1),lit('e'),idx),lit('')) + def unix_timestamp(col): + return call_builtin("DATE_PART","epoch_second",col) + + def from_unixtime(col): + col = _to_col_if_str(col,"from_unixtime") + return F.to_timestamp(col).alias('ts') + def format_number(col,d): col = _to_col_if_str(col,"format_number") return F.to_varchar(col,'999,999,999,999,999.' + '0'*d) @@ -281,6 +289,41 @@ def _bround(col: Column, scale: int = 0): , F.when(columnFloor % F.lit(2) == F.lit(0), columnFloor).otherwise(columnFloor + F.lit(1)) ).otherwise(F.round(elevatedColumn)) / F.when(F.lit(0) == F.lit(scale), F.lit(1)).otherwise(power) + def has_special_char(string): + pattern = '[^A-Za-z0-9]+' + result = re.search(pattern, string) + return bool(result) + + def is_not_a_regex(pattern): + return not has_special_char(pattern) + + F._split_regex = None + F.snowflake_split = F.split + def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:ColumnOrLiteral = -1): + value = _to_col_if_str(value,"split_regex") + if not F._split_regex: + session = context.get_active_session() + current_database = session.get_current_database() + def split_regex_definition(value:str, pattern:str, limit:int)->str: + if limit == 1: + return '[\''+ value +'\']' + else: + limit = limit - 1 + if limit < 0: + limit = 0 + return re.split(pattern,value,limit) + F._split_regex = session.udf.register(split_regex_definition,is_permanent=False,overwrite=True) + # Replace parenthesis because re.split adds what is inside them into the result list + # while Pyspark doesn't take in count + pattern = pattern.replace('(','').replace(')','') + pattern_col = pattern + if isinstance(pattern, str): + pattern_col = lit(pattern) + if limit < 0 and isinstance(pattern, str) and is_not_a_regex(pattern): + F.snowflake_split(value, pattern_col) + if isinstance(limit, int): + limit = lit(limit) + return F._split_regex (value, pattern_col, limit) F.array = _array F.array_max = _array_max @@ -288,12 +331,19 @@ def _bround(col: Column, scale: int = 0): F.array_distinct = array_distinct F.regexp_extract = regexp_extract F.create_map = create_map + F.unix_timestamp = unix_timestamp + F.from_unixtime = from_unixtime F.format_number = format_number F.reverse = reverse F.daydiff = daydiff F.date_add = date_add F.date_sub = date_sub + F.asc = lambda col: _to_col_if_str(col, "asc").asc() + F.desc = lambda col: _to_col_if_str(col, "desc").desc() + F.asc_nulls_first = lambda col: _to_col_if_str(col, "asc_nulls_first").asc() + F.desc_nulls_first = lambda col: _to_col_if_str(col, "desc_nulls_first").asc() F.sort_array = _sort_array F.array_sort = _array_sort F.struct = _struct - F.bround = _bround \ No newline at end of file + F.bround = _bround + F.regexp_split = _regexp_split \ No newline at end of file diff --git a/tests/test_functions.py b/tests/test_functions.py index eb5b0fb..d1c3de2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -230,6 +230,7 @@ def test_daydiff(): res = df.select(F.daydiff(F.to_date(df.d2), F.to_date(df.d1)).alias('diff')).collect() assert res[0].DIFF == 32 + def test_bround(): session = Session.builder.from_snowsql().getOrCreate() data0 = [(1.5,0), @@ -300,3 +301,71 @@ def test_bround(): assert resNull[4].ROUNDING == None assert resNull[5].ROUNDING == None +def test_regexp_split(): + session = Session.builder.from_snowsql().config("schema","PUBLIC").getOrCreate() + from snowflake.snowpark.functions import regexp_split + + df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + + res = df.select(regexp_split(df.s, 'Z', -1).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, 'Z', 0).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, 'Z', 1).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, 'Z', 2).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, 't', 0).alias('s')).collect() + assert res[0].S == "['oneA', 'woB', 'hreeC']" + res = df.select(regexp_split(df.s, 't', 1).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, '[ABC]', 0).alias('s')).collect() + assert res[0].S == "['one', 'two', 'three', '']" + res = df.select(regexp_split(df.s, '[ABC]', 1).alias('s')).collect() + assert res[0].S == "['oneAtwoBthreeC']" + res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect() + assert res[0].S == "['one', 'twoBthreeC']" + res = df.select(regexp_split(df.s, '[ABC]', -1).alias('s')).collect() + assert res[0].S == "['one', 'two', 'three', '']" + res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect() + assert res[0].S == "['one', 'two', 'three', '']" + + df = session.createDataFrame([('HelloabNewacWorld',)], ['s',]) + + res = df.select(regexp_split(df.s, 'abNew(a*)c', 2).alias('s')).collect() + assert res[0].S == "['Hello', 'World']" + res = df.select(regexp_split(df.s, 'abNew(ac)', 2).alias('s')).collect() + assert res[0].S == "['Hello', 'World']" + res = df.select(regexp_split(df.s, 'abNew[a]c', 2).alias('s')).collect() + assert res[0].S == "['Hello', 'World']" + res = df.select(regexp_split(df.s, 'a([b, c]).*?', 3).alias('s')).collect() + assert res[0].S == "['Hello', 'New', 'World']" + res = df.select(regexp_split(df.s, 'a([b, c]).*?').alias('s')).collect() + assert res[0].S == "['Hello', 'New', 'World']" + + df = session.createDataFrame([(r'aa\nbb\nccc\b',)], ['s',]) + + res = df.select(regexp_split(df.s, r'\w+.').alias('s')).collect() + assert res[0].S == "['', '', '', 'b']" + + df = session.createDataFrame([(r'\n\n\n',)], ['s',]) + + res = df.select(regexp_split(df.s, '.*', 3).alias('s')).collect() + assert res[0].S == "['', '', '']" + + df = session.createDataFrame([("""line 1 +line 2 +line 3""",)], ['s',]) + + res = df.select(regexp_split(df.s, r'\n', 3).alias('s')).collect() + assert res[0].S == "['line 1', 'line 2', 'line 3']" + res = df.select(regexp_split(df.s, r'line 1(\n)', 3).alias('s')).collect() + assert res[0].S == "['', 'line 2\\nline 3']" + + df = session.createDataFrame([('The price of PINEAPPLE ice cream is 20',)], ['s',]) + res = df.select(regexp_split(df.s, r"(\b[A-Z]+\b).+(\b\d+)", 4).alias('s')).collect() + assert res[0].S == "['The price of ', '']" + + df = session.createDataFrame([('',)], ['s',]) + res = df.select(regexp_split(df.s, '".+?"', 4).alias('s')).collect() + assert res[0].S == "['']" \ No newline at end of file