Skip to content

Commit

Permalink
Create test_fd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Barry57 authored Oct 27, 2024
1 parent 94f19ce commit 51ba14d
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions pytest/test_fd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import numpy as np
from GENetLib.fd import fd
from GENetLib.create_basis import create_bspline_basis

# Test fd function with default coefficients
def test_fd_default_coef():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
fdobj = fd(basisobj=basisobj)
assert 'coefs' in fdobj
assert fdobj['coefs'].shape == (4,)

# Test fd function with custom coefficients
def test_fd_custom_coef():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4)
fdobj = fd(coef=coef, basisobj=basisobj)
assert 'coefs' in fdobj
assert np.all(fdobj['coefs'] == coef)

# Test fd function with multidimensional coefficients
def test_fd_multidim_coef():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4, 2)
fdobj = fd(coef=coef, basisobj=basisobj)
assert 'coefs' in fdobj
assert fdobj['coefs'].shape == (4, 2)

# Test fd function with invalid coefficient type
def test_fd_invalid_coef_type():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = "not a number"
with pytest.raises(ValueError):
fd(coef=coef, basisobj=basisobj)

# Test fd function with invalid coefficient dimension
def test_fd_invalid_coef_dim():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4, 2, 2)
with pytest.raises(ValueError):
fd(coef=coef, basisobj=basisobj)

# Test fd function with custom fdnames
def test_fd_custom_fdnames():
basisobj = create_bspline_basis(rangeval=[0, 1], nbasis=4, norder=4)
coef = np.random.rand(4)
fdnames = {"args": ["time"], "reps": ["rep1"], "funs": ["value"]}
fdobj = fd(coef=coef, basisobj=basisobj, fdnames=fdnames)
assert 'fdnames' in fdobj
assert fdobj['fdnames'] == fdnames

# Test fd function without basis object
def test_fd_without_basisobj():
coef = np.random.rand(4)
fdobj = fd(coef=coef)
assert 'coefs' in fdobj
assert fdobj['coefs'].shape == (4,)

# Test fd function with invalid basis object
def test_fd_invalid_basisobj():
coef = np.random.rand(4)
basisobj = "not a basis object"
with pytest.raises(ValueError):
fd(coef=coef, basisobj=basisobj)

0 comments on commit 51ba14d

Please sign in to comment.