Skip to content

Commit

Permalink
Fix coverage issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pipliggins committed Jan 16, 2024
1 parent 64f3fc9 commit c9a0ff1
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pybamm
from pybamm.util import have_optional_dependency

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import sympy


Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _diff(self, variable):
# Differentiate the child and broadcast the result in the same way
return self._unary_new_copy(self.child.diff(variable))

def reduce_one_dimension(self):
def reduce_one_dimension(self): # pragma: no cover
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

Expand Down
7 changes: 0 additions & 7 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,3 @@ def domain_concatenation(children: list[pybamm.Symbol], mesh: pybamm.Mesh):
"""Helper function to create domain concatenations."""
# TODO: add option to turn off simplifications
return simplified_domain_concatenation(children, mesh)


def all_children_are(
children: list[pybamm.Symbol],
class_type: type[S],
) -> TypeGuard[list[S]]:
return all(isinstance(child, class_type) for child in children)
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/unpack_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import pybamm


Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import sympy

import pybamm
Expand Down
33 changes: 12 additions & 21 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,13 @@
import numpy as np
from scipy.sparse import csr_matrix, issparse
from functools import lru_cache, cached_property
from typing import (
TYPE_CHECKING,
Sequence,
)
from typing import TYPE_CHECKING, Sequence, cast

import pybamm
from pybamm.util import have_optional_dependency
from pybamm.expression_tree.printing.print_name import prettify_print_name

if TYPE_CHECKING:
from pybamm.expression_tree.binary_operators import (
Addition,
Subtraction,
Multiplication,
Division,
)
if TYPE_CHECKING: # pragma: no cover
import casadi
from hints import S

Expand Down Expand Up @@ -171,9 +162,9 @@ def simplify_if_constant(symbol: S) -> S:
or (isinstance(result, np.ndarray) and result.ndim == 0)
or isinstance(result, np.bool_)
):
if isinstance(result, np.ndarray):
if isinstance(result, np.ndarray): # pragma: no cover
# type-narrow for Scalar
new_result = result[0]
new_result = cast(float, result)
return pybamm.Scalar(new_result)
return pybamm.Scalar(result)
elif isinstance(result, np.ndarray) or issparse(result):
Expand Down Expand Up @@ -582,27 +573,27 @@ def __repr__(self):
{k: v for k, v in self.domains.items() if v != []},
)

def __add__(self, other: Symbol | float | np.ndarray) -> Addition:
def __add__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition:
"""return an :class:`Addition` object."""
return pybamm.add(self, other)

def __radd__(self, other: Symbol | float | np.ndarray) -> Addition:
def __radd__(self, other: Symbol | float | np.ndarray) -> pybamm.Addition:
"""return an :class:`Addition` object."""
return pybamm.add(other, self)

def __sub__(self, other: Symbol | float | np.ndarray) -> Subtraction:
def __sub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction:
"""return a :class:`Subtraction` object."""
return pybamm.subtract(self, other)

def __rsub__(self, other: Symbol | float | np.ndarray) -> Subtraction:
def __rsub__(self, other: Symbol | float | np.ndarray) -> pybamm.Subtraction:
"""return a :class:`Subtraction` object."""
return pybamm.subtract(other, self)

def __mul__(self, other: Symbol | float | np.ndarray) -> Multiplication:
def __mul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication:
"""return a :class:`Multiplication` object."""
return pybamm.multiply(self, other)

def __rmul__(self, other: Symbol | float | np.ndarray) -> Multiplication:
def __rmul__(self, other: Symbol | float | np.ndarray) -> pybamm.Multiplication:
"""return a :class:`Multiplication` object."""
return pybamm.multiply(other, self)

Expand All @@ -618,11 +609,11 @@ def __rmatmul__(
"""return a :class:`MatrixMultiplication` object."""
return pybamm.matmul(other, self)

def __truediv__(self, other: Symbol | float | np.ndarray) -> Division:
def __truediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division:
"""return a :class:`Division` object."""
return pybamm.divide(self, other)

def __rtruediv__(self, other: Symbol | float | np.ndarray) -> Division:
def __rtruediv__(self, other: Symbol | float | np.ndarray) -> pybamm.Division:
"""return a :class:`Division` object."""
return pybamm.divide(other, self)

Expand Down

0 comments on commit c9a0ff1

Please sign in to comment.