Skip to content

Commit

Permalink
[PIR] Fix IfOp exe bug (#58132)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
  • Loading branch information
zhangbo9674 authored Oct 18, 2023
1 parent ec23983 commit 9ec721f
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 51 deletions.
78 changes: 45 additions & 33 deletions paddle/fluid/framework/new_executor/instruction/cond_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,36 @@ CondInstruction::CondInstruction(size_t id,
}
VLOG(6) << "finish process cond_var and output_vars";

// NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is
// OpOperand of IfOp, and the other is external Values used in true_block or
// false_block.
auto true_branch_block = if_op.true_block();
auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block);
auto false_branch_block = if_op.false_block();
std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs);
auto false_outside_inputs =
GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); i++) {
pir::Value value = op->result(i);
if (value && value.type()) {
PADDLE_ENFORCE_EQ(
value_exec_info->HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";

Scope* true_scope = &(value_exec_info->GetScope()->NewScope());
true_branch_inter_ =
new NewIRInterpreter(place,
Expand All @@ -74,15 +102,20 @@ CondInstruction::CondInstruction(size_t id,
{});

std::set<std::string> true_skip_gc_names_set;
for (auto value : true_branch_yied_inputs) {
for (auto value : GetYiedOpInputs(true_branch_block)) {
true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
}
// NOTE(zhangbo): According to the concept of control flow, child scopes
// should not control the lifecycle of parent scope variables.
for (auto value : true_outside_inputs) {
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
}
true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set);
VLOG(6) << "finish process true branch interpreter";

auto false_branch_block = if_op.false_block();
auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block);
Scope* false_scope = &(value_exec_info->GetScope()->NewScope());
false_branch_inter_ =
new NewIRInterpreter(place,
Expand All @@ -93,38 +126,17 @@ CondInstruction::CondInstruction(size_t id,
{});

std::set<std::string> false_skip_gc_names_set;
for (auto value : false_branch_yied_inputs) {
for (auto value : GetYiedOpInputs(false_branch_block)) {
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
}
for (auto value : false_outside_inputs) {
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
}
false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set);
VLOG(6) << "finish process false branch interpreter";

// NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is
// OpOperand of IfOp, and the other is external Values used in true_block or
// false_block.
std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *value_exec_info, &inputs);
GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs);
GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); i++) {
pir::Value value = op->result(i);
if (value && value.type()) {
PADDLE_ENFORCE_EQ(
value_exec_info->HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";
}

CondInstruction::~CondInstruction() {
Expand All @@ -150,10 +162,10 @@ void CondInstruction::Run() {
DeviceContext().Wait();
if (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
true_branch_inter_->Run({}, false);
CopyBranchOutput(true_skip_gc_names_, true_branch_inter_);
CopyBranchOutput(true_branch_outputs_, true_branch_inter_);
} else {
false_branch_inter_->Run({}, false);
CopyBranchOutput(false_skip_gc_names_, false_branch_inter_);
CopyBranchOutput(false_branch_outputs_, false_branch_inter_);
}

// copy ouptut
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class CondInstruction : public InstructionBase {

NewIRInterpreter* false_branch_inter_;

std::vector<std::string> true_branch_outputs_;

std::vector<std::string> false_branch_outputs_;

// TODO(zhangbo): Currently, only the output of IfOp is included. In the
// future, need to consider how to support IfGradOp using IfOp value.
std::vector<std::string> true_skip_gc_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void GetInputIds(pir::Operation* op,
}
}

void GetOutsideOpInputs(
std::vector<pir::Value> GetOutsideOpInputs(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
Expand All @@ -232,6 +232,7 @@ void GetOutsideOpInputs(
}
}

std::vector<pir::Value> outside_op_inputs;
for (auto op : (*block)) {
for (size_t i = 0; i < op->num_operands(); ++i) {
pir::Value value = op->operand_source(i);
Expand All @@ -244,9 +245,11 @@ void GetOutsideOpInputs(
i,
op->name()));
input_ids->emplace(value, GetValueIds(value, value_exec_info));
outside_op_inputs.push_back(value);
}
}
}
return outside_op_inputs;
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);

void GetOutsideOpInputs(
std::vector<pir::Value> GetOutsideOpInputs(
pir::Block* block,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids);
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx,
pir::OpResult value = operation->result(idx_in_op);
bool generated_by_vector = value.type().isa<pir::VectorType>();

param_map->UpdateValue(
param_map->PushValue(
arg_name,
VariableDefiningInfo(
value,
Expand Down Expand Up @@ -1697,8 +1697,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
pir::OpResult value = operation->result(idx_in_op);
pir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
param_map->UpdateValue(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
param_map->PushValue(y_grad_var_name,
VariableDefiningInfo(reshape_op.out(), false, -1));
}
};

Expand Down Expand Up @@ -1865,8 +1865,8 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber {
auto output_var = output_vars[0];
auto fused_feedforward_op =
operation->dyn_cast<dialect::FusedFeedforwardOp>();
param_map->UpdateValue(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
param_map->PushValue(output_var,
VariableDefiningInfo{fused_feedforward_op.out()});
}
}
};
Expand Down
10 changes: 1 addition & 9 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,6 @@ void TranslationContext::PopValue(const Key& key) {
container_[key].pop_back();
}

void TranslationContext::UpdateValue(const Key& key, const Value& value) {
auto& vec = container_[key];
if (vec.empty())
vec.push_back(value);
else
vec.back() = value;
}

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
pir::Program* program)
: legacy_program_(legacy_program), program_(program) {
Expand Down Expand Up @@ -496,7 +488,7 @@ void ProgramTranslator::TranslateWhileOperation(const OpDesc* op,
}
auto name_iter = loop_vars_reverse.rbegin();
for (size_t idx = 0; idx < while_op->num_results(); ++idx) {
param_map_.UpdateValue(name_iter++->first, while_op->result(idx));
param_map_.PushValue(name_iter++->first, while_op->result(idx));
}
while_op->Verify();
VLOG(8) << "=============>end to translate while op:" << op;
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class TranslationContext {
size_t count(const Key& key)
const; // Caution: not exactly same as count in stl library

void UpdateValue(const Key& key, const Value& value);
void PushValue(const Key& key, const Value& value);
void PopValue(const Key& key);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget(
op_info);
block->push_back(operation);
pir::OpResult target_op_result = operation->result(0);
param_map->UpdateValue(arg_name, VariableDefiningInfo(target_op_result));
param_map->PushValue(arg_name, VariableDefiningInfo(target_op_result));
return operation;
}

Expand Down

0 comments on commit 9ec721f

Please sign in to comment.