From 53b0055972d2a30294b11d0b51eab9d90246e0f6 Mon Sep 17 00:00:00 2001 From: EdwardLi-coder <2023edwardll@gmail.com> Date: Sun, 18 Aug 2024 20:49:03 +0800 Subject: [PATCH] add regexp replace --- src/datachain/sql/functions/string.py | 13 +++++++++++++ src/datachain/sql/sqlite/base.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/datachain/sql/functions/string.py b/src/datachain/sql/functions/string.py index 25b7b2c16..ee623d78f 100644 --- a/src/datachain/sql/functions/string.py +++ b/src/datachain/sql/functions/string.py @@ -26,5 +26,18 @@ class split(GenericFunction): # noqa: N801 inherit_cache = True +class regexp_replace(GenericFunction): # noqa: N801 + """ + Replaces substring that match a regular expression. + """ + + type = String() + package = "string" + name = "regexp_replace" + inherit_cache = True + + +compiler_not_implemented(regexp_replace) + compiler_not_implemented(length) compiler_not_implemented(split) diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index c677de39d..e8a16b59a 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -1,4 +1,5 @@ import logging +import re import sqlite3 from collections.abc import Iterable from datetime import MAXYEAR, MINYEAR, datetime, timezone @@ -178,9 +179,15 @@ def create_vector_functions(conn): _registered_function_creators["vector_functions"] = create_vector_functions + def sqlite_regexp_replace(string: str, pattern: str, replacement: str) -> str: + return re.sub(pattern, replacement, string) + def create_string_functions(conn): conn.create_function("split", 2, sqlite_string_split, deterministic=True) conn.create_function("split", 3, sqlite_string_split, deterministic=True) + conn.create_function( + "regexp_replace", 3, sqlite_regexp_replace, deterministic=True + ) _registered_function_creators["string_functions"] = create_string_functions @@ -239,6 +246,10 @@ def path_file_ext(path): return func.substr(path, func.length(path) - path_file_ext_length(path) + 1) +def compile_regexp_replace(element, compiler, **kwargs): + return f"regexp_replace({compiler.process(element.clauses, **kwargs)})" + + def compile_path_parent(element, compiler, **kwargs): return compiler.process(path_parent(*element.clauses.clauses), **kwargs) @@ -370,3 +381,8 @@ def load_usearch_extension(conn) -> bool: except Exception: # noqa: BLE001 return False + + +@compiles(string.regexp_replace, "sqlite") +def _compile_regexp_replace_sqlite(element, compiler, **kwargs): + return compile_regexp_replace(element, compiler, **kwargs)