Skip to content

Commit

Permalink
rename splitOrFailed to split, add test for SOFT_SPLIT
Browse files Browse the repository at this point in the history
  • Loading branch information
jiang13021 committed Nov 26, 2024
1 parent 7029f7a commit 7f85fe9
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 246 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,29 +366,6 @@ private void submitRetryPushData(
}
}

public ReviveRequest[] addAndGetReviveRequestsWithMultiCause(
int shuffleId,
int mapId,
int attemptId,
ArrayList<DataBatches.DataBatch> batches,
List<StatusCode> causeList) {
if (batches.size() != causeList.size()) {
throw new IllegalArgumentException(
"Batches size " + batches.size() + " is not equal to cause size " + causeList.size());
}
ReviveRequest[] reviveRequests = new ReviveRequest[batches.size()];
for (int i = 0; i < batches.size(); i++) {
DataBatches.DataBatch batch = batches.get(i);
PartitionLocation loc = batch.loc;
ReviveRequest reviveRequest =
new ReviveRequest(
shuffleId, mapId, attemptId, loc.getId(), loc.getEpoch(), loc, causeList.get(i));
reviveManager.addRequest(reviveRequest);
reviveRequests[i] = reviveRequest;
}
return reviveRequests;
}

public ReviveRequest[] addAndGetReviveRequests(
int shuffleId,
int mapId,
Expand Down Expand Up @@ -519,8 +496,8 @@ private void submitRetryPushMergedData(
pushState.removeBatch(oldGroupedBatchId, batches.get(0).loc.hostAndPushPort());
} else {
ReviveRequest[] requests =
addAndGetReviveRequestsWithMultiCause(
shuffleId, mapId, attemptId, reviveFailedBatchesMap, reviveFailedBatchesCauses);
addAndGetReviveRequests(
shuffleId, mapId, attemptId, reviveFailedBatchesMap, StatusCode.HARD_SPLIT);
pushDataRetryPool.submit(
() ->
submitRetryPushMergedData(
Expand Down Expand Up @@ -1466,22 +1443,20 @@ public void onFailure(Throwable e) {
public void onSuccess(ByteBuffer response) {
byte reason = response.get();
if (reason == StatusCode.HARD_SPLIT.getValue()) {
PbPushMergedDataUnsuccessfulPartitionInfo partitionInfo;
PbPushMergedDataSplitPartitionInfo partitionInfo;
try {
partitionInfo = TransportMessage.fromByteBuffer(response).getParsedPayload();
} catch (CelebornIOException | InvalidProtocolBufferException e) {
callback.onFailure(
new CelebornIOException("parse pushMergedData response failed", e));
return;
}
List<Integer> splitOrFailedPartitionIndexes =
partitionInfo.getSplitOrFailedPartitionIndexesList();
List<Integer> splitPartitionIndexes = partitionInfo.getSplitPartitionIndexesList();
List<Integer> statusCodeList = partitionInfo.getStatusCodesList();
ArrayList<DataBatches.DataBatch> batchesNeedResubmit = new ArrayList<>();
List<StatusCode> causeList = new ArrayList<>();
StringBuilder dataBatchReviveInfos = new StringBuilder();
for (int i = 0; i < splitOrFailedPartitionIndexes.size(); i++) {
int partitionIndex = splitOrFailedPartitionIndexes.get(i);
for (int i = 0; i < splitPartitionIndexes.size(); i++) {
int partitionIndex = splitPartitionIndexes.get(i);
int batchId = batches.get(partitionIndex).batchId;
dataBatchReviveInfos.append(
String.format(
Expand All @@ -1506,11 +1481,10 @@ public void onSuccess(ByteBuffer response) {
}
} else {
batchesNeedResubmit.add(batches.get(partitionIndex));
causeList.add(StatusCode.fromValue(statusCodeList.get(i).byteValue()));
}
}
logger.info(
"Push merged data to {} partial success required for shuffle {} map {} attempt {} groupedBatch {}. Unsuccessful batches {}.",
"Push merged data to {} partial success required for shuffle {} map {} attempt {} groupedBatch {}. split batches {}.",
addressPair,
shuffleId,
mapId,
Expand All @@ -1522,8 +1496,8 @@ public void onSuccess(ByteBuffer response) {
callback.onSuccess(ByteBuffer.wrap(new byte[] {StatusCode.SOFT_SPLIT.getValue()}));
} else {
ReviveRequest[] requests =
addAndGetReviveRequestsWithMultiCause(
shuffleId, mapId, attemptId, batchesNeedResubmit, causeList);
addAndGetReviveRequests(
shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT);
pushDataRetryPool.submit(
() ->
submitRetryPushMergedData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ public <T extends GeneratedMessageV3> T getParsedPayload() throws InvalidProtoco
return (T) PbSegmentStart.parseFrom(payload);
case NOTIFY_REQUIRED_SEGMENT_VALUE:
return (T) PbNotifyRequiredSegment.parseFrom(payload);
case PUSH_MERGED_DATA_UNSUCCESSFUL_PARTITION_INFO_VALUE:
return (T) PbPushMergedDataUnsuccessfulPartitionInfo.parseFrom(payload);
case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE:
return (T) PbPushMergedDataSplitPartitionInfo.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
Expand Down
6 changes: 3 additions & 3 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ enum MessageType {
BATCH_UNREGISTER_SHUFFLE_RESPONSE= 88;
REVISE_LOST_SHUFFLES = 89;
REVISE_LOST_SHUFFLES_RESPONSE = 90;
PUSH_MERGED_DATA_UNSUCCESSFUL_PARTITION_INFO = 91;
PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
}

enum StreamType {
Expand Down Expand Up @@ -876,7 +876,7 @@ message PbReviseLostShufflesResponse{
string message = 2;
}

message PbPushMergedDataUnsuccessfulPartitionInfo {
repeated int32 splitOrFailedPartitionIndexes = 1;
message PbPushMergedDataSplitPartitionInfo {
repeated int32 splitPartitionIndexes = 1;
repeated int32 statusCodes = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,8 @@ object ControlMessages extends Logging {
MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE,
pb.toByteArray)

case pb: PbPushMergedDataUnsuccessfulPartitionInfo =>
new TransportMessage(MessageType.PUSH_MERGED_DATA_UNSUCCESSFUL_PARTITION_INFO, pb.toByteArray)
case pb: PbPushMergedDataSplitPartitionInfo =>
new TransportMessage(MessageType.PUSH_MERGED_DATA_SPLIT_PARTITION_INFO, pb.toByteArray)

case HeartbeatFromWorker(
host,
Expand Down Expand Up @@ -1412,8 +1412,8 @@ object ControlMessages extends Logging {
case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE_VALUE =>
PbReportBarrierStageAttemptFailureResponse.parseFrom(message.getPayload)

case PUSH_MERGED_DATA_UNSUCCESSFUL_PARTITION_INFO_VALUE =>
PbPushMergedDataUnsuccessfulPartitionInfo.parseFrom(message.getPayload)
case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE =>
PbPushMergedDataSplitPartitionInfo.parseFrom(message.getPayload)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.celeborn.common.network.client.{RpcResponseCallback, Transport
import org.apache.celeborn.common.network.protocol.{Message, PushData, PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, RpcFailure, RpcRequest, RpcResponse, TransportMessage}
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.BaseMessageHandler
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, PbPushDataHandShake, PbPushMergedDataUnsuccessfulPartitionInfo, PbRegionFinish, PbRegionStart, PbSegmentStart}
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, PbPushDataHandShake, PbPushMergedDataSplitPartitionInfo, PbRegionFinish, PbRegionStart, PbSegmentStart}
import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.unsafe.Platform
Expand Down Expand Up @@ -493,7 +493,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
logDebug(s"[Case1] Receive push merged data for committed hard split partition of " +
s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.HARD_SPLIT)
pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
}
} else {
if (storageManager.shuffleKeySet().contains(shuffleKey)) {
Expand All @@ -503,13 +503,13 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
logDebug(s"[Case2] Receive push merged data for committed hard split partition of " +
s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.HARD_SPLIT)
pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
} else {
logWarning(s"While handling PushMergedData, Partition location wasn't found for " +
s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId, uniqueId $id).")
pushMergedDataCallback.addSplitOrFailedPartition(
index,
StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND)
pushMergedDataCallback.onFailure(
new CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND))
return
}
}
}
Expand All @@ -520,7 +520,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
// This should before return exception to make current push data can revive and retry.
if (shutdown.get()) {
partitionIdToLocations.indices.foreach(index =>
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.HARD_SPLIT))
pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT))
pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
return
}
Expand All @@ -545,34 +545,35 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
fileWriters.zipWithIndex.foreach {
case (fileWriter, index) =>
if (fileWriter == null) {
if (!pushMergedDataCallback.isSplitOrFailedPartition(index)) {
pushMergedDataCallback.onFailure(new CelebornIOException(s"Partition $index's fileWriter not found, but it hasn't been identified in the previous validation step."))
if (!pushMergedDataCallback.isHardSplitPartition(index)) {
pushMergedDataCallback.onFailure(
new CelebornIOException(s"Partition $index's fileWriter not found," +
s" but it hasn't been identified in the previous validation step."))
return
}
} else if (fileWriter.isClosed) {
val fileInfo = fileWriter.getCurrentFileInfo
logWarning(
s"[handlePushMergedData] FileWriter is already closed! File path ${fileInfo.getFilePath} " +
s"length ${fileInfo.getFileLength}")
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.HARD_SPLIT)
pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
} else {
val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
if (splitStatus == StatusCode.HARD_SPLIT) {
logWarning(
s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId")
workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.HARD_SPLIT)
pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
} else if (splitStatus == StatusCode.SOFT_SPLIT) {
pushMergedDataCallback.addSplitOrFailedPartition(index, StatusCode.SOFT_SPLIT)
pushMergedDataCallback.addSplitPartition(index, StatusCode.SOFT_SPLIT)
}
}
if (!pushMergedDataCallback.isSplitOrFailedPartition(index) ||
pushMergedDataCallback.getStatusCode(index) == StatusCode.SOFT_SPLIT.getValue) {
if (!pushMergedDataCallback.isHardSplitPartition(index)) {
fileWriter.incrementPendingWrites()
}
}

val unWritableIndexes = pushMergedDataCallback.getUnWritableIndexes
val unWritableIndexes = pushMergedDataCallback.getHardSplitIndexes
val writePromise = Promise[Array[StatusCode]]()
// for primary, send data to replica
if (doReplicate) {
Expand Down Expand Up @@ -602,11 +603,11 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
override def onSuccess(response: ByteBuffer): Unit = {
val replicaReason = response.get()
try {
val pushMergedDataResponse: PbPushMergedDataUnsuccessfulPartitionInfo =
val pushMergedDataResponse: PbPushMergedDataSplitPartitionInfo =
TransportMessage.fromByteBuffer(
response).getParsedPayload[PbPushMergedDataUnsuccessfulPartitionInfo]()
pushMergedDataCallback.unionSplitOrFailedPartitions(
pushMergedDataResponse.getSplitOrFailedPartitionIndexesList,
response).getParsedPayload[PbPushMergedDataSplitPartitionInfo]()
pushMergedDataCallback.unionReplicaSplitPartitions(
pushMergedDataResponse.getSplitPartitionIndexesList,
pushMergedDataResponse.getStatusCodesList)
} catch {
case e: CelebornIOException =>
Expand All @@ -626,7 +627,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
indexOfUnWritableIndexes += 1
} else {
if (result(index) != StatusCode.SUCCESS) {
pushMergedDataCallback.addSplitOrFailedPartition(index, result(index))
pushMergedDataCallback.addSplitPartition(index, result(index))
}
}
index += 1
Expand Down Expand Up @@ -738,7 +739,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
indexOfUnWritableIndexes += 1
} else {
if (result(index) != StatusCode.SUCCESS) {
pushMergedDataCallback.addSplitOrFailedPartition(index, result(index))
pushMergedDataCallback.addSplitPartition(index, result(index))
}
}
index += 1
Expand Down Expand Up @@ -836,21 +837,17 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
}

class PushMergedDataCallback(callback: RpcResponseCallback) {
private val splitOrFailedPartitionStatuses = new mutable.HashMap[Int, Byte]()

def addSplitOrFailedPartition(index: Int, statusCode: StatusCode): Unit = {
splitOrFailedPartitionStatuses.put(index, statusCode.getValue)
}
private val splitPartitionStatuses = new mutable.HashMap[Int, Byte]()

def isSplitOrFailedPartition(index: Int): Boolean = {
splitOrFailedPartitionStatuses.contains(index)
def addSplitPartition(index: Int, statusCode: StatusCode): Unit = {
splitPartitionStatuses.put(index, statusCode.getValue)
}

def getStatusCode(index: Int): Byte = {
splitOrFailedPartitionStatuses.getOrElse(index, -1)
def isHardSplitPartition(index: Int): Boolean = {
splitPartitionStatuses.getOrElse(index, -1) == StatusCode.HARD_SPLIT.getValue
}

def unionSplitOrFailedPartitions(
def unionReplicaSplitPartitions(
replicaPartitionIndexes: util.List[Integer],
replicaStatusCodes: util.List[Integer]): Unit = {
if (replicaPartitionIndexes.size() != replicaStatusCodes.size()) {
Expand All @@ -860,8 +857,8 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
for (i <- 0 until replicaPartitionIndexes.size()) {
val index = replicaPartitionIndexes.get(i)
// if primary and replica have the same index, use the primary's status code
if (!splitOrFailedPartitionStatuses.contains(index)) {
splitOrFailedPartitionStatuses.put(index, replicaStatusCodes.get(i).byteValue())
if (!splitPartitionStatuses.contains(index)) {
splitPartitionStatuses.put(index, replicaStatusCodes.get(i).byteValue())
}
}
}
Expand All @@ -870,31 +867,31 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
* Returns the ordered indexes of partitions that are not writable.
* A partition is considered not writable if it is marked as HARD_SPLIT or failed.
*/
def getUnWritableIndexes: Array[Int] = {
splitOrFailedPartitionStatuses.collect {
case (partitionIndex, statusCode) if statusCode != StatusCode.SOFT_SPLIT.getValue =>
def getHardSplitIndexes: Array[Int] = {
splitPartitionStatuses.collect {
case (partitionIndex, statusCode) if statusCode == StatusCode.HARD_SPLIT.getValue =>
partitionIndex
}.toSeq.sorted.toArray
}

def onSuccess(status: StatusCode): Unit = {
val splitOrFailedPartitionIndexes = new util.ArrayList[Integer]()
val splitPartitionIndexes = new util.ArrayList[Integer]()
val statusCodes = new util.ArrayList[Integer]()
var i = 0
splitOrFailedPartitionStatuses.foreach {
splitPartitionStatuses.foreach {
case (partitionIndex, statusCode) =>
splitOrFailedPartitionIndexes.add(partitionIndex)
splitPartitionIndexes.add(partitionIndex)
statusCodes.add(statusCode)
i += 1
}
val reason: Byte =
if (splitOrFailedPartitionStatuses.isEmpty || status == StatusCode.MAP_ENDED) {
if (splitPartitionStatuses.isEmpty || status == StatusCode.MAP_ENDED) {
status.getValue
} else {
StatusCode.HARD_SPLIT.getValue
}
val pushMergedDataInfo = PbPushMergedDataUnsuccessfulPartitionInfo.newBuilder()
.addAllSplitOrFailedPartitionIndexes(splitOrFailedPartitionIndexes)
val pushMergedDataInfo = PbPushMergedDataSplitPartitionInfo.newBuilder()
.addAllSplitPartitionIndexes(splitPartitionIndexes)
.addAllStatusCodes(statusCodes)
.build()
val pushMergedDataInfoByteBuffer = Utils.toTransportMessage(pushMergedDataInfo)
Expand Down
Loading

0 comments on commit 7f85fe9

Please sign in to comment.