Skip to content

Commit

Permalink
properly override get_rng
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Aug 1, 2024
1 parent 9cb0cff commit 18c4085
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/ScalarRandomFeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ $(DocStringExtensions.TYPEDSIGNATURES)
gets the rng field
"""
get_rng(srfi::ScalarRandomFeatureInterface) = srfi.rng
EKP.get_rng(srfi::ScalarRandomFeatureInterface) = srfi.rng

"""
$(DocStringExtensions.TYPEDSIGNATURES)
Expand Down
2 changes: 1 addition & 1 deletion src/VectorRandomFeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ $(DocStringExtensions.TYPEDSIGNATURES)
Gets the rng field
"""
get_rng(vrfi::VectorRandomFeatureInterface) = vrfi.rng
EKP.get_rng(vrfi::VectorRandomFeatureInterface) = vrfi.rng

"""
$(DocStringExtensions.TYPEDSIGNATURES)
Expand Down
12 changes: 4 additions & 8 deletions test/RandomFeature/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ using CalibrateEmulateSample.DataContainers
using CalibrateEmulateSample.EnsembleKalmanProcesses
using CalibrateEmulateSample.ParameterDistributions

# resolve function import conflicts:
const RF = RandomFeatures


seed = 10101010
rng = Random.MersenneTwister(seed)

Expand Down Expand Up @@ -166,14 +162,14 @@ rng = Random.MersenneTwister(seed)
@test get_batch_sizes(srfi) == batch_sizes
@test get_n_features(srfi) == n_features
@test get_input_dim(srfi) == input_dim
@test RF.get_rng(srfi) == rng
@test get_rng(srfi) == rng
@test get_kernel_structure(srfi) == kernel_structure
@test get_optimizer_options(srfi) == optimizer_options

# check defaults
srfi2 = ScalarRandomFeatureInterface(n_features, input_dim)
@test get_batch_sizes(srfi2) === nothing
@test RF.get_rng(srfi2) == Random.GLOBAL_RNG
@test get_rng(srfi2) == Random.GLOBAL_RNG
@test get_kernel_structure(srfi2) ==
SeparableKernel(cov_structure_from_string("lowrank", input_dim), OneDimFactor())

Expand Down Expand Up @@ -234,14 +230,14 @@ rng = Random.MersenneTwister(seed)
@test get_input_dim(vrfi) == input_dim
@test get_output_dim(vrfi) == output_dim
@test get_kernel_structure(vrfi) == kernel_structure
@test RF.get_rng(vrfi) == rng
@test get_rng(vrfi) == rng
@test get_optimizer_options(vrfi) == optimizer_options

#check defaults
vrfi2 = VectorRandomFeatureInterface(n_features, input_dim, output_dim)

@test get_batch_sizes(vrfi2) === nothing
@test RF.get_rng(vrfi2) == Random.GLOBAL_RNG
@test get_rng(vrfi2) == Random.GLOBAL_RNG
@test get_kernel_structure(vrfi2) == SeparableKernel(
cov_structure_from_string("lowrank", input_dim),
cov_structure_from_string("lowrank", output_dim),
Expand Down

0 comments on commit 18c4085

Please sign in to comment.