Skip to content

Commit

Permalink
Ut for spark utils (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 23, 2024
1 parent 0ada1d7 commit f7e8c0e
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;

import scala.Option;
import scala.Some;
Expand Down Expand Up @@ -221,57 +225,79 @@ public static void cancelShuffle(int shuffleId, String reason) {
.defaultAlwaysNull()
.build();

public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
protected static TaskSetManager getTaskSetManager(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
TaskInfo taskInfo = taskInfoOption.get();
int taskIndex = taskInfo.index();
if (taskSetManager.successful()[taskIndex]) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt has been successful.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber());
return true;
return taskIdToTaskSetManager.get(taskId);
} else {
LOG.error("Can not get active SparkContext.");
return null;
}
}

protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, long taskId) {
if (taskSetManager != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
int taskIndex = taskInfoOption.get().index();
return scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskIndex])
.asJavaCollection().stream()
.collect(Collectors.toList());
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return Collections.emptyList();
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return Collections.emptyList();
}
}

public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
Optional<TaskInfo> taskInfoOpt =
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();
int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is finished.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else if (ti.running()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
}
}
return scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskIndex])
.asJavaCollection().stream()
.anyMatch(
ti -> {
if (!ti.finished() && ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else {
return false;
}
});
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
return false;
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip checking.");
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.shuffle.celeborn;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;

import scala.Option;
import scala.Some;
Expand Down Expand Up @@ -337,57 +341,79 @@ public static void cancelShuffle(int shuffleId, String reason) {
.defaultAlwaysNull()
.build();

public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
protected static TaskSetManager getTaskSetManager(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
TaskInfo taskInfo = taskInfoOption.get();
int taskIndex = taskInfo.index();
if (taskSetManager.successful()[taskIndex]) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt has been successful.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber());
return true;
return taskIdToTaskSetManager.get(taskId);
} else {
LOG.error("Can not get active SparkContext.");
return null;
}
}

protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, long taskId) {
if (taskSetManager != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
int taskIndex = taskInfoOption.get().index();
return scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskIndex])
.asJavaCollection().stream()
.collect(Collectors.toList());
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return Collections.emptyList();
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return Collections.emptyList();
}
}

public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
Optional<TaskInfo> taskInfoOpt =
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();
int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is finished.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else if (ti.running()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
}
}
return scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskIndex])
.asJavaCollection().stream()
.anyMatch(
ti -> {
if (!ti.finished() && ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else {
return false;
}
});
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
return false;
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip checking.");
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.{interval, timeout}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime

class SparkUtilsSuite extends AnyFunSuite {
test("another task running or successful") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]")
val sparkSession = SparkSession.builder()
.config(sparkConf)
.config("spark.sql.shuffle.partitions", 2)
.getOrCreate()

try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
val rdd = sc.parallelize(1 to 100, 2)
rdd.mapPartitions { iter =>
Thread.sleep(5000)
iter
}.collect()
} catch {
case _: InterruptedException =>
}
}
}
jobThread.start()

eventually(timeout(5.seconds), interval(100.milliseconds)) {
val taskId = 0
val taskSetManager = SparkUtils.getTaskSetManager(taskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId).size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
}

jobThread.interrupt()
} finally {
sparkSession.stop()
}
}
}

0 comments on commit f7e8c0e

Please sign in to comment.