diff --git a/run.py b/run.py index f0a5f9a..c4a4eee 100755 --- a/run.py +++ b/run.py @@ -7,7 +7,6 @@ Created: 2021-11-12 Updated: 2022-06-17 """ - # Import standard libraries import argparse from datetime import datetime @@ -47,9 +46,9 @@ def find_myself(flg): as_cli_attr, as_cli_arg, correct_chirality, create_anatomical_average, crop_image, dict_has, dilate_LR_mask, exit_with_time_info, extract_from_json, get_and_make_preBIBSnet_work_dirs, get_optional_args_in, - get_stage_name, get_subj_ID_and_session, make_given_or_default_dir, - resize_images, run_all_stages, valid_readable_json, - validate_parameter_types, valid_readable_dir, will_run_stage + get_stage_name, get_subj_ID_and_session, get_template_age_closest_to, + make_given_or_default_dir, resize_images, run_all_stages, + valid_readable_json, validate_parameter_types, valid_readable_dir ) @@ -356,12 +355,22 @@ def run_BIBSnet(j_args, logger): #sys.path.append("/home/cabinet/SW/BIBSnet") #from BIBSnet.run import run_nnUNet_predict - # Run BIBSnet - run_nnUNet_predict({"model": j_args["BIBSnet"]["model"], - "nnUNet": j_args["BIBSnet"]["nnUNet_predict_path"], - "input": dir_BIBS.format("in"), - "output": dir_BIBS.format("out"), - "task": str(j_args["BIBSnet"]["task"])}) + try: # Run BIBSnet + inputs_BIBSnet = {"model": j_args["BIBSnet"]["model"], + "nnUNet": j_args["BIBSnet"]["nnUNet_predict_path"], + "input": dir_BIBS.format("in"), + "output": dir_BIBS.format("out"), + "task": str(j_args["BIBSnet"]["task"])} + run_nnUNet_predict(**inputs_BIBSnet) + except subprocess.CalledProcessError as e: + # BIBSnet will crash even after correctly creating a segmentation, + # so only crash CABINET if that segmentation is not made. + outfpath = os.path.join(dir_BIBS.format("out"), + "{}_{}_optimal_resized.nii.gz".format(*sub_ses)) + if not os.path.exists(outfpath): + logger.error("BIBSnet failed to create this segmentation file...\n{}\n...from these inputs:\n{}".format(outfpath, inputs_BIBSnet)) + sys.exit(e) + # TODO hardcoded below call to run_nnUNet_predict. Will likely want to change and integrate into j_args #run_nnUNet_predict({"model": "3d_fullres", @@ -384,12 +393,18 @@ def run_postBIBSnet(j_args, logger): # Template selection values age_months = j_args["common"]["age_months"] logger.info("Age of participant: {} months".format(age_months)) - if age_months > 33: - age_months = "34-38" + + # Get template closest to age + tmpl_age = get_template_age_closest_to( + age_months, os.path.join(SCRIPT_DIR, "data", "chirality_masks") + ) + if j_args["common"]["verbose"]: + logger.info("Closest template-age is {} months".format(tmpl_age)) + # if age_months > 33: age_months = "34-38" # Run left/right registration script and chirality correction left_right_mask_nifti_fpath = run_left_right_registration( - j_args, sub_ses, age_months, 2 if int(age_months) < 22 else 1, logger # NOTE 22 cutoff might change + j_args, sub_ses, tmpl_age, 2 if int(age_months) < 22 else 1, logger # NOTE 22 cutoff might change ) logger.info("Left/right image registration completed") @@ -463,13 +478,18 @@ def run_left_right_registration(j_args, sub_ses, age_months, t1or2, logger): msg = "{} left/right registration on {}" if (j_args["common"]["overwrite"] or not os.path.exists(left_right_mask_nifti_fpath)): - logger.info(msg.format("Running", first_subject_head)) + # logger.info(msg.format("Running", first_subject_head)) try: # SubjectHead TemplateHead TemplateMask OutputMaskFile - subprocess.check_call((LR_REGISTR_PATH, first_subject_head, - tmpl_head.format(age_months, t1or2), - tmpl_mask.format(age_months), - left_right_mask_nifti_fpath)) + cmd_LR_reg = (LR_REGISTR_PATH, first_subject_head, + tmpl_head.format(age_months, t1or2), + tmpl_mask.format(age_months), + left_right_mask_nifti_fpath) + if j_args["common"]["verbose"]: + logger.info(msg.format("Now running", "\n".join( + (first_subject_head, " ".join(cmd_LR_reg)) + ))) + subprocess.check_call(cmd_LR_reg) # Tell the user if ANTS crashes due to a memory error except subprocess.CalledProcessError as e: @@ -484,6 +504,7 @@ def run_left_right_registration(j_args, sub_ses, age_months, t1or2, logger): ))) logger.info(msg.format("Finished", first_subject_head)) # TODO Only print this message if not skipped (and do the same for all other stages) return left_right_mask_nifti_fpath + def run_chirality_correction(l_r_mask_nifti_fpath, j_args, logger): diff --git a/src/utilities.py b/src/utilities.py index 2c4ebc2..b30637c 100755 --- a/src/utilities.py +++ b/src/utilities.py @@ -614,6 +614,29 @@ def get_subj_ses(j_args): return "_".join(get_subj_ID_and_session(j_args)) +def get_template_age_closest_to(age, templates_dir): + template_ages = list() + template_ranges = dict() + + # Get list of all int ages (in months) that have template files + for tmpl_path in glob(os.path.join(templates_dir, + "*mo_template_LRmask.nii.gz")): + tmpl_age = os.path.basename(tmpl_path).split("mo", 1)[0] + if "-" in tmpl_age: # len(tmpl_age) <3: + for each_age in tmpl_age.split("-"): + template_ages.append(int(each_age)) + template_ranges[template_ages[-1]] = tmpl_age + # template_ages.append(int(tmpl_age.split("-"))) + else: + template_ages.append(int(tmpl_age)) + + # Get template age closest to subject age, then return template age + closest_age = template_ages[np.argmin(np.abs(np.array(template_ages)-age))] + return (template_ranges[closest_age] if closest_age + in template_ranges else str(closest_age)) #final_template_age + # template_ages = [os.path.basename(f).split("mo", 1)[0] for f in glob(globber)] + + def glob_and_copy(dest_dirpath, *path_parts_to_glob): """ Collect all files matching a glob string, then copy those files