diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt
index 0ad027b..ac9b8f2 100644
--- a/CHANGE_LOG.txt
+++ b/CHANGE_LOG.txt
@@ -117,3 +117,7 @@ Adjust to use snowpark-python>=1.1.0
Version 0.0.22
--------------
Adjustment for notebook integration in the case that only one row is returned. Thanks to @naga
+
+ Version 0.0.23
+--------------
+Adding function extension regexp_split
\ No newline at end of file
diff --git a/README.md b/README.md
index 4963744..7518aaa 100644
--- a/README.md
+++ b/README.md
@@ -213,6 +213,7 @@ df.group_by("ID").applyInPandas(
| functions.date_add | returns the date that is n days days after |
| functions.date_sub | returns the date that is n days before |
| functions.regexp_extract | extract a specific group matched by a regex, from the specified string column. |
+| functions.regexp_split | splits a specific group matched by a regex, it is an extension of split wich supports a limit parameter. |
| ~~functions.asc~~ | ~~returns a sort expression based on the ascending order of the given column name.~~ **Available in snowpark-python >=1.1.0** |
| ~~functions.desc~~ | ~~returns a sort expression based on the descending order of the given column name.~~ **Available in snowpark-python >=1.1.0** |
| functions.flatten | creates a single array from an array of arrays
@@ -357,6 +358,17 @@ df.select(F.regexp_extract('id', r'(\d+)_(\d+)', 2)).show()
# ------------------------------------------------------
```
+### regexp_split
+
+```python
+session = Session.builder.from_snowsql().create()
+
+df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])
+res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect()
+print(str(res))
+# ['one', 'twoBthreeC']
+```
+
# utilities
| Name | Description |
diff --git a/setup.py b/setup.py
index 9f4dc22..16408c5 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
-VERSION = '0.0.22'
+VERSION = '0.0.23'
setup(name='snowpark_extensions',
version=VERSION,
diff --git a/snowpark_extensions/functions_extensions.py b/snowpark_extensions/functions_extensions.py
index 3064a55..c4b9958 100644
--- a/snowpark_extensions/functions_extensions.py
+++ b/snowpark_extensions/functions_extensions.py
@@ -16,6 +16,7 @@
from snowflake.snowpark.column import _to_col_if_str, _to_col_if_lit
from snowflake.snowpark.dataframe import _generate_prefix
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
+import re
if not hasattr(F,"___extended"):
F.___extended = True
@@ -51,6 +52,13 @@ def regexp_extract(value:ColumnOrLiteralStr,regexp:ColumnOrLiteralStr,idx:int) -
# we add .* to the expression if needed
return coalesce(call_builtin('regexp_substr',value,regexp,lit(1),lit(1),lit('e'),idx),lit(''))
+ def unix_timestamp(col):
+ return call_builtin("DATE_PART","epoch_second",col)
+
+ def from_unixtime(col):
+ col = _to_col_if_str(col,"from_unixtime")
+ return F.to_timestamp(col).alias('ts')
+
def format_number(col,d):
col = _to_col_if_str(col,"format_number")
return F.to_varchar(col,'999,999,999,999,999.' + '0'*d)
@@ -281,6 +289,41 @@ def _bround(col: Column, scale: int = 0):
, F.when(columnFloor % F.lit(2) == F.lit(0), columnFloor).otherwise(columnFloor + F.lit(1))
).otherwise(F.round(elevatedColumn)) / F.when(F.lit(0) == F.lit(scale), F.lit(1)).otherwise(power)
+ def has_special_char(string):
+ pattern = '[^A-Za-z0-9]+'
+ result = re.search(pattern, string)
+ return bool(result)
+
+ def is_not_a_regex(pattern):
+ return not has_special_char(pattern)
+
+ F._split_regex = None
+ F.snowflake_split = F.split
+ def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:ColumnOrLiteral = -1):
+ value = _to_col_if_str(value,"split_regex")
+ if not F._split_regex:
+ session = context.get_active_session()
+ current_database = session.get_current_database()
+ def split_regex_definition(value:str, pattern:str, limit:int)->str:
+ if limit == 1:
+ return '[\''+ value +'\']'
+ else:
+ limit = limit - 1
+ if limit < 0:
+ limit = 0
+ return re.split(pattern,value,limit)
+ F._split_regex = session.udf.register(split_regex_definition,is_permanent=False,overwrite=True)
+ # Replace parenthesis because re.split adds what is inside them into the result list
+ # while Pyspark doesn't take in count
+ pattern = pattern.replace('(','').replace(')','')
+ pattern_col = pattern
+ if isinstance(pattern, str):
+ pattern_col = lit(pattern)
+ if limit < 0 and isinstance(pattern, str) and is_not_a_regex(pattern):
+ F.snowflake_split(value, pattern_col)
+ if isinstance(limit, int):
+ limit = lit(limit)
+ return F._split_regex (value, pattern_col, limit)
F.array = _array
F.array_max = _array_max
@@ -288,12 +331,19 @@ def _bround(col: Column, scale: int = 0):
F.array_distinct = array_distinct
F.regexp_extract = regexp_extract
F.create_map = create_map
+ F.unix_timestamp = unix_timestamp
+ F.from_unixtime = from_unixtime
F.format_number = format_number
F.reverse = reverse
F.daydiff = daydiff
F.date_add = date_add
F.date_sub = date_sub
+ F.asc = lambda col: _to_col_if_str(col, "asc").asc()
+ F.desc = lambda col: _to_col_if_str(col, "desc").desc()
+ F.asc_nulls_first = lambda col: _to_col_if_str(col, "asc_nulls_first").asc()
+ F.desc_nulls_first = lambda col: _to_col_if_str(col, "desc_nulls_first").asc()
F.sort_array = _sort_array
F.array_sort = _array_sort
F.struct = _struct
- F.bround = _bround
\ No newline at end of file
+ F.bround = _bround
+ F.regexp_split = _regexp_split
\ No newline at end of file
diff --git a/tests/test_functions.py b/tests/test_functions.py
index eb5b0fb..d1c3de2 100644
--- a/tests/test_functions.py
+++ b/tests/test_functions.py
@@ -230,6 +230,7 @@ def test_daydiff():
res = df.select(F.daydiff(F.to_date(df.d2), F.to_date(df.d1)).alias('diff')).collect()
assert res[0].DIFF == 32
+
def test_bround():
session = Session.builder.from_snowsql().getOrCreate()
data0 = [(1.5,0),
@@ -300,3 +301,71 @@ def test_bround():
assert resNull[4].ROUNDING == None
assert resNull[5].ROUNDING == None
+def test_regexp_split():
+ session = Session.builder.from_snowsql().config("schema","PUBLIC").getOrCreate()
+ from snowflake.snowpark.functions import regexp_split
+
+ df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])
+
+ res = df.select(regexp_split(df.s, 'Z', -1).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, 'Z', 0).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, 'Z', 1).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, 'Z', 2).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, 't', 0).alias('s')).collect()
+ assert res[0].S == "['oneA', 'woB', 'hreeC']"
+ res = df.select(regexp_split(df.s, 't', 1).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, '[ABC]', 0).alias('s')).collect()
+ assert res[0].S == "['one', 'two', 'three', '']"
+ res = df.select(regexp_split(df.s, '[ABC]', 1).alias('s')).collect()
+ assert res[0].S == "['oneAtwoBthreeC']"
+ res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect()
+ assert res[0].S == "['one', 'twoBthreeC']"
+ res = df.select(regexp_split(df.s, '[ABC]', -1).alias('s')).collect()
+ assert res[0].S == "['one', 'two', 'three', '']"
+ res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect()
+ assert res[0].S == "['one', 'two', 'three', '']"
+
+ df = session.createDataFrame([('HelloabNewacWorld',)], ['s',])
+
+ res = df.select(regexp_split(df.s, 'abNew(a*)c', 2).alias('s')).collect()
+ assert res[0].S == "['Hello', 'World']"
+ res = df.select(regexp_split(df.s, 'abNew(ac)', 2).alias('s')).collect()
+ assert res[0].S == "['Hello', 'World']"
+ res = df.select(regexp_split(df.s, 'abNew[a]c', 2).alias('s')).collect()
+ assert res[0].S == "['Hello', 'World']"
+ res = df.select(regexp_split(df.s, 'a([b, c]).*?', 3).alias('s')).collect()
+ assert res[0].S == "['Hello', 'New', 'World']"
+ res = df.select(regexp_split(df.s, 'a([b, c]).*?').alias('s')).collect()
+ assert res[0].S == "['Hello', 'New', 'World']"
+
+ df = session.createDataFrame([(r'aa\nbb\nccc\b',)], ['s',])
+
+ res = df.select(regexp_split(df.s, r'\w+.').alias('s')).collect()
+ assert res[0].S == "['', '', '', 'b']"
+
+ df = session.createDataFrame([(r'\n\n\n',)], ['s',])
+
+ res = df.select(regexp_split(df.s, '.*', 3).alias('s')).collect()
+ assert res[0].S == "['', '', '']"
+
+ df = session.createDataFrame([("""line 1
+line 2
+line 3""",)], ['s',])
+
+ res = df.select(regexp_split(df.s, r'\n', 3).alias('s')).collect()
+ assert res[0].S == "['line 1', 'line 2', 'line 3']"
+ res = df.select(regexp_split(df.s, r'line 1(\n)', 3).alias('s')).collect()
+ assert res[0].S == "['', 'line 2\\nline 3']"
+
+ df = session.createDataFrame([('The price of PINEAPPLE ice cream is 20',)], ['s',])
+ res = df.select(regexp_split(df.s, r"(\b[A-Z]+\b).+(\b\d+)", 4).alias('s')).collect()
+ assert res[0].S == "['The price of ', '']"
+
+ df = session.createDataFrame([('',)], ['s',])
+ res = df.select(regexp_split(df.s, '".+?"', 4).alias('s')).collect()
+ assert res[0].S == "['']"
\ No newline at end of file