Skip to content

Commit

Permalink
feat: Introduce CometTaskMemoryManager and native side memory pool
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Feb 22, 2024
1 parent 637dba9 commit 5165385
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.apache.comet;

/** OOM error specific for Comet memory management */
public class CometOutOfMemoryError extends OutOfMemoryError {
public CometOutOfMemoryError(String msg) {
super(msg);
}
}
6 changes: 5 additions & 1 deletion core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ pub enum CometError {
#[from]
source: std::num::ParseFloatError,
},

#[error(transparent)]
BoolFormat {
#[from]
source: std::str::ParseBoolError,
},
#[error(transparent)]
Format {
#[from]
Expand Down
42 changes: 28 additions & 14 deletions core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use jni::{
};
use std::{collections::HashMap, sync::Arc, task::Poll};

use super::{serde, utils::SparkArrowConvert};
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};

use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
Expand Down Expand Up @@ -103,6 +103,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
task_memory_manager_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand Down Expand Up @@ -147,11 +148,12 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}
let task_memory_manager = Arc::new(jni_new_global_ref!(env, task_memory_manager_obj)?);

// We need to keep the session context alive. Some session state like temporary
// dictionaries are stored in session context. If it is dropped, the temporary
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs)?;
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;

let exec_context = Box::new(ExecutionContext {
id,
Expand All @@ -175,6 +177,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
/// Parse Comet configs and configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
task_memory_manager: Arc<GlobalRef>,
) -> CometResult<SessionContext> {
// Get the batch size from Comet JVM side
let batch_size = conf
Expand All @@ -186,18 +189,29 @@ fn prepare_datafusion_session_context(

let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);

// Set up memory limit if specified
if conf.contains_key("memory_limit") {
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;

let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;

rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction);
let use_unified_memory_manager = conf
.get("use_unified_memory_manager")
.ok_or(CometError::Internal(
"Config 'use_unified_memory_manager' is not specified from Comet JVM side".to_string(),
))?
.parse::<bool>()?;

if use_unified_memory_manager {
// Set Comet memory pool for native
let memory_pool = CometMemoryPool::new(task_memory_manager);
rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
} else {
// Use the memory pool from DF
if conf.contains_key("memory_limit") {
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;
let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;
rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
}
}

// Get Datafusion configuration from Spark Execution context
Expand Down
85 changes: 85 additions & 0 deletions core/src/execution/memory_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
},
};

use jni::objects::GlobalRef;

use datafusion::{
common::DataFusionError,
execution::memory_pool::{MemoryPool, MemoryReservation},
};

use crate::jvm_bridge::{jni_call, JVMClasses};

pub struct CometMemoryPool {
task_memory_manager_handle: Arc<GlobalRef>,
used: AtomicUsize,
}

impl Debug for CometMemoryPool {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("CometMemoryPool")
.field("used", &self.used.load(Relaxed))
.finish()
}
}

impl CometMemoryPool {
pub fn new(task_memory_manager_handle: Arc<GlobalRef>) -> CometMemoryPool {
Self {
task_memory_manager_handle,
used: AtomicUsize::new(0),
}
}
}

unsafe impl Send for CometMemoryPool {}
unsafe impl Sync for CometMemoryPool {}

impl MemoryPool for CometMemoryPool {
fn grow(&self, _: &MemoryReservation, additional: usize) {
self.used.fetch_add(additional, Relaxed);
}

fn shrink(&self, _: &MemoryReservation, size: usize) {
let mut env = JVMClasses::get_env();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
.unwrap();
}
self.used.fetch_sub(size, Relaxed);
}

fn try_grow(&self, _: &MemoryReservation, additional: usize) -> Result<(), DataFusionError> {
if additional > 0 {
let mut env = JVMClasses::get_env();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
let acquired = jni_call!(&mut env,
comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64)?;

// If the number of bytes we acquired is less than the requested, return an error,
// and hopefully will trigger spilling from the caller side.
if acquired < additional as i64 {
return Err(DataFusionError::Execution(format!(
"Failed to acquire {} bytes, only got {}. Reserved: {}",
additional,
acquired,
self.reserved(),
)));
}
}
self.used.fetch_add(additional, Relaxed);
}
Ok(())
}

fn reserved(&self) -> usize {
self.used.load(Relaxed)
}
}
3 changes: 3 additions & 0 deletions core/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub(crate) mod sort;
mod timezone;
pub(crate) mod utils;

mod memory_pool;
pub use memory_pool::*;

// Include generated modules from .proto files.
#[allow(missing_docs)]
pub mod spark_expression {
Expand Down
45 changes: 45 additions & 0 deletions core/src/jvm_bridge/comet_task_memory_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
signature::{Primitive, ReturnType},
JNIEnv,
};

use crate::jvm_bridge::get_global_jclass;

/// A DataFusion `MemoryPool` implementation for Comet, which delegate to the JVM
/// side `CometTaskMemoryManager`.
#[derive(Debug)]
pub struct CometTaskMemoryManager<'a> {
pub class: JClass<'a>,
pub method_acquire_memory: JMethodID,
pub method_release_memory: JMethodID,

pub method_acquire_memory_ret: ReturnType,
pub method_release_memory_ret: ReturnType,
}

impl<'a> CometTaskMemoryManager<'a> {
pub const JVM_CLASS: &'static str = "org/apache/spark/CometTaskMemoryManager";

pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometTaskMemoryManager<'a>> {
let class = get_global_jclass(env, Self::JVM_CLASS)?;

let result = CometTaskMemoryManager {
class,
method_acquire_memory: env.get_method_id(
Self::JVM_CLASS,
"acquireMemory",
"(J)J".to_string(),
)?,
method_release_memory: env.get_method_id(
Self::JVM_CLASS,
"releaseMemory",
"(J)V".to_string(),
)?,
method_acquire_memory_ret: ReturnType::Primitive(Primitive::Long),
method_release_memory_ret: ReturnType::Primitive(Primitive::Void),
};
Ok(result)
}
}
8 changes: 7 additions & 1 deletion core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ macro_rules! jni_call {
let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? {
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
Expand Down Expand Up @@ -194,10 +194,12 @@ mod comet_exec;
pub use comet_exec::*;
mod batch_iterator;
mod comet_metric_node;
mod comet_task_memory_manager;

use crate::{errors::CometError, JAVA_VM};
use batch_iterator::CometBatchIterator;
pub use comet_metric_node::*;
pub use comet_task_memory_manager::*;

/// The JVM classes that are used in the JNI calls.
pub struct JVMClasses<'a> {
Expand All @@ -216,6 +218,9 @@ pub struct JVMClasses<'a> {
pub comet_exec: CometExec<'a>,
/// The CometBatchIterator class. Used for iterating over the batches.
pub comet_batch_iterator: CometBatchIterator<'a>,
/// The CometTaskMemoryManager used for interacting with JVM side to
/// acquire & release native memory.
pub comet_task_memory_manager: CometTaskMemoryManager<'a>,
}

unsafe impl<'a> Send for JVMClasses<'a> {}
Expand Down Expand Up @@ -261,6 +266,7 @@ impl JVMClasses<'_> {
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
comet_batch_iterator: CometBatchIterator::new(env).unwrap(),
comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(),
}
});
}
Expand Down
2 changes: 2 additions & 0 deletions dev/ensure-jars-have-correct-contents.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ allowed_expr+="|^org/apache/spark/shuffle/comet/.*$"
allowed_expr+="|^org/apache/spark/sql/$"
allowed_expr+="|^org/apache/spark/CometPlugin.class$"
allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$"
allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$"
allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.*$"

allowed_expr+=")"
declare -i bad_artifacts=0
Expand Down
49 changes: 49 additions & 0 deletions spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.apache.spark;

import java.io.IOException;

import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;

/**
* A adapter class that is used by Comet native to acquire & release memory through Spark's unified
* memory manager. This assumes Spark's off-heap memory mode is enabled.
*/
public class CometTaskMemoryManager {
private final TaskMemoryManager internal;
private final NativeMemoryConsumer nativeMemoryConsumer;

public CometTaskMemoryManager() {
this.internal = TaskContext$.MODULE$.get().taskMemoryManager();
this.nativeMemoryConsumer = new NativeMemoryConsumer();
}

// Called by Comet native through JNI.
// Returns the actual amount of memory (in bytes) granted.
public long acquireMemory(long size) {
return internal.acquireExecutionMemory(size, nativeMemoryConsumer);
}

// Called by Comet native through JNI
public void releaseMemory(long size) {
internal.releaseExecutionMemory(size, nativeMemoryConsumer);
}

/**
* A dummy memory consumer that does nothing when spilling. At the moment, Comet native doesn't
* share the same API as Spark and cannot trigger spill when acquire memory. Therefore, when
* acquiring memory from native or JVM, spilling can only be triggered from JVM operators.
*/
private class NativeMemoryConsumer extends MemoryConsumer {
protected NativeMemoryConsumer() {
super(CometTaskMemoryManager.this.internal, 0, MemoryMode.OFF_HEAP);
}

@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
// No spilling
return 0;
}
}
}
14 changes: 13 additions & 1 deletion spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ class CometExecIterator(
}.toArray
private val plan = {
val configs = createNativeConf
nativeLib.createPlan(id, configs, cometBatchIterators, protobufQueryPlan, nativeMetrics)
nativeLib.createPlan(
id,
configs,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager)
}

private var nextBatch: Option[ColumnarBatch] = None
Expand Down Expand Up @@ -83,6 +89,12 @@ class CometExecIterator(
val conf = SparkEnv.get.conf

val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
// Only enable unified memory manager when off-heap mode is enabled. Otherwise,
// we'll use the built-in memory pool from DF, and initializes with `memory_limit`
// and `memory_fraction` below.
result.put(
"use_unified_memory_manager",
String.valueOf(conf.get("spark.memory.offHeap.enabled", "false")))
result.put("memory_limit", String.valueOf(maxMemory))
result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
Expand Down
7 changes: 6 additions & 1 deletion spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.comet

import java.util.Map

import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode

class Native extends NativeBase {
Expand All @@ -38,6 +39,9 @@ class Native extends NativeBase {
* the bytes of serialized SparkPlan.
* @param metrics
* the native metrics of SparkPlan.
* @param taskMemoryManager
* the task-level memory manager that is responsible for tracking memory usage across JVM and
* native side.
* @return
* the address to native query plan.
*/
Expand All @@ -46,7 +50,8 @@ class Native extends NativeBase {
configMap: Map[String, String],
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode): Long
metrics: CometMetricNode,
taskMemoryManager: CometTaskMemoryManager): Long

/**
* Execute a native query plan based on given input Arrow arrays.
Expand Down
Loading

0 comments on commit 5165385

Please sign in to comment.