From 97d73449ba0995eddc2afd6f28fa06c43c56f138 Mon Sep 17 00:00:00 2001 From: Ali Ramadhan Date: Tue, 29 Oct 2024 17:56:13 -0600 Subject: [PATCH] Allow `FieldTimeSeries` to pass keyword arguments to `jldopen` (#3739) Co-authored-by: Simone Silvestri --- src/OutputReaders/field_dataset.jl | 24 ++++-- src/OutputReaders/field_time_series.jl | 84 +++++++++++-------- .../field_time_series_indexing.jl | 23 +++-- src/OutputReaders/set_field_time_series.jl | 9 +- test/test_forcings.jl | 12 +-- test/test_output_readers.jl | 40 +++++++-- 6 files changed, 122 insertions(+), 70 deletions(-) diff --git a/src/OutputReaders/field_dataset.jl b/src/OutputReaders/field_dataset.jl index dc6072eb6d..cf22389ce2 100644 --- a/src/OutputReaders/field_dataset.jl +++ b/src/OutputReaders/field_dataset.jl @@ -1,7 +1,8 @@ -struct FieldDataset{F, M, P} - fields :: F - metadata :: M - filepath :: P +struct FieldDataset{F, M, P, KW} + fields :: F + metadata :: M + filepath :: P + reader_kw :: KW end """ @@ -22,17 +23,24 @@ linearly. `file["metadata"]`. - `grid`: May be specified to override the grid used in the JLD2 file. + +- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2) + to be used when opening files. """ function FieldDataset(filepath; - architecture=CPU(), grid=nothing, backend=InMemory(), metadata_paths=["metadata"]) + architecture = CPU(), + grid = nothing, + backend = InMemory(), + metadata_paths = ["metadata"], + reader_kw = Dict{Symbol, Any}()) - file = jldopen(filepath) + file = jldopen(filepath; reader_kw...) field_names = keys(file["timeseries"]) filter!(k -> k != "t", field_names) # Time is not a field. ds = Dict{String, FieldTimeSeries}( - name => FieldTimeSeries(filepath, name; architecture, backend, grid) + name => FieldTimeSeries(filepath, name; architecture, backend, grid, reader_kw) for name in field_names ) @@ -44,7 +52,7 @@ function FieldDataset(filepath; close(file) - return FieldDataset(ds, metadata, abspath(filepath)) + return FieldDataset(ds, metadata, abspath(filepath), reader_kw) end Base.getindex(fds::FieldDataset, inds...) = Base.getindex(fds.fields, inds...) diff --git a/src/OutputReaders/field_time_series.jl b/src/OutputReaders/field_time_series.jl index 23704ac625..a07377e22c 100644 --- a/src/OutputReaders/field_time_series.jl +++ b/src/OutputReaders/field_time_series.jl @@ -85,7 +85,7 @@ period = t[end] - t[1] + Δt """ struct Cyclical{FT} period :: FT -end +end Cyclical() = Cyclical(nothing) @@ -164,7 +164,7 @@ Nt = 5 backend = InMemory(4, 3) # so we have (4, 5, 1) n = 1 # so, the right answer is m̃ = 3 m = 1 - (4 - 1) # = -2 -m̃ = mod1(-2, 5) # = 3 ✓ +m̃ = mod1(-2, 5) # = 3 ✓ ``` # Another shifting + wrapping example @@ -213,7 +213,7 @@ Base.length(backend::PartlyInMemory) = backend.length ##### FieldTimeSeries ##### -mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: AbstractField{LX, LY, LZ, G, ET, 4} +mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4} data :: D grid :: G backend :: K @@ -223,16 +223,18 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A path :: P name :: N time_indexing :: TI - + reader_kw :: KW + function FieldTimeSeries{LX, LY, LZ}(data::D, grid::G, backend::K, bcs::B, - indices::I, + indices::I, times, path, name, - time_indexing) where {LX, LY, LZ, K, D, G, B, I} + time_indexing, + reader_kw) where {LX, LY, LZ, K, D, G, B, I} ET = eltype(data) @@ -250,7 +252,7 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A times = on_architecture(architecture(grid), times) end - + if time_indexing isa Cyclical{Nothing} # we have to infer the period Δt = @allowscalar times[end] - times[end-1] period = @allowscalar times[end] - times[1] + Δt @@ -261,23 +263,25 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A TI = typeof(time_indexing) P = typeof(path) N = typeof(name) + KW = typeof(reader_kw) - return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N}(data, grid, backend, bcs, - indices, times, path, name, - time_indexing) + return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW}(data, grid, backend, bcs, + indices, times, path, name, + time_indexing, reader_kw) end end -on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} = +on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} = FieldTimeSeries{LX, LY, LZ}(on_architecture(to, fts.data), on_architecture(to, fts.grid), on_architecture(to, fts.backend), on_architecture(to, fts.bcs), - on_architecture(to, fts.indices), + on_architecture(to, fts.indices), on_architecture(to, fts.times), on_architecture(to, fts.path), on_architecture(to, fts.name), - on_architecture(to, fts.time_indexing)) + on_architecture(to, fts.time_indexing), + on_architecture(to, fts.reader_kw)) ##### ##### Minimal implementation of FieldTimeSeries for use in GPU kernels @@ -290,7 +294,7 @@ struct GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K, ET, D, χ} <: AbstractField{ times :: χ backend :: K time_indexing :: TI - + function GPUAdaptedFieldTimeSeries{LX, LY, LZ}(data::D, times::χ, backend::K, @@ -313,7 +317,7 @@ const FTS{LX, LY, LZ, TI, K} = FieldTimeSeries{LX, LY, LZ, TI, K} w const GPUFTS{LX, LY, LZ, TI, K} = GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K} where {LX, LY, LZ, TI, K} const FlavorOfFTS{LX, LY, LZ, TI, K} = Union{GPUFTS{LX, LY, LZ, TI, K}, - FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K} + FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K} const InMemoryFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:AbstractInMemoryBackend} const OnDiskFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:OnDisk} @@ -345,7 +349,7 @@ instantiate(T::Type) = T() new_data(FT, grid, loc, indices, ::Nothing) = nothing # Apparently, not explicitly specifying Int64 in here makes this function -# fail on x86 processors where `Int` is implied to be `Int32` +# fail on x86 processors where `Int` is implied to be `Int32` # see ClimaOcean commit 3c47d887659d81e0caed6c9df41b7438e1f1cd52 at https://github.com/CliMA/ClimaOcean.jl/actions/runs/8804916198/job/24166354095) function new_data(FT, grid, loc, indices, Nt::Union{Int, Int64}) space_size = total_size(grid, loc, indices) @@ -360,12 +364,13 @@ time_indices_length(backend::PartlyInMemory, times) = length(backend) time_indices_length(::OnDisk, times) = nothing function FieldTimeSeries(loc, grid, times=(); - indices = (:, :, :), + indices = (:, :, :), backend = InMemory(), - path = nothing, + path = nothing, name = nothing, time_indexing = Linear(), - boundary_conditions = nothing) + boundary_conditions = nothing, + reader_kw = Dict{Symbol, Any}()) LX, LY, LZ = loc @@ -376,9 +381,9 @@ function FieldTimeSeries(loc, grid, times=(); isnothing(path) && error(ArgumentError("Must provide the keyword argument `path` when `backend=OnDisk()`.")) isnothing(name) && error(ArgumentError("Must provide the keyword argument `name` when `backend=OnDisk()`.")) end - - return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, - indices, times, path, name, time_indexing) + + return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices, + times, path, name, time_indexing, reader_kw) end """ @@ -405,10 +410,16 @@ end struct UnspecifiedBoundaryConditions end """ - FieldTimeSeries(path, name, backend = InMemory(); + FieldTimeSeries(path, name; + backend = InMemory(), + architecture = nothing, grid = nothing, + location = nothing, + boundary_conditions = UnspecifiedBoundaryConditions(), + time_indexing = Linear(), iterations = nothing, - times = nothing) + times = nothing, + reader_kw = Dict{Symbol, Any}()) Return a `FieldTimeSeries` containing a time-series of the field `name` load from JLD2 output located at `path`. @@ -427,6 +438,9 @@ Keyword arguments - `times`: Save times to load, as determined through an approximate floating point comparison to recorded save times. Defaults to times associated with `iterations`. Takes precedence over `iterations` if `times` is specified. + +- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2) + to be used when opening files. """ function FieldTimeSeries(path::String, name::String; backend = InMemory(), @@ -436,9 +450,10 @@ function FieldTimeSeries(path::String, name::String; boundary_conditions = UnspecifiedBoundaryConditions(), time_indexing = Linear(), iterations = nothing, - times = nothing) + times = nothing, + reader_kw = Dict{Symbol, Any}()) - file = jldopen(path) + file = jldopen(path; reader_kw...) # Defaults isnothing(iterations) && (iterations = parse.(Int, keys(file["timeseries/t"]))) @@ -520,8 +535,8 @@ function FieldTimeSeries(path::String, name::String; Nt = time_indices_length(backend, times) data = new_data(eltype(grid), grid, loc, indices, Nt) - time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, - indices, times, path, name, time_indexing) + time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices, + times, path, name, time_indexing, reader_kw) set!(time_series, path, name) @@ -533,7 +548,8 @@ end grid = nothing, architecture = nothing, indices = (:, :, :), - boundary_conditions = nothing) + boundary_conditions = nothing, + reader_kw = Dict{Symbol, Any}()) Load a field called `name` saved in a JLD2 file at `path` at `iter`ation. Unless specified, the `grid` is loaded from `path`. @@ -542,7 +558,8 @@ function Field(location, path::String, name::String, iter; grid = nothing, architecture = nothing, indices = (:, :, :), - boundary_conditions = nothing) + boundary_conditions = nothing, + reader_kw = Dict{Symbol, Any}()) # Default to CPU if neither architecture nor grid is specified if isnothing(architecture) @@ -552,9 +569,9 @@ function Field(location, path::String, name::String, iter; architecture = Architectures.architecture(grid) end end - + # Load the grid and data from file - file = jldopen(path) + file = jldopen(path; reader_kw...) isnothing(grid) && (grid = file["serialized/grid"]) raw_data = file["timeseries/$name/$iter"] @@ -565,7 +582,7 @@ function Field(location, path::String, name::String, iter; grid = on_architecture(architecture, grid) raw_data = on_architecture(architecture, raw_data) data = offset_data(raw_data, grid, location, indices) - + return Field(location, grid; boundary_conditions, indices, data) end @@ -625,4 +642,3 @@ function fill_halo_regions!(fts::InMemoryFTS) return nothing end - diff --git a/src/OutputReaders/field_time_series_indexing.jl b/src/OutputReaders/field_time_series_indexing.jl index 6a4683f9d0..59868330e7 100644 --- a/src/OutputReaders/field_time_series_indexing.jl +++ b/src/OutputReaders/field_time_series_indexing.jl @@ -14,7 +14,7 @@ import Oceananigans.Fields: interpolate # Cyclical implementation if out-of-bounds (wrap around the time-series) @inline function interpolating_time_indices(ti::Cyclical, times, t) Nt = length(times) - t¹ = first(times) + t¹ = first(times) tᴺ = last(times) T = ti.period @@ -32,14 +32,14 @@ import Oceananigans.Fields: interpolate uncycled_indices = (ñ, n₁, n₂) return ifelse(cycling, cycled_indices, uncycled_indices) -end +end # Clamp mode if out-of-bounds, i.e get the neareast neighbor @inline function interpolating_time_indices(::Clamp, times, t) n, n₁, n₂ = time_index_binary_search(times, t) beyond_indices = (0, n₂, n₂) # Beyond the last time: return n₂ - before_indices = (0, n₁, n₁) # Before the first time: return n₁ + before_indices = (0, n₁, n₁) # Before the first time: return n₁ unclamped_indices = (n, n₁, n₂) # Business as usual Nt = length(times) @@ -53,13 +53,13 @@ end @inline function time_index_binary_search(times, t) Nt = length(times) - # n₁ and n₂ are the index to interpolate inbetween and + # n₁ and n₂ are the index to interpolate inbetween and # n is a fractional index where 0 ≤ n ≤ 1 n₁, n₂ = index_binary_search(times, t, Nt) @inbounds begin - t₁ = times[n₁] - t₂ = times[n₂] + t₁ = times[n₁] + t₂ = times[n₂] end # "Fractional index" ñ ∈ (0, 1) @@ -79,7 +79,7 @@ import Base: getindex function getindex(fts::OnDiskFTS, n::Int) # Load data arch = architecture(fts) - file = jldopen(fts.path) + file = jldopen(fts.path; fts.reader_kw...) iter = keys(file["timeseries/t"])[n] raw_data = on_architecture(arch, file["timeseries/$(fts.name)/$iter"]) close(file) @@ -117,7 +117,7 @@ const YZFTS = FlavorOfFTS{Nothing, <:Any, <:Any, <:Any, <:Any} @inline function interpolating_getindex(fts, i, j, k, time_index) ñ, n₁, n₂ = interpolating_time_indices(fts.time_indexing, fts.times, time_index.time) - + @inbounds begin ψ₁ = getindex(fts, i, j, k, n₁) ψ₂ = getindex(fts, i, j, k, n₂) @@ -229,14 +229,14 @@ end ##### FieldTimeSeries updating ##### -# Let's make sure `times` is available on the CPU. This is always the case -# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing. +# Let's make sure `times` is available on the CPU. This is always the case +# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing. # TODO: Copying the whole array is a bit unclean, maybe find a way that avoids the penalty of allocating and copying memory. # This would require refactoring `FieldTimeSeries` to include a cpu-allocated times array cpu_interpolating_time_indices(::CPU, times, time_indexing, t, arch) = interpolating_time_indices(time_indexing, times, t) cpu_interpolating_time_indices(::CPU, times::AbstractVector, time_indexing, t) = interpolating_time_indices(time_indexing, times, t) -function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t) +function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t) cpu_times = on_architecture(CPU(), times) return interpolating_time_indices(time_indexing, cpu_times, t) end @@ -279,4 +279,3 @@ function getindex(fts::InMemoryFTS, n::Int) return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices) end - diff --git a/src/OutputReaders/set_field_time_series.jl b/src/OutputReaders/set_field_time_series.jl index d450926b3e..577082782e 100644 --- a/src/OutputReaders/set_field_time_series.jl +++ b/src/OutputReaders/set_field_time_series.jl @@ -11,7 +11,7 @@ find_time_index(time::Number, file_times) = findfirst(t -> t ≈ time, fil find_time_index(time::AbstractTime, file_times) = findfirst(t -> t == time, file_times) function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name) - file = jldopen(path) + file = jldopen(path; fts.reader_kw...) file_iterations = iterations_from_file(file) file_times = [file["timeseries/t/$i"] for i in file_iterations] close(file) @@ -33,7 +33,7 @@ function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name) end file_iter = file_iterations[file_index] - + # Note: use the CPU for this step field_n = Field(location(fts), path, name, file_iter, architecture = cpu_architecture(arch), @@ -51,7 +51,7 @@ set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value) function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) raw_data = parent(fts) - file = jldopen(path) + file = jldopen(path; fts.reader_kw...) for (n, field) in enumerate(fields_vector) nth_raw_data = view(raw_data, :, :, :, n) @@ -68,7 +68,7 @@ end function maybe_write_property!(file, property, data) try test = file[property] - catch + catch file[property] = data end end @@ -101,4 +101,3 @@ function initialize_file!(file, name, fts) end set!(fts::OnDiskFTS, path::String, name::String) = nothing - diff --git a/test/test_forcings.jl b/test/test_forcings.jl index e4e4e48294..c4056ea7e0 100644 --- a/test/test_forcings.jl +++ b/test/test_forcings.jl @@ -118,7 +118,7 @@ end function time_step_with_field_time_series_forcing(arch) grid = RectilinearGrid(arch, size=(1, 1, 1), extent=(1, 1, 1)) - + u_forcing = FieldTimeSeries{Face, Center, Center}(grid, 0:1:3) for (t, time) in enumerate(u_forcing.times) @@ -134,14 +134,14 @@ function time_step_with_field_time_series_forcing(arch) model = NonhydrostaticModel(; grid, forcing=(; u=u_forcing)) time_step!(model, 2) time_step!(model, 2) - + @test u_forcing.backend.start == 4 return true end function relaxed_time_stepping(arch) - x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1), + x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1), target = LinearTarget{:x}(intercept=π, gradient=ℯ)) y_relax = Relaxation(rate = 1/60, mask = GaussianMask{:y}(center=0.5, width=0.1), @@ -197,7 +197,7 @@ end function two_forcings(arch) grid = RectilinearGrid(arch, size=(4, 5, 6), extent=(1, 1, 1), halo=(4, 4, 4)) - + forcing1 = Relaxation(rate=1) forcing2 = Relaxation(rate=2) @@ -221,7 +221,7 @@ function seven_forcings(arch) peculiar_forcing(x, y, z, t) = 2t / z eccentric_forcing(x, y, z, t) = x + y + z + t unconventional_forcing(x, y, z, t) = 10x * y - + F1 = Forcing(weird_forcing) F2 = Forcing(wonky_forcing) F3 = Forcing(strange_forcing) @@ -269,7 +269,7 @@ end @test time_step_with_multiple_field_dependent_forcing(arch) @test time_step_with_parameterized_field_dependent_forcing(arch) - end + end @testset "Relaxation forcing functions [$A]" begin @info " Testing relaxation forcing functions [$A]..." diff --git a/test/test_output_readers.jl b/test/test_output_readers.jl index fd51e1ca7a..7d4657cd81 100644 --- a/test/test_output_readers.jl +++ b/test/test_output_readers.jl @@ -290,8 +290,8 @@ end @test t[1, 1, 1] == 3.8 end - @testset "Test chunked abstraction" begin - @info " Testing Chunked abstraction..." + @testset "Test chunked abstraction" begin + @info " Testing Chunked abstraction..." filepath = "testfile.jld2" fts = FieldTimeSeries(filepath, "c") fts_chunked = FieldTimeSeries(filepath, "c"; backend = InMemory(2), time_indexing = Cyclical()) @@ -342,8 +342,8 @@ end end for Backend in [InMemory, OnDisk] - @testset "FieldDataset{$Backend}" begin - @info " Testing FieldDataset{$Backend}..." + @testset "FieldDataset{$Backend} indexing" begin + @info " Testing FieldDataset{$Backend} indexing..." ds = FieldDataset(filepath3d, backend=Backend()) @@ -354,7 +354,7 @@ end @test ds[var_str] isa FieldTimeSeries @test ds[var_str][1] isa Field end - + for var_sym in (:u, :v, :w, :T, :S, :b, :ζ, :ke) @test ds[var_sym] isa FieldTimeSeries @test ds[var_sym][2] isa Field @@ -371,6 +371,36 @@ end end end + for Backend in [InMemory, OnDisk] + @testset "FieldTimeSeries{$Backend} parallel reading" begin + @info " Testing FieldTimeSeries{$Backend} parallel reading..." + + reader_kw = Dict(:parallel_read => true) + u3 = FieldTimeSeries(filepath3d, "u"; backend=Backend(), reader_kw) + b3 = FieldTimeSeries(filepath3d, "b"; backend=Backend(), reader_kw) + + @test u3 isa FieldTimeSeries + @test b3 isa FieldTimeSeries + @test u3[1] isa Field + @test b3[1] isa Field + end + end + + for Backend in [InMemory, OnDisk] + @testset "FieldDataset{$Backend} parallel reading" begin + @info " Testing FieldDataset{$Backend} parallel reading..." + + reader_kw = Dict(:parallel_read => true) + ds = FieldDataset(filepath3d; backend=Backend(), reader_kw) + + @test ds isa FieldDataset + @test ds.u isa FieldTimeSeries + @test ds.b isa FieldTimeSeries + @test ds.u[1] isa Field + @test ds.b[1] isa Field + end + end + rm(filepath1d) rm(filepath2d) rm(filepath3d)