Skip to content

Commit

Permalink
Update test-triton.sh to support Triton installed from nightly wheel (#…
Browse files Browse the repository at this point in the history
…2213)

Fixes #2187.
  • Loading branch information
pbchekin authored Sep 11, 2024
1 parent aed4bf7 commit 8882448
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
17 changes: 17 additions & 0 deletions scripts/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Keep in sync with extra_require.tests from python/setup.py
lit
numpy
pytest
scipy>=1.7.1
llnl-hatchet

# Keep in sync with extra_require.tutorials from python/setup.py
matplotlib
pandas
tabulate

# Used by test-triton.sh
pytest-xdist
pytest-rerunfailures
pytest-select
pytest-timeout
54 changes: 32 additions & 22 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ for arg in "$@"; do
esac
done

# Only run interpreter test when $TEST_INTERPRETER is ture
# Only run interpreter test when $TEST_INTERPRETER is true
if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER" = false ] && [ "$TEST_TUTORIAL" = false ] && [ "$TEST_MICRO_BENCHMARKS" = false ] && [ "$TEST_BENCHMARK_SOFTMAX" = false ] && [ "$TEST_BENCHMARK_GEMM" = false ] && [ "$TEST_BENCHMARK_ATTENTION" = false ]; then
TEST_UNIT=true
TEST_CORE=true
Expand All @@ -97,7 +97,7 @@ if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER"
fi

if [ ! -v BASE ]; then
echo "**** BASE is not given *****"
echo "**** BASE is not given ****"
BASE=$(cd $(dirname "$0")/../.. && pwd)
echo "**** Default BASE is set to $BASE ****"
fi
Expand All @@ -106,34 +106,43 @@ if [ "$VENV" = true ]; then
source .venv/bin/activate
fi

export TRITON_PROJ=$BASE/intel-xpu-backend-for-triton
export TRITON_PROJ_BUILD=$TRITON_PROJ/python/build
export SCRIPTS_DIR=$(cd $(dirname "$0") && pwd)
SCRIPTS_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
TRITON_PROJ="$BASE/intel-xpu-backend-for-triton"

python3 -m pip install lit pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout setuptools==69.5.1 defusedxml
source "$SCRIPTS_DIR/pytest-utils.sh"

if [ "$TRITON_TEST_WARNING_REPORTS" == true ]; then
python3 -m pip install git+https://github.com/kwasd/[email protected]
fi

source $SCRIPTS_DIR/pytest-utils.sh
if [ "$TRITON_TEST_REPORTS" == true ]; then
capture_runtime_env
fi

if [ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]; then
$SKIP_DEPS || $SCRIPTS_DIR/compile-pytorch-ipex.sh --pytorch --ipex --pinned --source $([ $VENV = true ] && echo "--venv")
else
$SKIP_DEPS || $SCRIPTS_DIR/install-pytorch.sh $([ $VENV = true ] && echo "--venv")
fi
install_deps() {
if [ "$SKIP_DEPS" = true ]; then
echo "**** Skipping installation of dependencies ****"
return 0
fi

if [ ! -d "$TRITON_PROJ_BUILD" ]
then
echo "****** ERROR: Build Triton first ******"
exit 1
fi
echo "**** Installing dependencies ****"

python -m pip install -r "$SCRIPTS_DIR/requirements-test.txt"

if [ "$TRITON_TEST_WARNING_REPORTS" == true ]; then
python -m pip install git+https://github.com/kwasd/[email protected]
fi

if [ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]; then
$SCRIPTS_DIR/compile-pytorch-ipex.sh --pytorch --ipex --pinned --source $([ $VENV = true ] && echo "--venv")
else
$SCRIPTS_DIR/install-pytorch.sh $([ $VENV = true ] && echo "--venv")
fi
}

run_unit_tests() {
TRITON_PROJ_BUILD="$TRITON_PROJ/python/build"
if [ ! -d "$TRITON_PROJ_BUILD" ]; then
echo "****** ERROR: Build Triton first ******"
exit 1
fi

echo "***************************************************"
echo "****** Running Triton CXX unittests ******"
echo "***************************************************"
Expand Down Expand Up @@ -200,7 +209,7 @@ run_tutorial_tests() {
echo "***************************************************"
echo "**** Running Triton Tutorial tests ******"
echo "***************************************************"
python3 -m pip install matplotlib pandas tabulate -q
python -m pip install matplotlib pandas tabulate -q
cd $TRITON_PROJ/python/tutorials

run_tutorial_test "01-vector-add"
Expand Down Expand Up @@ -304,4 +313,5 @@ test_triton() {
fi
}

install_deps
test_triton

0 comments on commit 8882448

Please sign in to comment.