Skip to content

Commit

Permalink
[RISCV] Use vnclip for scalable vector saturating truncation. (llvm#8…
Browse files Browse the repository at this point in the history
…8648)

Similar to llvm#75145, but for scalable vectors.

Specifically, this patch works for the below optimization case:

## Source Code
```
define void @trunc_sat_i8i16_maxmin(ptr %x, ptr %y) {
  %1 = load <vscale x 4 x i16>, ptr %x, align 16
  %2 = tail call <vscale x 4 x i16> @llvm.smax.v4i16(<vscale x 4 x i16> %1, <vscale x 4 x i16> splat (i16 -128))
  %3 = tail call <vscale x 4 x i16> @llvm.smin.v4i16(<vscale x 4 x i16> %2, <vscale x 4 x i16> splat (i16 127))
  %4 = trunc <vscale x 4 x i16> %3 to <vscale x 4 x i8>
  store <vscale x 4 x i8> %4, ptr %y, align 8
  ret void
}
```
## Before this patch
[Compiler Explorer](https://godbolt.org/z/EKc9eGvo8)
```
trunc_sat_i8i16_maxmin:
        vl1re16.v       v8, (a0)
        li      a0, -128
        vsetvli a2, zero, e16, m1, ta, ma
        vmax.vx v8, v8, a0
        li      a0, 127
        vmin.vx v8, v8, a0
        vsetvli zero, zero, e8, mf2, ta, ma
        vnsrl.wi        v8, v8, 0
        vse8.v  v8, (a1)
        ret
```
## After this patch
```
trunc_sat_i8i16_maxmin:
        vsetivli zero, 4, e8, mf4, ta, ma
        vle16.v v8, (a0)
        vnclip.wi v8, v8, 0
        vse8.v v8, (a1)
        ret
```
  • Loading branch information
sun-jacobi authored Apr 18, 2024
1 parent 808d794 commit 0afc884
Show file tree
Hide file tree
Showing 4 changed files with 455 additions and 61 deletions.
41 changes: 41 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,47 @@ defm : VPatBinarySDNode_VV_VX<usubsat, "PseudoVSSUBU">;
defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10>;
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00>;

// 12.5. Vector Narrowing Fixed-Point Clip Instructions
multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
defvar sew = vti.SEW;
defvar uminval = !sub(!shl(1, sew), 1);
defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));

let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (smin
(wti.Vector (smax (wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
(vti.Mask V0), VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;

def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (smax
(wti.Vector (smin (wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
(vti.Mask V0), VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;

def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}

foreach vtiToWti = AllWidenableIntVectors in
defm : VPatTruncSatClipSDNode<vtiToWti.Vti, vtiToWti.Wti>;

// 15. Vector Mask Instructions

// 15.1. Vector Mask-Register Logical Instructions
Expand Down
66 changes: 29 additions & 37 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -2373,30 +2373,41 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00>;

// 12.5. Vector Narrowing Fixed-Point Clip Instructions
class VPatTruncSatClipMaxMinBase<string inst,
VTypeInfo vti,
VTypeInfo wti,
SDPatternOperator op1,
int op1_value,
SDPatternOperator op2,
int op2_value> :
Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (op1
(wti.Vector (op2
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
defvar sew = vti.SEW;
defvar uminval = !sub(!shl(1, sew), 1);
defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));

let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (riscv_smin_vl
(wti.Vector (riscv_smax_vl
(wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), op2_value, (XLenVT srcvalue))),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
(wti.Vector undef),(wti.Mask V0), VLOpFrag)),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), op1_value, (XLenVT srcvalue))),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
(wti.Vector undef), (wti.Mask V0), VLOpFrag)),
(vti.Mask V0), VLOpFrag)),
(!cast<Instruction>(inst#"_WI_"#vti.LMul.MX#"_MASK")
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;

class VPatTruncSatClipUMin<VTypeInfo vti,
VTypeInfo wti,
int uminval> :
Pat<(vti.Vector (riscv_trunc_vector_vl
def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (riscv_smax_vl
(wti.Vector (riscv_smin_vl
(wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
(wti.Vector undef),(wti.Mask V0), VLOpFrag)),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
(wti.Vector undef), (wti.Mask V0), VLOpFrag)),
(vti.Mask V0), VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;

def : Pat<(vti.Vector (riscv_trunc_vector_vl
(wti.Vector (riscv_umin_vl
(wti.Vector wti.RegClass:$rs1),
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
Expand All @@ -2405,30 +2416,11 @@ class VPatTruncSatClipUMin<VTypeInfo vti,
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;

multiclass VPatTruncSatClipMaxMin<string inst, VTypeInfo vti, VTypeInfo wti,
SDPatternOperator max, int maxval, SDPatternOperator min, int minval> {
def : VPatTruncSatClipMaxMinBase<inst, vti, wti, max, maxval, min, minval>;
def : VPatTruncSatClipMaxMinBase<inst, vti, wti, min, minval, max, maxval>;
}

multiclass VPatTruncSatClip<VTypeInfo vti, VTypeInfo wti> {
defvar sew = vti.SEW;
defvar uminval = !sub(!shl(1, sew), 1);
defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));

let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
defm : VPatTruncSatClipMaxMin<"PseudoVNCLIP", vti, wti, riscv_smin_vl,
sminval, riscv_smax_vl, smaxval>;
def : VPatTruncSatClipUMin<vti, wti, uminval>;
}

}

foreach vtiToWti = AllWidenableIntVectors in
defm : VPatTruncSatClip<vtiToWti.Vti, vtiToWti.Wti>;
defm : VPatTruncSatClipVL<vtiToWti.Vti, vtiToWti.Wti>;

// 13. Vector Floating-Point Instructions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>)
declare <4 x i64> @llvm.smax.v4i64(<4 x i64>, <4 x i64>)
declare <4 x i64> @llvm.smin.v4i64(<4 x i64>, <4 x i64>)

declare <4 x i16> @llvm.umax.v4i16(<4 x i16>, <4 x i16>)
declare <4 x i16> @llvm.umin.v4i16(<4 x i16>, <4 x i16>)
declare <4 x i32> @llvm.umax.v4i32(<4 x i32>, <4 x i32>)
declare <4 x i32> @llvm.umin.v4i32(<4 x i32>, <4 x i32>)
declare <4 x i64> @llvm.umax.v4i64(<4 x i64>, <4 x i64>)
declare <4 x i64> @llvm.umin.v4i64(<4 x i64>, <4 x i64>)

define void @trunc_sat_i8i16_maxmin(ptr %x, ptr %y) {
Expand Down Expand Up @@ -110,10 +107,9 @@ define void @trunc_sat_u8u16_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vse8.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i16>, ptr %x, align 16
%2 = tail call <4 x i16> @llvm.umax.v4i16(<4 x i16> %1, <4 x i16> <i16 0, i16 0, i16 0, i16 0>)
%3 = tail call <4 x i16> @llvm.umin.v4i16(<4 x i16> %2, <4 x i16> <i16 255, i16 255, i16 255, i16 255>)
%4 = trunc <4 x i16> %3 to <4 x i8>
store <4 x i8> %4, ptr %y, align 8
%2 = tail call <4 x i16> @llvm.umin.v4i16(<4 x i16> %1, <4 x i16> <i16 255, i16 255, i16 255, i16 255>)
%3 = trunc <4 x i16> %2 to <4 x i8>
store <4 x i8> %3, ptr %y, align 8
ret void
}

Expand All @@ -127,9 +123,8 @@ define void @trunc_sat_u8u16_minmax(ptr %x, ptr %y) {
; CHECK-NEXT: ret
%1 = load <4 x i16>, ptr %x, align 16
%2 = tail call <4 x i16> @llvm.umin.v4i16(<4 x i16> %1, <4 x i16> <i16 255, i16 255, i16 255, i16 255>)
%3 = tail call <4 x i16> @llvm.umax.v4i16(<4 x i16> %2, <4 x i16> <i16 0, i16 0, i16 0, i16 0>)
%4 = trunc <4 x i16> %3 to <4 x i8>
store <4 x i8> %4, ptr %y, align 8
%3 = trunc <4 x i16> %2 to <4 x i8>
store <4 x i8> %3, ptr %y, align 8
ret void
}

Expand Down Expand Up @@ -231,10 +226,9 @@ define void @trunc_sat_u16u32_minmax(ptr %x, ptr %y) {
; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i32>, ptr %x, align 32
%2 = tail call <4 x i32> @llvm.umax.v4i32(<4 x i32> %1, <4 x i32> <i32 0, i32 0, i32 0, i32 0>)
%3 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %2, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%4 = trunc <4 x i32> %3 to <4 x i16>
store <4 x i16> %4, ptr %y, align 16
%2 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %1, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%3 = trunc <4 x i32> %2 to <4 x i16>
store <4 x i16> %3, ptr %y, align 16
ret void
}

Expand All @@ -248,9 +242,8 @@ define void @trunc_sat_u16u32_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: ret
%1 = load <4 x i32>, ptr %x, align 32
%2 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %1, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%3 = tail call <4 x i32> @llvm.umax.v4i32(<4 x i32> %2, <4 x i32> <i32 0, i32 0, i32 0, i32 0>)
%4 = trunc <4 x i32> %3 to <4 x i16>
store <4 x i16> %4, ptr %y, align 16
%3 = trunc <4 x i32> %2 to <4 x i16>
store <4 x i16> %3, ptr %y, align 16
ret void
}

Expand Down Expand Up @@ -355,10 +348,9 @@ define void @trunc_sat_u32u64_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vse32.v v10, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i64>, ptr %x, align 64
%2 = tail call <4 x i64> @llvm.umax.v4i64(<4 x i64> %1, <4 x i64> <i64 0, i64 0, i64 0, i64 0>)
%3 = tail call <4 x i64> @llvm.umin.v4i64(<4 x i64> %2, <4 x i64> <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>)
%4 = trunc <4 x i64> %3 to <4 x i32>
store <4 x i32> %4, ptr %y, align 32
%2 = tail call <4 x i64> @llvm.umin.v4i64(<4 x i64> %1, <4 x i64> <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>)
%3 = trunc <4 x i64> %2 to <4 x i32>
store <4 x i32> %3, ptr %y, align 32
ret void
}

Expand All @@ -372,8 +364,7 @@ define void @trunc_sat_u32u64_minmax(ptr %x, ptr %y) {
; CHECK-NEXT: ret
%1 = load <4 x i64>, ptr %x, align 64
%2 = tail call <4 x i64> @llvm.umin.v4i64(<4 x i64> %1, <4 x i64> <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>)
%3 = tail call <4 x i64> @llvm.umax.v4i64(<4 x i64> %2, <4 x i64> <i64 0, i64 0, i64 0, i64 0>)
%4 = trunc <4 x i64> %3 to <4 x i32>
store <4 x i32> %4, ptr %y, align 32
%3 = trunc <4 x i64> %2 to <4 x i32>
store <4 x i32> %3, ptr %y, align 32
ret void
}
Loading

0 comments on commit 0afc884

Please sign in to comment.