Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech committed Sep 27, 2021
1 parent fd253d8 commit b9ad508
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 117 deletions.
2 changes: 1 addition & 1 deletion cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void PaddleModelToProgram::AddOpMapper_reshape2() {
auto x = GetVar(utils::TransValidVarName(x_name));
std::vector<int> shape = op_desc.GetAttr<std::vector<int>>("shape");
VLOG(4) << "x shape: " << utils::Join(x->shape, ",");
auto out = program_->reshape2(x, shape);
auto out = program_->reshape(x, shape);
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
AddVar(utils::TransValidVarName(out_name), out);
Expand Down
4 changes: 2 additions & 2 deletions cinn/frontend/syntax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ Variable Program::matmul(const Variable& a, const Variable& b, bool trans_a, boo
return instr.GetOutput(0);
}

Variable Program::reshape2(const Variable& a, const std::vector<int>& shape) {
Instruction instr("reshape2", {a});
Variable Program::reshape(const Variable& a, const std::vector<int>& shape) {
Instruction instr("reshape", {a});
instr.SetAttr("shape", shape);
AppendInstruction(instr);
return instr.GetOutput(0);
Expand Down
9 changes: 8 additions & 1 deletion cinn/frontend/syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,18 @@ struct Program {

/**
* Reshape a tensor.
* @param a The input tensor.
* @param shape The output tensor's shape we specified.
* @return The reshaped output tensor.
*/
Variable reshape2(const Variable& a, const std::vector<int>& shape);
Variable reshape(const Variable& a, const std::vector<int>& shape);

/**
* Concat 2 tensors.
* @param a The first input tensor.
* @param b The second input tensor.
* @param axis The axis specified to do the concat operation.
* @return The concated output tensor.
*/
Variable concat(const Variable& a, const Variable& b, int axis = 0);

Expand Down
3 changes: 2 additions & 1 deletion cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,8 @@ std::vector<std::vector<int>> InferShapeForPool2d(const std::vector<std::vector<

if (adaptive) {
kernel_size = std::get<std::vector<int>>(attr_store["kernel_size"]);
if (kernel_size.size() == 1) kernel_size.push_back(kernel_size[0]);
if (kernel_size.size() == 1UL) kernel_size.push_back(kernel_size[0]);
CHECK(kernel_size.size() >= 2UL) << "In pool2d, kernel_size's size should be >= 2, please check!";
output_shape1[height_axis] = kernel_size[0];
output_shape1[width_axis] = kernel_size[1];
}
Expand Down
64 changes: 39 additions & 25 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,15 @@ std::vector<Type> InferDtypeForMatMul(const std::vector<Type> &inputs_type,
return res;
}

std::shared_ptr<OpStrategy> StrategyForReshape2(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute reshape2_compute([=](lang::Args args, lang::RetValue *ret) {
std::shared_ptr<OpStrategy> StrategyForReshape(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute reshape_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input arguments of Matmul compute is empty! Please check.\n";
CINNValuePack a = args[0];
CHECK_GE(a.size(), 1U) << "at least 1 input tensors for Reshape2 compute\n";
CHECK_GE(a.size(), 1U) << "at least 1 input tensors for Reshape compute\n";
Expr A = a[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
Expand All @@ -283,17 +283,17 @@ std::shared_ptr<OpStrategy> StrategyForReshape2(const framework::NodeAttr &attrs
auto tensor_A = A.as_tensor_ref();
auto stages = CreateStages({tensor_A});
ir::Tensor out;
out = pe::Reshape2(tensor_A, output_shapes.back(), UniqName("Reshape2_output"));
out = pe::Reshape(tensor_A, output_shapes.back(), UniqName("Reshape_output"));
std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
CHECK(!out_type.empty()) << "Output type of Reshape2 is empty! Please check.\n";
CHECK(!out_type.empty()) << "Output type of Reshape is empty! Please check.\n";
res.push_back(CINNValue(stages));
*ret = CINNValuePack{res};
});

framework::CINNSchedule reshape2_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of reshape2 schedule is empty! Please check.\n";
framework::CINNSchedule reshape_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
int arg_size = arg_pack.size();
poly::StageMap stages = arg_pack.back();
Expand All @@ -308,33 +308,42 @@ std::shared_ptr<OpStrategy> StrategyForReshape2(const framework::NodeAttr &attrs
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(reshape2_compute, reshape2_schedule, "strategy.reshape2.x86", 1);
strategy->AddImpl(reshape_compute, reshape_schedule, "strategy.reshape.x86", 1);
return strategy;
}

std::vector<std::vector<int>> InferShapeForReshape2(const std::vector<std::vector<int>> &inputs_shape,
framework::NodeAttr &attrs,
const Target &target) {
std::vector<std::vector<int>> InferShapeForReshape(const std::vector<std::vector<int>> &inputs_shape,
framework::NodeAttr &attrs,
const Target &target) {
CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again.";
std::vector<int> output_shape;
for (auto &iter : attrs.attr_store) {
if (iter.first == "shape") {
output_shape = std::get<std::vector<int>>(iter.second);
break;
}
}
int tensor_size = 1;
for (auto i : inputs_shape[0]) tensor_size *= i;
CHECK(!output_shape.empty()) << "infer_shape for reshape2 turns out to be empty. Please check\n";
CHECK(!output_shape.empty()) << "infer_shape for reshape turns out to be empty. Please check\n";
int flag_index = -1;
for (int i = 0; i < output_shape.size(); i++) {
if (output_shape[i] > 0) {
CHECK_EQ(tensor_size % output_shape[i], 0)
<< "Incompatible input shape and output shape in op reshape2: " << tensor_size << ", " << output_shape[i];
<< "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i];
tensor_size /= output_shape[i];
} else if (output_shape[i] == 0) {
CHECK_LT(i, inputs_shape[0].size())
<< "In op reshape, when attribute shape[i] == 0, shape[i] = input_shape[i]. But now the size of input_shape "
"<= i, which is incompatible. Please check!";
output_shape[i] = inputs_shape[0][i];
CHECK_EQ(tensor_size % output_shape[i], 0)
<< "Incompatible input shape and output shape in op reshape: " << tensor_size << ", " << output_shape[i];
tensor_size /= output_shape[i];
} else if (output_shape[i] == -1 && flag_index == -1) {
flag_index = i;
} else if (output_shape[i] == -1) {
LOG(FATAL) << "More than one -1 in output_shape of op reshape2.";
LOG(FATAL) << "More than one -1 in output_shape of op reshape.";
} else {
LOG(FATAL) << "Unsupported output_shape " << output_shape[i];
}
Expand All @@ -344,9 +353,9 @@ std::vector<std::vector<int>> InferShapeForReshape2(const std::vector<std::vecto
return res;
}

std::vector<Type> InferDtypeForReshape2(const std::vector<Type> &inputs_type,
const framework::NodeAttr &attrs,
const Target &target) {
std::vector<Type> InferDtypeForReshape(const std::vector<Type> &inputs_type,
const framework::NodeAttr &attrs,
const Target &target) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
return res;
Expand Down Expand Up @@ -416,10 +425,15 @@ std::vector<std::vector<int>> InferShapeForConcat(const std::vector<std::vector<
for (auto &iter : attrs.attr_store) {
if (iter.first == "axis") {
axis = std::get<int>(iter.second);
break;
}
}
if (axis < 0) axis += inputs_shape[0].size();
std::vector<int> output_shape = inputs_shape[0];
CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size())
<< "In Concat op, the 2 input tensors' shape should be the same, please check!";
CHECK(axis >= 0 && axis < inputs_shape[0].size())
<< "In Concat op, the attribute `axis` should be >= 0 and < input shape's size, please check!";
output_shape[axis] += inputs_shape[1][axis];
std::vector<std::vector<int>> res{output_shape};
return res;
Expand Down Expand Up @@ -888,13 +902,13 @@ CINN_REGISTER_HELPER(transform_ops) {
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

CINN_REGISTER_OP(reshape2)
CINN_REGISTER_OP(reshape)
.describe("This operator is used to reshape input tensor X.")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForReshape2)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForReshape2))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForReshape2))
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForReshape)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForReshape))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForReshape))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

Expand Down
11 changes: 10 additions & 1 deletion cinn/hlir/pe/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,21 @@ std::vector<Tensor> Matmul(
}
}

ir::Tensor Reshape2(const ir::Tensor& A, const std::vector<int>& new_shape, const std::string& name) {
ir::Tensor Reshape(const ir::Tensor& A, const std::vector<int>& new_shape, const std::string& name) {
std::vector<Expr> new_expr_shape;
std::vector<Expr> A_expr_shape = A->shape;
int input_total_size = 1;
int output_total_size = 1;
for (auto& i : A_expr_shape) {
CHECK(i.is_constant()) << "Input tensor's shape should be constant value.";
input_total_size *= (int)(i.get_constant());
}
for (auto& i : new_shape) {
output_total_size *= i;
new_expr_shape.push_back(Expr(i));
}
CHECK_EQ(input_total_size, output_total_size)
<< "In op reshape, the input tensor and output tensor's total size should be equal, please check!";
auto res = Compute(
new_expr_shape,
[=](const std::vector<Expr>& indice) {
Expand Down
6 changes: 3 additions & 3 deletions cinn/hlir/pe/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ std::vector<ir::Tensor> Matmul(const ir::Tensor& A,
float alpha = 1,
const std::string& name = UniqName("T_Transform_Matmul_out"));

ir::Tensor Reshape2(const ir::Tensor& A,
const std::vector<int>& new_shape,
const std::string& name = UniqName("T_Transform_Matmul_out"));
ir::Tensor Reshape(const ir::Tensor& A,
const std::vector<int>& new_shape,
const std::string& name = UniqName("T_Transform_Matmul_out"));

ir::Tensor Concat(const ir::Tensor& A,
const ir::Tensor& B,
Expand Down
6 changes: 6 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ ADD_TEST(NAME test_cinn_frontend
"${WITH_CUDA}" WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)

ADD_TEST(NAME test_cinn_ops_check
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_ops.py "${WITH_CUDA}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)

ADD_TEST(NAME test_cinn_op_benchmark
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_op_benchmark.py "${WITH_CUDA}"
Expand Down
83 changes: 0 additions & 83 deletions python/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,86 +93,3 @@ def test_basic(self):
result = result.numpy(self.target).reshape(-1)
tensor_data.append(result)
self.paddle_verify(tensor_data)


class TestLoadPaddleModel_FC(unittest.TestCase):
def setUp(self):
if enable_gpu == "ON":
self.target = DefaultNVGPUTarget()
else:
self.target = DefaultHostTarget()

self.model_dir = naive_model_dir

def get_paddle_inference_result(self, model_dir, data):
config = fluid.core.AnalysisConfig(model_dir)
config.disable_gpu()
config.switch_ir_optim(False)
self.paddle_predictor = fluid.core.create_paddle_predictor(config)
data = fluid.core.PaddleTensor(data)
results = self.paddle_predictor.run([data])

return results[0].as_ndarray()

def test_model(self):
np.random.seed(0)
self.x_shape = [4, 30]
x_data = np.random.random(
self.x_shape).astype("float16").astype("float32")
print('x_data', x_data)

self.executor = Interpreter(["A"], [self.x_shape])
self.executor.load_paddle_model(self.model_dir, self.target, False)
a_t = self.executor.get_tensor("A")
a_t.from_numpy(x_data, self.target)

self.executor.run()

out = self.executor.get_tensor("save_infer_model/scale_0.tmp_0")
target_data = self.get_paddle_inference_result(self.model_dir, x_data)
print("target_data's shape is: ", target_data.shape)
out_np = out.numpy(self.target)
print("cinn data's shape is: ", out_np.shape)

self.assertTrue(np.allclose(out_np, target_data, atol=1e-4))


class TestLoadPaddleModel_MultiFC(unittest.TestCase):
def setUp(self):
if enable_gpu == "ON":
self.target = DefaultNVGPUTarget()
else:
self.target = DefaultHostTarget()

self.model_dir = multi_fc_model_dir

def get_paddle_inference_result(self, model_dir, data):
config = fluid.core.AnalysisConfig(model_dir)
config.disable_gpu()
config.switch_ir_optim(False)
self.paddle_predictor = fluid.core.create_paddle_predictor(config)
data = fluid.core.PaddleTensor(data)
results = self.paddle_predictor.run([data])

return results[0].as_ndarray()

def test_model(self):
np.random.seed(0)
self.x_shape = [8, 64]
x_data = np.random.random(self.x_shape).astype("float32")

self.executor = Interpreter(["A"], [self.x_shape])
self.executor.load_paddle_model(self.model_dir, self.target, False)
a_t = self.executor.get_tensor("A")
a_t.from_numpy(x_data, self.target)

self.executor.run()

out = self.executor.get_tensor("save_infer_model/scale_0.tmp_0")
target = self.get_paddle_inference_result(self.model_dir, x_data)

self.assertTrue(np.allclose(out.numpy(self.target), target, atol=1e-4))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit b9ad508

Please sign in to comment.