Skip to content

Commit

Permalink
using importlib to populate requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-rakowski committed Nov 13, 2023
1 parent 5275729 commit 96c96ae
Showing 1 changed file with 109 additions and 76 deletions.
185 changes: 109 additions & 76 deletions py4DSTEM/utils/configuration_checker.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,93 @@
#### this file contains a function/s that will check if various
# libaries/compute options are available
import importlib

# list of modules we expect/may expect to be installed
# as part of a standard py4DSTEM installation
# this needs to be the import name e.g. import mp_api not mp-api
# TODO use importlib.metadata.requirements to populate
modules = [
"crystal4D",
"cupy",
"dask",
"dill",
"distributed",
"gdown",
"h5py",
"ipyparallel",
"jax",
"matplotlib",
"mp_api",
"ncempy",
"numba",
"numpy",
"pymatgen",
"skimage",
"sklearn",
"scipy",
"tensorflow",
"tensorflow-addons",
"tqdm",
]

# currently this was copy and pasted from setup.py,
# hopefully there's a programatic way to do this.
module_depenencies = {
"base": [
"numpy",
"scipy",
"h5py",
"ncempy",
"matplotlib",
"skimage",
"sklearn",
"tqdm",
"dill",
"gdown",
"dask",
"distributed",
],
"ipyparallel": ["ipyparallel", "dill"],
"cuda": ["cupy"],
"acom": ["pymatgen", "mp_api"],
"aiml": ["tensorflow", "tensorflow-addons", "crystal4D"],
"aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"],
"numba": ["numba"],
from importlib.metadata import requires
import re

# need a mapping of pypi/conda names to import names
import_mapping_dict = {
"scikit-image": "skimage",
"scikit-learn": "sklearn",
"scikit-optimize": "skopt",
"mp-api": "mp_api",
}


# programatically get all possible requirements in the import name style
def get_modules_list():
# Get the dependencies from the installed distribution
dependencies = requires("py4DSTEM")

# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")

# Extract only the module names without versions
module_names = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in dependencies
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for index, module in enumerate(module_names):
if module in import_mapping_dict.keys():
module_names[index] = import_mapping_dict[module]

return module_names


# programatically get all possible requirements in the import name style,
# split into a dict where optional import names are keys
def get_modules_dict():
package_name = "py4DSTEM"
# Get the dependencies from the installed distribution
dependencies = requires(package_name)

# set the dictionary for modules and packages to go into
modules_dict = {
"base": [],
"acom": [],
"aiml": [],
"aiml-cuda": [],
"cuda": [],
"numba": [],
}
# loop over the dependencies
for depend in dependencies:
# all the optional have extra in the name
# if its not there append it to base
if "extra" not in depend:
modules_dict["base"].append(depend)

# if it has extra
else:
# loop over the keys and check if its in there
for key in modules_dict.keys():
if key in depend:
modules_dict[key].append(depend)

# STRIP all the versioning and semi-colons
# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")
for key, val in modules_dict.items():
# modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val]
modules_dict[key] = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in val
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for key, val in modules_dict.items():
for index, module in enumerate(val):
if module in import_mapping_dict.keys():
val[index] = import_mapping_dict[module]

return modules_dict


# module_depenencies = get_modules_dict()
modules = get_modules_list()


#### Class and Functions to Create Coloured Strings ####
class colours:
CEND = "\x1b[0m"
Expand Down Expand Up @@ -140,6 +172,7 @@ def create_underline(s: str) -> str:
### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used


# get the state of each modules as a dict key-val e.g. "numpy" : True
def get_import_states(modules: list = modules) -> dict:
"""
Check the ability to import modules and store the results as a boolean value. Returns as a dict.
Expand All @@ -163,16 +196,17 @@ def get_import_states(modules: list = modules) -> dict:
return import_states_dict


# Check
def get_module_states(state_dict: dict) -> dict:
"""_summary_
Args:
state_dict (dict): _description_
"""
given a state dict for all modules e.g. "numpy" : True,
this parses through and checks if all modules required for a state are true
Returns:
dict: _description_
returns dict "base": True, "ai-ml": False etc.
"""

# get the modules_dict
module_depenencies = get_modules_dict()
# create an empty dict to put module states into:
module_states = {}

Expand All @@ -190,19 +224,18 @@ def get_module_states(state_dict: dict) -> dict:

# check that all the depencies could be imported i.e. state == True
# and set the state of the module to that
module_states[key] = all(temp_lst) == True
module_states[key] = all(temp_lst) is True

return module_states


def print_import_states(import_states: dict) -> None:
"""_summary_
Args:
import_states (dict): _description_
"""
print with colours if the library could be imported or not
takes dict
"numpy" : True -> prints success
"pymatgen" : Fasle -> prints failure
Returns:
_type_: _description_
"""
# m is the name of the import module
# state is whether it was importable
Expand All @@ -223,13 +256,11 @@ def print_import_states(import_states: dict) -> None:


def print_module_states(module_states: dict) -> None:
"""_summary_
Args:
module_states (dict): _description_
Returns:
_type_: _description_
"""
print with colours if all the imports required for module could be imported or not
takes dict
"base" : True -> prints success
"ai-ml" : Fasle -> prints failure
"""
# Print out the state of all the modules in colour code
# key is the name of a py4DSTEM Module
Expand All @@ -254,9 +285,9 @@ def perfrom_extra_checks(
"""_summary_
Args:
import_states (dict): _description_
verbose (bool): _description_
gratuitously_verbose (bool): _description_
import_states (dict): dict of modules and if they could be imported or not
verbose (bool): will show module states and all import states
gratuitously_verbose (bool): will run extra checks - Currently only for cupy
Returns:
_type_: _description_
Expand Down Expand Up @@ -324,6 +355,7 @@ def check_module_functionality(state_dict: dict) -> None:

# create an empty dict to put module states into:
module_states = {}
module_depenencies = get_modules_dict()

# key is the name of the module e.g. ACOM
# val is a list of its dependencies
Expand All @@ -338,7 +370,7 @@ def check_module_functionality(state_dict: dict) -> None:

# check that all the depencies could be imported i.e. state == True
# and set the state of the module to that
module_states[key] = all(temp_lst) == True
module_states[key] = all(temp_lst) is True

# Print out the state of all the modules in colour code
for key, val in module_states.items():
Expand Down Expand Up @@ -421,6 +453,7 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
return None


# TODO add generic version which will print version
def print_no_extra_checks(m: str):
"""
This function prints a warning style message that the module m
Expand Down

0 comments on commit 96c96ae

Please sign in to comment.