Skip to content

Commit

Permalink
Distributed regridding of bathymetry (#309)
Browse files Browse the repository at this point in the history
* distributed regridding

* another strategy

* fix imports

* cannot share subarrays apparently

* remember not to redownload

* fix test

* add corect tests

* test more architectures

* test other architectures

* fix  distributed tests
  • Loading branch information
simone-silvestri authored Dec 21, 2024
1 parent ff4b883 commit d8992d5
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 5 deletions.
44 changes: 39 additions & 5 deletions src/Bathymetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ..DataWrangling: download_progress

using Oceananigans
using Oceananigans.Architectures: architecture, on_architecture
using Oceananigans.DistributedComputations: child_architecture
using Oceananigans.DistributedComputations: DistributedGrid, reconstruct_global_grid, barrier!, all_reduce
using Oceananigans.Grids: halo_size, λnodes, φnodes
using Oceananigans.Grids: x_domain, y_domain
using Oceananigans.Grids: topology
Expand Down Expand Up @@ -95,11 +95,12 @@ function regrid_bathymetry(target_grid;
filepath = joinpath(dir, filename)
fileurl = url * "/" * filename # joinpath on windows creates the wrong url

@root if !isfile(filepath) # perform all this only on rank 0, aka the "root" rank
# No need for @root here, because only rank 0 accesses this function
if !isfile(filepath)
Downloads.download(fileurl, filepath; progress=download_progress)
end
dataset = Dataset(filepath)

dataset = Dataset(filepath, "r")

FT = eltype(target_grid)

Expand All @@ -115,7 +116,7 @@ function regrid_bathymetry(target_grid;
close(dataset)

# Diagnose target grid information
arch = child_architecture(architecture(target_grid))
arch = architecture(target_grid)
φ₁, φ₂ = y_domain(target_grid)
λ₁, λ₂ = x_domain(target_grid)

Expand Down Expand Up @@ -247,6 +248,39 @@ function interpolate_bathymetry_in_passes(native_z, target_grid;
return target_z
end

# Regridding bathymetry for distributed grids, we handle the whole process
# on just one rank, and share the results with the other processors.
function regrid_bathymetry(target_grid::DistributedGrid; kw...)
global_grid = reconstruct_global_grid(target_grid)
global_grid = on_architecture(CPU(), global_grid)
arch = architecture(target_grid)
Nx, Ny, _ = size(global_grid)

# If all ranks open a gigantic bathymetry and the memory is
# shared, we could easily have OOM errors.
# We perform the reconstruction only on rank 0 and share the result.
bottom_height = if arch.local_rank == 0
bottom_field = regrid_bathymetry(global_grid; kw...)
bottom_field.data[1:Nx, 1:Ny, 1]
else
zeros(Nx, Ny)
end

# Synchronize
ClimaOcean.global_barrier(arch.communicator)

# Share the result (can we share SubArrays?)
bottom_height = all_reduce(+, bottom_height, arch)

# Partition the result
local_bottom_height = Field{Center, Center, Nothing}(target_grid)
set!(local_bottom_height, bottom_height)
fill_halo_regions!(local_bottom_height)

return local_bottom_height
end


"""
remove_minor_basins!(z_data, keep_major_basins)
Expand Down
71 changes: 71 additions & 0 deletions test/test_distributed_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ include("runtests_setup.jl")
using MPI
MPI.Init()

using NCDatasets
using ClimaOcean.ECCO: download_dataset, metadata_path
using Oceananigans.DistributedComputations: reconstruct_global_grid
using CFTime
using Dates

Expand Down Expand Up @@ -74,3 +76,72 @@ end
@test isfile(metadata_path(metadatum))
end
end

@testset "Distributed Bathymetry interpolation" begin
# We start by building a fake bathyemtry on rank 0 and save it to file
@root begin
λ = -180:0.1:180
φ = 0:0.1:50

= length(λ)
= length(φ)

ds = NCDataset("./trivial_bathymetry.nc", "c")

# Define the dimension "lon" and "lat" with the size 361 and 51 resp.
defDim(ds, "lon", Nλ)
defDim(ds, "lat", Nφ)
defVar(ds, "lat", Float32, ("lat", ))
defVar(ds, "lon", Float32, ("lon", ))

# Define the variables z
z = defVar(ds, "z", Float32, ("lon","lat"))

# Generate some example data
data = [Float32(-i) for i = 1:Nλ, j = 1:Nφ]

# write a the complete data set
ds["lon"][:] = λ
ds["lat"][:] = φ
z[:,:] = data

close(ds)
end

global_grid = LatitudeLongitudeGrid(CPU();
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

global_height = regrid_bathymetry(global_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

arch_x = Distributed(CPU(), partition=Partition(4, 1))
arch_y = Distributed(CPU(), partition=Partition(1, 4))
arch_xy = Distributed(CPU(), partition=Partition(2, 2))

for arch in (arch_x, arch_y, arch_xy)
local_grid = LatitudeLongitudeGrid(arch;
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

local_height = regrid_bathymetry(local_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

Nx, Ny, _ = size(local_grid)
rx, ry, _ = arch.local_index
irange = (rx - 1) * Nx + 1 : rx * Nx
jrange = (ry - 1) * Ny + 1 : ry * Ny

@handshake begin
@test interior(global_height, irange, jrange, 1) == interior(local_height, :, :, 1)
end
end
end

0 comments on commit d8992d5

Please sign in to comment.