Skip to content

Commit

Permalink
[security] fix 'assert' & 'xml' & 'md5' vulnerability in python files (
Browse files Browse the repository at this point in the history
…#1962)

Fix: #1961

## 1. Use raise exception instead of  `assert` in python files.
Test files are not included.

## 2. Replace `xml` with `defusedxml` 
## 3. Fix MD5 issue
~~In Triton code, md5 is used for hash key generation, not for security
purpose, so this PR adds `usedforsecurity=False` to solve severity issue
found by Bandit~~
Replace md5 with sha256
  • Loading branch information
AshburnLee authored Aug 23, 2024
1 parent 9b5b553 commit 2d59963
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:

- name: Install test dependencies
run: |
pip install pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout expecttest
pip install pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout expecttest defusedxml
pip install git+https://github.com/kwasd/[email protected]
- name: Setup Triton
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def compile_module_from_src(src, name):
key = hashlib.md5(src.encode("utf-8")).hexdigest()
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
if cache_path is None:
Expand Down
6 changes: 4 additions & 2 deletions scripts/build_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def parse_args():

def check_cols(target_cols, all_cols):
diff = set(target_cols).difference(all_cols)
assert (len(diff) == 0), f"Couldn't find required columns: '{diff}' among available '{all_cols}'"
if len(diff) != 0:
raise ValueError(f"Couldn't find required columns: '{diff}' among available '{all_cols}'")


def transform_df(df, param_cols, tflops_col, hbm_col, benchmark, compiler, tag):
Expand All @@ -48,7 +49,8 @@ def transform_df(df, param_cols, tflops_col, hbm_col, benchmark, compiler, tag):
n: os.getenv(n.upper(), default="")
for n in ["libigc1_version", "level_zero_version", "gpu_device", "agama_version"]
}
assert host_info['gpu_device'], "Could not find GPU device description, was capture_device.sh called?"
if not host_info["gpu_device"]:
raise RuntimeError(f"Could not find GPU device description, was capture_device.sh called?")
for name, val in host_info.items():
df_results[name] = val

Expand Down
4 changes: 2 additions & 2 deletions scripts/get_failed_cases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import xml.etree.ElementTree as ET
from defusedxml.ElementTree import parse
import argparse


Expand All @@ -12,7 +12,7 @@ def create_argument_parser() -> argparse.ArgumentParser:

def extract_failed_from_xml(in_file: str, out_file: str):
"""Process XML log file and output failed cases."""
root = ET.parse(in_file).getroot()
root = parse(in_file).getroot()
failed = []

for testcase in root.findall('.//testcase'):
Expand Down
4 changes: 2 additions & 2 deletions scripts/pass_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import pathlib
import platform
import xml.etree.ElementTree as ET
from defusedxml.ElementTree import parse
from typing import List


Expand Down Expand Up @@ -54,7 +54,7 @@ def get_deselected(report_path: pathlib.Path) -> int:
def parse_report(report_path: pathlib.Path) -> ReportStats:
"""Parses the specified report."""
stats = ReportStats(name=report_path.stem)
root = ET.parse(report_path).getroot()
root = parse(report_path).getroot()
for testsuite in root:
testsuite_fixme_tests = set()
stats.total += int(testsuite.get('tests'))
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export TRITON_PROJ=$BASE/intel-xpu-backend-for-triton
export TRITON_PROJ_BUILD=$TRITON_PROJ/python/build
export SCRIPTS_DIR=$(cd $(dirname "$0") && pwd)

python3 -m pip install lit pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout setuptools==69.5.1
python3 -m pip install lit pytest pytest-xdist pytest-rerunfailures pytest-select pytest-timeout setuptools==69.5.1 defusedxml

if [ "$TRITON_TEST_WARNING_REPORTS" == true ]; then
python3 -m pip install git+https://github.com/kwasd/[email protected]
Expand Down
9 changes: 5 additions & 4 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def __post_init__(self):
extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH",
str(default_libdir / 'libsycl-spir64-unknown-unknown.bc'))
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
"num_warps must be a power of 2"
if self.num_warps <= 0 or (self.num_warps & (self.num_warps - 1)) != 0:
raise AssertionError(f"num_warps must be a power of 2")

def hash(self):
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
return hashlib.md5(key.encode("utf-8")).hexdigest()
return hashlib.sha256(key.encode("utf-8")).hexdigest()


def min_dot_size(device_props: dict):
Expand Down Expand Up @@ -117,7 +117,8 @@ def supports_target(target: tuple):

def __init__(self, target: tuple) -> None:
super().__init__(target)
assert isinstance(target.arch, dict)
if not isinstance(target.arch, dict):
raise TypeError(f"target.arch is not a dict")
self.properties = self.parse_target(target.arch)
self.binary_ext = "spv"

Expand Down
2 changes: 1 addition & 1 deletion third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def compile_module_from_src(src, name):
key = hashlib.md5(src.encode("utf-8")).hexdigest()
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
if cache_path is None:
Expand Down

0 comments on commit 2d59963

Please sign in to comment.