From 4dd901cc654455fff64cf49f7677f92a3cbf16b2 Mon Sep 17 00:00:00 2001 From: SirLynix Date: Tue, 26 Mar 2024 14:21:40 +0100 Subject: [PATCH] Simplify constant propagation code --- include/NZSL/Ast/SanitizeVisitor.hpp | 2 +- src/NZSL/Ast/ConstantPropagationVisitor.cpp | 36 +- ...ntPropagationVisitor_BinaryArithmetics.cpp | 449 +++++------------- ...antPropagationVisitor_BinaryComparison.cpp | 293 +++--------- tests/src/Tests/ErrorsTests.cpp | 4 +- 5 files changed, 199 insertions(+), 585 deletions(-) diff --git a/include/NZSL/Ast/SanitizeVisitor.hpp b/include/NZSL/Ast/SanitizeVisitor.hpp index 42437aa..2f84381 100644 --- a/include/NZSL/Ast/SanitizeVisitor.hpp +++ b/include/NZSL/Ast/SanitizeVisitor.hpp @@ -16,8 +16,8 @@ #include #include #include -#include #include +#include namespace nzsl::Ast { diff --git a/src/NZSL/Ast/ConstantPropagationVisitor.cpp b/src/NZSL/Ast/ConstantPropagationVisitor.cpp index b814494..24d7106 100644 --- a/src/NZSL/Ast/ConstantPropagationVisitor.cpp +++ b/src/NZSL/Ast/ConstantPropagationVisitor.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include #include #include @@ -13,25 +14,6 @@ namespace nzsl::Ast { namespace NAZARA_ANONYMOUS_NAMESPACE { - template - struct IsCompleteHelper - { - // SFINAE: sizeof of an incomplete type is an error, but since there's another specialization it won't result in a compilation error - template - static auto test(U*) -> std::bool_constant; - - // less specialized overload - static auto test(...) -> std::false_type; - - using type = decltype(test(static_cast(nullptr))); - }; - - template - struct IsComplete : IsCompleteHelper::type {}; - - template - inline constexpr bool IsComplete_v = IsComplete::value; - template struct VectorInfo { @@ -639,7 +621,7 @@ namespace nzsl::Ast { using T = std::decay_t; - if constexpr (IsComplete_v>) + if constexpr (Nz::IsComplete_v>) { ArrayBuilder builder; optimized = builder(expressions, node.sourceLocation); @@ -1004,7 +986,7 @@ namespace nzsl::Ast using T = std::decay_t; using CCType = CastConstant; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = CCType{}(arg, sourceLocation); }, operand.value); @@ -1026,7 +1008,7 @@ namespace nzsl::Ast using SPType = SwizzlePropagation; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = SPType{}(components, arg, sourceLocation); }, operand.value); @@ -1044,10 +1026,10 @@ namespace nzsl::Ast using T = std::decay_t; using PCType = UnaryConstantPropagation; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) { using Op = typename PCType::Op; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = Op{}(arg, sourceLocation); } }, operand.value); @@ -1068,7 +1050,7 @@ namespace nzsl::Ast using CCType = CastConstant, TargetType, TargetType>; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = CCType{}(v1, v2, sourceLocation); return optimized; @@ -1088,7 +1070,7 @@ namespace nzsl::Ast using CCType = CastConstant, TargetType, TargetType, TargetType>; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = CCType{}(v1, v2, v3, sourceLocation); return optimized; @@ -1109,7 +1091,7 @@ namespace nzsl::Ast using CCType = CastConstant, TargetType, TargetType, TargetType, TargetType>; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) optimized = CCType{}(v1, v2, v3, v4, sourceLocation); return optimized; diff --git a/src/NZSL/Ast/ConstantPropagationVisitor_BinaryArithmetics.cpp b/src/NZSL/Ast/ConstantPropagationVisitor_BinaryArithmetics.cpp index 0457dd7..d3360fe 100644 --- a/src/NZSL/Ast/ConstantPropagationVisitor_BinaryArithmetics.cpp +++ b/src/NZSL/Ast/ConstantPropagationVisitor_BinaryArithmetics.cpp @@ -2,10 +2,11 @@ // This file is part of the "Nazara Shading Language" project // For conditions of distribution and use, see copyright notice in Config.hpp +#include #include #include +#include #include -#include #include #include #include @@ -15,25 +16,6 @@ namespace nzsl::Ast { namespace NAZARA_ANONYMOUS_NAMESPACE { - template - struct IsCompleteHelper - { - // SFINAE: sizeof in an incomplete type is an error, but since there's another specialization it won't result in a compilation error - template - static auto test(U*) -> std::bool_constant; - - // less specialized overload - static auto test(...) -> std::false_type; - - using type = decltype(test(static_cast(nullptr))); - }; - - template - struct IsComplete : IsCompleteHelper::type {}; - - template - inline constexpr bool IsComplete_v = IsComplete::value; - /*************************************************************************************************/ template @@ -41,211 +23,144 @@ namespace nzsl::Ast // Addition template - struct BinaryAdditionBase + struct BinaryAddition { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::Add; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs + rhs); + return lhs + rhs; } }; - template - struct BinaryAddition; - - template - struct BinaryConstantPropagation - { - using Op = BinaryAddition; - }; - // BitwiseAnd template - struct BinaryBitwiseAndBase + struct BinaryBitwiseAnd { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::BitwiseAnd; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs & rhs); + return lhs & rhs; } }; - template - struct BinaryBitwiseAnd; - - template - struct BinaryConstantPropagation - { - using Op = BinaryBitwiseAnd; - }; - // BitwiseOr template - struct BinaryBitwiseOrBase + struct BinaryBitwiseOr { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::BitwiseOr; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs | rhs); + return lhs | rhs; } }; - template - struct BinaryBitwiseOr; - - template - struct BinaryConstantPropagation - { - using Op = BinaryBitwiseOr; - }; - // BitwiseXor template - struct BinaryBitwiseXorBase + struct BinaryBitwiseXor { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::BitwiseXor; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs ^ rhs); + return lhs ^ rhs; } }; - template - struct BinaryBitwiseXor; - - template - struct BinaryConstantPropagation - { - using Op = BinaryBitwiseXor; - }; - // Division template - struct BinaryDivisionBase + struct BinaryDivision { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) + static constexpr BinaryType Type = BinaryType::Divide; + static constexpr bool AllowSingleOperand = true; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) { if constexpr (std::is_integral_v) { if (rhs == 0) throw CompilerIntegralDivisionByZeroError{ sourceLocation, ConstantToString(lhs), ConstantToString(rhs) }; } - else if constexpr (IsVector_v) - { - for (std::size_t i = 0; i < T2::Dimensions; ++i) - { - if (rhs[i] == 0) - throw CompilerIntegralDivisionByZeroError{ sourceLocation, ConstantToString(lhs), ConstantToString(rhs) }; - } - } - return ShaderBuilder::ConstantValue(lhs / rhs); + return lhs / rhs; } }; - template - struct BinaryDivision; - - template - struct BinaryConstantPropagation - { - using Op = BinaryDivision; - }; - // LogicalAnd template - struct BinaryLogicalAndBase + struct BinaryLogicalAnd { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::LogicalAnd; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs && rhs); + return lhs && rhs; } }; - template - struct BinaryLogicalAnd; - - template - struct BinaryConstantPropagation - { - using Op = BinaryLogicalAnd; - }; - // LogicalOr template - struct BinaryLogicalOrBase + struct BinaryLogicalOr { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::LogicalOr; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs || rhs); + return lhs || rhs; } }; - template - struct BinaryLogicalOr; - - template - struct BinaryConstantPropagation - { - using Op = BinaryLogicalOr; - }; - // Modulo template - struct BinaryModuloBase + struct BinaryModulo { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) + static constexpr BinaryType Type = BinaryType::Modulo; + static constexpr bool AllowSingleOperand = true; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) { if constexpr (std::is_integral_v) { if (rhs == 0) throw CompilerIntegralModuloByZeroError{ sourceLocation, ConstantToString(lhs), ConstantToString(rhs) }; } - else if constexpr (IsVector_v) - { - for (std::size_t i = 0; i < T2::Dimensions; ++i) - { - if (rhs[i] == 0) - throw CompilerIntegralModuloByZeroError{ sourceLocation, ConstantToString(lhs), ConstantToString(rhs) }; - } - } if constexpr (std::is_floating_point_v && std::is_floating_point_v) - return ShaderBuilder::ConstantValue(std::fmod(lhs, rhs)); + return std::fmod(lhs, rhs); else - return ShaderBuilder::ConstantValue(lhs % rhs); + return lhs % rhs; } }; - template - struct BinaryModulo; - - template - struct BinaryConstantPropagation - { - using Op = BinaryModulo; - }; - // Multiplication template - struct BinaryMultiplicationBase + struct BinaryMultiplication { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::Multiply; + static constexpr bool AllowSingleOperand = true; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs * rhs); + return lhs * rhs; } }; - template - struct BinaryMultiplication; - - template - struct BinaryConstantPropagation - { - using Op = BinaryMultiplication; - }; - // ShiftLeft template - struct BinaryShiftLeftBase + struct BinaryShiftLeft { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) + static constexpr BinaryType Type = BinaryType::ShiftLeft; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) { if constexpr (std::is_integral_v) { @@ -259,24 +174,18 @@ namespace nzsl::Ast throw CompilerBinaryTooLargeShiftError{ sourceLocation, ConstantToString(lhs), "<<", ConstantToString(rhs), ToString(GetConstantExpressionType()) }; } - return ShaderBuilder::ConstantValue(lhs << rhs); + return lhs << rhs; } }; - template - struct BinaryShiftLeft; - - template - struct BinaryConstantPropagation - { - using Op = BinaryShiftLeft; - }; - // ShiftRight template - struct BinaryShiftRightBase + struct BinaryShiftRight { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) + static constexpr BinaryType Type = BinaryType::ShiftRight; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& sourceLocation) { if constexpr (std::is_integral_v) { @@ -290,60 +199,31 @@ namespace nzsl::Ast throw CompilerBinaryTooLargeShiftError{ sourceLocation, ConstantToString(lhs), ">>", ConstantToString(rhs), ToString(GetConstantExpressionType()) }; } - return ShaderBuilder::ConstantValue(Nz::ArithmeticRightShift(lhs, rhs)); + return Nz::ArithmeticRightShift(lhs, rhs); } }; - template - struct BinaryShiftRight; - - template - struct BinaryConstantPropagation - { - using Op = BinaryShiftRight; - }; - // Subtraction template - struct BinarySubtractionBase + struct BinarySubtraction { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::Subtract; + static constexpr bool AllowSingleOperand = false; + + NAZARA_FORCEINLINE auto operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs - rhs); + return lhs - rhs; } }; - template - struct BinarySubtraction; - - template - struct BinaryConstantPropagation - { - using Op = BinarySubtraction; - }; - /*************************************************************************************************/ -#define EnableOptimisation(Op, ...) template<> struct Op<__VA_ARGS__> : Op##Base<__VA_ARGS__> {} - - // Binary +#define EnableOptimisation(Impl, ...) template<> struct BinaryConstantPropagation::Type, __VA_ARGS__> : Impl<__VA_ARGS__> {} EnableOptimisation(BinaryAddition, double, double); EnableOptimisation(BinaryAddition, float, float); EnableOptimisation(BinaryAddition, std::int32_t, std::int32_t); EnableOptimisation(BinaryAddition, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryAddition, Vector2f32, Vector2f32); - EnableOptimisation(BinaryAddition, Vector3f32, Vector3f32); - EnableOptimisation(BinaryAddition, Vector4f32, Vector4f32); - EnableOptimisation(BinaryAddition, Vector2f64, Vector2f64); - EnableOptimisation(BinaryAddition, Vector3f64, Vector3f64); - EnableOptimisation(BinaryAddition, Vector4f64, Vector4f64); - EnableOptimisation(BinaryAddition, Vector2i32, Vector2i32); - EnableOptimisation(BinaryAddition, Vector3i32, Vector3i32); - EnableOptimisation(BinaryAddition, Vector4i32, Vector4i32); - EnableOptimisation(BinaryAddition, Vector2u32, Vector2u32); - EnableOptimisation(BinaryAddition, Vector3u32, Vector3u32); - EnableOptimisation(BinaryAddition, Vector4u32, Vector4u32); EnableOptimisation(BinaryBitwiseAnd, std::int32_t, std::uint32_t); EnableOptimisation(BinaryBitwiseAnd, std::uint32_t, std::int32_t); @@ -361,130 +241,22 @@ namespace nzsl::Ast EnableOptimisation(BinaryBitwiseXor, std::int32_t, std::int32_t); EnableOptimisation(BinaryDivision, double, double); - EnableOptimisation(BinaryDivision, double, Vector2f64); - EnableOptimisation(BinaryDivision, double, Vector3f64); - EnableOptimisation(BinaryDivision, double, Vector4f64); EnableOptimisation(BinaryDivision, float, float); - EnableOptimisation(BinaryDivision, float, Vector2f32); - EnableOptimisation(BinaryDivision, float, Vector3f32); - EnableOptimisation(BinaryDivision, float, Vector4f32); EnableOptimisation(BinaryDivision, std::int32_t, std::int32_t); - EnableOptimisation(BinaryDivision, std::int32_t, Vector2i32); - EnableOptimisation(BinaryDivision, std::int32_t, Vector3i32); - EnableOptimisation(BinaryDivision, std::int32_t, Vector4i32); EnableOptimisation(BinaryDivision, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryDivision, std::uint32_t, Vector2u32); - EnableOptimisation(BinaryDivision, std::uint32_t, Vector3u32); - EnableOptimisation(BinaryDivision, std::uint32_t, Vector4u32); - EnableOptimisation(BinaryDivision, Vector2f32, float); - EnableOptimisation(BinaryDivision, Vector2f32, Vector2f32); - EnableOptimisation(BinaryDivision, Vector3f32, float); - EnableOptimisation(BinaryDivision, Vector3f32, Vector3f32); - EnableOptimisation(BinaryDivision, Vector4f32, float); - EnableOptimisation(BinaryDivision, Vector4f32, Vector4f32); - EnableOptimisation(BinaryDivision, Vector2f64, double); - EnableOptimisation(BinaryDivision, Vector2f64, Vector2f64); - EnableOptimisation(BinaryDivision, Vector3f64, double); - EnableOptimisation(BinaryDivision, Vector3f64, Vector3f64); - EnableOptimisation(BinaryDivision, Vector4f64, double); - EnableOptimisation(BinaryDivision, Vector4f64, Vector4f64); - EnableOptimisation(BinaryDivision, Vector2i32, std::int32_t); - EnableOptimisation(BinaryDivision, Vector2i32, Vector2i32); - EnableOptimisation(BinaryDivision, Vector3i32, std::int32_t); - EnableOptimisation(BinaryDivision, Vector3i32, Vector3i32); - EnableOptimisation(BinaryDivision, Vector4i32, std::int32_t); - EnableOptimisation(BinaryDivision, Vector4i32, Vector4i32); - EnableOptimisation(BinaryDivision, Vector2u32, std::uint32_t); - EnableOptimisation(BinaryDivision, Vector2u32, Vector2u32); - EnableOptimisation(BinaryDivision, Vector3u32, std::uint32_t); - EnableOptimisation(BinaryDivision, Vector3u32, Vector3u32); - EnableOptimisation(BinaryDivision, Vector4u32, std::uint32_t); - EnableOptimisation(BinaryDivision, Vector4u32, Vector4u32); EnableOptimisation(BinaryLogicalAnd, bool, bool); EnableOptimisation(BinaryLogicalOr, bool, bool); EnableOptimisation(BinaryModulo, double, double); - EnableOptimisation(BinaryModulo, double, Vector2f64); - EnableOptimisation(BinaryModulo, double, Vector3f64); - EnableOptimisation(BinaryModulo, double, Vector4f64); EnableOptimisation(BinaryModulo, float, float); - EnableOptimisation(BinaryModulo, float, Vector2f32); - EnableOptimisation(BinaryModulo, float, Vector3f32); - EnableOptimisation(BinaryModulo, float, Vector4f32); EnableOptimisation(BinaryModulo, std::int32_t, std::int32_t); - EnableOptimisation(BinaryModulo, std::int32_t, Vector2i32); - EnableOptimisation(BinaryModulo, std::int32_t, Vector3i32); - EnableOptimisation(BinaryModulo, std::int32_t, Vector4i32); EnableOptimisation(BinaryModulo, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryModulo, std::uint32_t, Vector2u32); - EnableOptimisation(BinaryModulo, std::uint32_t, Vector3u32); - EnableOptimisation(BinaryModulo, std::uint32_t, Vector4u32); - EnableOptimisation(BinaryModulo, Vector2f32, float); - EnableOptimisation(BinaryModulo, Vector2f32, Vector2f32); - EnableOptimisation(BinaryModulo, Vector3f32, float); - EnableOptimisation(BinaryModulo, Vector3f32, Vector3f32); - EnableOptimisation(BinaryModulo, Vector4f32, float); - EnableOptimisation(BinaryModulo, Vector4f32, Vector4f32); - EnableOptimisation(BinaryModulo, Vector2f64, double); - EnableOptimisation(BinaryModulo, Vector2f64, Vector2f64); - EnableOptimisation(BinaryModulo, Vector3f64, double); - EnableOptimisation(BinaryModulo, Vector3f64, Vector3f64); - EnableOptimisation(BinaryModulo, Vector4f64, double); - EnableOptimisation(BinaryModulo, Vector4f64, Vector4f64); - EnableOptimisation(BinaryModulo, Vector2i32, std::int32_t); - EnableOptimisation(BinaryModulo, Vector2i32, Vector2i32); - EnableOptimisation(BinaryModulo, Vector3i32, std::int32_t); - EnableOptimisation(BinaryModulo, Vector3i32, Vector3i32); - EnableOptimisation(BinaryModulo, Vector4i32, std::int32_t); - EnableOptimisation(BinaryModulo, Vector4i32, Vector4i32); - EnableOptimisation(BinaryModulo, Vector2u32, std::uint32_t); - EnableOptimisation(BinaryModulo, Vector2u32, Vector2u32); - EnableOptimisation(BinaryModulo, Vector3u32, std::uint32_t); - EnableOptimisation(BinaryModulo, Vector3u32, Vector3u32); - EnableOptimisation(BinaryModulo, Vector4u32, std::uint32_t); - EnableOptimisation(BinaryModulo, Vector4u32, Vector4u32); EnableOptimisation(BinaryMultiplication, double, double); - EnableOptimisation(BinaryMultiplication, double, Vector2f64); - EnableOptimisation(BinaryMultiplication, double, Vector3f64); - EnableOptimisation(BinaryMultiplication, double, Vector4f64); EnableOptimisation(BinaryMultiplication, float, float); - EnableOptimisation(BinaryMultiplication, float, Vector2f32); - EnableOptimisation(BinaryMultiplication, float, Vector3f32); - EnableOptimisation(BinaryMultiplication, float, Vector4f32); EnableOptimisation(BinaryMultiplication, std::int32_t, std::int32_t); - EnableOptimisation(BinaryMultiplication, std::int32_t, Vector2i32); - EnableOptimisation(BinaryMultiplication, std::int32_t, Vector3i32); - EnableOptimisation(BinaryMultiplication, std::int32_t, Vector4i32); EnableOptimisation(BinaryMultiplication, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryMultiplication, std::uint32_t, Vector2u32); - EnableOptimisation(BinaryMultiplication, std::uint32_t, Vector3u32); - EnableOptimisation(BinaryMultiplication, std::uint32_t, Vector4u32); - EnableOptimisation(BinaryMultiplication, Vector2f32, float); - EnableOptimisation(BinaryMultiplication, Vector2f32, Vector2f32); - EnableOptimisation(BinaryMultiplication, Vector3f32, float); - EnableOptimisation(BinaryMultiplication, Vector3f32, Vector3f32); - EnableOptimisation(BinaryMultiplication, Vector4f32, float); - EnableOptimisation(BinaryMultiplication, Vector4f32, Vector4f32); - EnableOptimisation(BinaryMultiplication, Vector2f64, double); - EnableOptimisation(BinaryMultiplication, Vector2f64, Vector2f64); - EnableOptimisation(BinaryMultiplication, Vector3f64, double); - EnableOptimisation(BinaryMultiplication, Vector3f64, Vector3f64); - EnableOptimisation(BinaryMultiplication, Vector4f64, double); - EnableOptimisation(BinaryMultiplication, Vector4f64, Vector4f64); - EnableOptimisation(BinaryMultiplication, Vector2i32, std::int32_t); - EnableOptimisation(BinaryMultiplication, Vector2i32, Vector2i32); - EnableOptimisation(BinaryMultiplication, Vector3i32, std::int32_t); - EnableOptimisation(BinaryMultiplication, Vector3i32, Vector3i32); - EnableOptimisation(BinaryMultiplication, Vector4i32, std::int32_t); - EnableOptimisation(BinaryMultiplication, Vector4i32, Vector4i32); - EnableOptimisation(BinaryMultiplication, Vector2u32, std::uint32_t); - EnableOptimisation(BinaryMultiplication, Vector2u32, Vector2u32); - EnableOptimisation(BinaryMultiplication, Vector3u32, std::uint32_t); - EnableOptimisation(BinaryMultiplication, Vector3u32, Vector3u32); - EnableOptimisation(BinaryMultiplication, Vector4u32, std::uint32_t); - EnableOptimisation(BinaryMultiplication, Vector4u32, Vector4u32); EnableOptimisation(BinaryShiftLeft, std::int32_t, std::int32_t); EnableOptimisation(BinaryShiftLeft, std::int32_t, std::uint32_t); @@ -500,18 +272,6 @@ namespace nzsl::Ast EnableOptimisation(BinarySubtraction, float, float); EnableOptimisation(BinarySubtraction, std::int32_t, std::int32_t); EnableOptimisation(BinarySubtraction, std::uint32_t, std::uint32_t); - EnableOptimisation(BinarySubtraction, Vector2f32, Vector2f32); - EnableOptimisation(BinarySubtraction, Vector3f32, Vector3f32); - EnableOptimisation(BinarySubtraction, Vector4f32, Vector4f32); - EnableOptimisation(BinarySubtraction, Vector2f64, Vector2f64); - EnableOptimisation(BinarySubtraction, Vector3f64, Vector3f64); - EnableOptimisation(BinarySubtraction, Vector4f64, Vector4f64); - EnableOptimisation(BinarySubtraction, Vector2i32, Vector2i32); - EnableOptimisation(BinarySubtraction, Vector3i32, Vector3i32); - EnableOptimisation(BinarySubtraction, Vector4i32, Vector4i32); - EnableOptimisation(BinarySubtraction, Vector2u32, Vector2u32); - EnableOptimisation(BinarySubtraction, Vector3u32, Vector3u32); - EnableOptimisation(BinarySubtraction, Vector4u32, Vector4u32); #undef EnableOptimisation } @@ -547,13 +307,60 @@ namespace nzsl::Ast { using T1 = std::decay_t; using T2 = std::decay_t; - using PCType = BinaryConstantPropagation; + using Op = BinaryConstantPropagation; + + if constexpr (Nz::IsComplete_v) + optimized = ShaderBuilder::ConstantValue(Op{}(arg1, arg2, sourceLocation)); + else if constexpr (IsVector_v && IsVector_v) + { + using SubOp = BinaryConstantPropagation; + if constexpr (Nz::IsComplete_v && T1::Dimensions == T2::Dimensions) + { + using RetBaseType = std::decay_t>; + using RetType = Vector; + + RetType value; + for (std::size_t i = 0; i < T1::Dimensions; ++i) + value[i] = SubOp{}(arg1[i], arg2[i], sourceLocation); + + optimized = ShaderBuilder::ConstantValue(value); + } + } + else if constexpr (IsVector_v) + { + using SubOp = BinaryConstantPropagation; + if constexpr (Nz::IsComplete_v) + { + if constexpr (SubOp::AllowSingleOperand) + { + using RetBaseType = std::decay_t>; + using RetType = Vector; + + RetType value; + for (std::size_t i = 0; i < T1::Dimensions; ++i) + value[i] = SubOp{}(arg1[i], arg2, sourceLocation); - if constexpr (IsComplete_v) + optimized = ShaderBuilder::ConstantValue(value); + } + } + } + else if constexpr (IsVector_v) { - using Op = typename PCType::Op; - if constexpr (IsComplete_v) - optimized = Op{}(arg1, arg2, sourceLocation); + using SubOp = BinaryConstantPropagation; + if constexpr (Nz::IsComplete_v) + { + if constexpr (SubOp::AllowSingleOperand) + { + using RetBaseType = std::decay_t>; + using RetType = Vector; + + RetType value; + for (std::size_t i = 0; i < T2::Dimensions; ++i) + value[i] = SubOp{}(arg1, arg2[i], sourceLocation); + + optimized = ShaderBuilder::ConstantValue(value); + } + } } }, lhs.value, rhs.value); diff --git a/src/NZSL/Ast/ConstantPropagationVisitor_BinaryComparison.cpp b/src/NZSL/Ast/ConstantPropagationVisitor_BinaryComparison.cpp index e59564a..ffdfe6f 100644 --- a/src/NZSL/Ast/ConstantPropagationVisitor_BinaryComparison.cpp +++ b/src/NZSL/Ast/ConstantPropagationVisitor_BinaryComparison.cpp @@ -2,8 +2,9 @@ // This file is part of the "Nazara Shading Language" project // For conditions of distribution and use, see copyright notice in Config.hpp -#include +#include #include +#include #include #include #include @@ -13,40 +14,6 @@ namespace nzsl::Ast { namespace NAZARA_ANONYMOUS_NAMESPACE { - template - struct IsCompleteHelper - { - // SFINAE: sizeof in an incomplete type is an error, but since there's another specialization it won't result in a compilation error - template - static auto test(U*) -> std::bool_constant; - - // less specialized overload - static auto test(...) -> std::false_type; - - using type = decltype(test(static_cast(nullptr))); - }; - - template - struct IsComplete : IsCompleteHelper::type {}; - - template - inline constexpr bool IsComplete_v = IsComplete::value; - - /*************************************************************************************************/ - - template - struct IsVector : std::false_type - { - }; - - template - struct IsVector> : std::true_type - { - }; - - template - inline constexpr bool IsVector_v = IsVector::value; - /*************************************************************************************************/ template @@ -54,177 +21,103 @@ namespace nzsl::Ast // CompEq template - struct BinaryCompEqBase + struct BinaryCompEq { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompEq; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentEq(rhs)); - else - return ShaderBuilder::ConstantValue(lhs == rhs); + return lhs == rhs; } }; - template - struct BinaryCompEq; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompEq; - }; - // CompGe template - struct BinaryCompGeBase + struct BinaryCompGe { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompGe; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentGe(rhs)); - else - return ShaderBuilder::ConstantValue(lhs >= rhs); + return lhs >= rhs; } }; - template - struct BinaryCompGe; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompGe; - }; - // CompGt template - struct BinaryCompGtBase + struct BinaryCompGt { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompGt; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentGt(rhs)); - else - return ShaderBuilder::ConstantValue(lhs > rhs); + return lhs > rhs; } }; - template - struct BinaryCompGt; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompGt; - }; - // CompLe template - struct BinaryCompLeBase + struct BinaryCompLe { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompLe; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentLe(rhs)); - else - return ShaderBuilder::ConstantValue(lhs <= rhs); + return lhs <= rhs; } }; - template - struct BinaryCompLe; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompLe; - }; - // CompLt template - struct BinaryCompLtBase + struct BinaryCompLt { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompLt; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentLt(rhs)); - else - return ShaderBuilder::ConstantValue(lhs < rhs); + return lhs < rhs; } }; - template - struct BinaryCompLt; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompLe; - }; - // CompNe template - struct BinaryCompNeBase + struct BinaryCompNe { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::CompNe; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - if constexpr (IsVector_v) - return ShaderBuilder::ConstantValue(lhs.ComponentNe(rhs)); - else - return ShaderBuilder::ConstantValue(lhs != rhs); + return lhs != rhs; } }; - template - struct BinaryCompNe; - - template - struct BinaryConstantPropagation - { - using Op = BinaryCompNe; - }; - // LogicalAnd template - struct BinaryLogicalAndBase + struct BinaryLogicalAnd { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::LogicalAnd; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs && rhs); + return lhs && rhs; } }; - template - struct BinaryLogicalAnd; - - template - struct BinaryConstantPropagation - { - using Op = BinaryLogicalAnd; - }; - // LogicalOr template - struct BinaryLogicalOrBase + struct BinaryLogicalOr { - std::unique_ptr operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) + static constexpr BinaryType Type = BinaryType::LogicalOr; + + NAZARA_FORCEINLINE bool operator()(const T1& lhs, const T2& rhs, const SourceLocation& /*sourceLocation*/) { - return ShaderBuilder::ConstantValue(lhs || rhs); + return lhs || rhs; } }; - template - struct BinaryLogicalOr; - - template - struct BinaryConstantPropagation - { - using Op = BinaryLogicalOr; - }; - /*************************************************************************************************/ -#define EnableOptimisation(Op, ...) template<> struct Op<__VA_ARGS__> : Op##Base<__VA_ARGS__> {} +#define EnableOptimisation(Impl, ...) template<> struct BinaryConstantPropagation::Type, __VA_ARGS__> : Impl<__VA_ARGS__> {} // Binary @@ -233,113 +126,35 @@ namespace nzsl::Ast EnableOptimisation(BinaryCompEq, float, float); EnableOptimisation(BinaryCompEq, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompEq, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompEq, Vector2, Vector2); - EnableOptimisation(BinaryCompEq, Vector3, Vector3); - EnableOptimisation(BinaryCompEq, Vector4, Vector4); - EnableOptimisation(BinaryCompEq, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompEq, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompEq, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompEq, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompEq, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompEq, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompEq, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompEq, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompEq, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompEq, Vector4u32, Vector4u32); EnableOptimisation(BinaryCompGe, double, double); EnableOptimisation(BinaryCompGe, float, float); EnableOptimisation(BinaryCompGe, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompGe, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompGe, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompGe, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompGe, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompGe, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompGe, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompGe, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompGe, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompGe, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompGe, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompGe, Vector4u32, Vector4u32); EnableOptimisation(BinaryCompGt, double, double); EnableOptimisation(BinaryCompGt, float, float); EnableOptimisation(BinaryCompGt, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompGt, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompGt, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompGt, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompGt, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompGt, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompGt, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompGt, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompGt, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompGt, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompGt, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompGt, Vector4u32, Vector4u32); EnableOptimisation(BinaryCompLe, double, double); EnableOptimisation(BinaryCompLe, float, float); EnableOptimisation(BinaryCompLe, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompLe, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompLe, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompLe, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompLe, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompLe, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompLe, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompLe, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompLe, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompLe, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompLe, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompLe, Vector4u32, Vector4u32); EnableOptimisation(BinaryCompLt, double, double); EnableOptimisation(BinaryCompLt, float, float); EnableOptimisation(BinaryCompLt, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompLt, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompLt, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompLt, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompLt, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompLt, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompLt, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompLt, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompLt, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompLt, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompLt, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompLt, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompLt, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompLt, Vector4u32, Vector4u32); EnableOptimisation(BinaryCompNe, bool, bool); EnableOptimisation(BinaryCompNe, double, double); EnableOptimisation(BinaryCompNe, float, float); EnableOptimisation(BinaryCompNe, std::int32_t, std::int32_t); EnableOptimisation(BinaryCompNe, std::uint32_t, std::uint32_t); - EnableOptimisation(BinaryCompNe, Vector2, Vector2); - EnableOptimisation(BinaryCompNe, Vector3, Vector3); - EnableOptimisation(BinaryCompNe, Vector4, Vector4); - EnableOptimisation(BinaryCompNe, Vector2f32, Vector2f32); - EnableOptimisation(BinaryCompNe, Vector3f32, Vector3f32); - EnableOptimisation(BinaryCompNe, Vector4f32, Vector4f32); - EnableOptimisation(BinaryCompNe, Vector2f64, Vector2f64); - EnableOptimisation(BinaryCompNe, Vector3f64, Vector3f64); - EnableOptimisation(BinaryCompNe, Vector4f64, Vector4f64); - EnableOptimisation(BinaryCompNe, Vector2i32, Vector2i32); - EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32); - EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompNe, Vector2u32, Vector2u32); - EnableOptimisation(BinaryCompNe, Vector3u32, Vector3u32); - EnableOptimisation(BinaryCompNe, Vector4u32, Vector4u32); EnableOptimisation(BinaryLogicalAnd, bool, bool); - EnableOptimisation(BinaryLogicalOr, bool, bool); + EnableOptimisation(BinaryLogicalOr, bool, bool); #undef EnableOptimisation } @@ -369,13 +184,23 @@ namespace nzsl::Ast { using T1 = std::decay_t; using T2 = std::decay_t; - using PCType = BinaryConstantPropagation; + using Op = BinaryConstantPropagation; - if constexpr (IsComplete_v) + if constexpr (Nz::IsComplete_v) + optimized = ShaderBuilder::ConstantValue(Op{}(arg1, arg2, sourceLocation)); + else if constexpr (IsVector_v && IsVector_v) { - using Op = typename PCType::Op; - if constexpr (IsComplete_v) - optimized = Op{}(arg1, arg2, sourceLocation); + using SubOp = BinaryConstantPropagation; + if constexpr (Nz::IsComplete_v && T1::Dimensions == T2::Dimensions) + { + using RetType = Vector; + + RetType value; + for (std::size_t i = 0; i < T1::Dimensions; ++i) + value[i] = SubOp{}(arg1[i], arg2[i], sourceLocation); + + optimized = ShaderBuilder::ConstantValue(value); + } } }, lhs.value, rhs.value); diff --git a/tests/src/Tests/ErrorsTests.cpp b/tests/src/Tests/ErrorsTests.cpp index a9aa422..d00d981 100644 --- a/tests/src/Tests/ErrorsTests.cpp +++ b/tests/src/Tests/ErrorsTests.cpp @@ -325,7 +325,7 @@ module; const V = vec4[i32](7, 6, 5, 4) / vec4[i32](3, 2, 1, 0); -)"), "(5,11 -> 55): CIntegralDivisionByZero error: integral division by zero in expression (vec4[i32](7, 6, 5, 4) / vec4[i32](3, 2, 1, 0))"); +)"), "(5,11 -> 55): CIntegralDivisionByZero error: integral division by zero in expression (4 / 0)"); CHECK_THROWS_WITH(Compile(R"( [nzsl_version("1.0")] @@ -341,7 +341,7 @@ module; const V = vec4[i32](7, 6, 5, 4) % vec4[i32](3, 2, 1, 0); -)"), "(5,11 -> 55): CIntegralModuloByZero error: integral modulo by zero in expression (vec4[i32](7, 6, 5, 4) % vec4[i32](3, 2, 1, 0))"); +)"), "(5,11 -> 55): CIntegralModuloByZero error: integral modulo by zero in expression (4 % 0)"); CHECK_THROWS_WITH(Compile(R"( [nzsl_version("1.0")]