Skip to content

Commit

Permalink
Merge branch 'master' into literal-pow
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer authored Dec 26, 2024
2 parents d875abb + 5dbd23a commit f49ea9d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "1.5.0"
version = "1.5.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ const Dual = ForwardDiff.Dual
#binary: mod
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.

const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}
const FloatOrDual = Union{AbstractFloat,Dual}
# Note that a complex dual is Complex{<:Dual}, so we are safe to use this signature.

# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
# Define allowed operators. Any julia operator can also be used.
Expand Down
38 changes: 38 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,41 @@ end
@test c2() == 0
@test typeof(c2()) === Int
end

@testitem "Test higher-order derivatives of safe_log with DynamicDiff" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: D, safe_log, ValidVector
using DynamicExpressions: OperatorEnum
using ForwardDiff: DimensionMismatch

operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(safe_log,))
variable_names = ["x"]
x = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)

# Test first and second derivatives of log(x)
structure = TemplateStructure{(:f,)}(
((; f), (x,)) ->
ValidVector([(f(x).x[1], D(f, 1)(x).x[1], D(D(f, 1), 1)(x).x[1])], true),
)
expr = TemplateExpression((; f=log(x)); structure, operators, variable_names)

# Test at x = 2.0 where log(x) is well-defined
X = [2.0]'
result = only(expr(X))
@test result !== nothing
@test result[1] == log(2.0) # function value
@test result[2] == 1 / 2.0 # first derivative
@test result[3] == -1 / 4.0 # second derivative

# We handle invalid ranges gracefully:
X_invalid = [-1.0]'
result = only(expr(X_invalid))
@test result !== nothing
@test isnan(result[1])
@test result[2] == 0.0
@test result[3] == 0.0

# Eventually we want to support complex numbers:
X_complex = [-1.0 - 1.0im]'
@test_throws DimensionMismatch expr(X_complex)
end

0 comments on commit f49ea9d

Please sign in to comment.