diff --git a/src/Fields/field.jl b/src/Fields/field.jl index 0f5f476a0b..81021f3743 100644 --- a/src/Fields/field.jl +++ b/src/Fields/field.jl @@ -1,5 +1,6 @@ using Oceananigans.BoundaryConditions: OBC, MCBC, BoundaryCondition using Oceananigans.Grids: parent_index_range, index_range_offset, default_indices, all_indices, validate_indices +using Oceananigans.Grids: index_range_contains using Adapt using KernelAbstractions: @kernel, @index @@ -230,27 +231,24 @@ function Base.similar(f::Field, grid=f.grid) end """ - offset_windowed_data(data, loc, grid, indices) + offset_windowed_data(data, data_indices, loc, grid, view_indices) -Return an `OffsetArray` of a `view` of `parent(data)` with `indices`. +Return an `OffsetArray` of `parent(data)`. + +If `indices` is not (:, :, :), a `view` of `parent(data)` with `indices`. If `indices === (:, :, :)`, return an `OffsetArray` of `parent(data)`. """ -function offset_windowed_data(data, Loc, grid, indices) +function offset_windowed_data(data, data_indices, Loc, grid, view_indices) halo = halo_size(grid) topo = map(instantiate, topology(grid)) loc = map(instantiate, Loc) - if indices isa typeof(default_indices(3)) - windowed_parent = parent(data) - else - parent_indices = map(parent_index_range, indices, loc, topo, halo) - windowed_parent = view(parent(data), parent_indices...) - end + parent_indices = map(parent_index_range, data_indices, view_indices, loc, topo, halo) + windowed_parent = view(parent(data), parent_indices...) sz = size(grid) - - return offset_data(windowed_parent, loc, topo, sz, halo, indices) + return offset_data(windowed_parent, loc, topo, sz, halo, view_indices) end """ @@ -305,17 +303,27 @@ function Base.view(f::Field, i, j, k) loc = location(f) # Validate indices (convert Int to UnitRange, error for invalid indices) - window_indices = validate_indices((i, j, k), loc, f.grid) - + view_indices = validate_indices((i, j, k), loc, f.grid) + + if view_indices == f.indices # nothing to "view" here + return f # we want the whole field after all. + end + + # Check that the indices actually work here + valid_view_indices = map(index_range_contains, f.indices, view_indices) + + all(valid_view_indices) || + throw(ArgumentError("view indices $((i, j, k)) do not intersect field indices $(f.indices)")) + # Choice: OffsetArray of view of OffsetArray, or OffsetArray of view? # -> the first retains a reference to the original f.data (an OffsetArray) # -> the second loses it, so we'd have to "re-offset" the underlying data to access. # -> we choose the second here, opting to "reduce indirection" at the cost of "index recomputation". # # OffsetArray around a view of parent with appropriate indices: - windowed_data = offset_windowed_data(f.data, loc, grid, window_indices) + windowed_data = offset_windowed_data(f.data, f.indices, loc, grid, view_indices) - boundary_conditions = FieldBoundaryConditions(window_indices, f.boundary_conditions) + boundary_conditions = FieldBoundaryConditions(view_indices, f.boundary_conditions) # "Sliced" Fields created here share data with their parent. # Therefore we set status=nothing so we don't conflate computation @@ -326,7 +334,7 @@ function Base.view(f::Field, i, j, k) grid, windowed_data, boundary_conditions, - window_indices, + view_indices, f.operand, status) end diff --git a/src/Grids/grid_utils.jl b/src/Grids/grid_utils.jl index 90b34d603f..b71a2855fd 100644 --- a/src/Grids/grid_utils.jl +++ b/src/Grids/grid_utils.jl @@ -221,12 +221,34 @@ regular_dimensions(grid) = () @inline all_parent_y_indices(grid, loc) = all_parent_indices(loc[2](), topology(grid, 2)(), size(grid, 2), halo_size(grid, 2)) @inline all_parent_z_indices(grid, loc) = all_parent_indices(loc[3](), topology(grid, 3)(), size(grid, 3), halo_size(grid, 3)) +# Return the index range of "full" parent arrays that span an entire dimension parent_index_range(::Colon, loc, topo, halo) = Colon() parent_index_range(::Base.Slice{<:IdOffsetRange}, loc, topo, halo) = Colon() -parent_index_range(index::UnitRange, loc, topo, halo) = index .+ interior_parent_offset(loc, topo, halo) +parent_index_range(view_indices::UnitRange, ::Nothing, ::Flat, halo) = view_indices +parent_index_range(view_indices::UnitRange, ::Nothing, ::AT, halo) = 1:1 # or Colon() +parent_index_range(view_indices::UnitRange, loc, topo, halo) = view_indices .+ interior_parent_offset(loc, topo, halo) -parent_index_range(index::UnitRange, ::Nothing, ::Flat, halo) = index -parent_index_range(index::UnitRange, ::Nothing, ::AT, halo) = 1:1 # or Colon() +# Return the index range of parent arrays that are themselves windowed +parent_index_range(::Colon, args...) = parent_index_range(args...) + +function parent_index_range(parent_indices::UnitRange, view_indices, args...) + start = first(view_indices) - first(parent_indices) + 1 + stop = start + length(view_indices) - 1 + return UnitRange(start, stop) +end + +# intersect_index_range(::Colon, ::Colon) = Colon() +index_range_contains(range, subset::UnitRange) = (first(subset) ∈ range) & (last(subset) ∈ range) +index_range_contains(::Colon, subset::UnitRange) = true +index_range_contains(::Colon, ::Colon) = true + +# Note: this choice means subset indices are defined on the whole grid. +# Thus any UnitRange does not contain `:`. +index_range_contains(range::UnitRange, subset::Colon) = false + +# Return the index range of "full" parent arrays that span an entire dimension +parent_windowed_indices(::Colon, loc, topo, halo) = Colon() +parent_windowed_indices(indices::UnitRange, loc, topo, halo) = UnitRange(1, length(indices)) index_range_offset(index::UnitRange, loc, topo, halo) = index[1] - interior_parent_offset(loc, topo, halo) index_range_offset(::Colon, loc, topo, halo) = - interior_parent_offset(loc, topo, halo) diff --git a/test/test_field.jl b/test/test_field.jl index 2383afaab3..65c8787b94 100644 --- a/test/test_field.jl +++ b/test/test_field.jl @@ -2,6 +2,7 @@ include("dependencies_for_runtests.jl") using Statistics +using Oceananigans.Grids: total_length using Oceananigans.Fields: ReducedField, has_velocities using Oceananigans.Fields: VelocityFields, TracerFields, interpolate, interpolate! using Oceananigans.Fields: reduced_location @@ -295,7 +296,6 @@ end end @testset "Setting fields" begin - @info " Testing field setting..." FieldTypes = (CenterField, XFaceField, YFaceField, ZFaceField) @@ -451,4 +451,54 @@ end end end end + + @testset "Views of field views" begin + @info " Testing views of field views..." + + Nx, Ny, Nz = 1, 1, 7 + + FieldTypes = (CenterField, XFaceField, YFaceField, ZFaceField) + ZTopologies = (Periodic, Bounded) + + for arch in archs, FT in float_types, FieldType in FieldTypes, ZTopology in ZTopologies + grid = RectilinearGrid(arch, FT, size=(Nx, Ny, Nz), x=(0, 1), y=(0, 1), z=(0, 1), topology = (Periodic, Periodic, ZTopology)) + Hx, Hy, Hz = halo_size(grid) + + c = FieldType(grid) + set!(c, (x, y, z) -> rand()) + + k_top = total_length(location(c, 3)(), topology(c, 3)(), size(grid, 3)) + + # First test that the regular view is correct + cv = view(c, :, :, 1+1:k_top-1) + @test size(cv) == (Nx, Ny, k_top-2) + @test size(parent(cv)) == (Nx+2Hx, Ny+2Hy, k_top-2) + CUDA.@allowscalar @test all(cv[i, j, k] == c[i, j, k] for k in 1+1:k_top-1, j in 1:Ny, i in 1:Nx) + + # Now test the views of views + cvv = view(cv, :, :, 1+2:k_top-2) + @test size(cvv) == (Nx, Ny, k_top-4) + @test size(parent(cvv)) == (Nx+2Hx, Ny+2Hy, k_top-4) + CUDA.@allowscalar @test all(cvv[i, j, k] == cv[i, j, k] for k in 1+2:k_top-2, j in 1:Ny, i in 1:Nx) + + cvvv = view(cvv, :, :, 1+3:k_top-3) + @test size(cvvv) == (1, 1, k_top-6) + @test size(parent(cvvv)) == (Nx+2Hx, Ny+2Hy, k_top-6) + CUDA.@allowscalar @test all(cvvv[i, j, k] == cvv[i, j, k] for k in 1+3:k_top-3, j in 1:Ny, i in 1:Nx) + + @test_throws ArgumentError view(cv, :, :, 1) + @test_throws ArgumentError view(cv, :, :, k_top) + @test_throws ArgumentError view(cvv, :, :, 1:1+1) + @test_throws ArgumentError view(cvv, :, :, k_top-1:k_top) + @test_throws ArgumentError view(cvvv, :, :, 1:1+2) + @test_throws ArgumentError view(cvvv, :, :, k_top-2:k_top) + + @test_throws BoundsError cv[:, :, 1] + @test_throws BoundsError cv[:, :, k_top] + @test_throws BoundsError cvv[:, :, 1:1+1] + @test_throws BoundsError cvv[:, :, k_top-1:k_top] + @test_throws BoundsError cvvv[:, :, 1:1+2] + @test_throws BoundsError cvvv[:, :, k_top-2:k_top] + end + end end