Skip to content

Commit

Permalink
refactor: improve usage of env checks (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
dinhani-cw authored Feb 16, 2024
1 parent 4224723 commit d77a08d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
19 changes: 10 additions & 9 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use crate::eth::storage::test_accounts;
use crate::eth::storage::EthStorage;
use crate::eth::storage::InMemoryStorage;
use crate::eth::EthExecutor;
use crate::ext::not;
use crate::infra::postgres::Postgres;

/// Configuration for main Stratus service.
Expand Down Expand Up @@ -121,7 +120,11 @@ impl CommonConfig {
storage.enable_genesis(BlockMiner::genesis()).await?;
}
if self.enable_test_accounts {
storage.save_initial_accounts(test_accounts()).await?;
if self.env.is_production() {
tracing::warn!("cannot enable test accounts in production environment");
} else {
storage.save_initial_accounts(test_accounts()).await?;
}
}
Ok(storage)
}
Expand Down Expand Up @@ -193,10 +196,9 @@ impl FromStr for StorageConfig {
}

/// Enviroment where the application is running.
#[derive(clap::ValueEnum, Debug, Clone, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Environment {
Development,
Staging,
Production,
}

Expand All @@ -206,9 +208,9 @@ impl Environment {
matches!(self, Self::Production)
}

/// Checks if the current environment is NOT production.
pub fn is_not_production(&self) -> bool {
not(self.is_production())
/// Checks if the current environment is development.
pub fn is_development(&self) -> bool {
matches!(self, Self::Development)
}
}

Expand All @@ -219,9 +221,8 @@ impl FromStr for Environment {
let s = s.trim().to_lowercase();
match s.as_str() {
"dev" | "development" => Ok(Self::Development),
"stag" | "staging" => Ok(Self::Staging),
"prod" | "production" => Ok(Self::Production),
&_ => todo!(),
s => Err(anyhow!("unknown environment: {}", s)),
}
}
}
6 changes: 0 additions & 6 deletions src/eth/rpc/rpc_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,3 @@ impl Debug for RpcContext {
.finish_non_exhaustive()
}
}

impl RpcContext {
pub fn is_production(&self) -> bool {
self.env == Environment::Production
}
}
13 changes: 7 additions & 6 deletions src/eth/rpc/rpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use jsonrpsee::IntoSubscriptionCloseResponse;
use jsonrpsee::PendingSubscriptionSink;
use serde_json::Value as JsonValue;

use crate::config::Environment;
use crate::config::StratusConfig;
use crate::eth::primitives::Address;
use crate::eth::primitives::BlockNumber;
Expand Down Expand Up @@ -67,8 +68,9 @@ pub async fn serve_rpc(executor: EthExecutor, eth_storage: Arc<dyn EthStorage>,
tracing::info!(%address, ?ctx, "starting rpc server");

// configure module
let env = ctx.env;
let mut module = RpcModule::<RpcContext>::new(ctx);
module = register_methods(module)?;
module = register_methods(module, env)?;

// configure middleware
let rpc_middleware = RpcServiceBuilder::new().layer_fn(RpcMiddleware::new);
Expand Down Expand Up @@ -98,9 +100,11 @@ pub async fn serve_rpc(executor: EthExecutor, eth_storage: Arc<dyn EthStorage>,
Ok(())
}

fn register_methods(mut module: RpcModule<RpcContext>) -> anyhow::Result<RpcModule<RpcContext>> {
fn register_methods(mut module: RpcModule<RpcContext>, env: Environment) -> anyhow::Result<RpcModule<RpcContext>> {
// debug
module.register_async_method("debug_setHead", debug_set_head)?;
if env.is_development() {
module.register_async_method("debug_setHead", debug_set_head)?;
}

// blockchain
module.register_async_method("net_version", net_version)?;
Expand Down Expand Up @@ -147,9 +151,6 @@ fn register_methods(mut module: RpcModule<RpcContext>) -> anyhow::Result<RpcModu

// Debug
async fn debug_set_head(params: Params<'_>, ctx: Arc<RpcContext>) -> anyhow::Result<JsonValue, RpcError> {
if ctx.is_production() {
return Err(RpcError::Response(rpc_internal_error("method is only available in development environment")));
}
let (_, number) = next_rpc_param::<BlockNumber>(params.sequence())?;
ctx.storage.reset(number).await?;
Ok(serde_json::to_value(number).unwrap())
Expand Down

0 comments on commit d77a08d

Please sign in to comment.