From 14bdd3430c8da26f9d762850ef242ef3414174c7 Mon Sep 17 00:00:00 2001 From: Ar-Kareem Date: Mon, 2 Oct 2023 05:07:38 -0400 Subject: [PATCH] fix recursion error when setting tp_wrapped_module #122 --- src/tensor_parallel/wrapper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tensor_parallel/wrapper.py b/src/tensor_parallel/wrapper.py index 3d907c9..ab1fd39 100644 --- a/src/tensor_parallel/wrapper.py +++ b/src/tensor_parallel/wrapper.py @@ -73,3 +73,8 @@ def forward(self, *args, **kwargs): def __getattr__(self, attr): return getattr(self.tp_wrapped_module, attr) + + def __setattr__(self, attr, value): + super().__setattr__(attr, value) + if attr == "tp_wrapped_module": + self.__dict__["tp_wrapped_module"] = value # to access without getattr, nn.Module removed it from __dict__