diff --git a/src/Bathymetry.jl b/src/Bathymetry.jl index 6173a40b..bf9ee264 100644 --- a/src/Bathymetry.jl +++ b/src/Bathymetry.jl @@ -7,7 +7,7 @@ using ..DataWrangling: download_progress using Oceananigans using Oceananigans.Architectures: architecture, on_architecture -using Oceananigans.DistributedComputations: child_architecture +using Oceananigans.DistributedComputations: DistributedGrid, reconstruct_global_grid, barrier!, all_reduce using Oceananigans.Grids: halo_size, λnodes, φnodes using Oceananigans.Grids: x_domain, y_domain using Oceananigans.Grids: topology @@ -95,11 +95,12 @@ function regrid_bathymetry(target_grid; filepath = joinpath(dir, filename) fileurl = url * "/" * filename # joinpath on windows creates the wrong url - @root if !isfile(filepath) # perform all this only on rank 0, aka the "root" rank + # No need for @root here, because only rank 0 accesses this function + if !isfile(filepath) Downloads.download(fileurl, filepath; progress=download_progress) end - - dataset = Dataset(filepath) + + dataset = Dataset(filepath, "r") FT = eltype(target_grid) @@ -115,7 +116,7 @@ function regrid_bathymetry(target_grid; close(dataset) # Diagnose target grid information - arch = child_architecture(architecture(target_grid)) + arch = architecture(target_grid) φ₁, φ₂ = y_domain(target_grid) λ₁, λ₂ = x_domain(target_grid) @@ -247,6 +248,39 @@ function interpolate_bathymetry_in_passes(native_z, target_grid; return target_z end +# Regridding bathymetry for distributed grids, we handle the whole process +# on just one rank, and share the results with the other processors. +function regrid_bathymetry(target_grid::DistributedGrid; kw...) + global_grid = reconstruct_global_grid(target_grid) + global_grid = on_architecture(CPU(), global_grid) + arch = architecture(target_grid) + Nx, Ny, _ = size(global_grid) + + # If all ranks open a gigantic bathymetry and the memory is + # shared, we could easily have OOM errors. + # We perform the reconstruction only on rank 0 and share the result. + bottom_height = if arch.local_rank == 0 + bottom_field = regrid_bathymetry(global_grid; kw...) + bottom_field.data[1:Nx, 1:Ny, 1] + else + zeros(Nx, Ny) + end + + # Synchronize + ClimaOcean.global_barrier(arch.communicator) + + # Share the result (can we share SubArrays?) + bottom_height = all_reduce(+, bottom_height, arch) + + # Partition the result + local_bottom_height = Field{Center, Center, Nothing}(target_grid) + set!(local_bottom_height, bottom_height) + fill_halo_regions!(local_bottom_height) + + return local_bottom_height +end + + """ remove_minor_basins!(z_data, keep_major_basins) diff --git a/test/test_distributed_utils.jl b/test/test_distributed_utils.jl index db23d14f..8c790690 100644 --- a/test/test_distributed_utils.jl +++ b/test/test_distributed_utils.jl @@ -3,7 +3,9 @@ include("runtests_setup.jl") using MPI MPI.Init() +using NCDatasets using ClimaOcean.ECCO: download_dataset, metadata_path +using Oceananigans.DistributedComputations: reconstruct_global_grid using CFTime using Dates @@ -74,3 +76,72 @@ end @test isfile(metadata_path(metadatum)) end end + +@testset "Distributed Bathymetry interpolation" begin + # We start by building a fake bathyemtry on rank 0 and save it to file + @root begin + λ = -180:0.1:180 + φ = 0:0.1:50 + + Nλ = length(λ) + Nφ = length(φ) + + ds = NCDataset("./trivial_bathymetry.nc", "c") + + # Define the dimension "lon" and "lat" with the size 361 and 51 resp. + defDim(ds, "lon", Nλ) + defDim(ds, "lat", Nφ) + defVar(ds, "lat", Float32, ("lat", )) + defVar(ds, "lon", Float32, ("lon", )) + + # Define the variables z + z = defVar(ds, "z", Float32, ("lon","lat")) + + # Generate some example data + data = [Float32(-i) for i = 1:Nλ, j = 1:Nφ] + + # write a the complete data set + ds["lon"][:] = λ + ds["lat"][:] = φ + z[:,:] = data + + close(ds) + end + + global_grid = LatitudeLongitudeGrid(CPU(); + size = (40, 40, 1), + longitude = (0, 100), + latitude = (0, 20), + z = (0, 1)) + + global_height = regrid_bathymetry(global_grid; + dir = "./", + filename = "trivial_bathymetry.nc", + interpolation_passes=10) + + arch_x = Distributed(CPU(), partition=Partition(4, 1)) + arch_y = Distributed(CPU(), partition=Partition(1, 4)) + arch_xy = Distributed(CPU(), partition=Partition(2, 2)) + + for arch in (arch_x, arch_y, arch_xy) + local_grid = LatitudeLongitudeGrid(arch; + size = (40, 40, 1), + longitude = (0, 100), + latitude = (0, 20), + z = (0, 1)) + + local_height = regrid_bathymetry(local_grid; + dir = "./", + filename = "trivial_bathymetry.nc", + interpolation_passes=10) + + Nx, Ny, _ = size(local_grid) + rx, ry, _ = arch.local_index + irange = (rx - 1) * Nx + 1 : rx * Nx + jrange = (ry - 1) * Ny + 1 : ry * Ny + + @handshake begin + @test interior(global_height, irange, jrange, 1) == interior(local_height, :, :, 1) + end + end +end \ No newline at end of file