Skip to content

Commit

Permalink
[New IR]Bind core structrure (#55665)
Browse files Browse the repository at this point in the history
* bind ir core

* perfect code

* deal with conflict
  • Loading branch information
YuanRisheng authored Jul 26, 2023
1 parent e838a4b commit ee506c2
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 27 deletions.
2 changes: 0 additions & 2 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"fetch",
};

constexpr char kAttrStopGradients[] = "stop_gradient";

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program)
: legacy_program_(legacy_program), program_(program) {
Expand Down
174 changes: 167 additions & 7 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@
#include <unordered_set>
#include <utility>

#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
#include "pybind11/stl.h"

namespace py = pybind11;
using ir::Block;
using ir::Operation;
using ir::OpOperand;
using ir::OpResult;
using ir::Program;
using ir::Type;
using ir::Value;
using paddle::dialect::DenseTensorType;
using pybind11::return_value_policy;

namespace paddle {
Expand All @@ -53,24 +64,173 @@ void BindProgram(py::module *m) {
void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference)
.def("get_op_list", [](Block &self) -> py::list {
py::list op_list;
for (auto iter = self.begin(); iter != self.end(); iter++) {
op_list.append(*iter);
}
return op_list;
.def("get_ops",
[](Block &self) -> py::list {
py::list op_list;
for (auto iter = self.begin(); iter != self.end(); iter++) {
op_list.append(*iter);
}
return op_list;
})
.def("remove_op", [](Block &self, Operation *op) {
auto op_iter = std::find(self.begin(), self.end(), op);
self.erase(op_iter);
});
}

void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name);
op.def("name", &Operation::name)
.def("get_parent", &Operation::GetParent, return_value_policy::reference)
.def("num_results", &Operation::num_results)
.def("result", &Operation::result)
.def("operands",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.op_operand(i));
}
return op_list;
})
.def("results",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_results(); i++) {
op_list.append(self.result(i));
}
return op_list;
})
.def("get_input_names",
[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto inputs_info = std::get<0>(yaml_interface.GetOpInfo());
for (auto input_info : inputs_info) {
op_list.append(input_info.name);
}
return op_list;
})
.def("get_attr_names",
[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto attrs_info = std::get<1>(yaml_interface.GetOpInfo());
for (auto attr_info : attrs_info) {
op_list.append(attr_info.name);
}
return op_list;
})
.def("get_output_names",
[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto outputs_info = std::get<2>(yaml_interface.GetOpInfo());
for (auto output_info : outputs_info) {
op_list.append(output_info.name);
}
return op_list;
})
.def("replace_all_uses_with",
[](Operation &self, const std::vector<OpResult> &op_results) {
self.ReplaceAllUsesWith(op_results);
});
}

void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value");
value.def(
"get_defining_op", &Value::GetDefiningOp, return_value_policy::reference);
}

void BindOpOperand(py::module *m) {
py::class_<OpOperand> op_operand(*m, "OpOperand");
op_operand.def("source", &OpOperand::source)
.def("set_source", &OpOperand::set_source);
}

void BindOpResult(py::module *m) {
py::class_<OpResult> op_result(*m, "OpResult");
op_result
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
.def("use_empty", &OpResult::use_empty)
.def("type", &OpResult::type)
.def("set_stop_gradient",
[](OpResult &self, bool stop_gradient) {
auto *defining_op = self.owner();
std::vector<ir::Attribute> stop_gradients;
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(),
ir::BoolAttribute::get(ir::IrContext::Instance(), false));
}
stop_gradients[self.GetResultIndex()] = ir::BoolAttribute::get(
ir::IrContext::Instance(), stop_gradient);
defining_op->set_attribute(
kAttrStopGradients,
ir::ArrayAttribute::get(ir::IrContext::Instance(),
stop_gradients));
})
.def("get_stop_gradient", [](OpResult &self) {
auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
return stop_gradients[self.GetResultIndex()]
.dyn_cast<ir::BoolAttribute>()
.data();
} else {
return false;
}
});
}

void BindType(py::module *m) {
py::class_<Type> ir_type(*m, "Type");
ir_type.def("__eq__", [](Type &self, Type &other) { return self == other; })
.def("print", [](Type &self) { LOG(INFO) << self; });
}

void BindUtils(pybind11::module *m) {
m->def("get_op_result_shape", [](const OpResult &op_result) {
if (op_result.type().isa<DenseTensorType>()) {
return phi::vectorize(
op_result.type().dyn_cast<DenseTensorType>().dims());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"get_op_result_shape currently only support op_result that is a "
"DenseTensorType"));
}
});
m->def("get_op_result_dtype", [](const OpResult &op_result) {
if (op_result.type().isa<DenseTensorType>()) {
return op_result.type().dyn_cast<DenseTensorType>().dtype();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"get_op_result_dtype currently only support op_result that is a "
"DenseTensorType"));
}
});
}

void BindNewIR(pybind11::module *m) {
BindProgram(m);
BindBlock(m);
BindOperation(m);
BindValue(m);
BindOpOperand(m);
BindOpResult(m);
BindType(m);
BindUtils(m);
}

} // namespace pybind
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2748,7 +2748,7 @@ All parameter, weight, gradient are variables in Paddle.
// Add skipped op list
m.def("set_skipped_op_list",
[](const std::string &op_list) { egr::SetSkipOpList(op_list); });
m.def("translate_newirprogram", &paddle::TranslateLegacyProgramToProgram);
m.def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram);
BindFleetWrapper(&m);
BindIO(&m);
BindParallelExecutor(m);
Expand Down
2 changes: 2 additions & 0 deletions paddle/ir/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/type_id.h"

constexpr char kAttrStopGradients[] = "stop_gradient";

namespace ir {
class AttributeStorage;
class AbstractAttribute;
Expand Down
8 changes: 8 additions & 0 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ void Operation::ReplaceAllUsesWith(const std::vector<Value> &values) {
}
}

void Operation::ReplaceAllUsesWith(const std::vector<OpResult> &op_results) {
IR_ENFORCE(num_results_ == op_results.size(),
"the num of result should be the same.");
for (uint32_t i = 0; i < num_results_; ++i) {
result(i).ReplaceAllUsesWith(op_results[i]);
}
}

void Operation::Verify() {
if (info_) {
info_.Verify(this);
Expand Down
2 changes: 2 additions & 0 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class IR_API alignas(8) Operation final {
/// Replace all uses of results of this operation with the provided 'values'.
void ReplaceAllUsesWith(const std::vector<Value> &values);

void ReplaceAllUsesWith(const std::vector<OpResult> &op_results);

inline void ReplaceAllUsesWith(Value value) {
ReplaceAllUsesWith(std::vector<Value>{value});
}
Expand Down
41 changes: 41 additions & 0 deletions python/paddle/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.core import (
Program,
Block,
Operation,
Value,
OpOperand,
OpResult,
Type,
) # noqa: F401
from paddle.fluid.core import (
get_op_result_shape,
get_op_result_dtype,
translate_to_new_ir,
) # noqa: F401

__all__ = [ # noqa
'Program',
'Block',
'Operation',
'Value',
'OpOperand',
'OpResult',
'Type',
'get_op_result_shape',
'get_op_result_dtype',
'translate_to_new_ir',
]
71 changes: 62 additions & 9 deletions test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import paddle
from paddle.fluid import core
from paddle import ir

paddle.enable_static()

Expand All @@ -32,24 +32,77 @@ def get_ir_program():
y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s)
y_s = paddle.tanh(y_s)
newir_program = core.translate_newirprogram(main_program.desc)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program


class TestPybind(unittest.TestCase):
def test_program(self):
newir_program = get_ir_program()
newir_program.print()
ops = newir_program.block().get_op_list()

def test_block(self):
newir_program = get_ir_program()
block = newir_program.block()
ops = block.get_ops()
self.assertTrue(
len(ops), 4
) # ir program add "builtin.get_parameter" by default, so size is 4
for op in ops:
# check op.name function
if op.name() == 'pd.tanh':
self.assertTrue(True)
return
self.assertTrue(False)
block.remove_op(ops[3])
self.assertTrue(len(block.get_ops()), 3)

def test_operation(self):
newir_program = get_ir_program()
ops = newir_program.block().get_ops()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3]
parent_block = tanh_op.get_parent()
parent_ops_num = len(parent_block.get_ops())
self.assertTrue(parent_ops_num, 4)
self.assertTrue(tanh_op.num_results(), 1)
self.assertTrue(len(matmul_op.get_input_names()), 2)
self.assertTrue(len(matmul_op.get_attr_names()), 2)
self.assertTrue(len(matmul_op.get_output_names()), 1)

def test_value(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3]
self.assertTrue(
matmul_op.results()[0].get_defining_op().name(), "pd.matmul"
)
self.assertTrue(
matmul_op.result(0).get_defining_op().name(), "pd.matmul"
)
matmul_op.result(0).set_stop_gradient(True)
self.assertTrue(matmul_op.result(0).get_stop_gradient, True)

self.assertTrue(
tanh_op.operands()[0].source().get_defining_op(), "pd.add"
)

add_op.replace_all_uses_with(matmul_op.results())
self.assertTrue(
tanh_op.operands()[0].source().get_defining_op(), "pd.matmul"
)
self.assertTrue(add_op.result(0).use_empty(), False)

def test_type(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
matmul_op.result(0).type().print()
self.assertTrue(
matmul_op.result(0).type() == add_op.result(0).type(), True
)

def test_utils(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
print(ir.get_op_result_dtype(matmul_op.result(0)).print())
self.assertEqual(ir.get_op_result_shape(matmul_op.result(0)), [4, 4])


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit ee506c2

Please sign in to comment.