-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ArrayDifferentialOperators for Vector calculus #942
base: master
Are you sure you want to change the base?
Changes from 10 commits
bccb4f1
03ac562
69c1265
72b7855
dfc7d94
73352b9
0559b06
caa49cc
75fb18c
09943a4
af60a35
dc4a596
068ad64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
abstract type Operator <: Function end | ||
abstract type AbstractOperator <: Function end | ||
abstract type Operator <: AbstractOperator end | ||
|
||
""" | ||
$(TYPEDEF) | ||
|
@@ -33,17 +34,17 @@ struct Differential <: Operator | |
x | ||
Differential(x) = new(value(x)) | ||
end | ||
(D::Differential)(x) = Term{symtype(x)}(D, [x]) | ||
(D::Differential)(x::Num) = Num(D(value(x))) | ||
(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) | ||
(D::Operator)(x) = Term{symtype(x)}(D, [x]) | ||
(D::Operator)(x::Num) = Num(D(value(x))) | ||
(D::Operator)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) | ||
SymbolicUtils.promote_symtype(::Differential, x) = x | ||
|
||
is_derivative(x) = istree(x) ? operation(x) isa Differential : false | ||
|
||
Base.:*(D1, D2::Differential) = D1 ∘ D2 | ||
Base.:*(D1::Differential, D2) = D1 ∘ D2 | ||
Base.:*(D1::Differential, D2::Differential) = D1 ∘ D2 | ||
Base.:^(D::Differential, n::Integer) = _repeat_apply(D, n) | ||
Base.:*(D1, D2::Operator) = D1 ∘ D2 | ||
Base.:*(D1::Operator, D2) = D1 ∘ D2 | ||
Base.:*(D1::Operator, D2::Operator) = D1 ∘ D2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering why we have these methods @ChrisRackauckas ? this does not make sense in general, only maybe for 2 operators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's a case where they don't make sense? |
||
Base.:^(D::Operator, n::Integer) = _repeat_apply(D, n) | ||
|
||
Base.show(io::IO, D::Differential) = print(io, "Differential(", D.x, ")") | ||
|
||
|
@@ -785,3 +786,94 @@ end | |
function SymbolicUtils.substitute(op::Differential, dict; kwargs...) | ||
@set! op.x = substitute(op.x, dict; kwargs...) | ||
end | ||
|
||
|
||
####################################################################################################################### | ||
# Vector Calculus | ||
####################################################################################################################### | ||
abstract type ArrayOperator end | ||
|
||
struct ArrayDifferentialOperator <: ArrayOperator | ||
"""The variables to differentiate with resp≈ect to.""" | ||
vars | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? |
||
differentials | ||
name | ||
ArrayDifferentialOperator(vars, differentials, name) = new(vars, differentials, name) | ||
end | ||
Nabla(vars) = ArrayDifferentialOperator(value.(vars), map(Differential, scalarize(value.(vars))), "∇") | ||
const Grad = Nabla | ||
Div(vars) = (x) -> Nabla(vars) ⋅ x | ||
Curl(vars) = (x) -> Nabla(vars) × x | ||
Laplacian(vars) = Nabla(vars) ⋅ Nabla(vars) | ||
|
||
#? How to get transpose and Jac working? | ||
|
||
function (D::ArrayDifferentialOperator)(x::SymVec) | ||
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)." | ||
@arrayop (i,) (D.differentials)[i](x[i]) term=D(x) | ||
end | ||
(D::ArrayDifferentialOperator)(x::Arr) = Arr(D(value(x))) | ||
|
||
function (D1::ArrayDifferentialOperator)(D2::ArrayDifferentialOperator) | ||
@assert all(x -> any(isequal.((x,), D2.vars)), D1.vars) | ||
|
||
ArrayDifferentialOperator(D1.vars, scalarize(D1.differentials .∘ D2.differentials), "("*D1.name*"∘"*D2.name*")") | ||
end | ||
|
||
function LinearAlgebra.dot(D::ArrayDifferentialOperator, x::SymVec) | ||
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)." | ||
@show D(x), scalarize(D(x)) | ||
sum(scalarize(D(x))) | ||
end | ||
LinearAlgebra.dot(D::ArrayDifferentialOperator, x::Arr) = Num(D ⋅ value(x)) | ||
|
||
function LinearAlgebra.dot(x::SymVec, D::ArrayDifferentialOperator) | ||
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)." | ||
(y) -> sum(@arrayop (i,) x[i]*D.differentials[i](y) term = (x⋅D)(y)) | ||
end | ||
LinearAlgebra.dot(x::Arr, D::ArrayDifferentialOperator) = value(x) ⋅ D | ||
|
||
function LinearAlgebra.dot(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator) | ||
@assert all(scalarize(isequal.(D1.vars, D2.vars))) "Operators have different variables and cannot be composed." | ||
lap = x -> sum((D1.differentials[i] ∘ D2.differentials[i])(x) for i in 1:length(D1.vars)) | ||
(x) -> @arrayop (i,) lap(x[i]) term=(D1⋅D2)(x) reduce=+ | ||
end | ||
|
||
function crosscompose(a, b) | ||
v1 = x -> (a[2] ∘ b[3])(x) - (a[3] ∘ b[2])(x) | ||
v2 = x -> (a[3] ∘ b[1])(x) - (a[1] ∘ b[3])(x) | ||
v3 = x -> (a[1] ∘ b[2])(x) - (a[2] ∘ b[1])(x) | ||
return [v1, v2, v3] | ||
end | ||
|
||
function crosscall(a, b) | ||
v1 = a[2](b[3]) - a[3](b[2]) | ||
v2 = a[3](b[1]) - a[1](b[3]) | ||
v3 = a[1](b[2]) - a[2](b[1]) | ||
return [v1, v2, v3] | ||
end | ||
function LinearAlgebra.cross(D::ArrayDifferentialOperator, x::SymVec) | ||
@assert length(D.vars) == length(x) == 3 "Cross product is only defined in 3 dimensions." | ||
curl = crosscall(D.differentials, x) | ||
@arrayop (i,) curl[i] term=D×x | ||
end | ||
LinearAlgebra.cross(D::ArrayDifferentialOperator, x::Arr) = Arr(D × value(x)) | ||
|
||
function LinearAlgebra.cross(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator) | ||
@assert length(D1.vars) == length(D2.vars) == 3 "Cross product is only defined in 3 dimensions." | ||
@assert all(scalarize(isequal.(D1.vars, D2.vars))) "Operators have different variables and cannot be composed." | ||
|
||
ArrayDifferentialOperator(D1.vars, crosscompose(D1.differentials, D2.differentials), "("*D1.name*"×"*D2.name*")") | ||
end | ||
|
||
SymbolicUtils.promote_symtype(::ArrayDifferentialOperator, x) = x | ||
|
||
Base.show(io::IO, D::ArrayDifferentialOperator) = print(io, D.name) | ||
Base.nameof(D::ArrayDifferentialOperator) = Symbol(D.name) | ||
|
||
function Base.:(==)(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator) | ||
@variables x[1:length(D1.vars)] | ||
all(scalarize(isequal.(D1.vars, D2.vars))) && all(scalarize(isequal.(D1(x), D2(x)))) | ||
end | ||
|
||
# TODO: Add simplification rules for dot and cross products to remove 0 terms and simplify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need 2 of these?