Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GPU comparisons and if functionality #857

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,18 @@ test!(gpu_map => r#"
stdout "[3, 4, 5, 6]\n";
);

test!(gpu_if => r#"
export fn main {
let b = GBuffer([1, 2, 3, 4]);
let out = b.map(fn (val: gi32, i: gu32) -> gi32 = if(
i % 2 == 0,
val * i.gi32,
val - i.gi32));
out.read{i32}.print;
}"#;
stdout "[0, 1, 6, 1]\n";
);

// Bitwise Math

test!(i8_bitwise => r#"
Expand Down
17 changes: 13 additions & 4 deletions src/lntors/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub fn from_microstatement(
format!(
"|{}| {{\n {};\n }}",
arg_names.join(", "),
inner_statements.join(";\n ")
inner_statements.join(";\n "),
),
out,
))
Expand Down Expand Up @@ -209,9 +209,18 @@ pub fn from_microstatement(
// Static functions just replace the function call with their static value
// calculated at compile time.
match &function.microstatements[0] {
Microstatement::Value { representation, .. } => {
Ok((representation.clone(), out))
}
Microstatement::Value {
representation,
typen,
} => match &typen {
CType::Type(n, _) if n == "string" => {
Ok((format!("{}.to_string()", representation).to_string(), out))
}
CType::Binds(a) if a == "String" => {
Ok((format!("{}.to_string()", representation).to_string(), out))
}
_ => Ok((representation.clone(), out)),
},
_ => unreachable!(),
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/lntors/typen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,16 @@ pub fn ctype_to_rtype(
} else {
Ok(format!(
"impl Fn(&{}) -> {}",
ctype_to_rtype(i, true)?,
match &**i {
CType::Tuple(ts) => {
let mut out = Vec::new();
for t in ts {
out.push(ctype_to_rtype(t, true)?);
}
out.join(", &")
},
otherwise => ctype_to_rtype(otherwise, true)?,
},
ctype_to_rtype(o, true)?
))
}
Expand Down
214 changes: 213 additions & 1 deletion src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ export fn xnor(a: bool, b: bool) -> bool binds xnorbool;
export fn eq(a: bool, b: bool) -> bool binds eqbool;
export fn neq(a: bool, b: bool) -> bool binds neqbool;
export fn if{T}(c: bool, t: () -> T, f: () -> T) -> T binds ifbool;
export fn if{T}(c: bool, t: () -> T) -> Maybe{T} = if(c, fn = Maybe{T}(t()), fn = Maybe{T}());
export fn if{T}(c: bool, t: () -> T) -> Maybe{T} = if(c, fn () -> Maybe{T} = Maybe{T}(t()), fn () -> Maybe{T} = Maybe{T}());
export fn if{T}(c: bool, t: T, f: T) -> T = if(c, fn () -> T = t, fn () -> T = f);
export fn if{T}(c: bool, t: T) -> Maybe{T} = if(c, fn () -> Maybe{T} = Maybe{T}(t), fn () -> Maybe{T} = Maybe{T}());

/// Array related bindings
export fn get{T}(a: T[], i: i64) -> Maybe{T} binds getarray;
Expand Down Expand Up @@ -4451,6 +4453,207 @@ export fn mod(a: gf32, b: gvec4f) -> gvec4f {
return gvec4f(varName, statements, buffers);
}

// GPU Comparison methods

fn geq{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' == ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn eq(a: gu32, b: gu32) -> gbool = geq{gu32, gbool}(a, b);
export fn eq{T}(a: gu32, b: T) -> gbool = geq{gu32, gbool}(a, b.gu32);
export fn eq{T}(a: T, b: gu32) -> gbool = geq{gu32, gbool}(a.gu32, b);
export fn eq(a: gi32, b: gi32) -> gbool = geq{gi32, gbool}(a, b);
export fn eq{T}(a: gi32, b: T) -> gbool = geq{gi32, gbool}(a, b.gi32);
export fn eq{T}(a: T, b: gi32) -> gbool = geq{gi32, gbool}(a.gi32, b);
export fn eq(a: gf32, b: gf32) -> gbool = geq{gf32, gbool}(a, b);
export fn eq{T}(a: gf32, b: T) -> gbool = geq{gf32, gbool}(a, b.gf32);
export fn eq{T}(a: T, b: gf32) -> gbool = geq{gf32, gbool}(a.gf32, b);
export fn eq(a: gbool, b: gbool) -> gbool = geq{gbool, gbool}(a, b);
export fn eq{T}(a: gbool, b: T) -> gbool = geq{gbool, gbool}(a, b.gbool);
export fn eq{T}(a: T, b: gbool) -> gbool = geq{gbool, gbool}(a.gbool, b);
export fn eq(a: gvec2u, b: gvec2u) -> gvec2b = geq{gvec2u, gvec2b}(a, b);
export fn eq(a: gvec2i, b: gvec2i) -> gvec2b = geq{gvec2i, gvec2b}(a, b);
export fn eq(a: gvec2f, b: gvec2f) -> gvec2b = geq{gvec2f, gvec2b}(a, b);
export fn eq(a: gvec2b, b: gvec2b) -> gvec2b = geq{gvec2b, gvec2b}(a, b);
export fn eq(a: gvec3u, b: gvec3u) -> gvec3b = geq{gvec3u, gvec3b}(a, b);
export fn eq(a: gvec3i, b: gvec3i) -> gvec3b = geq{gvec3i, gvec3b}(a, b);
export fn eq(a: gvec3f, b: gvec3f) -> gvec3b = geq{gvec3f, gvec3b}(a, b);
export fn eq(a: gvec3b, b: gvec3b) -> gvec3b = geq{gvec3b, gvec3b}(a, b);
export fn eq(a: gvec4u, b: gvec4u) -> gvec4b = geq{gvec4u, gvec4b}(a, b);
export fn eq(a: gvec4i, b: gvec4i) -> gvec4b = geq{gvec4i, gvec4b}(a, b);
export fn eq(a: gvec4f, b: gvec4f) -> gvec4b = geq{gvec4f, gvec4b}(a, b);
export fn eq(a: gvec4b, b: gvec4b) -> gvec4b = geq{gvec4b, gvec4b}(a, b);

fn gneq{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' != ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn neq(a: gu32, b: gu32) -> gbool = gneq{gu32, gbool}(a, b);
export fn neq{T}(a: gu32, b: T) -> gbool = gneq{gu32, gbool}(a, b.gu32);
export fn neq{T}(a: T, b: gu32) -> gbool = gneq{gu32, gbool}(a.gu32, b);
export fn neq(a: gi32, b: gi32) -> gbool = gneq{gi32, gbool}(a, b);
export fn neq{T}(a: gi32, b: T) -> gbool = gneq{gi32, gbool}(a, b.gi32);
export fn neq{T}(a: T, b: gi32) -> gbool = gneq{gi32, gbool}(a.gi32, b);
export fn neq(a: gf32, b: gf32) -> gbool = gneq{gf32, gbool}(a, b);
export fn neq{T}(a: gf32, b: T) -> gbool = gneq{gf32, gbool}(a, b.gf32);
export fn neq{T}(a: T, b: gf32) -> gbool = gneq{gf32, gbool}(a.gf32, b);
export fn neq(a: gbool, b: gbool) -> gbool = gneq{gbool, gbool}(a, b);
export fn neq{T}(a: gbool, b: T) -> gbool = gneq{gbool, gbool}(a, b.gbool);
export fn neq{T}(a: T, b: gbool) -> gbool = gneq{gbool, gbool}(a.gbool, b);
export fn neq(a: gvec2u, b: gvec2u) -> gvec2b = gneq{gvec2u, gvec2b}(a, b);
export fn neq(a: gvec2i, b: gvec2i) -> gvec2b = gneq{gvec2i, gvec2b}(a, b);
export fn neq(a: gvec2f, b: gvec2f) -> gvec2b = gneq{gvec2f, gvec2b}(a, b);
export fn neq(a: gvec2b, b: gvec2b) -> gvec2b = gneq{gvec2b, gvec2b}(a, b);
export fn neq(a: gvec3u, b: gvec3u) -> gvec3b = gneq{gvec3u, gvec3b}(a, b);
export fn neq(a: gvec3i, b: gvec3i) -> gvec3b = gneq{gvec3i, gvec3b}(a, b);
export fn neq(a: gvec3f, b: gvec3f) -> gvec3b = gneq{gvec3f, gvec3b}(a, b);
export fn neq(a: gvec3b, b: gvec3b) -> gvec3b = gneq{gvec3b, gvec3b}(a, b);
export fn neq(a: gvec4u, b: gvec4u) -> gvec4b = gneq{gvec4u, gvec4b}(a, b);
export fn neq(a: gvec4i, b: gvec4i) -> gvec4b = gneq{gvec4i, gvec4b}(a, b);
export fn neq(a: gvec4f, b: gvec4f) -> gvec4b = gneq{gvec4f, gvec4b}(a, b);
export fn neq(a: gvec4b, b: gvec4b) -> gvec4b = gneq{gvec4b, gvec4b}(a, b);

fn glt{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' < ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn lt(a: gu32, b: gu32) -> gbool = glt{gu32, gbool}(a, b);
export fn lt{T}(a: gu32, b: T) -> gbool = glt{gu32, gbool}(a, b.gu32);
export fn lt{T}(a: T, b: gu32) -> gbool = glt{gu32, gbool}(a.gu32, b);
export fn lt(a: gi32, b: gi32) -> gbool = glt{gi32, gbool}(a, b);
export fn lt{T}(a: gi32, b: T) -> gbool = glt{gi32, gbool}(a, b.gi32);
export fn lt{T}(a: T, b: gi32) -> gbool = glt{gi32, gbool}(a.gi32, b);
export fn lt(a: gf32, b: gf32) -> gbool = glt{gf32, gbool}(a, b);
export fn lt{T}(a: gf32, b: T) -> gbool = glt{gf32, gbool}(a, b.gf32);
export fn lt{T}(a: T, b: gf32) -> gbool = glt{gf32, gbool}(a.gf32, b);
export fn lt(a: gvec2u, b: gvec2u) -> gvec2b = glt{gvec2u, gvec2b}(a, b);
export fn lt(a: gvec2i, b: gvec2i) -> gvec2b = glt{gvec2i, gvec2b}(a, b);
export fn lt(a: gvec2f, b: gvec2f) -> gvec2b = glt{gvec2f, gvec2b}(a, b);
export fn lt(a: gvec3u, b: gvec3u) -> gvec3b = glt{gvec3u, gvec3b}(a, b);
export fn lt(a: gvec3i, b: gvec3i) -> gvec3b = glt{gvec3i, gvec3b}(a, b);
export fn lt(a: gvec3f, b: gvec3f) -> gvec3b = glt{gvec3f, gvec3b}(a, b);
export fn lt(a: gvec4u, b: gvec4u) -> gvec4b = glt{gvec4u, gvec4b}(a, b);
export fn lt(a: gvec4i, b: gvec4i) -> gvec4b = glt{gvec4i, gvec4b}(a, b);
export fn lt(a: gvec4f, b: gvec4f) -> gvec4b = glt{gvec4f, gvec4b}(a, b);

fn glte{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' <= ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn lte(a: gu32, b: gu32) -> gbool = glte{gu32, gbool}(a, b);
export fn lte{T}(a: gu32, b: T) -> gbool = glte{gu32, gbool}(a, b.gu32);
export fn lte{T}(a: T, b: gu32) -> gbool = glte{gu32, gbool}(a.gu32, b);
export fn lte(a: gi32, b: gi32) -> gbool = glte{gi32, gbool}(a, b);
export fn lte{T}(a: gi32, b: T) -> gbool = glte{gi32, gbool}(a, b.gi32);
export fn lte{T}(a: T, b: gi32) -> gbool = glte{gi32, gbool}(a.gi32, b);
export fn lte(a: gf32, b: gf32) -> gbool = glte{gf32, gbool}(a, b);
export fn lte{T}(a: gf32, b: T) -> gbool = glte{gf32, gbool}(a, b.gf32);
export fn lte{T}(a: T, b: gf32) -> gbool = glte{gf32, gbool}(a.gf32, b);
export fn lte(a: gvec2u, b: gvec2u) -> gvec2b = glte{gvec2u, gvec2b}(a, b);
export fn lte(a: gvec2i, b: gvec2i) -> gvec2b = glte{gvec2i, gvec2b}(a, b);
export fn lte(a: gvec2f, b: gvec2f) -> gvec2b = glte{gvec2f, gvec2b}(a, b);
export fn lte(a: gvec3u, b: gvec3u) -> gvec3b = glte{gvec3u, gvec3b}(a, b);
export fn lte(a: gvec3i, b: gvec3i) -> gvec3b = glte{gvec3i, gvec3b}(a, b);
export fn lte(a: gvec3f, b: gvec3f) -> gvec3b = glte{gvec3f, gvec3b}(a, b);
export fn lte(a: gvec4u, b: gvec4u) -> gvec4b = glte{gvec4u, gvec4b}(a, b);
export fn lte(a: gvec4i, b: gvec4i) -> gvec4b = glte{gvec4i, gvec4b}(a, b);
export fn lte(a: gvec4f, b: gvec4f) -> gvec4b = glte{gvec4f, gvec4b}(a, b);

fn ggt{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' < ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn gt(a: gu32, b: gu32) -> gbool = ggt{gu32, gbool}(a, b);
export fn gt{T}(a: gu32, b: T) -> gbool = ggt{gu32, gbool}(a, b.gu32);
export fn gt{T}(a: T, b: gu32) -> gbool = ggt{gu32, gbool}(a.gu32, b);
export fn gt(a: gi32, b: gi32) -> gbool = ggt{gi32, gbool}(a, b);
export fn gt{T}(a: gi32, b: T) -> gbool = ggt{gi32, gbool}(a, b.gi32);
export fn gt{T}(a: T, b: gi32) -> gbool = ggt{gi32, gbool}(a.gi32, b);
export fn gt(a: gf32, b: gf32) -> gbool = ggt{gf32, gbool}(a, b);
export fn gt{T}(a: gf32, b: T) -> gbool = ggt{gf32, gbool}(a, b.gf32);
export fn gt{T}(a: T, b: gf32) -> gbool = ggt{gf32, gbool}(a.gf32, b);
export fn gt(a: gvec2u, b: gvec2u) -> gvec2b = ggt{gvec2u, gvec2b}(a, b);
export fn gt(a: gvec2i, b: gvec2i) -> gvec2b = ggt{gvec2i, gvec2b}(a, b);
export fn gt(a: gvec2f, b: gvec2f) -> gvec2b = ggt{gvec2f, gvec2b}(a, b);
export fn gt(a: gvec3u, b: gvec3u) -> gvec3b = ggt{gvec3u, gvec3b}(a, b);
export fn gt(a: gvec3i, b: gvec3i) -> gvec3b = ggt{gvec3i, gvec3b}(a, b);
export fn gt(a: gvec3f, b: gvec3f) -> gvec3b = ggt{gvec3f, gvec3b}(a, b);
export fn gt(a: gvec4u, b: gvec4u) -> gvec4b = ggt{gvec4u, gvec4b}(a, b);
export fn gt(a: gvec4i, b: gvec4i) -> gvec4b = ggt{gvec4i, gvec4b}(a, b);
export fn gt(a: gvec4f, b: gvec4f) -> gvec4b = ggt{gvec4f, gvec4b}(a, b);

fn ggte{I, O}(a: I, b: I) -> O {
let varName = '('.concat(a.varName).concat(' <= ').concat(b.varName).concat(')');
let statements = a.statements.concat(b.statements);
let buffers = a.buffers.union(b.buffers);
return {O}(varName, statements, buffers);
}
export fn gte(a: gu32, b: gu32) -> gbool = ggte{gu32, gbool}(a, b);
export fn gte{T}(a: gu32, b: T) -> gbool = ggte{gu32, gbool}(a, b.gu32);
export fn gte{T}(a: T, b: gu32) -> gbool = ggte{gu32, gbool}(a.gu32, b);
export fn gte(a: gi32, b: gi32) -> gbool = ggte{gi32, gbool}(a, b);
export fn gte{T}(a: gi32, b: T) -> gbool = ggte{gi32, gbool}(a, b.gi32);
export fn gte{T}(a: T, b: gi32) -> gbool = ggte{gi32, gbool}(a.gi32, b);
export fn gte(a: gf32, b: gf32) -> gbool = ggte{gf32, gbool}(a, b);
export fn gte{T}(a: gf32, b: T) -> gbool = ggte{gf32, gbool}(a, b.gf32);
export fn gte{T}(a: T, b: gf32) -> gbool = ggte{gf32, gbool}(a.gf32, b);
export fn gte(a: gvec2u, b: gvec2u) -> gvec2b = ggte{gvec2u, gvec2b}(a, b);
export fn gte(a: gvec2i, b: gvec2i) -> gvec2b = ggte{gvec2i, gvec2b}(a, b);
export fn gte(a: gvec2f, b: gvec2f) -> gvec2b = ggte{gvec2f, gvec2b}(a, b);
export fn gte(a: gvec3u, b: gvec3u) -> gvec3b = ggte{gvec3u, gvec3b}(a, b);
export fn gte(a: gvec3i, b: gvec3i) -> gvec3b = ggte{gvec3i, gvec3b}(a, b);
export fn gte(a: gvec3f, b: gvec3f) -> gvec3b = ggte{gvec3f, gvec3b}(a, b);
export fn gte(a: gvec4u, b: gvec4u) -> gvec4b = ggte{gvec4u, gvec4b}(a, b);
export fn gte(a: gvec4i, b: gvec4i) -> gvec4b = ggte{gvec4i, gvec4b}(a, b);
export fn gte(a: gvec4f, b: gvec4f) -> gvec4b = ggte{gvec4f, gvec4b}(a, b);

fn if{T}(c: gbool, t: T, f: T) -> T {
let varName = "if_".concat(uuid().string.replace('-', '_'));
let tBody = t.statements.Array.map(fn (kv: (string, string)) -> string {
return if(kv.0.eq("@builtin(global_invocation_id) id: vec3u"), fn () -> string = "", fn () -> string {
return " ".concat(kv.1).concat(";\n");
});
}).join("");
let fBody = f.statements.Array.map(fn (kv: (string, string)) -> string {
return if(kv.0.eq("@builtin(global_invocation_id) id: vec3u"), fn () -> string = "", fn () -> string {
return " ".concat(kv.1).concat(";\n");
});
}).join("");
let statement = "var "
.concat(varName)
.concat(": ")
.concat(t.typeName)
.concat("; if ")
.concat(c.varName)
.concat(" { ")
.concat(tBody)
.concat('; ')
.concat(varName)
.concat(' = ')
.concat(t.varName)
.concat("; } else { ")
.concat(fBody)
.concat('; ')
.concat(varName)
.concat(' = ')
.concat(f.varName)
.concat('; }');
let statements = c.statements.concat(Dict(varName, statement));
let buffers = c.buffers.union(t.buffers).union(f.buffers);
return {T}(varName, statements, buffers);
}
fn if{T}(c: gbool, t: () -> T, f: () -> T) -> T = if(c, t(), f());

// GBuffer methods

// TODO: Support more than i32 for GBuffer
Expand All @@ -4462,6 +4665,15 @@ export fn map(gb: GBuffer, f: (gi32) -> gi32) -> GBuffer {
compute.build.run;
return out;
}
export fn map(gb: GBuffer, f: (gi32, gu32) -> gi32) -> GBuffer {
let idx = gFor(gb.len);
let val = gb[idx];
let out = GBuffer(gb.len.mul(4));
let compute = out[idx].store(f(val, idx));
compute.build.run;
return out;
}


/// Stdout/stderr-related bindings
// TODO: Rework this to just print anything that can be converted to `string` via interfaces
Expand Down
Loading