Skip to content

Commit

Permalink
[naga] Implement quantizeToF16 (#6519)
Browse files Browse the repository at this point in the history
Implement WGSL frontend and WGSL, SPIR-V, HLSL, MSL, and GLSL
backends. WGSL and SPIR-V backends natively support the instruction.
MSL and HLSL emulate it by casting to f16 and back to f32. GLSL does
similar but must (mis)use (un)pack2x16 to do so.
  • Loading branch information
jamienicol authored Nov 12, 2024
1 parent 6a60458 commit cffc793
Show file tree
Hide file tree
Showing 17 changed files with 206 additions and 76 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Bottom level categories:
- Parse `diagnostic(…)` directives, but don't implement any triggering rules yet. By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456).
- Fix an issue where `naga` CLI would incorrectly skip the first positional argument when `--stdin-file-path` was specified. By @ErichDonGubler in [#6480](https://github.com/gfx-rs/wgpu/pull/6480).
- Fix textureNumLevels in the GLSL backend. By @magcius in [#6483](https://github.com/gfx-rs/wgpu/pull/6483).
- Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519).

#### General

Expand Down
45 changes: 44 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,8 @@ impl<'a, W: Write> Writer<'a, W> {
crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Unpack4xI8
| crate::MathFunction::Unpack4xU8 => {
| crate::MathFunction::Unpack4xU8
| crate::MathFunction::QuantizeToF16 => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::ExtractBits => {
Expand Down Expand Up @@ -3495,6 +3496,48 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Inverse => "inverse",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => match *ctx.resolve_type(arg, &self.module.types) {
crate::TypeInner::Scalar { .. } => {
write!(self.out, "unpackHalf2x16(packHalf2x16(vec2(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))).x")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => {
write!(self.out, "unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => {
write!(self.out, "vec3(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zz)).x)")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => {
write!(self.out, "vec4(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zw)))")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
},
// bits
Mf::CountTrailingZeros => {
match *ctx.resolve_type(arg, &self.module.types) {
Expand Down
7 changes: 7 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3036,6 +3036,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack4x8unorm,
Unpack4xI8,
Unpack4xU8,
QuantizeToF16,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
Expand Down Expand Up @@ -3102,6 +3103,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
//Mf::Inverse =>,
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::QuantizeToF16,
// bits
Mf::CountTrailingZeros => Function::CountTrailingZeros,
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Expand Down Expand Up @@ -3303,6 +3305,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
Function::Regular(fun_name) => {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
Expand Down
17 changes: 17 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,7 @@ impl<W: Write> Writer<W> {
Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "",
// bits
Mf::CountTrailingZeros => "ctz",
Mf::CountLeadingZeros => "clz",
Expand Down Expand Up @@ -2144,6 +2145,22 @@ impl<W: Write> Writer<W> {
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Mf::QuantizeToF16 => {
match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
crate::TypeInner::Vector { size, .. } => write!(
self.out,
"{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
size = back::vector_size_str(size),
)?,
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
};

self.put_expression(arg, context, true)?;
write!(self.out, "))")?;
}
_ => {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,12 @@ impl<'w> BlockContext<'w> {
arg0_id,
)),
Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
spirv::Op::QuantizeToF16,
result_type_id,
id,
arg0_id,
)),
Mf::ReverseBits => MathOp::Custom(Instruction::unary(
spirv::Op::BitReverse,
result_type_id,
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ impl<W: Write> Writer<W> {
Mf::InverseSqrt => Function::Regular("inverseSqrt"),
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"inverseSqrt" => Mf::InverseSqrt,
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
"quantizeToF16" => Mf::QuantizeToF16,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,
Expand Down
1 change: 1 addition & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ pub enum MathFunction {
Inverse,
Transpose,
Determinant,
QuantizeToF16,
// bits
CountTrailingZeros,
CountLeadingZeros,
Expand Down
1 change: 1 addition & 0 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ impl super::MathFunction {
Self::Inverse => 1,
Self::Transpose => 1,
Self::Determinant => 1,
Self::QuantizeToF16 => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,
Expand Down
3 changes: 2 additions & 1 deletion naga/src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ impl<'a> ResolveContext<'a> {
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Pow => res_arg.clone(),
| Mf::Pow
| Mf::QuantizeToF16 => res_arg.clone(),
Mf::Modf | Mf::Frexp => {
let (size, width) = match res_arg.inner_with(types) {
&Ti::Scalar(crate::Scalar {
Expand Down
20 changes: 20 additions & 0 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,26 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::QuantizeToF16 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width: 4,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width: 4,
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
Mf::CountLeadingZeros
| Mf::CountTrailingZeros
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ fn main() {
let frexp_b = frexp(1.5).fract;
let frexp_c: i32 = frexp(1.5).exp;
let frexp_d: i32 = frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp.x;
let quantizeToF16_a: f32 = quantizeToF16(1.0);
let quantizeToF16_b: vec2<f32> = quantizeToF16(vec2(1.0, 1.0));
let quantizeToF16_c: vec3<f32> = quantizeToF16(vec3(1.0, 1.0, 1.0));
let quantizeToF16_d: vec4<f32> = quantizeToF16(vec4(1.0, 1.0, 1.0, 1.0));
}
7 changes: 7 additions & 0 deletions naga/tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,12 @@ void main() {
float frexp_b = naga_frexp(1.5).fract_;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = unpackHalf2x16(packHalf2x16(vec2(1.0))).x;
vec2 _e120 = vec2(1.0, 1.0);
vec2 quantizeToF16_b = unpackHalf2x16(packHalf2x16(_e120));
vec3 _e125 = vec3(1.0, 1.0, 1.0);
vec3 quantizeToF16_c = vec3(unpackHalf2x16(packHalf2x16(_e125.xy)), unpackHalf2x16(packHalf2x16(_e125.zz)).x);
vec4 _e131 = vec4(1.0, 1.0, 1.0, 1.0);
vec4 quantizeToF16_d = vec4(unpackHalf2x16(packHalf2x16(_e131.xy)), unpackHalf2x16(packHalf2x16(_e131.zw)));
}

4 changes: 4 additions & 0 deletions naga/tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,8 @@ void main()
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(float4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = f16tof32(f32tof16(1.0));
float2 quantizeToF16_b = f16tof32(f32tof16(float2(1.0, 1.0)));
float3 quantizeToF16_c = f16tof32(f32tof16(float3(1.0, 1.0, 1.0)));
float4 quantizeToF16_d = f16tof32(f32tof16(float4(1.0, 1.0, 1.0, 1.0)));
}
4 changes: 4 additions & 0 deletions naga/tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ fragment void main_(
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp;
int frexp_d = naga_frexp(metal::float4(1.5, 1.5, 1.5, 1.5)).exp.x;
float quantizeToF16_a = float(half(1.0));
metal::float2 quantizeToF16_b = metal::float2(metal::half2(metal::float2(1.0, 1.0)));
metal::float3 quantizeToF16_c = metal::float3(metal::half3(metal::float3(1.0, 1.0, 1.0)));
metal::float4 quantizeToF16_d = metal::float4(metal::half4(metal::float4(1.0, 1.0, 1.0, 1.0)));
}
Loading

0 comments on commit cffc793

Please sign in to comment.