diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt
index 3942fa1..0c6d085 100644
--- a/CHANGE_LOG.txt
+++ b/CHANGE_LOG.txt
@@ -124,4 +124,8 @@ Adding function extension regexp_split
Version 0.0.24
--------------
-Fixing an issue with the current implementation of applyInPandas
\ No newline at end of file
+Fixing an issue with the current implementation of applyInPandas
+
+Version 0.0.25
+--------------
+Change in implementation of regexp_split to support different regular expression cases
diff --git a/README.md b/README.md
index 7518aaa..bae3530 100644
--- a/README.md
+++ b/README.md
@@ -366,7 +366,7 @@ session = Session.builder.from_snowsql().create()
df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])
res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect()
print(str(res))
-# ['one', 'twoBthreeC']
+# [\n "one",\n "twoBthreeC"\n]
```
# utilities
diff --git a/setup.py b/setup.py
index fb0ba77..4db518d 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
-VERSION = '0.0.24'
+VERSION = '0.0.25'
setup(name='snowpark_extensions',
version=VERSION,
diff --git a/snowpark_extensions/functions_extensions.py b/snowpark_extensions/functions_extensions.py
index c4b9958..6eb02a3 100644
--- a/snowpark_extensions/functions_extensions.py
+++ b/snowpark_extensions/functions_extensions.py
@@ -297,33 +297,40 @@ def has_special_char(string):
def is_not_a_regex(pattern):
return not has_special_char(pattern)
- F._split_regex = None
+ F._split_regex_function = None
F.snowflake_split = F.split
- def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:ColumnOrLiteral = -1):
+ def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:int = -1):
+
value = _to_col_if_str(value,"split_regex")
- if not F._split_regex:
- session = context.get_active_session()
- current_database = session.get_current_database()
- def split_regex_definition(value:str, pattern:str, limit:int)->str:
- if limit == 1:
- return '[\''+ value +'\']'
- else:
- limit = limit - 1
- if limit < 0:
- limit = 0
- return re.split(pattern,value,limit)
- F._split_regex = session.udf.register(split_regex_definition,is_permanent=False,overwrite=True)
- # Replace parenthesis because re.split adds what is inside them into the result list
- # while Pyspark doesn't take in count
- pattern = pattern.replace('(','').replace(')','')
+
pattern_col = pattern
if isinstance(pattern, str):
pattern_col = lit(pattern)
if limit < 0 and isinstance(pattern, str) and is_not_a_regex(pattern):
- F.snowflake_split(value, pattern_col)
- if isinstance(limit, int):
- limit = lit(limit)
- return F._split_regex (value, pattern_col, limit)
+ return F.snowflake_split(value, pattern_col)
+
+ session = context.get_active_session()
+ current_database = session.get_current_database()
+ function_name =_generate_prefix("_regex_split_helper")
+ F._split_regex_function = f"{current_database}.public.{function_name}"
+
+ session.sql(f"""CREATE OR REPLACE FUNCTION {F._split_regex_function} (input String, regex String, limit INT)
+RETURNS ARRAY
+LANGUAGE JAVA
+RUNTIME_VERSION = '11'
+PACKAGES = ('com.snowflake:snowpark:latest')
+HANDLER = 'MyJavaClass.regex_split_run'
+AS
+$$
+import java.util.regex.Pattern;
+public class MyJavaClass {{
+ public String[] regex_split_run(String input,String regex, int limit) {{
+ Pattern pattern = Pattern.compile(regex);
+ return pattern.split(input, limit);
+ }}}}$$;""").show()
+
+ return call_builtin(F._split_regex_function, value, pattern_col, limit)
+
F.array = _array
F.array_max = _array_max
diff --git a/tests/test_functions.py b/tests/test_functions.py
index d1c3de2..f771bb8 100644
--- a/tests/test_functions.py
+++ b/tests/test_functions.py
@@ -305,67 +305,63 @@ def test_regexp_split():
session = Session.builder.from_snowsql().config("schema","PUBLIC").getOrCreate()
from snowflake.snowpark.functions import regexp_split
- df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])
+ df = session.createDataFrame([('testAandtestBareTwoBBtests',)], ['s',])
+
+ res = df.select(regexp_split(df.s, "test(A|BB)" , 3).alias('s')).collect()
+ assert res[0].S == '[\n "",\n "andtestBareTwoBBtests"\n]'
+ res = df.select(regexp_split(df.s, "test(A|BB)", 1).alias('s')).collect()
+ assert res[0].S == '[\n "testAandtestBareTwoBBtests"\n]'
+
+ df = session.createDataFrame([('From: mauricio@mobilize.net',)], ['s',])
+
+ res = df.select(regexp_split(df.s, "((From|To)|Subject): (\w+@\w+\.[a-z]+)").alias('s')).collect()
+ assert res[0].S == '[\n "",\n ""\n]'
- res = df.select(regexp_split(df.s, 'Z', -1).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
- res = df.select(regexp_split(df.s, 'Z', 0).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
- res = df.select(regexp_split(df.s, 'Z', 1).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
- res = df.select(regexp_split(df.s, 'Z', 2).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
- res = df.select(regexp_split(df.s, 't', 0).alias('s')).collect()
- assert res[0].S == "['oneA', 'woB', 'hreeC']"
+ df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',])
+
+ res = df.select(regexp_split(df.s, 'Z').alias('s')).collect()
+ assert res[0].S == '[\n "oneAtwoBthreeC"\n]'
+ res = df.select(regexp_split(df.s, 't').alias('s')).collect()
+ assert res[0].S == '[\n "oneA",\n "woB",\n "hreeC"\n]'
res = df.select(regexp_split(df.s, 't', 1).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
- res = df.select(regexp_split(df.s, '[ABC]', 0).alias('s')).collect()
- assert res[0].S == "['one', 'two', 'three', '']"
+ assert res[0].S == '[\n "oneAtwoBthreeC"\n]'
+ res = df.select(regexp_split(df.s, 't', 2).alias('s')).collect()
+ assert res[0].S == '[\n "oneA",\n "woBthreeC"\n]'
+ res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect()
+ assert res[0].S == '[\n "one",\n "two",\n "three",\n ""\n]'
res = df.select(regexp_split(df.s, '[ABC]', 1).alias('s')).collect()
- assert res[0].S == "['oneAtwoBthreeC']"
+ assert res[0].S == '[\n "oneAtwoBthreeC"\n]'
res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect()
- assert res[0].S == "['one', 'twoBthreeC']"
- res = df.select(regexp_split(df.s, '[ABC]', -1).alias('s')).collect()
- assert res[0].S == "['one', 'two', 'three', '']"
- res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect()
- assert res[0].S == "['one', 'two', 'three', '']"
+ assert res[0].S == '[\n "one",\n "twoBthreeC"\n]'
df = session.createDataFrame([('HelloabNewacWorld',)], ['s',])
- res = df.select(regexp_split(df.s, 'abNew(a*)c', 2).alias('s')).collect()
- assert res[0].S == "['Hello', 'World']"
- res = df.select(regexp_split(df.s, 'abNew(ac)', 2).alias('s')).collect()
- assert res[0].S == "['Hello', 'World']"
- res = df.select(regexp_split(df.s, 'abNew[a]c', 2).alias('s')).collect()
- assert res[0].S == "['Hello', 'World']"
- res = df.select(regexp_split(df.s, 'a([b, c]).*?', 3).alias('s')).collect()
- assert res[0].S == "['Hello', 'New', 'World']"
res = df.select(regexp_split(df.s, 'a([b, c]).*?').alias('s')).collect()
- assert res[0].S == "['Hello', 'New', 'World']"
+ assert res[0].S == '[\n "Hello",\n "New",\n "World"\n]'
df = session.createDataFrame([(r'aa\nbb\nccc\b',)], ['s',])
res = df.select(regexp_split(df.s, r'\w+.').alias('s')).collect()
- assert res[0].S == "['', '', '', 'b']"
+ assert res[0].S == '[\n "",\n "",\n "",\n "b"\n]'
df = session.createDataFrame([(r'\n\n\n',)], ['s',])
res = df.select(regexp_split(df.s, '.*', 3).alias('s')).collect()
- assert res[0].S == "['', '', '']"
+ assert res[0].S == '[\n "",\n "",\n ""\n]'
df = session.createDataFrame([("""line 1
line 2
line 3""",)], ['s',])
res = df.select(regexp_split(df.s, r'\n', 3).alias('s')).collect()
- assert res[0].S == "['line 1', 'line 2', 'line 3']"
+ assert res[0].S == '[\n "line 1",\n "line 2",\n "line 3"\n]'
res = df.select(regexp_split(df.s, r'line 1(\n)', 3).alias('s')).collect()
- assert res[0].S == "['', 'line 2\\nline 3']"
+ assert res[0].S == '[\n "",\n "line 2\\nline 3"\n]'
df = session.createDataFrame([('The price of PINEAPPLE ice cream is 20',)], ['s',])
res = df.select(regexp_split(df.s, r"(\b[A-Z]+\b).+(\b\d+)", 4).alias('s')).collect()
- assert res[0].S == "['The price of ', '']"
+ assert res[0].S == '[\n "The price of ",\n ""\n]'
df = session.createDataFrame([('',)], ['s',])
res = df.select(regexp_split(df.s, '".+?"', 4).alias('s')).collect()
- assert res[0].S == "['']"
\ No newline at end of file
+ assert res[0].S == '[\n ""\n]'
\ No newline at end of file