Skip to content

Commit

Permalink
[Custom Operator] Custom op support inplace mechanism (#51620)
Browse files Browse the repository at this point in the history
* init unit test commit, contains register thinking

* support inplace

* get inplaced x.grad

* Try support inplace and hook at the same time

* Support inplace, need debug

* Support inplace successfully

* Inplace use Tensor&, consistent with Tensor*

* fix MapPlainOutputs bug

* fix double grad inplace error
  • Loading branch information
jiahy0825 authored Mar 16, 2023
1 parent 0b778bd commit f824bc0
Show file tree
Hide file tree
Showing 10 changed files with 713 additions and 36 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/eager/custom_operator/custom_operator_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();

Expand Down Expand Up @@ -205,6 +208,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
}
VLOG(6) << "Prepare Grad attrs";
ctx.EmplaceBackAttrs(attrs_);
// NOTE(HongyuJia): grad_outputs_names.size() <= OutputMeta().size():
// OutputMeta().size() indicates input size of forward op,
// grad_outputs_names.size() indicates output size of backward op.
paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize> outs(
OutputMeta().size());
paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize>
Expand Down Expand Up @@ -234,8 +240,10 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
}
VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad";

ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn(
kernel_map.at(op_type_)[1]))(&ctx);
ctx.AssignInplaceOutputs();

VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas;
Expand Down Expand Up @@ -353,6 +361,8 @@ RunCustomOpDoubleGradNode::operator()(
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]);
auto grad_outputs_names =
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]);
const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(vec_map[2]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();

Expand Down Expand Up @@ -419,8 +429,10 @@ RunCustomOpDoubleGradNode::operator()(
}
VLOG(7) << "Run Kernel of Grad Custom Op: " << name();

ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn(
kernel_map.at(op_type_)[2]))(&ctx);
ctx.AssignInplaceOutputs();

return outs;
}
Expand Down
66 changes: 51 additions & 15 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ static std::vector<std::string> ParseAttrStr(const std::string& attr) {
////////////////// Kernel Define ////////////////////

// custom op kernel call function define
static void RunKernelFunc(const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
static void RunKernelFunc(
const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map) {
VLOG(3) << "Custom Operator: Start run KernelFunc.";
// prepare CustomOpKernelContext
paddle::CustomOpKernelContext kernel_ctx;
Expand Down Expand Up @@ -283,7 +285,10 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
VLOG(4) << "Initialize phi tensor operants successfully";
}

// handle inplace case
kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map);
func(&kernel_ctx);
kernel_ctx.AssignInplaceOutputs();

// sync output tensor data into original output
auto* calc_outs = kernel_ctx.AllMutableOutput();
Expand Down Expand Up @@ -686,12 +691,14 @@ static void RegisterOperatorKernelWithPlace(
OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func;
}

static void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs,
void* dso_handle) {
static void RegisterOperatorKernel(
const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map,
void* dso_handle) {
VLOG(3) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
Expand All @@ -701,10 +708,10 @@ static void RegisterOperatorKernel(const std::string& name,
OperatorWithKernel::OpKernelFunc op_kernel_func;
if (kernel_func) {
VLOG(3) << "Register custom operator " << name << " with kernel func";
op_kernel_func = [kernel_func, inputs, outputs, attrs](
op_kernel_func = [kernel_func, inputs, outputs, attrs, inplace_map](
const framework::ExecutionContext& ctx) {
VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs, inplace_map);
};
} else {
VLOG(3) << "Register custom operator " << name
Expand Down Expand Up @@ -760,6 +767,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta);
auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta);
auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(base_op_meta);
auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta);
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
Expand All @@ -771,6 +779,12 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
<< string::join_strings(op_outputs, ',');
VLOG(3) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');
if (!op_inplace_map.empty()) {
VLOG(3) << "Custom Operator: forward, op inplace_map: "
<< string::join_strings(op_inplace_map, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}

// Op
info.creator_ = [](const std::string& op_name,
Expand All @@ -795,6 +809,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
op_name,
info.proto_->InitializationErrorString()));

// Inplace
if (!op_inplace_map.empty()) {
info.infer_inplace_ = [op_inplace_map](bool use_cuda) {
return op_inplace_map;
};
}

// InferShape
if (infer_shape_func == nullptr) {
// use default InferShape
Expand Down Expand Up @@ -908,8 +929,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
}

// Kernel func
RegisterOperatorKernel(
op_name, kernel_fn, op_inputs, op_outputs, op_attrs, dso_handle);
RegisterOperatorKernel(op_name,
kernel_fn,
op_inputs,
op_outputs,
op_attrs,
op_inplace_map,
dso_handle);

// If grad op or double grad op exists
std::string cur_op_name = op_name;
Expand All @@ -920,6 +946,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_op_inplace_map = OpMetaInfoHelper::GetInplaceMap(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);

Expand All @@ -928,6 +955,14 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
<< string::join_strings(grad_op_inputs, ',');
VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');
VLOG(3) << "Custom Operator: backward, op attrs: "
<< string::join_strings(grad_op_attrs, ',');
if (!op_inplace_map.empty()) {
VLOG(3) << "Custom Operator: backward, op inplace_map: "
<< string::join_strings(grad_op_inplace_map, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}

bool is_double_grad = (i == 2);

Expand Down Expand Up @@ -1040,6 +1075,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
grad_op_inputs,
grad_op_outputs,
grad_op_attrs,
grad_op_inplace_map,
dso_handle);

// update current info
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_meta_info_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class OpMetaInfoHelper {
const paddle::OpMetaInfo& info) {
return info.attrs_;
}
static const std::unordered_map<std::string, std::string>& GetInplaceMap(
const paddle::OpMetaInfo& info) {
return info.inplace_map_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_;
}
Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/pybind/eager_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,18 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
meta_info_map.at(op_type)[0]));
ctx.EmplaceBackAttrs(res_attrs);
const auto& vec_map = meta_info_map.at(op_type);

// handle inplace case
const auto& inputs = paddle::framework::OpMetaInfoHelper::GetInputs(
meta_info_map.at(op_type)[0]);
const auto& outputs = paddle::framework::OpMetaInfoHelper::GetOutputs(
meta_info_map.at(op_type)[0]);
const auto& inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(
meta_info_map.at(op_type)[0]);
ctx.MapPlainOutputs(inputs, outputs, inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx);
ctx.AssignInplaceOutputs();

VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas;
Expand All @@ -557,12 +568,43 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
require_any_grad || egr::EagerUtils::ComputeRequireGrad(
trace_backward, &(ins_auto_grad_metas[i]));
}

// handle inplace case
for (size_t i = 0; i < ctx.InputRange().size(); i++) {
if (inplace_map.find(inputs[i]) != inplace_map.end()) {
size_t input_size =
ctx.InputRangeAt(i).second - ctx.InputRangeAt(i).first;
size_t start_idx = ctx.InputRangeAt(i).first;
for (size_t j = 0; j < input_size; j++) {
egr::EagerUtils::CheckInplace(ctx.InputAt(start_idx + j),
ins_auto_grad_metas[i][j],
require_any_grad);
// Bump Inplace Version
ctx.MutableInputAt(start_idx + j).bump_inplace_version();
VLOG(3) << "Custom operator: Tensor("
<< ctx.InputAt(start_idx + j).name()
<< ") uses Inplace Strategy.";
}
}
}

if (require_any_grad && (vec_map.size() > 1)) {
VLOG(6) << " Construct Grad for Custom Op: " << op_type;
ConstructFwdAndBwdMap(vec_map, op_type);
for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) {
egr::EagerUtils::PassStopGradient(false, &(outs_auto_grad_metas[i]));
}
// Note(HongyuJia): In dygraph eager mode, CheckInplace makes sure leaf
// nodes set stop_gradient=True. However, dygraph mode can also outputs
// lead nodes' gradients (For example, we can get x.grad after x.add_(y)).
// To be consistent with dygraph mode, we have to PassStopGradient for all
// inplaced ins_auto_grad_metas.
std::unordered_map<size_t, size_t> inplace_tensor_map =
ctx.GetInplaceTensorMap();
for (auto pair : inplace_tensor_map) {
egr::EagerUtils::PassStopGradient(false,
&(ins_auto_grad_metas[pair.first]));
}
auto grad_node = std::make_shared<egr::RunCustomOpNode>(
outs_auto_grad_metas.size(), ins_auto_grad_metas.size(), op_type);
auto slot_map =
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,7 @@ paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj,
return ::pybind11::handle(obj).cast<paddle::CustomOpKernelContext>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), "
"argument (position %d) must be CustomOpKernelContext, "
"but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
Expand Down
Loading

0 comments on commit f824bc0

Please sign in to comment.