diff --git a/src/problem.jl b/src/problem.jl index f95f3f9..79fdfc2 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -99,6 +99,6 @@ end Overwrite `Base.getindex` to allow for slicing of input/output-based problems. """ Base.getindex(p::Problem{Vector{IOExample}}, indices) = Problem(p.spec[indices]) -Base.getindex(p::MetricProblem{Vector{IOExample}}, indices) = Problem(p.spec[indices]) +Base.getindex(p::MetricProblem{Vector{IOExample}}, indices) = MetricProblem(p.cost_function, p.spec[indices]) diff --git a/test/test_ioproblem.jl b/test/test_ioproblem.jl index 42c3efc..8f969db 100644 --- a/test/test_ioproblem.jl +++ b/test/test_ioproblem.jl @@ -69,7 +69,8 @@ end # Test getindex submetric = metric2[1:2] - @test isa(submetric, Problem) + @test isa(submetric, MetricProblem) @test submetric.spec == spec[1:2] @test submetric.name == "" + @test submetric.cost_function === cost_function end