From 987abb56d45a710d16985c64b194cae342c1a191 Mon Sep 17 00:00:00 2001 From: Federico Ficarelli <1379149+nazavode@users.noreply.github.com> Date: Wed, 10 Jul 2024 17:32:55 +0200 Subject: [PATCH] dialects: (riscv_snitch) Add f32 mul, add, pack from Snitch packed SIMD extension (#2872) This PR adds a small set of operations to `riscv_snitch` directly mapping on top of the Snitch custom packed SIMD ISA. Further ops are going to need operand constraints to be correctly register allocated. --- .../riscv_snitch/assembly_emission.mlir | 8 ++ .../filecheck/dialects/riscv_snitch/ops.mlir | 23 ++++++ xdsl/dialects/riscv_snitch.py | 80 +++++++++++++++++++ 3 files changed, 111 insertions(+) diff --git a/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir b/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir index 0ba89d799f..197d222214 100644 --- a/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir +++ b/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir @@ -22,6 +22,11 @@ riscv_func.func @main() { %4 = riscv_snitch.dmstat %3 : (!riscv.reg) -> !riscv.reg %5 = riscv_snitch.dmstati 22 : () -> !riscv.reg + %f0 = riscv.get_float_register : !riscv.freg + %f1 = riscv_snitch.vfmul.s %f0, %f0 : (!riscv.freg, !riscv.freg) -> !riscv.freg + %f2 = riscv_snitch.vfadd.s %f0, %f0 : (!riscv.freg, !riscv.freg) -> !riscv.freg + %f3 = riscv_snitch.vfcpka.s.s %f0, %f0 : (!riscv.freg, !riscv.freg) -> !riscv.freg + riscv_func.return } @@ -36,4 +41,7 @@ riscv_func.func @main() { // CHECK-NEXT: dmcpy a3, a0, a2 // CHECK-NEXT: dmstat a4, a3 // CHECK-NEXT: dmstati a5, 22 +// CHECK-NEXT: vfmul.s ft1, ft0, ft0 +// CHECK-NEXT: vfadd.s ft1, ft0, ft0 +// CHECK-NEXT: vfcpka.s.s ft1, ft0, ft0 // CHECK-NEXT: ret diff --git a/tests/filecheck/dialects/riscv_snitch/ops.mlir b/tests/filecheck/dialects/riscv_snitch/ops.mlir index cfd41f4bdc..1cc39c1965 100644 --- a/tests/filecheck/dialects/riscv_snitch/ops.mlir +++ b/tests/filecheck/dialects/riscv_snitch/ops.mlir @@ -86,6 +86,22 @@ riscv_func.func @xdma() { riscv_func.return } +riscv_func.func @simd() { + %v = riscv.get_float_register : !riscv.freg + // CHECK: %v = riscv.get_float_register : !riscv.freg + + %0 = riscv_snitch.vfmul.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + // CHECK-NEXT: %0 = riscv_snitch.vfmul.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + + %1 = riscv_snitch.vfadd.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + // CHECK-NEXT: %1 = riscv_snitch.vfadd.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + + %2 = riscv_snitch.vfcpka.s.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + // CHECK-NEXT: %2 = riscv_snitch.vfcpka.s.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg + + riscv_func.return +} + // CHECK-GENERIC-NEXT: "builtin.module"() ({ // CHECK-GENERIC-NEXT: "riscv_func.func"() ({ @@ -129,6 +145,13 @@ riscv_func.func @xdma() { // CHECK-GENERIC-NEXT: %{{.*}} = "riscv_snitch.dmstati"() <{"status" = 0 : ui5}> : () -> !riscv.reg // CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> () // CHECK-GENERIC-NEXT: }) {"sym_name" = "xdma", "function_type" = () -> ()} : () -> () +// CHECK-GENERIC-NEXT: "riscv_func.func"() ({ +// CHECK-GENERIC-NEXT: %v = "riscv.get_float_register"() : () -> !riscv.freg +// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-GENERIC-NEXT: %2 = "riscv_snitch.vfcpka.s.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> () +// CHECK-GENERIC-NEXT: }) {"sym_name" = "simd", "function_type" = () -> ()} : () -> () // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index c12d1391c6..6179740796 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -16,6 +16,7 @@ ) from xdsl.dialects.riscv import ( AssemblyInstructionArg, + FloatRegisterType, IntRegisterType, RdRsImmIntegerOperation, RdRsRsOperation, @@ -675,6 +676,82 @@ def parse(cls, parser: Parser) -> Self: return op +# endregion + +# region Snitch Packed SIMD Extension + +# Operations that map directly to the packed SIMD ISA provided by Snitch FPU. +# The implemented ISA is *almost* the one specified here: +# * https://iis-git.ee.ethz.ch/smach/smallFloat-spec/-/blob/master/smallFloat_isa.pdf +# Beware of main undocumented differences from the spec: +# * Additional reductions (e.g.: vfsum.*) +# * Missing reductions (e.g.: vfdotp.*) +# * Control of alternative FP formats (e.g.: IEEE fp16 vs BF16) delegated to the +# RISC-V float CSR instead of being part of the encoding + + +@irdl_op_definition +class VFCpkASSOp( + RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType] +): + """ + Packs two scalar f32 values from rs1 and rs2 and packs the result as two adjacent + entries into the vectorial 2xf32 rd operand, such as: + + f[rd][lo] = f[rs1] + f[rd][hi] = f[rs2] + """ + + name = "riscv_snitch.vfcpka.s.s" + + def assembly_instruction_name(self) -> str: + return "vfcpka.s.s" + + traits = frozenset((Pure(),)) + + +@irdl_op_definition +class VFMulSOp( + RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType] +): + """ + Performs vectorial multiplication of corresponding f32 values from + rs1 and rs2 and stores the results in the corresponding f32 lanes + into the vectorial 2xf32 rd operand, such as: + + f[rd][lo] = f[rs1][lo] * f[rs2][lo] + f[rd][hi] = f[rs1][hi] * f[rs2][hi] + """ + + name = "riscv_snitch.vfmul.s" + + def assembly_instruction_name(self) -> str: + return "vfmul.s" + + traits = frozenset((Pure(),)) + + +@irdl_op_definition +class VFAddSOp( + RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType] +): + """ + Performs vectorial addition of corresponding f32 values from + rs1 and rs2 and stores the results in the corresponding f32 lanes + into the vectorial 2xf32 rd operand, such as: + + f[rd][lo] = f[rs1][lo] + f[rs2][lo] + f[rd][hi] = f[rs1][hi] + f[rs2][hi] + """ + + name = "riscv_snitch.vfadd.s" + + def assembly_instruction_name(self) -> str: + return "vfadd.s" + + traits = frozenset((Pure(),)) + + # endregion RISCV_Snitch = Dialect( @@ -696,6 +773,9 @@ def parse(cls, parser: Parser) -> Self: DMCopyImmOp, DMStatOp, DMStatImmOp, + VFMulSOp, + VFAddSOp, + VFCpkASSOp, ], [], )