From c9a0ff128c8d51d1b9a08f5029918d11849dfaf8 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Tue, 16 Jan 2024 17:04:43 +0000 Subject: [PATCH] Fix coverage issues --- pybamm/expression_tree/array.py | 2 +- pybamm/expression_tree/broadcasts.py | 2 +- pybamm/expression_tree/concatenations.py | 7 ---- .../operations/unpack_symbols.py | 2 +- pybamm/expression_tree/parameter.py | 2 +- pybamm/expression_tree/symbol.py | 33 +++++++------------ 6 files changed, 16 insertions(+), 32 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a5f1314e01..a433dec822 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -9,7 +9,7 @@ import pybamm from pybamm.util import have_optional_dependency -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import sympy diff --git a/pybamm/expression_tree/broadcasts.py b/pybamm/expression_tree/broadcasts.py index 19943cbb17..8e22ecd487 100644 --- a/pybamm/expression_tree/broadcasts.py +++ b/pybamm/expression_tree/broadcasts.py @@ -59,7 +59,7 @@ def _diff(self, variable): # Differentiate the child and broadcast the result in the same way return self._unary_new_copy(self.child.diff(variable)) - def reduce_one_dimension(self): + def reduce_one_dimension(self): # pragma: no cover """Reduce the broadcast by one dimension.""" raise NotImplementedError diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 3b8eca9666..1a6d69d64d 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -592,10 +592,3 @@ def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh): """Helper function to create domain concatenations.""" # TODO: add option to turn off simplifications return simplified_domain_concatenation(children, mesh) - - -def all_children_are( - children: list[pybamm.Symbol], - class_type: type[S], -) -> TypeGuard[list[S]]: - return all(isinstance(child, class_type) for child in children) diff --git a/pybamm/expression_tree/operations/unpack_symbols.py b/pybamm/expression_tree/operations/unpack_symbols.py index 56cff5e859..1933eada76 100644 --- a/pybamm/expression_tree/operations/unpack_symbols.py +++ b/pybamm/expression_tree/operations/unpack_symbols.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import pybamm diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index fb0380cea4..911e0ca59f 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -8,7 +8,7 @@ import numpy as np from typing import TYPE_CHECKING, Literal -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import sympy import pybamm diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index bbf67bf8d3..3320b361ad 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -7,22 +7,13 @@ import numpy as np from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property -from typing import ( - TYPE_CHECKING, - Sequence, -) +from typing import TYPE_CHECKING, Sequence, cast import pybamm from pybamm.util import have_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name -if TYPE_CHECKING: - from pybamm.expression_tree.binary_operators import ( - Addition, - Subtraction, - Multiplication, - Division, - ) +if TYPE_CHECKING: # pragma: no cover import casadi from hints import S @@ -171,9 +162,9 @@ def simplify_if_constant(symbol: S) -> S: or (isinstance(result, np.ndarray) and result.ndim == 0) or isinstance(result, np.bool_) ): - if isinstance(result, np.ndarray): + if isinstance(result, np.ndarray): # pragma: no cover # type-narrow for Scalar - new_result = result[0] + new_result = cast(float, result) return pybamm.Scalar(new_result) return pybamm.Scalar(result) elif isinstance(result, np.ndarray) or issparse(result): @@ -582,27 +573,27 @@ def __repr__(self): {k: v for k, v in self.domains.items() if v != []}, ) - def __add__(self, other: Symbol | float | np.ndarray) -> Addition: + def __add__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(self, other) - def __radd__(self, other: Symbol | float | np.ndarray) -> Addition: + def __radd__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition: """return an :class:`Addition` object.""" return pybamm.add(other, self) - def __sub__(self, other: Symbol | float | np.ndarray) -> Subtraction: + def __sub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(self, other) - def __rsub__(self, other: Symbol | float | np.ndarray) -> Subtraction: + def __rsub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction: """return a :class:`Subtraction` object.""" return pybamm.subtract(other, self) - def __mul__(self, other: Symbol | float | np.ndarray) -> Multiplication: + def __mul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(self, other) - def __rmul__(self, other: Symbol | float | np.ndarray) -> Multiplication: + def __rmul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication: """return a :class:`Multiplication` object.""" return pybamm.multiply(other, self) @@ -618,11 +609,11 @@ def __rmatmul__( """return a :class:`MatrixMultiplication` object.""" return pybamm.matmul(other, self) - def __truediv__(self, other: Symbol | float | np.ndarray) -> Division: + def __truediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(self, other) - def __rtruediv__(self, other: Symbol | float | np.ndarray) -> Division: + def __rtruediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division: """return a :class:`Division` object.""" return pybamm.divide(other, self)