Skip to content

Commit

Permalink
[aotinductor] add no weight change version of fuse_parallel_linear (p…
Browse files Browse the repository at this point in the history
…ytorch#115791)

Summary: We need a new version of fuse_parallel_linear w/o creating new weights for real-time update.

Reviewed By: khabinov

Differential Revision: D52128296

Pull Request resolved: pytorch#115791
Approved by: https://github.com/khabinov
  • Loading branch information
frank-wei authored and pytorchmergebot committed Dec 14, 2023
1 parent ca4caf4 commit 87547a2
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torch/fx/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ def remove_pass(self, _passes: List[Callable]):
self.passes = passes_left
self._validated = False

def replace_pass(self, _target, _replacement):
passes_left = []
for ps in self.passes:
if ps.__name__ == _target.__name__:
passes_left.append(_replacement)
else:
passes_left.append(ps)
self.passes = passes_left
self._validated = False

def validate(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
Expand Down

0 comments on commit 87547a2

Please sign in to comment.