Skip to content

Commit

Permalink
Bugfix for parent_index_range (#3573)
Browse files Browse the repository at this point in the history
* Just a hack for now!

* A real solution this time!

* Make views of views work

* Create tests for views of views

* Refine code

* Use CUDA.@allowscalar

* Slightly compact code

* Delete redundant comments

* Use CUDA.@allowscalar only where necessary

* Remove hardcoded 1s and refine code

---------

Co-authored-by: Gregory L. Wagner <[email protected]>
Co-authored-by: Navid C. Constantinou <[email protected]>
  • Loading branch information
3 people authored May 16, 2024
1 parent c91df92 commit 11c317d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 20 deletions.
40 changes: 24 additions & 16 deletions src/Fields/field.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Oceananigans.BoundaryConditions: OBC, MCBC, BoundaryCondition
using Oceananigans.Grids: parent_index_range, index_range_offset, default_indices, all_indices, validate_indices
using Oceananigans.Grids: index_range_contains

using Adapt
using KernelAbstractions: @kernel, @index
Expand Down Expand Up @@ -230,27 +231,24 @@ function Base.similar(f::Field, grid=f.grid)
end

"""
offset_windowed_data(data, loc, grid, indices)
offset_windowed_data(data, data_indices, loc, grid, view_indices)
Return an `OffsetArray` of a `view` of `parent(data)` with `indices`.
Return an `OffsetArray` of `parent(data)`.
If `indices` is not (:, :, :), a `view` of `parent(data)` with `indices`.
If `indices === (:, :, :)`, return an `OffsetArray` of `parent(data)`.
"""
function offset_windowed_data(data, Loc, grid, indices)
function offset_windowed_data(data, data_indices, Loc, grid, view_indices)
halo = halo_size(grid)
topo = map(instantiate, topology(grid))
loc = map(instantiate, Loc)

if indices isa typeof(default_indices(3))
windowed_parent = parent(data)
else
parent_indices = map(parent_index_range, indices, loc, topo, halo)
windowed_parent = view(parent(data), parent_indices...)
end
parent_indices = map(parent_index_range, data_indices, view_indices, loc, topo, halo)
windowed_parent = view(parent(data), parent_indices...)

sz = size(grid)

return offset_data(windowed_parent, loc, topo, sz, halo, indices)
return offset_data(windowed_parent, loc, topo, sz, halo, view_indices)
end

"""
Expand Down Expand Up @@ -305,17 +303,27 @@ function Base.view(f::Field, i, j, k)
loc = location(f)

# Validate indices (convert Int to UnitRange, error for invalid indices)
window_indices = validate_indices((i, j, k), loc, f.grid)

view_indices = validate_indices((i, j, k), loc, f.grid)

if view_indices == f.indices # nothing to "view" here
return f # we want the whole field after all.
end

# Check that the indices actually work here
valid_view_indices = map(index_range_contains, f.indices, view_indices)

all(valid_view_indices) ||
throw(ArgumentError("view indices $((i, j, k)) do not intersect field indices $(f.indices)"))

# Choice: OffsetArray of view of OffsetArray, or OffsetArray of view?
# -> the first retains a reference to the original f.data (an OffsetArray)
# -> the second loses it, so we'd have to "re-offset" the underlying data to access.
# -> we choose the second here, opting to "reduce indirection" at the cost of "index recomputation".
#
# OffsetArray around a view of parent with appropriate indices:
windowed_data = offset_windowed_data(f.data, loc, grid, window_indices)
windowed_data = offset_windowed_data(f.data, f.indices, loc, grid, view_indices)

boundary_conditions = FieldBoundaryConditions(window_indices, f.boundary_conditions)
boundary_conditions = FieldBoundaryConditions(view_indices, f.boundary_conditions)

# "Sliced" Fields created here share data with their parent.
# Therefore we set status=nothing so we don't conflate computation
Expand All @@ -326,7 +334,7 @@ function Base.view(f::Field, i, j, k)
grid,
windowed_data,
boundary_conditions,
window_indices,
view_indices,
f.operand,
status)
end
Expand Down
28 changes: 25 additions & 3 deletions src/Grids/grid_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,34 @@ regular_dimensions(grid) = ()
@inline all_parent_y_indices(grid, loc) = all_parent_indices(loc[2](), topology(grid, 2)(), size(grid, 2), halo_size(grid, 2))
@inline all_parent_z_indices(grid, loc) = all_parent_indices(loc[3](), topology(grid, 3)(), size(grid, 3), halo_size(grid, 3))

# Return the index range of "full" parent arrays that span an entire dimension
parent_index_range(::Colon, loc, topo, halo) = Colon()
parent_index_range(::Base.Slice{<:IdOffsetRange}, loc, topo, halo) = Colon()
parent_index_range(index::UnitRange, loc, topo, halo) = index .+ interior_parent_offset(loc, topo, halo)
parent_index_range(view_indices::UnitRange, ::Nothing, ::Flat, halo) = view_indices
parent_index_range(view_indices::UnitRange, ::Nothing, ::AT, halo) = 1:1 # or Colon()
parent_index_range(view_indices::UnitRange, loc, topo, halo) = view_indices .+ interior_parent_offset(loc, topo, halo)

parent_index_range(index::UnitRange, ::Nothing, ::Flat, halo) = index
parent_index_range(index::UnitRange, ::Nothing, ::AT, halo) = 1:1 # or Colon()
# Return the index range of parent arrays that are themselves windowed
parent_index_range(::Colon, args...) = parent_index_range(args...)

function parent_index_range(parent_indices::UnitRange, view_indices, args...)
start = first(view_indices) - first(parent_indices) + 1
stop = start + length(view_indices) - 1
return UnitRange(start, stop)
end

# intersect_index_range(::Colon, ::Colon) = Colon()
index_range_contains(range, subset::UnitRange) = (first(subset) range) & (last(subset) range)
index_range_contains(::Colon, subset::UnitRange) = true
index_range_contains(::Colon, ::Colon) = true

# Note: this choice means subset indices are defined on the whole grid.
# Thus any UnitRange does not contain `:`.
index_range_contains(range::UnitRange, subset::Colon) = false

# Return the index range of "full" parent arrays that span an entire dimension
parent_windowed_indices(::Colon, loc, topo, halo) = Colon()
parent_windowed_indices(indices::UnitRange, loc, topo, halo) = UnitRange(1, length(indices))

index_range_offset(index::UnitRange, loc, topo, halo) = index[1] - interior_parent_offset(loc, topo, halo)
index_range_offset(::Colon, loc, topo, halo) = - interior_parent_offset(loc, topo, halo)
Expand Down
52 changes: 51 additions & 1 deletion test/test_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ include("dependencies_for_runtests.jl")

using Statistics

using Oceananigans.Grids: total_length
using Oceananigans.Fields: ReducedField, has_velocities
using Oceananigans.Fields: VelocityFields, TracerFields, interpolate, interpolate!
using Oceananigans.Fields: reduced_location
Expand Down Expand Up @@ -295,7 +296,6 @@ end
end

@testset "Setting fields" begin

@info " Testing field setting..."

FieldTypes = (CenterField, XFaceField, YFaceField, ZFaceField)
Expand Down Expand Up @@ -451,4 +451,54 @@ end
end
end
end

@testset "Views of field views" begin
@info " Testing views of field views..."

Nx, Ny, Nz = 1, 1, 7

FieldTypes = (CenterField, XFaceField, YFaceField, ZFaceField)
ZTopologies = (Periodic, Bounded)

for arch in archs, FT in float_types, FieldType in FieldTypes, ZTopology in ZTopologies
grid = RectilinearGrid(arch, FT, size=(Nx, Ny, Nz), x=(0, 1), y=(0, 1), z=(0, 1), topology = (Periodic, Periodic, ZTopology))
Hx, Hy, Hz = halo_size(grid)

c = FieldType(grid)
set!(c, (x, y, z) -> rand())

k_top = total_length(location(c, 3)(), topology(c, 3)(), size(grid, 3))

# First test that the regular view is correct
cv = view(c, :, :, 1+1:k_top-1)
@test size(cv) == (Nx, Ny, k_top-2)
@test size(parent(cv)) == (Nx+2Hx, Ny+2Hy, k_top-2)
CUDA.@allowscalar @test all(cv[i, j, k] == c[i, j, k] for k in 1+1:k_top-1, j in 1:Ny, i in 1:Nx)

# Now test the views of views
cvv = view(cv, :, :, 1+2:k_top-2)
@test size(cvv) == (Nx, Ny, k_top-4)
@test size(parent(cvv)) == (Nx+2Hx, Ny+2Hy, k_top-4)
CUDA.@allowscalar @test all(cvv[i, j, k] == cv[i, j, k] for k in 1+2:k_top-2, j in 1:Ny, i in 1:Nx)

cvvv = view(cvv, :, :, 1+3:k_top-3)
@test size(cvvv) == (1, 1, k_top-6)
@test size(parent(cvvv)) == (Nx+2Hx, Ny+2Hy, k_top-6)
CUDA.@allowscalar @test all(cvvv[i, j, k] == cvv[i, j, k] for k in 1+3:k_top-3, j in 1:Ny, i in 1:Nx)

@test_throws ArgumentError view(cv, :, :, 1)
@test_throws ArgumentError view(cv, :, :, k_top)
@test_throws ArgumentError view(cvv, :, :, 1:1+1)
@test_throws ArgumentError view(cvv, :, :, k_top-1:k_top)
@test_throws ArgumentError view(cvvv, :, :, 1:1+2)
@test_throws ArgumentError view(cvvv, :, :, k_top-2:k_top)

@test_throws BoundsError cv[:, :, 1]
@test_throws BoundsError cv[:, :, k_top]
@test_throws BoundsError cvv[:, :, 1:1+1]
@test_throws BoundsError cvv[:, :, k_top-1:k_top]
@test_throws BoundsError cvvv[:, :, 1:1+2]
@test_throws BoundsError cvvv[:, :, k_top-2:k_top]
end
end
end

0 comments on commit 11c317d

Please sign in to comment.