Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added check function for weighted data #189

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/curve_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@ StatsAPI.residuals(lfr::LsqFitResult) = lfr.resid
mse(lfr::LsqFitResult) = rss(lfr) / dof(lfr)
isconverged(lsr::LsqFitResult) = lsr.converged

function check_data_health(xdata, ydata)
if any(ismissing, xdata) || any(ismissing, ydata)
error("Data contains `missing` values and a fit cannot be performed")
function check_data_health(xdata, ydata, wt = [])
if any(ismissing, xdata)
error("The independent variable (`x`) contains `missing` values and a fit cannot be performed")
end
if any(isinf, xdata) || any(isinf, ydata) || any(isnan, xdata) || any(isnan, ydata)
error("Data contains `Inf` or `NaN` values and a fit cannot be performed")
if any(ismissing, ydata)
error("The dependent variable (`y`) contains `missing` values and a fit cannot be performed")
end
if any(ismissing, wt)
error("Weight data contains `missing` values and a fit cannot be performed")
end
if any(!isfinite, xdata)
error("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed")
end
if any(!isfinite, ydata)
error("The dependent variable (`y`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed")
end
if any(!isfinite, wt)
error("Weight contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed")
end

end

# provide a method for those who have their own Jacobian function
Expand Down Expand Up @@ -174,7 +187,7 @@ function curve_fit(
inplace=false,
kwargs...,
)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)
# construct a weighted cost function, with a vector weight for each ydata
# for example, this might be wt = 1/sigma where sigma is some error term
u = sqrt.(wt) # to be consistant with the matrix form
Expand All @@ -198,7 +211,7 @@ function curve_fit(
inplace=false,
kwargs...,
)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)
u = sqrt.(wt) # to be consistant with the matrix form

if inplace
Expand All @@ -220,7 +233,7 @@ function curve_fit(
p0::AbstractArray;
kwargs...,
)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)

# as before, construct a weighted cost function with where this
# method uses a matrix weight.
Expand All @@ -244,7 +257,7 @@ function curve_fit(
p0::AbstractArray;
kwargs...,
)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)

u = cholesky(wt).U

Expand Down
16 changes: 7 additions & 9 deletions test/curve_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@ using LsqFit, Test, StableRNGs, LinearAlgebra
@testset "curve fit" begin
# before testing the model, check whether missing/null data is rejected
tdata = [rand(1:10, 5)..., missing]
@test_throws ErrorException(
"Data contains `missing` values and a fit cannot be performed",
) LsqFit.check_data_health(tdata, tdata)
@test_throws ErrorException("The independent variable (`x`) contains `missing` values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata)
tdata = [rand(1:10, 5)..., Inf]
@test_throws ErrorException(
"Data contains `Inf` or `NaN` values and a fit cannot be performed",
) LsqFit.check_data_health(tdata, tdata)
@test_throws ErrorException("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata)
tdata = [rand(1:10, 5)..., NaN]
@test_throws ErrorException(
"Data contains `Inf` or `NaN` values and a fit cannot be performed",
) LsqFit.check_data_health(tdata, tdata)
@test_throws ErrorException("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata)

# fitting noisy data to an exponential model
# TODO: Change to `.-x` when 0.5 support is dropped
model(x, p) = p[1] .* exp.(-x .* p[2])

for T in (Float64, BigFloat)
# fitting noisy data to an exponential model
Expand Down
20 changes: 20 additions & 0 deletions test/curve_fit_inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,26 @@ using LsqFit, Test, StableRNGs, LinearAlgebra
)
@test fit_wt.converged

@testset "bad input" begin
xxdata = copy(collect(xdata))
yydata = copy(ydata)
WWT = 1 ./ sqrt.(yvars)
x1 = xxdata[1]
y1 = yydata[1]
wt1 = WWT[1]
for x in (x1, Inf, -Inf, NaN), y in (y1, Inf, -Inf, NaN), wt in (wt1, Inf, -Inf, NaN)

xxdata[1] = x
yydata[1] = y
WWT[1] = wt
if x == x1 && y == y1 && wt == wt1
@test_nowarn curve_fit(model, jacobian_model, xxdata, yydata, WWT, [0.5, 0.5]; maxIter=100)
else
@test_throws ErrorException curve_fit(model, jacobian_model, xxdata, yydata, WWT, [0.5, 0.5]; maxIter=100)
end
end
end

println("\t Inplace with weights")
fit_inplace_wt = @time curve_fit(
model_inplace,
Expand Down
Loading