Skip to content

Commit

Permalink
Fixes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
kratman committed Sep 19, 2023
1 parent 446ada2 commit 4df4850
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
11 changes: 10 additions & 1 deletion pybamm/parameters/parameter_sets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
import importlib.metadata
import textwrap
Expand Down Expand Up @@ -37,9 +38,17 @@ class ParameterSets(Mapping):
def __init__(self):
# Dict of entry points for parameter sets, lazily load entry points as
self.__all_parameter_sets = dict()
for entry_point in importlib.metadata.entry_points()["pybamm_parameter_sets"]:
for entry_point in self.get_entries("pybamm_parameter_sets"):
self.__all_parameter_sets[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
# Wrapper for the importlib version logic
if sys.version_info < (3, 10):
return importlib.metadata.entry_points()[group_name]
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls):
"""Ensure only one instance of ParameterSets exists"""
if not hasattr(cls, "instance"):
Expand Down
6 changes: 3 additions & 3 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from warnings import warn

import numpy as np
import importlib.metadata as il
import importlib.metadata

import pybamm

Expand Down Expand Up @@ -272,8 +272,8 @@ def have_jax():
def is_jax_compatible():
"""Check if the available version of jax and jaxlib are compatible with PyBaMM"""
return (
il.distribution("jax").version.startswith(JAX_VERSION)
and il.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
importlib.metadata.distribution("jax").version.startswith(JAX_VERSION)
and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
)


Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_parameters/test_parameter_sets_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tests import TestCase

import pybamm
import importlib.metadata as il
import unittest


Expand All @@ -25,7 +24,7 @@ def test_all_registered(self):
"""Check that all parameter sets have been registered with the
``pybamm_parameter_sets`` entry point"""
known_entry_points = set(
ep.name for ep in il.entry_points()["pybamm_parameter_sets"]
ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets")
)
self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points)
self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets))
Expand Down

0 comments on commit 4df4850

Please sign in to comment.