Skip to content

Commit

Permalink
spark3 only
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 15, 2024
1 parent caf7fed commit da0f247
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ private void initializeLifecycleManager(String appId) {
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId));
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;

import scala.Option;
Expand All @@ -37,9 +35,6 @@
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
Expand All @@ -52,7 +47,7 @@
import org.apache.celeborn.reflect.DynFields;

public class SparkUtils {
private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class);
private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class);

public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id ";

Expand Down Expand Up @@ -98,7 +93,7 @@ public static SQLMetric getUnsafeRowSerializerDataSizeMetric(UnsafeRowSerializer
field.setAccessible(true);
return (SQLMetric) field.get(serializer);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOG.warn("Failed to get dataSize metric, aqe won`t work properly.");
logger.warn("Failed to get dataSize metric, aqe won`t work properly.");
}
return null;
}
Expand Down Expand Up @@ -205,50 +200,7 @@ public static void cancelShuffle(int shuffleId, String reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
}
}

private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
DynFields.builder()
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<HashMap<Long, TaskInfo>> TASK_INFOS_FIELD =
DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build();

public static boolean taskAnotherAttemptRunning(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
HashMap<Long, TaskInfo> taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get();
TaskInfo taskInfo = taskInfos.get(taskId);
if (taskInfo != null) {
return taskSetManager.taskAttempts()[taskInfo.index()].exists(
ti -> {
if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti);
return true;
} else {
return false;
}
});
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
return false;
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ private void initializeLifecycleManager(String appId) {
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();

lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId));

lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.shuffle.celeborn;

import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;

Expand Down Expand Up @@ -331,8 +330,12 @@ public static void cancelShuffle(int shuffleId, String reason) {
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<HashMap<Long, TaskInfo>> TASK_INFOS_FIELD =
DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build();
private static final DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
TASK_INFOS_FIELD =
DynFields.builder()
.hiddenImpl(TaskSetManager.class, "taskInfos")
.defaultAlwaysNull()
.build();

public static boolean taskAnotherAttemptRunning(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
Expand All @@ -342,9 +345,10 @@ public static boolean taskAnotherAttemptRunning(long taskId) {
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
HashMap<Long, TaskInfo> taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get();
TaskInfo taskInfo = taskInfos.get(taskId);
if (taskInfo != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
TaskInfo taskInfo = taskInfoOption.get();
return taskSetManager.taskAttempts()[taskInfo.index()].exists(
ti -> {
if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) {
Expand All @@ -363,7 +367,7 @@ public static boolean taskAnotherAttemptRunning(long taskId) {
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
LOG.error("Can not get active SparkContext, skip checking.");
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
logError(t.toString)
false
}
case None =>
throw new UnsupportedOperationException(
"unexpected! reportTaskShuffleFetchFailurePreCheck is not registered")
case None => true
}
}

Expand Down

0 comments on commit da0f247

Please sign in to comment.