Skip to content

Commit

Permalink
[naga/wgsl-out]: polyfill inverse function (#6385)
Browse files Browse the repository at this point in the history
  • Loading branch information
chyyran authored Oct 11, 2024
1 parent d70ef62 commit 73764fd
Show file tree
Hide file tree
Showing 12 changed files with 367 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Support for more atomic ops in the SPIR-V frontend. By @schell in [#5824](https://github.com/gfx-rs/wgpu/pull/5824).
- Support local `const` declarations in WGSL. By @sagudev in [#6156](https://github.com/gfx-rs/wgpu/pull/6156).
- Implemented `const_assert` in WGSL. By @sagudev in [#6198](https://github.com/gfx-rs/wgpu/pull/6198).
- Support polyfilling `inverse` in WGSL. By @chyyran in [#6385](https://github.com/gfx-rs/wgpu/pull/6385).

#### General

Expand Down
1 change: 1 addition & 0 deletions naga/src/back/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Backend for [WGSL][wgsl] (WebGPU Shading Language).
[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
*/

mod polyfill;
mod writer;

use thiserror::Error;
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f16.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
fn _naga_inverse_2x2_f16(m: mat2x2<f16>) -> mat2x2<f16> {
var adj: mat2x2<f16>;
adj[0][0] = m[1][1];
adj[0][1] = -m[0][1];
adj[1][0] = -m[1][0];
adj[1][1] = m[0][0];

let det: f16 = m[0][0] * m[1][1] - m[1][0] * m[0][1];
return adj * (1 / det);
}
10 changes: 10 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_2x2_f32.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
fn _naga_inverse_2x2_f32(m: mat2x2<f32>) -> mat2x2<f32> {
var adj: mat2x2<f32>;
adj[0][0] = m[1][1];
adj[0][1] = -m[0][1];
adj[1][0] = -m[1][0];
adj[1][1] = m[0][0];

let det: f32 = m[0][0] * m[1][1] - m[1][0] * m[0][1];
return adj * (1 / det);
}
19 changes: 19 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f16.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
fn _naga_inverse_3x3_f16(m: mat3x3<f16>) -> mat3x3<f16> {
var adj: mat3x3<f16>;

adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]);
adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]);
adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]);
adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]);
adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]);
adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]);
adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]);
adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]);

let det: f16 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
- m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]));

return adj * (1 / det);
}
19 changes: 19 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_3x3_f32.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
fn _naga_inverse_3x3_f32(m: mat3x3<f32>) -> mat3x3<f32> {
var adj: mat3x3<f32>;

adj[0][0] = (m[1][1] * m[2][2] - m[2][1] * m[1][2]);
adj[1][0] = - (m[1][0] * m[2][2] - m[2][0] * m[1][2]);
adj[2][0] = (m[1][0] * m[2][1] - m[2][0] * m[1][1]);
adj[0][1] = - (m[0][1] * m[2][2] - m[2][1] * m[0][2]);
adj[1][1] = (m[0][0] * m[2][2] - m[2][0] * m[0][2]);
adj[2][1] = - (m[0][0] * m[2][1] - m[2][0] * m[0][1]);
adj[0][2] = (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
adj[1][2] = - (m[0][0] * m[1][2] - m[1][0] * m[0][2]);
adj[2][2] = (m[0][0] * m[1][1] - m[1][0] * m[0][1]);

let det: f32 = (m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
- m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]));

return adj * (1 / det);
}
43 changes: 43 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f16.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fn _naga_inverse_4x4_f16(m: mat4x4<f16>) -> mat4x4<f16> {
let sub_factor00: f16 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
let sub_factor01: f16 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
let sub_factor02: f16 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
let sub_factor03: f16 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
let sub_factor04: f16 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
let sub_factor05: f16 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
let sub_factor06: f16 = m[1][2] * m[3][3] - m[3][2] * m[1][3];
let sub_factor07: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
let sub_factor08: f16 = m[1][1] * m[3][2] - m[3][1] * m[1][2];
let sub_factor09: f16 = m[1][0] * m[3][3] - m[3][0] * m[1][3];
let sub_factor10: f16 = m[1][0] * m[3][2] - m[3][0] * m[1][2];
let sub_factor11: f16 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
let sub_factor12: f16 = m[1][0] * m[3][1] - m[3][0] * m[1][1];
let sub_factor13: f16 = m[1][2] * m[2][3] - m[2][2] * m[1][3];
let sub_factor14: f16 = m[1][1] * m[2][3] - m[2][1] * m[1][3];
let sub_factor15: f16 = m[1][1] * m[2][2] - m[2][1] * m[1][2];
let sub_factor16: f16 = m[1][0] * m[2][3] - m[2][0] * m[1][3];
let sub_factor17: f16 = m[1][0] * m[2][2] - m[2][0] * m[1][2];
let sub_factor18: f16 = m[1][0] * m[2][1] - m[2][0] * m[1][1];

var adj: mat4x4<f16>;
adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02);
adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04);
adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05);
adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05);
adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02);
adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04);
adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05);
adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05);
adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08);
adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10);
adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12);
adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12);
adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15);
adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17);
adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18);
adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18);

let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]);

return adj * (1 / det);
}
43 changes: 43 additions & 0 deletions naga/src/back/wgsl/polyfill/inverse/inverse_4x4_f32.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fn _naga_inverse_4x4_f32(m: mat4x4<f32>) -> mat4x4<f32> {
let sub_factor00: f32 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
let sub_factor01: f32 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
let sub_factor02: f32 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
let sub_factor03: f32 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
let sub_factor04: f32 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
let sub_factor05: f32 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
let sub_factor06: f32 = m[1][2] * m[3][3] - m[3][2] * m[1][3];
let sub_factor07: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
let sub_factor08: f32 = m[1][1] * m[3][2] - m[3][1] * m[1][2];
let sub_factor09: f32 = m[1][0] * m[3][3] - m[3][0] * m[1][3];
let sub_factor10: f32 = m[1][0] * m[3][2] - m[3][0] * m[1][2];
let sub_factor11: f32 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
let sub_factor12: f32 = m[1][0] * m[3][1] - m[3][0] * m[1][1];
let sub_factor13: f32 = m[1][2] * m[2][3] - m[2][2] * m[1][3];
let sub_factor14: f32 = m[1][1] * m[2][3] - m[2][1] * m[1][3];
let sub_factor15: f32 = m[1][1] * m[2][2] - m[2][1] * m[1][2];
let sub_factor16: f32 = m[1][0] * m[2][3] - m[2][0] * m[1][3];
let sub_factor17: f32 = m[1][0] * m[2][2] - m[2][0] * m[1][2];
let sub_factor18: f32 = m[1][0] * m[2][1] - m[2][0] * m[1][1];

var adj: mat4x4<f32>;
adj[0][0] = (m[1][1] * sub_factor00 - m[1][2] * sub_factor01 + m[1][3] * sub_factor02);
adj[1][0] = - (m[1][0] * sub_factor00 - m[1][2] * sub_factor03 + m[1][3] * sub_factor04);
adj[2][0] = (m[1][0] * sub_factor01 - m[1][1] * sub_factor03 + m[1][3] * sub_factor05);
adj[3][0] = - (m[1][0] * sub_factor02 - m[1][1] * sub_factor04 + m[1][2] * sub_factor05);
adj[0][1] = - (m[0][1] * sub_factor00 - m[0][2] * sub_factor01 + m[0][3] * sub_factor02);
adj[1][1] = (m[0][0] * sub_factor00 - m[0][2] * sub_factor03 + m[0][3] * sub_factor04);
adj[2][1] = - (m[0][0] * sub_factor01 - m[0][1] * sub_factor03 + m[0][3] * sub_factor05);
adj[3][1] = (m[0][0] * sub_factor02 - m[0][1] * sub_factor04 + m[0][2] * sub_factor05);
adj[0][2] = (m[0][1] * sub_factor06 - m[0][2] * sub_factor07 + m[0][3] * sub_factor08);
adj[1][2] = - (m[0][0] * sub_factor06 - m[0][2] * sub_factor09 + m[0][3] * sub_factor10);
adj[2][2] = (m[0][0] * sub_factor11 - m[0][1] * sub_factor09 + m[0][3] * sub_factor12);
adj[3][2] = - (m[0][0] * sub_factor08 - m[0][1] * sub_factor10 + m[0][2] * sub_factor12);
adj[0][3] = - (m[0][1] * sub_factor13 - m[0][2] * sub_factor14 + m[0][3] * sub_factor15);
adj[1][3] = (m[0][0] * sub_factor13 - m[0][2] * sub_factor16 + m[0][3] * sub_factor17);
adj[2][3] = - (m[0][0] * sub_factor14 - m[0][1] * sub_factor16 + m[0][3] * sub_factor18);
adj[3][3] = (m[0][0] * sub_factor15 - m[0][1] * sub_factor17 + m[0][2] * sub_factor18);

let det = (m[0][0] * adj[0][0] + m[0][1] * adj[1][0] + m[0][2] * adj[2][0] + m[0][3] * adj[3][0]);

return adj * (1 / det);
}
66 changes: 66 additions & 0 deletions naga/src/back/wgsl/polyfill/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use crate::{ScalarKind, TypeInner, VectorSize};

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct InversePolyfill {
pub fun_name: &'static str,
pub source: &'static str,
}

impl InversePolyfill {
pub fn find_overload(ty: &TypeInner) -> Option<InversePolyfill> {
let &TypeInner::Matrix {
columns,
rows,
scalar,
} = ty
else {
return None;
};

if columns != rows || scalar.kind != ScalarKind::Float {
return None;
};

Self::polyfill_overload(columns, scalar.width)
}

const fn polyfill_overload(
dimension: VectorSize,
width: crate::Bytes,
) -> Option<InversePolyfill> {
const INVERSE_2X2_F32: &str = include_str!("inverse/inverse_2x2_f32.wgsl");
const INVERSE_3X3_F32: &str = include_str!("inverse/inverse_3x3_f32.wgsl");
const INVERSE_4X4_F32: &str = include_str!("inverse/inverse_4x4_f32.wgsl");
const INVERSE_2X2_F16: &str = include_str!("inverse/inverse_2x2_f16.wgsl");
const INVERSE_3X3_F16: &str = include_str!("inverse/inverse_3x3_f16.wgsl");
const INVERSE_4X4_F16: &str = include_str!("inverse/inverse_4x4_f16.wgsl");

match (dimension, width) {
(VectorSize::Bi, 4) => Some(InversePolyfill {
fun_name: "_naga_inverse_2x2_f32",
source: INVERSE_2X2_F32,
}),
(VectorSize::Tri, 4) => Some(InversePolyfill {
fun_name: "_naga_inverse_3x3_f32",
source: INVERSE_3X3_F32,
}),
(VectorSize::Quad, 4) => Some(InversePolyfill {
fun_name: "_naga_inverse_4x4_f32",
source: INVERSE_4X4_F32,
}),
(VectorSize::Bi, 2) => Some(InversePolyfill {
fun_name: "_naga_inverse_2x2_f16",
source: INVERSE_2X2_F16,
}),
(VectorSize::Tri, 2) => Some(InversePolyfill {
fun_name: "_naga_inverse_3x3_f16",
source: INVERSE_3X3_F16,
}),
(VectorSize::Quad, 2) => Some(InversePolyfill {
fun_name: "_naga_inverse_4x4_f16",
source: INVERSE_4X4_F16,
}),
_ => None,
}
}
}
31 changes: 28 additions & 3 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::Error;
use crate::back::wgsl::polyfill::InversePolyfill;
use crate::{
back::{self, Baked},
proc::{self, ExpressionKindTracker, NameKey},
Expand Down Expand Up @@ -68,6 +69,7 @@ pub struct Writer<W> {
namer: proc::Namer,
named_expressions: crate::NamedExpressions,
ep_results: Vec<(ShaderStage, Handle<crate::Type>)>,
required_polyfills: crate::FastIndexSet<InversePolyfill>,
}

impl<W: Write> Writer<W> {
Expand All @@ -79,6 +81,7 @@ impl<W: Write> Writer<W> {
namer: proc::Namer::default(),
named_expressions: crate::NamedExpressions::default(),
ep_results: vec![],
required_polyfills: crate::FastIndexSet::default(),
}
}

Expand All @@ -90,11 +93,12 @@ impl<W: Write> Writer<W> {
// an identifier must not start with two underscore
&[],
&[],
&["__"],
&["__", "_naga"],
&mut self.names,
);
self.named_expressions.clear();
self.ep_results.clear();
self.required_polyfills.clear();
}

fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool {
Expand Down Expand Up @@ -203,6 +207,13 @@ impl<W: Write> Writer<W> {
}
}

// Write any polyfills that were required.
for polyfill in &self.required_polyfills {
writeln!(self.out)?;
write!(self.out, "{}", polyfill.source)?;
writeln!(self.out)?;
}

Ok(())
}

Expand Down Expand Up @@ -1653,6 +1664,7 @@ impl<W: Write> Writer<W> {

enum Function {
Regular(&'static str),
InversePolyfill(InversePolyfill),
}

let function = match fun {
Expand Down Expand Up @@ -1736,9 +1748,16 @@ impl<W: Write> Writer<W> {
Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
Mf::Unpack4xI8 => Function::Regular("unpack4xI8"),
Mf::Unpack4xU8 => Function::Regular("unpack4xU8"),
Mf::Inverse | Mf::Outer => {
return Err(Error::UnsupportedMathFunction(fun));
Mf::Inverse => {
let typ = func_ctx.resolve_type(arg, &module.types);

let Some(overload) = InversePolyfill::find_overload(typ) else {
return Err(Error::UnsupportedMathFunction(fun));
};

Function::InversePolyfill(overload)
}
Mf::Outer => return Err(Error::UnsupportedMathFunction(fun)),
};

match function {
Expand All @@ -1751,6 +1770,12 @@ impl<W: Write> Writer<W> {
}
write!(self.out, ")")?
}
Function::InversePolyfill(inverse) => {
write!(self.out, "{}(", inverse.fun_name)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")")?;
self.required_polyfills.insert(inverse);
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions naga/tests/in/glsl/inverse-polyfill.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

void main() {
vec4 a4 = vec4(1.0);
vec4 b4 = vec4(2.0);
mat4 m4 = mat4(a4, b4, a4, b4);

vec3 a3 = vec3(1.0);
vec3 b3 = vec3(2.0);
mat3 m3 = mat3(a3, b3, a3);

mat2 m2 = mat2(1.0, 2.0, 3.0, 4.0);

mat4 m4_inverse = inverse(m4);
mat3 m3_inverse = inverse(m3);
mat2 m2_inverse = inverse(m2);
}
Loading

0 comments on commit 73764fd

Please sign in to comment.