Skip to content

Commit

Permalink
Merge pull request #1359 from AayushSabharwal/as/discontinuous-interface
Browse files Browse the repository at this point in the history
feat: add discontinuity handling API
  • Loading branch information
ChrisRackauckas authored Nov 13, 2024
2 parents 195c95d + d4335e9 commit 833a05c
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,7 @@ end
export inverse, left_inverse, right_inverse, @register_inverse, has_inverse, has_left_inverse, has_right_inverse
include("inverse.jl")

export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity
include("discontinuities.jl")

end # module
106 changes: 106 additions & 0 deletions src/discontinuities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
rootfunction(f)
Given a function `f` with a discontinuity or discontinuous derivative, return the rootfinding
function of `f`. The rootfinding function `g` takes the same arguments as `f`, and is such
that `f` can be described as a piecewise function based on the sign of `g`, where each piece
is continuous and has a continuous derivative. The pieces are obtained using
`left_continuous_function(f)` and `right_continuous_function(f)`.
More formally,
```julia
f(args...) = if g(args...) < 0
left_continuous_function(f)(args...)
else
right_continuous_function(f)(args...)
end
```
For example, if `f` is `max(x, y)`, the root function is `(x, y) -> x - y` with
`left_continuous_function` as `(x, y) -> y` and `right_continuous_function` as
`(x, y) -> x`.
See also: [`left_continuous_function`](@ref), [`right_continuous_function`](@ref).
"""
function rootfunction end

"""
left_continuous_function(f)
Given a function `f` with a discontinuity or discontinuous derivative, return a function
taking the same arguments as `f` which is continuous and has a continuous derivative
when `rootfinding_function(f)` is negative.
See also: [`rootfunction`](@ref).
"""
function left_continuous_function end

"""
right_continuous_function(f)
Given a function `f` with a discontinuity or discontinuous derivative, return a function
taking the same arguments as `f` which is continuous and has a continuous derivative
when `rootfinding_function(f)` is positive.
See also: [`rootfunction`](@ref).
"""
function right_continuous_function end

"""
@register_discontinuity f(arg1, arg2, ...) root_expr left_expr right_expr
Utility macro to register functions with discontinuities. The function `f` with
arguments `arg1, arg2, ...` has a `rootfunction` of `root_expr`, a
`left_continuous_function` of `left_expr` and `right_continuous_function` of
`right_expr`. `root_expr`, `left_expr` and `right_expr` are all expressions in terms
of `arg1, arg2, ...`.
For example, `max(x, y)` can be registered as `@register_discontinuity max(x, y) x - y y x`.
See also: [`rootfunction`](@ref)
"""
macro register_discontinuity(f, root, left, right)
Meta.isexpr(f, :call) || error("Expected function call as first argument")
args = f.args[2:end]
fn = esc(f.args[1])
rootname = gensym(:root)
rootfn = :(function $rootname($(args...))
$root
end)
leftname = gensym(:left)
leftfn = :(function $leftname($(args...))
$left
end)
rightname = gensym(:right)
rightfn = :(function $rightname($(args...))
$right
end)
return quote
$rootfn
(::$typeof($rootfunction))(::$typeof($fn)) = $rootname
$leftfn
(::$typeof($left_continuous_function))(::$typeof($fn)) = $leftname
$rightfn
(::$typeof($right_continuous_function))(::$typeof($fn)) = $rightname
end
end

# a triangle function which is zero when x is a multiple of period
function _triangle(x, period)
x /= 2period
abs(x + 1 // 4 - floor(x + 3 // 4)) - 1 // 2
end

@register_discontinuity abs(x) x -x x
# just needs a rootfind to hit the discontinuity
@register_discontinuity mod(x, y) _triangle(x, y) mod(x, y) mod(x, y)
@register_discontinuity rem(x, y) _triangle(x, y) rem(x, y) rem(x, y)
@register_discontinuity div(x, y) _triangle(x, y) div(x, y) div(x, y)
@register_discontinuity max(x, y) x - y y x
@register_discontinuity min(x, y) x - y x y
@register_discontinuity NaNMath.max(x, y) x - y y x
@register_discontinuity NaNMath.min(x, y) x - y x y
@register_discontinuity <(x, y) x - y true false
@register_discontinuity <=(x, y) y - x false true
@register_discontinuity >(x, y) y - x true false
@register_discontinuity >=(x, y) x - y false true
22 changes: 12 additions & 10 deletions src/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,26 @@ inverse.
"""
macro register_inverse(f, g, dir::QuoteNode = :(:both))
dir = dir.value
f = esc(f)
g = esc(g)
if dir == :both
quote
(::typeof($inverse))(::typeof($f)) = $g
(::typeof($inverse))(::typeof($g)) = $f
(::typeof($left_inverse))(::typeof($f)) = $(inverse)($f)
(::typeof($right_inverse))(::typeof($f)) = $(inverse)($f)
(::typeof($left_inverse))(::typeof($g)) = $(inverse)($g)
(::typeof($right_inverse))(::typeof($g)) = $(inverse)($g)
(::$typeof($inverse))(::$typeof($f)) = $g
(::$typeof($inverse))(::$typeof($g)) = $f
(::$typeof($left_inverse))(::$typeof($f)) = $(inverse)($f)
(::$typeof($right_inverse))(::$typeof($f)) = $(inverse)($f)
(::$typeof($left_inverse))(::$typeof($g)) = $(inverse)($g)
(::$typeof($right_inverse))(::$typeof($g)) = $(inverse)($g)
end
elseif dir == :left
quote
(::typeof($left_inverse))(::typeof($f)) = $g
(::typeof($right_inverse))(::typeof($g)) = $f
(::$typeof($left_inverse))(::$typeof($f)) = $g
(::$typeof($right_inverse))(::$typeof($g)) = $f
end
elseif dir == :right
quote
(::typeof($right_inverse))(::typeof($f)) = $g
(::typeof($left_inverse))(::typeof($g)) = $f
(::$typeof($right_inverse))(::$typeof($f)) = $g
(::$typeof($left_inverse))(::$typeof($g)) = $f
end
else
throw(ArgumentError("The third argument to `@register_inverse` must be `left` or `right`"))
Expand Down
30 changes: 30 additions & 0 deletions test/discontinuities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Symbolics, NaNMath, Test

function discontinuity_eval(fn, args...)
if rootfunction(fn)(args...) < 0
left_continuous_function(fn)(args...)
else
right_continuous_function(fn)(args...)
end
end

@testset "abs" begin
for x in -1.0:0.001:1.0
@test abs(x) discontinuity_eval(abs, x)
end
end

@testset "$(nameof(f))" for f in (mod, rem, div)
y = 0.7
for x in -2y:0.001:2y
@test f(x, y) discontinuity_eval(f, x, y)
end
end

@testset "$(nameof(f))" for f in (min, max, NaNMath.min, NaNMath.max, <, <=, >, >=)
for x in 0.0:0.1:1.0
for y in 0.0:0.1:1.0
@test f(x, y) discontinuity_eval(f, x, y)
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "RootFinding solver" begin include("solver.jl") end
@safetestset "Function inverses test" begin include("inverse.jl") end
@safetestset "Taylor Series Test" begin include("taylor.jl") end
@safetestset "Discontinuity registration test" begin include("discontinuities.jl") end
end
end

Expand Down

0 comments on commit 833a05c

Please sign in to comment.