Skip to content

Commit

Permalink
Deriving Group trait
Browse files Browse the repository at this point in the history
  • Loading branch information
nanoqsh committed Jan 5, 2024
1 parent af140aa commit 2a01a9d
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 104 deletions.
56 changes: 27 additions & 29 deletions dunge/src/bind.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use {
crate::{
group::{BoundTexture, Group, Visitor},
group::{BoundTexture, Group},
shader::Shader,
state::State,
texture::Sampler,
Expand All @@ -11,44 +11,42 @@ use {
},
};

#[derive(Default)]
pub struct VisitGroup<'g>(Vec<BindGroupEntry<'g>>);

impl<'g> VisitGroup<'g> {
fn visit_texture(&mut self, texture: BoundTexture<'g>) {
self.push_resource(BindingResource::TextureView(texture.get().view()));
}
pub trait Visit: Group {
fn visit<'a>(&'a self, visitor: &mut Visitor<'a>);
}

fn visit_sampler(&mut self, sampler: &'g Sampler) {
self.push_resource(BindingResource::Sampler(sampler.inner()));
}
pub struct Visitor<'a>(Vec<BindGroupEntry<'a>>);

fn push_resource(&mut self, resource: BindingResource<'g>) {
impl<'a> Visitor<'a> {
fn push(&mut self, resource: BindingResource<'a>) {
let binding = self.0.len() as u32;
self.0.push(BindGroupEntry { binding, resource });
}
}

impl<'g> Visitor for VisitGroup<'g> {
type Texture = BoundTexture<'g>;
type Sampler = &'g Sampler;
pub trait VisitMember<'a> {
fn visit_member(self, visitor: &mut Visitor<'a>);
}

fn visit_texture(&mut self, texture: Self::Texture) {
self.visit_texture(texture);
impl<'a> VisitMember<'a> for BoundTexture<'a> {
fn visit_member(self, visitor: &mut Visitor<'a>) {
visitor.push(BindingResource::TextureView(self.get().view()));
}
}

fn visit_sampler(&mut self, sampler: Self::Sampler) {
self.visit_sampler(sampler);
impl<'a> VisitMember<'a> for &'a Sampler {
fn visit_member(self, visitor: &mut Visitor<'a>) {
visitor.push(BindingResource::Sampler(self.inner()));
}
}

fn visit<'g, G>(group: &'g G) -> Vec<BindGroupEntry<'g>>
fn visit<G>(group: &G) -> Vec<BindGroupEntry>
where
G: Group<Visitor<'g> = VisitGroup<'g>>,
G: Visit,
{
let mut visit = VisitGroup::default();
group.group(&mut visit);
visit.0
let mut visitor = Visitor(vec![]);
group.visit(&mut visitor);
visitor.0
}

pub struct GroupHandler<G> {
Expand Down Expand Up @@ -101,14 +99,14 @@ impl Binding for GroupBinding {

pub type Update = Result<(), ForeignShader>;

pub(crate) fn update<'g, G>(
pub(crate) fn update<G>(
state: &State,
uni: &mut UniqueGroupBinding,
handler: GroupHandler<G>,
group: &'g G,
group: &G,
) -> Update
where
G: Group<Visitor<'g> = VisitGroup<'g>>,
G: Visit,
{
if handler.shader_id != uni.0.shader_id {
return Err(ForeignShader);
Expand Down Expand Up @@ -181,9 +179,9 @@ impl<'a> Binder<'a> {
}
}

pub fn bind<'g, G>(&mut self, group: &'g G) -> GroupHandler<G>
pub fn bind<G>(&mut self, group: &G) -> GroupHandler<G>
where
G: Group<Visitor<'g> = VisitGroup<'g>>,
G: Visit,
{
let id = self.groups.len();
let Some(layout) = self.layout.get(id) else {
Expand Down
9 changes: 4 additions & 5 deletions dunge/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use {
crate::{
bind::{self, Binder, GroupHandler, UniqueGroupBinding, Update, VisitGroup},
bind::{self, Binder, GroupHandler, UniqueGroupBinding, Update, Visit},
draw::Draw,
group::Group,
layer::Layer,
mesh::{self, Mesh},
shader::Shader,
Expand Down Expand Up @@ -88,14 +87,14 @@ impl Context {
self.0.draw(render, view, draw)
}

pub fn update_group<'g, G>(
pub fn update_group<G>(
&self,
uni: &mut UniqueGroupBinding,
handler: GroupHandler<G>,
group: &'g G,
group: &G,
) -> Update
where
G: Group<Visitor<'g> = VisitGroup<'g>>,
G: Visit,
{
bind::update(&self.0, uni, handler, group)
}
Expand Down
2 changes: 1 addition & 1 deletion dunge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod time;
pub mod vertex;

pub use {
dunge_macros::Vertex,
dunge_macros::{Group, Vertex},
dunge_shader::{group::Group, sl, types, vertex::Vertex},
glam,
};
49 changes: 7 additions & 42 deletions dunge/tests/triangle_group.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use {
dunge::{
bind::VisitGroup,
color::Rgba,
context::Context,
draw,
group::{BoundTexture, DeclareGroup, Group, MemberProjection, Projection, Visitor},
group::BoundTexture,
mesh,
sl::{self, GlobalOut, Groups, Input, Out},
sl::{self, Groups, Input, Out},
state::{Options, Render},
texture::{self, Format, Sampler},
Vertex,
Group, Vertex,
},
glam::Vec2,
helpers::Image,
Expand All @@ -29,44 +28,10 @@ fn render() -> Result<(), Error> {
tex: Vec2,
}

struct Map<'g> {
tex: BoundTexture<'g>,
sam: &'g Sampler,
}

impl Group for Map<'_> {
type Projection = MapProjection;
type Visitor<'g> = VisitGroup<'g>
where
Self: 'g;

const DECL: DeclareGroup = DeclareGroup::new(&[
<BoundTexture<'static> as MemberProjection>::TYPE,
<&'static Sampler as MemberProjection>::TYPE,
]);

fn group<'g>(&'g self, visit: &mut Self::Visitor<'g>) {
visit.visit_texture(self.tex);
visit.visit_sampler(self.sam);
}
}

struct MapProjection {
tex: <BoundTexture<'static> as MemberProjection>::Field,
sam: <&'static Sampler as MemberProjection>::Field,
}

impl Projection for MapProjection {
fn projection(id: u32, out: GlobalOut) -> Self {
Self {
tex: <BoundTexture<'static> as MemberProjection>::member_projection(
id,
0,
out.clone(),
),
sam: <&'static Sampler as MemberProjection>::member_projection(id, 1, out.clone()),
}
}
#[derive(Group)]
struct Map<'a> {
tex: BoundTexture<'a>,
sam: &'a Sampler,
}

let triangle = |vert: Input<Vert>, groups: Groups<Map>| {
Expand Down
161 changes: 161 additions & 0 deletions dunge_macros/src/group.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use {
crate::utils,
proc_macro2::{Span, TokenStream},
syn::{spanned::Spanned, Data, DataStruct, DeriveInput, GenericParam, Ident, Lifetime},
};

pub(crate) fn derive(input: DeriveInput) -> TokenStream {
use std::iter;

let Data::Struct(DataStruct { fields, .. }) = input.data else {
return quote::quote_spanned! { input.ident.span() =>
::std::compile_error!("the group type must be a struct");
};
};

let mut lts = Vec::with_capacity(input.generics.params.len());
for param in input.generics.params {
let GenericParam::Lifetime(param) = param else {
return quote::quote_spanned! { param.span() =>
::std::compile_error!("the group struct cannot have non-lifetime generic parameters");
};
};

if !param.attrs.is_empty() {
return quote::quote_spanned! { param.span() =>
::std::compile_error!("the lifetime cannot have any attributes");
};
}

if !param.bounds.is_empty() {
return quote::quote_spanned! { param.span() =>
::std::compile_error!("the lifetime cannot have any bounds");
};
}

lts.push(param.lifetime)
}

if fields.is_empty() {
return quote::quote_spanned! { fields.span() =>
::std::compile_error!("the group struct must have some fields");
};
}

let static_lt = Lifetime {
apostrophe: Span::call_site(),
ident: Ident::new("static", Span::call_site()),
};

let static_lts = lts.iter().map(|_| &static_lt);

let anon_lt = Lifetime {
apostrophe: Span::call_site(),
ident: Ident::new("_", Span::call_site()),
};

let anon_lts = lts
.iter()
.map(|lt| if lt.ident == "static" { lt } else { &anon_lt });

let name = input.ident;
let projection_name = quote::format_ident!("{name}Projection");
let group_types = fields.iter().map(|field| {
let ty = &field.ty;
quote::quote! { <#ty as ::dunge::group::MemberProjection>::TYPE }
});

let group_visit_members = iter::zip(0.., &fields).map(|(index, field)| {
let ident = utils::make_ident(index, field.ident.as_ref());
quote::quote! { ::dunge::bind::VisitMember::visit_member(self.#ident, visitor) }
});

let group_fields = iter::zip(0.., &fields).map(|(index, field)| {
let ident = utils::make_ident(index, field.ident.as_ref());
let ty = &field.ty;
quote::quote! { #ident: <#ty as ::dunge::group::MemberProjection>::Field }
});

let group_member_projections = iter::zip(0.., &fields).map(|(index, field)| {
let ident = utils::make_ident(index, field.ident.as_ref());
let ty = &field.ty;
quote::quote! { #ident: <#ty as ::dunge::group::MemberProjection>::member_projection(id, #index, out.clone()) }
});

quote::quote! {
impl<#(#lts),*> ::dunge::group::Group for #name<#(#lts),*> {
type Projection = #projection_name<#(#static_lts),*>;
const DECL: ::dunge::group::DeclareGroup = ::dunge::group::DeclareGroup::new(&[
#(#group_types),*,
]);
}

impl ::dunge::bind::Visit for Map<#(#anon_lts),*> {
fn visit<'a>(&'a self, visitor: &mut ::dunge::bind::Visitor<'a>) {
#(#group_visit_members);*;
}
}

struct #projection_name<#(#lts),*> {
#(#group_fields),*,
}

impl<#(#lts),*> ::dunge::group::Projection for #projection_name<#(#lts),*> {
fn projection(id: ::core::primitive::u32, out: ::dunge::sl::GlobalOut) -> Self {
Self {
#(#group_member_projections),*,
}
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn derive_group() {
let input = quote::quote! {
struct Map<'a> {
tex: BoundTexture<'a>,
sam: &'a Sampler,
}
};

let input = syn::parse2(input).expect("parse input");
let actual = derive(input);
let expected = quote::quote! {
impl<'a> ::dunge::group::Group for Map<'a> {
type Projection = MapProjection<'static>;
const DECL: ::dunge::group::DeclareGroup = ::dunge::group::DeclareGroup::new(&[
<BoundTexture<'a> as ::dunge::group::MemberProjection>::TYPE,
<&'a Sampler as ::dunge::group::MemberProjection>::TYPE,
]);
}

impl ::dunge::bind::Visit for Map<'_> {
fn visit<'a>(&'a self, visitor: &mut ::dunge::bind::Visitor<'a>) {
::dunge::bind::VisitMember::visit_member(self.tex, visitor);
::dunge::bind::VisitMember::visit_member(self.sam, visitor);
}
}

struct MapProjection<'a> {
tex: <BoundTexture<'a> as ::dunge::group::MemberProjection>::Field,
sam: <&'a Sampler as ::dunge::group::MemberProjection>::Field,
}

impl<'a> ::dunge::group::Projection for MapProjection<'a> {
fn projection(id: ::core::primitive::u32, out: ::dunge::sl::GlobalOut) -> Self {
Self {
tex: <BoundTexture<'a> as ::dunge::group::MemberProjection>::member_projection(id, 0u32, out.clone()),
sam: <&'a Sampler as ::dunge::group::MemberProjection>::member_projection(id, 1u32, out.clone()),
}
}
}
};

assert_eq!(actual.to_string(), expected.to_string());
}
}
11 changes: 10 additions & 1 deletion dunge_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod group;
mod utils;
mod vertex;

use proc_macro::TokenStream;
Expand All @@ -6,5 +8,12 @@ use proc_macro::TokenStream;
#[proc_macro_derive(Vertex)]
pub fn derive_vertex(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input);
vertex::impl_vertex(input).into()
vertex::derive(input).into()
}

/// Derive implementation for the group type.
#[proc_macro_derive(Group)]
pub fn derive_group(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input);
group::derive(input).into()
}
8 changes: 8 additions & 0 deletions dunge_macros/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use {proc_macro2::Ident, std::borrow::Cow};

pub(crate) fn make_ident(index: u32, ident: Option<&Ident>) -> Cow<Ident> {
match ident {
Some(ident) => Cow::Borrowed(ident),
None => Cow::Owned(quote::format_ident!("{index}")),
}
}
Loading

0 comments on commit 2a01a9d

Please sign in to comment.