Skip to content

Commit

Permalink
dialects: (linalg) Fix matmul custom syntax and custom init (#2852)
Browse files Browse the repository at this point in the history
Co-authored-by: hpompougnac <[email protected]>
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
3 people authored Jul 10, 2024
1 parent abfd34c commit 2469213
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
10 changes: 10 additions & 0 deletions tests/dialects/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def body(args: tuple[Any, ...]):
func.FuncOp("foo", ([], []), funcBody)


def test_matmul_on_memrefs():
a = memref.Alloc.get(f32, shape=[100, 50])
b = memref.Alloc.get(f32, shape=[50, 100])
c = memref.Alloc.get(f32, shape=[100, 100])

matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,))

assert tuple(result.type for result in matmul_op.results) == ()


def test_loop_range_methods():
A = memref.Alloc.get(f32, shape=[100, 50])
B = memref.Alloc.get(f32, shape=[50, 100])
Expand Down
7 changes: 7 additions & 0 deletions tests/filecheck/dialects/linalg/linalg_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1)
linalg.yield %arg3 : f32
}

%2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%4 = "test.op"() : () -> (memref<64x4096xf32>)
linalg.matmul {id} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>)

// CHECK: module {
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, memref<1x256xf32>)
// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) {
Expand All @@ -32,6 +36,9 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1)
// CHECK-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: }
// CHECK-NEXT: %2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %4 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: linalg.matmul {"id"} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>)
// CHECK-NEXT: }

// CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 1, 1>}> ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)

%diff = linalg.sub ins(%2, %2 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%3 : tensor<2x3xf32>) -> tensor<2x3xf32>

%18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%20 = "test.op"() : () -> (memref<64x4096xf32>)

linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)

// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()>
// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
Expand Down Expand Up @@ -91,5 +96,8 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: } -> tensor<2x3xf32>
// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: }

12 changes: 8 additions & 4 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,8 @@ class MatmulOp(IRDLOperation):
res = var_result_def(AnyTensorType)

assembly_format = (
"`ins` `(` $inputs `:` type($inputs) `)` ` ` "
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res) attr-dict"
"attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` "
"`outs` `(` $outputs `:` type($outputs) `)` (`->` type($res)^)?"
)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]
Expand All @@ -681,12 +681,16 @@ def __init__(
res: Sequence[Attribute] | None = None,
):
if res is None:
result_types = tuple(output.type for output in outputs)
result_types = tuple(
cast(AnyTensorType, output_type)
for output in outputs
if isinstance(output_type := output.type, TensorType)
)
else:
result_types = res
super().__init__(
operands=(inputs, outputs),
result_types=result_types,
result_types=(result_types,),
)


Expand Down

0 comments on commit 2469213

Please sign in to comment.