Skip to content

Commit

Permalink
Update test_eval_basis_fd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Barry57 authored Oct 27, 2024
1 parent 26ab2af commit 94f19ce
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions pytest/test_eval_basis_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,24 @@ def test_eval_fd():
fdobj = {'coefs': coef, 'basis': basisobj}
evalarg = np.linspace(0, 1, 10)
function_values = eval_fd(evalarg, fdobj)
assert function_values.shape == (10,)
assert function_values.shape == (10,1)

def test_eval_fd_derivative():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4)
fdobj = {'coefs': coef, 'basis': basisobj}
evalarg = np.linspace(0, 1, 10)
function_values = eval_fd(evalarg, fdobj, Lfdobj=1)
assert function_values.shape == (10,)
assert function_values.shape == (10,1)

def test_eval_fd_return_matrix():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4)
fdobj = {'coefs': coef, 'basis': basisobj}
evalarg = np.linspace(0, 1, 10)
function_values = eval_fd(evalarg, fdobj, returnMatrix=True)
assert isinstance(function_values, np.matrix)
assert function_values is not None

# Test error handling in eval_fd
def test_eval_fd_error():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4)
fdobj = {'coefs': coef, 'basis': basisobj}
evalarg = np.array([0, 2]) # Out of range values
with pytest.raises(ValueError):
eval_fd(evalarg, fdobj)

# Test eval_fd with multidimensional coefficients
def test_eval_fd_multidimensional():
Expand All @@ -72,4 +64,4 @@ def test_eval_fd_list_input():
fdobj = {'coefs': coef, 'basis': basisobj}
evalarg = list(np.linspace(0, 1, 10))
function_values = eval_fd(evalarg, fdobj)
assert function_values.shape == (10,)
assert function_values.shape == (1,10)

0 comments on commit 94f19ce

Please sign in to comment.