Skip to content

Commit

Permalink
Merge pull request #861 from stan-dev/complex-support
Browse files Browse the repository at this point in the history
Add support/tests for exporting functions with complex types
  • Loading branch information
andrjohns authored Sep 30, 2023
2 parents a2af167 + ccab663 commit ede72e3
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
3 changes: 2 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,8 @@ get_function_name <- function(fun_start, fun_end, model_lines) {
"double",
"Eigen::Matrix<(.*)>",
"std::vector<(.*)>",
"std::tuple<(.*)>"
"std::tuple<(.*)>",
"std::complex<(.*)>"
)
pattern <- paste0(
# Only match if the type occurs at start of string
Expand Down
114 changes: 114 additions & 0 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,36 @@ functions {
tuple(int, tuple(array[] vector, array[] vector)) rtn_nest_tuple_vec_array(tuple(int, tuple(array[] vector, array[] vector)) x) { return x; }
tuple(int, tuple(array[] row_vector, array[] row_vector)) rtn_nest_tuple_rowvec_array(tuple(int, tuple(array[] row_vector, array[] row_vector)) x) { return x; }
tuple(int, tuple(array[] matrix, array[] matrix)) rtn_nest_tuple_matrix_array(tuple(int, tuple(array[] matrix, array[] matrix)) x) { return x; }
complex rtn_complex(complex x) { return x; }
complex_vector rtn_complex_vec(complex_vector x) { return x; }
complex_row_vector rtn_complex_rowvec(complex_row_vector x) { return x; }
complex_matrix rtn_complex_matrix(complex_matrix x) { return x; }
array[] complex rtn_complex_array(array[] complex x) { return x; }
array[] complex_vector rtn_complex_vec_array(array[] complex_vector x) { return x; }
array[] complex_row_vector rtn_complex_rowvec_array(array[] complex_row_vector x) { return x; }
array[] complex_matrix rtn_complex_matrix_array(array[] complex_matrix x) { return x; }
tuple(complex, complex) rtn_tuple_complex(tuple(complex, complex) x) { return x; }
tuple(complex_vector, complex_vector) rtn_tuple_complex_vec(tuple(complex_vector, complex_vector) x) { return x; }
tuple(complex_row_vector, complex_row_vector) rtn_tuple_complex_rowvec(tuple(complex_row_vector, complex_row_vector) x) { return x; }
tuple(complex_matrix, complex_matrix) rtn_tuple_complex_matrix(tuple(complex_matrix, complex_matrix) x) { return x; }
tuple(array[] complex, array[] complex) rtn_tuple_complex_array(tuple(array[] complex, array[] complex) x) { return x; }
tuple(array[] complex_vector, array[] complex_vector) rtn_tuple_complex_vec_array(tuple(array[] complex_vector, array[] complex_vector) x) { return x; }
tuple(array[] complex_row_vector, array[] complex_row_vector) rtn_tuple_complex_rowvec_array(tuple(array[] complex_row_vector, array[] complex_row_vector) x) { return x; }
tuple(array[] complex_matrix, array[] complex_matrix) rtn_tuple_complex_matrix_array(tuple(array[] complex_matrix, array[] complex_matrix) x) { return x; }
tuple(int, tuple(complex, complex)) rtn_nest_tuple_complex(tuple(int, tuple(complex, complex)) x) { return x; }
tuple(int, tuple(complex_vector, complex_vector)) rtn_nest_tuple_complex_vec(tuple(int, tuple(complex_vector, complex_vector)) x) { return x; }
tuple(int, tuple(complex_row_vector, complex_row_vector)) rtn_nest_tuple_complex_rowvec(tuple(int, tuple(complex_row_vector, complex_row_vector)) x) { return x; }
tuple(int, tuple(complex_matrix, complex_matrix)) rtn_nest_tuple_complex_matrix(tuple(int, tuple(complex_matrix, complex_matrix)) x) { return x; }
tuple(int, tuple(array[] complex, array[] complex)) rtn_nest_tuple_complex_array(tuple(int, tuple(array[] complex, array[] complex)) x) { return x; }
tuple(int, tuple(array[] complex_vector, array[] complex_vector)) rtn_nest_tuple_complex_vec_array(tuple(int, tuple(array[] complex_vector, array[] complex_vector)) x) { return x; }
tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) rtn_nest_tuple_complex_rowvec_array(tuple(int, tuple(array[] complex_row_vector, array[] complex_row_vector)) x) { return x; }
tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) rtn_nest_tuple_complex_matrix_array(tuple(int, tuple(array[] complex_matrix, array[] complex_matrix)) x) { return x; }
}"
stan_prog <- paste(function_decl,
paste(readLines(testing_stan_file("bernoulli")),
Expand Down Expand Up @@ -147,6 +177,90 @@ test_that("Functions handle types correctly", {
expect_equal(mod$functions$rtn_nest_tuple_matrix_array(nest_tuple_matrix_array), nest_tuple_matrix_array)
})

test_that("Functions handle complex types correctly", {
skip_if(os_is_wsl())

### Scalar

complex_scalar <- complex(real = 2.1, imaginary = 21.3)

expect_equal(mod$functions$rtn_complex(complex_scalar), complex_scalar)

### Container

complex_vec <- complex(real = c(2,1.5,0.11, 1.2), imaginary = c(11.2,21.5,6.1,3.2))
complex_rowvec <- t(complex_vec)
complex_matrix <- matrix(complex_vec, nrow=2, ncol=2)

expect_equal(mod$functions$rtn_complex_vec(complex_vec), complex_vec)
expect_equal(mod$functions$rtn_complex_rowvec(complex_rowvec), complex_rowvec)
expect_equal(mod$functions$rtn_complex_matrix(complex_matrix), complex_matrix)
expect_equal(mod$functions$rtn_complex_array(complex_vec), complex_vec)

### Array of Container

complex_vec_array <- list(complex_vec, complex_vec * 2, complex_vec + 0.1)
complex_rowvec_array <- list(complex_rowvec, complex_rowvec * 2, complex_rowvec + 0.1)
complex_matrix_array <- list(complex_matrix, complex_matrix * 2, complex_matrix + 0.1)

expect_equal(mod$functions$rtn_complex_vec_array(complex_vec_array), complex_vec_array)
expect_equal(mod$functions$rtn_complex_rowvec_array(complex_rowvec_array), complex_rowvec_array)
expect_equal(mod$functions$rtn_complex_matrix_array(complex_matrix_array), complex_matrix_array)

### Tuple of Scalar

tuple_complex <- list(complex_vec[1], complex_vec[2])
expect_equal(mod$functions$rtn_tuple_complex(tuple_complex), tuple_complex)

### Tuple of Container

tuple_complex_vec <- list(complex_vec, complex_vec * 1.2)
tuple_complex_rowvec <- list(complex_rowvec, complex_rowvec * 0.5)
tuple_complex_matrix <- list(complex_matrix, complex_matrix * 10.2)

expect_equal(mod$functions$rtn_tuple_complex_array(tuple_complex_vec), tuple_complex_vec)
expect_equal(mod$functions$rtn_tuple_complex_vec(tuple_complex_vec), tuple_complex_vec)
expect_equal(mod$functions$rtn_tuple_complex_rowvec(tuple_complex_rowvec), tuple_complex_rowvec)
expect_equal(mod$functions$rtn_tuple_complex_matrix(tuple_complex_matrix), tuple_complex_matrix)

### Tuple of Container Arrays

tuple_complex_vec_array <- list(complex_vec_array, complex_vec_array)
tuple_complex_rowvec_array <- list(complex_rowvec_array, complex_rowvec_array)
tuple_complex_matrix_array <- list(complex_matrix_array, complex_matrix_array)

expect_equal(mod$functions$rtn_tuple_complex_vec_array(tuple_complex_vec_array), tuple_complex_vec_array)
expect_equal(mod$functions$rtn_tuple_complex_rowvec_array(tuple_complex_rowvec_array), tuple_complex_rowvec_array)
expect_equal(mod$functions$rtn_tuple_complex_matrix_array(tuple_complex_matrix_array), tuple_complex_matrix_array)

### Nested Tuple of Scalar

nest_tuple_complex <- list(31, tuple_complex)
expect_equal(mod$functions$rtn_nest_tuple_complex(nest_tuple_complex), nest_tuple_complex)

### Nested Tuple of Container

nest_tuple_complex_vec <- list(12, tuple_complex_vec)
nest_tuple_complex_rowvec <- list(2, tuple_complex_rowvec)
nest_tuple_complex_matrix <- list(-23, tuple_complex_matrix)
nest_tuple_complex_array <- list(21, tuple_complex_vec)

expect_equal(mod$functions$rtn_nest_tuple_complex_array(nest_tuple_complex_vec), nest_tuple_complex_vec)
expect_equal(mod$functions$rtn_nest_tuple_complex_vec(nest_tuple_complex_vec), nest_tuple_complex_vec)
expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec(nest_tuple_complex_rowvec), nest_tuple_complex_rowvec)
expect_equal(mod$functions$rtn_nest_tuple_complex_matrix(nest_tuple_complex_matrix), nest_tuple_complex_matrix)

### Nested Tuple of Container Arrays

nest_tuple_complex_vec_array <- list(-21, tuple_complex_vec_array)
nest_tuple_complex_rowvec_array <- list(1000, tuple_complex_rowvec_array)
nest_tuple_complex_matrix_array <- list(0, tuple_complex_matrix_array)

expect_equal(mod$functions$rtn_nest_tuple_complex_vec_array(nest_tuple_complex_vec_array), nest_tuple_complex_vec_array)
expect_equal(mod$functions$rtn_nest_tuple_complex_rowvec_array(nest_tuple_complex_rowvec_array), nest_tuple_complex_rowvec_array)
expect_equal(mod$functions$rtn_nest_tuple_complex_matrix_array(nest_tuple_complex_matrix_array), nest_tuple_complex_matrix_array)
})

test_that("Functions can be exposed in fit object", {
skip_if(os_is_wsl())
fit$expose_functions(verbose = TRUE)
Expand Down

0 comments on commit ede72e3

Please sign in to comment.