diff --git a/CHANGELOG.md b/CHANGELOG.md index 036e55ef09..b763d71ef0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -115,6 +115,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429) - Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). - Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590). +- Implement type inference for abstract arguments to user-defined functions. By @jamienicol in [#6577](https://github.com/gfx-rs/wgpu/pull/6577). #### General diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index f221ff97c6..95a4902d16 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2184,7 +2184,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments .iter() - .map(|&arg| self.expression(arg, ctx)) + .enumerate() + .map(|(i, &arg)| { + // Try to convert abstract values to the known argument types + let Some(&crate::FunctionArgument { + ty: parameter_ty, .. + }) = ctx.module.functions[function].arguments.get(i) + else { + // Wrong number of arguments... just concretize the type here + // and let the validator report the error. + return self.expression(arg, ctx); + }; + + let expr = self.expression_for_abstract(arg, ctx)?; + ctx.try_automatic_conversions( + expr, + &crate::proc::TypeResolution::Handle(parameter_ty), + ctx.ast_expressions.get_span(arg), + ) + }) .collect::, _>>()?; let has_result = ctx.module.functions[function].result.is_some(); diff --git a/naga/src/front/wgsl/tests.rs b/naga/src/front/wgsl/tests.rs index 1ac8f33472..3ae006f9d4 100644 --- a/naga/src/front/wgsl/tests.rs +++ b/naga/src/front/wgsl/tests.rs @@ -84,6 +84,28 @@ fn parse_type_cast() { .is_err()); } +#[test] +fn parse_type_coercion() { + parse_str( + " + fn foo(bar: f32) {} + fn main() { + foo(0); + } + ", + ) + .unwrap(); + assert!(parse_str( + " + fn foo(bar: i32) {} + fn main() { + foo(0.0); + } + ", + ) + .is_err()); +} + #[test] fn parse_struct() { parse_str( @@ -461,7 +483,7 @@ fn binary_expression_mixed_scalar_and_vector_operands() { #[test] fn parse_pointers() { parse_str( - "fn foo(a: ptr) -> f32 { return *a; } + "fn foo(a: ptr) -> f32 { return *a; } fn bar() { var x: f32 = 1.0; let px = &x; diff --git a/naga/tests/in/abstract-types-function-calls.wgsl b/naga/tests/in/abstract-types-function-calls.wgsl new file mode 100644 index 0000000000..9c3c28ef06 --- /dev/null +++ b/naga/tests/in/abstract-types-function-calls.wgsl @@ -0,0 +1,38 @@ +fn func_f(a: f32) {} +fn func_i(a: i32) {} +fn func_u(a: u32) {} + +fn func_vf(a: vec2) {} +fn func_vi(a: vec2) {} +fn func_vu(a: vec2) {} + +fn func_mf(a: mat2x2) {} + +fn func_af(a: array) {} +fn func_ai(a: array) {} +fn func_au(a: array) {} + +fn func_f_i(a: f32, b: i32) {} + +fn main() { + func_f(0.0); + func_f(0); + func_i(0); + func_u(0); + + func_vf(vec2(0.0)); + func_vf(vec2(0)); + func_vi(vec2(0)); + func_vu(vec2(0)); + + func_mf(mat2x2(vec2(0.0), vec2(0.0))); + func_mf(mat2x2(vec2(0), vec2(0))); + + func_af(array(0.0, 0.0)); + func_af(array(0, 0)); + func_ai(array(0, 0)); + func_au(array(0, 0)); + + func_f_i(0.0, 0); + func_f_i(0, 0); +} diff --git a/naga/tests/out/hlsl/abstract-types-function-calls.ron b/naga/tests/out/hlsl/abstract-types-function-calls.ron new file mode 100644 index 0000000000..4d056ac29b --- /dev/null +++ b/naga/tests/out/hlsl/abstract-types-function-calls.ron @@ -0,0 +1,8 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ], +) diff --git a/naga/tests/out/msl/abstract-types-function-calls.msl b/naga/tests/out/msl/abstract-types-function-calls.msl new file mode 100644 index 0000000000..bf62d87a54 --- /dev/null +++ b/naga/tests/out/msl/abstract-types-function-calls.msl @@ -0,0 +1,103 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_7 { + float inner[2]; +}; +struct type_8 { + int inner[2]; +}; +struct type_9 { + uint inner[2]; +}; + +void func_f( + float a +) { + return; +} + +void func_i( + int a_1 +) { + return; +} + +void func_u( + uint a_2 +) { + return; +} + +void func_vf( + metal::float2 a_3 +) { + return; +} + +void func_vi( + metal::int2 a_4 +) { + return; +} + +void func_vu( + metal::uint2 a_5 +) { + return; +} + +void func_mf( + metal::float2x2 a_6 +) { + return; +} + +void func_af( + type_7 a_7 +) { + return; +} + +void func_ai( + type_8 a_8 +) { + return; +} + +void func_au( + type_9 a_9 +) { + return; +} + +void func_f_i( + float a_10, + int b +) { + return; +} + +void main_( +) { + func_f(0.0); + func_f(0.0); + func_i(0); + func_u(0u); + func_vf(metal::float2(0.0)); + func_vf(metal::float2(0.0)); + func_vi(metal::int2(0)); + func_vu(metal::uint2(0u)); + func_mf(metal::float2x2(metal::float2(0.0), metal::float2(0.0))); + func_mf(metal::float2x2(metal::float2(0.0), metal::float2(0.0))); + func_af(type_7 {0.0, 0.0}); + func_af(type_7 {0.0, 0.0}); + func_ai(type_8 {0, 0}); + func_au(type_9 {0u, 0u}); + func_f_i(0.0, 0); + func_f_i(0.0, 0); + return; +} diff --git a/naga/tests/out/spv/abstract-types-function-calls.spvasm b/naga/tests/out/spv/abstract-types-function-calls.spvasm new file mode 100644 index 0000000000..58e064bb4e --- /dev/null +++ b/naga/tests/out/spv/abstract-types-function-calls.spvasm @@ -0,0 +1,145 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 100 +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpDecorate %10 ArrayStride 4 +OpDecorate %12 ArrayStride 4 +OpDecorate %13 ArrayStride 4 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeInt 32 1 +%5 = OpTypeInt 32 0 +%6 = OpTypeVector %3 2 +%7 = OpTypeVector %4 2 +%8 = OpTypeVector %5 2 +%9 = OpTypeMatrix %6 2 +%11 = OpConstant %5 2 +%10 = OpTypeArray %3 %11 +%12 = OpTypeArray %4 %11 +%13 = OpTypeArray %5 %11 +%17 = OpTypeFunction %2 %3 +%22 = OpTypeFunction %2 %4 +%27 = OpTypeFunction %2 %5 +%32 = OpTypeFunction %2 %6 +%37 = OpTypeFunction %2 %7 +%42 = OpTypeFunction %2 %8 +%47 = OpTypeFunction %2 %9 +%52 = OpTypeFunction %2 %10 +%57 = OpTypeFunction %2 %12 +%62 = OpTypeFunction %2 %13 +%68 = OpTypeFunction %2 %3 %4 +%72 = OpTypeFunction %2 +%73 = OpConstant %3 0.0 +%74 = OpConstant %4 0 +%75 = OpConstant %5 0 +%76 = OpConstantComposite %6 %73 %73 +%77 = OpConstantComposite %7 %74 %74 +%78 = OpConstantComposite %8 %75 %75 +%79 = OpConstantComposite %9 %76 %76 +%80 = OpConstantComposite %10 %73 %73 +%81 = OpConstantComposite %12 %74 %74 +%82 = OpConstantComposite %13 %75 %75 +%16 = OpFunction %2 None %17 +%15 = OpFunctionParameter %3 +%14 = OpLabel +OpBranch %18 +%18 = OpLabel +OpReturn +OpFunctionEnd +%21 = OpFunction %2 None %22 +%20 = OpFunctionParameter %4 +%19 = OpLabel +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +%26 = OpFunction %2 None %27 +%25 = OpFunctionParameter %5 +%24 = OpLabel +OpBranch %28 +%28 = OpLabel +OpReturn +OpFunctionEnd +%31 = OpFunction %2 None %32 +%30 = OpFunctionParameter %6 +%29 = OpLabel +OpBranch %33 +%33 = OpLabel +OpReturn +OpFunctionEnd +%36 = OpFunction %2 None %37 +%35 = OpFunctionParameter %7 +%34 = OpLabel +OpBranch %38 +%38 = OpLabel +OpReturn +OpFunctionEnd +%41 = OpFunction %2 None %42 +%40 = OpFunctionParameter %8 +%39 = OpLabel +OpBranch %43 +%43 = OpLabel +OpReturn +OpFunctionEnd +%46 = OpFunction %2 None %47 +%45 = OpFunctionParameter %9 +%44 = OpLabel +OpBranch %48 +%48 = OpLabel +OpReturn +OpFunctionEnd +%51 = OpFunction %2 None %52 +%50 = OpFunctionParameter %10 +%49 = OpLabel +OpBranch %53 +%53 = OpLabel +OpReturn +OpFunctionEnd +%56 = OpFunction %2 None %57 +%55 = OpFunctionParameter %12 +%54 = OpLabel +OpBranch %58 +%58 = OpLabel +OpReturn +OpFunctionEnd +%61 = OpFunction %2 None %62 +%60 = OpFunctionParameter %13 +%59 = OpLabel +OpBranch %63 +%63 = OpLabel +OpReturn +OpFunctionEnd +%67 = OpFunction %2 None %68 +%65 = OpFunctionParameter %3 +%66 = OpFunctionParameter %4 +%64 = OpLabel +OpBranch %69 +%69 = OpLabel +OpReturn +OpFunctionEnd +%71 = OpFunction %2 None %72 +%70 = OpLabel +OpBranch %83 +%83 = OpLabel +%84 = OpFunctionCall %2 %16 %73 +%85 = OpFunctionCall %2 %16 %73 +%86 = OpFunctionCall %2 %21 %74 +%87 = OpFunctionCall %2 %26 %75 +%88 = OpFunctionCall %2 %31 %76 +%89 = OpFunctionCall %2 %31 %76 +%90 = OpFunctionCall %2 %36 %77 +%91 = OpFunctionCall %2 %41 %78 +%92 = OpFunctionCall %2 %46 %79 +%93 = OpFunctionCall %2 %46 %79 +%94 = OpFunctionCall %2 %51 %80 +%95 = OpFunctionCall %2 %51 %80 +%96 = OpFunctionCall %2 %56 %81 +%97 = OpFunctionCall %2 %61 %82 +%98 = OpFunctionCall %2 %67 %73 %74 +%99 = OpFunctionCall %2 %67 %73 %74 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/abstract-types-function-calls.wgsl b/naga/tests/out/wgsl/abstract-types-function-calls.wgsl new file mode 100644 index 0000000000..419125bc6b --- /dev/null +++ b/naga/tests/out/wgsl/abstract-types-function-calls.wgsl @@ -0,0 +1,64 @@ +fn func_f(a: f32) { + return; +} + +fn func_i(a_1: i32) { + return; +} + +fn func_u(a_2: u32) { + return; +} + +fn func_vf(a_3: vec2) { + return; +} + +fn func_vi(a_4: vec2) { + return; +} + +fn func_vu(a_5: vec2) { + return; +} + +fn func_mf(a_6: mat2x2) { + return; +} + +fn func_af(a_7: array) { + return; +} + +fn func_ai(a_8: array) { + return; +} + +fn func_au(a_9: array) { + return; +} + +fn func_f_i(a_10: f32, b: i32) { + return; +} + +fn main() { + func_f(0f); + func_f(0f); + func_i(0i); + func_u(0u); + func_vf(vec2(0f)); + func_vf(vec2(0f)); + func_vi(vec2(0i)); + func_vu(vec2(0u)); + func_mf(mat2x2(vec2(0f), vec2(0f))); + func_mf(mat2x2(vec2(0f), vec2(0f))); + func_af(array(0f, 0f)); + func_af(array(0f, 0f)); + func_ai(array(0i, 0i)); + func_au(array(0u, 0u)); + func_f_i(0f, 0i); + func_f_i(0f, 0i); + return; +} + diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 8c75058040..2460a69365 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -892,6 +892,10 @@ fn convert_wgsl() { "abstract-types-const", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, ), + ( + "abstract-types-function-calls", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + ), ( "abstract-types-var", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,