Skip to content

Commit

Permalink
Use test-triton in CI (#2183)
Browse files Browse the repository at this point in the history
This removes code duplication and adds regular testing of
`test-triton.sh` itself
  • Loading branch information
leshikus authored Sep 13, 2024
1 parent d7fd027 commit 7f5ca47
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 173 deletions.
127 changes: 29 additions & 98 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ jobs:
repository: pytorch/pytorch
ref: ${{ inputs.pytorch_ref }}

- name: Install test dependencies
- name: Install pass_rate dependencies
run: |
pip install pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout expecttest defusedxml
pip install git+https://github.com/kwasd/[email protected]
pip install defusedxml
- name: Setup Triton
uses: ./.github/actions/setup-triton
Expand All @@ -112,117 +111,49 @@ jobs:
cd python
lit -v build/*/test
- name: Create directory for tests reports
- name: Create test-triton command line
run: |
mkdir reports
echo "TRITON_TEST_REPORTS=true" >> $GITHUB_ENV
echo "TRITON_TEST_WARNING_REPORTS=true" >> $GITHUB_ENV
echo "TRITON_TEST_REPORTS_DIR=$GITHUB_WORKSPACE/reports" >> $GITHUB_ENV
- name: Enable ignoring test errors
if: inputs.ignore_errors
run: |
echo "TRITON_TEST_IGNORE_ERRORS=true" >> $GITHUB_ENV
- name: Set a default skip list
if: inputs.skip_list == ''
run: |
if [[ -n "${{ inputs.driver_version }}" ]]; then
if [[ -n "${{ inputs.skip_list }}" ]]; then
skiplist="$GITHUB_WORKSPACE/scripts/skiplist/${{ inputs.skip_list }}"
elif [[ -n "${{ inputs.driver_version }}" ]]; then
skiplist="$GITHUB_WORKSPACE/scripts/skiplist/${{ inputs.driver_version }}"
else
skiplist="$GITHUB_WORKSPACE/scripts/skiplist/default"
fi
if [[ -d $skiplist ]]; then
echo "TRITON_TEST_SKIPLIST_DIR=$skiplist" | tee -a $GITHUB_ENV
fi
- name: Set a custom skip list
if: inputs.skip_list != ''
run: |
echo "TRITON_TEST_SKIPLIST_DIR=$GITHUB_WORKSPACE/scripts/skiplist/${{ inputs.skip_list }}" | tee -a $GITHUB_ENV
if [ -d "$skiplist" ]; then
skiplist="--skip-list $skiplist"
else
skiplist=
fi
{
echo SKIPLIST="$skiplist"
echo TRITON_TEST_CMD="bash -v -x scripts/test-triton.sh --warning-reports --skip-pytorch-install --reports-dir $GITHUB_WORKSPACE/reports --ignore-errors $skiplist"
} | tee -a $GITHUB_ENV
- name: Run core tests
run: |
source ./scripts/pytest-utils.sh
ensure_spirv_dis
cd python/test/unit
${{ env.TRITON_TEST_CMD }} --core
TRITON_TEST_SUITE=language \
pytest -vvv -n 8 --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
TRITON_TEST_SUITE=subprocess \
pytest -vvv -n 8 --device xpu language/test_subprocess.py
# Run runtime tests serially to avoid race condition with cache handling
TRITON_TEST_SUITE=runtime \
pytest -vvv --device xpu runtime/
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_TEST_SUITE=line_info \
TRITON_DISABLE_LINE_INFO=0 \
pytest -vvv --device xpu language/test_line_info.py
- name: Run instrumentation tests
run: |
source ./scripts/pytest-utils.sh
# FIXME: the "instrumentation" test suite currently contains only one test, when all tests
# are skipped pytest reports an error. If the only test is the skip list, then we shouldn't
# run pytest at all. This must be changed when there is more than one instrumentation test.
if [[ $TEST_UNSKIP = false && -s $TRITON_TEST_SKIPLIST_DIR/instrumentation.txt ]]; then
exit 0
fi
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Could not find '${SHARED_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
TRITON_TEST_SUITE=instrumentation \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
pytest -vvv --device xpu instrumentation/test_gpuhello.py
- name: Clear cache
- name: Run interpreter tests
run: |
rm -rf ~/.triton
${{ env.TRITON_TEST_CMD }} --interpreter --skip-pip-install
- name: Run interpreter tests
- name: Run Tutorials
run: |
source ./scripts/pytest-utils.sh
cd python/test/unit
TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter \
pytest -vvv -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py --device cpu
${{ env.TRITON_TEST_CMD }} --tutorial --skip-pip-install
- name: Regression tests
- name: Run CXX unittests
run: |
source ./scripts/pytest-utils.sh
cd python/test/regression
TRITON_TEST_SUITE=regression \
pytest -vvv -s --device xpu . --reruns 10 --ignore=test_performance.py
${{ env.TRITON_TEST_CMD }} --unit --skip-pip-install
- name: Run Tutorials
- name: Run instrumentation tests
run: |
source ./scripts/pytest-utils.sh
cd python/tutorials
run_tutorial_test "01-vector-add"
run_tutorial_test "02-fused-softmax"
run_tutorial_test "03-matrix-multiplication"
run_tutorial_test "04-low-memory-dropout"
run_tutorial_test "05-layer-norm"
run_tutorial_test "06-fused-attention"
run_tutorial_test "07-extern-functions"
run_tutorial_test "08-grouped-gemm"
run_tutorial_test "10-experimental-block-pointer"
run_tutorial_test "10i-experimental-block-pointer"
${{ env.TRITON_TEST_CMD }} --instrumentation --skip-pip-install
- name: Run CXX unittests
- name: Clear cache
run: |
cd python/build/*cmake*
ctest
rm -rf ~/.triton
- name: Get transformers version
run: |
Expand Down Expand Up @@ -271,9 +202,9 @@ jobs:
- name: Pass rate
run: |
source ./scripts/capture-hw-details.sh
python3 scripts/pass_rate.py --reports reports
python3 scripts/pass_rate.py --reports reports --json > pass_rate.json
python3 scripts/pass_rate.py --reports reports --suite tutorials --json > pass_rate_tutorials.json
python3 scripts/pass_rate.py --reports reports ${{ env.SKIPLIST }}
python3 scripts/pass_rate.py --reports reports --json ${{ env.SKIPLIST }} > pass_rate.json
python3 scripts/pass_rate.py --reports reports --suite tutorials --json ${{ env.SKIPLIST }} > pass_rate_tutorials.json
- name: Upload pass rate report
# upload reports only for the default branch
Expand Down
2 changes: 1 addition & 1 deletion scripts/check-update-translator-cid.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ for cid in $COMMIT_IDS; do
fi

# execute full tests
if ./scripts/test-triton.sh --skip-deps; then
if ./scripts/test-triton.sh --skip-pytorch; then
echo "Tests passed for translator commit $cid"
echo "A newer commit found: $cid"
FOUND=true
Expand Down
31 changes: 20 additions & 11 deletions scripts/pass_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,25 @@ def create_argument_parser() -> argparse.ArgumentParser:
default='all',
help='name of the test suite, default: %(default)s',
)
argument_parser.add_argument(
'--skip-list',
type=str,
help='an exclude list dir used in pass rate calculation, can be passed via TRITON_TEST_SKIPLIST_DIR as well',
)
return argument_parser


def get_deselected(report_path: pathlib.Path) -> int:
def get_deselected(report_path: pathlib.Path, skiplist_dir: pathlib.Path) -> int:
"""Calculates deselected (via skiplist) tests."""
skiplist_dir = os.getenv('TRITON_TEST_SKIPLIST_DIR', 'scripts/skiplist/default')
skiplist_path = pathlib.Path(skiplist_dir) / f'{report_path.stem}.txt'
skiplist_path = skiplist_dir / f'{report_path.stem}.txt'
if not skiplist_path.exists():
return 0
with skiplist_path.open('r') as f:
# skip empty lines and comments
return len([line for line in f.readlines() if line and not line.startswith('#')])


def parse_report(report_path: pathlib.Path) -> ReportStats:
def parse_report(report_path: pathlib.Path, skiplist_dir: pathlib.Path) -> ReportStats:
"""Parses the specified report."""
stats = ReportStats(name=report_path.stem)
root = parse(report_path).getroot()
Expand Down Expand Up @@ -103,7 +107,7 @@ def parse_report(report_path: pathlib.Path) -> ReportStats:
if test_unskip not in ('true', 'false'):
raise ValueError('Error: please set TEST_UNSKIP true or false')
if test_unskip == 'false':
deselected = get_deselected(report_path)
deselected = get_deselected(report_path, skiplist_dir)
stats.skipped += deselected
stats.total += deselected
stats.passed = stats.total - stats.failed - stats.skipped - stats.xfailed
Expand Down Expand Up @@ -131,13 +135,15 @@ def find_stats(stats: List[ReportStats], name: str) -> ReportStats:
raise ValueError(f'{name} not found')


def parse_junit_reports(reports_path: pathlib.Path) -> List[ReportStats]:
def parse_junit_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
"""Parses junit report in the specified directory."""
return [parse_report(report) for report in reports_path.glob('*.xml')]
reports_path = pathlib.Path(args.reports)
return [parse_report(report, args.skiplist_dir) for report in reports_path.glob('*.xml')]


def parse_tutorials_reports(reports_path: pathlib.Path) -> List[ReportStats]:
def parse_tutorials_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
"""Parses tutorials reports in the specified directory."""
reports_path = pathlib.Path(args.reports)
stats = ReportStats(name='tutorials')
for report in reports_path.glob('tutorial-*.txt'):
result = report.read_text().strip()
Expand All @@ -151,9 +157,9 @@ def parse_tutorials_reports(reports_path: pathlib.Path) -> List[ReportStats]:
return [stats]


def parse_reports(reports_path: pathlib.Path) -> List[ReportStats]:
def parse_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
"""Parses all report in the specified directory."""
return parse_junit_reports(reports_path) + parse_tutorials_reports(reports_path)
return parse_junit_reports(args) + parse_tutorials_reports(args)


def print_text_stats(stats: ReportStats):
Expand Down Expand Up @@ -198,7 +204,10 @@ def print_json_stats(stats: ReportStats):
def main():
"""Main."""
args = create_argument_parser().parse_args()
stats = parse_reports(pathlib.Path(args.reports))
args.report_path = pathlib.Path(args.reports)
args.skiplist_dir = pathlib.Path(
args.skip_list if args.skip_list else os.getenv('TRITON_TEST_SKIPLIST_DIR', 'scripts/skiplist/default'))
stats = parse_reports(args)

if args.suite == 'all':
summary = overall_stats(stats)
Expand Down
Loading

0 comments on commit 7f5ca47

Please sign in to comment.