Skip to content

Commit

Permalink
fix error (PaddlePaddle#56572)
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 authored and BeingGod committed Sep 9, 2023
1 parent 18e3c38 commit 02411b9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
6 changes: 3 additions & 3 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def _setitem_impl_(var, item, value):
ProgramTranslator,
)

ProgramTranslator.get_instance()._params_map.add(
ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, var.desc.id(), output
)

Expand Down Expand Up @@ -935,7 +935,7 @@ def _setitem_static(x, indices, values):

if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def _setitem_static(x, indices, values):
)
if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
Expand Down
1 change: 1 addition & 0 deletions python/paddle/jit/dy2static/basic_api_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import astor

from paddle.utils import gast
Expand Down
19 changes: 12 additions & 7 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def convert_load(x):

from paddle.jit.dy2static.program_translator import ProgramTranslator

new_var = ProgramTranslator.get_instance()._params_map.get(
new_var = ProgramTranslator.get_instance()._inplace_map.get(
cur_block.program, x.desc.id()
)
if new_var is not None:
Expand Down Expand Up @@ -381,9 +381,13 @@ def _run_paddle_cond(
_convert_tensor_arrray_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred)
init_args = helper.get(return_name_ids)
from paddle.jit.dy2static.program_translator import ProgramTranslator

inplace_map = ProgramTranslator.get_instance()._inplace_map

def new_true_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set(
return_name_ids,
paddle.utils.copy_mutable_vars(init_args),
Expand All @@ -392,21 +396,22 @@ def new_true_fn():
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None:
return helper.get(return_name_ids)
else:
return ret
ret = helper.get(return_name_ids)
inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret

def new_false_fn():
# init args may contain mutable python container like [var, 2], we copy then like in while_loop
inplace_map_checkpoint = inplace_map.save_checkpoint()
helper.set(
return_name_ids,
paddle.utils.copy_mutable_vars(init_args),
)
ret = false_fn()
if ret is None:
return helper.get(return_name_ids)
else:
return ret
ret = helper.get(return_name_ids)
inplace_map.restore_checkpoint(inplace_map_checkpoint)
return ret

try:
cond_outs = paddle.static.nn.cond(
Expand Down
48 changes: 28 additions & 20 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,50 +1256,51 @@ def from_func_spec(
)


def _program_hash(program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)


class ParametersRecorder:
def __init__(self):
self.params_dict = {}

@synchronized
def add(self, program, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
key = _program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = set()
params = self.params_dict[key]
params.add(param)

def pop(self, program):
params = self.params_dict.get(self._program_hash(program))
params = self.params_dict.get(_program_hash(program))
if params is None:
return []
del self.params_dict[self._program_hash(program)]
del self.params_dict[_program_hash(program)]
return list(params)

def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)


class ParametersMap:
class InplaceMap:
def __init__(self):
self.params_dict = {}

@synchronized
def add(self, program, id, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
key = _program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = {}

params = self.params_dict[key]
params[id] = param

def get(self, program, id):
params = self.params_dict.get(self._program_hash(program))
params = self.params_dict.get(_program_hash(program))
if params is None:
return None
if id not in params:
Expand All @@ -1313,12 +1314,19 @@ def get(self, program, id):
params[var.desc.id()] = root_var
return root_var

def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
def restore_checkpoint(self, checkpoint):
# InplaceMap is a nested effect.
# when enter a block, we should save a checkpoint
# when exit a block, we should restore a checkpoint
# for example:
# if cond > 0:
# x [:] = 0
# return x
# x[:] only effect current cond block, we should restore in false block.
self.params_dict = checkpoint

def save_checkpoint(self):
return dict(self.params_dict.items())


class FallbackProgramLayer:
Expand Down Expand Up @@ -1582,7 +1590,7 @@ def __init__(self):
self._initialized = True
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self._params_map = ParametersMap()
self._inplace_map = InplaceMap()
self.enable_to_static = True

def enable(self, enable_to_static):
Expand Down

0 comments on commit 02411b9

Please sign in to comment.