From 754844991c48732c2a28aac1800e938119cdfe2e Mon Sep 17 00:00:00 2001 From: Thien Nguyen Date: Tue, 9 Nov 2021 10:35:45 -0500 Subject: [PATCH] Fixed nested range-based for loops The unnecessary casting ops between integer and index types cause the Affine for validation failed. This only occurs when the constant values are created inside another affine scope. Fixing the const index value builder to create the index const directly and the loop statement handler in the range-based case to create index type. Added the test case (https://github.com/ORNL-QCI/qcor/issues/240) Signed-off-by: Thien Nguyen --- mlir/parsers/qasm3/tests/test_loop_stmts.cpp | 8 ++++ mlir/parsers/qasm3/utils/qasm3_utils.cpp | 19 +++++--- .../visitor_handlers/loop_stmt_handler.cpp | 46 +++++-------------- 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/mlir/parsers/qasm3/tests/test_loop_stmts.cpp b/mlir/parsers/qasm3/tests/test_loop_stmts.cpp index fc4652f4..4166bb54 100644 --- a/mlir/parsers/qasm3/tests/test_loop_stmts.cpp +++ b/mlir/parsers/qasm3/tests/test_loop_stmts.cpp @@ -55,6 +55,14 @@ for i in [0:4] { } QCOR_EXPECT_TRUE(loop_count == 12); +loop_count = 0; +for i in [0:4] { + for j in [0:3] { + print(i,j); + loop_count += 1; + } +} +QCOR_EXPECT_TRUE(loop_count == 12); )#"; auto mlir = qcor::mlir_compile(for_stmt, "for_stmt", diff --git a/mlir/parsers/qasm3/utils/qasm3_utils.cpp b/mlir/parsers/qasm3/utils/qasm3_utils.cpp index 1f774a3b..abc69c20 100644 --- a/mlir/parsers/qasm3/utils/qasm3_utils.cpp +++ b/mlir/parsers/qasm3/utils/qasm3_utils.cpp @@ -147,13 +147,18 @@ mlir::Value get_or_create_constant_integer_value( mlir::Value get_or_create_constant_index_value(const std::size_t idx, mlir::Location location, int width, - ScopedSymbolTable& symbol_table, - mlir::OpBuilder& builder) { - auto type = mlir::IntegerType::get(builder.getContext(), width); - auto constant_int = get_or_create_constant_integer_value( - idx, location, type, symbol_table, builder); - return builder.create(location, constant_int, - builder.getIndexType()); + ScopedSymbolTable &symbol_table, + mlir::OpBuilder &builder) { + if (symbol_table.has_constant_integer(idx, width)) { + // If there is a cached constant integer value, cast and return it: + auto constant_int = symbol_table.get_constant_integer(idx, width); + return builder.create(location, constant_int, + builder.getIndexType()); + } else { + // Otherwise, create a new constant index value + auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx); + return builder.create(location, integer_attr); + } } mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx, diff --git a/mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp b/mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp index 15bd5b1d..2ac29522 100644 --- a/mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp +++ b/mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp @@ -156,43 +156,21 @@ void qasm3_visitor::createRangeBasedForLoop( auto n_expr = range->expression().size(); int a, b, c; - // First question what type should we use? - mlir::Type int_type = builder.getI64Type(); - if (symbol_table.has_symbol(range->expression(0)->getText())) { - int_type = - symbol_table.get_symbol(range->expression(0)->getText()).getType(); - } - if (n_expr == 3) { - if (symbol_table.has_symbol(range->expression(1)->getText())) { - int_type = - symbol_table.get_symbol(range->expression(1)->getText()).getType(); - } else if (symbol_table.has_symbol(range->expression(2)->getText())) { - int_type = - symbol_table.get_symbol(range->expression(2)->getText()).getType(); - } - } else { - if (symbol_table.has_symbol(range->expression(1)->getText())) { - int_type = - symbol_table.get_symbol(range->expression(1)->getText()).getType(); - } - } - - if (int_type.isa()) { - int_type = int_type.cast().getElementType(); - } + // For loop variables will be index type (casting will be done if needed) + mlir::Type index_type = builder.getIndexType(); c = 1; mlir::Value a_value, b_value, - c_value = get_or_create_constant_integer_value(c, location, int_type, - symbol_table, builder); + c_value = get_or_create_constant_index_value(c, location, 64, + symbol_table, builder); const auto const_eval_a_val = symbol_table.try_evaluate_constant_integer_expression( range->expression(0)->getText()); if (const_eval_a_val.has_value()) { // std::cout << "A val = " << const_eval_a_val.value() << "\n"; - a_value = get_or_create_constant_integer_value( - const_eval_a_val.value(), location, int_type, symbol_table, builder); + a_value = get_or_create_constant_index_value( + const_eval_a_val.value(), location, 64, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); exp_generator.visit(range->expression(0)); @@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop( range->expression(2)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; - b_value = get_or_create_constant_integer_value( - const_eval_b_val.value(), location, int_type, symbol_table, builder); + b_value = get_or_create_constant_index_value( + const_eval_b_val.value(), location, 64, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name); @@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop( } else { c = symbol_table.evaluate_constant_integer_expression( range->expression(1)->getText()); - c_value = get_or_create_constant_integer_value( - c, location, a_value.getType(), symbol_table, builder); + c_value = get_or_create_constant_index_value( + c, location, 64, symbol_table, builder); } } else { @@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop( range->expression(1)->getText()); if (const_eval_b_val.has_value()) { // std::cout << "B val = " << const_eval_b_val.value() << "\n"; - b_value = get_or_create_constant_integer_value( - const_eval_b_val.value(), location, int_type, symbol_table, builder); + b_value = get_or_create_constant_index_value( + const_eval_b_val.value(), location, 64, symbol_table, builder); } else { qasm3_expression_generator exp_generator(builder, symbol_table, file_name);