Skip to content

Commit

Permalink
dialects: (riscv_snitch) Add f32 mul, add, pack from Snitch packed SI…
Browse files Browse the repository at this point in the history
…MD extension (#2872)

This PR adds a small set of operations to `riscv_snitch` directly
mapping on top of the Snitch custom packed SIMD ISA. Further ops are
going to need operand constraints to be correctly register allocated.
  • Loading branch information
nazavode authored Jul 10, 2024
1 parent 57e3075 commit 987abb5
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ riscv_func.func @main() {
%4 = riscv_snitch.dmstat %3 : (!riscv.reg<a3>) -> !riscv.reg<a4>
%5 = riscv_snitch.dmstati 22 : () -> !riscv.reg<a5>

%f0 = riscv.get_float_register : !riscv.freg<ft0>
%f1 = riscv_snitch.vfmul.s %f0, %f0 : (!riscv.freg<ft0>, !riscv.freg<ft0>) -> !riscv.freg<ft1>
%f2 = riscv_snitch.vfadd.s %f0, %f0 : (!riscv.freg<ft0>, !riscv.freg<ft0>) -> !riscv.freg<ft1>
%f3 = riscv_snitch.vfcpka.s.s %f0, %f0 : (!riscv.freg<ft0>, !riscv.freg<ft0>) -> !riscv.freg<ft1>

riscv_func.return
}

Expand All @@ -36,4 +41,7 @@ riscv_func.func @main() {
// CHECK-NEXT: dmcpy a3, a0, a2
// CHECK-NEXT: dmstat a4, a3
// CHECK-NEXT: dmstati a5, 22
// CHECK-NEXT: vfmul.s ft1, ft0, ft0
// CHECK-NEXT: vfadd.s ft1, ft0, ft0
// CHECK-NEXT: vfcpka.s.s ft1, ft0, ft0
// CHECK-NEXT: ret
23 changes: 23 additions & 0 deletions tests/filecheck/dialects/riscv_snitch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ riscv_func.func @xdma() {
riscv_func.return
}

riscv_func.func @simd() {
%v = riscv.get_float_register : !riscv.freg
// CHECK: %v = riscv.get_float_register : !riscv.freg

%0 = riscv_snitch.vfmul.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-NEXT: %0 = riscv_snitch.vfmul.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg

%1 = riscv_snitch.vfadd.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-NEXT: %1 = riscv_snitch.vfadd.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg

%2 = riscv_snitch.vfcpka.s.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-NEXT: %2 = riscv_snitch.vfcpka.s.s %v, %v : (!riscv.freg, !riscv.freg) -> !riscv.freg

riscv_func.return
}


// CHECK-GENERIC-NEXT: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "riscv_func.func"() ({
Expand Down Expand Up @@ -129,6 +145,13 @@ riscv_func.func @xdma() {
// CHECK-GENERIC-NEXT: %{{.*}} = "riscv_snitch.dmstati"() <{"status" = 0 : ui5}> : () -> !riscv.reg
// CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) {"sym_name" = "xdma", "function_type" = () -> ()} : () -> ()
// CHECK-GENERIC-NEXT: "riscv_func.func"() ({
// CHECK-GENERIC-NEXT: %v = "riscv.get_float_register"() : () -> !riscv.freg
// CHECK-GENERIC-NEXT: %0 = "riscv_snitch.vfmul.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %1 = "riscv_snitch.vfadd.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %2 = "riscv_snitch.vfcpka.s.s"(%v, %v) : (!riscv.freg, !riscv.freg) -> !riscv.freg
// CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) {"sym_name" = "simd", "function_type" = () -> ()} : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()


80 changes: 80 additions & 0 deletions xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from xdsl.dialects.riscv import (
AssemblyInstructionArg,
FloatRegisterType,
IntRegisterType,
RdRsImmIntegerOperation,
RdRsRsOperation,
Expand Down Expand Up @@ -675,6 +676,82 @@ def parse(cls, parser: Parser) -> Self:
return op


# endregion

# region Snitch Packed SIMD Extension

# Operations that map directly to the packed SIMD ISA provided by Snitch FPU.
# The implemented ISA is *almost* the one specified here:
# * https://iis-git.ee.ethz.ch/smach/smallFloat-spec/-/blob/master/smallFloat_isa.pdf
# Beware of main undocumented differences from the spec:
# * Additional reductions (e.g.: vfsum.*)
# * Missing reductions (e.g.: vfdotp.*)
# * Control of alternative FP formats (e.g.: IEEE fp16 vs BF16) delegated to the
# RISC-V float CSR instead of being part of the encoding


@irdl_op_definition
class VFCpkASSOp(
RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType]
):
"""
Packs two scalar f32 values from rs1 and rs2 and packs the result as two adjacent
entries into the vectorial 2xf32 rd operand, such as:
f[rd][lo] = f[rs1]
f[rd][hi] = f[rs2]
"""

name = "riscv_snitch.vfcpka.s.s"

def assembly_instruction_name(self) -> str:
return "vfcpka.s.s"

traits = frozenset((Pure(),))


@irdl_op_definition
class VFMulSOp(
RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType]
):
"""
Performs vectorial multiplication of corresponding f32 values from
rs1 and rs2 and stores the results in the corresponding f32 lanes
into the vectorial 2xf32 rd operand, such as:
f[rd][lo] = f[rs1][lo] * f[rs2][lo]
f[rd][hi] = f[rs1][hi] * f[rs2][hi]
"""

name = "riscv_snitch.vfmul.s"

def assembly_instruction_name(self) -> str:
return "vfmul.s"

traits = frozenset((Pure(),))


@irdl_op_definition
class VFAddSOp(
RdRsRsOperation[FloatRegisterType, FloatRegisterType, FloatRegisterType]
):
"""
Performs vectorial addition of corresponding f32 values from
rs1 and rs2 and stores the results in the corresponding f32 lanes
into the vectorial 2xf32 rd operand, such as:
f[rd][lo] = f[rs1][lo] + f[rs2][lo]
f[rd][hi] = f[rs1][hi] + f[rs2][hi]
"""

name = "riscv_snitch.vfadd.s"

def assembly_instruction_name(self) -> str:
return "vfadd.s"

traits = frozenset((Pure(),))


# endregion

RISCV_Snitch = Dialect(
Expand All @@ -696,6 +773,9 @@ def parse(cls, parser: Parser) -> Self:
DMCopyImmOp,
DMStatOp,
DMStatImmOp,
VFMulSOp,
VFAddSOp,
VFCpkASSOp,
],
[],
)

0 comments on commit 987abb5

Please sign in to comment.