Skip to content

Commit

Permalink
Update branch
Browse files Browse the repository at this point in the history
  • Loading branch information
nanoqsh committed Feb 16, 2024
1 parent a7313ae commit fe14ef9
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 105 deletions.
47 changes: 38 additions & 9 deletions dunge/tests/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,46 @@ fn shader_if() -> Result<(), Error> {
fn shader_branch() -> Result<(), Error> {
use dunge::sl::{self, Out};

let compute = || Out {
place: sl::default(|| sl::splat_vec4(3.))
.when(true, || sl::splat_vec4(1.))
.when(false, || sl::splat_vec4(2.)),
color: sl::splat_vec4(1.),
let cx = helpers::block_on(dunge::context())?;
let shader0 = {
let compute = || Out {
place: sl::default(|| sl::splat_vec4(1.)).when(false, || sl::splat_vec4(2.)),
color: sl::splat_vec4(1.),
};

cx.make_shader(compute)
};

let cx = helpers::block_on(dunge::context())?;
let shader = cx.make_shader(compute);
// helpers::eq_lines(shader.debug_wgsl(), include_str!("shader_branch.wgsl"));
_ = std::fs::write("tests/shader_branch.wgsl", shader.debug_wgsl());
let shader1 = {
let compute = || Out {
place: sl::default(|| sl::splat_vec4(1.))
.when(true, || sl::splat_vec4(2.))
.when(false, || sl::splat_vec4(3.)),
color: sl::splat_vec4(1.),
};

cx.make_shader(compute)
};

let shader2 = {
let compute = || {
let p = sl::default(|| sl::splat_vec4(1.))
.when(true, || sl::splat_vec4(2.))
.when(true, || sl::splat_vec4(3.))
.when(false, || sl::splat_vec4(4.));

Out {
place: p,
color: sl::splat_vec4(1.),
}
};

cx.make_shader(compute)
};

helpers::eq_lines(shader0.debug_wgsl(), include_str!("shader_branch0.wgsl"));
helpers::eq_lines(shader1.debug_wgsl(), include_str!("shader_branch1.wgsl"));
helpers::eq_lines(shader2.debug_wgsl(), include_str!("shader_branch2.wgsl"));
Ok(())
}

Expand Down
21 changes: 21 additions & 0 deletions dunge/tests/shader_branch0.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct VertexOutput {
@builtin(position) member: vec4<f32>,
}

@vertex
fn vs() -> VertexOutput {
var local: vec4<f32>;

if false {
local = vec4<f32>(2f, 2f, 2f, 2f);
} else {
local = vec4<f32>(1f, 1f, 1f, 1f);
}
let _e6: vec4<f32> = local;
return VertexOutput(_e6);
}

@fragment
fn fs(param: VertexOutput) -> @location(0) vec4<f32> {
return vec4<f32>(1f, 1f, 1f, 1f);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ struct VertexOutput {
fn vs() -> VertexOutput {
var local: vec4<f32>;

if true {
local = vec4<f32>(1f, 1f, 1f, 1f);
if false {
local = vec4<f32>(3f, 3f, 3f, 3f);
} else {
if false {
if true {
local = vec4<f32>(2f, 2f, 2f, 2f);
} else {
local = vec4<f32>(3f, 3f, 3f, 3f);
local = vec4<f32>(1f, 1f, 1f, 1f);
}
}
let _e9: vec4<f32> = local;
Expand Down
29 changes: 29 additions & 0 deletions dunge/tests/shader_branch2.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct VertexOutput {
@builtin(position) member: vec4<f32>,
}

@vertex
fn vs() -> VertexOutput {
var local: vec4<f32>;

if false {
local = vec4<f32>(4f, 4f, 4f, 4f);
} else {
if true {
local = vec4<f32>(3f, 3f, 3f, 3f);
} else {
if true {
local = vec4<f32>(2f, 2f, 2f, 2f);
} else {
local = vec4<f32>(1f, 1f, 1f, 1f);
}
}
}
let _e12: vec4<f32> = local;
return VertexOutput(_e12);
}

@fragment
fn fs(param: VertexOutput) -> @location(0) vec4<f32> {
return vec4<f32>(1f, 1f, 1f, 1f);
}
42 changes: 18 additions & 24 deletions dunge_shader/src/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ where
let IfThenElse { c, a, b, .. } = self.get();
let c = c.eval(en);
let a = |en: &mut E| a().eval(en);
let b = |en: &mut E| Some(b().eval(en));
let b = |branch: &mut Branch<_>| Some(b().eval(branch.entry()));
let valty = <X::Out as types::Value>::VALUE_TYPE;
let ty = en.get_entry().new_type(valty.ty());
let branch = Branch::new(en.get_entry(), ty);
branch.add(en, c, a, b);
branch.load(en.get_entry())
let mut branch = Branch::new(en, ty);
branch.add(c, a, b);
branch.load()
}
}

Expand Down Expand Up @@ -94,22 +94,16 @@ pub struct When<C, A, B, E> {

impl<C, A, B, E, O> Ret<When<C, A, B, E>, O> {
#[allow(clippy::type_complexity)]
pub fn when<D, F, Z>(self, cond: D, expr: F) -> Ret<When<C, A, When<D, F, B, E>, E>, O>
pub fn when<D, F, Z>(self, cond: D, expr: F) -> Ret<When<D, F, When<C, A, B, E>, E>, O>
where
D: Eval<E, Out = bool>,
F: FnOnce() -> Z,
Z: Eval<E, Out = O>,
{
let when = self.get();
Ret::new(When {
c: when.c,
a: when.a,
b: When {
c: cond,
a: expr,
b: when.b,
e: PhantomData,
},
c: cond,
a: expr,
b: self.get(),
e: PhantomData,
})
}
Expand All @@ -130,23 +124,23 @@ where
let when = self.get();
let valty = <X::Out as types::Value>::VALUE_TYPE;
let ty = en.get_entry().new_type(valty.ty());
let branch = Branch::new(en.get_entry(), ty);
when.eval_else(en, &branch);
branch.load(en.get_entry())
let mut branch = Branch::new(en, ty);
when.eval_branch(&mut branch);
branch.load()
}
}

pub trait EvalBranch<E> {
fn eval_else(self, en: &mut E, branch: &Branch) -> Option<Expr>;
fn eval_branch(self, branch: &mut Branch<E>) -> Option<Expr>;
}

impl<F, R, E> EvalBranch<E> for F
where
F: FnOnce() -> R,
R: Eval<E>,
{
fn eval_else(self, en: &mut E, _: &Branch) -> Option<Expr> {
Some(self().eval(en))
fn eval_branch(self, branch: &mut Branch<E>) -> Option<Expr> {
Some(self().eval(branch.entry()))
}
}

Expand All @@ -158,12 +152,12 @@ where
B: EvalBranch<E>,
E: GetEntry,
{
fn eval_else(self, en: &mut E, branch: &Branch) -> Option<Expr> {
fn eval_branch(self, branch: &mut Branch<E>) -> Option<Expr> {
let Self { c, a, b, .. } = self;
let c = c.eval(en);
let c = c.eval(branch.entry());
let a = |en: &mut E| a().eval(en);
let b = |en: &mut E| b.eval_else(en, branch);
branch.add(en, c, a, b);
let b = |branch: &mut Branch<_>| b.eval_branch(branch);
branch.add(c, a, b);
None
}
}
100 changes: 32 additions & 68 deletions dunge_shader/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,31 +831,46 @@ impl Entry {
}
}

pub struct Branch {
pub struct Branch<'a, E> {
en: &'a mut E,
expr: Expr,
}

impl Branch {
pub(crate) fn new(en: &mut Entry, ty: Handle<Type>) -> Self {
let v = en.add_local(ty);
let expr = en.local(v);
Self { expr }
impl<'a, E> Branch<'a, E> {
pub(crate) fn new(en: &'a mut E, ty: Handle<Type>) -> Self
where
E: GetEntry,
{
let expr = {
let en = en.get_entry();
let v = en.add_local(ty);
en.local(v)
};

Self { en, expr }
}

pub(crate) fn load(&self, en: &mut Entry) -> Expr {
en.load(self.expr)
pub(crate) fn entry(&mut self) -> &mut E {
self.en
}

pub(crate) fn add<E, A, B>(&self, en: &mut E, c: Expr, a: A, b: B)
pub(crate) fn load(&mut self) -> Expr
where
E: GetEntry,
{
self.en.get_entry().load(self.expr)
}

pub(crate) fn add<A, B>(&mut self, c: Expr, a: A, b: B)
where
E: GetEntry,
A: FnOnce(&mut E) -> Expr,
B: FnOnce(&mut E) -> Option<Expr>,
B: FnOnce(&mut Self) -> Option<Expr>,
{
let a_branch = {
en.get_entry().push();
let a = a(en);
let en = en.get_entry();
self.en.get_entry().push();
let a = a(self.entry());
let en = self.en.get_entry();
let mut s = en.pop();
let st = Statement::Store {
pointer: self.expr.0,
Expand All @@ -867,9 +882,9 @@ impl Branch {
};

let b_branch = {
en.get_entry().push();
let b = b(en);
let en = en.get_entry();
self.en.get_entry().push();
let b = b(self);
let en = self.en.get_entry();
let mut s = en.pop();
if let Some(b) = b {
let st = Statement::Store {
Expand All @@ -889,62 +904,11 @@ impl Branch {
reject: b_branch.0.into(),
};

let en = en.get_entry();
let en = self.en.get_entry();
en.stack.insert(st, &en.exprs);
}
}

// pub(crate) fn branch<E, A, B>(en: &mut E, ty: Handle<Type>, c: Expr, a: A, b: B) -> Expr
// where
// E: GetEntry,
// A: FnOnce(&mut E) -> Expr,
// B: FnOnce(&mut E) -> Expr,
// {
// let pointer = {
// let en = en.get_entry();
// let v = en.add_local(ty);
// en.local(v)
// };

// let a_branch = {
// en.get_entry().push();
// let a = a(en);
// let en = en.get_entry();
// let mut s = en.pop();
// let st = Statement::Store {
// pointer: pointer.0,
// value: a.0,
// };

// s.insert(st, &en.exprs);
// s
// };

// let b_branch = {
// en.get_entry().push();
// let b = b(en);
// let en = en.get_entry();
// let mut s = en.pop();
// let st = Statement::Store {
// pointer: pointer.0,
// value: b.0,
// };

// s.insert(st, &en.exprs);
// s
// };

// let st = Statement::If {
// condition: c.0,
// accept: a_branch.0.into(),
// reject: b_branch.0.into(),
// };

// let en = en.get_entry();
// en.stack.insert(st, &en.exprs);
// en.load(pointer)
// }

struct Stack(Vec<Statements>);

impl Stack {
Expand Down

0 comments on commit fe14ef9

Please sign in to comment.