Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][CELEBORN-1577][Phase2] QuotaManager should support interrupt shuffle. #2819

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ private void initializeLifecycleManager(String appId) {
synchronized (this) {
if (lifecycleManager == null) {
lifecycleManager = new LifecycleManager(appId, celebornConf);
lifecycleManager.registerCancelShuffleCallback(SparkUtils::cancelShuffle);
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@
import java.lang.reflect.Method;
import java.util.concurrent.atomic.LongAdder;

import scala.Option;
import scala.Some;
import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
Expand All @@ -39,6 +44,7 @@
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;

public class SparkUtils {
private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class);
Expand Down Expand Up @@ -179,4 +185,22 @@ public static void addFailureListenerIfBarrierTask(
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}

private static final DynFields.UnboundField shuffleIdToMapStage_FIELD =
DynFields.builder().hiddenImpl(DAGScheduler.class, "shuffleIdToMapStage").build();

public static void cancelShuffle(int shuffleId, String reason) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
DAGScheduler scheduler = SparkContext$.MODULE$.getActive().get().dagScheduler();
scala.collection.mutable.Map<Integer, ShuffleMapStage> shuffleIdToMapStageValue =
(scala.collection.mutable.Map<Integer, ShuffleMapStage>)
shuffleIdToMapStage_FIELD.bind(scheduler).get();
Option<ShuffleMapStage> shuffleMapStage = shuffleIdToMapStageValue.get(shuffleId);
if (shuffleMapStage.nonEmpty()) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ private void initializeLifecycleManager() {
synchronized (this) {
if (lifecycleManager == null) {
lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
lifecycleManager.registerCancelShuffleCallback(SparkUtils::cancelShuffle);
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@

import java.util.concurrent.atomic.LongAdder;

import scala.Option;
import scala.Some;
import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
Expand Down Expand Up @@ -266,6 +271,9 @@ public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
.orNoop()
.build();

private static final DynFields.UnboundField shuffleIdToMapStage_FIELD =
DynFields.builder().hiddenImpl(DAGScheduler.class, "shuffleIdToMapStage").build();

public static void unregisterAllMapOutput(
MapOutputTrackerMaster mapOutputTracker, int shuffleId) {
if (!UnregisterAllMapAndMergeOutput_METHOD.isNoop()) {
Expand Down Expand Up @@ -296,4 +304,19 @@ public static void addFailureListenerIfBarrierTask(
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}

public static void cancelShuffle(int shuffleId, String reason) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
DAGScheduler scheduler = SparkContext$.MODULE$.getActive().get().dagScheduler();
scala.collection.mutable.Map<Integer, ShuffleMapStage> shuffleIdToMapStageValue =
(scala.collection.mutable.Map<Integer, ShuffleMapStage>)
shuffleIdToMapStage_FIELD.bind(scheduler).get();
Option<ShuffleMapStage> shuffleMapStage = shuffleIdToMapStageValue.get(shuffleId);
if (shuffleMapStage.nonEmpty()) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ZERO_UUID}
import org.apache.celeborn.common.protocol.message.ControlMessages._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{ThreadUtils, Utils}

Expand All @@ -33,7 +33,8 @@ class ApplicationHeartbeater(
conf: CelebornConf,
masterClient: MasterClient,
shuffleMetrics: () => (Long, Long),
workerStatusTracker: WorkerStatusTracker) extends Logging {
workerStatusTracker: WorkerStatusTracker,
cancelAllActiveStages: String => Unit) extends Logging {

private var stopped = false

Expand Down Expand Up @@ -68,6 +69,7 @@ class ApplicationHeartbeater(
if (response.statusCode == StatusCode.SUCCESS) {
logDebug("Successfully send app heartbeat.")
workerStatusTracker.handleHeartbeatResponse(response)
checkQuotaExceeds(response.checkQuotaResponse)
}
} catch {
case it: InterruptedException =>
Expand Down Expand Up @@ -97,7 +99,8 @@ class ApplicationHeartbeater(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava)
List.empty.asJava,
CheckQuotaResponse(isAvailable = true, ""))
}
}

Expand All @@ -114,6 +117,12 @@ class ApplicationHeartbeater(
}
}

private def checkQuotaExceeds(response: CheckQuotaResponse): Unit = {
if (!response.isAvailable) {
cancelAllActiveStages(response.reason)
}
}

def stop(): Unit = {
stopped.synchronized {
if (!stopped) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, ScheduledFuture, TimeUnit}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.Consumer
import java.util.function.{BiConsumer, Consumer}

import scala.collection.JavaConverters._
import scala.collection.generic.CanBuildFrom
Expand Down Expand Up @@ -54,7 +54,6 @@ import org.apache.celeborn.common.rpc.{ClientSaslContextBuilder, RpcSecurityCont
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils.awaitResult
import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID

Expand Down Expand Up @@ -209,7 +208,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
conf,
masterClient,
() => commitManager.commitMetrics(),
workerStatusTracker)
workerStatusTracker,
reason => cancelAllActiveStages(reason))
private val changePartitionManager = new ChangePartitionManager(conf, this)
private val releasePartitionManager = new ReleasePartitionManager(conf, this)

Expand Down Expand Up @@ -1760,6 +1760,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
appShuffleDeterminateMap.put(appShuffleId, determinate)
}

@volatile private var cancelShuffleCallback: Option[BiConsumer[Integer, String]] = None
def registerCancelShuffleCallback(callback: BiConsumer[Integer, String]): Unit = {
cancelShuffleCallback = Some(callback)
}

// Initialize at the end of LifecycleManager construction.
initialize()

Expand All @@ -1778,4 +1783,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
rnd.nextBytes(secretBytes)
JavaUtils.bytesToString(ByteBuffer.wrap(secretBytes))
}

def cancelAllActiveStages(reason: String): Unit = cancelShuffleCallback match {
case Some(c) =>
shuffleAllocatedWorkers
.asScala
.keys
.filter(!commitManager.isStageEnd(_))
.foreach(c.accept(_, reason))

case _ =>
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
StatusCode.SUCCESS,
excludedWorkers,
unknownWorkers,
shuttingWorkers)
shuttingWorkers,
null)
}

private def mockWorkers(workerHosts: Array[String]): util.ArrayList[WorkerInfo] = {
Expand Down
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ message PbHeartbeatFromApplicationResponse {
repeated PbWorkerInfo excludedWorkers = 2;
repeated PbWorkerInfo unknownWorkers = 3;
repeated PbWorkerInfo shuttingWorkers = 4;
PbCheckQuotaResponse checkQuotaResponse = 5;
}

message PbCheckQuota {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def estimatedPartitionSizeForEstimationUpdateInterval: Long =
get(ESTIMATED_PARTITION_SIZE_UPDATE_INTERVAL)
def masterResourceConsumptionInterval: Long = get(MASTER_RESOURCE_CONSUMPTION_INTERVAL)
def masterUserDiskUsageThreshold: Long = get(MASTER_USER_DISK_USAGE_THRESHOLD)
def masterClusterDiskUsageThreshold: Long = get(MASTER_CLUSTER_DISK_USAGE_THRESHOLD)
def clusterName: String = get(CLUSTER_NAME)

// //////////////////////////////////////////////////////
Expand Down Expand Up @@ -1061,6 +1063,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def registerShuffleFilterExcludedWorkerEnabled: Boolean =
get(REGISTER_SHUFFLE_FILTER_EXCLUDED_WORKER_ENABLED)

def interruptShuffleEnabled: Boolean = get(QUOTA_INTERRUPT_SHUFFLE_ENABLED)

// //////////////////////////////////////////////////////
// Worker //
// //////////////////////////////////////////////////////
Expand Down Expand Up @@ -2841,6 +2845,26 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("30s")

val MASTER_USER_DISK_USAGE_THRESHOLD: ConfigEntry[Long] =
buildConf("celeborn.master.userResourceConsumption.user.threshold")
.categories("master")
.doc("When user resource consumption exceeds quota, Master will " +
"interrupt some apps until user resource consumption is less " +
"than this value. Default value is Long.MaxValue which means disable check.")
.version("0.6.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val MASTER_CLUSTER_DISK_USAGE_THRESHOLD: ConfigEntry[Long] =
buildConf("celeborn.master.userResourceConsumption.cluster.threshold")
.categories("master")
.doc("When cluster resource consumption exceeds quota, Master will " +
"interrupt some apps until cluster resource consumption is less " +
"than this value. Default value is Long.MaxValue which means disable check.")
.version("0.6.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val CLUSTER_NAME: ConfigEntry[String] =
buildConf("celeborn.cluster.name")
.categories("master", "worker")
Expand Down Expand Up @@ -5185,7 +5209,7 @@ object CelebornConf extends Logging {
.dynamic
.doc("Quota dynamic configuration for written disk bytes.")
.version("0.5.0")
.longConf
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val QUOTA_DISK_FILE_COUNT: ConfigEntry[Long] =
Expand All @@ -5203,7 +5227,7 @@ object CelebornConf extends Logging {
.dynamic
.doc("Quota dynamic configuration for written hdfs bytes.")
.version("0.5.0")
.longConf
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val QUOTA_HDFS_FILE_COUNT: ConfigEntry[Long] =
Expand Down Expand Up @@ -5765,4 +5789,51 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)

val QUOTA_CLUSTER_DISK_BYTES_WRITTEN: ConfigEntry[Long] =
buildConf("celeborn.quota.cluster.diskBytesWritten")
.categories("quota")
.dynamic
.doc("Quota dynamic configuration for cluster written disk bytes.")
.version("0.6.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val QUOTA_CLUSTER_DISK_FILE_COUNT: ConfigEntry[Long] =
buildConf("celeborn.quota.cluster.diskFileCount")
.categories("quota")
.dynamic
.doc("Quota dynamic configuration for cluster written disk file count.")
.version("0.6.0")
.longConf
.createWithDefault(Long.MaxValue)

val QUOTA_CLUSTER_HDFS_BYTES_WRITTEN: ConfigEntry[Long] =
buildConf("celeborn.quota.cluster.hdfsBytesWritten")
.categories("quota")
.dynamic
.doc("Quota dynamic configuration for cluster written hdfs bytes.")
.version("0.6.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Long.MaxValue)

val QUOTA_CLUSTER_HDFS_FILE_COUNT: ConfigEntry[Long] =
buildConf("celeborn.quota.cluster.hdfsFileCount")
.categories("quota")
.dynamic
.doc("Quota dynamic configuration for cluster written hdfs file count.")
.version("0.6.0")
.longConf
.createWithDefault(Long.MaxValue)

val QUOTA_INTERRUPT_SHUFFLE_ENABLED: ConfigEntry[Boolean] = {
buildConf("celeborn.quota.interruptShuffle.enabled")
.categories("quota")
.dynamic
.doc("If enabled, the resource consumption used by the tenant exceeds " +
"celeborn.quota.tenant.xx, or the resource consumption of the entire cluster " +
"exceeds celeborn.quota.cluster.xx, some shuffles will be selected and interrupted.")
.version("0.6.0")
.booleanConf
.createWithDefault(false)
}
}
Loading