From 4285230b80afecdf9042f64bfae1c25c7f945a30 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Wed, 24 May 2023 03:40:02 +0000 Subject: [PATCH 1/3] Add Reference Type --- cinn/common/type.cc | 18 ++++++++++++++++++ cinn/common/type.h | 4 ++++ 2 files changed, 22 insertions(+) diff --git a/cinn/common/type.cc b/cinn/common/type.cc index c83911f559..709a072b73 100644 --- a/cinn/common/type.cc +++ b/cinn/common/type.cc @@ -111,6 +111,21 @@ Type &Type::set_cpp_handle2(bool x) { return *this; } +Type &Type::set_cpp_reference(bool x) { + auto &v = (*reinterpret_cast(&GetStorage().cpp_type_)); + + // unset the other handle-related bits. + v &= ~static_cast(cpp_type_t::Handle); + v &= ~static_cast(cpp_type_t::HandleHandle); + + if (x) + v |= static_cast(cpp_type_t::Reference); + else + v &= ~static_cast(cpp_type_t::Reference); + + return *this; +} + Type Type::VectorOf(int w) const { CheckTypeValid(); return Type(type(), bits(), w, specific_type()); @@ -263,6 +278,9 @@ bool Type::is_cpp_handle() const { bool Type::is_cpp_handle2() const { return static_cast(GetStorage().cpp_type_) & static_cast(cpp_type_t::HandleHandle); } +bool Type::is_cpp_reference() const { + return static_cast(GetStorage().cpp_type_) & static_cast(cpp_type_t::Reference); +} bool Type::is_cpp_const() const { return static_cast(cpp_type_t::Const) & static_cast(GetStorage().cpp_type_); } diff --git a/cinn/common/type.h b/cinn/common/type.h index 490b80c64f..9829728ab6 100644 --- a/cinn/common/type.h +++ b/cinn/common/type.h @@ -65,6 +65,7 @@ struct Type { Const = 1, // const. Handle = 1 << 1, // pointer type, such as `cinn_buffer_t*`. HandleHandle = 1 << 2, // pointer of pointer, such as `cinn_buffer_t**`. + Reference = 1 << 4, // reference type, such as `cinn_buffer_t&`. }; Type(); @@ -100,6 +101,9 @@ struct Type { Type& set_cpp_handle2(bool x = true); CINN_NODISCARD bool is_cpp_handle2() const; + Type& set_cpp_reference(bool x = true); + CINN_NODISCARD bool is_cpp_reference() const; + Type& set_cpp_const(bool is_const = true); CINN_NODISCARD bool is_cpp_const() const; From ed3cb518ee339919f30eda1101ce36b1e0b5d192 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Wed, 24 May 2023 06:46:46 +0000 Subject: [PATCH 2/3] Add GetReference IR and its transformers --- .../auto_schedule/cost_model/feature_extractor.cc | 1 + cinn/backends/codegen_c.cc | 3 +++ cinn/backends/llvm/codegen_llvm.cc | 5 +++++ cinn/ir/ir.cc | 9 +++++++++ cinn/ir/ir.h | 15 +++++++++++++++ cinn/ir/ir_base.h | 1 + cinn/ir/ir_printer.cc | 4 ++++ 7 files changed, 38 insertions(+) mode change 100755 => 100644 cinn/ir/ir.cc diff --git a/cinn/auto_schedule/cost_model/feature_extractor.cc b/cinn/auto_schedule/cost_model/feature_extractor.cc index 5f44b2e3f0..13acb06e1c 100644 --- a/cinn/auto_schedule/cost_model/feature_extractor.cc +++ b/cinn/auto_schedule/cost_model/feature_extractor.cc @@ -150,6 +150,7 @@ VisitForMultiOperandsDtypePattern(Product, mul); VisitCountMemberPattern(And, bool_op); VisitCountMemberPattern(Or, bool_op); VisitCountMemberPattern(Not, bool_op); +VisitCountMemberPattern(GetReference, mem_read); VisitCountMemberPattern(Max, select_op); VisitCountMemberPattern(Min, select_op); VisitCountMemberPattern(IfThenElse, select_op); diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index a5a26ecea0..680531c221 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -143,7 +143,10 @@ std::string CodeGenC::GetTypeRepr(Type type) { str += "*"; } else if (type.is_cpp_handle2()) { str += "**"; + } else if (type.is_cpp_reference()) { + str += "&"; } + return str; } void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); } diff --git a/cinn/backends/llvm/codegen_llvm.cc b/cinn/backends/llvm/codegen_llvm.cc index 318f1d02b8..d3cf717039 100644 --- a/cinn/backends/llvm/codegen_llvm.cc +++ b/cinn/backends/llvm/codegen_llvm.cc @@ -344,6 +344,11 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Minus *op) { return (op->type().is_int() || op->type().is_uint()) ? Neg(v) : FNeg(v); } +llvm::Value *CodeGenLLVM::Visit(const ir::GetReference *op) { + LOG(FATAL) << "TODO: Unimplementd CodeGenLLVM::Visit(const ir::GetReference *op)"; + return nullptr; +} + llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { return Not(Visit(&op->v())); } llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc old mode 100755 new mode 100644 index 7db39fa704..94fa8a9ac8 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -185,6 +185,15 @@ void Not::Verify() const { CHECK_EQ(v().type(), type_of()); } Type Not::type() const { return type_; } +Expr GetReference::Make(Expr v) { + auto node = make_shared(v); + return Expr(node); +} + +void GetReference::Verify() const { CHECK(v().defined()); } + +Type GetReference::type() const { return type_; } + Expr Let::Make(Expr symbol, Expr body) { auto *n = make_shared(); CHECK(symbol.type().valid()); diff --git a/cinn/ir/ir.h b/cinn/ir/ir.h index 894ab40e4c..4eb0b153c0 100644 --- a/cinn/ir/ir.h +++ b/cinn/ir/ir.h @@ -295,6 +295,21 @@ struct Not : public UnaryOpNode { static const IrNodeTy _node_type_ = IrNodeTy::Not; }; +/** + * Get reference, such as C++ &x + * + */ +struct GetReference : public UnaryOpNode { + explicit GetReference(Expr v) : UnaryOpNode(common::Int(32).set_cpp_reference(), v) {} + + static Expr Make(Expr v); + + Type type() const override; + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::GetReference; +}; + struct Let : public ExprNode { Expr symbol; Expr body; diff --git a/cinn/ir/ir_base.h b/cinn/ir/ir_base.h index b1baf1d59f..d76b2d1bae 100644 --- a/cinn/ir/ir_base.h +++ b/cinn/ir/ir_base.h @@ -79,6 +79,7 @@ class ScheduleBlockRealize; #define NODETY_UNARY_OP_FOR_EACH(macro__) \ macro__(Minus) \ macro__(Not) \ + macro__(GetReference) \ #define NODETY_OP_FOR_EACH(macro__) NODETY_BINARY_OP_FOR_EACH(macro__) NODETY_UNARY_OP_FOR_EACH(macro__) diff --git a/cinn/ir/ir_printer.cc b/cinn/ir/ir_printer.cc index 66604da970..65e4b8a978 100644 --- a/cinn/ir/ir_printer.cc +++ b/cinn/ir/ir_printer.cc @@ -140,6 +140,10 @@ void IrPrinter::Visit(const Minus *x) { Print(x->v()); os_ << ")"; } +void IrPrinter::Visit(const GetReference *x) { + os_ << "&"; + Print(x->v()); +} void IrPrinter::Visit(const For *x) { if (x->is_parallel()) { os() << "parallel for ("; From f5f5fcc9bbeaf8c5250d4c715556b4c551cecf57 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Thu, 25 May 2023 07:53:46 +0000 Subject: [PATCH 3/3] Add fusion conditions --- cinn/backends/codegen_c.cc | 1 + cinn/backends/codegen_cuda_dev.cc | 6 +++ cinn/hlir/framework/op_lowering.cc | 7 +++ cinn/hlir/framework/op_lowering_util.cc | 63 +++++++++++++++++++++++++ cinn/hlir/framework/op_lowering_util.h | 14 ++++++ cinn/hlir/pass/fusion_merge_pass_util.h | 19 ++++++++ cinn/hlir/pass/op_fusion_pass_util.h | 20 ++++++++ cinn/ir/ir_schedule.cc | 62 ++++++++++++++++++++++++ cinn/ir/ir_schedule.h | 8 ++++ cinn/ir/lowered_func.cc | 12 +++++ cinn/ir/lowered_func.h | 2 + cinn/ir/schedule_desc.cc | 4 ++ 12 files changed, 218 insertions(+) diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index 680531c221..c3c0b7cec1 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -189,6 +189,7 @@ void CodeGenC::Visit(const ir::Not *op) { IrPrinter::Print(op->v()); os() << ")"; } +void CodeGenC::Visit(const ir::GetReference *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); } void CodeGenC::Visit(const ir::For *op) { Expr extent = op->extent; diff --git a/cinn/backends/codegen_cuda_dev.cc b/cinn/backends/codegen_cuda_dev.cc index 21fc8961fa..6f4b13f223 100644 --- a/cinn/backends/codegen_cuda_dev.cc +++ b/cinn/backends/codegen_cuda_dev.cc @@ -112,6 +112,12 @@ std::vector CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFu } void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { + std::set device_count_exprs = op->PrepareDeviceCountExprs(); + for (auto dce : device_count_exprs) { + VLOG(6) << "PrepareDeviceCountExprs " << dce; + os() << "__device__ int " << dce.As()->name << " = 0;\n"; + } + // clear names valid within scope when enter a new function vectorized_tensor_names_.clear(); os() << "__global__\n"; diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 069a1b2e59..dd3cd1ba3b 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -109,6 +109,7 @@ std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ true); } else { for (auto& sub_group : group->fused_sub_groups) { + VLOG(4) << "sub_group->group_id = " << sub_group->group_id; auto exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ true); ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); } @@ -1277,8 +1278,14 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } VLOG(3) << "Before loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); VLOG(4) << " FUSION " << node->op()->name; + Node* fusion_master = master ? master : nodes_in_order.front(); + + // if (CanFuseReduceByBlockSync(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map)) { + // SyncGpuBlocks(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map); + // } else { // do loop fuse. LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map); + //} VLOG(3) << "After loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index ba993a219f..18d18a5d67 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -15,6 +15,7 @@ #include "cinn/hlir/framework/op_lowering_util.h" #include "cinn/hlir/pe/nn_util.h" +#include "cinn/utils/string.h" #ifdef CINN_WITH_CUDA #include "cinn/common/bfloat16.h" #include "cinn/common/float16.h" @@ -1429,6 +1430,68 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, } while (--index >= 0); } +bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (op_pattern_dict[node->op()] == framework::kReduction && op_pattern_dict[master->op()] == framework::kReduction && + node != master) { + auto node_shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id()); + auto master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + + VLOG(6) << "Checking CanFuseReduceByBlockSync"; + VLOG(6) << "node->id() = " << node->id() << ", node_shape.size() = " << node_shape.size(); + for (auto x : node_shape) { + VLOG(6) << x; + } + VLOG(6) << "master->id() = " << master->id() << ", master_shape.size() = " << master_shape.size(); + for (auto x : master_shape) { + VLOG(6) << x; + } + + static std::unordered_set reduce_op_type = { + "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"}; + for (const std::string& op_type : reduce_op_type) { + // TODO: this may speed up not only reduce_xxx_split nodes, but we limit it to reduce_xxx_split nodes for accuracy + // safety + if (cinn::utils::Startswith(master->id(), op_type + "_split") && + cinn::utils::Startswith(node->id(), op_type + "_split")) { + // Returns true only when shape is not equal. Shape equal is handled by other fusion + if (node_shape.size() != master_shape.size()) { + return true; + } + for (int i = 0; i < node_shape.size(); ++i) { + if (node_shape[i] != master_shape[i]) { + return true; + } + } + } + } + } + return false; +} + +void SyncGpuBlocks(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map) { + VLOG(6) << "Calling SyncGpuBlocks"; + if (!group->output_nodes.count(node)) { + auto block = ir_sch.GetBlock(GetNodeData(node)->id()); + ir_sch.SetBuffer(block, "local", true); + } + auto node_data = GetNodeData(node); + auto master_data = GetNodeData(master); + auto node_block = ir_sch.GetBlock(node->id()); + auto master_block = ir_sch.GetBlock(master_data->id()); + ir_sch.SyncGpuBlocks(master_block, node_block); +} + std::unordered_map GetNodeDataSet(const std::unordered_set& nodes_set) { std::unordered_map node_data_set; for (auto node : nodes_set) { diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index f081411ec0..c7404176c2 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -90,6 +90,20 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, const absl::flat_hash_map& shape_dict, const std::unordered_map& tensor_map); +bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + +void SyncGpuBlocks(ir::IRSchedule& ir_sch, + Node* node, + const Node* master, + const GroupPtr& group, + const absl::flat_hash_map& shape_dict, + const std::unordered_map& tensor_map); + void SyncThreadWithShared(ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_set& nodes_inline, diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 82bbabd20f..bf3822c3ca 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -285,6 +285,20 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { return elementwise_fuse_reduce(helper, first, second); } +inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) { + static std::unordered_set reduce_op_type = { + "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"}; + VLOG(6) << "Checking ReduceSplitCanFuse"; + VLOG(6) << "producer->id() = " << producer->id(); + VLOG(6) << "reducer->id() = " << reducer->id(); + for (const std::string& op_type : reduce_op_type) { + if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) { + return true; + } + } + return false; +} + CONDITION_FUNC(reduce_fuse_broadcast) { // if same shape with horizontal relation if (is_same_size(helper, first, second)) { @@ -388,6 +402,7 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } CONDITION_FUNC(reduce_fuse_reduce) { + VLOG(6) << "In reduce_fuse_reduce"; if (!limit_args(helper, first, second)) { return false; } @@ -409,6 +424,10 @@ CONDITION_FUNC(reduce_fuse_reduce) { } CHECK(reducer_1) << "Can't find reduce op in group " << second->group_id; + if (ReduceSplitCanFuse(reducer_0, reducer_1)) { + return true; + } + // check reduce has same input shape and output shape auto reducer_0_input_shape = helper->shape_dict_.at(reducer_0->inlinks_in_order()[0]->source()->id()); auto reducer_0_output_shape = helper->shape_dict_.at(reducer_0->outlinks_in_order()[0]->sink()->id()); diff --git a/cinn/hlir/pass/op_fusion_pass_util.h b/cinn/hlir/pass/op_fusion_pass_util.h index 778f7bdbf1..9f59daab7e 100644 --- a/cinn/hlir/pass/op_fusion_pass_util.h +++ b/cinn/hlir/pass/op_fusion_pass_util.h @@ -51,7 +51,22 @@ CONDITION_FUNC(without_last_dimension_in_reduce) { return helper->WithoutLastDimInReduce(in_shape, reduce_axes); } +inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) { + static std::unordered_set reduce_op_type = { + "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"}; + VLOG(6) << "Checking ReduceSplitCanFuse"; + VLOG(6) << "producer->id() = " << producer->id(); + VLOG(6) << "reducer->id() = " << reducer->id(); + for (const std::string& op_type : reduce_op_type) { + if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) { + return true; + } + } + return false; +} + CONDITION_FUNC(reduce_fuse_reduce) { + VLOG(6) << "In reduce_fuse_reduce"; Node* reducer = NULL; for (auto* master : consumer->master_nodes) { if (helper->GetOpKind(master) == framework::kReduction) { @@ -59,6 +74,11 @@ CONDITION_FUNC(reduce_fuse_reduce) { break; } } + + if (ReduceSplitCanFuse(producer, reducer)) { + return true; + } + // check reduce has same input shape and output shape auto producer_input_shape = helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id()); auto producer_output_shape = helper->shape_dict_.at(producer->outlinks_in_order()[0]->sink()->id()); diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index 3b5f8a6671..8f356551f5 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -86,6 +86,7 @@ class ScheduleImpl { Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); Expr CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type); void SyncThreads(const Expr& ir_node, bool after_node = true); + void SyncGpuBlocks(const Expr& master_block, const Expr& sequential_block); void SetBuffer(Expr& block, const std::string& memory_type, bool fixed = false); Expr Reorder(const std::vector& loops); Expr Reorder(const std::string& block_name, const std::vector& loops_index); @@ -854,6 +855,56 @@ void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) { return; } +void ScheduleImpl::SyncGpuBlocks(const Expr& master_block, const Expr& sequential_block) { + VLOG(6) << "Call ScheduleImpl::SyncGpuBlocks"; + CHECK(master_block.As() || master_block.As()); + CHECK(sequential_block.As() || sequential_block.As()); + Expr master_root = GetRootBlock(master_block); + ChangeBodyToBlock::Change(&master_root); + + Expr sync_threads = runtime::IntrinsicCall(Void(), "__syncthreads", {}); + + Var block_count_var(common::UniqName("block_count")); + Expr atomic_add = + runtime::IntrinsicCall(common::I32(), "atomicAdd", {block_count_var, common::make_const(Int(32), 1)}); + + Expr only_first_thread_add = ir::For::Make(Var(common::UniqName("sync_block_thread_x")), + ir::Expr(0), + ir::Expr(1), + ir::ForType::GPUThread, + ir::DeviceAPI::UNK, + atomic_add); + + int block_number = 96; // = TODO + Expr atomic_max = + runtime::IntrinsicCall(common::I32(), "atomicAdd", {block_count_var, common::make_const(Int(32), -1)}); + Expr loop_waiting = ir::PolyFor::Make(Var(common::UniqName("useless_tmp")), + ir::Expr(0), + ir::LE::Make(atomic_max, ir::Expr(block_number)), + ir::Expr(0), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + ir::Block::Make({})); + + Expr loop_gpu_block = ir::For::Make(Var(common::UniqName("only_first_block_run")), + ir::Expr(0), + ir::Expr(1), + ir::ForType::GPUBlock, + ir::DeviceAPI::UNK, + ir::Block::Make({loop_waiting, sequential_block})); + + Expr sync_statements = ir::Block::Make(std::vector{sync_threads, only_first_thread_add, loop_gpu_block}); + + VLOG(6) << "Debug sync_statements = "; + VLOG(6) << sync_statements; + InsertExpr::Insert(master_block, sync_statements, /* after_node = */ true, &master_root); + + Expr source_expr{nullptr}; + Expr target_expr{nullptr}; + LeafBlockRemovalPlan remove_plan(sequential_block, &source_expr, &target_expr); + remove_plan(&master_root); +} + /** * Replace a For node to another For node. * @param src_sref The For node to be changed. @@ -959,6 +1010,7 @@ Expr ScheduleImpl::GetRootBlock(const Expr& expr) const { } } LOG(FATAL) << "Didn't find expr \n" << expr << "in ScheduleImpl:\n" << exprs[0]; + return expr; } // The struct used to reconstruct the new For node to replace the old For node. @@ -1629,6 +1681,7 @@ Expr ScheduleImpl::GetBlock(const std::string& block_name) const { } } LOG(FATAL) << "Didn't find a block with name " << block_name << " in this ModuleExpr!"; + return result; } void ScheduleImpl::Annotate(const Expr& block, const std::string& key, const attr_t& value) { @@ -2145,6 +2198,15 @@ void IRSchedule::SyncThreads(const Expr& ir_node, bool after_node) { ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({ir_node})}}, {{"after_node", after_node}}, {})); } +void IRSchedule::SyncGpuBlocks(const Expr& master_block, const Expr& sequential_block) { + impl_->SyncGpuBlocks(master_block, sequential_block); + trace_.Append(ScheduleDesc::Step("SyncGpuBlocks", + {{"master_block", std::vector({master_block})}, + {"sequential_block", std::vector({sequential_block})}}, + {}, + {})); +} + void IRSchedule::SetBuffer(Expr& block, const std::string& memory_type, bool fixed) { impl_->SetBuffer(block, memory_type, fixed); trace_.Append(ScheduleDesc::Step( diff --git a/cinn/ir/ir_schedule.h b/cinn/ir/ir_schedule.h index 6b7b252a57..c4e7136307 100644 --- a/cinn/ir/ir_schedule.h +++ b/cinn/ir/ir_schedule.h @@ -221,6 +221,14 @@ class IRSchedule { */ void SyncThreads(const Expr& ir_node, bool after_node = true); + /** + * \brief Add GPU Block sync statements in AST of a master block, then change sequential block loop to match the + * master block if needed + * @param master_block Block that add GPU Block sync statements + * @param sequential_block Block that may change loop info + */ + void SyncGpuBlocks(const Expr& master_block, const Expr& sequential_block); + /*! * \brief Set a tensor's buffer type(memory_type) * \param block The ScheduleBlockRealize corresponding to an unique tensor. diff --git a/cinn/ir/lowered_func.cc b/cinn/ir/lowered_func.cc index 36b0dcf601..7fb45bf831 100644 --- a/cinn/ir/lowered_func.cc +++ b/cinn/ir/lowered_func.cc @@ -79,6 +79,18 @@ void _LoweredFunc_::CheckValid() const { std::vector _LoweredFunc_::expr_fields() { return {&body}; } std::vector _LoweredFunc_::expr_fields() const { return {&body}; } +std::set _LoweredFunc_::PrepareDeviceCountExprs() const { + std::set device_count_vars = ir::CollectIRNodes(body, [](const Expr* expr) { + const ir::_Var_* var = expr->As(); + if (var != nullptr) { + return utils::Startswith(var->name, "device_count"); + } + return false; + }); + + return device_count_vars; +} + void _LoweredFunc_::PrepareCudaAxisInfoFromBody() { std::set bound_for_exprs = ir::CollectIRNodes(body, [](const Expr* expr) { const ir::For* for_expr = expr->As(); diff --git a/cinn/ir/lowered_func.h b/cinn/ir/lowered_func.h index f237232b1c..9fad668e57 100755 --- a/cinn/ir/lowered_func.h +++ b/cinn/ir/lowered_func.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include @@ -170,6 +171,7 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_; + std::set PrepareDeviceCountExprs() const; std::vector PrepareCreateTempBufferExprs() const; //! Prepare the expressions for `alloc_tmp_buffer_exprs`. std::vector PrepareAllocTempBufferExprs() const; diff --git a/cinn/ir/schedule_desc.cc b/cinn/ir/schedule_desc.cc index cb50cc2ab9..92ca35138f 100644 --- a/cinn/ir/schedule_desc.cc +++ b/cinn/ir/schedule_desc.cc @@ -368,6 +368,10 @@ CINN_BUILD_STEP_KIND(SyncThreads) .Attrs({"after_node"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SyncThreads))); +CINN_BUILD_STEP_KIND(SyncGpuBlocks) + .Inputs({"master_block", "sequential_block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SyncGpuBlocks))); + CINN_BUILD_STEP_KIND(SetBuffer) .Inputs({"block"}) .Attrs({"memory_type", "fixed"})