Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Winograd dev #432

Open
wants to merge 37 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
eaeebe2
add winograd conv2d
yeliang2258 Aug 17, 2021
25b2fe3
update code
yeliang2258 Aug 17, 2021
ea15e25
update code
yeliang2258 Aug 18, 2021
8dfb12f
fix bugs
haozech Aug 18, 2021
e9a81c3
update code
yeliang2258 Aug 19, 2021
8c82786
fix bugs
haozech Aug 19, 2021
2aead88
update util and test
yeliang2258 Aug 19, 2021
a661baa
update code
yeliang2258 Aug 20, 2021
7c5cbfc
fix test
haozech Aug 24, 2021
7f44064
add original winograd winograd schedule
yeliang2258 Aug 24, 2021
8e7fdee
update schedule
yeliang2258 Aug 25, 2021
c13a29c
fix Heap type problem
haozech Aug 27, 2021
0bbd885
update schedule
yeliang2258 Aug 31, 2021
fd3d821
update schedule 80ms
yeliang2258 Sep 1, 2021
6d44964
update code
yeliang2258 Sep 1, 2021
a03aa5a
Merge branch 'develop' into winograd_dev
yeliang2258 Sep 1, 2021
02196bb
update code
yeliang2258 Sep 1, 2021
1fe28e3
Update nn.cc
yeliang2258 Sep 1, 2021
5b6d517
update code
yeliang2258 Sep 1, 2021
3f3c772
update code
yeliang2258 Sep 1, 2021
2e77e00
Merge remote-tracking branch 'upstream/develop' into winograd_dev
yeliang2258 Sep 2, 2021
89008d2
update nn.cc
yeliang2258 Sep 2, 2021
32f28cd
Merge branch 'winograd_dev' of https://github.com/yeliang2258/CINN in…
yeliang2258 Sep 2, 2021
a99e44d
uodate code
yeliang2258 Sep 2, 2021
b2d364c
update
yeliang2258 Sep 2, 2021
2117728
update
yeliang2258 Sep 2, 2021
ef28781
add test
yeliang2258 Sep 2, 2021
1a1f59c
update nn
yeliang2258 Sep 2, 2021
e65a31c
update schedule
yeliang2258 Sep 2, 2021
76cdf15
update code for debug
yeliang2258 Sep 3, 2021
037216e
Merge remote-tracking branch 'upstream/develop' into winograd_dev
yeliang2258 Sep 15, 2021
1b6b1f0
add temp change in cinn/optim/replace_var_with_expr
yeliang2258 Sep 15, 2021
316a6a7
update code
yeliang2258 Sep 16, 2021
7d3e6b6
Merge remote-tracking branch 'upstream/develop' into winograd_dev
yeliang2258 Sep 17, 2021
1336436
update code
yeliang2258 Sep 17, 2021
26a3cd3
fix bugs
haozech Oct 21, 2021
4654143
add winograd params
haozech Oct 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cinn/backends/codegen_cuda_dev.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs();
auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs);
auto alis_var_exprs = op->CudaAliasVarExprs();
LOG(INFO) << "Op is : " << op->name;
for (auto &i : alloca_temp_buffers) {
LOG(INFO) << "alloca_temp_buffers is : " << i;
}

#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(field__), std::end(field__));
APPEND_TO_NEW_BODY(alloca_temp_buffers)
Expand Down
407 changes: 404 additions & 3 deletions cinn/backends/codegen_cuda_dev_test.cc
100644 → 100755

Large diffs are not rendered by default.

185 changes: 168 additions & 17 deletions cinn/common/cas.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ bool IsDivisible(int64_t a, int64_t b) {
return a % b == 0;
}
bool IsDivisible(const Sum* a, int b);
bool IsDivisible(int a, const Product* b) {
for (auto& item : b->operands()) {
if (item.As<IntImm>() && IsDivisible(a, item.As<IntImm>()->value)) return true;
}
return false;
}
bool IsDivisible(const Product* a, int b) {
for (auto& item : a->operands()) {
if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) return true;
Expand Down Expand Up @@ -172,28 +178,47 @@ Expr Divide(const Sum* a, int b) {
return Sum::Make(args);
}
Expr Divide(const Product* a, int b) {
LOG(INFO) << "Divide(const Product* a, int b) Before simplify, b is: "
<< " / " << b;
std::vector<Expr> args;
int i = 0;
int times = -1;
bool is_divisible = false;
LOG(INFO) << "Divide(const Product* a, int b) Before simplify, a is: ";
for (i = 0; i < a->operands().size(); i++) {
LOG(INFO) << "a->operand is " << a->operand(i);
auto* a_i = a->operand(i).As<IntImm>();
if (a_i && a_i->value % b == 0) {
times = a_i->value / b;
is_divisible = true;
break;
}
}
// NOTE that a should be divisible by b.
CHECK(is_divisible) << "a should be divisible by b";
if (times != 1) {
args.push_back(make_const(a->type(), times));
}
for (int j = 0; j < a->operands().size(); j++) {
if (j == i) continue;
args.push_back(a->operand(j));
if (is_divisible) {
// NOTE that a should be divisible by b.
if (times != 1) {
args.push_back(make_const(a->type(), times));
}
for (int j = 0; j < a->operands().size(); j++) {
if (j == i) continue;
args.push_back(a->operand(j));
}
return Product::Make(args);
} else {
for (i = 0; i < a->operands().size(); i++) {
auto* a_i = a->operand(i).As<IntImm>();
if (a_i && b % a_i->value == 0) {
b = b / a_i->value;
} else {
args.push_back(a->operand(i));
}
}
LOG(INFO) << "Divide(const Product* a, int b) After simplify, b is: " << b;
for (auto& i : args) {
LOG(INFO) << "a->operand is " << i;
}
return FracOp::Make(Product::Make(args), Expr(b));
}
return Product::Make(args);
}

// @}
Expand Down Expand Up @@ -1103,16 +1128,40 @@ bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound, Expr* upper_bound, Expr
}

bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) {
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32))%33 = x%33 - 32%33 (0<=x<=32)
#ifdef CINN_WITH_CUDA
return false;
#else
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32))%33 = x%33 - 32%33 (0<=x<=32)
auto a_sum = a.As<Sum>();
auto b_i = b.As<IntImm>();
if (!a_sum || !b_i) {
return false;
}
LOG(INFO) << "SimplifySpecificSumMod with a, b is : " << a << ", " << b;
// if 0 < b < 3, (3a+b) % 6 = (3a % 6) + (b % 6)
if (a_sum->operands().size() == 2) {
a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals);
auto sum_a_prod = a_sum->operands()[0].As<Product>();
auto sum_b_var = a_sum->operands()[1].As<_Var_>();
if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) {
auto sum_a_prod_b_int = sum_a_prod->operand(1).As<IntImm>();
if (sum_a_prod_b_int) std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1));
auto sum_a_prod_a_int = sum_a_prod->operand(0).As<IntImm>();
auto& interval = var_intervals.at(sum_b_var->name);
int b_abs = std::abs(b_i->value);
int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value);
if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) {
if (std::abs(interval.l) < sum_prod_a_abs && std::abs(interval.r) < sum_prod_a_abs) {
*result = CasSimplify(Sum::Make({CasSimplify(Mod::Make(a_sum->operands()[0], b), var_intervals),
CasSimplify(Mod::Make(a_sum->operands()[1], b), var_intervals)}),
var_intervals);
return true;
}
}
}
}
#ifdef CINN_WITH_CUDA
return false;
#else

int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expand Down Expand Up @@ -1180,6 +1229,7 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) {
}

Expr CasSimplifyMutator::SimplifyMod(Expr u) {
LOG(INFO) << "CasSimplifyMutator::SimplifyMod with Expr u = " << u;
auto* node = u.As<Mod>();
CHECK(node);

Expand All @@ -1202,9 +1252,24 @@ Expr CasSimplifyMutator::SimplifyMod(Expr u) {
// x % 1 = 0
if (b_i && b_i->value == 1) return make_const(b_i->type(), 0);

// 2x % 2 = 0
// 4x % 2 = 0
// 2x % 4 = 2 * (x % 2)
if (b_i && a_product && a_product->operand(0).As<IntImm>()) {
if (a_product->operand(0).As<IntImm>()->value % b_i->value == 0) return make_const(a_product->type(), 0);
if (b_i->value % a_product->operand(0).As<IntImm>()->value == 0) {
int a_product_int = a_product->operand(0).As<IntImm>()->value;
int new_b = b_i->value / a_product_int;
return Product::Make({Expr(a_product_int), SimplifyMod(Mod::Make(a_product->operand(1), Expr(new_b)))});
}
}

if (b_i && a_product && a_product->operand(1).As<IntImm>()) {
if (a_product->operand(1).As<IntImm>()->value % b_i->value == 0) return make_const(a_product->type(), 0);
if (b_i->value % a_product->operand(1).As<IntImm>()->value == 0) {
int a_product_int = a_product->operand(1).As<IntImm>()->value;
int new_b = b_i->value / a_product_int;
return Product::Make({Expr(a_product_int), SimplifyMod(Mod::Make(a_product->operand(0), Expr(new_b)))});
}
}

// 0 % x = 1, 1 % x = 1
Expand Down Expand Up @@ -1466,6 +1531,7 @@ Expr CasSimplifyMutator::SimplifySpecificSum(Expr tmp) {
}

Expr CasSimplifyMutator::operator()(Expr u) {
LOG(INFO) << "CasSimplifyMutator::operator() with Expr u = " << u;
if (u.As<Min>() || u.As<Max>()) {
return SimplifyMinAndMax(u);
}
Expand Down Expand Up @@ -1636,6 +1702,64 @@ Expr ConvertCinnToCAS(Expr expr) {
return copied;
}

Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const Min* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();

Visit(&a);
Visit(&b);

auto min_a = op->a();
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b);
}
}
};
Mutator()(&copied);
return copied;
}

Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const Max* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();

Visit(&a);
Visit(&b);

auto max_a = op->a();
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b);
}
}
};
Mutator()(&copied);
return copied;
}

Expr ConvertCasToCinn(Expr expr) {
Expr copied = optim::IRCopy(expr);

Expand Down Expand Up @@ -1786,7 +1910,7 @@ Expr DividePartially(Sum* a, int b) {
std::vector<Expr> external_sum_args, sum_args;

for (auto& item : a->operands()) {
if (item.As<Product>() && IsDivisible(item.As<Product>(), b)) {
if (item.As<Product>() && (IsDivisible(item.As<Product>(), b) || IsDivisible(b, item.As<Product>()))) {
external_sum_args.push_back(Divide(item.As<Product>(), b));
} else if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) {
external_sum_args.push_back(make_const(item.type(), item.As<IntImm>()->value / b));
Expand Down Expand Up @@ -1916,12 +2040,35 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
// disiviable
if (a_sum && IsDivisible(a_sum, bi->value)) return Divide(a_sum, bi->value);
if (a_product) {
if (IsDivisible(a_product, bi->value)) {
if (IsDivisible(a_product, bi->value) || IsDivisible(bi->value, a_product)) {
return Divide(a_product, bi->value);
} else {
return FracOp::Make(a, b);
}
}

// if 0 < b < 3, (3a+b) / 6 = (3a / 6) + (b / 6)
if (a_sum && a_sum->operands().size() == 2) {
LOG(INFO) << "SimplifyFracOp with expr = " << expr;
a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals);
auto sum_a_prod = a_sum->operands()[0].As<Product>();
auto sum_b_var = a_sum->operands()[1].As<_Var_>();
if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) {
auto sum_a_prod_b_int = sum_a_prod->operand(1).As<IntImm>();
if (sum_a_prod_b_int) std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1));
auto sum_a_prod_a_int = sum_a_prod->operand(0).As<IntImm>();
auto& interval = var_intervals.at(sum_b_var->name);
int b_abs = std::abs(bi->value);
int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value);
if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) {
if (std::abs(interval.l) < sum_prod_a_abs && std::abs(interval.r) < sum_prod_a_abs) {
return CasSimplify(Sum::Make({CasSimplify(FracOp::Make(a_sum->operands()[0], b), var_intervals),
CasSimplify(FracOp::Make(a_sum->operands()[1], b), var_intervals)}),
var_intervals);
}
}
}
}
// not divisiable
/*
if (a_sum) {
Expand All @@ -1948,6 +2095,10 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
}

if (av && bi) {
LOG(INFO) << "When simplifying " << expr << ", in var_intervals:";
for (auto& i : var_intervals) {
LOG(INFO) << "Var [" << i.first << "] has interval " << i.second;
}
if (var_intervals.count(av->name)) {
auto& interval = var_intervals.at(av->name);
int b_abs = std::abs(bi->value);
Expand Down
29 changes: 26 additions & 3 deletions cinn/common/cas.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/optim/ir_simplify.h"

namespace cinn {
namespace common {

namespace detail {
Expr ReplaceMinToConstant(Expr expr);
Expr ReplaceMaxToConstant(Expr expr);
} // namespace detail

/**
* Interval of a _Var_.
*/
Expand All @@ -18,16 +24,33 @@ struct CasInterval {
CasInterval(T l, T r) : l(l), r(r) {
CHECK_LE(l, r) << "left shoud not be larger than right";
}
CasInterval(Expr e_l, Expr e_r) : e_l(e_l), e_r(e_r) {}
CasInterval(Expr expr_l, Expr expr_r) {
LOG(INFO) << "CasInterval is : [" << expr_l << ", " << expr_r << "].";
expr_r = detail::ReplaceMinToConstant(expr_r);
expr_l = detail::ReplaceMaxToConstant(expr_l);
optim::Simplify(&expr_l);
optim::Simplify(&expr_r);
LOG(INFO) << "After simplify, CasInterval is : [" << expr_l << ", " << expr_r << "].";

if (expr_l.is_constant() && expr_r.is_constant()) {
CHECK(expr_l->type().is_integer());
CHECK(expr_r->type().is_integer());
l = expr_l.as_int32();
r = expr_r.as_int32();
return;
}
e_l = expr_l;
e_r = expr_r;
}
int l, r;
// Note: not verify l <= r and (e_l, e_r) has higher priority than (l, r)
Expr e_l, e_r;

friend std::ostream& operator<<(std::ostream& os, const CasInterval& i) {
if (i.e_l.defined() && i.e_r.defined()) {
os << "Interval[" << i.e_l << ", " << i.e_r << "]";
os << "Expr e_l Interval[" << i.e_l << ", " << i.e_r << "]";
} else {
os << "Interval[" << i.l << ", " << i.r << "]";
os << "Int l Interval[" << i.l << ", " << i.r << "]";
}
return os;
}
Expand Down
7 changes: 7 additions & 0 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const Node* node) {
ir::Expr temp = C[i];
inputs.push_back(temp.as_tensor_ref());
}
for (auto& i : inputs) {
LOG(INFO) << "In inputs, it has: " << i->name;
}

auto func = lang::LowerVec(GenOpFuncName(node), stages, inputs, {}, {}, nullptr, this->target_);
VLOG(3) << "The [" << func.size() << "] functions of node [" << node->attrs.node_name << "] are:\n";
Expand Down Expand Up @@ -261,6 +264,10 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
}
}

for (auto& i : inputs) {
LOG(INFO) << "In inputs, it has: " << i->name;
}

auto func = lang::LowerVec(fuse_name, stages, inputs, {}, {}, nullptr, this->target_);
VLOG(3) << "The [" << func.size() << "] functions are:\n";
for (auto& i : func) {
Expand Down
Loading