Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: Test.__call__() got an unexpected keyword argument 'a' #138

Open
sylvorg opened this issue Apr 9, 2024 · 16 comments
Open

TypeError: Test.__call__() got an unexpected keyword argument 'a' #138

sylvorg opened this issue Apr 9, 2024 · 16 comments

Comments

@sylvorg
Copy link

sylvorg commented Apr 9, 2024

Hello!

With the following code:

from rich.traceback import install

install(show_locals=True)

from plum import dispatch


class Test:
    @dispatch
    def __call__(self):
        pass

    @dispatch
    def __call__(self, *args, **kwargs):
        pass


test = Test()
test(a=1)

I get the following traceback:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/wsl/sylvorg/sylvorg/sylveon/siluam/oreo/./test15.py:20 in <module>                          │
│                                                                                                  │
│   17 │   │   pass                                                                                │
│   18                                                                                             │
│   19 test = Test()                                                                               │
│ ❱ 20 test(a=1)                                                                                   │
│   21                                                                                             │
│                                                                                                  │
│ ╭───────────────────────────── locals ─────────────────────────────╮                             │
│ │ dispatch = <plum.dispatcher.Dispatcher object at 0x7f4713206350> │                             │
│ │  install = <function install at 0x7f4713d0bce0>                  │                             │
│ │     Test = <class '__main__.Test'>                               │                             │
│ │     test = <__main__.Test object at 0x7f47134f7950>              │                             │
│ ╰──────────────────────────────────────────────────────────────────╯                             │
│                                                                                                  │
│ /nix/store/0rqp0565hq7xhg1n0ld36icv3031f87s-python3-3.11.8-env/lib/python3.11/site-packages/plum │
│ /function.py:484 in __call__                                                                     │
│                                                                                                  │
│   481 │   │   pass                                                                               │
│   482 │                                                                                          │
│   483def __call__(self, _, *args, **kw_args):                                               │
│ ❱ 484 │   │   return self._f(self._instance, *args, **kw_args)                                   │
│   485 │                                                                                          │
│   486def invoke(self, *types):                                                              │
│   487 │   │   """See :meth:`.Function.invoke`."""                                                │
│                                                                                                  │
│ ╭───────────────────────────── locals ──────────────────────────────╮                            │
│ │       _ = <__main__.Test object at 0x7f47134f7950>                │                            │
│ │    args = ()                                                      │                            │
│ │ kw_args = {'a': 1}                                                │                            │
│ │    self = <plum.function._BoundFunction object at 0x7f47132bc0d0> │                            │
│ ╰───────────────────────────────────────────────────────────────────╯                            │
│                                                                                                  │
│ /nix/store/0rqp0565hq7xhg1n0ld36icv3031f87s-python3-3.11.8-env/lib/python3.11/site-packages/plum │
│ /function.py:368 in __call__                                                                     │
│                                                                                                  │
│   365def __call__(self, *args, **kw_args):                                                  │
│   366 │   │   __tracebackhide__ = True                                                           │
│   367 │   │   method, return_type = self._resolve_method_with_cache(args=args)                   │
│ ❱ 368 │   │   return _convert(method(*args, **kw_args), return_type)                             │
│   369 │                                                                                          │
│   370def _resolve_method_with_cache(                                                        │
│   371 │   │   self,                                                                              │
│                                                                                                  │
│ ╭─────────────────────────────────────────── locals ───────────────────────────────────────────╮ │
│ │        args = (<__main__.Test object at 0x7f47134f7950>,)                                    │ │
│ │     kw_args = {'a': 1}                                                                       │ │
│ │      method = <function Test.__call__ at 0x7f471383c540>                                     │ │
│ │ return_type = typing.Any                                                                     │ │
│ │        self = <multiple-dispatch function Test.__call__ (with 2 registered and 0 pending     │ │
│ │               method(s))>                                                                    │ │
│ ╰──────────────────────────────────────────────────────────────────────────────────────────────╯ │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Test.__call__() got an unexpected keyword argument 'a'

It doesn't seem to be happening with the following blocks either:

from rich.traceback import install

install(show_locals=True)

from plum import dispatch


class Test:
    @dispatch
    def __call__(self, *args):
        pass

    @dispatch
    def __call__(self, *args, **kwargs):
        pass


test = Test()
test(a=1)
from rich.traceback import install

install(show_locals=True)

from plum import dispatch


class Test:
    @dispatch
    def __call__(self):
        pass

    @dispatch
    def __call__(self, **kwargs):
        pass


test = Test()
test(a=1)

Thank you kindly for the help!

@sylvorg
Copy link
Author

sylvorg commented Apr 9, 2024

Interestingly, this does not seem to be working either, failing with the same message:

from rich.traceback import install

install(show_locals=True)

from plum import dispatch


class Test:
    @dispatch
    def __call__(self, **kwargs):
        pass

    @dispatch
    def __call__(self):
        pass


test = Test()
test(a=1)

@wesselb
Copy link
Member

wesselb commented Apr 21, 2024

Hey @sylvorg, what's going on is the following:

Firstly, it is important to know that Plum ignores keyword arguments in determining whether functions are the same or not:

from plum import dispatch


@dispatch
def f(x: int):
    return 1


@dispatch
def f(x: int, **kw_args):  # Overwrites the above, because keyword arguments are ignored in registering functions!
    return 2


@dispatch
def f(x: int, *, a=1, **kw_args):  # Overwrites above both two!
    return 3


assert f(1) == 3


@dispatch
def f(x: int, *, b=1):  # Overwrites all above methods! Now only keyword `b` is available!
    return 4


f(1, a=1)  # TypeError: f() got an unexpected keyword argument 'a'

However, Plum does not ignore *args:

from plum import dispatch


@dispatch
def f(x: int):
    return 1


@dispatch
def f(x: int, *args):
    return 2


assert f(1) == 1
assert f(1, 1) == 2

Therefore, in your case, the following is happening:

class Test:
    @dispatch
    def __call__(self):
        pass

    @dispatch
    def __call__(self, *args, **kwargs):  # New definition for `__call__`!
        pass

When you call Test(a=1), keyword arguments are ignored: it effectively looks for Test(). The most specific method matching this is the first definition: def __call__(self). And this definition does not implement keyword arguments, giving you the error.

In the other two cases, the second __call__ overwrites the first __call__ because both have *args or neither have *args.

I hope that makes sense!

@sylvorg
Copy link
Author

sylvorg commented Apr 27, 2024

Hmm... Would it be particularly difficult implementing support for keyword arguments? I would have thought that the keyword name makes the resolution easier, and any keyword names and their values not matching the ones already in the resolved functions will either go to a function with **kwargs (or any other variable keyword argument), or result in an NotFoundLookupError.

@wesselb
Copy link
Member

wesselb commented Apr 27, 2024

@sylvorg We've had a fair bit of discussion around keyword arguments. Very long story short is that supporting keyword arguments isn't entirely straightforward.

Plum's design basically mimics how multiple dispatch in the Julia programming works, including the treatment of keyword arguments.

@sylvorg
Copy link
Author

sylvorg commented Apr 27, 2024

Hmm... Does this mean none of my plum methods can use keywords arguments? At all? As they would always be overridden?

@wesselb
Copy link
Member

wesselb commented Apr 27, 2024

@sylvorg Definitely not! It just means that which methods are considered the same (i.e. what would be a redefinition) isn't influenced by keyword arguments. For example, this should work fine:

from plum import dispatch

@dispatch
def f(x: int, *, option1 = None):
    return x

@dispatch
def f(x: str, *, option2 = None):
    return x
>>> f(1, option1="value")  # OK

>>> f(1, option2="value")  # Not OK

>>> f("test", option1="value")  # Not OK

>>> f("test", option2="value")  # OK

Keyword arguments without *, are treated in this way:

@dispatch
def f(x: int, y: str="hey"):
    ...  # Implementation

is converted to

@dispatch
def f(x: int):
    y = "hey"
    ...  # Implementation

@dispatch
def f(x: int, y: str):
    ...  # Implementation

@sylvorg
Copy link
Author

sylvorg commented Apr 27, 2024

Ah; so the names are matched in the case of keyword arguments, not the types. Got it. I just can't accept variable keyword arguments in a plum method, right? Those would be overridden, so just one of them or the last one would work.

@sylvorg
Copy link
Author

sylvorg commented Apr 28, 2024

Actually, could you help me create a patch for my own version of plum that would redirect any unknown keyword arguments to a function with **kwargs? Which file(s) would I have to modify to get that functionality?

@wesselb
Copy link
Member

wesselb commented Apr 29, 2024

@sylvorg You should be able to use variable keyword arguments. Let me try to put together an example of what might be the desired behaviour. I’ll do that a little later

@wesselb
Copy link
Member

wesselb commented Apr 29, 2024

Since keyword arguments are ignored in the dispatch, you won't be able to create versions of methods for specific keyword arguments. However, you could emulate this behaviour in a way like this:

from plum import dispatch


@dispatch
def f(x, **kw_args):
    if kw_args:
        return f.invoke(object, object)(x, **kw_args)
    print("Only one argument and no keywords!")


@dispatch
def f(*args, **kw_args):
    print("Fallback:", args, kw_args)

Actually, could you help me create a patch for my own version of plum that would redirect any unknown keyword arguments to a function with **kwargs?

I think this might be considerably hard to get right. If I simple solution like the above suffices, my advice would be to go with that.

@sylvorg
Copy link
Author

sylvorg commented Apr 29, 2024

So basically I should just make sure that, if I use **kwargs, I should put them in every function of the same name, then act based on whether kwargs is empty or not?

@wesselb
Copy link
Member

wesselb commented May 18, 2024

@sylvorg, hmm, if you truly want to achieve dispatch based on which keyword arguments are given, you will unfortunately need to emulate that behaviour.

Perhaps another approach is possible, which has been proposed in other issues:

def f(x, *, option=None):
    return _f(x, option)

@dispatch
def _f(x, option: None):
    ...

@dispatch
def _f(x, option: str):
    ...

The idea is that you convert keyword arguments into positional arguments in a fixed order by using a wrapper function, and then use Plum's dispatch for positional arguments.

Could something like that be OK for your use case?

@sylvorg
Copy link
Author

sylvorg commented Jun 4, 2024

Hello, and sorry for the extreme delay; I'm taking two summer courses right now! 😅

In trying to create my own version of beartype and plum, I may have solved the scoring and TypeVar detection system, I think...?

def as_decorator():
    return True


class NO_CONSTRAINT:
    pass


class Function:
    def __init__(self, func):
        func.__signature__ = normalize_signature(signature(func))
        self.__paramessages__ = []
        self.__positional__ = dict()
        self.__keywords__ = dict()
        self.__var_positional__ = False
        self.__var_keyword__ = False
        for name, param in func.__signature__.parameters.items():
            if param.kind is Parameter.VAR_POSITIONAL:
                self.__var_positional__ = True
            if param.kind is Parameter.VAR_KEYWORD:
                self.__var_keyword__ = True
            if param.kind in (
                Parameter.POSITIONAL_ONLY,
                Parameter.POSITIONAL_OR_KEYWORD,
            ):
                self.__positional__[name] = param
            if param.kind in (
                Parameter.KEYWORD_ONLY,
                Parameter.POSITIONAL_OR_KEYWORD,
            ):
                self.__keywords__[name] = param
            if param.default is Parameter.empty:
                self.__paramessages__.append(f"{name}: {param.annotation.__name__}")
            else:
                self.__paramessages__.append(
                    f"{name}: {param.annotation.__name__} = {param.default}"
                )

        self.__constraint__ = NO_CONSTRAINT
        self.__func__ = func

    def __call__(self, *args, **kwargs):
        return self.__func__(*args, **kwargs)

    def __hash__(self):
        return hash(self.__func__)

    def __id__(self):
        return id(self.__func__)

    def __constrain__(self, param, arg):
        if isinstance(param.annotation, typing.TypeVar):
            if param.annotation.__constraints__:
                for constraint in param.annotation.__constraints__:
                    if isinstance(arg, constraint, exact=True):
                        self.__constraint__ = constraint

    def __score__(self, args, kwargs):
        score = -1 if self.__var_positional__ or self.__var_keyword__ else 0
        for index, (name, param) in enumerate(self.__positional__.items()):
            try:
                arg = args[index]
            except IndexError:
                score -= 1
            else:
                self.__constrain__(param, arg)
                if isinstance(arg, param.annotation):
                    score += 1
                else:
                    score -= 1
        if len(args) > len(self.__positional__) and not self.__var_positional__:
            score -= len(args) - len(self.__positional__)
        for name, param in self.__keywords__.items():
            if name in kwargs:
                self.__constrain__(param, kwargs[name])
                if isinstance(kwargs[name], param.annotation):
                    score += 1
                else:
                    score -= 1
            else:
                score -= 1
        for name in kwargs:
            if name not in self.__keywords__ and not self.__var_keyword__:
                score -= 1
        return score

    def __getattr__(self, attr):
        return getattr(self.__func__, attr)

    def __repr__(self):
        return f"{self.__func__.__name__}({', '.join(self.__paramessages__)}) -> {self.__func__.__signature__.return_annotation.__name__}"


class Treads(list):
    def __init__(self, func):
        super().__init__()
        self.append(Function(func))

    def __call__(self, *args, **kwargs):
        if as_decorator() and len(args) == 1 and isinstance(args[0], abc.Callable):
            self.append(Function(args[0]))
            return self
        scores = defaultdict(list)
        var_positional_scores = defaultdict(list)
        var_keyword_scores = defaultdict(list)
        var_both_scores = defaultdict(list)
        var_no_scores = defaultdict(list)

        for func in self:
            score = func.__score__(args, kwargs)
            scores[score].append(func)
            if func.__var_positional__ and func.__var_keyword__:
                var_both_scores[score].append(func)
            elif func.__var_positional__:
                var_positional_scores[score].append(func)
            elif func.__var_keyword__:
                var_keyword_scores[score].append(func)
            elif not (func.__positional__ or func.__keywords__):
                var_no_scores[score].append(func)

        message_prefix = lambda funcs: f"multiple callables from {funcs} match the"
        message = lambda funcs: f"{message_prefix(funcs)} arguments {args}"

        # TODO: What happens with functions with:
        #       - no arguments
        #       - no keyword arguments
        #       - only arguments
        #       - only keyword arguments
        #       - both arguments and keyword arguments
        funcs = scores[max(scores)]
        if len(funcs) == 1:
            pass

        elif args and kwargs and var_both_scores:
            funcs = var_both_scores[max(var_both_scores)]
            if len(funcs) > 1:
                raise TypeError(f"{message(funcs)} and keyword arguments {kwargs}")
        elif args and var_positional_scores:
            funcs = var_positional_scores[max(var_positional_scores)]
            if len(funcs) > 1:
                raise TypeError(message(funcs))
        elif kwargs and var_keyword_scores:
            funcs = var_keyword_scores[max(var_keyword_scores)]
            if len(funcs) > 1:
                raise TypeError(f"{message_prefix(funcs)} keyword arguments {kwargs}")
        elif not (args or kwargs) and var_no_scores:
            funcs = var_no_scores[max(var_no_scores)]
            if len(funcs) > 1:
                raise TypeError(f"{message_prefix(funcs)} missing arguments")
        else:
            raise TypeError(f"{message(funcs)} and keyword arguments {kwargs}")

        func = funcs[0]
        result = func(*args, **kwargs)
        return_annotation = func.__signature__.return_annotation

        message = (
            f"callable {func} returned {result} of type {type(result)}"
            f", which is not of type {return_annotation}"
        )

        if isinstance(return_annotation, typing.TypeVar) and not isinstance(
            result, func.__constraint__
        ):
            raise TypeError(f"{message}, specifically {func.__constraint__}")
        if not isinstance(result, return_annotation):
            raise TypeError(message)
        return result

@Treads
def f(x: int):
    return 1


@f
def f(x: int, **kw_args):
    return 2


@f
def f(x: int, *, a=1, **kw_args):
    return 3


@f
def f(x: int, *, b=1):
    return 4


assert f(1) == 1
assert f(1, z=1) == 2
assert f(1, a=1) == 3
assert f(1, b=1) == 4


@Treads
def f(si: T) -> T:
    return str(si)


try:
    f(0)
except TypeError:
    assert True
else:
    assert False

Sorry for the long code, and again for the even longer delay!

@wesselb
Copy link
Member

wesselb commented Jun 5, 2024

Hey @sylvorg! Not a problem, of course. Hope the summer courses are interesting. :) Which courses are you taking?

Thanks for sharing your proof of concept! That looks interesting. Could you explain the principle of what it's doing? Are you assigning some kind of score and then choosing the method with the based score?

@sylvorg
Copy link
Author

sylvorg commented Jun 5, 2024

Thanks for asking! I'm taking Introductory Psychology and Environment and Globalization, the latter of which is kinda stressing me out with a 2000 word essay with proper references due on the 17th. 😹 Ah, well. I've got some time to complete it.

Thanks for sharing your proof of concept! That looks interesting. Could you explain the principle of what it's doing? Are you assigning some kind of score and then choosing the method with the based score?

No problem, and that's exactly what I'm doing; I assume plum was also using a sort-of scoring system...? Because that's the only way I could think of, at the moment. My next steps will be to score the function based on how far up the mro an argument type is. This is just a little something I wrote down last night:

If the type of the argument matches the annotation exactly, award 1 point, otherwise, the further up the mro the type is, the more points it loses. For example, if A is a parent of B, and B is a parent of C, then if the annotation is C and the argument is of type C, award 1 point; if the annotation is B and the argument is of type C, award 0 points; if the annotation is A and the argument is of type C, award -1 points. Create a function which goes up an annotation's mro, builds a tree, and stores it in a dictionary.

If the annotation is a generic, for generics like Callable, etc. where you have to check the inner annotations, maybe use the maximum score between all the inner annotations. For Union, use the mro tree of the first matching type.

Otherwise, treat it as though the generic matches the type exactly.

I think that should work, though I'm still a little confused on how Callable, tuple, etc. should be treated.

@wesselb
Copy link
Member

wesselb commented Jun 13, 2024

Good luck with the essay! You got this. :)

Hmm, that's an interesting approach! Plum currently does not use a scoring system. The algorithm is quite simple: determine all methods which are applicable to the given arguments, and eliminate a method if its type are all strict subtypes of another method. If there is one method left, success. If there are none left, no method could be determined. If more than one is left, raise an ambiguity error.

I'll have to think about your scoring system a bit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants