Skip to content

Commit

Permalink
forego fmm calculation if empty sources and/or targets
Browse files Browse the repository at this point in the history
  • Loading branch information
rymanderson committed May 18, 2024
1 parent 7dca307 commit 393c80c
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 173 deletions.
27 changes: 16 additions & 11 deletions src/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ const SCALAR_STRENGTH = ScalarStrength()
struct VectorStrength <: Indexable end
const VECTOR_STRENGTH = VectorStrength()

#####
##### dispatch convenience functions for multipole creation definition
#####
##### dispatch convenience functions for multipole creation definition
#####
abstract type AbstractKernel{sign} end

Expand Down Expand Up @@ -117,6 +117,11 @@ function DerivativesSwitch(scalar_potential::Bool, vector_potential::Bool, veloc
return Tuple(DerivativesSwitch{scalar_potential, vector_potential, velocity, velocity_gradient}() for _ in target_systems)
end

function DerivativesSwitch(scalar_potential, vector_potential, velocity, velocity_gradient, target_systems::Tuple)
@assert length(scalar_potential) == length(vector_potential) == length(velocity) == length(velocity_gradient) == length(target_systems) "length of inputs to DerivativesSwitch inconsistent"
return Tuple(DerivativesSwitch{scalar_potential[i], vector_potential[i], velocity[i], velocity_gradient[i]}() for i in eachindex(target_systems))
end

function DerivativesSwitch(scalar_potential::Bool, vector_potential::Bool, velocity::Bool, velocity_gradient::Bool, target_system)
return DerivativesSwitch{scalar_potential, vector_potential, velocity, velocity_gradient}()
end
Expand All @@ -140,13 +145,13 @@ DerivativesSwitch() = DerivativesSwitch{true, true, true, true}()
# end

# SingleCostParameters(;
# alloc_M2M_L2L = ALLOC_M2M_L2L_DEFAULT,
# tau_M2M_L2L = TAU_M2M_DEFAULT,
# alloc_M2L = ALLOC_M2L_DEFAULT,
# tau_M2L = TAU_M2L_DEFAULT,
# tau_L2L = TAU_L2L_DEFAULT,
# alloc_L2B = ALLOC_L2B_DEFAULT,
# tau_L2B = TAU_L2B_DEFAULT,
# alloc_M2M_L2L = ALLOC_M2M_L2L_DEFAULT,
# tau_M2M_L2L = TAU_M2M_DEFAULT,
# alloc_M2L = ALLOC_M2L_DEFAULT,
# tau_M2L = TAU_M2L_DEFAULT,
# tau_L2L = TAU_L2L_DEFAULT,
# alloc_L2B = ALLOC_L2B_DEFAULT,
# tau_L2B = TAU_L2B_DEFAULT,
# C_nearfield = C_NEARFIELD_DEFAULT,
# tau_B2M = TAU_B2M_DEFAULT
# ) = SingleCostParameters(alloc_M2M_L2L, tau_M2M_L2L, alloc_M2L, tau_M2L, alloc_L2B, tau_L2B, C_nearfield, tau_B2M)
Expand Down Expand Up @@ -176,10 +181,10 @@ DerivativesSwitch() = DerivativesSwitch{true, true, true, true}()
# CostParameters(systems::Tuple) = MultiCostParameters()
# CostParameters(system) = SingleCostParameters()

# CostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield::Float64, tau_M2M_L2L) =
# CostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield::Float64, tau_M2M_L2L) =
# SingleCostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield, tau_M2M_L2L)

# CostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield::SVector, tau_M2M_L2L) =
# CostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield::SVector, tau_M2M_L2L) =
# MultiCostParameters(alloc_M2M_L2L, tau_B2M, alloc_M2L, tau_M2L, tau_L2L, alloc_L2B, tau_L2B, C_nearfield, tau_M2M_L2L)

#####
Expand Down
164 changes: 83 additions & 81 deletions src/fmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function body_2_multipole_multithread!(branches, systems::Tuple, expansion_order
end
i_thread <= n_threads && (leaf_assignments[i_system,i_thread] = i_start:length(leaf_index))
end

## compute multipole expansion coefficients
Threads.@threads for i_thread in 1:n_threads
for (i_system,system) in enumerate(systems)
Expand Down Expand Up @@ -147,10 +147,10 @@ end
function translate_multipoles_multithread!(branches, expansion_order::Val{P}, levels_index) where P
# initialize memory
n_threads = Threads.nthreads()

# iterate over levels
for level_index in view(levels_index,length(levels_index):-1:2)

# load balance
n_branches = length(level_index)
n_per_thread, rem = divrem(n_branches, n_threads)
Expand Down Expand Up @@ -197,7 +197,7 @@ function nearfield_multithread!(target_system, target_branches, derivatives_swit
## load balance
n_threads = Threads.nthreads()
assignments = Vector{UnitRange{Int64}}(undef,n_threads)

for (i_source_system, source_system) in enumerate(source_systems)
# total number of interactions
n_interactions = 0
Expand All @@ -206,7 +206,7 @@ function nearfield_multithread!(target_system, target_branches, derivatives_swit
source_leaf = view(source_branches,i_source)
n_interactions += get_n_bodies(target_leaf[].bodies_index) * get_n_bodies(source_leaf[].bodies_index[i_source_system])
end

# interactions per thread
n_per_thread, rem = divrem(n_interactions, n_threads)
rem > 0 && (n_per_thread += 1)
Expand All @@ -233,7 +233,7 @@ function nearfield_multithread!(target_system, target_branches, derivatives_swit
end
end
i_thread <= n_threads && (assignments[i_thread] = i_start:length(direct_list))

# execute tasks
Threads.@threads for i_thread in eachindex(assignments)
assignment = assignments[i_thread]
Expand All @@ -254,15 +254,15 @@ function nearfield_multithread!(target_system, target_branches, derivatives_swit
## load balance
n_threads = Threads.nthreads()
assignments = Vector{UnitRange{Int64}}(undef,n_threads)

# total number of interactions
n_interactions = 0
for (i_target, i_source) in direct_list
target_leaf = view(target_branches,i_target)
source_leaf = view(source_branches,i_source)
n_interactions += get_n_bodies(target_leaf[].bodies_index) * get_n_bodies(source_leaf[].bodies_index)
end

# interactions per thread
n_per_thread, rem = divrem(n_interactions, n_threads)
rem > 0 && (n_per_thread += 1)
Expand All @@ -289,7 +289,7 @@ function nearfield_multithread!(target_system, target_branches, derivatives_swit
end
end
i_thread <= n_threads && (assignments[i_thread] = i_start:length(direct_list))

# execute tasks
Threads.@threads for assignment in assignments
for (i_target, j_source) in view(direct_list, assignment)
Expand Down Expand Up @@ -354,7 +354,7 @@ function horizontal_pass_multithread!(target_branches, source_branches::Vector{<
# Threads.@lock target_branch.lock M2L!(target_branch, source_branches[j_source], expansion_order)
end
end

return nothing
end

Expand Down Expand Up @@ -382,10 +382,10 @@ end
function translate_locals_multithread!(branches, expansion_order::Val{P}, levels_index) where P
# initialize memory
n_threads = Threads.nthreads()

# iterate over levels
for level_index in view(levels_index,2:length(levels_index))

# divide chunks
n_per_thread, rem = divrem(length(level_index),n_threads)
rem > 0 && (n_per_thread += 1)
Expand Down Expand Up @@ -450,8 +450,8 @@ end
function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index)
# m2m translation
translate_locals_multithread!(branches, expansion_order, levels_index)
# local to body interaction

# local to body interaction
local_2_body_multithread!(branches, systems, derivatives_switch, expansion_order, leaf_index)
end

Expand Down Expand Up @@ -503,12 +503,12 @@ Apply all interactions of `source_systems` acting on `target_systems` using the
- a system object for which compatibility functions have been overloaded, or
- a tuple of system objects for which compatibility functions have been overloaded
- `source_systems`: either
- a system object for which compatibility functions have been overloaded, or
- a tuple of system objects for which compatibility functions have been overloaded
# Optional Arguments
- `expansion_order::Int`: the expansion order to be used
Expand All @@ -532,10 +532,10 @@ Apply all interactions of `source_systems` acting on `target_systems` using the
"""
function fmm!(target_systems, source_systems;
scalar_potential=true, vector_potential=true, velocity=true, velocity_gradient=true,
expansion_order=5, n_per_branch_source=50, n_per_branch_target=50, multipole_acceptance_criterion=0.4,
nearfield=true, farfield=true, self_induced=true,
unsort_source_bodies=true, unsort_target_bodies=true,
source_shrink_recenter=true, target_shrink_recenter=true,
expansion_order=5, n_per_branch_source=50, n_per_branch_target=50, multipole_acceptance_criterion=0.4,
nearfield=true, farfield=true, self_induced=true,
unsort_source_bodies=true, unsort_target_bodies=true,
source_shrink_recenter=true, target_shrink_recenter=true,
save_tree=false, save_name="tree"
)
# check for duplicate systems
Expand All @@ -546,17 +546,17 @@ function fmm!(target_systems, source_systems;
target_tree = Tree(target_systems; expansion_order, n_per_branch=n_per_branch_target, shrink_recenter=target_shrink_recenter)

# perform fmm
fmm!(target_tree, target_systems, source_tree, source_systems;
scalar_potential, vector_potential, velocity, velocity_gradient,
multipole_acceptance_criterion,
reset_source_tree=false, reset_target_tree=false,
nearfield, farfield, self_induced,
fmm!(target_tree, target_systems, source_tree, source_systems;
scalar_potential, vector_potential, velocity, velocity_gradient,
multipole_acceptance_criterion,
reset_source_tree=false, reset_target_tree=false,
nearfield, farfield, self_induced,
unsort_source_bodies, unsort_target_bodies
)

# visualize
save_tree && (visualize(save_name, systems, tree))

return source_tree, target_tree
end

Expand Down Expand Up @@ -588,26 +588,26 @@ Apply all interactions of `systems` acting on itself using the fast multipole me
- `shink_recenter::Bool`: indicates whether or not to resize branches for the octree after it is created to increase computational efficiency
- `save_tree::Bool`: indicates whether or not to save a VTK file for visualizing the octree
- `save_name::String`: name and path of the octree visualization if `save_tree == true`
"""
function fmm!(systems;
function fmm!(systems;
scalar_potential=true, vector_potential=true, velocity=true, velocity_gradient=true,
expansion_order=5, n_per_branch=50, multipole_acceptance_criterion=0.4,
nearfield=true, farfield=true, self_induced=true,
unsort_bodies=true, shrink_recenter=true,
expansion_order=5, n_per_branch=50, multipole_acceptance_criterion=0.4,
nearfield=true, farfield=true, self_induced=true,
unsort_bodies=true, shrink_recenter=true,
save_tree=false, save_name="tree"
)
# create tree
tree = Tree(systems; expansion_order, n_per_branch, shrink_recenter)

# perform fmm
fmm!(tree, systems;
scalar_potential, vector_potential, velocity, velocity_gradient,
multipole_acceptance_criterion, reset_tree=false,
nearfield, farfield, self_induced,
multipole_acceptance_criterion, reset_tree=false,
nearfield, farfield, self_induced,
unsort_bodies
)

# visualize
save_tree && (visualize(save_name, systems, tree))

Expand Down Expand Up @@ -638,18 +638,18 @@ Dispatches `fmm!` using an existing `::Tree`.
- `farfield::Bool`: indicates whether far-field (comuted with multipoles) interactions should be included
- `self_induced::Bool`: indicates whether to include the interactions of each leaf-level branch on itself
- `unsort_bodies::Bool`: indicates whether or not to undo the sort operation used to generate the octree for `systems`
"""
function fmm!(tree::Tree, systems;
function fmm!(tree::Tree, systems;
scalar_potential=true, vector_potential=true, velocity=true, velocity_gradient=true,
multipole_acceptance_criterion=0.4, reset_tree=true,
nearfield=true, farfield=true, self_induced=true,
multipole_acceptance_criterion=0.4, reset_tree=true,
nearfield=true, farfield=true, self_induced=true,
unsort_bodies=true
)
fmm!(tree, systems, tree, systems;
scalar_potential, vector_potential, velocity, velocity_gradient,
multipole_acceptance_criterion, reset_source_tree=reset_tree, reset_target_tree=false,
nearfield, farfield, self_induced,
multipole_acceptance_criterion, reset_source_tree=reset_tree, reset_target_tree=false,
nearfield, farfield, self_induced,
unsort_source_bodies=unsort_bodies, unsort_target_bodies=false
)
end
Expand Down Expand Up @@ -690,49 +690,51 @@ Dispatches `fmm!` using existing `::Tree` objects.
"""
function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_systems;
scalar_potential=true, vector_potential=true, velocity=true, velocity_gradient=true,
multipole_acceptance_criterion=0.4,
reset_source_tree=true, reset_target_tree=true,
nearfield=true, farfield=true, self_induced=true,
multipole_acceptance_criterion=0.4,
reset_source_tree=true, reset_target_tree=true,
nearfield=true, farfield=true, self_induced=true,
unsort_source_bodies=true, unsort_target_bodies=true
)
# reset multipole/local expansions
reset_target_tree && (reset_expansions!(source_tree))
reset_source_tree && (reset_expansions!(source_tree))

# create interaction lists
m2l_list, direct_list = build_interaction_lists(target_tree.branches, source_tree.branches, multipole_acceptance_criterion, farfield, nearfield, self_induced)
if DEBUG_TOGGLE[]
@show m2l_list
end
# assemble derivatives switch
derivatives_switch = DerivativesSwitch(scalar_potential, vector_potential, velocity, velocity_gradient, target_systems)

# run FMM
if Threads.nthreads() == 1
nearfield && (nearfield_singlethread!(target_systems, target_tree.branches, derivatives_switch, source_systems, source_tree.branches, direct_list))
if farfield
upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order)
horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order)
downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switch, target_tree.expansion_order)
end
else # multithread
# println("nearfield")
nearfield && length(direct_list) > 0 && (nearfield_multithread!(target_systems, target_tree.branches, derivatives_switch, source_systems, source_tree.branches, direct_list))
# @time nearfield && length(direct_list) > 0 && (nearfield_multithread!(target_systems, target_tree.branches, source_systems, source_tree.branches, direct_list))
if farfield
# println("upward pass")
upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index)
# @time upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index)
# println("horizontal pass")
length(m2l_list) > 0 && (horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order))
# @time length(m2l_list) > 0 && (horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order))
# println("downward pass")
downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switch, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index)
# @time downward_pass_multithread!(target_tree.branches, target_systems, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index)
# check if systems are empty
n_sources = get_n_bodies(source_systems)
n_targets = get_n_bodies(target_systems)

if n_sources > 0 && n_targets > 0

# reset multipole/local expansions
reset_target_tree && (reset_expansions!(source_tree))
reset_source_tree && (reset_expansions!(source_tree))

# create interaction lists
m2l_list, direct_list = build_interaction_lists(target_tree.branches, source_tree.branches, multipole_acceptance_criterion, farfield, nearfield, self_induced)

# assemble derivatives switch
derivatives_switch = DerivativesSwitch(scalar_potential, vector_potential, velocity, velocity_gradient, target_systems)

# run FMM
if Threads.nthreads() == 1
nearfield && (nearfield_singlethread!(target_systems, target_tree.branches, derivatives_switch, source_systems, source_tree.branches, direct_list))
if farfield
upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order)
horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order)
downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switch, target_tree.expansion_order)
end
else # multithread
nearfield && length(direct_list) > 0 && (nearfield_multithread!(target_systems, target_tree.branches, derivatives_switch, source_systems, source_tree.branches, direct_list))
if farfield
upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index)
length(m2l_list) > 0 && (horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order))
downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switch, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index)
end
end

else
n_sources == 0 && (@warn "fmm! called but the source system is empty; foregoing calculation")
n_targets == 0 && (@warn "fmm! called but the target system is empty; foregoing calculation")
end

# unsort bodies
unsort_target_bodies && (unsort!(target_systems, target_tree))
unsort_source_bodies && (unsort!(source_systems, source_tree))
n_sources > 0 && unsort_source_bodies && unsort!(source_systems, source_tree)
n_targets > 0 && unsort_target_bodies && unsort!(target_systems, target_tree)

end
1 change: 0 additions & 1 deletion src/probe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ Adds `n_probes` probes in a line between `x1` and `x2`. Specifically, they are a
dx = (x2 - x1) / n_probes
x = x1 + dx/2
for i in 1:n_probes
i_start += 1
probes.position[i_last + i] = x
x += dx
end
Expand Down
Loading

0 comments on commit 393c80c

Please sign in to comment.