Skip to content

Commit

Permalink
wasm_js: remove TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
newpavlov committed Dec 4, 2024
1 parent 96ed30e commit fad53d3
Showing 1 changed file with 80 additions and 54 deletions.
134 changes: 80 additions & 54 deletions src/backends/wasm_js.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::Error;
extern crate std;
use std::{mem::MaybeUninit, thread_local};

use core::sync::atomic::{AtomicU8, Ordering};

pub use crate::util::{inner_u32, inner_u64};

#[cfg(not(all(target_arch = "wasm32", target_os = "unknown",)))]
Expand All @@ -18,63 +20,89 @@ const WEB_CRYPTO_BUFFER_SIZE: u16 = 256;
// Node.js's crypto.randomFillSync requires the size to be less than 2**31.
const NODE_MAX_BUFFER_SIZE: usize = (1 << 31) - 1;

enum RngSource {
Node(NodeCrypto),
Web(WebCrypto, Uint8Array),
}
const KIND_UNINIT: u8 = 0;
const KIND_WEB: u8 = 1;
const KIND_NODE: u8 = 2;
const KIND_NA: u8 = 255;

// JsValues are always per-thread, so we initialize RngSource for each thread.
// See: https://github.com/rustwasm/wasm-bindgen/pull/955
thread_local!(
static RNG_SOURCE: Result<RngSource, Error> = getrandom_init();
);
static KIND: AtomicU8 = AtomicU8::new(KIND_UNINIT);

pub fn fill_inner(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
RNG_SOURCE.with(|result| {
let source = result.as_ref().map_err(|&e| e)?;

match source {
RngSource::Node(n) => {
for chunk in dest.chunks_mut(NODE_MAX_BUFFER_SIZE) {
// SAFETY: chunk is never used directly, the memory is only
// modified via the Uint8Array view, which is passed
// directly to JavaScript. Also, crypto.randomFillSync does
// not resize the buffer. We know the length is less than
// u32::MAX because of the chunking above.
// Note that this uses the fact that JavaScript doesn't
// have a notion of "uninitialized memory", this is purely
// a Rust/C/C++ concept.
let res = n.random_fill_sync(unsafe {
Uint8Array::view_mut_raw(chunk.as_mut_ptr().cast::<u8>(), chunk.len())
});
if res.is_err() {
return Err(Error::NODE_RANDOM_FILL_SYNC);
}
}
}
RngSource::Web(crypto, buf) => {
// getRandomValues does not work with all types of WASM memory,
// so we initially write to browser memory to avoid exceptions.
for chunk in dest.chunks_mut(WEB_CRYPTO_BUFFER_SIZE.into()) {
let chunk_len: u32 = chunk
.len()
.try_into()
.expect("chunk length is bounded by WEB_CRYPTO_BUFFER_SIZE");
// The chunk can be smaller than buf's length, so we call to
// JS to create a smaller view of buf without allocation.
let sub_buf = buf.subarray(0, chunk_len);

if crypto.get_random_values(&sub_buf).is_err() {
return Err(Error::WEB_GET_RANDOM_VALUES);
}

// SAFETY: `sub_buf`'s length is the same length as `chunk`
unsafe { sub_buf.raw_copy_to_ptr(chunk.as_mut_ptr().cast::<u8>()) };
}
loop {
break match KIND.load(Ordering::Relaxed) {
KIND_UNINIT => {
let global: Global = global().unchecked_into();
let crypto = global.crypto();
let val = if crypot.is_object() {
KIND_WEB
} else if is_node(&global) {
KIND_NODE
} else {
KIND_NA
};
KIND.store(val, Ordering::Relaxed);
continue;
}
KIND_WEB => web_fill(dest),
KIND_NODE => node_fill(dest),
_ => Err(Error::WEB_CRYPTO),
};
Ok(())
})
}
}

pub fn web_fill(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
let global: Global = global().unchecked_into();
let crypto = global.crypto();

// getRandomValues does not work with all types of WASM memory,
// so we initially write to browser memory to avoid exceptions.
let buf = Uint8Array::new_with_length(WEB_CRYPTO_BUFFER_SIZE.into());
for chunk in dest.chunks_mut(WEB_CRYPTO_BUFFER_SIZE.into()) {
let chunk_len: u32 = chunk
.len()
.try_into()
.expect("chunk length is bounded by WEB_CRYPTO_BUFFER_SIZE");
// The chunk can be smaller than buf's length, so we call to
// JS to create a smaller view of buf without allocation.
let sub_buf = buf.subarray(0, chunk_len);

if crypto.get_random_values(&sub_buf).is_err() {
return Err(Error::WEB_GET_RANDOM_VALUES);
}

// SAFETY: `sub_buf`'s length is the same length as `chunk`
unsafe { sub_buf.raw_copy_to_ptr(chunk.as_mut_ptr().cast::<u8>()) };
}
Ok(())
}

pub fn web_fill(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
// If module.require isn't a valid function, we are in an ES module.
let require_fn = Module::require_fn()
.and_then(JsCast::dyn_into::<Function>)
.map_err(|_| Error::NODE_ES_MODULE)?;
let n = require_fn
.call1(&global, &JsValue::from_str("crypto"))
.map_err(|_| Error::NODE_CRYPTO)?
.unchecked_into();

for chunk in dest.chunks_mut(NODE_MAX_BUFFER_SIZE) {
// SAFETY: chunk is never used directly, the memory is only
// modified via the Uint8Array view, which is passed
// directly to JavaScript. Also, crypto.randomFillSync does
// not resize the buffer. We know the length is less than
// u32::MAX because of the chunking above.
// Note that this uses the fact that JavaScript doesn't
// have a notion of "uninitialized memory", this is purely
// a Rust/C/C++ concept.
let res = n.random_fill_sync(unsafe {
Uint8Array::view_mut_raw(chunk.as_mut_ptr().cast::<u8>(), chunk.len())
});
if res.is_err() {
return Err(Error::NODE_RANDOM_FILL_SYNC);
}
}
Ok(())
}

fn getrandom_init() -> Result<RngSource, Error> {
Expand Down Expand Up @@ -124,8 +152,6 @@ extern "C" {
// Getters for the WebCrypto API
#[wasm_bindgen(method, getter)]
fn crypto(this: &Global) -> WebCrypto;
#[wasm_bindgen(method, getter, js_name = msCrypto)]
fn ms_crypto(this: &Global) -> WebCrypto;
// Crypto.getRandomValues()
#[wasm_bindgen(method, js_name = getRandomValues, catch)]
fn get_random_values(this: &WebCrypto, buf: &Uint8Array) -> Result<(), JsValue>;
Expand Down

0 comments on commit fad53d3

Please sign in to comment.