Skip to content

Commit

Permalink
fix recursion error when setting tp_wrapped_module #122
Browse files Browse the repository at this point in the history
  • Loading branch information
Ar-Kareem committed Oct 2, 2023
1 parent a7d1939 commit 14bdd34
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/tensor_parallel/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ def forward(self, *args, **kwargs):

def __getattr__(self, attr):
return getattr(self.tp_wrapped_module, attr)

def __setattr__(self, attr, value):
super().__setattr__(attr, value)
if attr == "tp_wrapped_module":
self.__dict__["tp_wrapped_module"] = value # to access without getattr, nn.Module removed it from __dict__

0 comments on commit 14bdd34

Please sign in to comment.