Skip to content

Commit

Permalink
implement unsafe trait support
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshuawuyts committed Jun 29, 2023
1 parent 0c78cd6 commit 7e83206
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 18 deletions.
21 changes: 19 additions & 2 deletions crates/formality-check/src/impls.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::bail;

use fn_error_context::context;
use formality_prove::Env;
use formality_prove::{Env, Safety};
use formality_rust::{
grammar::{
AssociatedTy, AssociatedTyBoundData, AssociatedTyValue, AssociatedTyValueBoundData, Fn,
Expand All @@ -19,7 +19,7 @@ use formality_types::{
impl super::Check<'_> {
#[context("check_trait_impl({v:?})")]
pub(super) fn check_trait_impl(&self, v: &TraitImpl) -> Fallible<()> {
let TraitImpl { binder } = v;
let TraitImpl { binder, safety } = v;

let mut env = Env::default();

Expand All @@ -45,6 +45,8 @@ impl super::Check<'_> {
trait_items,
} = trait_decl.binder.instantiate_with(&trait_ref.parameters)?;

self.check_safety_matches(&trait_decl.safety, safety)?;

for impl_item in &impl_items {
self.check_trait_impl_item(&env, &where_clauses, &trait_items, impl_item)?;
}
Expand All @@ -71,6 +73,21 @@ impl super::Check<'_> {
Ok(())
}

/// Validate `unsafe trait` and `unsafe impl` line up
fn check_safety_matches(&self, trait_decl: &Safety, trait_impl: &Safety) -> Fallible<()> {
match trait_decl {
Safety::Safe => anyhow::ensure!(
matches!(trait_impl, Safety::Safe),
"implementing the trait is not `unsafe`"
),
Safety::Unsafe => anyhow::ensure!(
matches!(trait_impl, Safety::Unsafe),
"the trait requires an `unsafe impl` declaration"
),
}
Ok(())
}

fn check_trait_impl_item(
&self,
env: &Env,
Expand Down
6 changes: 5 additions & 1 deletion crates/formality-check/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ use formality_types::grammar::Fallible;
impl super::Check<'_> {
#[context("check_trait({:?})", t.id)]
pub(super) fn check_trait(&self, t: &Trait) -> Fallible<()> {
let Trait { id: _, binder } = t;
let Trait {
safety: _,
id: _,
binder,
} = t;
let mut env = Env::default();

let TraitBoundData {
Expand Down
15 changes: 12 additions & 3 deletions crates/formality-prove/src/decls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ impl Decls {
}
}

#[term(impl $binder)]
#[term($safety impl $binder)]
pub struct ImplDecl {
pub safety: Safety,
pub binder: Binder<ImplDeclBoundData>,
}

Expand All @@ -106,8 +107,9 @@ pub struct ImplDeclBoundData {
pub where_clause: Wcs,
}

#[term(impl $binder)]
#[term($safety impl $binder)]
pub struct NegImplDecl {
pub safety: Safety,
pub binder: Binder<NegImplDeclBoundData>,
}

Expand All @@ -117,8 +119,15 @@ pub struct NegImplDeclBoundData {
pub where_clause: Wcs,
}

#[term(trait $id $binder)]
#[term]
pub enum Safety {
Safe,
Unsafe,
}

#[term($safety trait $id $binder)]
pub struct TraitDecl {
pub safety: Safety,
pub id: TraitId,
pub binder: Binder<TraitDeclBoundData>,
}
Expand Down
10 changes: 7 additions & 3 deletions crates/formality-rust/src/grammar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use formality_macros::term;
use formality_prove::Safety;
use formality_types::{
cast::Upcast,
grammar::{
Expand Down Expand Up @@ -161,8 +162,9 @@ pub struct Variant {
pub fields: Vec<Field>,
}

#[term(trait $id $binder)]
#[term($safety trait $id $binder)]
pub struct Trait {
pub safety: Safety,
pub id: TraitId,
pub binder: TraitBinder<TraitBoundData>,
}
Expand Down Expand Up @@ -242,8 +244,9 @@ pub struct AssociatedTyBoundData {
pub where_clauses: Vec<WhereClause>,
}

#[term(impl $binder)]
#[term($safety impl $binder)]
pub struct TraitImpl {
pub safety: Safety,
pub binder: Binder<TraitImplBoundData>,
}

Expand All @@ -268,8 +271,9 @@ impl TraitImplBoundData {
}
}

#[term(impl $binder)]
#[term($safety impl $binder)]
pub struct NegTraitImpl {
pub safety: Safety,
pub binder: Binder<NegTraitImplBoundData>,
}

Expand Down
12 changes: 8 additions & 4 deletions crates/formality-rust/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl Crate {
self.items
.iter()
.flat_map(|item| match item {
CrateItem::Trait(Trait { id, binder }) => {
CrateItem::Trait(Trait { id, binder, safety }) => {
let (
vars,
TraitBoundData {
Expand All @@ -89,6 +89,7 @@ impl Crate {
},
) = binder.open();
Some(prove::TraitDecl {
safety: safety.clone(),
id: id.clone(),
binder: Binder::new(
&vars,
Expand All @@ -110,7 +111,7 @@ impl Crate {
self.items
.iter()
.flat_map(|item| match item {
CrateItem::TraitImpl(TraitImpl { binder }) => {
CrateItem::TraitImpl(TraitImpl { binder, safety }) => {
let (
vars,
TraitImplBoundData {
Expand All @@ -122,6 +123,7 @@ impl Crate {
},
) = binder.open();
Some(prove::ImplDecl {
safety: safety.clone(),
binder: Binder::new(
&vars,
prove::ImplDeclBoundData {
Expand All @@ -140,7 +142,7 @@ impl Crate {
self.items
.iter()
.flat_map(|item| match item {
CrateItem::NegTraitImpl(NegTraitImpl { binder }) => {
CrateItem::NegTraitImpl(NegTraitImpl { binder, safety }) => {
let (
vars,
NegTraitImplBoundData {
Expand All @@ -151,6 +153,7 @@ impl Crate {
},
) = binder.open();
Some(prove::NegImplDecl {
safety: safety.clone(),
binder: Binder::new(
&vars,
prove::NegImplDeclBoundData {
Expand All @@ -169,7 +172,7 @@ impl Crate {
self.items
.iter()
.flat_map(|item| match item {
CrateItem::TraitImpl(TraitImpl { binder }) => {
CrateItem::TraitImpl(TraitImpl { binder, safety: _ }) => {
let (
impl_vars,
TraitImplBoundData {
Expand Down Expand Up @@ -225,6 +228,7 @@ impl Crate {
.iter()
.flat_map(|item| match item {
CrateItem::Trait(Trait {
safety: _,
id: trait_id,
binder,
}) => {
Expand Down
13 changes: 8 additions & 5 deletions tests/unsafe-trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ use formality::test_program_ok;
fn unsafe_trait_requires_unsafe_impl_err() {
expect_test::expect![[r#"
Err(
(),
Error {
context: "check_trait_impl(safe impl <> SendTrait < > for (rigid (adt SendStruct)) where [] { })",
source: "the trait requires an `unsafe impl` declaration",
},
)
"#]]
.assert_debug_eq(&test_program_ok(
"[
crate core {
unsafe trait SendTrait<> where [] {}
struct SendStruct<> where [] {}
impl<> SendTrait for SendStruct where [] {}
safe impl<> SendTrait<> for SendStruct<> where [] {}
}
]",
));
Expand All @@ -28,9 +31,9 @@ fn unsafe_trait_requires_unsafe_impl() {
.assert_debug_eq(&test_program_ok(
"[
crate core {
unsafe trait SendTrait<> where [] {}
struct SendStruct<> where [] {}
unsafe impl<> SendTrait for SendStruct where [] {}
unsafe trait CoreTrait<> where [] {}
struct CoreStruct<> where [] {}
unsafe impl<> CoreTrait<> for CoreStruct<> where [] {}
}
]",
));
Expand Down

0 comments on commit 7e83206

Please sign in to comment.