Skip to content

Commit

Permalink
update to check_config workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-rakowski committed Jan 3, 2024
1 parent 4243c33 commit c776b5e
Showing 1 changed file with 176 additions and 96 deletions.
272 changes: 176 additions & 96 deletions py4DSTEM/utils/configuration_checker.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,96 @@
#### this file contains a function/s that will check if various
# libaries/compute options are available
import importlib
from operator import mod

# list of modules we expect/may expect to be installed
# as part of a standard py4DSTEM installation
# this needs to be the import name e.g. import mp_api not mp-api
modules = [
"crystal4D",
"cupy",
"dask",
"dill",
"distributed",
"gdown",
"h5py",
"ipyparallel",
"jax",
"matplotlib",
"mp_api",
"ncempy",
"numba",
"numpy",
"pymatgen",
"skimage",
"sklearn",
"scipy",
"tensorflow",
"tensorflow-addons",
"tqdm",
]

# currently this was copy and pasted from setup.py,
# hopefully there's a programatic way to do this.
module_depenencies = {
"base": [
"numpy",
"scipy",
"h5py",
"ncempy",
"matplotlib",
"skimage",
"sklearn",
"tqdm",
"dill",
"gdown",
"dask",
"distributed",
],
"ipyparallel": ["ipyparallel", "dill"],
"cuda": ["cupy"],
"acom": ["pymatgen", "mp_api"],
"aiml": ["tensorflow", "tensorflow-addons", "crystal4D"],
"aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"],
"numba": ["numba"],
from importlib.metadata import requires
import re
from importlib.util import find_spec

# need a mapping of pypi/conda names to import names
import_mapping_dict = {
"scikit-image": "skimage",
"scikit-learn": "sklearn",
"scikit-optimize": "skopt",
"mp-api": "mp_api",
}


# programatically get all possible requirements in the import name style
def get_modules_list():
# Get the dependencies from the installed distribution
dependencies = requires("py4DSTEM")

# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")

# Extract only the module names without versions
module_names = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in dependencies
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for index, module in enumerate(module_names):
if module in import_mapping_dict.keys():
module_names[index] = import_mapping_dict[module]

return module_names


# programatically get all possible requirements in the import name style,
# split into a dict where optional import names are keys
def get_modules_dict():
package_name = "py4DSTEM"
# Get the dependencies from the installed distribution
dependencies = requires(package_name)

# set the dictionary for modules and packages to go into
# optional dependencies will be added as they are discovered
modules_dict = {
"base": [],
}
# loop over the dependencies
for depend in dependencies:
# all the optional have extra in the name
# if its not there append it to base
if "extra" not in depend:
# String looks like: 'numpy>=1.19'
modules_dict["base"].append(depend)

# if it has extra in the string
else:
# get the name of the optional name
# depend looks like this 'numba>=0.49.1; extra == "numba"'
# grab whatever is in the double quotes i.e. numba
optional_name = re.search(r'"(.*?)"', depend).group(1)
# if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency
if optional_name not in modules_dict:
modules_dict[optional_name] = [depend]
# if the optional_name is already in the dict then just append it to the list
else:
modules_dict[optional_name].append(depend)
# STRIP all the versioning and semi-colons
# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")
for key, val in modules_dict.items():
# modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val]
modules_dict[key] = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in val
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for key, val in modules_dict.items():
for index, module in enumerate(val):
if module in import_mapping_dict.keys():
val[index] = import_mapping_dict[module]

return modules_dict


# module_depenencies = get_modules_dict()
modules = get_modules_list()


#### Class and Functions to Create Coloured Strings ####
class colours:
CEND = "\x1b[0m"
Expand Down Expand Up @@ -140,6 +175,7 @@ def create_underline(s: str) -> str:
### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used


# get the state of each modules as a dict key-val e.g. "numpy" : True
def get_import_states(modules: list = modules) -> dict:
"""
Check the ability to import modules and store the results as a boolean value. Returns as a dict.
Expand All @@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict:
return import_states_dict


# Check
def get_module_states(state_dict: dict) -> dict:
"""_summary_
Args:
state_dict (dict): _description_
"""
given a state dict for all modules e.g. "numpy" : True,
this parses through and checks if all modules required for a state are true
Returns:
dict: _description_
returns dict "base": True, "ai-ml": False etc.
"""

# get the modules_dict
module_depenencies = get_modules_dict()
# create an empty dict to put module states into:
module_states = {}

Expand All @@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict:


def print_import_states(import_states: dict) -> None:
"""_summary_
Args:
import_states (dict): _description_
"""
print with colours if the library could be imported or not
takes dict
"numpy" : True -> prints success
"pymatgen" : False -> prints failure
Returns:
_type_: _description_
"""
# m is the name of the import module
# state is whether it was importable
Expand All @@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None:


def print_module_states(module_states: dict) -> None:
"""_summary_
Args:
module_states (dict): _description_
Returns:
_type_: _description_
"""
print with colours if all the imports required for module could be imported or not
takes dict
"base" : True -> prints success
"ai-ml" : Fasle -> prints failure
"""
# Print out the state of all the modules in colour code
# key is the name of a py4DSTEM Module
Expand All @@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None:
return None


def perfrom_extra_checks(
def perform_extra_checks(
import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs
) -> None:
"""_summary_
Args:
import_states (dict): _description_
verbose (bool): _description_
gratuitously_verbose (bool): _description_
import_states (dict): dict of modules and if they could be imported or not
verbose (bool): will show module states and all import states
gratuitously_verbose (bool): will run extra checks - Currently only for cupy
Returns:
_type_: _description_
"""

# print a output module
extra_checks_message = "Running Extra Checks"
extra_checks_message = create_bold(extra_checks_message)
print(f"{extra_checks_message}")
# For modules that import run any extra checks
if gratuitously_verbose:
# print a output module
extra_checks_message = "Running Extra Checks"
extra_checks_message = create_bold(extra_checks_message)
print(f"{extra_checks_message}")
# For modules that import run any extra checks
# get all the dependencies
dependencies = requires("py4DSTEM")
# Extract only the module names with versions
depends_with_requirements = [
dependency.split(";")[0] for dependency in dependencies
]
# print(depends_with_requirements)
# need to go from
for key, val in import_states.items():
if val:
# s = create_underline(key.capitalize())
Expand All @@ -281,7 +323,10 @@ def perfrom_extra_checks(
if gratuitously_verbose:
s = create_underline(key.capitalize())
print(s)
print_no_extra_checks(key)
# check
generic_versions(
key, depends_with_requires=depends_with_requirements
)
else:
pass

Expand All @@ -304,7 +349,7 @@ def import_tester(m: str) -> bool:
# try and import the module
try:
importlib.import_module(m)
except:
except Exception:
state = False

return state
Expand All @@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None:

# create an empty dict to put module states into:
module_states = {}
module_depenencies = get_modules_dict()

# key is the name of the module e.g. ACOM
# val is a list of its dependencies
Expand Down Expand Up @@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None:
#### ADDTIONAL CHECKS ####


def generic_versions(module: str, depends_with_requires: list[str]) -> None:
# module will be like numpy, skimage
# depends_with_requires look like: numpy >= 19.0, scikit-image
# get module_translated_name
# mapping scikit-image : skimage
for key, value in import_mapping_dict.items():
# if skimage == skimage get scikit-image
# print(f"{key = } - {value = } - {module = }")
if module in value:
module_depend_name = key
break
else:
# if cant find mapping set the search name to the same
module_depend_name = module
# print(f"{module_depend_name = }")
# find the requirement
for depend in depends_with_requires:
if module_depend_name in depend:
spec_required = depend
# print(f"{spec_required = }")
# get the version installed
spec_installed = find_spec(module)
if spec_installed is None:
s = f"{module} unable to import - {spec_required} required"
s = create_failure(s)
s = f"{s: <80}"
print(s)

else:
try:
version = importlib.metadata.version(module_depend_name)
except Exception:
version = "Couldn't test version"
s = f"{module} imported: {version = } - {spec_required} required"
s = create_warning(s)
s = f"{s: <80}"
print(s)


def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
"""
This function performs some additional tests which may be useful in
Expand All @@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
# check that CUDA is detected correctly
cuda_availability = cp.cuda.is_available()
if cuda_availability:
s = " CUDA is Available "
s = f" CUDA is Available "
s = create_success(s)
s = f"{s: <80}"
print(s)
else:
s = " CUDA is Unavailable "
s = f" CUDA is Unavailable "
s = create_failure(s)
s = f"{s: <80}"
print(s)

# Count how many GPUs Cupy can detect
# probably should change this to a while loop ...
for i in range(24):
try:
d = cp.cuda.Device(i)
hasattr(d, "attributes")
except:
num_gpus_detected = i
break
num_gpus_detected = cp.cuda.runtime.getDeviceCount()

# print how many GPUs were detected, filter for a couple of special conditons
if num_gpus_detected == 0:
Expand Down Expand Up @@ -448,7 +526,9 @@ def print_no_extra_checks(m: str):


# dict of extra check functions
funcs_dict = {"cupy": check_cupy_gpu}
funcs_dict = {
"cupy": check_cupy_gpu,
}


#### main function used to check the configuration of the installation
Expand Down Expand Up @@ -493,10 +573,10 @@ def check_config(

print_import_states(states_dict)

perfrom_extra_checks(
perform_extra_checks(
import_states=states_dict,
verbose=verbose,
gratuitously_verbose=gratuitously_verbose,
)

return None
return None

0 comments on commit c776b5e

Please sign in to comment.