Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GPU Types and Methods, and implement a new hello_gpu example using it #828

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/compile/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,14 +755,25 @@ test!(hello_gpu => r#"
@compute
@workgroup_size(1)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
vals[id.x] = vals[id.x] * bitcast<i32>(id.x);
vals[id.x] = vals[id.x] * i32(id.x);
}
", b);
plan.run;
b.read{i32}.print;
}"#;
stdout "[0, 2, 4, 6]\n";
);
test!(hello_gpu_new => r#"
export fn main {
let b = GBuffer(filled(2.i32, 4));
let id = gFor(4);
// TODO: `save` should be `store`, but there's a function resolution bug
let compute = b[id.x].save(b[id.x] * id.x.gi32);
compute.build.run;
b.read{i32}.print;
}"#;
stdout "[0, 2, 4, 6]\n";
);

// Bitwise Math

Expand Down Expand Up @@ -2045,6 +2056,11 @@ test!(basic_dict => r#"
print(test.len);
print(test.get('foo'));
test['bar'].print;
let test2 = Dict('foo', 3);
test2.store('bay', 4);
test.concat(test2).Array.map(fn (n: (string, i64)) -> string {
return 'key: '.concat(n.0).concat("\nval: ").concat(n.1.string);
}).join("\n").print;
}"#;
stdout r#"key: foo
val: 1
Expand All @@ -2057,6 +2073,14 @@ foo, bar, baz
3
1
2
key: foo
val: 3
key: bar
val: 2
key: baz
val: 99
key: bay
val: 4
"#;
);
test!(keyval_array_to_dict => r#"
Expand Down
1 change: 1 addition & 0 deletions src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ edition = "2021"
flume = "0.11.0"
futures = "0.3.30"
ordered_hash_map = "0.4.0"
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
wgpu = "0.20.1""#;
let cargo_path = {
let mut c = project_dir.clone();
Expand Down
29 changes: 3 additions & 26 deletions src/lntors/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,7 @@ pub fn from_microstatement(
| FnKind::Static => {
let mut arg_strs = Vec::new();
for arg in &fun.args {
match typen::ctype_to_rtype(&arg.1, false) {
Err(e) => Err(e),
Ok(s) => {
arg_strs.push(
s.replace(
['<', '>', ',', '[', ']', ';', '-', '(', ')'],
"_",
)
.replace(' ', ""),
);
/* TODO: Handle generic types better, also type inference */
Ok(())
}
}?;
arg_strs.push(arg.1.to_callable_string());
}
// Come up with a function name that is unique so Rust doesn't choke on
// duplicate function names that are allowed in Alan
Expand Down Expand Up @@ -142,7 +129,7 @@ pub fn from_microstatement(
let (_, o) = typen::generate(&arg_type, out)?;
out = o;
arg_types.push(arg_type.clone());
match typen::ctype_to_rtype(&arg_type, false) {
match typen::ctype_to_rtype(&arg_type, true) {
Err(e) => Err(e),
Ok(s) => {
arg_type_strs.push(s);
Expand All @@ -159,17 +146,7 @@ pub fn from_microstatement(
out = o;
let mut arg_strs = Vec::new();
for arg in &function.args {
match typen::ctype_to_rtype(&arg.1, false) {
Err(e) => Err(e),
Ok(s) => {
arg_strs.push(
s.replace(['<', '>', ',', '[', ']', ';', '-', '(', ')'], "_")
.replace(' ', ""),
);
/* TODO: Handle generic types better, also type inference */
Ok(())
}
}?;
arg_strs.push(arg.1.to_callable_string());
}
// Come up with a function name that is unique so Rust doesn't choke on
// duplicate function names that are allowed in Alan
Expand Down
27 changes: 20 additions & 7 deletions src/lntors/typen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ pub fn ctype_to_rtype(
CType::Tuple(ts) => {
let mut out = Vec::new();
for t in ts {
out.push(ctype_to_rtype(t, in_function_type)?);
match t {
CType::Field(_, t2) => {
if !matches!(&**t2, CType::Int(_) | CType::Float(_) | CType::Bool(_) | CType::TString(_)) {
out.push(ctype_to_rtype(t, in_function_type)?);
}
}
t => out.push(ctype_to_rtype(t, in_function_type)?),
}
}
Ok(format!("({})", out.join(", ")))
}
Expand Down Expand Up @@ -144,12 +151,18 @@ pub fn generate(
// output, while the `Structlike` type requires a new struct to be created and inserted
// into the source definition, potentially inserting inner types as needed
CType::Bound(_name, rtype) => Ok((rtype.clone(), out)),
CType::Type(name, t) => {
let res = generate(t, out)?;
out = res.1;
out.insert(name.clone(), ctype_to_rtype(typen, false)?);
Ok((name.clone(), out))
}
// TODO: The complexity of this function indicates more fundamental issues in the type
// generation. This needs a rethink and rewrite.
CType::Type(name, t) => match &**t {
CType::Either(_) => {
let res = generate(t, out)?;
out = res.1;
out.insert(name.clone(), ctype_to_rtype(typen, false)?);
Ok((name.clone(), out))
}
_ => Ok((ctype_to_rtype(t, true)?, out)),
},
CType::Tuple(_) => Ok((ctype_to_rtype(typen, true)?, out)),
CType::Void => {
out.insert("void".to_string(), "type void = ();".to_string());
Ok(("()".to_string(), out))
Expand Down
27 changes: 22 additions & 5 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,14 @@ impl FullTypename {
.to_string()
}
}
named_and!(typeoperatorswithwhitespace: TypeOperatorsWithWhitespace =>
a: String as optwhitespace,
op: String as typeoperators,
b: String as optwhitespace,
);
named_or!(withtypeoperators: WithTypeOperators =>
TypeBaseList: Vec<TypeBase> as typebaselist,
Operators: String as and!(optwhitespace, typeoperators, optwhitespace),
Operators: TypeOperatorsWithWhitespace as typeoperatorswithwhitespace,
);
impl WithTypeOperators {
#[allow(clippy::inherent_to_string)]
Expand All @@ -842,7 +847,7 @@ impl WithTypeOperators {
.collect::<Vec<String>>()
.join("")
.to_string(),
WithTypeOperators::Operators(o) => o.clone(),
WithTypeOperators::Operators(o) => format!(" {} ", o.op),
}
}
}
Expand Down Expand Up @@ -1383,7 +1388,11 @@ test!(functiontypeline =>
b: "".to_string(),
closeparen: ")".to_string(),
})]),
super::WithTypeOperators::Operators(" -> ".to_string()),
super::WithTypeOperators::Operators(super::TypeOperatorsWithWhitespace {
a: " ".to_string(),
op: "->".to_string(),
b: " ".to_string(),
}),
super::WithTypeOperators::TypeBaseList(vec![super::TypeBase::Variable("string".to_string())]),
]};
);
Expand All @@ -1401,7 +1410,11 @@ test!(interfaceline =>
b: "".to_string(),
closeparen: ")".to_string(),
})]),
super::WithTypeOperators::Operators(" -> ".to_string()),
super::WithTypeOperators::Operators(super::TypeOperatorsWithWhitespace {
a: " ".to_string(),
op: "->".to_string(),
b: " ".to_string(),
}),
super::WithTypeOperators::TypeBaseList(vec![super::TypeBase::Variable("string".to_string())]),
]});
);
Expand All @@ -1417,7 +1430,11 @@ test!(interfacelist =>
b: "".to_string(),
closeparen: ")".to_string(),
})]),
super::WithTypeOperators::Operators(" -> ".to_string()),
super::WithTypeOperators::Operators(super::TypeOperatorsWithWhitespace {
a: " ".to_string(),
op: "->".to_string(),
b: " ".to_string(),
}),
super::WithTypeOperators::TypeBaseList(vec![super::TypeBase::Variable("string".to_string())]),
]})];
);
Expand Down
36 changes: 21 additions & 15 deletions src/program/ctype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,21 +498,21 @@ impl CType {
CType::Int(_) | CType::Float(_) => format!(
"_{}",
self.to_functional_string()
.replace([' ', ',', '{', '}'], "_")
.replace([' ', ',', '{', '}', '"', '\''], "_")
),
CType::Type(_, t) => match **t {
CType::Int(_) | CType::Float(_) => format!(
"_{}",
self.to_functional_string()
.replace([' ', ',', '{', '}'], "_")
.replace([' ', ',', '{', '}', '"', '\''], "_")
),
_ => self
.to_functional_string()
.replace([' ', ',', '{', '}'], "_"),
.replace([' ', ',', '{', '}', '"', '\''], "_"),
},
_ => self
.to_functional_string()
.replace([' ', ',', '{', '}'], "_"),
.replace([' ', ',', '{', '}', '"', '\''], "_"),
}
}
pub fn degroup(&self) -> CType {
Expand Down Expand Up @@ -2742,8 +2742,8 @@ pub fn withtypeoperatorslist_to_ctype(
let mut largest_operator_index: i64 = -1;
for (i, assignable_or_operator) in queue.iter().enumerate() {
if let parse::WithTypeOperators::Operators(o) = assignable_or_operator {
let operatorname = o.trim();
let operator = match scope.resolve_typeoperator(&operatorname.to_string()) {
let operatorname = &o.op;
let operator = match scope.resolve_typeoperator(operatorname) {
Some(o) => Ok(o),
None => Err(format!("Operator {} not found", operatorname)),
}?;
Expand All @@ -2761,10 +2761,10 @@ pub fn withtypeoperatorslist_to_ctype(
if largest_operator_index > -1 {
// We have at least one operator, and this is the one to dig into
let operatorname = match &queue[largest_operator_index as usize] {
parse::WithTypeOperators::Operators(o) => o.trim(),
parse::WithTypeOperators::Operators(o) => &o.op,
_ => unreachable!(),
};
let operator = match scope.resolve_typeoperator(&operatorname.to_string()) {
let operator = match scope.resolve_typeoperator(operatorname) {
Some(o) => Ok(o),
None => Err(format!("Operator {} not found", operatorname)),
}?;
Expand Down Expand Up @@ -2796,15 +2796,15 @@ pub fn withtypeoperatorslist_to_ctype(
parse::WithTypeOperators::TypeBaseList(typebaselist) => Ok(typebaselist),
parse::WithTypeOperators::Operators(o) => Err(format!(
"Operator {} is an infix operator but preceded by another operator {}",
operatorname, o
operatorname, o.op
)),
}?;
let second_arg = match match queue.get(largest_operator_index as usize + 1) {
Some(val) => Ok(val),
None => Err(format!("Operator {} is an infix operator but missing a right-hand side value", operatorname)),
}? {
parse::WithTypeOperators::TypeBaseList(typebaselist) => Ok(typebaselist),
parse::WithTypeOperators::Operators(o) => Err(format!("Operator{} is an infix operator but followed by a lower precedence operator {}", operatorname, o)),
parse::WithTypeOperators::Operators(o) => Err(format!("Operator{} is an infix operator but followed by a lower precedence operator {}", operatorname, o.op)),
}?;
// We're gonna rewrite the operator and base assignables into a function call, eg
// we take `a + b` and turn it into `add(a, b)`
Expand All @@ -2815,7 +2815,13 @@ pub fn withtypeoperatorslist_to_ctype(
a: "".to_string(),
typecalllist: vec![
parse::WithTypeOperators::TypeBaseList(first_arg.to_vec()),
parse::WithTypeOperators::Operators(",".to_string()),
parse::WithTypeOperators::Operators(
parse::TypeOperatorsWithWhitespace {
a: " ".to_string(),
op: ",".to_string(),
b: " ".to_string(),
},
),
parse::WithTypeOperators::TypeBaseList(second_arg.to_vec()),
],
b: "".to_string(),
Expand All @@ -2842,7 +2848,7 @@ pub fn withtypeoperatorslist_to_ctype(
parse::WithTypeOperators::TypeBaseList(typebaselist) => Ok(typebaselist),
parse::WithTypeOperators::Operators(o) => Err(format!(
"Operator {} is an prefix operator but followed by another operator {}",
operatorname, o
operatorname, o.op
)),
}?;
// We're gonna rewrite the operator and base assignables into a function call, eg
Expand Down Expand Up @@ -2875,7 +2881,7 @@ pub fn withtypeoperatorslist_to_ctype(
parse::WithTypeOperators::TypeBaseList(typebaselist) => Ok(typebaselist),
parse::WithTypeOperators::Operators(o) => Err(format!(
"Operator {} is a postfix operator but preceded by another operator {}",
operatorname, o
operatorname, o.op
)),
}?;
// We're gonna rewrite the operator and base assignables into a function call, eg
Expand Down Expand Up @@ -3187,8 +3193,8 @@ pub fn typebaselist_to_ctype(
}
let mut arg_block = Vec::new();
for arg in temp_args {
if let parse::WithTypeOperators::Operators(a) = &arg {
if a.trim() == "," {
if let parse::WithTypeOperators::Operators(o) = &arg {
if o.op == "," {
// Process the arg block that has
// accumulated
args.push(withtypeoperatorslist_to_ctype(&arg_block, scope)?);
Expand Down
Loading
Loading