diff --git a/crates/formality-check/src/impls.rs b/crates/formality-check/src/impls.rs index a8e04a03..b7b9b491 100644 --- a/crates/formality-check/src/impls.rs +++ b/crates/formality-check/src/impls.rs @@ -2,12 +2,12 @@ use anyhow::bail; use fn_error_context::context; use formality_core::Downcasted; -use formality_prove::Env; +use formality_prove::{Env, Safety}; use formality_rust::{ grammar::{ AssociatedTy, AssociatedTyBoundData, AssociatedTyValue, AssociatedTyValueBoundData, Fn, - FnBoundData, ImplItem, NegTraitImpl, NegTraitImplBoundData, TraitBoundData, TraitImpl, - TraitImplBoundData, TraitItem, + FnBoundData, ImplItem, NegTraitImpl, NegTraitImplBoundData, Trait, TraitBoundData, + TraitImpl, TraitImplBoundData, TraitItem, }, prove::ToWcs, }; @@ -17,10 +17,8 @@ use formality_types::{ }; impl super::Check<'_> { - #[context("check_trait_impl({v:?})")] - pub(super) fn check_trait_impl(&self, v: &TraitImpl) -> Fallible<()> { - let TraitImpl { binder } = v; - + #[context("check_trait_impl({trait_impl:?})")] + pub(super) fn check_trait_impl(&self, trait_impl: &TraitImpl) -> Fallible<()> { let mut env = Env::default(); let TraitImplBoundData { @@ -29,7 +27,7 @@ impl super::Check<'_> { trait_parameters, where_clauses, impl_items, - } = env.instantiate_universally(binder); + } = env.instantiate_universally(&trait_impl.binder); let trait_ref = trait_id.with(self_ty, trait_parameters); @@ -45,6 +43,8 @@ impl super::Check<'_> { trait_items, } = trait_decl.binder.instantiate_with(&trait_ref.parameters)?; + self.check_safety_matches(&trait_decl, &trait_impl)?; + for impl_item in &impl_items { self.check_trait_impl_item(&env, &where_clauses, &trait_items, impl_item)?; } @@ -52,7 +52,8 @@ impl super::Check<'_> { Ok(()) } - pub(super) fn check_neg_trait_impl(&self, i: &NegTraitImpl) -> Fallible<()> { + #[context("check_neg_trait_impl({trait_impl:?})")] + pub(super) fn check_neg_trait_impl(&self, trait_impl: &NegTraitImpl) -> Fallible<()> { let mut env = Env::default(); let NegTraitImplBoundData { @@ -60,10 +61,15 @@ impl super::Check<'_> { self_ty, trait_parameters, where_clauses, - } = env.instantiate_universally(&i.binder); + } = env.instantiate_universally(&trait_impl.binder); let trait_ref = trait_id.with(self_ty, trait_parameters); + // Negative impls are always safe (rustc E0198) regardless of the trait's safety. + if trait_impl.safety == Safety::Unsafe { + bail!("negative impls cannot be unsafe"); + } + self.prove_where_clauses_well_formed(&env, &where_clauses, &where_clauses)?; self.prove_goal(&env, &where_clauses, trait_ref.not_implemented())?; @@ -71,6 +77,20 @@ impl super::Check<'_> { Ok(()) } + /// Validate that the declared safety of an impl matches the one from the trait declaration. + fn check_safety_matches(&self, trait_decl: &Trait, trait_impl: &TraitImpl) -> Fallible<()> { + if trait_decl.safety != trait_impl.safety { + match trait_decl.safety { + Safety::Safe => bail!("implementing the trait `{:?}` is not unsafe", trait_decl.id), + Safety::Unsafe => bail!( + "the trait `{:?}` requires an `unsafe impl` declaration", + trait_decl.id + ), + } + } + Ok(()) + } + fn check_trait_impl_item( &self, env: &Env, diff --git a/crates/formality-check/src/traits.rs b/crates/formality-check/src/traits.rs index 1347ce92..22791031 100644 --- a/crates/formality-check/src/traits.rs +++ b/crates/formality-check/src/traits.rs @@ -8,7 +8,11 @@ use formality_types::grammar::Fallible; impl super::Check<'_> { #[context("check_trait({:?})", t.id)] pub(super) fn check_trait(&self, t: &Trait) -> Fallible<()> { - let Trait { id: _, binder } = t; + let Trait { + safety: _, + id: _, + binder, + } = t; let mut env = Env::default(); let TraitBoundData { diff --git a/crates/formality-core/src/judgment/test_filtered.rs b/crates/formality-core/src/judgment/test_filtered.rs index d162c22a..d36f912d 100644 --- a/crates/formality-core/src/judgment/test_filtered.rs +++ b/crates/formality-core/src/judgment/test_filtered.rs @@ -1,8 +1,8 @@ #![cfg(test)] use std::sync::Arc; -use crate::cast_impl; +use crate::cast_impl; use crate::judgment_fn; #[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Debug, Hash)] @@ -24,14 +24,14 @@ impl Graph { judgment_fn!( fn transitive_reachable(g: Arc, node: u32) => u32 { debug(node, g) - + ( (graph.successors(a) => b) (if b % 2 == 0) --------------------------------------- ("base") (transitive_reachable(graph, a) => b) ) - + ( (transitive_reachable(&graph, a) => b) (transitive_reachable(&graph, b) => c) diff --git a/crates/formality-macros/src/debug.rs b/crates/formality-macros/src/debug.rs index 98401013..69c1a206 100644 --- a/crates/formality-macros/src/debug.rs +++ b/crates/formality-macros/src/debug.rs @@ -242,7 +242,7 @@ fn debug_variant_with_attr( fn debug_field_with_mode(name: &Ident, mode: &FieldMode) -> TokenStream { match mode { - FieldMode::Single | FieldMode::Optional => { + FieldMode::Single => { quote_spanned! { name.span() => write!(fmt, "{}", sep)?; write!(fmt, "{:?}", #name)?; @@ -250,6 +250,16 @@ fn debug_field_with_mode(name: &Ident, mode: &FieldMode) -> TokenStream { } } + FieldMode::Optional => { + quote_spanned! { name.span() => + if !::formality_core::util::is_default(#name) { + write!(fmt, "{}", sep)?; + write!(fmt, "{:?}", #name)?; + sep = " "; + } + } + } + FieldMode::Many => { quote_spanned! { name.span() => for e in #name { diff --git a/crates/formality-prove/src/db.rs b/crates/formality-prove/src/db.rs index e69de29b..8b137891 100644 --- a/crates/formality-prove/src/db.rs +++ b/crates/formality-prove/src/db.rs @@ -0,0 +1 @@ + diff --git a/crates/formality-prove/src/decls.rs b/crates/formality-prove/src/decls.rs index 38307b36..6f725c01 100644 --- a/crates/formality-prove/src/decls.rs +++ b/crates/formality-prove/src/decls.rs @@ -1,3 +1,5 @@ +use std::fmt; + use formality_core::{set, Set, Upcast}; use formality_macros::term; use formality_types::grammar::{ @@ -104,8 +106,10 @@ impl Decls { /// An "impl decl" indicates that a trait is implemented for a given set of types. /// One "impl decl" is created for each impl in the Rust source. -#[term(impl $binder)] +#[term($?safety impl $binder)] pub struct ImplDecl { + /// The safety this impl declares, which needs to match the implemented trait's safety. + pub safety: Safety, /// The binder covers the generic variables from the impl pub binder: Binder, } @@ -122,8 +126,11 @@ pub struct ImplDeclBoundData { /// A declaration that some trait will *not* be implemented for a type; derived from negative impls /// like `impl !Foo for Bar`. -#[term(impl $binder)] +#[term($?safety impl $binder)] pub struct NegImplDecl { + /// The safety this negative impl declares + pub safety: Safety, + /// Binder comes the generics on the impl pub binder: Binder, } @@ -135,15 +142,37 @@ pub struct NegImplDeclBoundData { pub where_clause: Wcs, } +/// Mark a trait or trait impl as `unsafe`. +#[term] +#[customize(debug)] +#[derive(Default)] +pub enum Safety { + #[default] + Safe, + Unsafe, +} + +impl fmt::Debug for Safety { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Safety::Safe => write!(f, "safe"), + Safety::Unsafe => write!(f, "unsafe"), + } + } +} + /// A "trait declaration" declares a trait that exists, its generics, and its where-clauses. /// It doesn't capture the trait items, which will be transformed into other sorts of rules. /// /// In Rust syntax, it covers the `trait Foo: Bar` part of the declaration, but not what appears in the `{...}`. -#[term(trait $id $binder)] +#[term($?safety trait $id $binder)] pub struct TraitDecl { /// The name of the trait pub id: TraitId, + /// Whether the trait is `unsafe` or not + pub safety: Safety, + /// The binder here captures the generics of the trait; it always begins with a `Self` type. pub binder: Binder, } diff --git a/crates/formality-prove/src/prove/minimize/test.rs b/crates/formality-prove/src/prove/minimize/test.rs index a80688a2..c62e9e8f 100644 --- a/crates/formality-prove/src/prove/minimize/test.rs +++ b/crates/formality-prove/src/prove/minimize/test.rs @@ -23,7 +23,7 @@ fn minimize_a() { let (mut env_min, term_min, m) = minimize(env, term); expect!["(Env { variables: [?ty_0, ?ty_1], coherence_mode: false }, [?ty_0, ?ty_1])"] - .assert_eq(&format!("{:?}", (&env_min, &term_min))); + .assert_eq(&format!("{:?}", (&env_min, &term_min))); let ty0 = term_min[0].as_variable().unwrap(); let ty1 = term_min[1].as_variable().unwrap(); diff --git a/crates/formality-rust/src/grammar.rs b/crates/formality-rust/src/grammar.rs index f36f52f9..69ec1ba6 100644 --- a/crates/formality-rust/src/grammar.rs +++ b/crates/formality-rust/src/grammar.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use formality_core::{term, Upcast}; +use formality_prove::Safety; use formality_types::{ grammar::{ AdtId, AliasTy, AssociatedItemId, Binder, Const, CrateId, Fallible, FieldId, FnId, Lt, @@ -160,8 +161,9 @@ pub struct Variant { pub fields: Vec, } -#[term(trait $id $binder)] +#[term($?safety trait $id $binder)] pub struct Trait { + pub safety: Safety, pub id: TraitId, pub binder: TraitBinder, } @@ -241,8 +243,9 @@ pub struct AssociatedTyBoundData { pub where_clauses: Vec, } -#[term(impl $binder)] +#[term($?safety impl $binder)] pub struct TraitImpl { + pub safety: Safety, pub binder: Binder, } @@ -267,8 +270,9 @@ impl TraitImplBoundData { } } -#[term(impl $binder)] +#[term($?safety impl $binder)] pub struct NegTraitImpl { + pub safety: Safety, pub binder: Binder, } diff --git a/crates/formality-rust/src/prove.rs b/crates/formality-rust/src/prove.rs index f7bc8996..6224e894 100644 --- a/crates/formality-rust/src/prove.rs +++ b/crates/formality-rust/src/prove.rs @@ -80,7 +80,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::Trait(Trait { id, binder }) => { + CrateItem::Trait(Trait { id, binder, safety }) => { let ( vars, TraitBoundData { @@ -89,6 +89,7 @@ impl Crate { }, ) = binder.open(); Some(prove::TraitDecl { + safety: safety.clone(), id: id.clone(), binder: Binder::new( vars, @@ -110,7 +111,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::TraitImpl(TraitImpl { binder }) => { + CrateItem::TraitImpl(TraitImpl { binder, safety }) => { let ( vars, TraitImplBoundData { @@ -122,6 +123,7 @@ impl Crate { }, ) = binder.open(); Some(prove::ImplDecl { + safety: safety.clone(), binder: Binder::new( vars, prove::ImplDeclBoundData { @@ -140,7 +142,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::NegTraitImpl(NegTraitImpl { binder }) => { + CrateItem::NegTraitImpl(NegTraitImpl { binder, safety }) => { let ( vars, NegTraitImplBoundData { @@ -151,6 +153,7 @@ impl Crate { }, ) = binder.open(); Some(prove::NegImplDecl { + safety: safety.clone(), binder: Binder::new( vars, prove::NegImplDeclBoundData { @@ -169,7 +172,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::TraitImpl(TraitImpl { binder }) => { + CrateItem::TraitImpl(TraitImpl { binder, safety: _ }) => { let ( impl_vars, TraitImplBoundData { @@ -226,6 +229,7 @@ impl Crate { .iter() .flat_map(|item| match item { CrateItem::Trait(Trait { + safety: _, id: trait_id, binder, }) => { diff --git a/examples/formality-eg/type_system.rs b/examples/formality-eg/type_system.rs index e69de29b..8b137891 100644 --- a/examples/formality-eg/type_system.rs +++ b/examples/formality-eg/type_system.rs @@ -0,0 +1 @@ + diff --git a/rust-toolchain b/rust-toolchain index 0a823a6e..c192ab89 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1,3 +1,3 @@ [toolchain] channel = "nightly-2023-10-08" -components = [ "rustc-dev", "llvm-tools" ] +components = ["rustc-dev", "llvm-tools", "rustfmt"] diff --git a/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.stderr b/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.stderr new file mode 100644 index 00000000..00127498 --- /dev/null +++ b/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.stderr @@ -0,0 +1,4 @@ +Error: check_neg_trait_impl(unsafe impl ! Foo for u32 {}) + +Caused by: + negative impls cannot be unsafe diff --git "a/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.\360\237\224\254" "b/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.\360\237\224\254" new file mode 100644 index 00000000..cad4149a --- /dev/null +++ "b/tests/ui/decl_safety/safe_trait-negative_impl_mismatch.\360\237\224\254" @@ -0,0 +1,6 @@ +[ + crate baguette { + trait Foo {} + unsafe impl !Foo for u32 {} + } +] diff --git "a/tests/ui/decl_safety/safe_trait.\360\237\224\254" "b/tests/ui/decl_safety/safe_trait.\360\237\224\254" new file mode 100644 index 00000000..dd6610a4 --- /dev/null +++ "b/tests/ui/decl_safety/safe_trait.\360\237\224\254" @@ -0,0 +1,7 @@ +//@check-pass +[ + crate baguette { + safe trait Foo {} + safe impl Foo for u32 {} + } +] diff --git a/tests/ui/decl_safety/safe_trait_mismatch.stderr b/tests/ui/decl_safety/safe_trait_mismatch.stderr new file mode 100644 index 00000000..374a2982 --- /dev/null +++ b/tests/ui/decl_safety/safe_trait_mismatch.stderr @@ -0,0 +1,4 @@ +Error: check_trait_impl(unsafe impl Foo for u32 { }) + +Caused by: + implementing the trait `Foo` is not unsafe diff --git "a/tests/ui/decl_safety/safe_trait_mismatch.\360\237\224\254" "b/tests/ui/decl_safety/safe_trait_mismatch.\360\237\224\254" new file mode 100644 index 00000000..7a2f8d41 --- /dev/null +++ "b/tests/ui/decl_safety/safe_trait_mismatch.\360\237\224\254" @@ -0,0 +1,6 @@ +[ + crate baguette { + trait Foo {} + unsafe impl Foo for u32 {} + } +] diff --git "a/tests/ui/decl_safety/unsafe_trait-negative_impl.\360\237\224\254" "b/tests/ui/decl_safety/unsafe_trait-negative_impl.\360\237\224\254" new file mode 100644 index 00000000..a1fdf102 --- /dev/null +++ "b/tests/ui/decl_safety/unsafe_trait-negative_impl.\360\237\224\254" @@ -0,0 +1,7 @@ +//@check-pass +[ + crate baguette { + unsafe trait Foo {} + impl !Foo for u32 {} + } +] diff --git a/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.stderr b/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.stderr new file mode 100644 index 00000000..00127498 --- /dev/null +++ b/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.stderr @@ -0,0 +1,4 @@ +Error: check_neg_trait_impl(unsafe impl ! Foo for u32 {}) + +Caused by: + negative impls cannot be unsafe diff --git "a/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.\360\237\224\254" "b/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.\360\237\224\254" new file mode 100644 index 00000000..2ba98a14 --- /dev/null +++ "b/tests/ui/decl_safety/unsafe_trait-negative_impl_mismatch.\360\237\224\254" @@ -0,0 +1,6 @@ +[ + crate baguette { + unsafe trait Foo {} + unsafe impl !Foo for u32 {} + } +] diff --git "a/tests/ui/decl_safety/unsafe_trait.\360\237\224\254" "b/tests/ui/decl_safety/unsafe_trait.\360\237\224\254" new file mode 100644 index 00000000..f93b0dac --- /dev/null +++ "b/tests/ui/decl_safety/unsafe_trait.\360\237\224\254" @@ -0,0 +1,7 @@ +//@check-pass +[ + crate baguette { + unsafe trait Foo {} + unsafe impl Foo for u32 {} + } +] diff --git a/tests/ui/decl_safety/unsafe_trait_mismatch.stderr b/tests/ui/decl_safety/unsafe_trait_mismatch.stderr new file mode 100644 index 00000000..662e2398 --- /dev/null +++ b/tests/ui/decl_safety/unsafe_trait_mismatch.stderr @@ -0,0 +1,4 @@ +Error: check_trait_impl(impl Foo for u32 { }) + +Caused by: + the trait `Foo` requires an `unsafe impl` declaration diff --git "a/tests/ui/decl_safety/unsafe_trait_mismatch.\360\237\224\254" "b/tests/ui/decl_safety/unsafe_trait_mismatch.\360\237\224\254" new file mode 100644 index 00000000..48c1b0cd --- /dev/null +++ "b/tests/ui/decl_safety/unsafe_trait_mismatch.\360\237\224\254" @@ -0,0 +1,6 @@ +[ + crate baguette { + unsafe trait Foo {} + impl Foo for u32 {} + } +]