Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace deprecated pkg_resources #3335

Merged
merged 10 commits into from
Sep 22, 2023
11 changes: 10 additions & 1 deletion pybamm/parameters/parameter_sets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
import importlib.metadata
import textwrap
Expand Down Expand Up @@ -37,9 +38,17 @@ class ParameterSets(Mapping):
def __init__(self):
# Dict of entry points for parameter sets, lazily load entry points as
self.__all_parameter_sets = dict()
for entry_point in importlib.metadata.entry_points()["pybamm_parameter_sets"]:
kratman marked this conversation as resolved.
Show resolved Hide resolved
for entry_point in self.get_entries("pybamm_parameter_sets"):
self.__all_parameter_sets[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
# Wrapper for the importlib version logic
if sys.version_info < (3, 10): # pragma: no cover
return importlib.metadata.entry_points()[group_name]
kratman marked this conversation as resolved.
Show resolved Hide resolved
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls):
"""Ensure only one instance of ParameterSets exists"""
if not hasattr(cls, "instance"):
Expand Down
9 changes: 5 additions & 4 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from warnings import warn

import numpy as np
import pkg_resources
import importlib.metadata

import pybamm

Expand Down Expand Up @@ -271,9 +271,10 @@ def have_jax():

def is_jax_compatible():
"""Check if the available version of jax and jaxlib are compatible with PyBaMM"""
return pkg_resources.get_distribution("jax").version.startswith(
JAX_VERSION
) and pkg_resources.get_distribution("jaxlib").version.startswith(JAXLIB_VERSION)
return (
importlib.metadata.distribution("jax").version.startswith(JAX_VERSION)
and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
)


def is_constant_and_can_evaluate(symbol):
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_parameters/test_parameter_sets_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tests import TestCase

import pybamm
import pkg_resources
import unittest


Expand All @@ -25,7 +24,7 @@ def test_all_registered(self):
"""Check that all parameter sets have been registered with the
``pybamm_parameter_sets`` entry point"""
known_entry_points = set(
ep.name for ep in pkg_resources.iter_entry_points("pybamm_parameter_sets")
ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets")
)
self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points)
self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets))
Expand Down
Loading