Skip to content

Commit

Permalink
Use test-triton.sh in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
leshikus committed Sep 12, 2024
1 parent 48fbcbb commit 010fae1
Show file tree
Hide file tree
Showing 19 changed files with 255 additions and 277 deletions.
101 changes: 18 additions & 83 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,7 @@ jobs:
cd python
lit -v build/*/test
- name: Create directory for tests reports
run: |
mkdir reports
echo "TRITON_TEST_REPORTS=true" >> $GITHUB_ENV
echo "TRITON_TEST_WARNING_REPORTS=true" >> $GITHUB_ENV
echo "TRITON_TEST_REPORTS_DIR=$GITHUB_WORKSPACE/reports" >> $GITHUB_ENV
- name: Enable ignoring test errors
if: inputs.ignore_errors
run: |
echo "TRITON_TEST_IGNORE_ERRORS=true" >> $GITHUB_ENV
- name: Set a default skip list
- name: Set a skip list
if: inputs.skip_list == ''
run: |
if [[ -n "${{ inputs.driver_version }}" ]]; then
Expand All @@ -141,88 +129,33 @@ jobs:
run: |
echo "TRITON_TEST_SKIPLIST_DIR=$GITHUB_WORKSPACE/scripts/skiplist/${{ inputs.skip_list }}" | tee -a $GITHUB_ENV
- name: Run core tests
- name: Create test-triton command line
run: |
source ./scripts/pytest-utils.sh
ensure_spirv_dis
cd python/test/unit
TRITON_TEST_SUITE=language \
pytest -vvv -n 8 --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
TRITON_TEST_SUITE=subprocess \
pytest -vvv -n 8 --device xpu language/test_subprocess.py
echo TRITON_TEST_CMD="bash -v -x scripts/test-triton.sh --warning-reports --skip-pytorch-install --reports-dir $GITHUB_WORKSPACE/reports --ignore-errors --skip-list ${{ env.TRITON_TEST_SKIPLIST_DIR }}" | tee -a $GITHUB_ENV
# Run runtime tests serially to avoid race condition with cache handling
TRITON_TEST_SUITE=runtime \
pytest -vvv --device xpu runtime/
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_TEST_SUITE=line_info \
TRITON_DISABLE_LINE_INFO=0 \
pytest -vvv --device xpu language/test_line_info.py
- name: Run instrumentation tests
- name: Run core tests
run: |
source ./scripts/pytest-utils.sh
# FIXME: the "instrumentation" test suite currently contains only one test, when all tests
# are skipped pytest reports an error. If the only test is the skip list, then we shouldn't
# run pytest at all. This must be changed when there is more than one instrumentation test.
if [[ $TEST_UNSKIP = false && -s $TRITON_TEST_SKIPLIST_DIR/instrumentation.txt ]]; then
exit 0
fi
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
if [ ! -d "${SHARED_LIB_DIR}" ]; then
echo "Could not find '${SHARED_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
${{ env.TRITON_TEST_CMD }} --core
TRITON_TEST_SUITE=instrumentation \
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
pytest -vvv --device xpu instrumentation/test_gpuhello.py
- name: Clear cache
- name: Run interpreter tests
run: |
rm -rf ~/.triton
${{ env.TRITON_TEST_CMD }} --interpreter --skip-pip-install
- name: Run interpreter tests
- name: Run Tutorials
run: |
source ./scripts/pytest-utils.sh
cd python/test/unit
TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter \
pytest -vvv -n 16 -m interpreter language/test_core.py language/test_standard.py \
language/test_random.py --device cpu
${{ env.TRITON_TEST_CMD }} --tutorial --skip-pip-install
- name: Regression tests
- name: Run CXX unittests
run: |
source ./scripts/pytest-utils.sh
cd python/test/regression
TRITON_TEST_SUITE=regression \
pytest -vvv -s --device xpu . --reruns 10 --ignore=test_performance.py
${{ env.TRITON_TEST_CMD }} --unit --skip-pip-install
- name: Run Tutorials
- name: Run instrumentation tests
run: |
source ./scripts/pytest-utils.sh
cd python/tutorials
run_tutorial_test "01-vector-add"
run_tutorial_test "02-fused-softmax"
run_tutorial_test "03-matrix-multiplication"
run_tutorial_test "04-low-memory-dropout"
run_tutorial_test "05-layer-norm"
run_tutorial_test "06-fused-attention"
run_tutorial_test "07-extern-functions"
run_tutorial_test "08-grouped-gemm"
run_tutorial_test "10-experimental-block-pointer"
run_tutorial_test "10i-experimental-block-pointer"
${{ env.TRITON_TEST_CMD }} --instrumentation --skip-pip-install
- name: Run CXX unittests
- name: Clear cache
run: |
cd python/build/*cmake*
ctest
rm -rf ~/.triton
- name: Get transformers version
run: |
Expand Down Expand Up @@ -271,16 +204,18 @@ jobs:
- name: Pass rate
run: |
source ./scripts/capture-hw-details.sh
python3 scripts/pass_rate.py --reports reports
python3 scripts/pass_rate.py --reports reports --json > pass_rate.json
python3 scripts/pass_rate.py --reports reports --suite tutorials --json > pass_rate_tutorials.json
- name: Upload pass rate report
# upload reports only for the default branch
if: github.ref_name == 'llvm-target' || github.ref_name == 'main'
uses: actions/upload-artifact@v4
with:
name: pass_rate-${{ inputs.python_version }}-${{ inputs.runner_label || inputs.driver_version }}
path: pass_rate.json
path: pass_rate*.json

- name: Upload test reports
if: inputs.upload_test_reports
Expand Down
2 changes: 1 addition & 1 deletion scripts/check-update-translator-cid.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ for cid in $COMMIT_IDS; do
fi

# execute full tests
if ./scripts/test-triton.sh --skip-deps; then
if ./scripts/test-triton.sh --skip-pytorch; then
echo "Tests passed for translator commit $cid"
echo "A newer commit found: $cid"
FOUND=true
Expand Down
75 changes: 52 additions & 23 deletions scripts/pass_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ class ReportStats:
fixme: int = 0
total: int = 0

@property
def pass_rate(self):
"""Pass rate."""
if self.total == 0:
return 0.0
return round(100 * self.passed / self.total, 2)

@property
def pass_rate_without_xfailed(self):
"""Pass rate without xfailed."""
if self.total - self.xfailed == 0:
return 0.0
return round(100 * self.passed / (self.total - self.xfailed), 2)


def create_argument_parser() -> argparse.ArgumentParser:
"""Creates ArgumentParser."""
Expand All @@ -38,6 +52,12 @@ def create_argument_parser() -> argparse.ArgumentParser:
action='store_true',
help='print stats in JSON',
)
argument_parser.add_argument(
'--suite',
type=str,
default='all',
help='name of the test suite, default: %(default)s',
)
return argument_parser


Expand Down Expand Up @@ -103,6 +123,14 @@ def overall_stats(stats: List[ReportStats]) -> ReportStats:
return overall


def find_stats(stats: List[ReportStats], name: str) -> ReportStats:
"""Finds stats by name."""
for item in stats:
if item.name == name:
return item
raise ValueError(f'{name} not found')


def parse_junit_reports(reports_path: pathlib.Path) -> List[ReportStats]:
"""Parses junit report in the specified directory."""
return [parse_report(report) for report in reports_path.glob('*.xml')]
Expand All @@ -120,15 +148,15 @@ def parse_tutorials_reports(reports_path: pathlib.Path) -> List[ReportStats]:
stats.skipped += 1
elif result == 'FAIL':
stats.failed += 1
return [stats] if stats.total > 0 else []
return [stats]


def parse_reports(reports_path: pathlib.Path) -> List[ReportStats]:
"""Parses all report in the specified directory."""
return parse_junit_reports(reports_path) + parse_tutorials_reports(reports_path)


def print_stats(stats: ReportStats):
def print_text_stats(stats: ReportStats):
"""Prints report stats."""
print(
f'{stats.name}:'
Expand All @@ -138,20 +166,12 @@ def print_stats(stats: ReportStats):
f' xfailed: {stats.xfailed},'
f' total: {stats.total},'
f' fixme: {stats.fixme},'
f' pass rate (w/o xfailed): {round(100 * stats.passed / (stats.total - stats.xfailed), 2)}%'
f' pass rate (w/o xfailed): {stats.pass_rate_without_xfailed}%'
) # yapf: disable


def print_text_stats(stats: List[ReportStats]):
"""Prints human readable stats."""
for item in sorted(stats, key=lambda x: x.name):
print_stats(item)
print_stats(overall_stats(stats))


def print_json_stats(stats: List[ReportStats]):
def print_json_stats(stats: ReportStats):
"""Print JSON stats."""
overall = overall_stats(stats)
data = {
'ts': datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S'),
'git_ref': os.getenv('GITHUB_REF_NAME', ''),
Expand All @@ -162,15 +182,15 @@ def print_json_stats(stats: List[ReportStats]):
'gpu_device': os.getenv('GPU_DEVICE', ''),
'python_version': platform.python_version(),
'pytorch_version': os.getenv('PYTORCH_VERSION', ''),
'testsuite': overall.name,
'passed': overall.passed,
'failed': overall.failed,
'skipped': overall.skipped,
'xfailed': overall.xfailed,
'total': overall.total,
'fixme': overall.fixme,
'pass_rate_1': round(100 * overall.passed / overall.total, 2),
'pass_rate_2': round(100 * overall.passed / (overall.total - overall.xfailed), 2)
'testsuite': stats.name,
'passed': stats.passed,
'failed': stats.failed,
'skipped': stats.skipped,
'xfailed': stats.xfailed,
'total': stats.total,
'fixme': stats.fixme,
'pass_rate_1': stats.pass_rate,
'pass_rate_2': stats.pass_rate_without_xfailed,
} # yapf: disable
print(json.dumps(data, indent=2))

Expand All @@ -179,10 +199,19 @@ def main():
"""Main."""
args = create_argument_parser().parse_args()
stats = parse_reports(pathlib.Path(args.reports))

if args.suite == 'all':
summary = overall_stats(stats)
else:
summary = find_stats(stats, args.suite)

if args.json:
print_json_stats(stats)
print_json_stats(summary)
else:
print_text_stats(stats)
if args.suite == 'all':
for item in sorted(stats, key=lambda x: x.name):
print_text_stats(item)
print_text_stats(summary)


if __name__ == '__main__':
Expand Down
5 changes: 5 additions & 0 deletions scripts/pytest-utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ TRITON_TEST_SKIPLIST_DIR="$(cd "$TRITON_TEST_SKIPLIST_DIR" && pwd)"
# absolute path for the current skip list
CURRENT_SKIPLIST_DIR="$SCRIPTS_DIR/skiplist/current"

err() {
echo $@
exit 1
}

pytest() {
pytest_extra_args=()

Expand Down
Loading

0 comments on commit 010fae1

Please sign in to comment.