Skip to content

Commit

Permalink
Generalize GPU buffer create/read/update. Part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
dfellis committed Nov 7, 2024
1 parent 40b091f commit baa9084
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 37 deletions.
29 changes: 21 additions & 8 deletions alan_std.js
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ export class Int {
export class I8 extends Int {
constructor(v) {
super(v, 8, 256, -128, 127);
this.ArrayKind = Int32Array; // GPUs don't support 8-bits (uniformly)
}

build(v) {
Expand All @@ -217,6 +218,7 @@ export class I8 extends Int {
export class U8 extends Int {
constructor(v) {
super(v, 8, 256, 0, 255);
this.ArrayKind = Uint32Array; // GPUs don't support 8-bits (uniformly)
}

build(v) {
Expand All @@ -227,6 +229,7 @@ export class U8 extends Int {
export class I16 extends Int {
constructor(v) {
super(v, 16, 65_536, -32_768, 32_767);
this.ArrayKind = Int32Array; // GPUs don't support 16-bits (uniformly)
}

build(v) {
Expand All @@ -237,6 +240,7 @@ export class I16 extends Int {
export class U16 extends Int {
constructor(v) {
super(v, 16, 65_536, 0, 65_535);
this.ArrayKind = Uint32Array; // GPUs don't support 16-bits (uniformly)
}

build(v) {
Expand All @@ -247,6 +251,7 @@ export class U16 extends Int {
export class I32 extends Int {
constructor(v) {
super(v, 32, 4_294_967_296, -2_147_483_648, 2_147_483_647);
this.ArrayKind = Int32Array;
}

build(v) {
Expand All @@ -257,6 +262,7 @@ export class I32 extends Int {
export class U32 extends Int {
constructor(v) {
super(v, 32, 4_294_967_296, 0, 4_294_967_295);
this.ArrayKind = Uint32Array;
}

build(v) {
Expand All @@ -267,6 +273,7 @@ export class U32 extends Int {
export class I64 extends Int {
constructor(v) {
super(v, 64, 18_446_744_073_709_551_616n, -9_223_372_036_854_775_808n, 9_223_372_036_854_775_807n);
this.ArrayKind = Int32Array; // GPUs don't support 64-bits
}

build(v) {
Expand All @@ -277,6 +284,7 @@ export class I64 extends Int {
export class U64 extends Int {
constructor(v) {
super(v, 64, 18_446_744_073_709_551_616n, 0n, 18_446_744_073_709_551_615n);
this.ArrayKind = Uint32Array; // GPUs don't support 64-bits
}

build(v) {
Expand All @@ -302,6 +310,7 @@ export class Float {
export class F32 extends Float {
constructor(v) {
super(Number(v), 32);
this.ArrayKind = Float32Array;
}

build(v) {
Expand All @@ -312,6 +321,7 @@ export class F32 extends Float {
export class F64 extends Float {
constructor(v) {
super(Number(v), 64);
this.ArrayKind = Float32Array; // GPUs don't support 64-bit vals
}

build(v) {
Expand All @@ -322,6 +332,7 @@ export class F64 extends Float {
export class Bool {
constructor(val) {
this.val = Boolean(val);
this.ArrayKind = Int8Array;
}

valueOf() {
Expand Down Expand Up @@ -400,26 +411,28 @@ export async function createBufferInit(usage, vals) {
let g = await gpu();
let b = await g.device.createBuffer({
mappedAtCreation: true,
size: vals.length * 4,
size: vals.length * (vals[0].bits ?? 32) / 8,
usage,
label: `buffer_${uuidv4().replaceAll('-', '_')}`,
});
let ab = b.getMappedRange();
let i32v = new Int32Array(ab);
let v = new (vals[0].ArrayKind ?? Int32Array)(ab);
for (let i = 0; i < vals.length; i++) {
i32v[i] = vals[i].valueOf();
v[i] = vals[i].valueOf();
}
b.unmap();
b.ValType = vals[0].constructor;
return b;
}

export async function createEmptyBuffer(usage, size) {
export async function createEmptyBuffer(usage, size, ValKind) {
let g = await gpu();
let b = await g.device.createBuffer({
size: size.valueOf() * 4,
size: size.valueOf() * (ValKind.bits ?? 32) / 8,
usage,
label: `buffer_${uuidv4().replaceAll('-', '_')}`,
});
b.ValKind = ValKind;
return b;
}

Expand All @@ -436,7 +449,7 @@ export function storageBufferType() {
}

export function bufferlen(b) {
return new I64(b.size / 4);
return new I64(b.size / ((b.ValKind.bits ?? 32) / 8));
}

export function bufferid(b) {
Expand Down Expand Up @@ -501,10 +514,10 @@ export async function readBuffer(b) {
g.queue.submit([encoder.finish()]);
await tempBuffer.mapAsync(GPUMapMode.READ);
let data = tempBuffer.getMappedRange(0, b.size);
let vals = new Int32Array(data);
let vals = new b.ArrayKind(data);
let out = [];
for (let i = 0; i < vals.length; i++) {
out[i] = new I32(vals[i]);
out[i] = new b.ValKind(vals[i]);
}
tempBuffer.unmap();
tempBuffer.destroy();
Expand Down
4 changes: 2 additions & 2 deletions alan_std.test_gpgpu.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import { chromium } from 'playwright';
}), "");

assert.strictEqual(await page.evaluate(async () => {
let b = await alanStd.createEmptyBuffer(alanStd.storageBufferType(), 4);
let b = await alanStd.createEmptyBuffer(alanStd.storageBufferType(), 4, alanStd.I32);
return alanStd.bufferlen(b).valueOf();
}), 4n);

assert((await page.evaluate(async () => {
let b = await alanStd.createEmptyBuffer(alanStd.storageBufferType(), 4);
let b = await alanStd.createEmptyBuffer(alanStd.storageBufferType(), 4, alanStd.I32);
return alanStd.bufferid(b).valueOf();
})).startsWith("buffer_"));

Expand Down
63 changes: 36 additions & 27 deletions alan_std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,65 +784,75 @@ fn gpu() -> &'static GPU {
}

#[derive(Clone)]
pub struct GBuffer(Rc<wgpu::Buffer>, String); // TODO: Temporary during transition
pub struct GBuffer {
buffer: Rc<wgpu::Buffer>,
id: String,
element_size: i8,
}

impl PartialEq for GBuffer {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
self.id == other.id
}
}

impl Eq for GBuffer {}

impl Hash for GBuffer {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
self.buffer.hash(state);
}
}

impl Deref for GBuffer {
type Target = Rc<wgpu::Buffer>;
fn deref(&self) -> &Self::Target {
&self.0
&self.buffer
}
}

impl DerefMut for GBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
&mut self.buffer
}
}

pub fn create_buffer_init(usage: &wgpu::BufferUsages, vals: &Vec<i32>) -> GBuffer {
pub fn create_buffer_init<T>(
usage: &wgpu::BufferUsages,
vals: &Vec<T>,

Check warning

Code scanning / clippy

writing &Vec instead of &[_] involves a new object where a slice will do Warning

writing &Vec instead of &[\_] involves a new object where a slice will do
element_size: &i8,
) -> GBuffer {
let g = gpu();
let val_slice = &vals[..];
let val_ptr = val_slice.as_ptr();
let val_u8_len = vals.len() * 4;
let val_u8_len = vals.len() * (*element_size as usize);
let val_u8: &[u8] = unsafe { std::slice::from_raw_parts(val_ptr as *const u8, val_u8_len) };
GBuffer(
Rc::new(wgpu::util::DeviceExt::create_buffer_init(
GBuffer {
buffer: Rc::new(wgpu::util::DeviceExt::create_buffer_init(
&g.device,
&wgpu::util::BufferInitDescriptor {
label: None, // TODO: Add a label for easier debugging?
contents: val_u8,
usage: *usage,
},
)),
format!("buffer_{}", format!("{}", Uuid::new_v4()).replace("-", "_")),
)
id: format!("buffer_{}", format!("{}", Uuid::new_v4()).replace("-", "_")),
element_size: *element_size,
}
}

pub fn create_empty_buffer(usage: &wgpu::BufferUsages, size: &i64) -> GBuffer {
pub fn create_empty_buffer(usage: &wgpu::BufferUsages, size: &i64, element_size: &i8) -> GBuffer {
let g = gpu();
GBuffer(
Rc::new(g.device.create_buffer(&wgpu::BufferDescriptor {
GBuffer {
buffer: Rc::new(g.device.create_buffer(&wgpu::BufferDescriptor {
label: None, // TODO: Add a label for easier debugging?
size: *size as u64,
size: (*size as u64) * (*element_size as u64),
usage: *usage,
mapped_at_creation: false, // TODO: With `create_buffer_init` does this make any sense?
})),
format!("buffer_{}", format!("{}", Uuid::new_v4()).replace("-", "_")),
)
id: format!("buffer_{}", format!("{}", Uuid::new_v4()).replace("-", "_")),
element_size: *element_size,
}
}

// TODO: Either add the ability to bind to const values, or come up with a better solution. For
Expand All @@ -864,12 +874,12 @@ pub fn storage_buffer_type() -> wgpu::BufferUsages {

#[inline(always)]
pub fn bufferlen(gb: &GBuffer) -> i64 {
(gb.size() / 4) as i64 // TODO: Support more than i32/u32/f32 values
(gb.size() as i64) / (gb.element_size as i64)
}

#[inline(always)]
pub fn buffer_id(b: &GBuffer) -> String {
b.1.clone()
b.id.clone()
}

pub struct GPGPU {
Expand Down Expand Up @@ -946,12 +956,12 @@ pub fn gpu_run(gg: &GPGPU) {
g.queue.submit(Some(encoder.finish()));
}

pub fn read_buffer(b: &GBuffer) -> Vec<i32> {
// TODO: Support other value types
pub fn read_buffer<T: std::clone::Clone>(b: &GBuffer) -> Vec<T> {
let g = gpu();
let temp_buffer = create_empty_buffer(
&mut map_read_buffer_type(),
&mut b.size().try_into().unwrap(),
&mut b.element_size.clone(),

Check warning

Code scanning / clippy

the function create_empty_buffer doesn't need a mutable reference Warning

the function create\_empty\_buffer doesn't need a mutable reference
);
let mut encoder = g
.device
Expand All @@ -965,27 +975,26 @@ pub fn read_buffer(b: &GBuffer) -> Vec<i32> {
if let Ok(Ok(())) = receiver.recv() {
let data = temp_slice.get_mapped_range();
let data_ptr = data.as_ptr();
let data_len = data.len() / 4; // From u8 to i32
let data_i32: &[i32] =
unsafe { std::slice::from_raw_parts(data_ptr as *const i32, data_len) };
let data_len = data.len() / (b.element_size as usize);
let data_i32: &[T] = unsafe { std::slice::from_raw_parts(data_ptr as *const T, data_len) };
let result = data_i32.to_vec();
drop(data);
temp_buffer.unmap();
result
} else {
panic!("failed to run compute on gpu!")
panic!("Failed to run compute on gpu!")
}
}

#[allow(clippy::ptr_arg)]
pub fn replace_buffer(b: &GBuffer, v: &Vec<i32>) -> Result<(), AlanError> {
pub fn replace_buffer<T>(b: &GBuffer, v: &Vec<T>) -> Result<(), AlanError> {
if v.len() as i64 != bufferlen(b) {
Err("The input array is not the same size as the buffer".into())
} else {
// TODO: Support other value types
let val_slice = &v[..];
let val_ptr = val_slice.as_ptr();
let val_u8_len = v.len() * 4;
let val_u8_len = v.len() * (b.element_size as usize);
let val_u8: &[u8] = unsafe { std::slice::from_raw_parts(val_ptr as *const u8, val_u8_len) };
let g = gpu();
let temp_buffer = wgpu::util::DeviceExt::create_buffer_init(
Expand Down

0 comments on commit baa9084

Please sign in to comment.