Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ONNX demo #2114

Closed
wants to merge 152 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
152 commits
Select commit Hold shift + click to select a range
305b6e0
wip generate add.mlir
superlopuh Jan 18, 2024
bc5db77
res_type default argument
superlopuh Jan 18, 2024
c39e77f
implement basic frontend
superlopuh Jan 18, 2024
34a2b34
implicit builder all the way
superlopuh Jan 18, 2024
25c5de7
move building logic to xdsl
superlopuh Jan 18, 2024
11132b6
remove Sequence
superlopuh Jan 19, 2024
5994238
Merge branch 'main' into sasha/onnx/frontend
superlopuh Feb 2, 2024
2ea27c4
add typing
superlopuh Feb 2, 2024
55853fe
remove unnecessary things
superlopuh Feb 3, 2024
6c21d83
add linalg.add
superlopuh Feb 7, 2024
c853fae
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 7, 2024
e3645a6
remove add.onnx
superlopuh Feb 7, 2024
15c6669
add results to linalg.generic custom format
superlopuh Feb 7, 2024
2b0f752
wip wip wip don't print generic format by default in mlir-opt
superlopuh Feb 7, 2024
2aa07d8
add tensor dialect with tensor.empty
superlopuh Feb 7, 2024
51e9e4c
add custom syntax parsing to memref alloc
superlopuh Feb 7, 2024
da85e96
wip wip wip
superlopuh Feb 7, 2024
b18b0df
transformations: (snitch_strem) factor out stride pattern ops lowerin…
superlopuh Feb 12, 2024
df804d9
dialects: (snitch_stream) add stride pattern attribute
superlopuh Feb 12, 2024
c2df141
tests: (snitch_stream) modify test to minimise upcoming PR diff [NFC]
superlopuh Feb 12, 2024
d21eef3
dialects: (snitch_stream) replace stride pattern op with attribute
superlopuh Feb 12, 2024
05113cf
dialects: (memref_stream) add streaming region
superlopuh Feb 13, 2024
a0d8ebe
transforms: (memref_stream) add streaming_region lowering
superlopuh Feb 13, 2024
db747f5
dialects: (memref_stream) add memref_stream.generic
superlopuh Feb 13, 2024
c891d0b
transformations: add convert-linalg-to-memref-stream
superlopuh Feb 13, 2024
71ec3b4
transformations: add memref-streamify
superlopuh Feb 13, 2024
d6aae8d
transformations: add lowering from memref_stream generic to loops
superlopuh Feb 13, 2024
d00f231
Merge branch 'main' into sasha/snitch_stream/stride-pattern-op
superlopuh Feb 13, 2024
4b67cd6
Merge branch 'sasha/snitch_stream/stride-pattern-op' into sasha/memre…
superlopuh Feb 13, 2024
5e714da
Merge branch 'sasha/memref_stream/streaming-region' into sasha/memref…
superlopuh Feb 13, 2024
c4eb7bf
fix doc comment
superlopuh Feb 13, 2024
b6cb33e
Merge branch 'sasha/memref_stream/streaming-region' into sasha/memref…
superlopuh Feb 13, 2024
70bae88
fix doc string
superlopuh Feb 13, 2024
f5cb14b
Merge branch 'sasha/memref_stream/generic' into sasha/memref_stream/c…
superlopuh Feb 13, 2024
65c1c98
Merge branch 'sasha/memref_stream/convert-linalg' into sasha/memref_s…
superlopuh Feb 13, 2024
487be38
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 13, 2024
7ad8963
Merge branch 'main' into sasha/memref_stream/lower-streaming-region
superlopuh Feb 13, 2024
2048f4e
Merge branch 'main' into sasha/memref_stream/lower-streaming-region
superlopuh Feb 13, 2024
385e509
handle unique streaming pattern
superlopuh Feb 13, 2024
f3d1b07
Merge branch 'sasha/memref_stream/lower-streaming-region' into sasha/…
superlopuh Feb 15, 2024
caae3b4
Merge branch 'sasha/memref_stream/generic' into sasha/memref_stream/c…
superlopuh Feb 15, 2024
05c1e48
Merge branch 'sasha/memref_stream/convert-linalg' into sasha/memref_s…
superlopuh Feb 15, 2024
22e0f0f
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 15, 2024
5878ee3
Merge branch 'main' into sasha/memref_stream/lower-streaming-region
superlopuh Feb 15, 2024
86d7d3a
Merge branch 'sasha/memref_stream/lower-streaming-region' into sasha/…
superlopuh Feb 15, 2024
d5c0e24
Merge branch 'sasha/memref_stream/generic' into sasha/memref_stream/c…
superlopuh Feb 15, 2024
173ea1d
Merge branch 'sasha/memref_stream/convert-linalg' into sasha/memref_s…
superlopuh Feb 15, 2024
5651ea8
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 15, 2024
5fcf1c3
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 15, 2024
8c0e612
use existing onnx frontend helpers
superlopuh Feb 15, 2024
eca2779
revert unnecessary filecheck changes
superlopuh Feb 15, 2024
aa4c3a6
add onnx to pyright ci
superlopuh Feb 15, 2024
a4bf47a
Merge branch 'main' into sasha/memref_stream/convert-linalg
superlopuh Feb 15, 2024
bc4156e
add a comment
superlopuh Feb 15, 2024
5783ee8
Merge branch 'sasha/memref_stream/convert-linalg' into sasha/memref_s…
superlopuh Feb 15, 2024
fa74e22
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 15, 2024
e9d833d
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 15, 2024
f98f6b0
add some comments
superlopuh Feb 15, 2024
fd8d625
add marimo as a dependency
superlopuh Feb 15, 2024
02e5d34
Merge branch 'main' into sasha/memref_stream/streamify
superlopuh Feb 15, 2024
4156ff2
fix test
superlopuh Feb 15, 2024
7a7c568
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 15, 2024
33a02e4
fix doc comment
superlopuh Feb 15, 2024
1f11fcf
Merge branch 'sasha/memref_stream/streamify' into sasha/memref_stream…
superlopuh Feb 15, 2024
fb80836
Merge branch 'sasha/memref_stream/loops' into sasha/onnx-demo
superlopuh Feb 15, 2024
a567cf9
interpreter: add infrastructure to collect interpreter traces
superlopuh Feb 15, 2024
38663c5
the walls have ears
superlopuh Feb 15, 2024
cccbdca
Merge branch 'interpreter/trace' into sasha/onnx-demo
superlopuh Feb 15, 2024
8b5cc0f
interpreter: (riscv) add impl for riscv.sub
superlopuh Feb 15, 2024
fbf9d17
add snitch and runitme op comparison to marimo
superlopuh Feb 15, 2024
340ae06
print assembly
superlopuh Feb 16, 2024
858079c
add Snitch assembly
superlopuh Feb 16, 2024
9fdaf81
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 18, 2024
fa1cb00
revert bad merge
superlopuh Feb 18, 2024
e156e50
revert unnecessary init change
superlopuh Feb 18, 2024
7ae8fe3
remove unnecessary file
superlopuh Feb 18, 2024
6691eef
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 22, 2024
5cd1d15
wip wip wip
superlopuh Feb 22, 2024
6c68193
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 22, 2024
d94c19e
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 26, 2024
e7e829c
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 27, 2024
dfdaf08
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 29, 2024
fb1d042
comment register checking again
superlopuh Feb 29, 2024
d86c355
add sub
superlopuh Feb 29, 2024
a53609f
Merge branch 'main' into sasha/onnx-demo
superlopuh Feb 29, 2024
8e56c89
reorganize init code in different files
alecerio Mar 7, 2024
a747c99
reorganise in two files
alecerio Mar 7, 2024
c9abbe0
fix imports
alecerio Mar 11, 2024
edf9331
test get_shape
alecerio Mar 12, 2024
1a1a5f3
test for get_tensor_type
alecerio Mar 12, 2024
190ddc6
move implementation test get_type in get_tensor_type
alecerio Mar 12, 2024
017d2ac
implementation test for get_test
alecerio Mar 12, 2024
ce1033a
Merge branch 'main' into sasha/onnx-demo
alecerio Mar 12, 2024
c371164
export build_module
superlopuh Mar 19, 2024
edde9bb
update snitch lowering pass name
superlopuh Mar 19, 2024
bdf726e
Merge branch 'main' into sasha/onnx-demo
superlopuh Mar 19, 2024
6efe6e6
dependencies: (onnx) use versioned onnx instead of weekly
superlopuh Apr 30, 2024
b212217
Merge branch 'sasha/onnx/dep' into sasha/onnx-demo
superlopuh Apr 30, 2024
17cb9d1
fix onnx notebook
superlopuh May 3, 2024
2286160
test_lower_snitch_stream_to_asm
superlopuh May 3, 2024
4b39064
remove old file for ir building
superlopuh May 3, 2024
032a52a
Merge branch 'main' into sasha/onnx-demo
superlopuh May 3, 2024
e7c963f
transformations: add malloc to riscv function call lowering
superlopuh May 3, 2024
52393d1
Merge branch 'sasha/memref/malloc-riscv' into sasha/onnx-demo
superlopuh May 3, 2024
4345642
fix registers
superlopuh May 3, 2024
198a039
Merge branch 'sasha/memref/malloc-riscv' into sasha/onnx-demo
superlopuh May 3, 2024
c366844
Merge branch 'main' into sasha/memref/malloc-riscv
superlopuh May 4, 2024
d0a116e
Merge branch 'sasha/memref/malloc-riscv' into sasha/onnx-demo
superlopuh May 4, 2024
8ea560f
wip add test file
superlopuh May 4, 2024
afad428
fix errors
superlopuh May 4, 2024
525ddda
core: add argument and result types getters to CallableOpInterface
superlopuh May 4, 2024
20b7427
Merge branch 'sasha/core/callable-types' into sasha/onnx-demo
superlopuh May 4, 2024
4bb3520
interpreter: use funtion type when parsing xdsl-run arguments
superlopuh May 4, 2024
2bdfcbe
add test for mixed values
superlopuh May 4, 2024
502a1d8
Merge branch 'sasha/interpreter/dialect-specific-values' into sasha/o…
superlopuh May 4, 2024
538f0c6
remove test things
superlopuh May 4, 2024
7e072b3
disable riscv reg checking again
superlopuh Jun 26, 2024
d5eb51f
docs: allow marimo notebooks as documentation
superlopuh Jun 26, 2024
70a87a8
add marimo to CI
superlopuh Jun 26, 2024
54b9710
Revert "add marimo to CI"
superlopuh Jun 26, 2024
c1ed347
use correct CI
superlopuh Jun 26, 2024
4798ce0
try again
superlopuh Jun 26, 2024
d5b5864
remove exception
superlopuh Jun 26, 2024
b6d6088
rename file
superlopuh Jun 26, 2024
850d80a
Revert "disable riscv reg checking again"
superlopuh Jun 27, 2024
20875a1
Merge branch 'sasha/marimo/init' into sasha/onnx-demo
superlopuh Jun 27, 2024
b413933
move notebook
superlopuh Jun 27, 2024
d5b32d7
disable riscv reg checking again
superlopuh Jun 26, 2024
3edf261
Merge branch 'main' into sasha/onnx-demo
superlopuh Jun 27, 2024
487d409
docs: add onnx demo notebook
superlopuh Jun 27, 2024
e0935a6
try to cd
superlopuh Jun 27, 2024
660b2ef
add bin to path
superlopuh Jun 27, 2024
e5c43e5
Merge branch 'sasha/marimo/onnx-init' into sasha/onnx-demo
superlopuh Jun 27, 2024
60eef0a
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 3, 2024
16a3b38
add interpreter thing
superlopuh Jul 3, 2024
75a4599
revert extra changes
superlopuh Jul 3, 2024
d479f3c
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 3, 2024
961a60f
prettify notebook
superlopuh Jul 3, 2024
62151b8
add pretty printing
superlopuh Jul 3, 2024
d97e743
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 12, 2024
a14e562
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 12, 2024
4b3df02
docs: (marimo) prettify onnx notebook
superlopuh Jul 12, 2024
7abc47f
Merge branch 'sasha/marimo/prettify-onnx' into sasha/onnx-demo
superlopuh Jul 17, 2024
714aa83
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 17, 2024
2c5ccf0
wip wip wip
superlopuh Jul 17, 2024
948cbcc
add linalg_snitch notebook
superlopuh Jul 17, 2024
e1a0c2e
clone arg name hints
superlopuh Jul 17, 2024
fc7ef1c
unhide things
superlopuh Jul 17, 2024
72312ea
prettify notebook and remove onnx references
superlopuh Jul 18, 2024
70e8369
move around xdsl imports and repeat sliders
superlopuh Jul 18, 2024
780c46e
fix test and remove fastmath flags in input
superlopuh Jul 18, 2024
923947f
Merge branch 'main' into sasha/onnx-demo
superlopuh Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions hello.mlir
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
builtin.module {
func.func @main_graph(%0 : tensor<3x2xf32>, %1 : tensor<3x2xf32>) -> tensor<3x2xf32> {
%2 = tensor.empty() : tensor<3x2xf32>
%3 = linalg.add ins(%0, %1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%2 : tensor<3x2xf32>) -> tensor<3x2xf32>
func.return %3 : tensor<3x2xf32>
}
}
300 changes: 300 additions & 0 deletions onnx-frontend.py
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import marimo

__generated_with = "0.1.77"
app = marimo.App()


@app.cell
def __():
import marimo as mo

return (mo,)


@app.cell
def __(mo):
a = mo.ui.slider(1, 4, value=2)
a
return (a,)


@app.cell
def __(a):
dims = list(range(2, 2 + a.value))
dims
return (dims,)


@app.cell
def __(dims):
import onnx
from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper

# Create one input (ValueInfoProto)
X1 = helper.make_tensor_value_info("X1", TensorProto.DOUBLE, dims)
X2 = helper.make_tensor_value_info("X2", TensorProto.DOUBLE, dims)

# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, dims)

# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
"Add", # node name
["X1", "X2"], # inputs
["Y"], # outputs
)

# Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def],
"main_graph",
[X1, X2],
[Y],
)

# Set opset version to 18
opset_import = [helper.make_operatorsetid("", 18)]

# Create the model (ModelProto) without using helper.make_model
model_def = helper.make_model(
graph_def, producer_name="onnx-example", opset_imports=opset_import
)

print(f"The model is:\n{model_def}")
onnx.checker.check_model(model_def)
# onnx.save(model_def, "add.onnx")
print("The model is checked!")
return (
AttributeProto,
GraphProto,
TensorProto,
ValueInfoProto,
X1,
X2,
Y,
graph_def,
helper,
model_def,
node_def,
onnx,
opset_import,
)


@app.cell
def __():
from xdsl.ir import Attribute, SSAValue

return Attribute, SSAValue


@app.cell
def __(model_def):
from xdsl.frontend.onnx import build_module

init_module = build_module(model_def.graph)

str(init_module)
return build_module, init_module


@app.cell
def __(init_module):
from xdsl.ir import MLContext
from xdsl.tools.command_line_tool import get_all_dialects
from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass

ctx = MLContext()

for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)

linalg_module = init_module

ConvertOnnxToLinalgPass().apply(ctx, linalg_module)

str(linalg_module)
return (
ConvertOnnxToLinalgPass,
MLContext,
ctx,
dialect_factory,
dialect_name,
get_all_dialects,
linalg_module,
)


@app.cell
def __(ctx, linalg_module):
from xdsl.transforms.mlir_opt import MLIROptPass

generalized_module = linalg_module

MLIROptPass(arguments=["--linalg-generalize-named-ops"]).apply(
ctx, generalized_module
)

str(generalized_module)
return MLIROptPass, generalized_module


@app.cell
def __(MLIROptPass, ctx, generalized_module):
bufferized_module = generalized_module

MLIROptPass(
arguments=[
"--empty-tensor-to-alloc-tensor",
"--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map",
]
).apply(ctx, bufferized_module)

str(bufferized_module)
return (bufferized_module,)


@app.cell
def __(MLIROptPass, bufferized_module, ctx):
scf_module = bufferized_module

MLIROptPass(
arguments=["--convert-linalg-to-loops", "--lower-affine", "--canonicalize"]
).apply(ctx, scf_module)

str(scf_module)
return (scf_module,)


@app.cell
def __(ctx, scf_module):
from xdsl.backend.riscv.lowering import (
convert_arith_to_riscv,
convert_func_to_riscv_func,
convert_memref_to_riscv,
convert_scf_to_riscv_scf,
)
from xdsl.passes import PipelinePass
from xdsl.transforms import reconcile_unrealized_casts

riscv_module = scf_module

lower_to_riscv = PipelinePass(
[
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),
convert_memref_to_riscv.ConvertMemrefToRiscvPass(),
convert_arith_to_riscv.ConvertArithToRiscvPass(),
convert_scf_to_riscv_scf.ConvertScfToRiscvPass(),
reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(),
]
).apply(ctx, riscv_module)

str(riscv_module)
return (
PipelinePass,
convert_arith_to_riscv,
convert_func_to_riscv_func,
convert_memref_to_riscv,
convert_scf_to_riscv_scf,
lower_to_riscv,
reconcile_unrealized_casts,
riscv_module,
)


@app.cell
def __(PipelinePass, ctx, riscv_module):
from xdsl.backend.riscv.lowering.convert_snitch_stream_to_snitch import (
ConvertSnitchStreamToSnitch,
)
from xdsl.transforms.canonicalize import CanonicalizePass
from xdsl.transforms.lower_snitch import LowerSnitchPass
from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation
from xdsl.transforms.riscv_scf_loop_range_folding import (
RiscvScfLoopRangeFoldingPass,
)
from xdsl.transforms.snitch_register_allocation import SnitchRegisterAllocation

regalloc_module = riscv_module.clone()

PipelinePass(
[
RISCVRegisterAllocation(),
CanonicalizePass(),
]
).apply(ctx, regalloc_module)

str(regalloc_module)
return (
CanonicalizePass,
ConvertSnitchStreamToSnitch,
LowerSnitchPass,
RISCVRegisterAllocation,
RiscvScfLoopRangeFoldingPass,
SnitchRegisterAllocation,
regalloc_module,
)


@app.cell
def __(CanonicalizePass, ctx, regalloc_module):
from xdsl.backend.riscv.lowering.convert_riscv_scf_to_riscv_cf import (
ConvertRiscvScfToRiscvCfPass,
)
from xdsl.dialects.riscv import riscv_code

assembly_module = regalloc_module.clone()

ConvertRiscvScfToRiscvCfPass().apply(ctx, assembly_module)
CanonicalizePass().apply(ctx, assembly_module)

str(assembly_module)
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
return ConvertRiscvScfToRiscvCfPass, assembly_module, riscv_code


@app.cell
def __(assembly_module, riscv_code):
assembly = riscv_code(assembly_module)

assembly
return (assembly,)


@app.cell
def __(dims):
from math import prod

from xdsl.interpreters.riscv import RawPtr

n = prod(dims)

lhs = RawPtr.new_float64([i + 1 for i in range(n)])
rhs = RawPtr.new_float64([(i + 1) / 100 for i in range(n)])

lhs.float64.get_list(n), rhs.float64.get_list(n)
return RawPtr, lhs, n, prod, rhs


@app.cell
def __(ctx, lhs, n, rhs, riscv_module):
from xdsl.interpreter import Interpreter
from xdsl.interpreters import register_implementations

interpreter = Interpreter(riscv_module)

register_implementations(interpreter, ctx, include_wgpu=False)

(res,) = interpreter.call_op("main_graph", (lhs, rhs))

res.float64.get_list(n)
return Interpreter, interpreter, register_implementations, res


@app.cell
def __(a):
a
return


if __name__ == "__main__":
app.run()
17 changes: 17 additions & 0 deletions tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1)
linalg.yield %arg3 : f32
}

%2, %3 = "test.op"() : () -> (tensor<2x3xf32>, tensor<2x3xf32>)

%sum = linalg.add ins(%2, %2 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%3 : tensor<2x3xf32>) -> tensor<2x3xf32>
superlopuh marked this conversation as resolved.
Show resolved Hide resolved

%sum_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2, %2 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%2 : tensor<2x3xf32>) {
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
^bb0(%in: f32, %in_0: f32, %out: f32):
%acc = arith.addf %in, %in_0 : f32
linalg.yield %acc : f32
} -> tensor<2x3xf32>

// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()>
// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
Expand All @@ -25,4 +35,11 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1)
// CHECK-NEXT: ^bb0(%in: f32, %out: f32):
// CHECK-NEXT: linalg.yield %in : f32
// CHECK-NEXT: }
// CHECK-NEXT: %1:2 = "test.op"() : () -> (tensor<2x3xf32>, tensor<2x3xf32>)
// CHECK-NEXT: %2 = linalg.add ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#1 : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) {
// CHECK-NEXT: ^bb0(%in: f32, %in_0: f32, %out: f32):
// CHECK-NEXT: %4 = arith.addf %in, %in_0 : f32
// CHECK-NEXT: linalg.yield %4 : f32
// CHECK-NEXT: } -> tensor<2x3xf32>
// CHECK-NEXT: }
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ memref.store %v0, %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
%v2 = memref.load %m[%i0, %i1] {"nontemporal" = false} : memref<2x3xi32>
%v3 = memref.load %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
%r1 = memref.expand_shape %r [[0, 1], [2]] : memref<10x3xi32> into memref<5x2x3xi32>
%r2 = memref.collapse_shape %r [[0, 1]] : memref<10x3xi32> into memref<30xi32>
%r2 = memref.collapse_shape %r [[0, 1]] : memref<10x3xi32> into memref<30xi32>

%bla = tensor.empty() : tensor<2x3xf32>
superlopuh marked this conversation as resolved.
Show resolved Hide resolved

// CHECK: module {
// CHECK-NEXT: func.func @memref_alloca_scope() {
Expand All @@ -39,5 +41,5 @@ memref.store %v0, %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.load %3[%1, %2] : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.load %3[%1, %2] {nontemporal = true} : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.expand_shape %4 [[0, 1], [2]] : memref<10x3xi32> into memref<5x2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.collapse_shape %4 [[0, 1]] : memref<10x3xi32> into memref<30xi32>
// CHECK-NEXT: %{{.*}} = memref.collapse_shape %4 [[0, 1]] : memref<10x3xi32> into memref<30xi32>
// CHECK-NEXT: }
Loading
Loading