Skip to content

Commit

Permalink
Merge branch 'main' into ss/data-dir
Browse files Browse the repository at this point in the history
  • Loading branch information
navidcy authored Dec 22, 2024
2 parents 395a2d5 + 0d7da30 commit 631d6e7
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 36 deletions.
44 changes: 39 additions & 5 deletions src/Bathymetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ..DataWrangling: download_progress

using Oceananigans
using Oceananigans.Architectures: architecture, on_architecture
using Oceananigans.DistributedComputations: child_architecture
using Oceananigans.DistributedComputations: DistributedGrid, reconstruct_global_grid, barrier!, all_reduce
using Oceananigans.Grids: halo_size, λnodes, φnodes
using Oceananigans.Grids: x_domain, y_domain
using Oceananigans.Grids: topology
Expand Down Expand Up @@ -95,11 +95,12 @@ function regrid_bathymetry(target_grid;
filepath = joinpath(dir, filename)
fileurl = url * "/" * filename # joinpath on windows creates the wrong url

@root if !isfile(filepath) # perform all this only on rank 0, aka the "root" rank
# No need for @root here, because only rank 0 accesses this function
if !isfile(filepath)
Downloads.download(fileurl, filepath; progress=download_progress)
end
dataset = Dataset(filepath)

dataset = Dataset(filepath, "r")

FT = eltype(target_grid)

Expand All @@ -115,7 +116,7 @@ function regrid_bathymetry(target_grid;
close(dataset)

# Diagnose target grid information
arch = child_architecture(architecture(target_grid))
arch = architecture(target_grid)
φ₁, φ₂ = y_domain(target_grid)
λ₁, λ₂ = x_domain(target_grid)

Expand Down Expand Up @@ -247,6 +248,39 @@ function interpolate_bathymetry_in_passes(native_z, target_grid;
return target_z
end

# Regridding bathymetry for distributed grids, we handle the whole process
# on just one rank, and share the results with the other processors.
function regrid_bathymetry(target_grid::DistributedGrid; kw...)
global_grid = reconstruct_global_grid(target_grid)
global_grid = on_architecture(CPU(), global_grid)
arch = architecture(target_grid)
Nx, Ny, _ = size(global_grid)

# If all ranks open a gigantic bathymetry and the memory is
# shared, we could easily have OOM errors.
# We perform the reconstruction only on rank 0 and share the result.
bottom_height = if arch.local_rank == 0
bottom_field = regrid_bathymetry(global_grid; kw...)
bottom_field.data[1:Nx, 1:Ny, 1]
else
zeros(Nx, Ny)
end

# Synchronize
ClimaOcean.global_barrier(arch.communicator)

# Share the result (can we share SubArrays?)
bottom_height = all_reduce(+, bottom_height, arch)

# Partition the result
local_bottom_height = Field{Center, Center, Nothing}(target_grid)
set!(local_bottom_height, bottom_height)
fill_halo_regions!(local_bottom_height)

return local_bottom_height
end


"""
remove_minor_basins!(z_data, keep_major_basins)
Expand Down
3 changes: 3 additions & 0 deletions src/DataWrangling/ECCO/ECCO_metadata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ ECCO4_short_names = Dict(
:salinity => "SALT",
:u_velocity => "EVEL",
:v_velocity => "NVEL",
:free_surface => "SSH",
:sea_ice_thickness => "SIheff",
:sea_ice_area_fraction => "SIarea",
:net_heat_flux => "oceQnet"
Expand All @@ -178,6 +179,7 @@ ECCO2_short_names = Dict(
:salinity => "SALT",
:u_velocity => "UVEL",
:v_velocity => "VVEL",
:free_surface => "SSH",
:sea_ice_thickness => "SIheff",
:sea_ice_area_fraction => "SIarea",
:net_heat_flux => "oceQnet"
Expand All @@ -186,6 +188,7 @@ ECCO2_short_names = Dict(
ECCO_location = Dict(
:temperature => (Center, Center, Center),
:salinity => (Center, Center, Center),
:free_surface => (Center, Center, Nothing),
:sea_ice_thickness => (Center, Center, Nothing),
:sea_ice_area_fraction => (Center, Center, Nothing),
:net_heat_flux => (Center, Center, Nothing),
Expand Down
53 changes: 29 additions & 24 deletions src/DataWrangling/ECCO/ECCO_restoring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Oceananigans.Utils: Time

using Base
using NCDatasets
using JLD2
using JLD2

using Dates: Second
using ClimaOcean: stateindex
Expand Down Expand Up @@ -35,7 +35,10 @@ end
Adapt.adapt_structure(to, b::ECCONetCDFBackend{N, C}) where {N, C} = ECCONetCDFBackend{N, C}(b.start, b.length, nothing, nothing)

"""
ECCONetCDFBackend(length; on_native_grid=false, inpainting=NearestNeighborInpainting(Inf))
ECCONetCDFBackend(length, metadata;
on_native_grid = false,
cache_inpainted_data = false,
inpainting = NearestNeighborInpainting(Inf))
Represent an ECCO FieldTimeSeries backed by ECCO native netCDF files.
Each time instance is stored in an individual file.
Expand Down Expand Up @@ -76,7 +79,7 @@ function set!(fts::ECCOFieldTimeSeries)
end

"""
ECCO_times(metadata; start_time = metadata.dates[1])
ECCO_times(metadata; start_time = first(metadata).dates)
Extract the time values from the given metadata and calculates the time difference
from the start time.
Expand Down Expand Up @@ -106,7 +109,8 @@ end
ECCOFieldTimeSeries(metadata::ECCOMetadata [, arch_or_grid=CPU() ];
time_indices_in_memory = 2,
time_indexing = Cyclical(),
inpainting = nothing)
inpainting = nothing,
cache_inpainted_data = true)
Create a field time series object for ECCO data.
Expand All @@ -115,7 +119,7 @@ Arguments
- `metadata`: `ECCOMetadata` containing information about the ECCO dataset.
- `arch_or_grid`: Either a grid to interpolate ECCO data to, or an `arch`itecture
- `arch_or_grid`: Either a grid to interpolate the ECCO data to, or an `arch`itecture
to use for the native ECCO grid. Default: CPU().
Keyword Arguments
Expand Down Expand Up @@ -236,45 +240,45 @@ end
ECCORestoring(variable_name::Symbol, [ arch_or_grid = CPU(), ];
version = ECCO4Monthly(),
dates = all_ECCO_dates(version),
time_indices_in_memory = 2,
time_indices_in_memory = 2,
time_indexing = Cyclical(),
mask = 1,
rate = 1,
data_dir = download_ECCO_cache,
inpainting = NearestNeighborInpainting(Inf),
cache_inpainted_data = true)
Build a forcing term that restores to values stored in an ECCO field time series.
The restoring is applied as a forcing on the right hand side of the evolution equations calculated as
Return a forcing term that restores to values stored in an ECCO field time series.
The restoring is applied as a forcing on the right hand side of the evolution
equations calculated as:
```math
Fψ = r μ (ψ_{ECCO} - ψ)
```
where ``μ`` is the mask, ``r`` is the restoring rate, ``ψ`` is the simulation variable,
and the ECCO variable ``ψ_ECCO`` is linearly interpolated in space and time from the
ECCO dataset of choice to the simulation grid and time.
and ``ψ_{ECCO}`` is the ECCO variable that is linearly interpolated in space and time
from the ECCO dataset of choice to the simulation grid and time.
Arguments
=========
- `variable_name`: The name of the variable to restore. Choices include:
* `:temperature`,
* `:salinity`,
* `:u_velocity`,
* `:v_velocity`,
* `:sea_ice_thickness`,
* `:sea_ice_area_fraction`.
Note that `ECCOMetadata` may be provided as the first argument instead
of `variable_name`. In this case the `version` and `dates` kwargs (described below)
cannot be provided.
* `:temperature`,
* `:salinity`,
* `:u_velocity`,
* `:v_velocity`,
* `:sea_ice_thickness`,
* `:sea_ice_area_fraction`.
- `arch_or_grid`: Either the architecture of the simulation, or a grid on which the ECCO data
is pre-interpolated when loaded. If an `arch`itecture is provided, such as
`arch_or_grid = CPU()` or `arch_or_grid = GPU()`, ECCO data
will be interpolated on-the-fly when the forcing tendency is computed.
Default: CPU().
`arch_or_grid = CPU()` or `arch_or_grid = GPU()`, ECCO data are interpolated
on-the-fly when the forcing tendency is computed. Default: CPU().
!!! info "Providing `ECCOMetadata` instead of `variable_name`"
Note that `ECCOMetadata` may be provided as the first argument instead of `variable_name`.
In this case the `version` and `dates` kwargs (described below) cannot be provided.
Keyword Arguments
=================
Expand Down Expand Up @@ -344,7 +348,8 @@ function Base.show(io::IO, p::ECCORestoring)
"├── restored variable: ", summary(p.variable_name), '\n',
"├── restoring dataset: ", summary(p.field_time_series.backend.metadata), '\n',
"├── restoring rate: ", p.rate, '\n',
"└── mask: ", summary(p.mask))
"├── mask: ", summary(p.mask), '\n',
"└── grid: ", summary(p.native_grid))
end

regularize_forcing(forcing::ECCORestoring, field, field_name, model_field_names) = forcing
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ const celsius_to_kelvin = 273.15
Base.summary(crf::OceanSeaIceSurfaceFluxes) = "OceanSeaIceSurfaceFluxes"
Base.show(io::IO, crf::OceanSeaIceSurfaceFluxes) = print(io, summary(crf))

function Base.show(io::IO, crf::OceanSeaIceSurfaceFluxes)
print(io, summary(crf), "\n")
print(io, "├── radiation: ", summary(crf.radiation), "\n")
print(io, "└── turbulent coefficients: ", summary(crf.turbulent), "\n")
return nothing
end

const SeaIceSimulation = Simulation{<:SeaIceModel}

function OceanSeaIceSurfaceFluxes(ocean, sea_ice=nothing;
Expand Down
16 changes: 9 additions & 7 deletions src/OceanSeaIceModels/ocean_sea_ice_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using SeawaterPolynomials: TEOS10EquationOfState

# Simulations interface
import Oceananigans: fields, prognostic_fields
import Oceananigans.Architectures: architecture
import Oceananigans.Fields: set!
import Oceananigans.Models: timestepper, NaNChecker, default_nan_checker
import Oceananigans.OutputWriters: default_included_properties
Expand All @@ -15,9 +16,8 @@ import Oceananigans.TimeSteppers: time_step!, update_state!, time
import Oceananigans.Utils: prettytime
import Oceananigans.Models: timestepper, NaNChecker, default_nan_checker

struct OceanSeaIceModel{I, A, O, F, C, G} <: AbstractModel{Nothing}
struct OceanSeaIceModel{I, A, O, F, C} <: AbstractModel{Nothing}
clock :: C
grid :: G # TODO: make it so Oceananigans.Simulation does not require this
atmosphere :: A
sea_ice :: I
ocean :: O
Expand All @@ -27,20 +27,23 @@ end
const OSIM = OceanSeaIceModel

function Base.summary(model::OSIM)
A = nameof(typeof(architecture(model.grid)))
G = nameof(typeof(model.grid))
return string("OceanSeaIceModel{$A, $G}",
A = nameof(typeof(architecture(model)))
return string("OceanSeaIceModel{$A}",
"(time = ", prettytime(model.clock.time), ", iteration = ", model.clock.iteration, ")")
end

function Base.show(io::IO, cm::OSIM)
print(io, summary(cm), "\n")
print(io, "├── ocean: ", summary(cm.ocean.model), "\n")
print(io, "├── atmosphere: ", summary(cm.atmosphere), "\n")
print(io, "└── sea_ice: ", summary(cm.sea_ice), "\n")
print(io, "├── sea_ice: ", summary(cm.sea_ice), "\n")
print(io, "└── fluxes: ", summary(cm.fluxes))
return nothing
end

# Assumption: We have an ocean!
architecture(model::OSIM) = architecture(model.ocean)

prettytime(model::OSIM) = prettytime(model.clock.time)
iteration(model::OSIM) = model.clock.iteration
timestepper(::OSIM) = nothing
Expand Down Expand Up @@ -98,7 +101,6 @@ function OceanSeaIceModel(ocean, sea_ice=FreezingLimitedOceanTemperature();
radiation)

ocean_sea_ice_model = OceanSeaIceModel(clock,
ocean.model.grid,
atmosphere,
sea_ice,
ocean,
Expand Down
71 changes: 71 additions & 0 deletions test/test_distributed_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ include("runtests_setup.jl")
using MPI
MPI.Init()

using NCDatasets
using ClimaOcean.ECCO: download_dataset, metadata_path
using Oceananigans.DistributedComputations: reconstruct_global_grid
using CFTime
using Dates

Expand Down Expand Up @@ -74,3 +76,72 @@ end
@test isfile(metadata_path(metadatum))
end
end

@testset "Distributed Bathymetry interpolation" begin
# We start by building a fake bathyemtry on rank 0 and save it to file
@root begin
λ = -180:0.1:180
φ = 0:0.1:50

= length(λ)
= length(φ)

ds = NCDataset("./trivial_bathymetry.nc", "c")

# Define the dimension "lon" and "lat" with the size 361 and 51 resp.
defDim(ds, "lon", Nλ)
defDim(ds, "lat", Nφ)
defVar(ds, "lat", Float32, ("lat", ))
defVar(ds, "lon", Float32, ("lon", ))

# Define the variables z
z = defVar(ds, "z", Float32, ("lon","lat"))

# Generate some example data
data = [Float32(-i) for i = 1:Nλ, j = 1:Nφ]

# write a the complete data set
ds["lon"][:] = λ
ds["lat"][:] = φ
z[:,:] = data

close(ds)
end

global_grid = LatitudeLongitudeGrid(CPU();
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

global_height = regrid_bathymetry(global_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

arch_x = Distributed(CPU(), partition=Partition(4, 1))
arch_y = Distributed(CPU(), partition=Partition(1, 4))
arch_xy = Distributed(CPU(), partition=Partition(2, 2))

for arch in (arch_x, arch_y, arch_xy)
local_grid = LatitudeLongitudeGrid(arch;
size = (40, 40, 1),
longitude = (0, 100),
latitude = (0, 20),
z = (0, 1))

local_height = regrid_bathymetry(local_grid;
dir = "./",
filename = "trivial_bathymetry.nc",
interpolation_passes=10)

Nx, Ny, _ = size(local_grid)
rx, ry, _ = arch.local_index
irange = (rx - 1) * Nx + 1 : rx * Nx
jrange = (ry - 1) * Ny + 1 : ry * Ny

@handshake begin
@test interior(global_height, irange, jrange, 1) == interior(local_height, :, :, 1)
end
end
end

0 comments on commit 631d6e7

Please sign in to comment.