diff --git a/neetbox/utils/mvc.py b/neetbox/utils/mvc.py index 6ba81564..ecb7acf8 100644 --- a/neetbox/utils/mvc.py +++ b/neetbox/utils/mvc.py @@ -23,23 +23,33 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] -def patch(func): +def patch(func=None, *, name=None, overwrite=True): """Patch a function into a class type Args: func (Function): A function that takes at least one argument with a specific class type 'self:YourClass' + name (str, optional): The name to assign to the method in the class. Defaults to the function's name. + overwrite (bool, optional): Whether to overwrite an existing method in the class. Defaults to True. Returns: - function: patched function + function: The patched function """ + if func is None: + return lambda f: patch(f, name=name, overwrite=overwrite) + + # Extract the class from the first parameter's type annotation cls = next(iter(func.__annotations__.values())) - defaults = func.__defaults__ or () - name = defaults[0] if defaults else None - func.__qualname__ = f"{cls.__name__}.{func.__name__}" + method_name = name or func.__name__ + + # Check if the method already exists in the class + if not overwrite and hasattr(cls, method_name): + # Do not overwrite; return the original function unmodified + return func + + # Update function metadata + func.__qualname__ = f"{cls.__name__}.{method_name}" func.__module__ = cls.__module__ - if name is None: - setattr(cls, func.__name__, func) - else: - func.__qualname__ = f"{cls.__name__}.{name}" - setattr(cls, name, func) + + # Patch the function into the class + setattr(cls, method_name, func) return func