From 9ec721f72438600250c85086cc486d7930b2cb26 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:41:34 +0800 Subject: [PATCH] [PIR] Fix IfOp exe bug (#58132) * fix * fix * fix * fix * fix * fix * fix * fix bug * fix bug * fix bug * fix bug * fix bug --- .../instruction/cond_instruction.cc | 78 +++++++++++-------- .../instruction/cond_instruction.h | 4 + .../instruction/instruction_util.cc | 5 +- .../instruction/instruction_util.h | 2 +- .../ir_adaptor/translator/op_translator.cc | 10 +-- .../translator/program_translator.cc | 10 +-- .../translator/program_translator.h | 1 - paddle/fluid/ir_adaptor/translator/utils.cc | 2 +- 8 files changed, 61 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 780219e406bff..2422597ece0d1 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -62,8 +62,36 @@ CondInstruction::CondInstruction(size_t id, } VLOG(6) << "finish process cond_var and output_vars"; + // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is + // OpOperand of IfOp, and the other is external Values used in true_block or + // false_block. auto true_branch_block = if_op.true_block(); - auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block); + auto false_branch_block = if_op.false_block(); + std::unordered_map> inputs; + GetInputIds(op, *value_exec_info, &inputs); + auto true_outside_inputs = + GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); + auto false_outside_inputs = + GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); + SetInputs(inputs); + + std::unordered_map> outputs; + for (size_t i = 0; i < op->num_results(); i++) { + pir::Value value = op->result(i); + if (value && value.type()) { + PADDLE_ENFORCE_EQ( + value_exec_info->HasValue(value), + true, + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + "if op")); + outputs.emplace(value, GetValueIds(value, *value_exec_info)); + } + } + SetOutputs(outputs); + VLOG(6) << "finish process inputs outputs index"; + Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); true_branch_inter_ = new NewIRInterpreter(place, @@ -74,15 +102,20 @@ CondInstruction::CondInstruction(size_t id, {}); std::set true_skip_gc_names_set; - for (auto value : true_branch_yied_inputs) { + for (auto value : GetYiedOpInputs(true_branch_block)) { + true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); + } + // NOTE(zhangbo): According to the concept of control flow, child scopes + // should not control the lifecycle of parent scope variables. + for (auto value : true_outside_inputs) { true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); } true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); VLOG(6) << "finish process true branch interpreter"; - auto false_branch_block = if_op.false_block(); - auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); false_branch_inter_ = new NewIRInterpreter(place, @@ -93,38 +126,17 @@ CondInstruction::CondInstruction(size_t id, {}); std::set false_skip_gc_names_set; - for (auto value : false_branch_yied_inputs) { + for (auto value : GetYiedOpInputs(false_branch_block)) { + false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + for (auto value : false_outside_inputs) { false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); } false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); VLOG(6) << "finish process false branch interpreter"; - - // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is - // OpOperand of IfOp, and the other is external Values used in true_block or - // false_block. - std::unordered_map> inputs; - GetInputIds(op, *value_exec_info, &inputs); - GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); - GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); - SetInputs(inputs); - - std::unordered_map> outputs; - for (size_t i = 0; i < op->num_results(); i++) { - pir::Value value = op->result(i); - if (value && value.type()) { - PADDLE_ENFORCE_EQ( - value_exec_info->HasValue(value), - true, - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - "if op")); - outputs.emplace(value, GetValueIds(value, *value_exec_info)); - } - } - SetOutputs(outputs); - VLOG(6) << "finish process inputs outputs index"; } CondInstruction::~CondInstruction() { @@ -150,10 +162,10 @@ void CondInstruction::Run() { DeviceContext().Wait(); if (cond_var_->Get().data()[0]) { true_branch_inter_->Run({}, false); - CopyBranchOutput(true_skip_gc_names_, true_branch_inter_); + CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { false_branch_inter_->Run({}, false); - CopyBranchOutput(false_skip_gc_names_, false_branch_inter_); + CopyBranchOutput(false_branch_outputs_, false_branch_inter_); } // copy ouptut diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 1cdc4a388126a..469c0ed0ae1ab 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -58,6 +58,10 @@ class CondInstruction : public InstructionBase { NewIRInterpreter* false_branch_inter_; + std::vector true_branch_outputs_; + + std::vector false_branch_outputs_; + // TODO(zhangbo): Currently, only the output of IfOp is included. In the // future, need to consider how to support IfGradOp using IfOp value. std::vector true_skip_gc_names_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index cf845ca482437..4066bc7afb3dc 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -221,7 +221,7 @@ void GetInputIds(pir::Operation* op, } } -void GetOutsideOpInputs( +std::vector GetOutsideOpInputs( pir::Block* block, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids) { @@ -232,6 +232,7 @@ void GetOutsideOpInputs( } } + std::vector outside_op_inputs; for (auto op : (*block)) { for (size_t i = 0; i < op->num_operands(); ++i) { pir::Value value = op->operand_source(i); @@ -244,9 +245,11 @@ void GetOutsideOpInputs( i, op->name())); input_ids->emplace(value, GetValueIds(value, value_exec_info)); + outside_op_inputs.push_back(value); } } } + return outside_op_inputs; } } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index fdc0e8774c1c5..8304b134e0534 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -49,7 +49,7 @@ void GetInputIds(pir::Operation* op, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids); -void GetOutsideOpInputs( +std::vector GetOutsideOpInputs( pir::Block* block, const ValueExecutionInfo& value_exec_info, std::unordered_map>* input_ids); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 3b665d174df55..92cb9504bb47b 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -677,7 +677,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx, pir::OpResult value = operation->result(idx_in_op); bool generated_by_vector = value.type().isa(); - param_map->UpdateValue( + param_map->PushValue( arg_name, VariableDefiningInfo( value, @@ -1697,8 +1697,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber { pir::OpResult value = operation->result(idx_in_op); pir::Builder builder(ctx, operation->GetParent()); auto reshape_op = builder.Build(value, y_shape); - param_map->UpdateValue(y_grad_var_name, - VariableDefiningInfo(reshape_op.out(), false, -1)); + param_map->PushValue(y_grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); } }; @@ -1865,8 +1865,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber { auto output_var = output_vars[0]; auto fused_feedforward_op = operation->dyn_cast(); - param_map->UpdateValue(output_var, - VariableDefiningInfo{fused_feedforward_op.out()}); + param_map->PushValue(output_var, + VariableDefiningInfo{fused_feedforward_op.out()}); } } }; diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index b899991f0f994..4dbf9d8707409 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -261,14 +261,6 @@ void TranslationContext::PopValue(const Key& key) { container_[key].pop_back(); } -void TranslationContext::UpdateValue(const Key& key, const Value& value) { - auto& vec = container_[key]; - if (vec.empty()) - vec.push_back(value); - else - vec.back() = value; -} - ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, pir::Program* program) : legacy_program_(legacy_program), program_(program) { @@ -496,7 +488,7 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op, } auto name_iter = loop_vars_reverse.rbegin(); for (size_t idx = 0; idx < while_op->num_results(); ++idx) { - param_map_.UpdateValue(name_iter++->first, while_op->result(idx)); + param_map_.PushValue(name_iter++->first, while_op->result(idx)); } while_op->Verify(); VLOG(8) << "=============>end to translate while op:" << op; diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 9d9e1b99552af..668f4db2c9682 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -78,7 +78,6 @@ class TranslationContext { size_t count(const Key& key) const; // Caution: not exactly same as count in stl library - void UpdateValue(const Key& key, const Value& value); void PushValue(const Key& key, const Value& value); void PopValue(const Key& key); diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index e8102e4e686a2..7f50115c5c578 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget( op_info); block->push_back(operation); pir::OpResult target_op_result = operation->result(0); - param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result)); + param_map->PushValue(arg_name, VariableDefiningInfo(target_op_result)); return operation; }