From adcd715d078e106b6fd33ad498231453040355cf Mon Sep 17 00:00:00 2001 From: John-John Tedro Date: Sun, 3 Nov 2024 15:08:00 +0100 Subject: [PATCH] rune: Clean up binary operations --- crates/rune/src/runtime/hasher.rs | 10 - crates/rune/src/runtime/value.rs | 342 ++++++++---------------- crates/rune/src/runtime/value/inline.rs | 93 +++++-- 3 files changed, 171 insertions(+), 274 deletions(-) diff --git a/crates/rune/src/runtime/hasher.rs b/crates/rune/src/runtime/hasher.rs index 34a7183c3..ee5300644 100644 --- a/crates/rune/src/runtime/hasher.rs +++ b/crates/rune/src/runtime/hasher.rs @@ -28,16 +28,6 @@ impl Hasher { self.hasher.write(string.as_bytes()); } - /// Hash an 64-bit float. - /// - /// You should ensure that the float is normal per the [`f64::is_normal`] - /// function before hashing it, since otherwise equality tests against the - /// float won't work as intended. Otherwise, know what you're doing. - pub(crate) fn write_f64(&mut self, value: f64) { - let bits = value.to_bits(); - self.hasher.write_u64(bits); - } - /// Construct a hash. pub fn finish(&self) -> u64 { self.hasher.finish() diff --git a/crates/rune/src/runtime/value.rs b/crates/rune/src/runtime/value.rs index df7f1bc32..fefcceef6 100644 --- a/crates/rune/src/runtime/value.rs +++ b/crates/rune/src/runtime/value.rs @@ -15,8 +15,6 @@ pub use self::data::{EmptyStruct, Struct, TupleStruct}; use core::any; use core::cmp::Ordering; use core::fmt; -#[cfg(feature = "alloc")] -use core::hash::Hasher as _; use core::mem::replace; use core::ptr::NonNull; @@ -26,7 +24,6 @@ use crate::alloc::fmt::TryWrite; use crate::alloc::prelude::*; use crate::alloc::{self, String}; use crate::compile::meta; -use crate::runtime; use crate::{Any, Hash, TypeHash}; use super::{ @@ -1151,135 +1148,19 @@ impl Value { b: &Value, caller: &mut dyn ProtocolCaller, ) -> VmResult { - match (vm_try!(self.as_ref()), vm_try!(b.as_ref())) { - (ReprRef::Inline(lhs), ReprRef::Inline(rhs)) => { - return VmResult::Ok(vm_try!(lhs.partial_eq(rhs))); - } - (ReprRef::Inline(lhs), rhs) => { - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_EQ.name, - lhs: lhs.type_info(), - rhs: vm_try!(rhs.type_info()), - }); - } - (ReprRef::Mutable(lhs), ReprRef::Mutable(rhs)) => { - let lhs = vm_try!(lhs.borrow_ref()); - let rhs = vm_try!(rhs.borrow_ref()); - - let (lhs_rtti, lhs) = lhs.as_parts(); - let (rhs_rtti, rhs) = rhs.as_parts(); - - if lhs_rtti.hash == rhs_rtti.hash { - if lhs_rtti.variant_hash != rhs_rtti.variant_hash { - return VmResult::Ok(false); - } - - return Vec::eq_with(lhs, rhs, Value::partial_eq_with, caller); - } - - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_EQ.name, - lhs: lhs_rtti.clone().type_info(), - rhs: rhs_rtti.clone().type_info(), - }); - } - (ReprRef::Any(value), _) => match value.type_hash() { - runtime::Vec::HASH => { - let vec = vm_try!(value.borrow_ref::()); - return Vec::partial_eq_with(&vec, b.clone(), caller); - } - runtime::OwnedTuple::HASH => { - let tuple = vm_try!(value.borrow_ref::()); - return Vec::partial_eq_with(&tuple, b.clone(), caller); - } - _ => {} - }, - _ => {} - } - - if let CallResultOnly::Ok(value) = vm_try!(caller.try_call_protocol_fn( + self.bin_op_with( + b, + caller, Protocol::PARTIAL_EQ, - self.clone(), - &mut Some((b.clone(),)) - )) { - return VmResult::Ok(vm_try!(<_>::from_value(value))); - } - - err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_EQ.name, - lhs: vm_try!(self.type_info()), - rhs: vm_try!(b.type_info()), - }) - } - - /// Hash the current value. - #[cfg(feature = "alloc")] - pub fn hash(&self, hasher: &mut Hasher) -> VmResult<()> { - self.hash_with(hasher, &mut EnvProtocolCaller) - } - - /// Hash the current value. - #[cfg_attr(feature = "bench", inline(never))] - pub(crate) fn hash_with( - &self, - hasher: &mut Hasher, - caller: &mut dyn ProtocolCaller, - ) -> VmResult<()> { - match vm_try!(self.as_ref()) { - ReprRef::Inline(value) => match value { - Inline::Unsigned(value) => { - hasher.write_u64(*value); - return VmResult::Ok(()); - } - Inline::Signed(value) => { - hasher.write_i64(*value); - return VmResult::Ok(()); + Inline::partial_eq, + |lhs, rhs, caller| { + if lhs.0.variant_hash != rhs.0.variant_hash { + return VmResult::Ok(false); } - // Care must be taken whan hashing floats, to ensure that `hash(v1) - // === hash(v2)` if `eq(v1) === eq(v2)`. Hopefully we accomplish - // this by rejecting NaNs and rectifying subnormal values of zero. - Inline::Float(value) => { - if value.is_nan() { - return VmResult::err(VmErrorKind::IllegalFloatOperation { value: *value }); - } - let zero = *value == 0.0; - hasher.write_f64((zero as u8 as f64) * 0.0 + (!zero as u8 as f64) * *value); - return VmResult::Ok(()); - } - operand => { - return err(VmErrorKind::UnsupportedUnaryOperation { - op: Protocol::HASH.name, - operand: operand.type_info(), - }); - } - }, - ReprRef::Any(value) => match value.type_hash() { - Vec::HASH => { - let vec = vm_try!(value.borrow_ref::()); - return Vec::hash_with(&vec, hasher, caller); - } - OwnedTuple::HASH => { - let tuple = vm_try!(value.borrow_ref::()); - return Tuple::hash_with(&tuple, hasher, caller); - } - _ => {} + Vec::eq_with(lhs.1, rhs.1, Value::partial_eq_with, caller) }, - _ => {} - } - - let mut args = DynGuardedArgs::new((hasher,)); - - if let CallResultOnly::Ok(value) = - vm_try!(caller.try_call_protocol_fn(Protocol::HASH, self.clone(), &mut args)) - { - return VmResult::Ok(vm_try!(<_>::from_value(value))); - } - - err(VmErrorKind::UnsupportedUnaryOperation { - op: Protocol::HASH.name, - operand: vm_try!(self.type_info()), - }) + ) } /// Perform a total equality test between two values. @@ -1301,53 +1182,12 @@ impl Value { /// This is the basis for the eq operation (`==`). #[cfg_attr(feature = "bench", inline(never))] pub(crate) fn eq_with(&self, b: &Value, caller: &mut dyn ProtocolCaller) -> VmResult { - match (vm_try!(self.as_ref()), vm_try!(b.as_ref())) { - (ReprRef::Inline(lhs), ReprRef::Inline(rhs)) => { - return lhs.eq(rhs); - } - (ReprRef::Inline(lhs), rhs) => { - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::EQ.name, - lhs: lhs.type_info(), - rhs: vm_try!(rhs.type_info()), - }); + self.bin_op_with(b, caller, Protocol::EQ, Inline::eq, |lhs, rhs, caller| { + if lhs.0.variant_hash != rhs.0.variant_hash { + return VmResult::Ok(false); } - (ReprRef::Mutable(lhs), ReprRef::Mutable(rhs)) => { - let lhs = vm_try!(lhs.borrow_ref()); - let rhs = vm_try!(rhs.borrow_ref()); - let (lhs_rtti, lhs) = lhs.as_parts(); - let (rhs_rtti, rhs) = rhs.as_parts(); - - if lhs_rtti.hash == rhs_rtti.hash { - if lhs_rtti.variant_hash != rhs_rtti.variant_hash { - return VmResult::Ok(false); - } - - return Vec::eq_with(lhs, rhs, Value::eq_with, caller); - } - - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::EQ.name, - lhs: lhs_rtti.clone().type_info(), - rhs: rhs_rtti.clone().type_info(), - }); - } - _ => {} - } - - if let CallResultOnly::Ok(value) = vm_try!(caller.try_call_protocol_fn( - Protocol::EQ, - self.clone(), - &mut Some((b.clone(),)) - )) { - return VmResult::Ok(vm_try!(<_>::from_value(value))); - } - - err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::EQ.name, - lhs: vm_try!(self.type_info()), - rhs: vm_try!(b.type_info()), + Vec::eq_with(lhs.1, rhs.1, Value::eq_with, caller) }) } @@ -1374,56 +1214,21 @@ impl Value { b: &Value, caller: &mut dyn ProtocolCaller, ) -> VmResult> { - match (vm_try!(self.as_ref()), vm_try!(b.as_ref())) { - (ReprRef::Inline(lhs), ReprRef::Inline(rhs)) => { - return VmResult::Ok(vm_try!(lhs.partial_cmp(rhs))) - } - (ReprRef::Inline(lhs), rhs) => { - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_CMP.name, - lhs: lhs.type_info(), - rhs: vm_try!(rhs.type_info()), - }) - } - (ReprRef::Mutable(lhs), ReprRef::Mutable(rhs)) => { - let lhs = vm_try!(lhs.borrow_ref()); - let rhs = vm_try!(rhs.borrow_ref()); - - let (lhs_rtti, lhs) = lhs.as_parts(); - let (rhs_rtti, rhs) = rhs.as_parts(); - - if lhs_rtti.hash == rhs_rtti.hash { - let ord = lhs_rtti.variant_hash.cmp(&rhs_rtti.variant_hash); - - if ord != Ordering::Equal { - return VmResult::Ok(Some(ord)); - } + self.bin_op_with( + b, + caller, + Protocol::PARTIAL_CMP, + Inline::partial_cmp, + |lhs, rhs, caller| { + let ord = lhs.0.variant_hash.cmp(&rhs.0.variant_hash); - return Vec::partial_cmp_with(lhs, rhs, caller); + if ord != Ordering::Equal { + return VmResult::Ok(Some(ord)); } - return err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_CMP.name, - lhs: lhs_rtti.clone().type_info(), - rhs: rhs_rtti.clone().type_info(), - }); - } - _ => {} - } - - if let CallResultOnly::Ok(value) = vm_try!(caller.try_call_protocol_fn( - Protocol::PARTIAL_CMP, - self.clone(), - &mut Some((b.clone(),)) - )) { - return VmResult::Ok(vm_try!(<_>::from_value(value))); - } - - err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::PARTIAL_CMP.name, - lhs: vm_try!(self.type_info()), - rhs: vm_try!(b.type_info()), - }) + Vec::partial_cmp_with(lhs.1, rhs.1, caller) + }, + ) } /// Perform a total ordering comparison between two values. @@ -1449,11 +1254,82 @@ impl Value { b: &Value, caller: &mut dyn ProtocolCaller, ) -> VmResult { + self.bin_op_with(b, caller, Protocol::CMP, Inline::cmp, |lhs, rhs, caller| { + let ord = lhs.0.variant_hash.cmp(&rhs.0.variant_hash); + + if ord != Ordering::Equal { + return VmResult::Ok(ord); + } + + Vec::cmp_with(lhs.1, rhs.1, caller) + }) + } + + /// Hash the current value. + #[cfg(feature = "alloc")] + pub fn hash(&self, hasher: &mut Hasher) -> VmResult<()> { + self.hash_with(hasher, &mut EnvProtocolCaller) + } + + /// Hash the current value. + #[cfg_attr(feature = "bench", inline(never))] + pub(crate) fn hash_with( + &self, + hasher: &mut Hasher, + caller: &mut dyn ProtocolCaller, + ) -> VmResult<()> { + match vm_try!(self.as_ref()) { + ReprRef::Inline(value) => return VmResult::Ok(vm_try!(value.hash(hasher))), + ReprRef::Any(value) => match value.type_hash() { + Vec::HASH => { + let vec = vm_try!(value.borrow_ref::()); + return Vec::hash_with(&vec, hasher, caller); + } + OwnedTuple::HASH => { + let tuple = vm_try!(value.borrow_ref::()); + return Tuple::hash_with(&tuple, hasher, caller); + } + _ => {} + }, + _ => {} + } + + let mut args = DynGuardedArgs::new((hasher,)); + + if let CallResultOnly::Ok(value) = + vm_try!(caller.try_call_protocol_fn(Protocol::HASH, self.clone(), &mut args)) + { + return VmResult::Ok(vm_try!(<_>::from_value(value))); + } + + err(VmErrorKind::UnsupportedUnaryOperation { + op: Protocol::HASH.name, + operand: vm_try!(self.type_info()), + }) + } + + fn bin_op_with( + &self, + b: &Value, + caller: &mut dyn ProtocolCaller, + protocol: Protocol, + inline: fn(&Inline, &Inline) -> Result, + dynamic: fn( + (&Arc, &[Value]), + (&Arc, &[Value]), + &mut dyn ProtocolCaller, + ) -> VmResult, + ) -> VmResult + where + T: FromValue, + { match (vm_try!(self.as_ref()), vm_try!(b.as_ref())) { - (ReprRef::Inline(lhs), ReprRef::Inline(rhs)) => return lhs.cmp(rhs), + (ReprRef::Inline(lhs), ReprRef::Inline(rhs)) => { + return VmResult::Ok(vm_try!(inline(lhs, rhs))) + } (ReprRef::Inline(lhs), rhs) => { return VmResult::err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::CMP.name, + op: protocol.name, lhs: lhs.type_info(), rhs: vm_try!(rhs.type_info()), }); @@ -1466,17 +1342,11 @@ impl Value { let (rhs_rtti, rhs) = rhs.as_parts(); if lhs_rtti.hash == rhs_rtti.hash { - let ord = lhs_rtti.variant_hash.cmp(&rhs_rtti.variant_hash); - - if ord != Ordering::Equal { - return VmResult::Ok(ord); - } - - return Vec::cmp_with(lhs, rhs, caller); + return dynamic((lhs_rtti, lhs), (rhs_rtti, rhs), caller); } return VmResult::err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::CMP.name, + op: protocol.name, lhs: lhs_rtti.clone().type_info(), rhs: rhs_rtti.clone().type_info(), }); @@ -1484,16 +1354,14 @@ impl Value { _ => {} } - if let CallResultOnly::Ok(value) = vm_try!(caller.try_call_protocol_fn( - Protocol::CMP, - self.clone(), - &mut Some((b.clone(),)) - )) { - return VmResult::Ok(vm_try!(<_>::from_value(value))); + if let CallResultOnly::Ok(value) = + vm_try!(caller.try_call_protocol_fn(protocol, self.clone(), &mut Some((b.clone(),)))) + { + return VmResult::Ok(vm_try!(T::from_value(value))); } err(VmErrorKind::UnsupportedBinaryOperation { - op: Protocol::CMP.name, + op: protocol.name, lhs: vm_try!(self.type_info()), rhs: vm_try!(b.type_info()), }) diff --git a/crates/rune/src/runtime/value/inline.rs b/crates/rune/src/runtime/value/inline.rs index ceb970c23..7cab10f1f 100644 --- a/crates/rune/src/runtime/value/inline.rs +++ b/crates/rune/src/runtime/value/inline.rs @@ -1,18 +1,17 @@ use core::any; use core::cmp::Ordering; use core::fmt; +use core::hash::Hash as _; use musli::{Decode, Encode}; use serde::{Deserialize, Serialize}; use crate::hash::Hash; use crate::runtime::{ - OwnedTuple, Protocol, RuntimeError, Type, TypeInfo, VmErrorKind, VmIntegerRepr, VmResult, + Hasher, OwnedTuple, Protocol, RuntimeError, Type, TypeInfo, VmErrorKind, VmIntegerRepr, }; use crate::TypeHash; -use super::err; - /// An inline value. #[derive(Clone, Copy, Encode, Decode, Deserialize, Serialize)] pub enum Inline { @@ -93,27 +92,30 @@ impl Inline { } /// Perform a total equality check over two inline values. - pub(crate) fn eq(&self, other: &Self) -> VmResult { + pub(crate) fn eq(&self, other: &Self) -> Result { match (self, other) { - (Inline::Unit, Inline::Unit) => VmResult::Ok(true), - (Inline::Bool(a), Inline::Bool(b)) => VmResult::Ok(*a == *b), - (Inline::Char(a), Inline::Char(b)) => VmResult::Ok(*a == *b), - (Inline::Unsigned(a), Inline::Unsigned(b)) => VmResult::Ok(*a == *b), - (Inline::Signed(a), Inline::Signed(b)) => VmResult::Ok(*a == *b), + (Inline::Unit, Inline::Unit) => Ok(true), + (Inline::Bool(a), Inline::Bool(b)) => Ok(*a == *b), + (Inline::Char(a), Inline::Char(b)) => Ok(*a == *b), + (Inline::Unsigned(a), Inline::Unsigned(b)) => Ok(*a == *b), + (Inline::Signed(a), Inline::Signed(b)) => Ok(*a == *b), (Inline::Float(a), Inline::Float(b)) => { let Some(ordering) = a.partial_cmp(b) else { - return VmResult::err(VmErrorKind::IllegalFloatComparison { lhs: *a, rhs: *b }); + return Err(RuntimeError::new(VmErrorKind::IllegalFloatComparison { + lhs: *a, + rhs: *b, + })); }; - VmResult::Ok(matches!(ordering, Ordering::Equal)) + Ok(matches!(ordering, Ordering::Equal)) } - (Inline::Type(a), Inline::Type(b)) => VmResult::Ok(*a == *b), - (Inline::Ordering(a), Inline::Ordering(b)) => VmResult::Ok(*a == *b), - (lhs, rhs) => err(VmErrorKind::UnsupportedBinaryOperation { + (Inline::Type(a), Inline::Type(b)) => Ok(*a == *b), + (Inline::Ordering(a), Inline::Ordering(b)) => Ok(*a == *b), + (lhs, rhs) => Err(RuntimeError::new(VmErrorKind::UnsupportedBinaryOperation { op: Protocol::EQ.name, lhs: lhs.type_info(), rhs: rhs.type_info(), - }), + })), } } @@ -147,29 +149,66 @@ impl Inline { } /// Total comparison implementation for inline. - pub(crate) fn cmp(&self, other: &Self) -> VmResult { + pub(crate) fn cmp(&self, other: &Self) -> Result { match (self, other) { - (Inline::Unit, Inline::Unit) => VmResult::Ok(Ordering::Equal), - (Inline::Bool(a), Inline::Bool(b)) => VmResult::Ok(a.cmp(b)), - (Inline::Char(a), Inline::Char(b)) => VmResult::Ok(a.cmp(b)), - (Inline::Unsigned(a), Inline::Unsigned(b)) => VmResult::Ok(a.cmp(b)), - (Inline::Signed(a), Inline::Signed(b)) => VmResult::Ok(a.cmp(b)), + (Inline::Unit, Inline::Unit) => Ok(Ordering::Equal), + (Inline::Bool(a), Inline::Bool(b)) => Ok(a.cmp(b)), + (Inline::Char(a), Inline::Char(b)) => Ok(a.cmp(b)), + (Inline::Unsigned(a), Inline::Unsigned(b)) => Ok(a.cmp(b)), + (Inline::Signed(a), Inline::Signed(b)) => Ok(a.cmp(b)), (Inline::Float(a), Inline::Float(b)) => { let Some(ordering) = a.partial_cmp(b) else { - return VmResult::err(VmErrorKind::IllegalFloatComparison { lhs: *a, rhs: *b }); + return Err(RuntimeError::new(VmErrorKind::IllegalFloatComparison { + lhs: *a, + rhs: *b, + })); }; - VmResult::Ok(ordering) + Ok(ordering) } - (Inline::Type(a), Inline::Type(b)) => VmResult::Ok(a.cmp(b)), - (Inline::Ordering(a), Inline::Ordering(b)) => VmResult::Ok(a.cmp(b)), - (lhs, rhs) => VmResult::err(VmErrorKind::UnsupportedBinaryOperation { + (Inline::Type(a), Inline::Type(b)) => Ok(a.cmp(b)), + (Inline::Ordering(a), Inline::Ordering(b)) => Ok(a.cmp(b)), + (lhs, rhs) => Err(RuntimeError::new(VmErrorKind::UnsupportedBinaryOperation { op: Protocol::CMP.name, lhs: lhs.type_info(), rhs: rhs.type_info(), - }), + })), } } + + /// Hash an inline value. + pub(crate) fn hash(&self, hasher: &mut Hasher) -> Result<(), RuntimeError> { + match self { + Inline::Unsigned(value) => { + value.hash(hasher); + } + Inline::Signed(value) => { + value.hash(hasher); + } + // Care must be taken whan hashing floats, to ensure that `hash(v1) + // === hash(v2)` if `eq(v1) === eq(v2)`. Hopefully we accomplish + // this by rejecting NaNs and rectifying subnormal values of zero. + Inline::Float(value) => { + if value.is_nan() { + return Err(RuntimeError::new(VmErrorKind::IllegalFloatOperation { + value: *value, + })); + } + + let zero = *value == 0.0; + let value = ((zero as u8 as f64) * 0.0 + (!zero as u8 as f64) * *value).to_bits(); + value.hash(hasher); + } + operand => { + return Err(RuntimeError::new(VmErrorKind::UnsupportedUnaryOperation { + op: Protocol::HASH.name, + operand: operand.type_info(), + })); + } + } + + Ok(()) + } } impl fmt::Debug for Inline {