diff --git a/pybamm/parameters/parameter_sets.py b/pybamm/parameters/parameter_sets.py index 1da7f239dd..6c6201d9af 100644 --- a/pybamm/parameters/parameter_sets.py +++ b/pybamm/parameters/parameter_sets.py @@ -1,3 +1,4 @@ +import sys import warnings import importlib.metadata import textwrap @@ -37,9 +38,17 @@ class ParameterSets(Mapping): def __init__(self): # Dict of entry points for parameter sets, lazily load entry points as self.__all_parameter_sets = dict() - for entry_point in importlib.metadata.entry_points()["pybamm_parameter_sets"]: + for entry_point in self.get_entries("pybamm_parameter_sets"): self.__all_parameter_sets[entry_point.name] = entry_point + @staticmethod + def get_entries(group_name): + # Wrapper for the importlib version logic + if sys.version_info < (3, 10): # pragma: no cover + return importlib.metadata.entry_points()[group_name] + else: + return importlib.metadata.entry_points(group=group_name) + def __new__(cls): """Ensure only one instance of ParameterSets exists""" if not hasattr(cls, "instance"): diff --git a/pybamm/util.py b/pybamm/util.py index b0fa9c822e..562352bfac 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -18,7 +18,7 @@ from warnings import warn import numpy as np -import pkg_resources +import importlib.metadata import pybamm @@ -271,9 +271,10 @@ def have_jax(): def is_jax_compatible(): """Check if the available version of jax and jaxlib are compatible with PyBaMM""" - return pkg_resources.get_distribution("jax").version.startswith( - JAX_VERSION - ) and pkg_resources.get_distribution("jaxlib").version.startswith(JAXLIB_VERSION) + return ( + importlib.metadata.distribution("jax").version.startswith(JAX_VERSION) + and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION) + ) def is_constant_and_can_evaluate(symbol): diff --git a/tests/unit/test_parameters/test_parameter_sets_class.py b/tests/unit/test_parameters/test_parameter_sets_class.py index f548fd7955..b14000f987 100644 --- a/tests/unit/test_parameters/test_parameter_sets_class.py +++ b/tests/unit/test_parameters/test_parameter_sets_class.py @@ -4,7 +4,6 @@ from tests import TestCase import pybamm -import pkg_resources import unittest @@ -25,7 +24,7 @@ def test_all_registered(self): """Check that all parameter sets have been registered with the ``pybamm_parameter_sets`` entry point""" known_entry_points = set( - ep.name for ep in pkg_resources.iter_entry_points("pybamm_parameter_sets") + ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets") ) self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points) self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets))