Skip to content

Commit

Permalink
Implement GPU comparisons and if functionality (#857)
Browse files Browse the repository at this point in the history
* Implement GPU comparisons and if functionality

* I may as well add the eagerly-evaled if functions for more trivial ternary-like assignments
  • Loading branch information
dfellis authored Aug 23, 2024
1 parent c210082 commit c12373b
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 6 deletions.
12 changes: 12 additions & 0 deletions src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,18 @@ test!(gpu_map => r#"
stdout "[3, 4, 5, 6]\n";
);

test!(gpu_if => r#"
export fn main {
let b = GBuffer([1, 2, 3, 4]);
let out = b.map(fn (val: gi32, i: gu32) -> gi32 = if(
i % 2 == 0,
val * i.gi32,
val - i.gi32));
out.read{i32}.print;
}"#;
stdout "[0, 1, 6, 1]\n";
);

// Bitwise Math

test!(i8_bitwise => r#"
Expand Down
17 changes: 13 additions & 4 deletions src/lntors/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub fn from_microstatement(
format!(
"|{}| {{\n {};\n }}",
arg_names.join(", "),
inner_statements.join(";\n ")
inner_statements.join(";\n "),
),
out,
))
Expand Down Expand Up @@ -209,9 +209,18 @@ pub fn from_microstatement(
// Static functions just replace the function call with their static value
// calculated at compile time.
match &function.microstatements[0] {
Microstatement::Value { representation, .. } => {
Ok((representation.clone(), out))
}
Microstatement::Value {
representation,
typen,
} => match &typen {
CType::Type(n, _) if n == "string" => {
Ok((format!("{}.to_string()", representation).to_string(), out))
}
CType::Binds(a) if a == "String" => {
Ok((format!("{}.to_string()", representation).to_string(), out))
}
_ => Ok((representation.clone(), out)),
},
_ => unreachable!(),
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/lntors/typen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,16 @@ pub fn ctype_to_rtype(
} else {
Ok(format!(
"impl Fn(&{}) -> {}",
ctype_to_rtype(i, true)?,
match &**i {
CType::Tuple(ts) => {
let mut out = Vec::new();
for t in ts {
out.push(ctype_to_rtype(t, true)?);
}
out.join(", &")
},
otherwise => ctype_to_rtype(otherwise, true)?,
},
ctype_to_rtype(o, true)?
))
}
Expand Down
214 changes: 213 additions & 1 deletion src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ export fn xnor(a: bool, b: bool) -> bool binds xnorbool;
export fn eq(a: bool, b: bool) -> bool binds eqbool;
export fn neq(a: bool, b: bool) -> bool binds neqbool;
export fn if{T}(c: bool, t: () -> T, f: () -> T) -> T binds ifbool;
export fn if{T}(c: bool, t: () -> T) -> Maybe{T} = if(c, fn = Maybe{T}(t()), fn = Maybe{T}());
export fn if{T}(c: bool, t: () -> T) -> Maybe{T} = if(c, fn () -> Maybe{T} = Maybe{T}(t()), fn () -> Maybe{T} = Maybe{T}());
export fn if{T}(c: bool, t: T, f: T) -> T = if(c, fn () -> T = t, fn () -> T = f);
export fn if{T}(c: bool, t: T) -> Maybe{T} = if(c, fn () -> Maybe{T} = Maybe{T}(t), fn () -> Maybe{T} = Maybe{T}());

/// Array related bindings
export fn get{T}(a: T[], i: i64) -> Maybe{T} binds getarray;
Expand Down Expand Up @@ -4451,6 +4453,207 @@ export fn mod(a: gf32, b: gvec4f) -> gvec4f {
return gvec4f(varName, statements, buffers);
}

// GPU Comparison methods

fn geq{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' == ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn eq(a: gu32, b: gu32) -> gbool = geq{gu32, gbool}(a, b);
export fn eq{T}(a: gu32, b: T) -> gbool = geq{gu32, gbool}(a, b.gu32);
export fn eq{T}(a: T, b: gu32) -> gbool = geq{gu32, gbool}(a.gu32, b);
export fn eq(a: gi32, b: gi32) -> gbool = geq{gi32, gbool}(a, b);
export fn eq{T}(a: gi32, b: T) -> gbool = geq{gi32, gbool}(a, b.gi32);
export fn eq{T}(a: T, b: gi32) -> gbool = geq{gi32, gbool}(a.gi32, b);
export fn eq(a: gf32, b: gf32) -> gbool = geq{gf32, gbool}(a, b);
export fn eq{T}(a: gf32, b: T) -> gbool = geq{gf32, gbool}(a, b.gf32);
export fn eq{T}(a: T, b: gf32) -> gbool = geq{gf32, gbool}(a.gf32, b);
export fn eq(a: gbool, b: gbool) -> gbool = geq{gbool, gbool}(a, b);
export fn eq{T}(a: gbool, b: T) -> gbool = geq{gbool, gbool}(a, b.gbool);
export fn eq{T}(a: T, b: gbool) -> gbool = geq{gbool, gbool}(a.gbool, b);
export fn eq(a: gvec2u, b: gvec2u) -> gvec2b = geq{gvec2u, gvec2b}(a, b);
export fn eq(a: gvec2i, b: gvec2i) -> gvec2b = geq{gvec2i, gvec2b}(a, b);
export fn eq(a: gvec2f, b: gvec2f) -> gvec2b = geq{gvec2f, gvec2b}(a, b);
export fn eq(a: gvec2b, b: gvec2b) -> gvec2b = geq{gvec2b, gvec2b}(a, b);
export fn eq(a: gvec3u, b: gvec3u) -> gvec3b = geq{gvec3u, gvec3b}(a, b);
export fn eq(a: gvec3i, b: gvec3i) -> gvec3b = geq{gvec3i, gvec3b}(a, b);
export fn eq(a: gvec3f, b: gvec3f) -> gvec3b = geq{gvec3f, gvec3b}(a, b);
export fn eq(a: gvec3b, b: gvec3b) -> gvec3b = geq{gvec3b, gvec3b}(a, b);
export fn eq(a: gvec4u, b: gvec4u) -> gvec4b = geq{gvec4u, gvec4b}(a, b);
export fn eq(a: gvec4i, b: gvec4i) -> gvec4b = geq{gvec4i, gvec4b}(a, b);
export fn eq(a: gvec4f, b: gvec4f) -> gvec4b = geq{gvec4f, gvec4b}(a, b);
export fn eq(a: gvec4b, b: gvec4b) -> gvec4b = geq{gvec4b, gvec4b}(a, b);

fn gneq{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' != ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn neq(a: gu32, b: gu32) -> gbool = gneq{gu32, gbool}(a, b);
export fn neq{T}(a: gu32, b: T) -> gbool = gneq{gu32, gbool}(a, b.gu32);
export fn neq{T}(a: T, b: gu32) -> gbool = gneq{gu32, gbool}(a.gu32, b);
export fn neq(a: gi32, b: gi32) -> gbool = gneq{gi32, gbool}(a, b);
export fn neq{T}(a: gi32, b: T) -> gbool = gneq{gi32, gbool}(a, b.gi32);
export fn neq{T}(a: T, b: gi32) -> gbool = gneq{gi32, gbool}(a.gi32, b);
export fn neq(a: gf32, b: gf32) -> gbool = gneq{gf32, gbool}(a, b);
export fn neq{T}(a: gf32, b: T) -> gbool = gneq{gf32, gbool}(a, b.gf32);
export fn neq{T}(a: T, b: gf32) -> gbool = gneq{gf32, gbool}(a.gf32, b);
export fn neq(a: gbool, b: gbool) -> gbool = gneq{gbool, gbool}(a, b);
export fn neq{T}(a: gbool, b: T) -> gbool = gneq{gbool, gbool}(a, b.gbool);
export fn neq{T}(a: T, b: gbool) -> gbool = gneq{gbool, gbool}(a.gbool, b);
export fn neq(a: gvec2u, b: gvec2u) -> gvec2b = gneq{gvec2u, gvec2b}(a, b);
export fn neq(a: gvec2i, b: gvec2i) -> gvec2b = gneq{gvec2i, gvec2b}(a, b);
export fn neq(a: gvec2f, b: gvec2f) -> gvec2b = gneq{gvec2f, gvec2b}(a, b);
export fn neq(a: gvec2b, b: gvec2b) -> gvec2b = gneq{gvec2b, gvec2b}(a, b);
export fn neq(a: gvec3u, b: gvec3u) -> gvec3b = gneq{gvec3u, gvec3b}(a, b);
export fn neq(a: gvec3i, b: gvec3i) -> gvec3b = gneq{gvec3i, gvec3b}(a, b);
export fn neq(a: gvec3f, b: gvec3f) -> gvec3b = gneq{gvec3f, gvec3b}(a, b);
export fn neq(a: gvec3b, b: gvec3b) -> gvec3b = gneq{gvec3b, gvec3b}(a, b);
export fn neq(a: gvec4u, b: gvec4u) -> gvec4b = gneq{gvec4u, gvec4b}(a, b);
export fn neq(a: gvec4i, b: gvec4i) -> gvec4b = gneq{gvec4i, gvec4b}(a, b);
export fn neq(a: gvec4f, b: gvec4f) -> gvec4b = gneq{gvec4f, gvec4b}(a, b);
export fn neq(a: gvec4b, b: gvec4b) -> gvec4b = gneq{gvec4b, gvec4b}(a, b);

fn glt{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' < ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn lt(a: gu32, b: gu32) -> gbool = glt{gu32, gbool}(a, b);
export fn lt{T}(a: gu32, b: T) -> gbool = glt{gu32, gbool}(a, b.gu32);
export fn lt{T}(a: T, b: gu32) -> gbool = glt{gu32, gbool}(a.gu32, b);
export fn lt(a: gi32, b: gi32) -> gbool = glt{gi32, gbool}(a, b);
export fn lt{T}(a: gi32, b: T) -> gbool = glt{gi32, gbool}(a, b.gi32);
export fn lt{T}(a: T, b: gi32) -> gbool = glt{gi32, gbool}(a.gi32, b);
export fn lt(a: gf32, b: gf32) -> gbool = glt{gf32, gbool}(a, b);
export fn lt{T}(a: gf32, b: T) -> gbool = glt{gf32, gbool}(a, b.gf32);
export fn lt{T}(a: T, b: gf32) -> gbool = glt{gf32, gbool}(a.gf32, b);
export fn lt(a: gvec2u, b: gvec2u) -> gvec2b = glt{gvec2u, gvec2b}(a, b);
export fn lt(a: gvec2i, b: gvec2i) -> gvec2b = glt{gvec2i, gvec2b}(a, b);
export fn lt(a: gvec2f, b: gvec2f) -> gvec2b = glt{gvec2f, gvec2b}(a, b);
export fn lt(a: gvec3u, b: gvec3u) -> gvec3b = glt{gvec3u, gvec3b}(a, b);
export fn lt(a: gvec3i, b: gvec3i) -> gvec3b = glt{gvec3i, gvec3b}(a, b);
export fn lt(a: gvec3f, b: gvec3f) -> gvec3b = glt{gvec3f, gvec3b}(a, b);
export fn lt(a: gvec4u, b: gvec4u) -> gvec4b = glt{gvec4u, gvec4b}(a, b);
export fn lt(a: gvec4i, b: gvec4i) -> gvec4b = glt{gvec4i, gvec4b}(a, b);
export fn lt(a: gvec4f, b: gvec4f) -> gvec4b = glt{gvec4f, gvec4b}(a, b);

fn glte{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' <= ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn lte(a: gu32, b: gu32) -> gbool = glte{gu32, gbool}(a, b);
export fn lte{T}(a: gu32, b: T) -> gbool = glte{gu32, gbool}(a, b.gu32);
export fn lte{T}(a: T, b: gu32) -> gbool = glte{gu32, gbool}(a.gu32, b);
export fn lte(a: gi32, b: gi32) -> gbool = glte{gi32, gbool}(a, b);
export fn lte{T}(a: gi32, b: T) -> gbool = glte{gi32, gbool}(a, b.gi32);
export fn lte{T}(a: T, b: gi32) -> gbool = glte{gi32, gbool}(a.gi32, b);
export fn lte(a: gf32, b: gf32) -> gbool = glte{gf32, gbool}(a, b);
export fn lte{T}(a: gf32, b: T) -> gbool = glte{gf32, gbool}(a, b.gf32);
export fn lte{T}(a: T, b: gf32) -> gbool = glte{gf32, gbool}(a.gf32, b);
export fn lte(a: gvec2u, b: gvec2u) -> gvec2b = glte{gvec2u, gvec2b}(a, b);
export fn lte(a: gvec2i, b: gvec2i) -> gvec2b = glte{gvec2i, gvec2b}(a, b);
export fn lte(a: gvec2f, b: gvec2f) -> gvec2b = glte{gvec2f, gvec2b}(a, b);
export fn lte(a: gvec3u, b: gvec3u) -> gvec3b = glte{gvec3u, gvec3b}(a, b);
export fn lte(a: gvec3i, b: gvec3i) -> gvec3b = glte{gvec3i, gvec3b}(a, b);
export fn lte(a: gvec3f, b: gvec3f) -> gvec3b = glte{gvec3f, gvec3b}(a, b);
export fn lte(a: gvec4u, b: gvec4u) -> gvec4b = glte{gvec4u, gvec4b}(a, b);
export fn lte(a: gvec4i, b: gvec4i) -> gvec4b = glte{gvec4i, gvec4b}(a, b);
export fn lte(a: gvec4f, b: gvec4f) -> gvec4b = glte{gvec4f, gvec4b}(a, b);

fn ggt{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' < ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn gt(a: gu32, b: gu32) -> gbool = ggt{gu32, gbool}(a, b);
export fn gt{T}(a: gu32, b: T) -> gbool = ggt{gu32, gbool}(a, b.gu32);
export fn gt{T}(a: T, b: gu32) -> gbool = ggt{gu32, gbool}(a.gu32, b);
export fn gt(a: gi32, b: gi32) -> gbool = ggt{gi32, gbool}(a, b);
export fn gt{T}(a: gi32, b: T) -> gbool = ggt{gi32, gbool}(a, b.gi32);
export fn gt{T}(a: T, b: gi32) -> gbool = ggt{gi32, gbool}(a.gi32, b);
export fn gt(a: gf32, b: gf32) -> gbool = ggt{gf32, gbool}(a, b);
export fn gt{T}(a: gf32, b: T) -> gbool = ggt{gf32, gbool}(a, b.gf32);
export fn gt{T}(a: T, b: gf32) -> gbool = ggt{gf32, gbool}(a.gf32, b);
export fn gt(a: gvec2u, b: gvec2u) -> gvec2b = ggt{gvec2u, gvec2b}(a, b);
export fn gt(a: gvec2i, b: gvec2i) -> gvec2b = ggt{gvec2i, gvec2b}(a, b);
export fn gt(a: gvec2f, b: gvec2f) -> gvec2b = ggt{gvec2f, gvec2b}(a, b);
export fn gt(a: gvec3u, b: gvec3u) -> gvec3b = ggt{gvec3u, gvec3b}(a, b);
export fn gt(a: gvec3i, b: gvec3i) -> gvec3b = ggt{gvec3i, gvec3b}(a, b);
export fn gt(a: gvec3f, b: gvec3f) -> gvec3b = ggt{gvec3f, gvec3b}(a, b);
export fn gt(a: gvec4u, b: gvec4u) -> gvec4b = ggt{gvec4u, gvec4b}(a, b);
export fn gt(a: gvec4i, b: gvec4i) -> gvec4b = ggt{gvec4i, gvec4b}(a, b);
export fn gt(a: gvec4f, b: gvec4f) -> gvec4b = ggt{gvec4f, gvec4b}(a, b);

fn ggte{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' <= ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn gte(a: gu32, b: gu32) -> gbool = ggte{gu32, gbool}(a, b);
export fn gte{T}(a: gu32, b: T) -> gbool = ggte{gu32, gbool}(a, b.gu32);
export fn gte{T}(a: T, b: gu32) -> gbool = ggte{gu32, gbool}(a.gu32, b);
export fn gte(a: gi32, b: gi32) -> gbool = ggte{gi32, gbool}(a, b);
export fn gte{T}(a: gi32, b: T) -> gbool = ggte{gi32, gbool}(a, b.gi32);
export fn gte{T}(a: T, b: gi32) -> gbool = ggte{gi32, gbool}(a.gi32, b);
export fn gte(a: gf32, b: gf32) -> gbool = ggte{gf32, gbool}(a, b);
export fn gte{T}(a: gf32, b: T) -> gbool = ggte{gf32, gbool}(a, b.gf32);
export fn gte{T}(a: T, b: gf32) -> gbool = ggte{gf32, gbool}(a.gf32, b);
export fn gte(a: gvec2u, b: gvec2u) -> gvec2b = ggte{gvec2u, gvec2b}(a, b);
export fn gte(a: gvec2i, b: gvec2i) -> gvec2b = ggte{gvec2i, gvec2b}(a, b);
export fn gte(a: gvec2f, b: gvec2f) -> gvec2b = ggte{gvec2f, gvec2b}(a, b);
export fn gte(a: gvec3u, b: gvec3u) -> gvec3b = ggte{gvec3u, gvec3b}(a, b);
export fn gte(a: gvec3i, b: gvec3i) -> gvec3b = ggte{gvec3i, gvec3b}(a, b);
export fn gte(a: gvec3f, b: gvec3f) -> gvec3b = ggte{gvec3f, gvec3b}(a, b);
export fn gte(a: gvec4u, b: gvec4u) -> gvec4b = ggte{gvec4u, gvec4b}(a, b);
export fn gte(a: gvec4i, b: gvec4i) -> gvec4b = ggte{gvec4i, gvec4b}(a, b);
export fn gte(a: gvec4f, b: gvec4f) -> gvec4b = ggte{gvec4f, gvec4b}(a, b);

fn if{T}(c: gbool, t: T, f: T) -> T {
let varName = "if_".concat(uuid().string.replace('-', '_'));
let tBody = t.statements.Array.map(fn (kv: (string, string)) -> string {
return if(kv.0.eq("@builtin(global_invocation_id) id: vec3u"), fn () -> string = "", fn () -> string {
return " ".concat(kv.1).concat(";\n");
});
}).join("");
let fBody = f.statements.Array.map(fn (kv: (string, string)) -> string {
return if(kv.0.eq("@builtin(global_invocation_id) id: vec3u"), fn () -> string = "", fn () -> string {
return " ".concat(kv.1).concat(";\n");
});
}).join("");
let statement = "var "
.concat(varName)
.concat(": ")
.concat(t.typeName)
.concat("; if ")
.concat(c.varName)
.concat(" { ")
.concat(tBody)
.concat('; ')
.concat(varName)
.concat(' = ')
.concat(t.varName)
.concat("; } else { ")
.concat(fBody)
.concat('; ')
.concat(varName)
.concat(' = ')
.concat(f.varName)
.concat('; }');
let statements = c.statements.concat(Dict(varName, statement));
let buffers = c.buffers.union(t.buffers).union(f.buffers);
return {T}(varName, statements, buffers);
}
fn if{T}(c: gbool, t: () -> T, f: () -> T) -> T = if(c, t(), f());

// GBuffer methods

// TODO: Support more than i32 for GBuffer
Expand All @@ -4462,6 +4665,15 @@ export fn map(gb: GBuffer, f: (gi32) -> gi32) -> GBuffer {
compute.build.run;
return out;
}
export fn map(gb: GBuffer, f: (gi32, gu32) -> gi32) -> GBuffer {
let idx = gFor(gb.len);
let val = gb[idx];
let out = GBuffer(gb.len.mul(4));
let compute = out[idx].store(f(val, idx));
compute.build.run;
return out;
}


/// Stdout/stderr-related bindings
// TODO: Rework this to just print anything that can be converted to `string` via interfaces
Expand Down

0 comments on commit c12373b

Please sign in to comment.