From 87547a26b86848f23fc2a7785fef17a07853c16f Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 14 Dec 2023 18:36:17 +0000 Subject: [PATCH] [aotinductor] add no weight change version of fuse_parallel_linear (#115791) Summary: We need a new version of fuse_parallel_linear w/o creating new weights for real-time update. Reviewed By: khabinov Differential Revision: D52128296 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115791 Approved by: https://github.com/khabinov --- torch/fx/passes/pass_manager.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 37c31fdff19b6..4f103640fa30a 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -228,6 +228,16 @@ def remove_pass(self, _passes: List[Callable]): self.passes = passes_left self._validated = False + def replace_pass(self, _target, _replacement): + passes_left = [] + for ps in self.passes: + if ps.__name__ == _target.__name__: + passes_left.append(_replacement) + else: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + def validate(self): """ Validates that current pass schedule defined by `self.passes` is valid