Skip to content

Commit

Permalink
Fixed nested range-based for loops
Browse files Browse the repository at this point in the history
The unnecessary casting ops between integer and index types cause the Affine for validation failed. This only occurs when the constant values are created inside another affine scope.

Fixing the const index value builder to create the index const directly and the loop statement handler in the range-based case to create index type.

Added the test case (https://github.com/ORNL-QCI/qcor/issues/240)

Signed-off-by: Thien Nguyen <[email protected]>
  • Loading branch information
Thien Nguyen committed Nov 9, 2021
1 parent 356a3e2 commit 7548449
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 41 deletions.
8 changes: 8 additions & 0 deletions mlir/parsers/qasm3/tests/test_loop_stmts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ for i in [0:4] {
}
QCOR_EXPECT_TRUE(loop_count == 12);
loop_count = 0;
for i in [0:4] {
for j in [0:3] {
print(i,j);
loop_count += 1;
}
}
QCOR_EXPECT_TRUE(loop_count == 12);
)#";
auto mlir = qcor::mlir_compile(for_stmt, "for_stmt",
Expand Down
19 changes: 12 additions & 7 deletions mlir/parsers/qasm3/utils/qasm3_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,18 @@ mlir::Value get_or_create_constant_integer_value(
mlir::Value get_or_create_constant_index_value(const std::size_t idx,
mlir::Location location,
int width,
ScopedSymbolTable& symbol_table,
mlir::OpBuilder& builder) {
auto type = mlir::IntegerType::get(builder.getContext(), width);
auto constant_int = get_or_create_constant_integer_value(
idx, location, type, symbol_table, builder);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
ScopedSymbolTable &symbol_table,
mlir::OpBuilder &builder) {
if (symbol_table.has_constant_integer(idx, width)) {
// If there is a cached constant integer value, cast and return it:
auto constant_int = symbol_table.get_constant_integer(idx, width);
return builder.create<mlir::IndexCastOp>(location, constant_int,
builder.getIndexType());
} else {
// Otherwise, create a new constant index value
auto integer_attr = mlir::IntegerAttr::get(builder.getIndexType(), idx);
return builder.create<mlir::ConstantOp>(location, integer_attr);
}
}

mlir::Type convertQasm3Type(qasm3::qasm3Parser::ClassicalTypeContext* ctx,
Expand Down
46 changes: 12 additions & 34 deletions mlir/parsers/qasm3/visitor_handlers/loop_stmt_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,43 +156,21 @@ void qasm3_visitor::createRangeBasedForLoop(
auto n_expr = range->expression().size();
int a, b, c;

// First question what type should we use?
mlir::Type int_type = builder.getI64Type();
if (symbol_table.has_symbol(range->expression(0)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(0)->getText()).getType();
}
if (n_expr == 3) {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
} else if (symbol_table.has_symbol(range->expression(2)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(2)->getText()).getType();
}
} else {
if (symbol_table.has_symbol(range->expression(1)->getText())) {
int_type =
symbol_table.get_symbol(range->expression(1)->getText()).getType();
}
}

if (int_type.isa<mlir::MemRefType>()) {
int_type = int_type.cast<mlir::MemRefType>().getElementType();
}
// For loop variables will be index type (casting will be done if needed)
mlir::Type index_type = builder.getIndexType();

c = 1;
mlir::Value a_value, b_value,
c_value = get_or_create_constant_integer_value(c, location, int_type,
symbol_table, builder);
c_value = get_or_create_constant_index_value(c, location, 64,
symbol_table, builder);

const auto const_eval_a_val =
symbol_table.try_evaluate_constant_integer_expression(
range->expression(0)->getText());
if (const_eval_a_val.has_value()) {
// std::cout << "A val = " << const_eval_a_val.value() << "\n";
a_value = get_or_create_constant_integer_value(
const_eval_a_val.value(), location, int_type, symbol_table, builder);
a_value = get_or_create_constant_index_value(
const_eval_a_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table, file_name);
exp_generator.visit(range->expression(0));
Expand All @@ -208,8 +186,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(2)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
Expand All @@ -231,8 +209,8 @@ void qasm3_visitor::createRangeBasedForLoop(
} else {
c = symbol_table.evaluate_constant_integer_expression(
range->expression(1)->getText());
c_value = get_or_create_constant_integer_value(
c, location, a_value.getType(), symbol_table, builder);
c_value = get_or_create_constant_index_value(
c, location, 64, symbol_table, builder);
}

} else {
Expand All @@ -241,8 +219,8 @@ void qasm3_visitor::createRangeBasedForLoop(
range->expression(1)->getText());
if (const_eval_b_val.has_value()) {
// std::cout << "B val = " << const_eval_b_val.value() << "\n";
b_value = get_or_create_constant_integer_value(
const_eval_b_val.value(), location, int_type, symbol_table, builder);
b_value = get_or_create_constant_index_value(
const_eval_b_val.value(), location, 64, symbol_table, builder);
} else {
qasm3_expression_generator exp_generator(builder, symbol_table,
file_name);
Expand Down

0 comments on commit 7548449

Please sign in to comment.