diff --git a/sway-lib-core/src/ops.sw b/sway-lib-core/src/ops.sw index 371bbff5382..a79068a0e95 100644 --- a/sway-lib-core/src/ops.sw +++ b/sway-lib-core/src/ops.sw @@ -56,38 +56,62 @@ impl Add for u64 { // Emulate overflowing arithmetic for non-64-bit integer types impl Add for u32 { fn add(self, other: Self) -> Self { - // any non-64-bit value is compiled to a u64 value under-the-hood - // constants (like Self::max() below) are also automatically promoted to u64 - let res = __add(self, other); - // integer overflow - if __gt(res, Self::max()) { + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __add(self_u64, other_u64); + let max_u32_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u32_u64) { if panic_on_overflow_is_enabled() { __revert(0) } else { // overflow enabled // res % (Self::max() + 1) - __mod(res, __add(Self::max(), 1)) + let res_u64 = __mod(res_u64, __add(max_u32_u64, 1)); + asm(input: res_u64) { + input: u32 + } } } else { - // no overflow - res + asm(input: res_u64) { + input: u32 + } } } } impl Add for u16 { fn add(self, other: Self) -> Self { - let res = __add(self, other); - if __gt(res, Self::max()) { + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __add(self_u64, other_u64); + let max_u16_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u16_u64) { if panic_on_overflow_is_enabled() { __revert(0) } else { // overflow enabled // res % (Self::max() + 1) - __mod(res, __add(Self::max(), 1)) + let res_u64 = __mod(res_u64, __add(max_u16_u64, 1)); + asm(input: res_u64) { + input: u16 + } } } else { - res + asm(input: res_u64) { + input: u16 + } } } } @@ -173,23 +197,96 @@ impl Subtract for u64 { } } -// unlike addition, underflowing subtraction does not need special treatment -// because VM handles underflow impl Subtract for u32 { fn subtract(self, other: Self) -> Self { - __sub(self, other) + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __sub(self_u64, other_u64); + let max_u32_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u32_u64) { + if panic_on_overflow_is_enabled() { + __revert(0) + } else { + // overflow enabled + // res % (Self::max() + 1) + let res_u64 = __mod(res_u64, __add(max_u32_u64, 1)); + asm(input: res_u64) { + input: u32 + } + } + } else { + asm(input: res_u64) { + input: u32 + } + } } } impl Subtract for u16 { fn subtract(self, other: Self) -> Self { - __sub(self, other) + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __sub(self_u64, other_u64); + let max_u16_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u16_u64) { + if panic_on_overflow_is_enabled() { + __revert(0) + } else { + // overflow enabled + // res % (Self::max() + 1) + let res_u64 = __mod(res_u64, __add(max_u16_u64, 1)); + asm(input: res_u64) { + input: u16 + } + } + } else { + asm(input: res_u64) { + input: u16 + } + } } } impl Subtract for u8 { fn subtract(self, other: Self) -> Self { - __sub(self, other) + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __sub(self_u64, other_u64); + let max_u8_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u8_u64) { + if panic_on_overflow_is_enabled() { + __revert(0) + } else { + // overflow enabled + // res % (Self::max() + 1) + let res_u64 = __mod(res_u64, __add(max_u8_u64, 1)); + asm(input: res_u64) { + input: u8 + } + } + } else { + asm(input: res_u64) { + input: u8 + } + } } } @@ -246,36 +343,62 @@ impl Multiply for u64 { // Emulate overflowing arithmetic for non-64-bit integer types impl Multiply for u32 { fn multiply(self, other: Self) -> Self { - // any non-64-bit value is compiled to a u64 value under-the-hood - // constants (like Self::max() below) are also automatically promoted to u64 - let res = __mul(self, other); - if __gt(res, Self::max()) { + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __mul(self_u64, other_u64); + let max_u32_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u32_u64) { if panic_on_overflow_is_enabled() { - // integer overflow __revert(0) } else { // overflow enabled // res % (Self::max() + 1) - __mod(res, __add(Self::max(), 1)) + let res_u64 = __mod(res_u64, __add(max_u32_u64, 1)); + asm(input: res_u64) { + input: u32 + } } } else { - // no overflow - res + asm(input: res_u64) { + input: u32 + } } } } impl Multiply for u16 { fn multiply(self, other: Self) -> Self { - let res = __mul(self, other); - if __gt(res, Self::max()) { + let self_u64 = asm(input: self) { + input: u64 + }; + let other_u64 = asm(input: other) { + input: u64 + }; + let res_u64 = __mul(self_u64, other_u64); + let max_u16_u64 = asm(input: Self::max()) { + input: u64 + }; + if __gt(res_u64, max_u16_u64) { if panic_on_overflow_is_enabled() { __revert(0) } else { - __mod(res, __add(Self::max(), 1)) + // overflow enabled + // res % (Self::max() + 1) + let res_u64 = __mod(res_u64, __add(max_u16_u64, 1)); + asm(input: res_u64) { + input: u16 + } } } else { - res + asm(input: res_u64) { + input: u16 + } } } } diff --git a/test/src/in_language_tests/test_programs/math_inline_tests/src/main.sw b/test/src/in_language_tests/test_programs/math_inline_tests/src/main.sw index c2641d71272..de0c4841544 100644 --- a/test/src/in_language_tests/test_programs/math_inline_tests/src/main.sw +++ b/test/src/in_language_tests/test_programs/math_inline_tests/src/main.sw @@ -900,6 +900,22 @@ fn math_u256_overflow_mul_revert() { log(b); } +#[test(should_revert)] +fn math_u16_underflow_sub_revert() { + let a = 0u16; + let b = 1u16; + let c = a - b; + log(c); +} + +#[test(should_revert)] +fn math_u32_underflow_sub_revert() { + let a = 0u32; + let b = 1u32; + let c = a - b; + log(c); +} + #[test] fn math_u8_overflow_add() { let _ = disable_panic_on_overflow(); @@ -922,6 +938,26 @@ fn math_u8_overflow_add() { require(e == u8::max() - 2, e); } +#[test] +fn math_u8_underflow_sub() { + assert((u8::max() - u8::max()) == 0u8); + assert((u8::min() - u8::min()) == 0u8); + assert((10u8 - 5u8) == 5u8); + + let _ = disable_panic_on_overflow(); + + let a = 0u8; + let b = 1u8; + + let c = a - b; + assert(c == u8::max()); + + let d = u8::max(); + + let e = a - d; + assert(e == b); +} + #[test] fn math_u16_overflow_add() { let _ = disable_panic_on_overflow(); @@ -944,6 +980,26 @@ fn math_u16_overflow_add() { require(e == u16::max() - 2, e); } +#[test] +fn math_u16_underflow_sub() { + assert((u16::max() - u16::max()) == 0u16); + assert((u16::min() - u16::min()) == 0u16); + assert((10u16 - 5u16) == 5u16); + + let _ = disable_panic_on_overflow(); + + let a = 0u16; + let b = 1u16; + + let c = a - b; + assert(c == u16::max()); + + let d = u16::max(); + + let e = a - d; + assert(e == b); +} + #[test] fn math_u32_overflow_add() { let _ = disable_panic_on_overflow(); @@ -966,6 +1022,26 @@ fn math_u32_overflow_add() { require(e == u32::max() - 2, e); } +#[test] +fn math_u32_underflow_sub() { + assert((u32::max() - u32::max()) == 0u32); + assert((u32::min() - u32::min()) == 0u32); + assert((10u32 - 5u32) == 5u32); + + let _ = disable_panic_on_overflow(); + + let a = 0u32; + let b = 1u32; + + let c = a - b; + assert(c == u32::max()); + + let d = u32::max(); + + let e = a - d; + assert(e == b); +} + #[test] fn math_u64_overflow_add() { let _ = disable_panic_on_overflow();