diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index 904dceb29..0382cda9c 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -1,61 +1,96 @@ #### this file contains a function/s that will check if various # libaries/compute options are available import importlib -from operator import mod - -# list of modules we expect/may expect to be installed -# as part of a standard py4DSTEM installation -# this needs to be the import name e.g. import mp_api not mp-api -modules = [ - "crystal4D", - "cupy", - "dask", - "dill", - "distributed", - "gdown", - "h5py", - "ipyparallel", - "jax", - "matplotlib", - "mp_api", - "ncempy", - "numba", - "numpy", - "pymatgen", - "skimage", - "sklearn", - "scipy", - "tensorflow", - "tensorflow-addons", - "tqdm", -] - -# currently this was copy and pasted from setup.py, -# hopefully there's a programatic way to do this. -module_depenencies = { - "base": [ - "numpy", - "scipy", - "h5py", - "ncempy", - "matplotlib", - "skimage", - "sklearn", - "tqdm", - "dill", - "gdown", - "dask", - "distributed", - ], - "ipyparallel": ["ipyparallel", "dill"], - "cuda": ["cupy"], - "acom": ["pymatgen", "mp_api"], - "aiml": ["tensorflow", "tensorflow-addons", "crystal4D"], - "aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"], - "numba": ["numba"], +from importlib.metadata import requires +import re +from importlib.util import find_spec + +# need a mapping of pypi/conda names to import names +import_mapping_dict = { + "scikit-image": "skimage", + "scikit-learn": "sklearn", + "scikit-optimize": "skopt", + "mp-api": "mp_api", } +# programatically get all possible requirements in the import name style +def get_modules_list(): + # Get the dependencies from the installed distribution + dependencies = requires("py4DSTEM") + + # Define a regular expression pattern for splitting on '>', '>=', '=' + delimiter_pattern = re.compile(r">=|>|==|<|<=") + + # Extract only the module names without versions + module_names = [ + delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip() + for dependency in dependencies + ] + + # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api + for index, module in enumerate(module_names): + if module in import_mapping_dict.keys(): + module_names[index] = import_mapping_dict[module] + + return module_names + + +# programatically get all possible requirements in the import name style, +# split into a dict where optional import names are keys +def get_modules_dict(): + package_name = "py4DSTEM" + # Get the dependencies from the installed distribution + dependencies = requires(package_name) + + # set the dictionary for modules and packages to go into + # optional dependencies will be added as they are discovered + modules_dict = { + "base": [], + } + # loop over the dependencies + for depend in dependencies: + # all the optional have extra in the name + # if its not there append it to base + if "extra" not in depend: + # String looks like: 'numpy>=1.19' + modules_dict["base"].append(depend) + + # if it has extra in the string + else: + # get the name of the optional name + # depend looks like this 'numba>=0.49.1; extra == "numba"' + # grab whatever is in the double quotes i.e. numba + optional_name = re.search(r'"(.*?)"', depend).group(1) + # if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency + if optional_name not in modules_dict: + modules_dict[optional_name] = [depend] + # if the optional_name is already in the dict then just append it to the list + else: + modules_dict[optional_name].append(depend) + # STRIP all the versioning and semi-colons + # Define a regular expression pattern for splitting on '>', '>=', '=' + delimiter_pattern = re.compile(r">=|>|==|<|<=") + for key, val in modules_dict.items(): + # modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val] + modules_dict[key] = [ + delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip() + for dependency in val + ] + + # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api + for key, val in modules_dict.items(): + for index, module in enumerate(val): + if module in import_mapping_dict.keys(): + val[index] = import_mapping_dict[module] + + return modules_dict + + +# module_depenencies = get_modules_dict() +modules = get_modules_list() + + #### Class and Functions to Create Coloured Strings #### class colours: CEND = "\x1b[0m" @@ -140,6 +175,7 @@ def create_underline(s: str) -> str: ### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used +# get the state of each modules as a dict key-val e.g. "numpy" : True def get_import_states(modules: list = modules) -> dict: """ Check the ability to import modules and store the results as a boolean value. Returns as a dict. @@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict: return import_states_dict +# Check def get_module_states(state_dict: dict) -> dict: - """_summary_ - - Args: - state_dict (dict): _description_ + """ + given a state dict for all modules e.g. "numpy" : True, + this parses through and checks if all modules required for a state are true - Returns: - dict: _description_ + returns dict "base": True, "ai-ml": False etc. """ + # get the modules_dict + module_depenencies = get_modules_dict() # create an empty dict to put module states into: module_states = {} @@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict: def print_import_states(import_states: dict) -> None: - """_summary_ - - Args: - import_states (dict): _description_ + """ + print with colours if the library could be imported or not + takes dict + "numpy" : True -> prints success + "pymatgen" : False -> prints failure - Returns: - _type_: _description_ """ # m is the name of the import module # state is whether it was importable @@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None: def print_module_states(module_states: dict) -> None: - """_summary_ - - Args: - module_states (dict): _description_ - - Returns: - _type_: _description_ + """ + print with colours if all the imports required for module could be imported or not + takes dict + "base" : True -> prints success + "ai-ml" : Fasle -> prints failure """ # Print out the state of all the modules in colour code # key is the name of a py4DSTEM Module @@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None: return None -def perfrom_extra_checks( +def perform_extra_checks( import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs ) -> None: """_summary_ Args: - import_states (dict): _description_ - verbose (bool): _description_ - gratuitously_verbose (bool): _description_ + import_states (dict): dict of modules and if they could be imported or not + verbose (bool): will show module states and all import states + gratuitously_verbose (bool): will run extra checks - Currently only for cupy Returns: _type_: _description_ """ - - # print a output module - extra_checks_message = "Running Extra Checks" - extra_checks_message = create_bold(extra_checks_message) - print(f"{extra_checks_message}") - # For modules that import run any extra checks + if gratuitously_verbose: + # print a output module + extra_checks_message = "Running Extra Checks" + extra_checks_message = create_bold(extra_checks_message) + print(f"{extra_checks_message}") + # For modules that import run any extra checks + # get all the dependencies + dependencies = requires("py4DSTEM") + # Extract only the module names with versions + depends_with_requirements = [ + dependency.split(";")[0] for dependency in dependencies + ] + # print(depends_with_requirements) + # need to go from for key, val in import_states.items(): if val: # s = create_underline(key.capitalize()) @@ -281,7 +323,10 @@ def perfrom_extra_checks( if gratuitously_verbose: s = create_underline(key.capitalize()) print(s) - print_no_extra_checks(key) + # check + generic_versions( + key, depends_with_requires=depends_with_requirements + ) else: pass @@ -304,7 +349,7 @@ def import_tester(m: str) -> bool: # try and import the module try: importlib.import_module(m) - except: + except Exception: state = False return state @@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None: # create an empty dict to put module states into: module_states = {} + module_depenencies = get_modules_dict() # key is the name of the module e.g. ACOM # val is a list of its dependencies @@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None: #### ADDTIONAL CHECKS #### +def generic_versions(module: str, depends_with_requires: list[str]) -> None: + # module will be like numpy, skimage + # depends_with_requires look like: numpy >= 19.0, scikit-image + # get module_translated_name + # mapping scikit-image : skimage + for key, value in import_mapping_dict.items(): + # if skimage == skimage get scikit-image + # print(f"{key = } - {value = } - {module = }") + if module in value: + module_depend_name = key + break + else: + # if cant find mapping set the search name to the same + module_depend_name = module + # print(f"{module_depend_name = }") + # find the requirement + for depend in depends_with_requires: + if module_depend_name in depend: + spec_required = depend + # print(f"{spec_required = }") + # get the version installed + spec_installed = find_spec(module) + if spec_installed is None: + s = f"{module} unable to import - {spec_required} required" + s = create_failure(s) + s = f"{s: <80}" + print(s) + + else: + try: + version = importlib.metadata.version(module_depend_name) + except Exception: + version = "Couldn't test version" + s = f"{module} imported: {version = } - {spec_required} required" + s = create_warning(s) + s = f"{s: <80}" + print(s) + + def check_cupy_gpu(gratuitously_verbose: bool, **kwargs): """ This function performs some additional tests which may be useful in @@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs): # check that CUDA is detected correctly cuda_availability = cp.cuda.is_available() if cuda_availability: - s = " CUDA is Available " + s = f" CUDA is Available " s = create_success(s) s = f"{s: <80}" print(s) else: - s = " CUDA is Unavailable " + s = f" CUDA is Unavailable " s = create_failure(s) s = f"{s: <80}" print(s) # Count how many GPUs Cupy can detect - # probably should change this to a while loop ... - for i in range(24): - try: - d = cp.cuda.Device(i) - hasattr(d, "attributes") - except: - num_gpus_detected = i - break + num_gpus_detected = cp.cuda.runtime.getDeviceCount() # print how many GPUs were detected, filter for a couple of special conditons if num_gpus_detected == 0: @@ -448,7 +526,9 @@ def print_no_extra_checks(m: str): # dict of extra check functions -funcs_dict = {"cupy": check_cupy_gpu} +funcs_dict = { + "cupy": check_cupy_gpu, +} #### main function used to check the configuration of the installation @@ -493,10 +573,10 @@ def check_config( print_import_states(states_dict) - perfrom_extra_checks( + perform_extra_checks( import_states=states_dict, verbose=verbose, gratuitously_verbose=gratuitously_verbose, ) - return None + return None \ No newline at end of file