Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
split error
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean committed Jun 26, 2023
1 parent 9b9d7f3 commit 4a19ff4
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 23 deletions.
43 changes: 37 additions & 6 deletions cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ namespace ir {
class ScheduleImpl {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
: module_expr_(module_expr), debug_flag_(debug_flag) {}
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank)
: module_expr_(module_expr), debug_flag_(debug_flag), err_msg_level_(err_msg_level) {}
explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {}

//! Set the debug flag.
Expand Down Expand Up @@ -114,8 +116,32 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_;
};

/** \brief A macro that guards the beginning of each implementation of schedule */
#define CINN_IR_SCHEDULE_BEGIN() try {
/**
* \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential errors and error
* message printing
* \param primitive A string representing the kind of schedule primitive
* \param err_msg_level A ScheduleErrorMessageLevel enum, level of error message printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
switch (err_msg_level) { \
case ScheduleErrorMessageLevel::kDetailed: \
throw std::runtime_error(err_hanlder.FormatErrorMessage(primitive)); \
case ScheduleErrorMessageLevel::kGenearl: \
throw std::runtime_error(err_hanlder.GeneralErrorMessage()); \
case ScheduleErrorMessageLevel::kBlank: \
throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \
default: \
throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \
} \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop, const std::vector<int>& factors) {
CHECK(loop.As<ir::For>()) << "Expr param of Split must be For node! Please check.";
auto* for_node = loop.As<ir::For>();
Expand All @@ -126,8 +152,10 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop, const std::vector<int>&
VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to ("
<< cinn::utils::Join(factors, ", ") << ") at loop:\n"
<< loop;

auto processed_factors = ValidateFactors(factors, tot_extent);
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies<int>());
std::vector<Var> new_loop_vars;
Expr substitute_value(0);
Expand Down Expand Up @@ -1971,8 +1999,11 @@ Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_

IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
}

Expand Down
4 changes: 3 additions & 1 deletion cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_mutator.h"
#include "cinn/ir/ir_schedule_error.h"
#include "cinn/ir/schedule_desc.h"
#include "cinn/ir/tensor.h"
#include "cinn/utils/random_engine.h"
Expand Down Expand Up @@ -67,7 +68,8 @@ class IRSchedule {
IRSchedule();
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false);
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank);
IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1);
IRSchedule(const IRSchedule& other);
IRSchedule& operator=(const IRSchedule& src);
Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/ir_schedule_error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) {
std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) const {
std::ostringstream os;
std::string err_msg = DetailedErrorMessage();

os << "[IRScheduleError] An error occurred in the schedue primitive <" << primitive << ">. " << std::endl;
os << "[IRScheduleError] An error occurred in the scheduel primitive <" << primitive << ">. " << std::endl;
os << "Error info: " << err_msg;
return os.str();
}
Expand Down
12 changes: 6 additions & 6 deletions cinn/ir/ir_schedule_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ enum class ScheduleErrorMessageLevel : int32_t {
};

/**
* This handler is to deal with the errors happens in in the current Scheduling.
* This handler is dealing with the errors happen in in the current Scheduling.
*/
class IRScheduleErrorHandler : public std::runtime_error {
public:
Expand All @@ -50,17 +50,17 @@ class IRScheduleErrorHandler : public std::runtime_error {
/**
* \brief Returns a detailed error message corresponding to the kDetailed error level.
*/
std::string FormatErrorMessage(const std::string &primitive);
std::string FormatErrorMessage(const std::string &primitive) const;

/**
* \brief Returns a detailed error message corresponding to the kDetailed error level.
* \brief Returns a short error message corresponding to the kGeneral error level.
*/
virtual std::string DetailedErrorMessage() const = 0;
virtual std::string GeneralErrorMessage() const = 0;

/**
* \brief Returns a short error message corresponding to the kGeneral error level.
* \brief Returns a detailed error message corresponding to the kDetailed error level.
*/
virtual std::string GeneralErrorMessage() const = 0;
virtual std::string DetailedErrorMessage() const = 0;
};

} // namespace ir
Expand Down
71 changes: 63 additions & 8 deletions cinn/ir/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/ir_schedule_error.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/lang/compute.h"
#include "cinn/optim/ir_copy.h"
Expand Down Expand Up @@ -196,27 +197,81 @@ void ReplaceExpr(Expr* source, const std::vector<Var>& replaced, const std::vect
}

std::vector<int> ValidateFactors(const std::vector<int>& factors, int total_extent) {
class NegativeFactorErrorHandler : public IRScheduleErrorHandler {
public:
explicit NegativeFactorErrorHandler(int64_t factor, size_t idx) : factor_(factor), idx_(idx) {}

std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: The params in factors of Split should be positive. However, some "
"factor is zero or negative.";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the factor at position " << idx_ << " is "
<< factor_;
return os.str();
}

private:
int64_t factor_;
size_t idx_;
};

class InferFactorErrorHandler : public IRScheduleErrorHandler {
public:
std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: The params in factors of Split should not be less than -1 or have more than one -1!";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have more than one -1!";
return os.str();
}
};

class FactorProductErrorHandler : public IRScheduleErrorHandler {
public:
std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: In Split, the factors' product should be not larger than or equal to original loop's "
"extent!";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal to original loop's extent!";
return os.str();
}
};

CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
int product = 1;
int idx = -1;
for (auto& i : factors) {
CHECK(i != 0) << "The params in factors of Split should not be 0! Please check.";
CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check.";
if (i == -1) {
CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check.";
idx++;
if (i == 0 || i < -1) {
throw NegativeFactorErrorHandler(i, idx);
} else if (i == -1) {
if (has_minus_one) {
throw InferFactorErrorHandler();
}
has_minus_one = true;
} else {
product *= i;
}
}
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
CHECK_GE(product, total_extent)
<< "In Split, the factors' product should be equal to original loop's extent! Please check.";
if (product < total_extent) {
throw FactorProductErrorHandler();
}
return validated_factors;
} else {
CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= "
"original loop's extent! Please check.";
if (product > total_extent) {
throw FactorProductErrorHandler();
}
int minus_one_candidate = (int)ceil((double)total_extent / (double)product);
for (int i = 0; i < validated_factors.size(); ++i) {
if (validated_factors[i] == -1) {
Expand Down

0 comments on commit 4a19ff4

Please sign in to comment.