Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed regridding of bathymetry #309

Merged
merged 12 commits into from
Dec 21, 2024
44 changes: 39 additions & 5 deletions src/Bathymetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ..DataWrangling: download_progress

using Oceananigans
using Oceananigans.Architectures: architecture, on_architecture
using Oceananigans.DistributedComputations: child_architecture
using Oceananigans.DistributedComputations: DistributedGrid, reconstruct_global_grid, barrier!, all_reduce
using Oceananigans.Grids: halo_size, λnodes, φnodes
using Oceananigans.Grids: x_domain, y_domain
using Oceananigans.Grids: topology
Expand Down Expand Up @@ -95,11 +95,12 @@ function regrid_bathymetry(target_grid;
filepath = joinpath(dir, filename)
fileurl = url * "/" * filename # joinpath on windows creates the wrong url

@root if !isfile(filepath) # perform all this only on rank 0, aka the "root" rank
# No need for @root here, because only rank 0 accesses this function
if !isfile(filepath)
Downloads.download(fileurl, filepath; progress=download_progress)
end
dataset = Dataset(filepath)

dataset = Dataset(filepath, "r")

FT = eltype(target_grid)

Expand All @@ -115,7 +116,7 @@ function regrid_bathymetry(target_grid;
close(dataset)

# Diagnose target grid information
arch = child_architecture(architecture(target_grid))
arch = architecture(target_grid)
φ₁, φ₂ = y_domain(target_grid)
λ₁, λ₂ = x_domain(target_grid)

Expand Down Expand Up @@ -247,6 +248,39 @@ function interpolate_bathymetry_in_passes(native_z, target_grid;
return target_z
end

# Regridding bathymetry for distributed grids, we handle the whole process
# on just one rank, and share the results with the other processors.
function regrid_bathymetry(target_grid::DistributedGrid; kw...)
global_grid = reconstruct_global_grid(target_grid)
global_grid = on_architecture(CPU(), global_grid)
arch = architecture(target_grid)
Nx, Ny, _ = size(global_grid)

# If all ranks open a gigantic bathymetry and the memory is
# shared, we could easily have OOM errors.
# We perform the reconstruction only on rank 0 and share the result.
bottom_height = if arch.local_rank == 0
bottom_field = regrid_bathymetry(global_grid; kw...)
bottom_field.data[1:Nx, 1:Ny, 1]
else
zeros(Nx, Ny)
end

# Synchronize
ClimaOcean.global_barrier(arch.communicator)

# Share the result (can we share SubArrays?)
bottom_height = all_reduce(+, bottom_height, arch)

# Partition the result
local_bottom_height = Field{Center, Center, Nothing}(target_grid)
set!(local_bottom_height, bottom_height)
fill_halo_regions!(local_bottom_height)

return local_bottom_height
end


"""
remove_minor_basins!(z_data, keep_major_basins)

Expand Down
71 changes: 71 additions & 0 deletions test/test_distributed_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ include("runtests_setup.jl")
using MPI
MPI.Init()

using NCDatasets
using ClimaOcean.ECCO: download_dataset, metadata_path
using Oceananigans.DistributedComputations: reconstruct_global_grid
using CFTime
using Dates

Expand Down Expand Up @@ -74,3 +76,72 @@ end
@test isfile(metadata_path(metadatum))
end
end

@testset "Distributed Bathymetry interpolation" begin
# We start by building a fake bathyemtry on rank 0 and save it to file
@root begin
λ = -180:0.1:180
φ = 0:0.1:50

Nλ = length(λ)
Nφ = length(φ)

ds = NCDataset("./trivial_bathymetry.nc", "c")

# Define the dimension "lon" and "lat" with the size 361 and 51 resp.
defDim(ds, "lon", Nλ)
defDim(ds, "lat", Nφ)
defVar(ds, "lat", Float32, ("lat", ))
defVar(ds, "lon", Float32, ("lon", ))

# Define the variables z
z = defVar(ds, "z", Float32, ("lon","lat"))

# Generate some example data
data = [Float32(-i) for i = 1:Nλ, j = 1:Nφ]

# write a the complete data set
ds["lon"][:] = λ
ds["lat"][:] = φ
z[:,:] = data

close(ds)
end

global_grid = LatitudeLongitudeGrid(CPU();
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

global_height = regrid_bathymetry(global_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

arch_x = Distributed(CPU(), partition=Partition(4, 1))
arch_y = Distributed(CPU(), partition=Partition(1, 4))
arch_xy = Distributed(CPU(), partition=Partition(2, 2))

for arch in (arch_x, arch_y, arch_xy)
local_grid = LatitudeLongitudeGrid(arch;
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

local_height = regrid_bathymetry(local_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

Nx, Ny, _ = size(local_grid)
rx, ry, _ = arch.local_index
irange = (rx - 1) * Nx + 1 : rx * Nx
jrange = (ry - 1) * Ny + 1 : ry * Ny

@handshake begin
@test interior(global_height, irange, jrange, 1) == interior(local_height, :, :, 1)
end
end
end
Loading