From 02411b936f2276fc5bd58c71524fb42ef0e689c7 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 25 Aug 2023 17:00:39 +0800 Subject: [PATCH] fix error (#56572) --- python/paddle/fluid/variable_index.py | 6 +-- .../jit/dy2static/basic_api_transformer.py | 1 + .../paddle/jit/dy2static/convert_operators.py | 19 +++++--- .../jit/dy2static/program_translator.py | 48 +++++++++++-------- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 519ad7481b906b..bdc3d740509121 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -589,7 +589,7 @@ def _setitem_impl_(var, item, value): ProgramTranslator, ) - ProgramTranslator.get_instance()._params_map.add( + ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, var.desc.id(), output ) @@ -935,7 +935,7 @@ def _setitem_static(x, indices, values): if not paddle.in_dynamic_mode(): # map var to the new output - paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, x.desc.id(), output ) return output @@ -1008,7 +1008,7 @@ def _setitem_static(x, indices, values): ) if not paddle.in_dynamic_mode(): # map var to the new output - paddle.jit.api.ProgramTranslator.get_instance()._params_map.add( + paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( cur_block.program, x.desc.id(), output ) return output diff --git a/python/paddle/jit/dy2static/basic_api_transformer.py b/python/paddle/jit/dy2static/basic_api_transformer.py index 34b8708f6a22e2..30af698923de33 100644 --- a/python/paddle/jit/dy2static/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/basic_api_transformer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import astor from paddle.utils import gast diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 0089b4622cf705..5ffb3ebfce9783 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -53,7 +53,7 @@ def convert_load(x): from paddle.jit.dy2static.program_translator import ProgramTranslator - new_var = ProgramTranslator.get_instance()._params_map.get( + new_var = ProgramTranslator.get_instance()._inplace_map.get( cur_block.program, x.desc.id() ) if new_var is not None: @@ -381,9 +381,13 @@ def _run_paddle_cond( _convert_tensor_arrray_if_necessary(helper, push_pop_names) pred = cast_bool_if_necessary(pred) init_args = helper.get(return_name_ids) + from paddle.jit.dy2static.program_translator import ProgramTranslator + + inplace_map = ProgramTranslator.get_instance()._inplace_map def new_true_fn(): # init args may contain mutable python container like [var, 2], we copy then like in while_loop + inplace_map_checkpoint = inplace_map.save_checkpoint() helper.set( return_name_ids, paddle.utils.copy_mutable_vars(init_args), @@ -392,21 +396,22 @@ def new_true_fn(): # IfExpr will return a non-None return value, so we just return ret. # We assume normal return has no return value. if ret is None: - return helper.get(return_name_ids) - else: - return ret + ret = helper.get(return_name_ids) + inplace_map.restore_checkpoint(inplace_map_checkpoint) + return ret def new_false_fn(): # init args may contain mutable python container like [var, 2], we copy then like in while_loop + inplace_map_checkpoint = inplace_map.save_checkpoint() helper.set( return_name_ids, paddle.utils.copy_mutable_vars(init_args), ) ret = false_fn() if ret is None: - return helper.get(return_name_ids) - else: - return ret + ret = helper.get(return_name_ids) + inplace_map.restore_checkpoint(inplace_map_checkpoint) + return ret try: cond_outs = paddle.static.nn.cond( diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 182b9d786412c1..ed044d1fd2293e 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1256,6 +1256,14 @@ def from_func_spec( ) +def _program_hash(program): + """ + because program is not deleted while calling from_func_spec. + so it's ok to use id(program) + """ + return id(program) + + class ParametersRecorder: def __init__(self): self.params_dict = {} @@ -1263,35 +1271,28 @@ def __init__(self): @synchronized def add(self, program, param): """use the default_program as key, append param the parameter list.""" - key = self._program_hash(program) + key = _program_hash(program) if key not in self.params_dict: self.params_dict[key] = set() params = self.params_dict[key] params.add(param) def pop(self, program): - params = self.params_dict.get(self._program_hash(program)) + params = self.params_dict.get(_program_hash(program)) if params is None: return [] - del self.params_dict[self._program_hash(program)] + del self.params_dict[_program_hash(program)] return list(params) - def _program_hash(self, program): - """ - because program is not deleted while calling from_func_spec. - so it's ok to use id(program) - """ - return id(program) - -class ParametersMap: +class InplaceMap: def __init__(self): self.params_dict = {} @synchronized def add(self, program, id, param): """use the default_program as key, append param the parameter list.""" - key = self._program_hash(program) + key = _program_hash(program) if key not in self.params_dict: self.params_dict[key] = {} @@ -1299,7 +1300,7 @@ def add(self, program, id, param): params[id] = param def get(self, program, id): - params = self.params_dict.get(self._program_hash(program)) + params = self.params_dict.get(_program_hash(program)) if params is None: return None if id not in params: @@ -1313,12 +1314,19 @@ def get(self, program, id): params[var.desc.id()] = root_var return root_var - def _program_hash(self, program): - """ - because program is not deleted while calling from_func_spec. - so it's ok to use id(program) - """ - return id(program) + def restore_checkpoint(self, checkpoint): + # InplaceMap is a nested effect. + # when enter a block, we should save a checkpoint + # when exit a block, we should restore a checkpoint + # for example: + # if cond > 0: + # x [:] = 0 + # return x + # x[:] only effect current cond block, we should restore in false block. + self.params_dict = checkpoint + + def save_checkpoint(self): + return dict(self.params_dict.items()) class FallbackProgramLayer: @@ -1582,7 +1590,7 @@ def __init__(self): self._initialized = True self._program_cache = ProgramCache() self._params_recorder = ParametersRecorder() - self._params_map = ParametersMap() + self._inplace_map = InplaceMap() self.enable_to_static = True def enable(self, enable_to_static):