Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

isl for1 wrong elimination after splitting #329

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/test02_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,15 @@ ir::Module CreateMatmulLoopPermutation(Target target, int m, int n, int k_) {
}

ir::Module CreateMatmulArrayPacking(Target target, int m, int n, int k_) {
m=n=k_=16;
auto [M, N, K] = std::make_tuple(Expr(m), Expr(n), Expr(k_));

Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});

Var k(K.as_int32(), "k0");

Expr bn(32);
Expr bn(16);

auto C_init = Compute(
{M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init");
Expand All @@ -207,7 +208,7 @@ ir::Module CreateMatmulArrayPacking(Target target, int m, int n, int k_) {
auto [k_outer, k_inner] = stages[C]->Split("k0", 4); // NOLINT

stages[C]->Reorder({i_outer, j_outer, k_outer, i_inner, k_inner, j_inner});
stages[C]->Vectorize(j_inner, 8);
stages[C]->Vectorize(j_inner, 16);
}

Module::Builder builder("module_array_packing", target);
Expand Down
64 changes: 32 additions & 32 deletions tests/test02_matmul_case.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,25 @@ TEST(test02, basic) {
compare(); \
reset();

TEST_FUNC(matmul)
// TEST_FUNC(matmul)

TEST_FUNC(matmul_tile)
// TEST_FUNC(matmul_tile)

TEST_FUNC(matmul_split)
// TEST_FUNC(matmul_split)

TEST_FUNC(matmul_block)
// TEST_FUNC(matmul_block)

TEST_FUNC(matmul_vectorize)
// TEST_FUNC(matmul_vectorize)

TEST_FUNC(matmul_loop_permutation)
// TEST_FUNC(matmul_loop_permutation)

TEST_FUNC1(matmul_array_packing, 1e-5)
// TEST_FUNC1(matmul_array_packing, 1e-5)

TEST_FUNC2(matmul_dynamic_shape, 1e-5);
// TEST_FUNC2(matmul_dynamic_shape, 1e-5);

TEST_FUNC2(matmul_dynamic_shape_tile, 1e-5);
// TEST_FUNC2(matmul_dynamic_shape_tile, 1e-5);

TEST_FUNC3(matmul_array_packing_dynamic_shape, 1e-5);
// TEST_FUNC3(matmul_array_packing_dynamic_shape, 1e-5);

// Currently, the execution of a LoweredFunc is scheduled by the outer framework, so no need to Call inside another
// LoweredFunc.
Expand Down Expand Up @@ -175,34 +175,34 @@ TEST(test02, basic) {
target.bits = cinn::Target::Bit::k32;
target.os = cinn::Target::OS::Linux;

TEST_LLVM_MATMUL(basic, target);
TEST_LLVM_MATMUL(tile, target);
TEST_LLVM_MATMUL(block, target);
TEST_LLVM_MATMUL(vectorize, target);
TEST_LLVM_MATMUL(loop_permutation, target);
// TEST_LLVM_MATMUL(basic, target);
// TEST_LLVM_MATMUL(tile, target);
// TEST_LLVM_MATMUL(block, target);
// TEST_LLVM_MATMUL(vectorize, target);
// TEST_LLVM_MATMUL(loop_permutation, target);
TEST_LLVM_MATMUL1(array_packing, target);

{
auto module = cinn::tests::CreateMatmulBasicModule(target, 1024, 1024, 1024);
auto jit = cinn::tests::CreateSimpleJit(module);
auto matmul_fn = reinterpret_cast<void (*)(void**, int32_t)>(jit->Lookup("matmul_basic"));
TEST_FUNC(matmul_fn);
}
// {
// auto module = cinn::tests::CreateMatmulBasicModule(target, 1024, 1024, 1024);
// auto jit = cinn::tests::CreateSimpleJit(module);
// auto matmul_fn = reinterpret_cast<void (*)(void**, int32_t)>(jit->Lookup("matmul_basic"));
// TEST_FUNC(matmul_fn);
// }

#undef TEST_LLVM_MATMUL
}

// include the generated C source code:
// @{
#include "tests/test02_matmul.cc"
#include "tests/test02_matmul_array_packing.cc"
#include "tests/test02_matmul_array_packing_dynamic_shape.cc"
#include "tests/test02_matmul_block.cc"
#include "tests/test02_matmul_call.cc"
#include "tests/test02_matmul_loop_permutation.cc"
#include "tests/test02_matmul_split.cc"
#include "tests/test02_matmul_tile.cc"
#include "tests/test02_matmul_varient_shape.cc"
#include "tests/test02_matmul_varient_shape_tile.cc"
#include "tests/test02_matmul_vectorize.cc"
// #include "tests/test02_matmul.cc"
// #include "tests/test02_matmul_array_packing.cc"
// #include "tests/test02_matmul_array_packing_dynamic_shape.cc"
// #include "tests/test02_matmul_block.cc"
// #include "tests/test02_matmul_call.cc"
// #include "tests/test02_matmul_loop_permutation.cc"
// #include "tests/test02_matmul_split.cc"
// #include "tests/test02_matmul_tile.cc"
// #include "tests/test02_matmul_varient_shape.cc"
// #include "tests/test02_matmul_varient_shape_tile.cc"
// #include "tests/test02_matmul_vectorize.cc"
// @}