Skip to content

Commit

Permalink
feat: Introduce CometTaskMemoryManager and native side memory pool (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao authored Mar 6, 2024
1 parent a028132 commit e83635a
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 41 deletions.
27 changes: 27 additions & 0 deletions common/src/main/java/org/apache/comet/CometOutOfMemoryError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet;

/** OOM error specific for Comet memory management */
public class CometOutOfMemoryError extends OutOfMemoryError {
public CometOutOfMemoryError(String msg) {
super(msg);
}
}
6 changes: 5 additions & 1 deletion core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ pub enum CometError {
#[from]
source: std::num::ParseFloatError,
},

#[error(transparent)]
BoolFormat {
#[from]
source: std::str::ParseBoolError,
},
#[error(transparent)]
Format {
#[from]
Expand Down
44 changes: 30 additions & 14 deletions core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use jni::{
};
use std::{collections::HashMap, sync::Arc, task::Poll};

use super::{serde, utils::SparkArrowConvert};
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};

use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
Expand Down Expand Up @@ -103,6 +103,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand Down Expand Up @@ -147,11 +148,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}
let task_memory_manager =
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);

// We need to keep the session context alive. Some session state like temporary
// dictionaries are stored in session context. If it is dropped, the temporary
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs)?;
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;

let exec_context = Box::new(ExecutionContext {
id,
Expand All @@ -175,6 +178,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
/// Parse Comet configs and configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
comet_task_memory_manager: Arc<GlobalRef>,
) -> CometResult<SessionContext> {
// Get the batch size from Comet JVM side
let batch_size = conf
Expand All @@ -186,18 +190,30 @@ fn prepare_datafusion_session_context(

let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);

// Set up memory limit if specified
if conf.contains_key("memory_limit") {
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;

let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;

rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction);
// Check if we are using unified memory manager integrated with Spark. Default to false if not
// set.
let use_unified_memory_manager = conf
.get("use_unified_memory_manager")
.map(String::as_str)
.unwrap_or("false")
.parse::<bool>()?;

if use_unified_memory_manager {
// Set Comet memory pool for native
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
} else {
// Use the memory pool from DF
if conf.contains_key("memory_limit") {
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;
let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;
rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
}
}

// Get Datafusion configuration from Spark Execution context
Expand Down
119 changes: 119 additions & 0 deletions core/src/execution/memory_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
},
};

use jni::objects::GlobalRef;

use datafusion::{
common::DataFusionError,
execution::memory_pool::{MemoryPool, MemoryReservation},
};

use crate::{
errors::CometResult,
jvm_bridge::{jni_call, JVMClasses},
};

/// A DataFusion `MemoryPool` implementation for Comet. Internally this is
/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
pub struct CometMemoryPool {
task_memory_manager_handle: Arc<GlobalRef>,
used: AtomicUsize,
}

impl Debug for CometMemoryPool {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("CometMemoryPool")
.field("used", &self.used.load(Relaxed))
.finish()
}
}

impl CometMemoryPool {
pub fn new(task_memory_manager_handle: Arc<GlobalRef>) -> CometMemoryPool {
Self {
task_memory_manager_handle,
used: AtomicUsize::new(0),
}
}

fn acquire(&self, additional: usize) -> CometResult<i64> {
let mut env = JVMClasses::get_env();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
jni_call!(&mut env,
comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64)
}
}

fn release(&self, size: usize) -> CometResult<()> {
let mut env = JVMClasses::get_env();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
}
}
}

unsafe impl Send for CometMemoryPool {}
unsafe impl Sync for CometMemoryPool {}

impl MemoryPool for CometMemoryPool {
fn grow(&self, _: &MemoryReservation, additional: usize) {
self.acquire(additional)
.unwrap_or_else(|_| panic!("Failed to acquire {} bytes", additional));
self.used.fetch_add(additional, Relaxed);
}

fn shrink(&self, _: &MemoryReservation, size: usize) {
self.release(size)
.unwrap_or_else(|_| panic!("Failed to release {} bytes", size));
self.used.fetch_sub(size, Relaxed);
}

fn try_grow(&self, _: &MemoryReservation, additional: usize) -> Result<(), DataFusionError> {
if additional > 0 {
let acquired = self.acquire(additional)?;
// If the number of bytes we acquired is less than the requested, return an error,
// and hopefully will trigger spilling from the caller side.
if acquired < additional as i64 {
// Release the acquired bytes before throwing error
self.release(acquired as usize)?;

return Err(DataFusionError::Execution(format!(
"Failed to acquire {} bytes, only got {}. Reserved: {}",
additional,
acquired,
self.reserved(),
)));
}
self.used.fetch_add(additional, Relaxed);
}
Ok(())
}

fn reserved(&self) -> usize {
self.used.load(Relaxed)
}
}
3 changes: 3 additions & 0 deletions core/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub(crate) mod sort;
mod timezone;
pub(crate) mod utils;

mod memory_pool;
pub use memory_pool::*;

// Include generated modules from .proto files.
#[allow(missing_docs)]
pub mod spark_expression {
Expand Down
62 changes: 62 additions & 0 deletions core/src/jvm_bridge/comet_task_memory_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
signature::{Primitive, ReturnType},
JNIEnv,
};

use crate::jvm_bridge::get_global_jclass;

/// A wrapper which delegate acquire/release memory calls to the
/// JVM side `CometTaskMemoryManager`.
#[derive(Debug)]
pub struct CometTaskMemoryManager<'a> {
pub class: JClass<'a>,
pub method_acquire_memory: JMethodID,
pub method_release_memory: JMethodID,

pub method_acquire_memory_ret: ReturnType,
pub method_release_memory_ret: ReturnType,
}

impl<'a> CometTaskMemoryManager<'a> {
pub const JVM_CLASS: &'static str = "org/apache/spark/CometTaskMemoryManager";

pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometTaskMemoryManager<'a>> {
let class = get_global_jclass(env, Self::JVM_CLASS)?;

let result = CometTaskMemoryManager {
class,
method_acquire_memory: env.get_method_id(
Self::JVM_CLASS,
"acquireMemory",
"(J)J".to_string(),
)?,
method_release_memory: env.get_method_id(
Self::JVM_CLASS,
"releaseMemory",
"(J)V".to_string(),
)?,
method_acquire_memory_ret: ReturnType::Primitive(Primitive::Long),
method_release_memory_ret: ReturnType::Primitive(Primitive::Void),
};
Ok(result)
}
}
8 changes: 7 additions & 1 deletion core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ macro_rules! jni_call {
let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? {
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
Expand Down Expand Up @@ -194,10 +194,12 @@ mod comet_exec;
pub use comet_exec::*;
mod batch_iterator;
mod comet_metric_node;
mod comet_task_memory_manager;

use crate::{errors::CometError, JAVA_VM};
use batch_iterator::CometBatchIterator;
pub use comet_metric_node::*;
pub use comet_task_memory_manager::*;

/// The JVM classes that are used in the JNI calls.
pub struct JVMClasses<'a> {
Expand All @@ -216,6 +218,9 @@ pub struct JVMClasses<'a> {
pub comet_exec: CometExec<'a>,
/// The CometBatchIterator class. Used for iterating over the batches.
pub comet_batch_iterator: CometBatchIterator<'a>,
/// The CometTaskMemoryManager used for interacting with JVM side to
/// acquire & release native memory.
pub comet_task_memory_manager: CometTaskMemoryManager<'a>,
}

unsafe impl<'a> Send for JVMClasses<'a> {}
Expand Down Expand Up @@ -261,6 +266,7 @@ impl JVMClasses<'_> {
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
comet_batch_iterator: CometBatchIterator::new(env).unwrap(),
comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(),
}
});
}
Expand Down
2 changes: 2 additions & 0 deletions dev/ensure-jars-have-correct-contents.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ allowed_expr+="|^org/apache/spark/shuffle/comet/.*$"
allowed_expr+="|^org/apache/spark/sql/$"
allowed_expr+="|^org/apache/spark/CometPlugin.class$"
allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$"
allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$"
allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.*$"

allowed_expr+=")"
declare -i bad_artifacts=0
Expand Down
Loading

0 comments on commit e83635a

Please sign in to comment.