diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 0d7354cb76a..d450071f477 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -101,8 +101,6 @@ private void initializeLifecycleManager(String appId) { if (celebornConf.clientFetchThrowsFetchFailure()) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); - lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( - taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); lifecycleManager.registerShuffleTrackerCallback( shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId)); } diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index b685eade055..3f2e4709750 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,8 +20,6 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.util.HashMap; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import scala.Option; @@ -37,9 +35,6 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.scheduler.ShuffleMapStage; -import org.apache.spark.scheduler.TaskInfo; -import org.apache.spark.scheduler.TaskSchedulerImpl; -import org.apache.spark.scheduler.TaskSetManager; import org.apache.spark.sql.execution.UnsafeRowSerializer; import org.apache.spark.sql.execution.metric.SQLMetric; import org.apache.spark.storage.BlockManagerId; @@ -52,7 +47,7 @@ import org.apache.celeborn.reflect.DynFields; public class SparkUtils { - private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class); + private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class); public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id "; @@ -98,7 +93,7 @@ public static SQLMetric getUnsafeRowSerializerDataSizeMetric(UnsafeRowSerializer field.setAccessible(true); return (SQLMetric) field.get(serializer); } catch (NoSuchFieldException | IllegalAccessException e) { - LOG.warn("Failed to get dataSize metric, aqe won`t work properly."); + logger.warn("Failed to get dataSize metric, aqe won`t work properly."); } return null; } @@ -205,50 +200,7 @@ public static void cancelShuffle(int shuffleId, String reason) { scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason)); } } else { - LOG.error("Can not get active SparkContext, skip cancelShuffle."); - } - } - - private static final DynFields.UnboundField> - TASK_ID_TO_TASK_SET_MANAGER_FIELD = - DynFields.builder() - .hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager") - .defaultAlwaysNull() - .build(); - private static final DynFields.UnboundField> TASK_INFOS_FIELD = - DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build(); - - public static boolean taskAnotherAttemptRunning(long taskId) { - if (SparkContext$.MODULE$.getActive().nonEmpty()) { - TaskSchedulerImpl taskScheduler = - (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); - ConcurrentHashMap taskIdToTaskSetManager = - TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); - TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId); - if (taskSetManager != null) { - HashMap taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get(); - TaskInfo taskInfo = taskInfos.get(taskId); - if (taskInfo != null) { - return taskSetManager.taskAttempts()[taskInfo.index()].exists( - ti -> { - if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) { - LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti); - return true; - } else { - return false; - } - }); - } else { - LOG.error("Can not get TaskInfo for taskId: {}", taskId); - return false; - } - } else { - LOG.error("Can not get TaskSetManager for taskId: {}", taskId); - return false; - } - } else { - LOG.error("Can not get active SparkContext, skip cancelShuffle."); - return false; + logger.error("Can not get active SparkContext, skip cancelShuffle."); } } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index d4434cb765f..307e877cb2f 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -143,8 +143,10 @@ private void initializeLifecycleManager(String appId) { if (celebornConf.clientFetchThrowsFetchFailure()) { MapOutputTrackerMaster mapOutputTracker = (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); + lifecycleManager.registerShuffleTrackerCallback( shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index e7f9995d28a..10a944b7ccf 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -1024,9 +1024,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logError(t.toString) false } - case None => - throw new UnsupportedOperationException( - "unexpected! reportTaskShuffleFetchFailurePreCheck is not registered") + case None => true } }