Skip to content

Commit

Permalink
Conditional expression
Browse files Browse the repository at this point in the history
  • Loading branch information
nanoqsh committed Feb 8, 2024
1 parent 89b2359 commit 174f3f4
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 9 deletions.
30 changes: 27 additions & 3 deletions dunge/tests/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
type Error = Box<dyn std::error::Error>;

#[test]
fn render() -> Result<(), Error> {
fn shader_calc() -> Result<(), Error> {
use dunge::{
glam::Vec4,
sl::{self, Out},
Expand All @@ -12,7 +12,7 @@ fn render() -> Result<(), Error> {
let compute = || {
let m = -sl::mat2(sl::vec2(1., 0.), sl::vec2(0., 1.));
let [m0, m1, m3] = sl::thunk(m);
let v = m0.x() + (-m1.y());
let v = m0.x() + m1.y();
let z = sl::splat_vec3(1.).z();

Out {
Expand All @@ -23,6 +23,30 @@ fn render() -> Result<(), Error> {

let cx = helpers::block_on(dunge::context())?;
let shader = cx.make_shader(compute);
assert_eq!(shader.debug_wgsl(), include_str!("shader.wgsl"));
assert_eq!(shader.debug_wgsl(), include_str!("shader_calc.wgsl"));
Ok(())
}

#[test]
fn shader_if() -> Result<(), Error> {
use dunge::{
glam::Vec4,
sl::{self, Out},
};

let compute = || {
let a = Vec4::splat(3.);
let b = sl::splat_vec4(2.) * 2.;
let x = sl::if_then_else(true, a, b);

Out {
place: x,
color: sl::splat_vec4(1.),
}
};

let cx = helpers::block_on(dunge::context())?;
let shader = cx.make_shader(compute);
assert_eq!(shader.debug_wgsl(), include_str!("shader_if.wgsl"));
Ok(())
}
2 changes: 1 addition & 1 deletion dunge/tests/shader.wgsl → dunge/tests/shader_calc.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct VertexOutput {
@vertex
fn vs() -> VertexOutput {
let _e7: mat2x2<f32> = -(mat2x2<f32>(vec2<f32>(1f, 0f), vec2<f32>(0f, 1f)));
return VertexOutput(((vec4<f32>(_e7[0], (_e7[0] + -(_e7[1]))) * f32(1i)) * vec3<f32>(1f, 1f, 1f).z));
return VertexOutput(((vec4<f32>(_e7[0], (_e7[0] + _e7[1])) * f32(1i)) * vec3<f32>(1f, 1f, 1f).z));
}

@fragment
Expand Down
21 changes: 21 additions & 0 deletions dunge/tests/shader_if.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct VertexOutput {
@builtin(position) member: vec4<f32>,
}

@vertex
fn vs() -> VertexOutput {
var local: vec4<f32>;

if true {
local = vec4<f32>(3f, 3f, 3f, 3f);
} else {
local = (vec4<f32>(2f, 2f, 2f, 2f) * 2f);
}
let _e11: vec4<f32> = local;
return VertexOutput(_e11);
}

@fragment
fn fs(param: VertexOutput) -> @location(0) vec4<f32> {
return vec4<f32>(1f, 1f, 1f, 1f);
}
100 changes: 95 additions & 5 deletions dunge_shader/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use {
types::{self, MemberType, ScalarType, ValueType, VectorType},
},
naga::{
AddressSpace, Arena, BinaryOperator, Binding, Block, BuiltIn, EntryPoint, Expression,
Function, FunctionArgument, FunctionResult, GlobalVariable, Handle, Literal, Range,
AddressSpace, Arena, BinaryOperator, Binding, BuiltIn, EntryPoint, Expression, Function,
FunctionArgument, FunctionResult, GlobalVariable, Handle, Literal, LocalVariable, Range,
ResourceBinding, SampleLevel, ShaderStage, Span, Statement, StructMember, Type, TypeInner,
UnaryOperator, UniqueArena,
},
Expand Down Expand Up @@ -338,11 +338,11 @@ impl<A, E> Clone for Thunk<A, E> {
}
}

impl<A, O, E> Eval<E> for Ret<Thunk<A, E>, O>
impl<A, E> Eval<E> for Ret<Thunk<A, E>, A::Out>
where
A: Eval<E>,
{
type Out = O;
type Out = A::Out;

fn eval(self, en: &mut E) -> Expr {
let Thunk { s, .. } = self.get();
Expand All @@ -363,6 +363,50 @@ enum State<A> {
Expr(Expr),
}

pub fn if_then_else<C, A, B, E>(c: C, a: A, b: B) -> Ret<IfThenElse<C, A, B, E>, A::Out>
where
C: Eval<E, Out = bool>,
A: Eval<E>,
A::Out: types::Value,
B: Eval<E, Out = A::Out>,
{
Ret::new(IfThenElse {
c,
a,
b,
e: PhantomData,
})
}

pub struct IfThenElse<C, A, B, E> {
c: C,
a: A,
b: B,
e: PhantomData<E>,
}

impl<C, A, B, E> Eval<E> for Ret<IfThenElse<C, A, B, E>, A::Out>
where
C: Eval<E>,
A: Eval<E>,
A::Out: types::Value,
B: Eval<E>,
E: GetEntry,
{
type Out = A::Out;

fn eval(self, en: &mut E) -> Expr {
let IfThenElse { c, a, b, .. } = self.get();
let c = c.eval(en);
let a = a.eval(en);
let b = b.eval(en);
let en = en.get_entry();
let valty = <A::Out as types::Value>::VALUE_TYPE;
let ty = en.new_type(valty.ty());
en.if_then_else(c, a, b, ty)
}
}

#[derive(Default)]
pub(crate) struct Evaluated([Option<Expr>; 4]);

Expand Down Expand Up @@ -651,19 +695,23 @@ impl Sampled {

pub struct Entry {
compl: Compiler,
locls: Arena<LocalVariable>,
exprs: Arena<Expression>,
stats: Statements,
cached_glob: HashMap<Handle<GlobalVariable>, Expr>,
cached_locl: HashMap<Handle<LocalVariable>, Expr>,
cached_args: HashMap<u32, Expr>,
}

impl Entry {
fn new(compl: Compiler) -> Self {
Self {
compl,
locls: Arena::default(),
exprs: Arena::default(),
stats: Statements::default(),
cached_glob: HashMap::default(),
cached_locl: HashMap::default(),
cached_args: HashMap::default(),
}
}
Expand All @@ -672,6 +720,16 @@ impl Entry {
self.compl.types.insert(ty, Span::UNDEFINED)
}

fn add_local(&mut self, ty: Handle<Type>) -> Handle<LocalVariable> {
let local = LocalVariable {
name: None,
ty,
init: None,
};

self.locls.append(local, Span::UNDEFINED)
}

fn literal(&mut self, literal: Literal) -> Expr {
let ex = Expression::Literal(literal);
Expr(self.exprs.append(ex, Span::UNDEFINED))
Expand All @@ -691,6 +749,13 @@ impl Entry {
})
}

fn local(&mut self, v: Handle<LocalVariable>) -> Expr {
*self.cached_locl.entry(v).or_insert_with(|| {
let ex = Expression::LocalVariable(v);
Expr(self.exprs.append(ex, Span::UNDEFINED))
})
}

fn load(&mut self, ptr: Expr) -> Expr {
let ex = Expression::Load { pointer: ptr.0 };
let handle = self.exprs.append(ex, Span::UNDEFINED);
Expand Down Expand Up @@ -777,6 +842,30 @@ impl Entry {
Expr(handle)
}

// TODO: Lazy evaluation
fn if_then_else(&mut self, cond: Expr, a: Expr, b: Expr, ty: Handle<Type>) -> Expr {
let v = self.add_local(ty);
let pointer = self.local(v);
let a = Statements(vec![Statement::Store {
pointer: pointer.0,
value: a.0,
}]);

let b = Statements(vec![Statement::Store {
pointer: pointer.0,
value: b.0,
}]);

let st = Statement::If {
condition: cond.0,
accept: a.0.into(),
reject: b.0.into(),
};

self.stats.push(st, &self.exprs);
self.load(pointer)
}

fn ret(&mut self, value: Expr) {
let st = Statement::Return {
value: Some(value.0),
Expand Down Expand Up @@ -807,8 +896,9 @@ impl Entry {
function: Function {
arguments: args.map(Argument::into_function).collect(),
result: Some(res),
local_variables: self.locls,
expressions: self.exprs,
body: Block::from_vec(self.stats.0),
body: self.stats.0.into(),
..Default::default()
},
};
Expand Down

0 comments on commit 174f3f4

Please sign in to comment.