diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cade77684..e06cd3a93a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ - Serialisation added so models can be written to/read from JSON ([#3397](https://github.com/pybamm-team/PyBaMM/pull/3397)) - Added a `get_parameter_info` method for models and modified "print_parameter_info" functionality to extract all parameters and their type in a tabular and readable format ([#3584](https://github.com/pybamm-team/PyBaMM/pull/3584)) - Mechanical parameters are now a function of stoichiometry and temperature ([#3576](https://github.com/pybamm-team/PyBaMM/pull/3576)) +- Added `by_submodel` feature in `print_parameter_info` method to allow users to print parameters and types of submodels in a tabular and readable format ([#3628](https://github.com/pybamm-team/PyBaMM/pull/3628)) ## Bug fixes diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index b6b5a9b2da..7fe0535849 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -104,6 +104,7 @@ def __init__(self, name="Unnamed model"): self._algebraic = {} self._initial_conditions = {} self._boundary_conditions = {} + self._variables_by_submodel = {} self._variables = pybamm.FuzzyDict({}) self._events = [] self._concatenated_rhs = None @@ -422,83 +423,214 @@ def input_parameters(self): self._input_parameters = self._find_symbols(pybamm.InputParameter) return self._input_parameters - def get_parameter_info(self): + def get_parameter_info(self, by_submodel=False): """ Extracts the parameter information and returns it as a dictionary. To get a list of all parameter-like objects without extra information, use :py:attr:`model.parameters`. + + Parameters + ---------- + by_submodel : bool, optional + Whether to return the parameter info sub-model wise or not (default False) """ parameter_info = {} - parameters = self._find_symbols(pybamm.Parameter) - for param in parameters: - parameter_info[param.name] = (param, "Parameter") - - input_parameters = self._find_symbols(pybamm.InputParameter) - for input_param in input_parameters: - if not input_param.domain: - parameter_info[input_param.name] = (input_param, "InputParameter") - else: - parameter_info[input_param.name] = ( - input_param, - f"InputParameter in {input_param.domain}", + + if by_submodel: + for submodel_name, submodel_vars in self._variables_by_submodel.items(): + submodel_info = {} + for var_name, var_symbol in submodel_vars.items(): + if isinstance(var_symbol, pybamm.Parameter): + submodel_info[var_name] = (var_symbol, "Parameter") + elif isinstance(var_symbol, pybamm.InputParameter): + if not var_symbol.domain: + submodel_info[var_name] = (var_symbol, "InputParameter") + else: + submodel_info[var_name] = ( + var_symbol, + f"InputParameter in {var_symbol.domain}", + ) + elif isinstance(var_symbol, pybamm.FunctionParameter): + input_names = "', '".join(var_symbol.input_names) + submodel_info[var_name] = ( + var_symbol, + f"FunctionParameter with inputs(s) '{input_names}'", + ) + + parameters = self._find_symbols_by_submodel( + pybamm.Parameter, submodel_name ) + for param in parameters: + submodel_info[param.name] = (param, "Parameter") - function_parameters = self._find_symbols(pybamm.FunctionParameter) - for func_param in function_parameters: - if func_param.name not in parameter_info: - input_names = "', '".join(func_param.input_names) - parameter_info[func_param.name] = ( - func_param, - f"FunctionParameter with inputs(s) '{input_names}'", + input_parameters = self._find_symbols_by_submodel( + pybamm.InputParameter, submodel_name ) + for input_param in input_parameters: + if not input_param.domain: + submodel_info[input_param.name] = ( + input_param, + "InputParameter", + ) + else: + submodel_info[input_param.name] = ( + input_param, + f"InputParameter in {input_param.domain}", + ) - return parameter_info + function_parameters = self._find_symbols_by_submodel( + pybamm.FunctionParameter, submodel_name + ) + for func_param in function_parameters: + if func_param.name not in parameter_info: + input_names = "', '".join(func_param.input_names) + submodel_info[func_param.name] = ( + func_param, + f"FunctionParameter with inputs(s) '{input_names}'", + ) + + parameter_info[submodel_name] = submodel_info + + else: + parameters = self._find_symbols(pybamm.Parameter) + for param in parameters: + parameter_info[param.name] = (param, "Parameter") + + input_parameters = self._find_symbols(pybamm.InputParameter) + for input_param in input_parameters: + if not input_param.domain: + parameter_info[input_param.name] = (input_param, "InputParameter") + else: + parameter_info[input_param.name] = ( + input_param, + f"InputParameter in {input_param.domain}", + ) - def print_parameter_info(self): - """Print parameter information in a formatted table from a dictionary of parameters""" - info = self.get_parameter_info() - max_param_name_length = 0 - max_param_type_length = 0 + function_parameters = self._find_symbols(pybamm.FunctionParameter) + for func_param in function_parameters: + if func_param.name not in parameter_info: + input_names = "', '".join(func_param.input_names) + parameter_info[func_param.name] = ( + func_param, + f"FunctionParameter with inputs(s) '{input_names}'", + ) - for param, param_type in info.values(): - param_name_length = len(getattr(param, "name", str(param))) - param_type_length = len(param_type) - max_param_name_length = max(max_param_name_length, param_name_length) - max_param_type_length = max(max_param_type_length, param_type_length) + return parameter_info - header_format = ( - f"| {{:<{max_param_name_length}}} | {{:<{max_param_type_length}}} |" + def _calculate_max_lengths(self, parameter_dict): + """ + Calculate the maximum length of parameters and parameter type in a dictionary + + Parameters + ---------- + parameter_dict : dict + The dict from which maximum lengths are calculated + """ + max_name_length = max( + len(getattr(parameter, "name", str(parameter))) + for parameter, _ in parameter_dict.values() ) - row_format = ( - f"| {{:<{max_param_name_length}}} | {{:<{max_param_type_length}}} |" + max_type_length = max( + len(parameter_type) for _, parameter_type in parameter_dict.values() ) - table = [ - header_format.format("Parameter", "Type of parameter"), - header_format.format( - "=" * max_param_name_length, "=" * max_param_type_length - ), + return max_name_length, max_type_length + + def _format_table_row( + self, param_name, param_type, max_name_length, max_type_length + ): + """ + Format the parameter information in a formatted table + + Parameters + ---------- + param_name : str + The name of the parameter + param_type : str + The type of the parameter + max_name_length : int + The maximum length of the parameter in the dictionary + max_type_length : int + The maximum length of the parameter type in the dictionary + """ + param_name_lines = [ + param_name[i : i + max_name_length] + for i in range(0, len(param_name), max_name_length) ] + param_type_lines = [ + param_type[i : i + max_type_length] + for i in range(0, len(param_type), max_type_length) + ] + max_lines = max(len(param_name_lines), len(param_type_lines)) - for param, param_type in info.values(): - param_name = getattr(param, "name", str(param)) - param_name_lines = [ - param_name[i : i + max_param_name_length] - for i in range(0, len(param_name), max_param_name_length) - ] - param_type_lines = [ - param_type[i : i + max_param_type_length] - for i in range(0, len(param_type), max_param_type_length) + return [ + f"| {param_name_lines[i]:<{max_name_length}} | {param_type_lines[i]:<{max_type_length}} |" + for i in range(max_lines) + ] + + def print_parameter_info(self, by_submodel=False): + """ + Print parameter information in a formatted table from a dictionary of parameters + + Parameters + ---------- + by_submodel : bool, optional + Whether to print the parameter info sub-model wise or not (default False) + """ + + if by_submodel: + parameter_info = self.get_parameter_info(by_submodel=True) + for submodel_name, submodel_vars in parameter_info.items(): + if not submodel_vars: + print(f"'{submodel_name}' submodel parameters: \nNo parameters\n") + else: + print(f"'{submodel_name}' submodel parameters:") + ( + max_param_name_length, + max_param_type_length, + ) = self._calculate_max_lengths(submodel_vars) + + table = [ + f"| {'Parameter':<{max_param_name_length}} | {'Type of parameter':<{max_param_type_length}} |", + f"| {'=' * max_param_name_length} | {'=' * max_param_type_length} |", + ] + + for param, param_type in submodel_vars.values(): + param_name = getattr(param, "name", str(param)) + table.extend( + self._format_table_row( + param_name, + param_type, + max_param_name_length, + max_param_type_length, + ) + ) + + print("\n".join(table) + "\n") + + else: + info = self.get_parameter_info() + max_param_name_length, max_param_type_length = self._calculate_max_lengths( + info + ) + + table = [ + f"| {'Parameter':<{max_param_name_length}} | {'Type of parameter':<{max_param_type_length}} |", + f"| {'=' * max_param_name_length} | {'=' * max_param_type_length} |", ] - max_lines = max(len(param_name_lines), len(param_type_lines)) - for i in range(max_lines): - param_line = param_name_lines[i] if i < len(param_name_lines) else "" - type_line = param_type_lines[i] if i < len(param_type_lines) else "" - table.append(row_format.format(param_line, type_line)) + for param, param_type in info.values(): + param_name = getattr(param, "name", str(param)) + table.extend( + self._format_table_row( + param_name, + param_type, + max_param_name_length, + max_param_type_length, + ) + ) - for line in table: - print(line) + print("\n".join(table) + "\n") def _find_symbols(self, typ): """Find all the instances of `typ` in the model""" @@ -517,6 +649,23 @@ def _find_symbols(self, typ): ) return list(all_input_parameters) + def _find_symbols_by_submodel(self, typ, submodel): + """Find all the instances of `typ` in the submodel""" + unpacker = pybamm.SymbolUnpacker(typ) + all_input_parameters = unpacker.unpack_list_of_symbols( + list(self.submodels[submodel].rhs.values()) + + list(self.submodels[submodel].algebraic.values()) + + list(self.submodels[submodel].initial_conditions.values()) + + [ + x[side][0] + for x in self.submodels[submodel].boundary_conditions.values() + for side in x.keys() + ] + + list(self._variables_by_submodel[submodel].values()) + + [event.expression for event in self.submodels[submodel].events] + ) + return list(all_input_parameters) + def new_copy(self): """ Creates a copy of the model, explicitly copying all the mutable attributes @@ -556,13 +705,18 @@ def update(self, *submodels): def build_fundamental(self): # Get the fundamental variables + self._variables_by_submodel = {submodel: {} for submodel in self.submodels} for submodel_name, submodel in self.submodels.items(): pybamm.logger.debug( "Getting fundamental variables for {} submodel ({})".format( submodel_name, self.name ) ) - self.variables.update(submodel.get_fundamental_variables()) + submodel_fundamental_variables = submodel.get_fundamental_variables() + self._variables_by_submodel[submodel_name].update( + submodel_fundamental_variables + ) + self.variables.update(submodel_fundamental_variables) self._built_fundamental = True @@ -586,9 +740,18 @@ def build_coupled_variables(self): ) ) try: - self.variables.update( - submodel.get_coupled_variables(self.variables) + model_var_copy = self.variables.copy() + updated_variables = submodel.get_coupled_variables( + self.variables + ) + self._variables_by_submodel[submodel_name].update( + { + key: updated_variables[key] + for key in updated_variables + if key not in model_var_copy + } ) + self.variables.update(updated_variables) submodels.remove(submodel_name) except KeyError as key: if len(submodels) == 1 or count == 100: diff --git a/pybamm/models/submodels/base_submodel.py b/pybamm/models/submodels/base_submodel.py index 225ae83705..f2c415cb64 100644 --- a/pybamm/models/submodels/base_submodel.py +++ b/pybamm/models/submodels/base_submodel.py @@ -70,7 +70,6 @@ def __init__( super().__init__(name) self.domain = domain self.name = name - self.external = external if options is None or type(options) == dict: # noqa: E721 @@ -135,6 +134,23 @@ def domain(self, domain): def domain_Domain(self): return self._domain, self._Domain + def get_parameter_info(self, by_submodel=False): + """ + Extracts the parameter information and returns it as a dictionary. + To get a list of all parameter-like objects without extra information, + use :py:attr:`model.parameters`. + + Returns + ------- + NotImplementedError: + This method is not available for direct use on submodels since the submodel may contain coupled variables + that depend on other submodels, in which case some parameters may be missed. + It is recommended to use on the full model. + """ + raise NotImplementedError( + "Cannot use get_parameter_info OR print_parameter_info directly on a submodel. Please use it on the full model." + ) + def get_fundamental_variables(self): """ A public method that creates and returns the variables in a submodel which can diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 538765c48d..0e6db5e15f 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -6,6 +6,8 @@ import platform import subprocess # nosec import unittest +from io import StringIO +import sys import casadi import numpy as np @@ -160,6 +162,243 @@ def test_read_parameters(self): } model.print_parameter_info() + def test_get_parameter_info(self): + model = pybamm.BaseModel() + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", "test") + c = pybamm.InputParameter("c") + d = pybamm.InputParameter("d") + e = pybamm.InputParameter("e") + f = pybamm.InputParameter("f") + g = pybamm.Parameter("g") + h = pybamm.Parameter("h") + i = pybamm.Parameter("i") + + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: -u * a} + model.algebraic = {v: v - b} + model.initial_conditions = {u: c, v: d} + model.events = [pybamm.Event("u=e", u - e)] + model.variables = {"v+f+i": v + f + i} + model.boundary_conditions = { + u: {"left": (g, "Dirichlet"), "right": (0, "Neumann")}, + v: {"left": (0, "Dirichlet"), "right": (h, "Neumann")}, + } + + parameter_info = model.get_parameter_info() + self.assertEqual(parameter_info["a"][1], "InputParameter") + self.assertEqual(parameter_info["b"][1], "InputParameter in ['test']") + self.assertIn("c", parameter_info) + self.assertIn("d", parameter_info) + self.assertIn("e", parameter_info) + self.assertIn("f", parameter_info) + self.assertEqual(parameter_info["g"][1], "Parameter") + self.assertIn("h", parameter_info) + self.assertIn("i", parameter_info) + + def test_get_parameter_info_submodel(self): + submodel = pybamm.lithium_ion.SPM().submodels["electrolyte diffusion"] + + class SubModel1(pybamm.BaseSubModel): + def get_fundamental_variables(self): + u = pybamm.Variable("u") + + variables = {"u": u} + return variables + + def get_coupled_variables(self, variables): + x = pybamm.Parameter("x") + w = pybamm.InputParameter("w") + f = pybamm.InputParameter("f", "test") + variables.update({"w": w, "x": x, "f": f}) + return variables + + def set_rhs(self, variables): + a = pybamm.InputParameter("a") + u = variables["u"] + self.rhs = {u: -u * a} + + def set_boundary_conditions(self, variables): + g = pybamm.Parameter("g") + u = variables["u"] + self.boundary_conditions = { + u: {"left": (g, "Dirichlet"), "right": (0, "Neumann")}, + } + + def set_initial_conditions(self, variables): + c = pybamm.FunctionParameter("c", {}) + u = variables["u"] + self.initial_conditions = {u: c} + + def set_events(self, variables): + e = pybamm.InputParameter("e") + u = variables["u"] + self.events = [pybamm.Event("u=e", u - e)] + + class SubModel2(pybamm.BaseSubModel): + def get_fundamental_variables(self): + v = pybamm.Variable("v") + i = pybamm.FunctionParameter("i", {}) + variables = {"v": v, "i": i} + return variables + + def set_rhs(self, variables): + b = pybamm.InputParameter("b", "test") + v = variables["v"] + self.rhs = {v: v - b} + + def set_boundary_conditions(self, variables): + h = pybamm.Parameter("h") + v = variables["v"] + self.boundary_conditions = { + v: {"left": (0, "Dirichlet"), "right": (h, "Neumann")}, + } + + def set_initial_conditions(self, variables): + d = pybamm.FunctionParameter("d", {}) + v = variables["v"] + self.initial_conditions = {v: d} + + sub1 = SubModel1(None) + sub2 = SubModel2(None) + model = pybamm.BaseModel() + model.submodels = {"sub1": sub1, "sub2": sub2} + model.build_model() + + parameter_info = model.get_parameter_info(by_submodel=True) + + expected_error_message = "Cannot use get_parameter_info" + + with self.assertRaisesRegex(NotImplementedError, expected_error_message): + submodel.get_parameter_info(by_submodel=True) + + with self.assertRaisesRegex(NotImplementedError, expected_error_message): + submodel.get_parameter_info(by_submodel=False) + + self.assertIn("a", parameter_info["sub1"]) + self.assertIn("b", parameter_info["sub2"]) + self.assertEqual(parameter_info["sub1"]["a"][1], "InputParameter") + self.assertEqual(parameter_info["sub1"]["w"][1], "InputParameter") + self.assertEqual(parameter_info["sub1"]["e"][1], "InputParameter") + self.assertEqual(parameter_info["sub1"]["g"][1], "Parameter") + self.assertEqual(parameter_info["sub1"]["x"][1], "Parameter") + self.assertEqual(parameter_info["sub1"]["f"][1], "InputParameter in ['test']") + self.assertEqual(parameter_info["sub2"]["b"][1], "InputParameter in ['test']") + self.assertEqual(parameter_info["sub2"]["h"][1], "Parameter") + self.assertEqual( + parameter_info["sub1"]["c"][1], + "FunctionParameter with inputs(s) ''", + ) + self.assertEqual( + parameter_info["sub2"]["d"][1], + "FunctionParameter with inputs(s) ''", + ) + self.assertEqual( + parameter_info["sub2"]["i"][1], + "FunctionParameter with inputs(s) ''", + ) + + def test_print_parameter_info(self): + model = pybamm.BaseModel() + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", "test") + c = pybamm.FunctionParameter("c", {}) + d = pybamm.FunctionParameter("d", {}) + e = pybamm.InputParameter("e") + f = pybamm.InputParameter("f") + g = pybamm.Parameter("g") + h = pybamm.Parameter("h") + i = pybamm.Parameter("i") + + u = pybamm.Variable("u") + v = pybamm.Variable("v") + + sub1 = pybamm.BaseSubModel(None) + sub1.rhs = {u: -u * a} + sub1.initial_conditions = {u: c} + sub1.variables = {"u": u} + sub1.boundary_conditions = { + u: {"left": (g, "Dirichlet"), "right": (0, "Neumann")}, + } + sub2 = pybamm.BaseSubModel(None) + sub2.algebraic = {v: v - b} + sub2.variables = {"v": v, "v+f+i": v + f + i} + sub2.initial_conditions = {v: d} + sub2.boundary_conditions = { + v: {"left": (0, "Dirichlet"), "right": (h, "Neumann")}, + } + sub3 = pybamm.BaseSubModel(None) + model.submodels = {"sub1": sub1, "sub2": sub2, "sub3": sub3} + model.events = [pybamm.Event("u=e", u - e)] + model.build_model() + captured_output = StringIO() + sys.stdout = captured_output + + model.print_parameter_info() + sys.stdout = sys.__stdout__ + + result = captured_output.getvalue().strip() + self.assertIn("a", result) + self.assertIn("b", result) + self.assertIn("InputParameter", result) + self.assertIn("InputParameter in ['test']", result) + self.assertIn("Parameter", result) + self.assertIn("FunctionParameter with inputs(s) ''", result) + + def test_print_parameter_info_submodel(self): + model = pybamm.BaseModel() + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", "test") + c = pybamm.FunctionParameter("c", {}) + d = pybamm.FunctionParameter("d", {}) + e = pybamm.InputParameter("e") + f = pybamm.InputParameter("f") + g = pybamm.Parameter("g") + h = pybamm.Parameter("h") + i = pybamm.Parameter("i") + + u = pybamm.Variable("u") + v = pybamm.Variable("v") + + sub1 = pybamm.BaseSubModel(None) + sub1.rhs = {u: -u * a} + sub1.initial_conditions = {u: c} + sub1.variables = {"u": u} + sub1.boundary_conditions = { + u: {"left": (g, "Dirichlet"), "right": (0, "Neumann")}, + } + sub2 = pybamm.BaseSubModel(None) + sub2.algebraic = {v: v - b} + sub2.variables = {"v": v, "v+f+i": v + f + i} + sub2.initial_conditions = {v: d} + sub2.boundary_conditions = { + v: {"left": (0, "Dirichlet"), "right": (h, "Neumann")}, + } + sub3 = pybamm.BaseSubModel(None) + model.submodels = {"sub1": sub1, "sub2": sub2, "sub3": sub3} + model.events = [pybamm.Event("u=e", u - e)] + model.build_model() + captured_output = StringIO() + sys.stdout = captured_output + + model.print_parameter_info(by_submodel=True) + sys.stdout = sys.__stdout__ + + result = captured_output.getvalue().strip() + self.assertIn("'sub1' submodel parameters:", result) + self.assertIn("'sub2' submodel parameters:", result) + self.assertIn("Parameter", result) + self.assertIn("InputParameter", result) + self.assertIn("FunctionParameter with inputs(s) ''", result) + self.assertIn("InputParameter in ['test']", result) + self.assertIn("g", result) + self.assertIn("a", result) + self.assertIn("c", result) + self.assertIn("h", result) + self.assertIn("b", result) + self.assertIn("d", result) + def test_read_input_parameters(self): # Read input parameters from different parts of the model model = pybamm.BaseModel()