diff --git a/node-wasm/Cargo.toml b/node-wasm/Cargo.toml index 169b97be0..a32c3be71 100644 --- a/node-wasm/Cargo.toml +++ b/node-wasm/Cargo.toml @@ -51,12 +51,16 @@ web-sys = { version = "0.3.69", features = [ "Blob", "BlobPropertyBag", "BroadcastChannel", + "DedicatedWorkerGlobalScope", "MessageEvent", "MessagePort", + "Navigator", "SharedWorker", "SharedWorkerGlobalScope", "Url", "Worker", + "WorkerGlobalScope", + "WorkerNavigator", "WorkerOptions", - "WorkerType", + "WorkerType" ] } diff --git a/node-wasm/js/worker.js b/node-wasm/js/worker.js index d41dfd5ae..5fba1529f 100644 --- a/node-wasm/js/worker.js +++ b/node-wasm/js/worker.js @@ -7,18 +7,23 @@ export function worker_script_url() { } // if we are in a worker -if ( - typeof WorkerGlobalScope !== 'undefined' - && self instanceof WorkerGlobalScope -) { +if (typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope) { Error.stackTraceLimit = 99; + // for SharedWorker we queue incoming connections + // for dedicated Workerwe queue incoming messages (coming from the single client) let queued = []; - onconnect = (event) => { - console.log("Queued connection", event); - queued.push(event.ports[0]); + if (typeof SharedWorkerGlobalScope !== 'undefined' && self instanceof SharedWorkerGlobalScope) { + onconnect = (event) => { + queued.push(event) + } + } else { + onmessage = (event) => { + queued.push(event); + } } await init(); + console.log("starting worker, queued messages: ", queued.length); await run_worker(queued); } diff --git a/node-wasm/src/node.rs b/node-wasm/src/node.rs index 5f40aba25..ba500cb2b 100644 --- a/node-wasm/src/node.rs +++ b/node-wasm/src/node.rs @@ -6,22 +6,22 @@ use libp2p::identity::Keypair; use libp2p::multiaddr::Protocol; use serde::{Deserialize, Serialize}; use serde_wasm_bindgen::to_value; -use tracing::error; +use tracing::info; use wasm_bindgen::prelude::*; -use web_sys::{MessageEvent, SharedWorker}; +use web_sys::{SharedWorker, Worker, WorkerOptions, WorkerType}; use lumina_node::blockstore::IndexedDbBlockstore; use lumina_node::network::{canonical_network_bootnodes, network_genesis, network_id}; use lumina_node::node::NodeConfig; use lumina_node::store::IndexedDbStore; -use crate::utils::{js_value_from_display, Network}; +use crate::utils::{is_chrome, js_value_from_display, JsValueToJsError, Network}; use crate::worker::commands::{CheckableResponseExt, NodeCommand, SingleHeaderQuery}; -use crate::worker::{spawn_worker, WorkerClient, WorkerError}; +use crate::worker::{worker_script_url, WorkerClient, WorkerError}; use crate::wrapper::libp2p::NetworkInfoSnapshot; use crate::Result; -const LUMINA_SHARED_WORKER_NAME: &str = "lumina"; +const LUMINA_WORKER_NAME: &str = "lumina"; /// Config for the lumina wasm node. #[wasm_bindgen(js_name = NodeConfig)] @@ -37,11 +37,25 @@ pub struct WasmNodeConfig { pub bootnodes: Vec, } +/// `NodeDriver` represents lumina node running in a dedicated Worker/SharedWorker. +/// It's responsible for sending commands and receiving responses from the node. #[wasm_bindgen(js_name = NodeClient)] struct NodeDriver { - _worker: SharedWorker, - _onerror_callback: Closure, - channel: WorkerClient, + client: WorkerClient, +} + +/// Type of worker to run lumina in. Allows overriding automatically detected worker kind +/// (which should usually be appropriate). +#[wasm_bindgen] +pub enum NodeWorkerKind { + /// Run in [`SharedWorker`] + /// + /// [`SharedWorker`]: https://developer.mozilla.org/en-US/docs/Web/API/SharedWorker + Shared, + /// Run in [`Worker`] + /// + /// [`Worker`]: https://developer.mozilla.org/en-US/docs/Web/API/Worker + Dedicated, } #[wasm_bindgen(js_class = NodeClient)] @@ -50,27 +64,40 @@ impl NodeDriver { /// Note that single Shared Worker can be accessed from multiple tabs, so Lumina may /// already have been started. Otherwise it needs to be started with [`NodeDriver::start`]. #[wasm_bindgen(constructor)] - pub async fn new() -> Result { - let worker = spawn_worker(LUMINA_SHARED_WORKER_NAME)?; - - let onerror_callback: Closure = Closure::new(|ev: MessageEvent| { - error!("received error from SharedWorker: {:?}", ev.to_string()); - }); - worker.set_onerror(Some(onerror_callback.as_ref().unchecked_ref())); + pub async fn new(worker_type: Option) -> Result { + let url = worker_script_url(); + let mut opts = WorkerOptions::new(); + opts.type_(WorkerType::Module); + opts.name(LUMINA_WORKER_NAME); + + let default_worker_type = if is_chrome() { + NodeWorkerKind::Dedicated + } else { + NodeWorkerKind::Shared + }; - let channel = WorkerClient::new(worker.port()); + let client = match worker_type.unwrap_or(default_worker_type) { + NodeWorkerKind::Shared => { + info!("Starting SharedWorker"); + let worker = SharedWorker::new_with_worker_options(&url, &opts) + .to_error("could not create SharedWorker")?; + WorkerClient::from(worker) + } + NodeWorkerKind::Dedicated => { + info!("Starting Worker"); + let worker = + Worker::new_with_options(&url, &opts).to_error("could not create Worker")?; + WorkerClient::from(worker) + } + }; - Ok(Self { - _worker: worker, - _onerror_callback: onerror_callback, - channel, - }) + Ok(Self { client }) } /// Check whether Lumina is currently running pub async fn is_running(&self) -> Result { let command = NodeCommand::IsRunning; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let running = response.into_is_running().check_variant()?; Ok(running) @@ -79,7 +106,7 @@ impl NodeDriver { /// Start a node with the provided config, if it's not running pub async fn start(&self, config: WasmNodeConfig) -> Result<()> { let command = NodeCommand::StartNode(config); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let started = response.into_node_started().check_variant()?; Ok(started?) @@ -88,7 +115,7 @@ impl NodeDriver { /// Get node's local peer ID. pub async fn local_peer_id(&self) -> Result { let command = NodeCommand::GetLocalPeerId; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let peer_id = response.into_local_peer_id().check_variant()?; Ok(peer_id) @@ -97,7 +124,7 @@ impl NodeDriver { /// Get current [`PeerTracker`] info. pub async fn peer_tracker_info(&self) -> Result { let command = NodeCommand::GetPeerTrackerInfo; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let peer_info = response.into_peer_tracker_info().check_variant()?; Ok(to_value(&peer_info)?) @@ -106,7 +133,7 @@ impl NodeDriver { /// Wait until the node is connected to at least 1 peer. pub async fn wait_connected(&self) -> Result<()> { let command = NodeCommand::WaitConnected { trusted: false }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let result = response.into_connected().check_variant()?; Ok(result?) @@ -115,7 +142,7 @@ impl NodeDriver { /// Wait until the node is connected to at least 1 trusted peer. pub async fn wait_connected_trusted(&self) -> Result<()> { let command = NodeCommand::WaitConnected { trusted: true }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let result = response.into_connected().check_variant()?; Ok(result?) @@ -124,7 +151,7 @@ impl NodeDriver { /// Get current network info. pub async fn network_info(&self) -> Result { let command = NodeCommand::GetNetworkInfo; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let network_info = response.into_network_info().check_variant()?; Ok(network_info?) @@ -133,7 +160,7 @@ impl NodeDriver { /// Get all the multiaddresses on which the node listens. pub async fn listeners(&self) -> Result { let command = NodeCommand::GetListeners; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let listeners = response.into_listeners().check_variant()?; let result = listeners?.iter().map(js_value_from_display).collect(); @@ -143,7 +170,7 @@ impl NodeDriver { /// Get all the peers that node is connected to. pub async fn connected_peers(&self) -> Result { let command = NodeCommand::GetConnectedPeers; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let peers = response.into_connected_peers().check_variant()?; let result = peers?.iter().map(js_value_from_display).collect(); @@ -156,7 +183,7 @@ impl NodeDriver { peer_id: peer_id.parse()?, is_trusted, }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let set_result = response.into_set_peer_trust().check_variant()?; Ok(set_result?) @@ -165,7 +192,7 @@ impl NodeDriver { /// Request the head header from the network. pub async fn request_head_header(&self) -> Result { let command = NodeCommand::RequestHeader(SingleHeaderQuery::Head); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -174,7 +201,7 @@ impl NodeDriver { /// Request a header for the block with a given hash from the network. pub async fn request_header_by_hash(&self, hash: &str) -> Result { let command = NodeCommand::RequestHeader(SingleHeaderQuery::ByHash(hash.parse()?)); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -183,7 +210,7 @@ impl NodeDriver { /// Request a header for the block with a given height from the network. pub async fn request_header_by_height(&self, height: u64) -> Result { let command = NodeCommand::RequestHeader(SingleHeaderQuery::ByHeight(height)); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -201,7 +228,7 @@ impl NodeDriver { from: from_header, amount, }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let headers = response.into_headers().check_variant()?; Ok(headers.into_result()?) @@ -210,7 +237,7 @@ impl NodeDriver { /// Get current header syncing info. pub async fn syncer_info(&self) -> Result { let command = NodeCommand::GetSyncerInfo; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let syncer_info = response.into_syncer_info().check_variant()?; Ok(to_value(&syncer_info?)?) @@ -219,7 +246,7 @@ impl NodeDriver { /// Get the latest header announced in the network. pub async fn get_network_head_header(&self) -> Result { let command = NodeCommand::LastSeenNetworkHead; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_last_seen_network_head().check_variant()?; Ok(header) @@ -228,7 +255,7 @@ impl NodeDriver { /// Get the latest locally synced header. pub async fn get_local_head_header(&self) -> Result { let command = NodeCommand::GetHeader(SingleHeaderQuery::Head); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -237,7 +264,7 @@ impl NodeDriver { /// Get a synced header for the block with a given hash. pub async fn get_header_by_hash(&self, hash: &str) -> Result { let command = NodeCommand::GetHeader(SingleHeaderQuery::ByHash(hash.parse()?)); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -246,7 +273,7 @@ impl NodeDriver { /// Get a synced header for the block with a given height. pub async fn get_header_by_height(&self, height: u64) -> Result { let command = NodeCommand::GetHeader(SingleHeaderQuery::ByHeight(height)); - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let header = response.into_header().check_variant()?; Ok(header.into_result()?) @@ -270,7 +297,7 @@ impl NodeDriver { start_height, end_height, }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let headers = response.into_headers().check_variant()?; Ok(headers.into_result()?) @@ -279,7 +306,7 @@ impl NodeDriver { /// Get data sampling metadata of an already sampled height. pub async fn get_sampling_metadata(&self, height: u64) -> Result { let command = NodeCommand::GetSamplingMetadata { height }; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; let metadata = response.into_sampling_metadata().check_variant()?; Ok(to_value(&metadata?)?) @@ -289,7 +316,7 @@ impl NodeDriver { /// be processed and new NodeClient needs to be created to restart a node. pub async fn close(&self) -> Result<()> { let command = NodeCommand::CloseWorker; - let response = self.channel.exec(command).await?; + let response = self.client.exec(command).await?; if response.is_worker_closed() { Ok(()) } else { diff --git a/node-wasm/src/utils.rs b/node-wasm/src/utils.rs index f940ef832..4d2cfffd3 100644 --- a/node-wasm/src/utils.rs +++ b/node-wasm/src/utils.rs @@ -12,7 +12,10 @@ use tracing_subscriber::fmt::time::UtcTime; use tracing_subscriber::prelude::*; use tracing_web::{performance_layer, MakeConsoleWriter}; use wasm_bindgen::prelude::*; -use web_sys::{SharedWorker, SharedWorkerGlobalScope}; +use web_sys::{ + window, DedicatedWorkerGlobalScope, SharedWorker, SharedWorkerGlobalScope, Worker, + WorkerGlobalScope, +}; use lumina_node::network; @@ -104,6 +107,7 @@ pub(crate) trait WorkerSelf { type GlobalScope; fn worker_self() -> Self::GlobalScope; + fn is_worker_type() -> bool; } impl WorkerSelf for SharedWorker { @@ -112,6 +116,22 @@ impl WorkerSelf for SharedWorker { fn worker_self() -> Self::GlobalScope { JsValue::from(js_sys::global()).into() } + + fn is_worker_type() -> bool { + js_sys::global().has_type::() + } +} + +impl WorkerSelf for Worker { + type GlobalScope = DedicatedWorkerGlobalScope; + + fn worker_self() -> Self::GlobalScope { + JsValue::from(js_sys::global()).into() + } + + fn is_worker_type() -> bool { + js_sys::global().has_type::() + } } #[derive(Serialize, Deserialize, Debug)] @@ -161,3 +181,27 @@ where } } } + +const CHROME_USER_AGENT_DETECTION: &str = "Chrome/"; + +// currently there's issue with SharedWorkers on Chrome, where restarting lumina's worker +// causes all network connections to fail. Until that's resolved detect chrome and apply +// a workaround. +pub(crate) fn is_chrome() -> bool { + let mut user_agent = None; + if let Some(window) = window() { + user_agent = Some(window.navigator().user_agent()); + }; + if let Some(worker_scope) = JsValue::from(js_sys::global()).dyn_ref::() { + user_agent = Some(worker_scope.navigator().user_agent()); + } + + if let Some(user_agent) = user_agent { + user_agent + .as_deref() + .unwrap_or("") + .contains(CHROME_USER_AGENT_DETECTION) + } else { + false + } +} diff --git a/node-wasm/src/worker.rs b/node-wasm/src/worker.rs index b743280a6..07a2242d8 100644 --- a/node-wasm/src/worker.rs +++ b/node-wasm/src/worker.rs @@ -8,15 +8,17 @@ use thiserror::Error; use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; use wasm_bindgen::prelude::*; -use web_sys::{MessagePort, SharedWorker, WorkerOptions, WorkerType}; +use web_sys::{MessageEvent, SharedWorker}; use lumina_node::node::{Node, NodeError}; use lumina_node::store::{IndexedDbStore, SamplingMetadata, Store, StoreError}; use lumina_node::syncer::SyncingInfo; use crate::node::WasmNodeConfig; -use crate::utils::{to_jsvalue_or_undefined, JsValueToJsError, WorkerSelf}; -use crate::worker::channel::{WorkerMessage, WorkerMessageServer}; +use crate::utils::{to_jsvalue_or_undefined, WorkerSelf}; +use crate::worker::channel::{ + DedicatedWorkerMessageServer, MessageServer, SharedWorkerMessageServer, WorkerMessage, +}; use crate::worker::commands::{NodeCommand, SingleHeaderQuery, WorkerResponse}; use crate::wrapper::libp2p::NetworkInfoSnapshot; @@ -251,16 +253,17 @@ impl NodeWorker { } #[wasm_bindgen] -pub async fn run_worker(queued_connections: Vec) { +pub async fn run_worker(queued_events: Vec) { info!("Entered run_worker"); let (tx, mut rx) = mpsc::channel(WORKER_MESSAGE_SERVER_INCOMING_QUEUE_LENGTH); - let mut message_server = WorkerMessageServer::new(tx.clone()); - for connection in queued_connections { - message_server.add(connection); - } + let mut message_server: Box = if SharedWorker::is_worker_type() { + Box::new(SharedWorkerMessageServer::new(tx.clone(), queued_events)) + } else { + Box::new(DedicatedWorkerMessageServer::new(tx.clone(), queued_events).await) + }; - info!("Entering SharedWorker message loop"); + info!("Entering worker message loop"); let mut worker = None; while let Some(message) = rx.recv().await { match message { @@ -307,22 +310,8 @@ pub async fn run_worker(queued_connections: Vec) { info!("Channel to WorkerMessageServer closed, exiting the SharedWorker"); } -/// Spawn a new SharedWorker. -pub(crate) fn spawn_worker(name: &str) -> Result { - let url = worker_script_url(); - - let mut opts = WorkerOptions::new(); - opts.type_(WorkerType::Module); - opts.name(name); - - let worker = SharedWorker::new_with_worker_options(&url, &opts) - .to_error("could not create SharedWorker")?; - - Ok(worker) -} - #[wasm_bindgen(module = "/js/worker.js")] extern "C" { // must be called in order to include this script in generated package - fn worker_script_url() -> String; + pub(crate) fn worker_script_url() -> String; } diff --git a/node-wasm/src/worker/channel.rs b/node-wasm/src/worker/channel.rs index 807596911..56896c532 100644 --- a/node-wasm/src/worker/channel.rs +++ b/node-wasm/src/worker/channel.rs @@ -6,7 +6,7 @@ use tokio::sync::{mpsc, Mutex}; use tracing::{debug, error, info, warn}; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::spawn_local; -use web_sys::{MessageEvent, MessagePort, SharedWorker}; +use web_sys::{DedicatedWorkerGlobalScope, MessageEvent, MessagePort, SharedWorker, Worker}; use crate::utils::WorkerSelf; use crate::worker::commands::{NodeCommand, WorkerResponse}; @@ -15,24 +15,22 @@ use crate::worker::WorkerError; type WireMessage = Result; type WorkerClientConnection = (MessagePort, Closure); +/// Access to sending channel is protected by mutex to make sure we only can hold a single +/// writable instance from JS. Thus we expect to have at most 1 message in-flight. const WORKER_CHANNEL_SIZE: usize = 1; /// `WorkerClient` is responsible for sending messages to and receiving responses from [`WorkerMessageServer`]. /// It covers JS details like callbacks, having to synchronise requests and responses and exposes /// simple RPC-like function call interface. pub(crate) struct WorkerClient { - _onmessage: Closure, - message_port: MessagePort, + worker: AnyWorker, response_channel: Mutex>, + _onmessage: Closure, + _onerror: Closure, } -impl WorkerClient { - /// Create a new `WorkerClient` out of a [`MessagePort`] that should be connected - /// to a [`SharedWorker`] running lumina. - /// - /// [`SharedWorker`]: https://developer.mozilla.org/en-US/docs/Web/API/SharedWorker - /// [`MessagePort`]: https://developer.mozilla.org/en-US/docs/Web/API/MessagePort - pub fn new(message_port: MessagePort) -> Self { +impl From for WorkerClient { + fn from(worker: Worker) -> WorkerClient { let (response_tx, response_rx) = mpsc::channel(WORKER_CHANNEL_SIZE); let onmessage_callback = move |ev: MessageEvent| { @@ -53,15 +51,68 @@ impl WorkerClient { }; let onmessage = Closure::new(onmessage_callback); + worker.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + + let onerror = Closure::new(|ev: MessageEvent| { + error!("received error from Worker: {:?}", ev.to_string()); + }); + worker.set_onerror(Some(onerror.as_ref().unchecked_ref())); + + Self { + worker: AnyWorker::DedicatedWorker(worker), + response_channel: Mutex::new(response_rx), + _onmessage: onmessage, + _onerror: onerror, + } + } +} + +impl From for WorkerClient { + fn from(worker: SharedWorker) -> WorkerClient { + let (response_tx, response_rx) = mpsc::channel(WORKER_CHANNEL_SIZE); + + let onmessage_callback = move |ev: MessageEvent| { + let response_tx = response_tx.clone(); + spawn_local(async move { + let data: WireMessage = match from_value(ev.data()) { + Ok(jsvalue) => jsvalue, + Err(e) => { + error!("WorkerClient could not convert from JsValue: {e}"); + Err(WorkerError::CouldNotDeserialiseResponse(e.to_string())) + } + }; + + if let Err(e) = response_tx.send(data).await { + error!("message forwarding channel closed, should not happen: {e}"); + } + }) + }; + + let onmessage = Closure::new(onmessage_callback); + let message_port = worker.port(); message_port.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + + let onerror: Closure = Closure::new(|ev: MessageEvent| { + error!("received error from SharedWorker: {:?}", ev.to_string()); + }); + worker.set_onerror(Some(onerror.as_ref().unchecked_ref())); + message_port.start(); Self { + worker: AnyWorker::SharedWorker(worker), response_channel: Mutex::new(response_rx), _onmessage: onmessage, - message_port, + _onerror: onerror, } } +} +enum AnyWorker { + DedicatedWorker(Worker), + SharedWorker(SharedWorker), +} + +impl WorkerClient { /// Send command to lumina and wait for a response. /// /// Response enum variant can be converted into appropriate type at runtime with a provided @@ -82,13 +133,22 @@ impl WorkerClient { fn send(&self, command: NodeCommand) -> Result<(), WorkerError> { let command_value = to_value(&command).map_err(|e| WorkerError::CouldNotSerialiseCommand(e.to_string()))?; - self.message_port - .post_message(&command_value) - .map_err(|e| WorkerError::CouldNotSendCommand(format!("{:?}", e.dyn_ref::()))) + match &self.worker { + AnyWorker::DedicatedWorker(worker) => { + worker.post_message(&command_value).map_err(|e| { + WorkerError::CouldNotSendCommand(format!("{:?}", e.dyn_ref::())) + }) + } + AnyWorker::SharedWorker(worker) => { + worker.port().post_message(&command_value).map_err(|e| { + WorkerError::CouldNotSendCommand(format!("{:?}", e.dyn_ref::())) + }) + } + } } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub(super) struct ClientId(usize); impl fmt::Display for ClientId { @@ -103,10 +163,23 @@ pub(super) enum WorkerMessage { Command((NodeCommand, ClientId)), } -pub(super) struct WorkerMessageServer { +pub(super) trait MessageServer { + fn send_response(&self, client: ClientId, message: WireMessage); + fn add(&mut self, port: MessagePort); + + fn respond_to(&self, client: ClientId, msg: WorkerResponse) { + self.send_response(client, Ok(msg)) + } + + fn respond_err_to(&self, client: ClientId, error: WorkerError) { + self.send_response(client, Err(error)) + } +} + +pub(super) struct SharedWorkerMessageServer { // same onconnect callback is used throughtout entire Worker lifetime. // Keep a reference to make sure it doesn't get dropped. - _onconnect_callback: Closure, + _onconnect: Closure, // keep a MessagePort for each client to send messages over, as well as callback responsible // for forwarding messages back @@ -116,70 +189,36 @@ pub(super) struct WorkerMessageServer { command_channel: mpsc::Sender, } -impl WorkerMessageServer { - pub fn new(command_channel: mpsc::Sender) -> Self { - let closure_command_channel = command_channel.clone(); - let onconnect: Closure = Closure::new(move |ev: MessageEvent| { - let command_channel = closure_command_channel.clone(); - spawn_local(async move { - let Ok(port) = ev.ports().at(0).dyn_into() else { - error!("received onconnect event without MessagePort, should not happen"); - return; - }; - - if let Err(e) = command_channel - .send(WorkerMessage::NewConnection(port)) - .await - { - error!("command channel inside worker closed, should not happen: {e}"); - } - }) - }); - +impl SharedWorkerMessageServer { + pub fn new(command_channel: mpsc::Sender, queued: Vec) -> Self { let worker_scope = SharedWorker::worker_self(); + let onconnect = get_client_connect_callback(command_channel.clone()); worker_scope.set_onconnect(Some(onconnect.as_ref().unchecked_ref())); - Self { - _onconnect_callback: onconnect, + let mut server = Self { + _onconnect: onconnect, clients: Vec::with_capacity(1), // we usually expect to have exactly one client command_channel, + }; + + for event in queued { + if let Ok(port) = event.ports().at(0).dyn_into() { + server.add(port); + } else { + error!("received onconnect event without MessagePort, should not happen"); + } } - } - pub fn respond_to(&self, client: ClientId, msg: WorkerResponse) { - self.send_response(client, Ok(msg)) + server } +} - pub fn respond_err_to(&self, client: ClientId, error: WorkerError) { - self.send_response(client, Err(error)) - } +impl MessageServer for SharedWorkerMessageServer { + fn add(&mut self, port: MessagePort) { + let client_id = ClientId(self.clients.len()); - pub fn add(&mut self, port: MessagePort) { - let client_id = self.clients.len(); - - let near_tx = self.command_channel.clone(); - let client_message_callback: Closure = - Closure::new(move |ev: MessageEvent| { - let local_tx = near_tx.clone(); - spawn_local(async move { - let client_id = ClientId(client_id); - - let message = match from_value(ev.data()) { - Ok(command) => { - debug!("received command from client {client_id}: {command:#?}"); - WorkerMessage::Command((command, client_id)) - } - Err(e) => { - warn!("could not deserialize message from client {client_id}: {e}"); - WorkerMessage::InvalidCommandReceived(client_id) - } - }; - - if let Err(e) = local_tx.send(message).await { - error!("command channel inside worker closed, should not happen: {e}"); - } - }) - }); + let client_message_callback = + get_client_message_callback(self.command_channel.clone(), client_id); self.clients.push((port, client_message_callback)); @@ -209,3 +248,106 @@ impl WorkerMessageServer { } } } + +pub(super) struct DedicatedWorkerMessageServer { + // same onmessage callback is used throughtout entire Worker lifetime. + // Keep a reference to make sure it doesn't get dropped. + _onmessage: Closure, + // global scope we use to send messages + worker: DedicatedWorkerGlobalScope, +} + +impl DedicatedWorkerMessageServer { + pub async fn new( + command_channel: mpsc::Sender, + queued: Vec, + ) -> Self { + for event in queued { + let message = parse_message_event_to_worker_message(event, ClientId(0)); + + if let Err(e) = command_channel.send(message).await { + error!("command channel inside worker closed, should not happen: {e}"); + } + } + + let worker = Worker::worker_self(); + let onmessage = get_client_message_callback(command_channel, ClientId(0)); + worker.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + + Self { + _onmessage: onmessage, + worker, + } + } +} + +impl MessageServer for DedicatedWorkerMessageServer { + fn add(&mut self, _port: MessagePort) { + warn!("DedicatedWorkerMessageServer::add called, should not happen"); + } + + fn send_response(&self, client: ClientId, message: WireMessage) { + let message = match to_value(&message) { + Ok(jsvalue) => jsvalue, + Err(e) => { + warn!("provided response could not be coverted to JsValue: {e}"); + to_value(&WorkerError::CouldNotSerialiseResponse(e.to_string())) + .expect("something's wrong, couldn't serialise serialisation error") + } + }; + + if let Err(e) = self.worker.post_message(&message) { + error!("could not post response message to client {client}: {e:?}"); + } + } +} + +fn get_client_connect_callback( + command_channel: mpsc::Sender, +) -> Closure { + Closure::new(move |ev: MessageEvent| { + let command_channel = command_channel.clone(); + spawn_local(async move { + let Ok(port) = ev.ports().at(0).dyn_into() else { + error!("received onconnect event without MessagePort, should not happen"); + return; + }; + + if let Err(e) = command_channel + .send(WorkerMessage::NewConnection(port)) + .await + { + error!("command channel inside worker closed, should not happen: {e}"); + } + }) + }) +} + +fn get_client_message_callback( + command_channel: mpsc::Sender, + client: ClientId, +) -> Closure { + Closure::new(move |ev: MessageEvent| { + let command_channel = command_channel.clone(); + spawn_local(async move { + let message = parse_message_event_to_worker_message(ev, client); + + if let Err(e) = command_channel.send(message).await { + error!("command channel inside worker closed, should not happen: {e}"); + } + }) + }) +} + +fn parse_message_event_to_worker_message(ev: MessageEvent, client: ClientId) -> WorkerMessage { + match from_value(ev.data()) { + Ok(command) => { + debug!("received command from client {client}: {command:#?}"); + WorkerMessage::Command((command, client)) + } + Err(e) => { + warn!("could not deserialize message from client {client}: {e}"); + WorkerMessage::InvalidCommandReceived(client) + } + } +}