Skip to content

Commit

Permalink
fix python black
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Aug 1, 2024
1 parent 8a474a3 commit 7c3e4b4
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions dipu/scripts/ci/ci_run_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ def run_cmd(cmd: str) -> None:


def parse_device_task(device):
#get json file path
# get json file path
device_config = dict()
current_path = os.path.dirname(os.path.realpath(__file__))
config_path = current_path+"/test_perf_config.json"
config_path = current_path + "/test_perf_config.json"
with open(config_path) as json_config:
json_content = json.loads(json_config.read())
if device in json_content:
device_config = json_content[device]
return device_config


def process_test_perf(log_file, clear_log, task: dict) -> None:
# READ CONFIG

task_name = task["name"]
storage_path = os.getcwd() + "/perf_data/" + task_name
partition = task["partition"]
Expand All @@ -47,20 +48,22 @@ def process_test_perf(log_file, clear_log, task: dict) -> None:
op_args = task["op_args"]

os.environ["ONE_ITER_TOOL_STORAGE_PATH"] = storage_path
os.environ["DIPU_FORCE_FALLBACK_OPS_LIST"] = task["fallback_op_list"] if "fallback_op_list" in task else ""
os.environ["DIPU_FORCE_FALLBACK_OPS_LIST"] = (
task["fallback_op_list"] if "fallback_op_list" in task else ""
)

logging.info(f"task_name = {task_name}")

if not os.path.exists(storage_path):
os.makedirs(storage_path)

#GENERATE RUN COMMAND
# GENERATE RUN COMMAND
cmd_run_test_perf = f"srun --job-name={job_name} --partition={partition} --gres={gpu_requests} python {task_script} {op_args}"
if device == "sco":
current_path = os.getcwd()
parent_directory = os.path.dirname(current_path)
cmd_run_test_perf = f"""srun --job-name={job_name} bash -c "cd {parent_directory} && source scripts/ci/ci_one_iter.sh export_pythonpath_cuda {current_path} && source /mnt/cache/share/deeplinkci/github/dipu_env && cd mmlab_pack && source environment_exported && export ONE_ITER_TOOL_STORAGE_PATH={storage_path} && python {current_path}/{task_script}" """

print(cmd_run_test_perf)

current_path = os.getcwd()
Expand All @@ -70,23 +73,25 @@ def process_test_perf(log_file, clear_log, task: dict) -> None:
else:
run_cmd(cmd_run_test_perf + f" 2>&1 >> {current_path}/{log_file}")
os.chdir(current_path)
print("MATCH_PATTERN:",filter_pattern)

print("MATCH_PATTERN:", filter_pattern)
import re

log_content = open(f"{current_path}/{log_file}").read()
pattern = re.compile(filter_pattern)
match_result = pattern.search(log_content)
run_perf = 0.0

if match_result:
match_result = match_result.group(0)
float_pattern = re.compile("\d+(\.\d+)?")
run_perf = float(float_pattern.search(match_result).group(0))
print("RUNNING PERF:{}".format(run_perf))


def run_perf_task(device_config):
error_flag = multiprocessing.Value("i", 0) # if encount error

device = device_config["name"]

logging.info("we use {}!".format(device))
Expand All @@ -96,20 +101,20 @@ def run_perf_task(device_config):
os.environ["DIPU_DUMP_OP_ARGS"] = "0"
os.environ["DIPU_DEBUG_ALLOCATOR"] = "0"
os.environ["ONE_ITER_TOOL_DEVICE"] = "dipu"

current_path = os.path.dirname(os.path.realpath(__file__))
env_file_path = os.path.join(current_path, "environment_exported")
env_variables = os.environ
keywords_to_filter = ["DIPU", "ONE_ITER"]
if os.path.exists(env_file_path):
os.remove(env_file_path)

with open("environment_exported", "w") as file:
file.write("pwd\n")
for key, value in env_variables.items():
if any(keyword in key for keyword in keywords_to_filter):
file.write(f'export {key}="{value}"\n')

tasks = device_config["tasks"]
logging.info(f"tasks nums: {len(tasks)}")

Expand All @@ -127,8 +132,8 @@ def run_perf_task(device_config):
process_test_perf,
args=(
log_file,
True,
task,
True,
task,
),
error_callback=handle_error,
)
Expand All @@ -144,6 +149,7 @@ def run_perf_task(device_config):
logging.error(e)
exit(1)


def handle_error(error: str) -> None:
logging.error(f"Error: {error}")
if pool is not None:
Expand Down Expand Up @@ -172,7 +178,5 @@ def print_file(file_name):
device_config = parse_device_task(device)
print(device_config)

logging.info(
f"device: {device}, job_name: {job_name}"
)
logging.info(f"device: {device}, job_name: {job_name}")
run_perf_task(device_config)

0 comments on commit 7c3e4b4

Please sign in to comment.