diff --git a/src/curve_fit.jl b/src/curve_fit.jl index e45fd26..24801bc 100755 --- a/src/curve_fit.jl +++ b/src/curve_fit.jl @@ -16,13 +16,26 @@ StatsAPI.residuals(lfr::LsqFitResult) = lfr.resid mse(lfr::LsqFitResult) = rss(lfr) / dof(lfr) isconverged(lsr::LsqFitResult) = lsr.converged -function check_data_health(xdata, ydata) - if any(ismissing, xdata) || any(ismissing, ydata) - error("Data contains `missing` values and a fit cannot be performed") +function check_data_health(xdata, ydata, wt = []) + if any(ismissing, xdata) + error("The independent variable (`x`) contains `missing` values and a fit cannot be performed") end - if any(isinf, xdata) || any(isinf, ydata) || any(isnan, xdata) || any(isnan, ydata) - error("Data contains `Inf` or `NaN` values and a fit cannot be performed") + if any(ismissing, ydata) + error("The dependent variable (`y`) contains `missing` values and a fit cannot be performed") end + if any(ismissing, wt) + error("Weight data contains `missing` values and a fit cannot be performed") + end + if any(!isfinite, xdata) + error("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") + end + if any(!isfinite, ydata) + error("The dependent variable (`y`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") + end + if any(!isfinite, wt) + error("Weight contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") + end + end # provide a method for those who have their own Jacobian function @@ -174,7 +187,7 @@ function curve_fit( inplace=false, kwargs..., ) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) # construct a weighted cost function, with a vector weight for each ydata # for example, this might be wt = 1/sigma where sigma is some error term u = sqrt.(wt) # to be consistant with the matrix form @@ -198,7 +211,7 @@ function curve_fit( inplace=false, kwargs..., ) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) u = sqrt.(wt) # to be consistant with the matrix form if inplace @@ -220,7 +233,7 @@ function curve_fit( p0::AbstractArray; kwargs..., ) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) # as before, construct a weighted cost function with where this # method uses a matrix weight. @@ -244,7 +257,7 @@ function curve_fit( p0::AbstractArray; kwargs..., ) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) u = cholesky(wt).U diff --git a/test/curve_fit.jl b/test/curve_fit.jl index 0be8e3e..50694ab 100755 --- a/test/curve_fit.jl +++ b/test/curve_fit.jl @@ -2,17 +2,15 @@ using LsqFit, Test, StableRNGs, LinearAlgebra @testset "curve fit" begin # before testing the model, check whether missing/null data is rejected tdata = [rand(1:10, 5)..., missing] - @test_throws ErrorException( - "Data contains `missing` values and a fit cannot be performed", - ) LsqFit.check_data_health(tdata, tdata) + @test_throws ErrorException("The independent variable (`x`) contains `missing` values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata) tdata = [rand(1:10, 5)..., Inf] - @test_throws ErrorException( - "Data contains `Inf` or `NaN` values and a fit cannot be performed", - ) LsqFit.check_data_health(tdata, tdata) + @test_throws ErrorException("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata) tdata = [rand(1:10, 5)..., NaN] - @test_throws ErrorException( - "Data contains `Inf` or `NaN` values and a fit cannot be performed", - ) LsqFit.check_data_health(tdata, tdata) + @test_throws ErrorException("The independent variable (`x`) contains non-finite (e.g. `Inf`, `NaN`) values and a fit cannot be performed") LsqFit.check_data_health(tdata, tdata) + + # fitting noisy data to an exponential model + # TODO: Change to `.-x` when 0.5 support is dropped + model(x, p) = p[1] .* exp.(-x .* p[2]) for T in (Float64, BigFloat) # fitting noisy data to an exponential model diff --git a/test/curve_fit_inplace.jl b/test/curve_fit_inplace.jl index e20c3f0..9ea63b5 100644 --- a/test/curve_fit_inplace.jl +++ b/test/curve_fit_inplace.jl @@ -168,6 +168,26 @@ using LsqFit, Test, StableRNGs, LinearAlgebra ) @test fit_wt.converged + @testset "bad input" begin + xxdata = copy(collect(xdata)) + yydata = copy(ydata) + WWT = 1 ./ sqrt.(yvars) + x1 = xxdata[1] + y1 = yydata[1] + wt1 = WWT[1] + for x in (x1, Inf, -Inf, NaN), y in (y1, Inf, -Inf, NaN), wt in (wt1, Inf, -Inf, NaN) + + xxdata[1] = x + yydata[1] = y + WWT[1] = wt + if x == x1 && y == y1 && wt == wt1 + @test_nowarn curve_fit(model, jacobian_model, xxdata, yydata, WWT, [0.5, 0.5]; maxIter=100) + else + @test_throws ErrorException curve_fit(model, jacobian_model, xxdata, yydata, WWT, [0.5, 0.5]; maxIter=100) + end + end + end + println("\t Inplace with weights") fit_inplace_wt = @time curve_fit( model_inplace,