Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1626] group mapTask #2771

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4b75071
push data part
zaynt4606 Sep 25, 2024
779a015
push data part
zaynt4606 Sep 29, 2024
9fa8075
reformat & configuration
zaynt4606 Sep 29, 2024
a27a6ae
miscalculation
zaynt4606 Sep 29, 2024
5a78a44
clientReader fetch park
zaynt4606 Sep 30, 2024
21a83dc
delete useless code
zaynt4606 Oct 8, 2024
b6ce81d
reformat
zaynt4606 Oct 8, 2024
29ec3ad
test false
zaynt4606 Oct 8, 2024
10684f6
add ut and change conf to test
zaynt4606 Oct 22, 2024
4eb0d58
fix ShuffleReader numPartitions bug
zaynt4606 Oct 23, 2024
d184cbc
change for spark2
zaynt4606 Oct 23, 2024
5cdd93f
test groupSize100
zaynt4606 Oct 23, 2024
1bd8156
delete useless code
zaynt4606 Oct 30, 2024
0e44f0a
ditto
zaynt4606 Oct 30, 2024
bc958a5
reformat
zaynt4606 Oct 30, 2024
7525ef8
bug tod
zaynt4606 Oct 31, 2024
36fedc2
todo
zaynt4606 Nov 11, 2024
bfaba0f
group workers
zaynt4606 Nov 11, 2024
86338a2
inputStream null bug fix
zaynt4606 Nov 11, 2024
ee6d3f8
bug fix window
zaynt4606 Nov 12, 2024
16587ff
change log
zaynt4606 Nov 12, 2024
0c729f9
group workers
zaynt4606 Nov 14, 2024
2306368
worker group update
zaynt4606 Nov 18, 2024
c814b74
reformat
zaynt4606 Nov 18, 2024
13ef988
part revive worker group
zaynt4606 Nov 19, 2024
079cfbb
fix leetcode compute
zaynt4606 Nov 19, 2024
cf0857e
part revive worker group
zaynt4606 Nov 19, 2024
32acd7c
finish group worker for revive
zaynt4606 Nov 19, 2024
db894fe
fix worker group revicve bug
zaynt4606 Nov 19, 2024
f3ec99f
delete log
zaynt4606 Nov 19, 2024
84a4d4e
ut fix
zaynt4606 Nov 20, 2024
41ea018
update
zaynt4606 Nov 20, 2024
0e5419d
modify code
zaynt4606 Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.io.IOException
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader}
Expand Down Expand Up @@ -88,8 +90,29 @@ class CelebornShuffleReader[K, C](
}
}

var partitionIdList = new ArrayBuffer[Int]()
if (!conf.groupMapTaskEnabled) {
partitionIdList = ArrayBuffer[Int]() ++ (startPartition until endPartition)
} else {
val numPartitions = dep.partitioner.numPartitions
val numMappers = handle.numMaps
val partitionGroupCnt =
if (conf.groupMapTaskEnabled)
math.ceil(numMappers.toDouble / conf.groupMapTaskGroupSize).toInt
else 1
val groupNumPartitions = numPartitions * partitionGroupCnt
(startPartition until endPartition).foreach { originalPartitionId =>
(0 until partitionGroupCnt).foreach { groupCnt =>
val tmpPartitionId = {
originalPartitionId + groupCnt * (groupNumPartitions / partitionGroupCnt)
}
partitionIdList += tmpPartitionId
}
}
}

val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]()
(startPartition until endPartition).map(partitionId => {
partitionIdList.foreach(partitionId => {
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
if (exceptionRef.get() == null) {
Expand All @@ -115,7 +138,7 @@ class CelebornShuffleReader[K, C](
})
})

val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
val recordIter = partitionIdList.iterator.map(partitionId => {
if (handle.numMaps > 0) {
val startFetchWait = System.nanoTime()
var inputStream: CelebornInputStream = streams.get(partitionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
Expand Down Expand Up @@ -107,7 +108,7 @@ class CelebornShuffleReader[K, C](
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
var fileGroups: ReduceFileGroups = null
try {
// startPartition is irrelevant
// startPartition is irrelevant, for error log print
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
Expand All @@ -121,8 +122,25 @@ class CelebornShuffleReader[K, C](
(TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()

var partCnt = 0
var groupPartitionIdList = new ArrayBuffer[Int]()
if (!conf.groupMapTaskEnabled) {
groupPartitionIdList = ArrayBuffer[Int]() ++ (startPartition until endPartition)
} else {
val numPartitions = dep.partitioner.numPartitions
val numMappers = handle.numMappers
val partitionGroupCnt = math.ceil(numMappers.toDouble / conf.groupMapTaskGroupSize).toInt
val groupNumPartitions = numPartitions * partitionGroupCnt
(startPartition until endPartition).foreach { originalPartitionId =>
(0 until partitionGroupCnt).foreach { groupCnt =>
val tmpPartitionId = {
originalPartitionId + groupCnt * (groupNumPartitions / partitionGroupCnt)
}
groupPartitionIdList += tmpPartitionId
}
}
}

(startPartition until endPartition).foreach { partitionId =>
groupPartitionIdList.foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
partCnt += 1
Expand Down Expand Up @@ -236,17 +254,17 @@ class CelebornShuffleReader[K, C](
}

val inputStreamCreationWindow = conf.clientInputStreamCreationWindow
(startPartition until Math.min(
startPartition + inputStreamCreationWindow,
endPartition)).foreach(partitionId => {

(0 until Math.min(inputStreamCreationWindow, groupPartitionIdList.size)).foreach(listIndex => {
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
createInputStream(partitionId)
createInputStream(groupPartitionIdList(listIndex))
}
})
})

val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
var curIndex = 0
val recordIter = groupPartitionIdList.iterator.map(partitionId => {
if (handle.numMappers > 0) {
val startFetchWait = System.nanoTime()
var inputStream: CelebornInputStream = streams.get(partitionId)
Expand All @@ -258,7 +276,7 @@ class CelebornShuffleReader[K, C](
case e => throw e
}
}
logInfo("inputStream is null, sleeping...")
logInfo(s"partitionId ${partitionId} inputStream is null, sleeping...")
Thread.sleep(50)
inputStream = streams.get(partitionId)
}
Expand All @@ -268,16 +286,19 @@ class CelebornShuffleReader[K, C](
context.addTaskCompletionListener[Unit](_ => inputStream.close())

// Advance the input creation window
if (partitionId + inputStreamCreationWindow < endPartition) {
if (curIndex + inputStreamCreationWindow < groupPartitionIdList.size) {
val nextPartitionId = groupPartitionIdList(curIndex + inputStreamCreationWindow)
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
createInputStream(partitionId + inputStreamCreationWindow)
createInputStream(nextPartitionId)
}
})
}

curIndex = curIndex + 1
(partitionId, inputStream)
} else {
curIndex = curIndex + 1
(partitionId, CelebornInputStream.empty())
}
}).filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,17 @@ public int pushOrMergeData(
return 0;
}

final PartitionLocation loc = map.get(partitionId);
int innerPartitionId =
conf.groupMapTaskEnabled()
? partitionId + (mapId / conf.groupMapTaskGroupSize()) * numPartitions
: partitionId;
PartitionLocation loc = map.get(innerPartitionId);

if (loc == null) {
throw new CelebornIOException(
String.format(
"Partition location for shuffle %s partition %d is NULL!", shuffleId, partitionId));
"Partition location for shuffle %s partition %d groupPartition %d is NULL!",
shuffleId, partitionId, innerPartitionId));
}

PushState pushState = getPushState(mapKey);
Expand Down Expand Up @@ -1017,21 +1023,22 @@ public void onSuccess(ByteBuffer response) {
.add(mapId);
}
logger.debug(
"Push data to {} success for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} success for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
loc.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId);
}

@Override
public void onFailure(Throwable e) {
String errorMsg =
String.format(
"Push data to %s failed for shuffle %d map %d attempt %d partition %d batch %d.",
loc, shuffleId, mapId, attemptId, partitionId, nextBatchId);
"Push data to %s failed for shuffle %d map %d attempt %d partition %d groupPartition %d batch %d.",
loc, shuffleId, mapId, attemptId, partitionId, innerPartitionId, nextBatchId);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
}
};
Expand All @@ -1054,21 +1061,25 @@ public void onSuccess(ByteBuffer response) {
byte reason = response.get();
if (reason == StatusCode.SOFT_SPLIT.getValue()) {
logger.debug(
"Push data to {} soft split required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} soft split required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId);
if (!newerPartitionLocationExists(
reducePartitionMap.get(shuffleId), partitionId, latest.getEpoch(), false)) {
reducePartitionMap.get(shuffleId),
innerPartitionId,
latest.getEpoch(),
false)) {
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
latest.getEpoch(),
latest,
StatusCode.SOFT_SPLIT);
Expand All @@ -1079,19 +1090,20 @@ public void onSuccess(ByteBuffer response) {
callback.onSuccess(response);
} else if (reason == StatusCode.HARD_SPLIT.getValue()) {
logger.debug(
"Push data to {} hard split required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} hard split required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId);
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
latest.getEpoch(),
latest,
StatusCode.HARD_SPLIT);
Expand All @@ -1114,24 +1126,26 @@ public void onSuccess(ByteBuffer response) {
dueTime));
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) {
logger.debug(
"Push data to {} primary congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} primary congestion required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
callback.onSuccess(response);
} else if (reason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue()) {
logger.debug(
"Push data to {} replica congestion required for shuffle {} map {} attempt {} partition {} batch {}.",
"Push data to {} replica congestion required for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId);
pushState.onCongestControl(latest.hostAndPushPort());
pushState.removeBatch(nextBatchId, latest.hostAndPushPort());
Expand Down Expand Up @@ -1166,12 +1180,13 @@ public void onFailure(Throwable e) {
}

logger.error(
"Push data to {} failed for shuffle {} map {} attempt {} partition {} batch {}, remain revive times {}.",
"Push data to {} failed for shuffle {} map {} attempt {} partition {} groupPartition {} batch {}, remain revive times {}.",
latest.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId,
remainReviveTimes,
e);
Expand All @@ -1180,7 +1195,13 @@ public void onFailure(Throwable e) {
remainReviveTimes = remainReviveTimes - 1;
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId, mapId, attemptId, partitionId, latest.getEpoch(), latest, cause);
shuffleId,
mapId,
attemptId,
innerPartitionId,
latest.getEpoch(),
latest,
cause);
reviveManager.addRequest(reviveRequest);
long dueTime =
System.currentTimeMillis()
Expand Down Expand Up @@ -1217,7 +1238,7 @@ public void onFailure(Throwable e) {
if (!testRetryRevive) {
assert dataClientFactory != null;
TransportClient client =
dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId);
dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), innerPartitionId);
client.pushData(pushData, pushDataTimeout, wrappedCallback);
} else {
wrappedCallback.onFailure(
Expand All @@ -1228,11 +1249,12 @@ public void onFailure(Throwable e) {
}
} catch (Exception e) {
logger.error(
"Exception raised while pushing data for shuffle {} map {} attempt {} partition {} batch {} location {}.",
"Exception raised while pushing data for shuffle {} map {} attempt {} partition {} groupPartition {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
partitionId,
innerPartitionId,
nextBatchId,
loc,
e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class ChangePartitionManager(
private val dynamicResourceEnabled = conf.clientShuffleDynamicResourceEnabled
private val dynamicResourceUnavailableFactor = conf.clientShuffleDynamicResourceFactor

private val groupWorkerResources = conf.groupWorkerResources

def start(): Unit = {
batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map {
// noinspection ConvertExpressionToSAM
Expand Down Expand Up @@ -363,11 +365,16 @@ class ChangePartitionManager(
return
}

val groupWorkerMap = lifecycleManager.groupedWorkers.getOrDefault(
shuffleId,
new ConcurrentHashMap[Int, util.HashSet[WorkerInfo]]())

// PartitionSplit all contains oldPartition
val newlyAllocatedLocations =
reallocateChangePartitionRequestSlotsFromCandidates(
changePartitions.toList,
candidates.asScala.toList)
candidates.asScala.toList,
groupWorkerMap)

if (!lifecycleManager.reserveSlotsWithRetry(
shuffleId,
Expand Down Expand Up @@ -413,14 +420,29 @@ class ChangePartitionManager(

private def reallocateChangePartitionRequestSlotsFromCandidates(
changePartitionRequests: List[ChangePartitionRequest],
candidates: List[WorkerInfo]): WorkerResource = {
candidates: List[WorkerInfo],
groupWorkerMap: ConcurrentHashMap[Int, util.HashSet[WorkerInfo]]): WorkerResource = {
val slots = new WorkerResource()
changePartitionRequests.foreach { partition =>
val partitionId = partition.partitionId
val groupWorkerList =
if (groupWorkerResources) {
Option(lifecycleManager.partitionGroupMap.get(partitionId)) match {
case Some(partitionGroup) =>
groupWorkerMap.getOrDefault(partitionGroup, new util.HashSet[WorkerInfo]()).asScala
.filter(lifecycleManager.workerStatusTracker.workerAvailable)
.toList
case None => List()
}
} else {
List()
}
lifecycleManager.allocateFromCandidates(
partition.partitionId,
partitionId,
partition.epoch,
candidates,
slots)
slots,
groupWorkerList)
}
slots
}
Expand Down
Loading
Loading