Skip to content

Commit

Permalink
feat: Nonlinear observation operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Giles authored and Dan Giles committed Nov 3, 2023
1 parent 0095bdc commit 748e6f9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 19 additions & 3 deletions test/models/lorenz63.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Base.@kwdef struct Lorenz63ModelParameters{S <: Real, T <: Real}
initial_state_std::Union{S, Vector{S}} = 0.05
state_noise_std::Union{S, Vector{S}} = 0.05
observation_noise_std::Union{T, Vector{T}} = 2.
operator_type::String = "linear"
end

function get_params(
Expand Down Expand Up @@ -103,15 +104,28 @@ function ParticleDA.update_state_stochastic!(
rand!(rng, state + model.state_noise_distribution, state)
end

function observation_operator!(
observation::AbstractVector{T},
operator_type::String
) where {T <: Real}
if operator_type == "log"
observation .= log.(abs.(observation))
else
observation .= observation
end
end

function ParticleDA.sample_observation_given_state!(
observation::AbstractVector{T},
state::AbstractVector{S},
model::Lorenz63Model{S, T},
rng::Random.AbstractRNG,
) where {S <: Real, T <: Real}

observation .= view(state, model.parameters.observed_indices)
rand!(
rng,
view(state, model.parameters.observed_indices)
observation_operator!(observation, model.parameters.operator_type)
+ model.observation_noise_distribution,
observation
)
Expand All @@ -120,9 +134,11 @@ end
function ParticleDA.get_log_density_observation_given_state(
observation::AbstractVector{T}, state::AbstractVector{S}, model::Lorenz63Model{S, T}
) where {S <: Real, T <: Real}

obs_given_state = (view(state, model.parameters.observed_indices) + model.observation_noise_distribution)
observation_operator!(obs_given_state.μ, model.parameters.operator_type)
return logpdf(
view(state, model.parameters.observed_indices)
+ model.observation_noise_distribution,
obs_given_state,
observation
)
end
Expand Down
2 changes: 0 additions & 2 deletions test/models/lorenz96.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ function observation_operator!(
) where {T <: Real}
if operator_type == "log"
observation .= log.(abs.(observation))
elseif operator_type == "square"
observation .= (observation).^2
else
observation .= observation
end
Expand Down

0 comments on commit 748e6f9

Please sign in to comment.