Skip to content

Commit

Permalink
Feature: High-level network specification (#2050)
Browse files Browse the repository at this point in the history
Implement a high-level network specification as proposed in
#418. It does not include support for gap junctions to allow the use of
domain decomposition for some distributed network generation.
The general idea is a DSL based on set algebra, which operates on the
set of all possible connections, by selecting based on different
criteria, such as the distance between cells or lists of labels. By
operating on all possible connections, a separate definition of cell
populations becomes unnecessary. An example for selecting all inter-cell
connections with a certain source and destination label is:
`(intersect (inter-cell) (source-label \"detector\") (destination-label
\"syn\"))`

For parameters such as weight and delay, a value can be defined in the
DSL in a similar way with the usual mathematical operations available.
An example would be:
`(max 0.1 (exp (mul -0.5 (distance))))`

The position of each connection site is calculated by resolving the
local position on the cell and applying an isometry, which is provided
by a new optional function of the recipe. In contrast to the usage of
policies to select a member within a locset, each site is treated
individually and can be distinguished by its position.

Internally, some steps have been implemented in an attempt to reduce the
overhead of generating connections:
- Pre-select source and destination sites based on the selection to
reduce the sampling space when possible
- If selection is limited to a maximum distance, use an octree for
efficient spatial sampling
- When using MPI, only instantiate local cells and exchange source sites
in a ring communication pattern to overlap communication and sampling.
In addition, this reduces memory usage, since only the current and next
source sites have to be stored in memory during the exchange process.

Custom selection and value functions can still be provided by storing
the wrapped function in a dictionary with an associated label, which can
then be used in the DSL.

Some challenges remain. In particular, how to handle combined explicit
connections returned by `connections_on` and the new way to describe a
network. Also, the use of non-blocking MPI is not easily integrated into
the current context types, and the dry-run context is not supported so
far.


# Example
A (trimmed) example in Python, where a ring connection combined with
random connections based on the distance:
```py
class recipe(arbor.recipe):
    def cell_isometry(self, gid):
        # place cells with equal distance on a circle
        radius = 500.0 # μm
        angle = 2.0 * math.pi * gid / self.ncells
        return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0)

    def network_description(self):
        seed = 42

        # create a chain
        ring = f"(chain (gid-range 0 {self.ncells}))"
        # connect front and back of chain to form ring
        ring = f"(join {ring} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))"

        # Create random connections with probability inversely proportional to the distance within a
        # radius
        max_dist = 400.0 # μm
        probability = f"(div (sub {max_dist} (distance)) {max_dist})"
        rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))"

        # combine ring with random selection
        s = f"(join {ring} {rand})"
        # restrict to inter-cell connections and certain source / destination labels
        s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))"

        # normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS
        # and truncated to [0.005, 0.035]
        w = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)"
        # fixed delay
        d = "(scalar 5.0)"  # ms delay

        return arbor.network_description(s, w, d, {})
```
Co-authored-by: Thorsten Hater <[email protected]>
  • Loading branch information
AdhocMan authored May 22, 2024
1 parent 9f20e83 commit 689eea3
Show file tree
Hide file tree
Showing 54 changed files with 6,259 additions and 89 deletions.
2 changes: 2 additions & 0 deletions arbor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ set(arbor_sources
morph/segment_tree.cpp
morph/stitch.cpp
merge_events.cpp
network.cpp
network_impl.cpp
simulation.cpp
partition_load_balance.cpp
profile/clock.cpp
Expand Down
53 changes: 35 additions & 18 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "connection.hpp"
#include "distributed_context.hpp"
#include "execution_context.hpp"
#include "network_impl.hpp"
#include "profile/profiler_macro.hpp"
#include "threading/threading.hpp"
#include "util/partition.hpp"
Expand All @@ -24,14 +25,12 @@

namespace arb {

communicator::communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx): num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type) ctx.distributed->size()},
distributed_{ctx.distributed},
thread_pool_{ctx.thread_pool} {}
communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx):
num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type)ctx->distributed->size()},
ctx_(std::move(ctx)) {}

constexpr inline
bool is_external(cell_gid_type c) {
Expand All @@ -55,7 +54,7 @@ cell_member_type global_cell_of(const cell_member_type& c) {
return {c.gid | msb, c.index};
}

void communicator::update_connections(const connectivity& rec,
void communicator::update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map) {
Expand All @@ -67,6 +66,9 @@ void communicator::update_connections(const connectivity& rec,
index_divisions_.clear();
PL();

// Construct connections from high-level specification
auto generated_connections = generate_connections(rec, ctx_, dom_dec);

// Make a list of local cells' connections
// -> gid_connections
// Count the number of local connections (i.e. connections terminating on this domain)
Expand Down Expand Up @@ -114,9 +116,18 @@ void communicator::update_connections(const connectivity& rec,
}
part_ext_connections.push_back(gid_ext_connections.size());
}
for (const auto& c: generated_connections) {
auto sgid = c.source.gid;
if (sgid >= num_total_cells_) {
throw arb::bad_connection_source_gid(c.source.gid, sgid, num_total_cells_);
}
const auto src = dom_dec.gid_domain(sgid);
src_domains.push_back(src);
src_counts[src]++;
}

util::make_partition(connection_part_, src_counts);
auto n_cons = gid_connections.size();
auto n_cons = gid_connections.size() + generated_connections.size();
auto n_ext_cons = gid_ext_connections.size();
PL();

Expand All @@ -132,6 +143,7 @@ void communicator::update_connections(const connectivity& rec,
auto target_resolver = resolver(&target_resolution_map);
for (const auto index: util::make_span(num_local_cells_)) {
const auto tgt_gid = gids[index];
const auto iod = dom_dec.index_on_domain(tgt_gid);
auto source_resolver = resolver(&source_resolution_map);
for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) {
const auto& conn = gid_connections[cidx];
Expand All @@ -141,18 +153,23 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, iod};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, iod};
++ext;
}
}
for (const auto& c: generated_connections) {
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = c;
}
PL();

PE(init:communicator:update:index);
Expand All @@ -167,7 +184,7 @@ void communicator::update_connections(const connectivity& rec,
// Sort the connections for each domain.
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
threading::parallel_for::apply(0, num_domains_, ctx_->thread_pool.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
Expand All @@ -193,7 +210,7 @@ time_type communicator::min_delay() {
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
res = ctx_->distributed->min(res);
return res;
}

Expand All @@ -206,7 +223,7 @@ communicator::exchange(std::vector<spike> local_spikes) {

PE(communication:exchange:gather);
// global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = distributed_->gather_spikes(local_spikes);
auto global_spikes = ctx_->distributed->gather_spikes(local_spikes);
num_spikes_ += global_spikes.size();
PL();

Expand All @@ -217,7 +234,7 @@ communicator::exchange(std::vector<spike> local_spikes) {
local_spikes.end(),
[this] (const auto& s) { return !remote_spike_filter_(s); }));
}
auto remote_spikes = distributed_->remote_gather_spikes(local_spikes);
auto remote_spikes = ctx_->distributed->remote_gather_spikes(local_spikes);
PL();

PE(communication:exchange:gather:remote:post_process);
Expand All @@ -231,8 +248,8 @@ communicator::exchange(std::vector<spike> local_spikes) {
}

void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_spike_filter_ = p; }
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }
void communicator::remote_ctrl_send_continue(const epoch& e) { ctx_->distributed->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { ctx_->distributed->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
Expand Down
13 changes: 6 additions & 7 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once

#include <vector>
#include <unordered_set>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
#include <arbor/context.hpp>
#include <arbor/domain_decomposition.hpp>
#include <arbor/export.hpp>
#include <arbor/recipe.hpp>
#include <arbor/spike.hpp>

Expand Down Expand Up @@ -40,7 +40,7 @@ class ARB_ARBOR_API communicator {

explicit communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx);
context ctx);

/// The range of event queues that belong to cells in group i.
std::pair<cell_size_type, cell_size_type> group_queue_range(cell_size_type i);
Expand Down Expand Up @@ -78,7 +78,7 @@ class ARB_ARBOR_API communicator {
void remote_ctrl_send_continue(const epoch&);
void remote_ctrl_send_done();

void update_connections(const connectivity& rec,
void update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map);
Expand All @@ -98,7 +98,7 @@ class ARB_ARBOR_API communicator {
for (const auto& con: cons) {
idx_on_domain.push_back(con.index_on_domain);
srcs.push_back(con.source);
dests.push_back(con.destination);
dests.push_back(con.target);
weights.push_back(con.weight);
delays.push_back(con.delay);
}
Expand Down Expand Up @@ -136,10 +136,9 @@ class ARB_ARBOR_API communicator {
// Currently we have no partitions/indices/acceleration structures
connection_list ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
std::uint64_t num_spikes_ = 0u;
std::uint64_t num_local_events_ = 0u;
context ctx_;
};

} // namespace arb
185 changes: 185 additions & 0 deletions arbor/communication/distributed_for_each.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#pragma once

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <type_traits>
#include <utility>

#include "distributed_context.hpp"
#include "util/range.hpp"

namespace arb {

namespace impl {
template <class FUNC, typename... T, std::size_t... Is>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t, std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t)), ...);
}

template <class FUNC, typename... T>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t) {
for_each_in_tuple(func, t, std::index_sequence_for<T...>());
}

template <class FUNC, typename... T1, typename... T2, std::size_t... Is>
void for_each_in_tuple_pair(FUNC&& func,
std::tuple<T1...>& t1,
std::tuple<T2...>& t2,
std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t1), std::get<Is>(t2)), ...);
}

template <class FUNC, typename... T1, typename... T2>
void for_each_in_tuple_pair(FUNC&& func, std::tuple<T1...>& t1, std::tuple<T2...>& t2) {
for_each_in_tuple_pair(func, t1, t2, std::index_sequence_for<T1...>());
}

} // namespace impl


/*
* Collective operation, calling func on args supplied by each rank exactly once. The order of calls
* is unspecified. Requires
*
* - Item = util::range<ARGS>::value_type to be identical across all ranks
* - Item is trivially_copyable
* - Alignment of Item must not exceed std::max_align_t
* - func to be a callable type with signature
* void func(util::range<Item*>...)
* - func must not modify contents of range
* - All ranks in distributed must call this collectively.
*/
template <typename FUNC, typename... ARGS>
void distributed_for_each(FUNC&& func,
const distributed_context& distributed,
const util::range<ARGS>&... args) {

static_assert(sizeof...(args) > 0);
auto arg_tuple = std::forward_as_tuple(args...);

struct vec_info {
std::size_t offset; // offset in bytes
std::size_t size; // size in bytes
};

std::array<vec_info, sizeof...(args)> info;
std::size_t buffer_size = 0;

// Compute offsets in bytes for each vector when placed in common buffer
{
std::size_t offset = info.size() * sizeof(vec_info);
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
static_assert(std::is_trivially_copyable_v<T>);
static_assert(alignof(std::max_align_t) >= alignof(T));
static_assert(alignof(std::max_align_t) % alignof(T) == 0);

// make sure alignment of offset fulfills requirement
const auto alignment_excess = offset % alignof(T);
offset += alignment_excess > 0 ? alignof(T) - (alignment_excess) : 0;

const auto size_in_bytes = vec.size() * sizeof(T);

info[i].size = size_in_bytes;
info[i].offset = offset;

buffer_size = offset + size_in_bytes;
offset += size_in_bytes;
},
arg_tuple);
}

// compute maximum buffer size between ranks, such that we only allocate once
const std::size_t max_buffer_size = distributed.max(buffer_size);

std::tuple<util::range<typename std::remove_reference_t<decltype(args)>::value_type*>...>
ranges;

if (max_buffer_size == info.size() * sizeof(vec_info)) {
// if all empty, call function with empty ranges for each step and exit
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>(nullptr, nullptr);
},
arg_tuple,
ranges);

for (int step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); }
return;
}

// use malloc for std::max_align_t alignment
auto deleter = [](char* ptr) { std::free(ptr); };
std::unique_ptr<char[], void (*)(char*)> buffer((char*)std::malloc(max_buffer_size), deleter);
std::unique_ptr<char[], void (*)(char*)> recv_buffer(
(char*)std::malloc(max_buffer_size), deleter);

// copy offset and size info to front of buffer
std::memcpy(buffer.get(), info.data(), info.size() * sizeof(vec_info));

// copy each vector to each location in buffer
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
std::copy(vec.begin(), vec.end(), (T*)(buffer.get() + info[i].offset));
},
arg_tuple);


const auto my_rank = distributed.id();
const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1;
const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1;

// exchange buffer in ring pattern and apply function at each step
for (int step = 0; step < distributed.size() - 1; ++step) {
// always expect to recieve the max size but send actual size. MPI_recv only expects a max
// size, not the actual size.
const auto current_info = (const vec_info*)buffer.get();

auto request = distributed.send_recv_nonblocking(max_buffer_size,
recv_buffer.get(),
right_rank,
current_info[info.size() - 1].offset + current_info[info.size() - 1].size,
buffer.get(),
left_rank,
0);

// update ranges
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);

request.finalize();
buffer.swap(recv_buffer);
}

// final step does not require any exchange
const auto current_info = (const vec_info*)buffer.get();
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);
}

} // namespace arb
Loading

0 comments on commit 689eea3

Please sign in to comment.