From 84cb0436ad3ca6e3c28b20380454178c65852adc Mon Sep 17 00:00:00 2001 From: aherreraNet <123779283+aherreraNet@users.noreply.github.com> Date: Mon, 6 Feb 2023 10:39:22 -0600 Subject: [PATCH] Feature/updating regex split (#26) * Updating regexp_split * Minor change * Update setup version --------- Co-authored-by: aherreraNet --- CHANGE_LOG.txt | 6 +- README.md | 2 +- setup.py | 2 +- snowpark_extensions/functions_extensions.py | 49 ++++++++------- tests/test_functions.py | 66 ++++++++++----------- 5 files changed, 66 insertions(+), 59 deletions(-) diff --git a/CHANGE_LOG.txt b/CHANGE_LOG.txt index 3942fa1..0c6d085 100644 --- a/CHANGE_LOG.txt +++ b/CHANGE_LOG.txt @@ -124,4 +124,8 @@ Adding function extension regexp_split Version 0.0.24 -------------- -Fixing an issue with the current implementation of applyInPandas \ No newline at end of file +Fixing an issue with the current implementation of applyInPandas + +Version 0.0.25 +-------------- +Change in implementation of regexp_split to support different regular expression cases diff --git a/README.md b/README.md index 7518aaa..bae3530 100644 --- a/README.md +++ b/README.md @@ -366,7 +366,7 @@ session = Session.builder.from_snowsql().create() df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',]) res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect() print(str(res)) -# ['one', 'twoBthreeC'] +# [\n "one",\n "twoBthreeC"\n] ``` # utilities diff --git a/setup.py b/setup.py index fb0ba77..4db518d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -VERSION = '0.0.24' +VERSION = '0.0.25' setup(name='snowpark_extensions', version=VERSION, diff --git a/snowpark_extensions/functions_extensions.py b/snowpark_extensions/functions_extensions.py index c4b9958..6eb02a3 100644 --- a/snowpark_extensions/functions_extensions.py +++ b/snowpark_extensions/functions_extensions.py @@ -297,33 +297,40 @@ def has_special_char(string): def is_not_a_regex(pattern): return not has_special_char(pattern) - F._split_regex = None + F._split_regex_function = None F.snowflake_split = F.split - def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:ColumnOrLiteral = -1): + def _regexp_split(value:ColumnOrName, pattern:ColumnOrLiteralStr, limit:int = -1): + value = _to_col_if_str(value,"split_regex") - if not F._split_regex: - session = context.get_active_session() - current_database = session.get_current_database() - def split_regex_definition(value:str, pattern:str, limit:int)->str: - if limit == 1: - return '[\''+ value +'\']' - else: - limit = limit - 1 - if limit < 0: - limit = 0 - return re.split(pattern,value,limit) - F._split_regex = session.udf.register(split_regex_definition,is_permanent=False,overwrite=True) - # Replace parenthesis because re.split adds what is inside them into the result list - # while Pyspark doesn't take in count - pattern = pattern.replace('(','').replace(')','') + pattern_col = pattern if isinstance(pattern, str): pattern_col = lit(pattern) if limit < 0 and isinstance(pattern, str) and is_not_a_regex(pattern): - F.snowflake_split(value, pattern_col) - if isinstance(limit, int): - limit = lit(limit) - return F._split_regex (value, pattern_col, limit) + return F.snowflake_split(value, pattern_col) + + session = context.get_active_session() + current_database = session.get_current_database() + function_name =_generate_prefix("_regex_split_helper") + F._split_regex_function = f"{current_database}.public.{function_name}" + + session.sql(f"""CREATE OR REPLACE FUNCTION {F._split_regex_function} (input String, regex String, limit INT) +RETURNS ARRAY +LANGUAGE JAVA +RUNTIME_VERSION = '11' +PACKAGES = ('com.snowflake:snowpark:latest') +HANDLER = 'MyJavaClass.regex_split_run' +AS +$$ +import java.util.regex.Pattern; +public class MyJavaClass {{ + public String[] regex_split_run(String input,String regex, int limit) {{ + Pattern pattern = Pattern.compile(regex); + return pattern.split(input, limit); + }}}}$$;""").show() + + return call_builtin(F._split_regex_function, value, pattern_col, limit) + F.array = _array F.array_max = _array_max diff --git a/tests/test_functions.py b/tests/test_functions.py index d1c3de2..f771bb8 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -305,67 +305,63 @@ def test_regexp_split(): session = Session.builder.from_snowsql().config("schema","PUBLIC").getOrCreate() from snowflake.snowpark.functions import regexp_split - df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + df = session.createDataFrame([('testAandtestBareTwoBBtests',)], ['s',]) + + res = df.select(regexp_split(df.s, "test(A|BB)" , 3).alias('s')).collect() + assert res[0].S == '[\n "",\n "andtestBareTwoBBtests"\n]' + res = df.select(regexp_split(df.s, "test(A|BB)", 1).alias('s')).collect() + assert res[0].S == '[\n "testAandtestBareTwoBBtests"\n]' + + df = session.createDataFrame([('From: mauricio@mobilize.net',)], ['s',]) + + res = df.select(regexp_split(df.s, "((From|To)|Subject): (\w+@\w+\.[a-z]+)").alias('s')).collect() + assert res[0].S == '[\n "",\n ""\n]' - res = df.select(regexp_split(df.s, 'Z', -1).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" - res = df.select(regexp_split(df.s, 'Z', 0).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" - res = df.select(regexp_split(df.s, 'Z', 1).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" - res = df.select(regexp_split(df.s, 'Z', 2).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" - res = df.select(regexp_split(df.s, 't', 0).alias('s')).collect() - assert res[0].S == "['oneA', 'woB', 'hreeC']" + df = session.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + + res = df.select(regexp_split(df.s, 'Z').alias('s')).collect() + assert res[0].S == '[\n "oneAtwoBthreeC"\n]' + res = df.select(regexp_split(df.s, 't').alias('s')).collect() + assert res[0].S == '[\n "oneA",\n "woB",\n "hreeC"\n]' res = df.select(regexp_split(df.s, 't', 1).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" - res = df.select(regexp_split(df.s, '[ABC]', 0).alias('s')).collect() - assert res[0].S == "['one', 'two', 'three', '']" + assert res[0].S == '[\n "oneAtwoBthreeC"\n]' + res = df.select(regexp_split(df.s, 't', 2).alias('s')).collect() + assert res[0].S == '[\n "oneA",\n "woBthreeC"\n]' + res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect() + assert res[0].S == '[\n "one",\n "two",\n "three",\n ""\n]' res = df.select(regexp_split(df.s, '[ABC]', 1).alias('s')).collect() - assert res[0].S == "['oneAtwoBthreeC']" + assert res[0].S == '[\n "oneAtwoBthreeC"\n]' res = df.select(regexp_split(df.s, '[ABC]', 2).alias('s')).collect() - assert res[0].S == "['one', 'twoBthreeC']" - res = df.select(regexp_split(df.s, '[ABC]', -1).alias('s')).collect() - assert res[0].S == "['one', 'two', 'three', '']" - res = df.select(regexp_split(df.s, '[ABC]').alias('s')).collect() - assert res[0].S == "['one', 'two', 'three', '']" + assert res[0].S == '[\n "one",\n "twoBthreeC"\n]' df = session.createDataFrame([('HelloabNewacWorld',)], ['s',]) - res = df.select(regexp_split(df.s, 'abNew(a*)c', 2).alias('s')).collect() - assert res[0].S == "['Hello', 'World']" - res = df.select(regexp_split(df.s, 'abNew(ac)', 2).alias('s')).collect() - assert res[0].S == "['Hello', 'World']" - res = df.select(regexp_split(df.s, 'abNew[a]c', 2).alias('s')).collect() - assert res[0].S == "['Hello', 'World']" - res = df.select(regexp_split(df.s, 'a([b, c]).*?', 3).alias('s')).collect() - assert res[0].S == "['Hello', 'New', 'World']" res = df.select(regexp_split(df.s, 'a([b, c]).*?').alias('s')).collect() - assert res[0].S == "['Hello', 'New', 'World']" + assert res[0].S == '[\n "Hello",\n "New",\n "World"\n]' df = session.createDataFrame([(r'aa\nbb\nccc\b',)], ['s',]) res = df.select(regexp_split(df.s, r'\w+.').alias('s')).collect() - assert res[0].S == "['', '', '', 'b']" + assert res[0].S == '[\n "",\n "",\n "",\n "b"\n]' df = session.createDataFrame([(r'\n\n\n',)], ['s',]) res = df.select(regexp_split(df.s, '.*', 3).alias('s')).collect() - assert res[0].S == "['', '', '']" + assert res[0].S == '[\n "",\n "",\n ""\n]' df = session.createDataFrame([("""line 1 line 2 line 3""",)], ['s',]) res = df.select(regexp_split(df.s, r'\n', 3).alias('s')).collect() - assert res[0].S == "['line 1', 'line 2', 'line 3']" + assert res[0].S == '[\n "line 1",\n "line 2",\n "line 3"\n]' res = df.select(regexp_split(df.s, r'line 1(\n)', 3).alias('s')).collect() - assert res[0].S == "['', 'line 2\\nline 3']" + assert res[0].S == '[\n "",\n "line 2\\nline 3"\n]' df = session.createDataFrame([('The price of PINEAPPLE ice cream is 20',)], ['s',]) res = df.select(regexp_split(df.s, r"(\b[A-Z]+\b).+(\b\d+)", 4).alias('s')).collect() - assert res[0].S == "['The price of ', '']" + assert res[0].S == '[\n "The price of ",\n ""\n]' df = session.createDataFrame([('',)], ['s',]) res = df.select(regexp_split(df.s, '".+?"', 4).alias('s')).collect() - assert res[0].S == "['']" \ No newline at end of file + assert res[0].S == '[\n ""\n]' \ No newline at end of file