diff --git a/crates/formality-check/src/impls.rs b/crates/formality-check/src/impls.rs index 1fb4ae7d..fe18f224 100644 --- a/crates/formality-check/src/impls.rs +++ b/crates/formality-check/src/impls.rs @@ -1,7 +1,7 @@ use anyhow::bail; use fn_error_context::context; -use formality_prove::Env; +use formality_prove::{Env, Safety}; use formality_rust::{ grammar::{ AssociatedTy, AssociatedTyBoundData, AssociatedTyValue, AssociatedTyValueBoundData, Fn, @@ -19,7 +19,7 @@ use formality_types::{ impl super::Check<'_> { #[context("check_trait_impl({v:?})")] pub(super) fn check_trait_impl(&self, v: &TraitImpl) -> Fallible<()> { - let TraitImpl { binder } = v; + let TraitImpl { binder, safety } = v; let mut env = Env::default(); @@ -45,6 +45,8 @@ impl super::Check<'_> { trait_items, } = trait_decl.binder.instantiate_with(&trait_ref.parameters)?; + self.check_safety_matches(&trait_decl.safety, safety)?; + for impl_item in &impl_items { self.check_trait_impl_item(&env, &where_clauses, &trait_items, impl_item)?; } @@ -71,6 +73,21 @@ impl super::Check<'_> { Ok(()) } + /// Validate `unsafe trait` and `unsafe impl` line up + fn check_safety_matches(&self, trait_decl: &Safety, trait_impl: &Safety) -> Fallible<()> { + match trait_decl { + Safety::Safe => anyhow::ensure!( + matches!(trait_impl, Safety::Safe), + "implementing the trait is not `unsafe`" + ), + Safety::Unsafe => anyhow::ensure!( + matches!(trait_impl, Safety::Unsafe), + "the trait requires an `unsafe impl` declaration" + ), + } + Ok(()) + } + fn check_trait_impl_item( &self, env: &Env, diff --git a/crates/formality-check/src/traits.rs b/crates/formality-check/src/traits.rs index 1347ce92..22791031 100644 --- a/crates/formality-check/src/traits.rs +++ b/crates/formality-check/src/traits.rs @@ -8,7 +8,11 @@ use formality_types::grammar::Fallible; impl super::Check<'_> { #[context("check_trait({:?})", t.id)] pub(super) fn check_trait(&self, t: &Trait) -> Fallible<()> { - let Trait { id: _, binder } = t; + let Trait { + safety: _, + id: _, + binder, + } = t; let mut env = Env::default(); let TraitBoundData { diff --git a/crates/formality-prove/src/decls.rs b/crates/formality-prove/src/decls.rs index 4b7b170c..1fafa26b 100644 --- a/crates/formality-prove/src/decls.rs +++ b/crates/formality-prove/src/decls.rs @@ -2,11 +2,16 @@ use formality_macros::term; use formality_types::{ cast::Upcast, collections::Set, + derive_links::{DowncastTo, UpcastFrom}, + fold::Fold, grammar::{ AdtId, AliasName, AliasTy, Binder, Parameter, Predicate, Relation, TraitId, TraitRef, Ty, Wc, Wcs, PR, }, + parse::{self, Parse}, set, + term::Term, + visit::Visit, }; #[term] @@ -95,8 +100,9 @@ impl Decls { } } -#[term(impl $binder)] +#[term($safety impl $binder)] pub struct ImplDecl { + pub safety: Safety, pub binder: Binder, } @@ -106,8 +112,9 @@ pub struct ImplDeclBoundData { pub where_clause: Wcs, } -#[term(impl $binder)] +#[term($safety impl $binder)] pub struct NegImplDecl { + pub safety: Safety, pub binder: Binder, } @@ -117,8 +124,71 @@ pub struct NegImplDeclBoundData { pub where_clause: Wcs, } -#[term(trait $id $binder)] +/// Mark a trait or trait impl as `unsafe`. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Safety { + Safe, + Unsafe, +} + +// NOTE(yosh): `Debug` is currently used to print error messages with. In order +// to not print `safe impl` / `safe trait` where none is written, we leave the impl blank. +impl std::fmt::Debug for Safety { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Safe => write!(f, ""), + Self::Unsafe => write!(f, "unsafe"), + } + } +} + +impl Term for Safety {} + +impl DowncastTo for Safety { + fn downcast_to(&self) -> Option { + Some(Self::clone(self)) + } +} + +impl UpcastFrom for Safety { + fn upcast_from(term: Self) -> Self { + term + } +} + +impl Fold for Safety { + fn substitute(&self, _substitution_fn: formality_types::fold::SubstitutionFn<'_>) -> Self { + self.clone() + } +} + +impl Visit for Safety { + fn free_variables(&self) -> Vec { + vec![] + } + + fn size(&self) -> usize { + 1 + } + + fn assert_valid(&self) {} +} + +impl Parse for Safety { + fn parse<'t>( + _scope: &formality_types::parse::Scope, + text0: &'t str, + ) -> formality_types::parse::ParseResult<'t, Self> { + match parse::expect_optional_keyword("unsafe", text0) { + Some(text1) => Ok((Self::Unsafe, text1)), + None => Ok((Self::Safe, text0)), + } + } +} + +#[term($safety trait $id $binder)] pub struct TraitDecl { + pub safety: Safety, pub id: TraitId, pub binder: Binder, } diff --git a/crates/formality-rust/src/grammar.rs b/crates/formality-rust/src/grammar.rs index a77d8ab9..6cb2cb0c 100644 --- a/crates/formality-rust/src/grammar.rs +++ b/crates/formality-rust/src/grammar.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use formality_macros::term; +use formality_prove::Safety; use formality_types::{ cast::Upcast, grammar::{ @@ -161,8 +162,9 @@ pub struct Variant { pub fields: Vec, } -#[term(trait $id $binder)] +#[term($safety trait $id $binder)] pub struct Trait { + pub safety: Safety, pub id: TraitId, pub binder: TraitBinder, } @@ -242,8 +244,9 @@ pub struct AssociatedTyBoundData { pub where_clauses: Vec, } -#[term(impl $binder)] +#[term($safety impl $binder)] pub struct TraitImpl { + pub safety: Safety, pub binder: Binder, } @@ -268,8 +271,9 @@ impl TraitImplBoundData { } } -#[term(impl $binder)] +#[term($safety impl $binder)] pub struct NegTraitImpl { + pub safety: Safety, pub binder: Binder, } diff --git a/crates/formality-rust/src/prove.rs b/crates/formality-rust/src/prove.rs index 7264326f..ff2ddf8e 100644 --- a/crates/formality-rust/src/prove.rs +++ b/crates/formality-rust/src/prove.rs @@ -80,7 +80,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::Trait(Trait { id, binder }) => { + CrateItem::Trait(Trait { id, binder, safety }) => { let ( vars, TraitBoundData { @@ -89,6 +89,7 @@ impl Crate { }, ) = binder.open(); Some(prove::TraitDecl { + safety: safety.clone(), id: id.clone(), binder: Binder::new( &vars, @@ -110,7 +111,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::TraitImpl(TraitImpl { binder }) => { + CrateItem::TraitImpl(TraitImpl { binder, safety }) => { let ( vars, TraitImplBoundData { @@ -122,6 +123,7 @@ impl Crate { }, ) = binder.open(); Some(prove::ImplDecl { + safety: safety.clone(), binder: Binder::new( &vars, prove::ImplDeclBoundData { @@ -140,7 +142,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::NegTraitImpl(NegTraitImpl { binder }) => { + CrateItem::NegTraitImpl(NegTraitImpl { binder, safety }) => { let ( vars, NegTraitImplBoundData { @@ -151,6 +153,7 @@ impl Crate { }, ) = binder.open(); Some(prove::NegImplDecl { + safety: safety.clone(), binder: Binder::new( &vars, prove::NegImplDeclBoundData { @@ -169,7 +172,7 @@ impl Crate { self.items .iter() .flat_map(|item| match item { - CrateItem::TraitImpl(TraitImpl { binder }) => { + CrateItem::TraitImpl(TraitImpl { binder, safety: _ }) => { let ( impl_vars, TraitImplBoundData { @@ -225,6 +228,7 @@ impl Crate { .iter() .flat_map(|item| match item { CrateItem::Trait(Trait { + safety: _, id: trait_id, binder, }) => { diff --git a/crates/formality-types/src/parse.rs b/crates/formality-types/src/parse.rs index ce0c58ac..ab43587e 100644 --- a/crates/formality-types/src/parse.rs +++ b/crates/formality-types/src/parse.rs @@ -347,6 +347,15 @@ pub fn expect_keyword<'t>(expected: &str, text0: &'t str) -> ParseResult<'t, ()> } } +/// Attempt to consume next identifier if it is equal to `expected`. +#[tracing::instrument(level = "trace", ret)] +pub fn expect_optional_keyword<'t>(expected: &str, text0: &'t str) -> Option<&'t str> { + match identifier(text0) { + Ok((ident, text1)) if &*ident == expected => Some(text1), + _ => None, + } +} + /// Reject next identifier if it is the given keyword. Consumes nothing. #[tracing::instrument(level = "trace", ret)] pub fn reject_keyword<'t>(expected: &str, text0: &'t str) -> ParseResult<'t, ()> { diff --git a/tests/coherence_orphan.rs b/tests/coherence_orphan.rs index 40f7df5a..1af9efaf 100644 --- a/tests/coherence_orphan.rs +++ b/tests/coherence_orphan.rs @@ -8,7 +8,7 @@ fn test_orphan_CoreTrait_for_CoreStruct_in_Foo() { expect_test::expect![[r#" Err( Error { - context: "orphan_check(impl <> CoreTrait < > for (rigid (adt CoreStruct)) where [] { })", + context: "orphan_check( impl <> CoreTrait < > for (rigid (adt CoreStruct)) where [] { })", source: "failed to prove {@ IsLocal(CoreTrait((rigid (adt CoreStruct))))} given {}, got {}", }, ) @@ -31,7 +31,7 @@ fn test_orphan_neg_CoreTrait_for_CoreStruct_in_Foo() { expect_test::expect![[r#" Err( Error { - context: "orphan_check_neg(impl <> ! CoreTrait < > for (rigid (adt CoreStruct)) where [] {})", + context: "orphan_check_neg( impl <> ! CoreTrait < > for (rigid (adt CoreStruct)) where [] {})", source: "failed to prove {@ IsLocal(CoreTrait((rigid (adt CoreStruct))))} given {}, got {}", }, ) @@ -54,7 +54,7 @@ fn test_orphan_mirror_CoreStruct() { expect_test::expect![[r#" Err( Error { - context: "orphan_check(impl <> CoreTrait < > for (alias (Mirror :: Assoc) (rigid (adt CoreStruct))) where [] { })", + context: "orphan_check( impl <> CoreTrait < > for (alias (Mirror :: Assoc) (rigid (adt CoreStruct))) where [] { })", source: "failed to prove {@ IsLocal(CoreTrait((alias (Mirror :: Assoc) (rigid (adt CoreStruct)))))} given {}, got {}", }, ) @@ -119,7 +119,7 @@ fn test_orphan_alias_to_unit() { expect_test::expect![[r#" Err( Error { - context: "orphan_check(impl <> CoreTrait < > for (alias (Unit :: Assoc) (rigid (adt FooStruct))) where [] { })", + context: "orphan_check( impl <> CoreTrait < > for (alias (Unit :: Assoc) (rigid (adt FooStruct))) where [] { })", source: "failed to prove {@ IsLocal(CoreTrait((alias (Unit :: Assoc) (rigid (adt FooStruct)))))} given {}, got {}", }, ) @@ -150,7 +150,7 @@ fn test_orphan_uncovered_T() { expect_test::expect![[r#" Err( Error { - context: "orphan_check(impl CoreTrait < (rigid (adt FooStruct)) > for ^ty0_0 where [] { })", + context: "orphan_check( impl CoreTrait < (rigid (adt FooStruct)) > for ^ty0_0 where [] { })", source: "failed to prove {@ IsLocal(CoreTrait(!ty_1, (rigid (adt FooStruct))))} given {}, got {}", }, ) diff --git a/tests/coherence_overlap.rs b/tests/coherence_overlap.rs index e3cca598..90cfc703 100644 --- a/tests/coherence_overlap.rs +++ b/tests/coherence_overlap.rs @@ -27,7 +27,7 @@ fn test_u32_u32_impls() { // Test that we detect duplicate impls. expect_test::expect![[r#" Err( - "duplicate impl in current crate: impl <> Foo < > for (rigid (scalar u32)) where [] { }", + "duplicate impl in current crate: impl <> Foo < > for (rigid (scalar u32)) where [] { }", ) "#]] .assert_debug_eq(&test_program_ok( @@ -46,7 +46,7 @@ fn test_u32_T_impls() { // Test that we detect overlap involving generic parameters. expect_test::expect![[r#" Err( - "impls may overlap: `impl <> Foo < > for (rigid (scalar u32)) where [] { }` vs `impl Foo < > for ^ty0_0 where [] { }`", + "impls may overlap: ` impl <> Foo < > for (rigid (scalar u32)) where [] { }` vs ` impl Foo < > for ^ty0_0 where [] { }`", ) "#]] .assert_debug_eq(&test_program_ok( @@ -90,7 +90,7 @@ fn test_u32_T_where_T_Is_impls() { // and also all `T: Is`, and `u32: Is`. expect_test::expect![[r#" Err( - "impls may overlap: `impl <> Foo < > for (rigid (scalar u32)) where [] { }` vs `impl Foo < > for ^ty0_0 where [^ty0_0 : Is < >] { }`", + "impls may overlap: ` impl <> Foo < > for (rigid (scalar u32)) where [] { }` vs ` impl Foo < > for ^ty0_0 where [^ty0_0 : Is < >] { }`", ) "#]] .assert_debug_eq(&test_program_ok( @@ -113,7 +113,7 @@ fn test_u32_not_u32_impls() { expect_test::expect![[r#" Err( Error { - context: "check_trait_impl(impl <> Foo < > for (rigid (scalar u32)) where [] { })", + context: "check_trait_impl( impl <> Foo < > for (rigid (scalar u32)) where [] { })", source: "failed to disprove {! Foo((rigid (scalar u32)))} given {}, got {Constraints { env: Env { variables: [], coherence_mode: false }, known_true: true, substitution: {} }}", }, ) @@ -139,7 +139,7 @@ fn test_T_where_Foo_not_u32_impls() { expect_test::expect![[r#" Err( Error { - context: "check_trait_impl(impl Foo < > for ^ty0_0 where [^ty0_0 : Foo < >] { })", + context: "check_trait_impl( impl Foo < > for ^ty0_0 where [^ty0_0 : Foo < >] { })", source: "failed to disprove {! Foo(!ty_1)} given {Foo(!ty_1)}, got {Constraints { env: Env { variables: [?ty_1], coherence_mode: false }, known_true: true, substitution: {?ty_1 => (rigid (scalar u32))} }}", }, ) @@ -159,7 +159,7 @@ fn test_T_where_Foo_not_u32_impls() { fn test_foo_crate_cannot_assume_CoreStruct_does_not_impl_CoreTrait() { expect_test::expect![[r#" Err( - "impls may overlap: `impl FooTrait < > for ^ty0_0 where [^ty0_0 : CoreTrait < >] { }` vs `impl <> FooTrait < > for (rigid (adt CoreStruct)) where [] { }`", + "impls may overlap: ` impl FooTrait < > for ^ty0_0 where [^ty0_0 : CoreTrait < >] { }` vs ` impl <> FooTrait < > for (rigid (adt CoreStruct)) where [] { }`", ) "#]] .assert_debug_eq(&test_program_ok( @@ -253,7 +253,7 @@ fn test_overlap_normalize_alias_to_LocalType() { expect_test::expect![[r#" Err( - "impls may overlap: `impl LocalTrait < > for ^ty0_0 where [^ty0_0 : Iterator < >] { }` vs `impl <> LocalTrait < > for (alias (Mirror :: T) (rigid (adt LocalType))) where [] { }`", + "impls may overlap: ` impl LocalTrait < > for ^ty0_0 where [^ty0_0 : Iterator < >] { }` vs ` impl <> LocalTrait < > for (alias (Mirror :: T) (rigid (adt LocalType))) where [] { }`", ) "#]] .assert_debug_eq(&test_program_ok(&gen_program("impl<> Iterator<> for LocalType<> where [] {}"))); @@ -312,7 +312,7 @@ fn test_overlap_alias_not_normalizable() { expect_test::expect![[r#" Err( - "impls may overlap: `impl LocalTrait < > for ^ty0_0 where [^ty0_0 : Iterator < >] { }` vs `impl LocalTrait < > for (alias (Mirror :: T) ^ty0_0) where [^ty0_0 : Mirror < >] { }`", + "impls may overlap: ` impl LocalTrait < > for ^ty0_0 where [^ty0_0 : Iterator < >] { }` vs ` impl LocalTrait < > for (alias (Mirror :: T) ^ty0_0) where [^ty0_0 : Mirror < >] { }`", ) "#]] // FIXME .assert_debug_eq(&test_program_ok(&gen_program( diff --git a/tests/unsafe-trait.rs b/tests/unsafe-trait.rs new file mode 100644 index 00000000..fe335eb0 --- /dev/null +++ b/tests/unsafe-trait.rs @@ -0,0 +1,40 @@ +use formality::test_program_ok; + +#[test] +fn unsafe_trait_requires_unsafe_impl_err() { + expect_test::expect![[r#" + Err( + Error { + context: "check_trait_impl( impl <> SendTrait < > for (rigid (adt SendStruct)) where [] { })", + source: "the trait requires an `unsafe impl` declaration", + }, + ) + "#]] + .assert_debug_eq(&test_program_ok( + "[ + crate core { + unsafe trait SendTrait<> where [] {} + struct SendStruct<> where [] {} + impl<> SendTrait<> for SendStruct<> where [] {} + } + ]", + )); +} + +#[test] +fn unsafe_trait_requires_unsafe_impl() { + expect_test::expect![[r#" + Ok( + (), + ) + "#]] + .assert_debug_eq(&test_program_ok( + "[ + crate core { + unsafe trait CoreTrait<> where [] {} + struct CoreStruct<> where [] {} + unsafe impl<> CoreTrait<> for CoreStruct<> where [] {} + } + ]", + )); +}