From cc609bf0e4c4da8a63b21c8b21f3a2b5b3577ab5 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 27 Feb 2023 19:40:04 +0100 Subject: [PATCH 01/84] replace domain index with gid in connection struct --- arbor/communication/communicator.cpp | 41 ++++++++++++++-------------- arbor/communication/communicator.hpp | 3 ++ arbor/connection.hpp | 21 ++++---------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index 70fa08ba2e..9bffc2703f 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -39,16 +40,15 @@ void communicator::update_connections(const connectivity& rec, connections_.clear(); connection_part_.clear(); index_divisions_.clear(); + index_on_domain_.clear(); // For caching information about each cell struct gid_info { using connection_list = decltype(std::declval().connections_on(0)); cell_gid_type gid; // global identifier of cell - cell_size_type index_on_domain; // index of cell in this domain connection_list conns; // list of connections terminating at this cell gid_info() = default; // so we can in a std::vector - gid_info(cell_gid_type g, cell_size_type di, connection_list c): - gid(g), index_on_domain(di), conns(std::move(c)) {} + gid_info(cell_gid_type g, connection_list c): gid(g), conns(std::move(c)) {} }; // Make a list of local gid with their group index and connections @@ -60,22 +60,20 @@ void communicator::update_connections(const connectivity& rec, // Also the count of presynaptic sources from each domain // -> src_counts: array with one entry for each domain - // Record all the gid in a flat vector. - // These are used to map from local index to gid in the parallel loop - // that populates gid_infos. - std::vector gids; - gids.reserve(num_local_cells_); - for (auto g: dom_dec.groups()) { - util::append(gids, g.gids); + // create gid_infos and store their enumeration in a map + std::vector gid_infos; + gid_infos.reserve(num_local_cells_); + { + cell_size_type index = 0; + for (const auto& group: dom_dec.groups()) { + for (const auto& gid: group.gids) { + index_on_domain_.insert({gid, index++}); + gid_infos.emplace_back(gid, rec.connections_on(gid)); + } + } } + // Build the connection information for local cells in parallel. - std::vector gid_infos; - gid_infos.resize(num_local_cells_); - threading::parallel_for::apply(0, gids.size(), thread_pool_.get(), - [&](cell_size_type i) { - auto gid = gids[i]; - gid_infos[i] = gid_info(gid, i, rec.connections_on(gid)); - }); cell_local_size_type n_cons = util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); std::vector src_domains; @@ -103,14 +101,14 @@ void communicator::update_connections(const connectivity& rec, auto src_domain = src_domains.begin(); auto target_resolver = resolver(&target_resolution_map); for (const auto& cell: gid_infos) { - auto index = cell.index_on_domain; auto source_resolver = resolver(&source_resolution_map); for (const auto& c: cell.conns) { auto src_lid = source_resolver.resolve(c.source); auto tgt_lid = target_resolver.resolve({cell.gid, c.dest}); auto offset = offsets[*src_domain]++; ++src_domain; - connections_[offset] = {{c.source.gid, src_lid}, tgt_lid, c.weight, c.delay, index}; + connections_[offset] = { + {c.source.gid, src_lid}, {cell.gid, tgt_lid}, c.weight, c.delay}; } } @@ -182,7 +180,7 @@ void communicator::make_event_queues(const gathered_vector& global_spikes // number of spikes, and C is the number of connections. if (cons.size()index_on_domain]; + auto& queue = queues[index_on_domain_.at(cn->destination.gid)]; auto src = cn->source; auto sources = std::equal_range(sp, se, src, spike_pred()); for (auto s: util::make_range(sources)) queue.push_back(cn->make_event(s)); @@ -193,7 +191,8 @@ void communicator::make_event_queues(const gathered_vector& global_spikes else { while (cn!=ce && sp!=se) { auto targets = std::equal_range(cn, ce, sp->source); - for (auto c: util::make_range(targets)) queues[c.index_on_domain].push_back(c.make_event(*sp)); + for (auto c: util::make_range(targets)) + queues[index_on_domain_.at(c.destination.gid)].push_back(c.make_event(*sp)); cn = targets.first; ++sp; } diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index c9eab08294..850ce0f119 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -82,6 +84,7 @@ class ARB_ARBOR_API communicator { std::vector connection_part_; std::vector index_divisions_; util::partition_view_type> index_part_; + std::unordered_map index_on_domain_; distributed_context_handle distributed_; task_system_handle thread_pool_; diff --git a/arbor/connection.hpp b/arbor/connection.hpp index 7fab0162bf..9ea5cc8623 100644 --- a/arbor/connection.hpp +++ b/arbor/connection.hpp @@ -10,25 +10,18 @@ namespace arb { class connection { public: connection() = default; - connection(cell_member_type src, - cell_lid_type dest, - float w, - float d, - cell_gid_type didx=cell_gid_type(-1)): + connection(cell_member_type src, cell_member_type dest, float w, float d): source(src), destination(dest), weight(w), - delay(d), - index_on_domain(didx) - {} + delay(d) {} - spike_event make_event(const spike& s) { return { destination, s.time + delay, weight}; } + spike_event make_event(const spike& s) { return {destination.index, s.time + delay, weight}; } cell_member_type source; - cell_lid_type destination; + cell_member_type destination; float weight; float delay; - cell_size_type index_on_domain; }; // connections are sorted by source id @@ -40,8 +33,6 @@ static inline bool operator<(cell_member_type lhs, const connection& rhs) { ret } // namespace arb static inline std::ostream& operator<<(std::ostream& o, arb::connection const& con) { - return o << "con [" << con.source << " -> " << con.destination - << " : weight " << con.weight - << ", delay " << con.delay - << ", index " << con.index_on_domain << "]"; + return o << "con [" << con.source << " -> " << con.destination << " : weight " << con.weight + << ", delay " << con.delay << "]"; } From dac34adf72013271b053d1f1d9512656603d804a Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 3 Mar 2023 17:45:06 +0100 Subject: [PATCH 02/84] wip --- arbor/include/arbor/math.hpp | 13 ++ arbor/include/arbor/network.hpp | 252 ++++++++++++++++++++++++++++++++ arbor/include/arbor/recipe.hpp | 5 +- arbor/label_resolution.cpp | 13 ++ arbor/label_resolution.hpp | 2 + arbor/network_generation.cpp | 158 ++++++++++++++++++++ arbor/network_generation.hpp | 23 +++ arbor/util/spatial_tree.hpp | 174 ++++++++++++++++++++++ 8 files changed, 639 insertions(+), 1 deletion(-) create mode 100644 arbor/include/arbor/network.hpp create mode 100644 arbor/network_generation.cpp create mode 100644 arbor/network_generation.hpp create mode 100644 arbor/util/spatial_tree.hpp diff --git a/arbor/include/arbor/math.hpp b/arbor/include/arbor/math.hpp index 17e12b3c2c..0e10a303bc 100644 --- a/arbor/include/arbor/math.hpp +++ b/arbor/include/arbor/math.hpp @@ -32,6 +32,19 @@ T constexpr area_circle(T r) { return pi * square(r); } +template >> +T constexpr pow(T base, U exp) { + if (exp == 0) return 1; + + const U exp_half = exp / 2; + if (2 * exp_half == exp) { + const auto r = ::arb::math::pow(base, exp_half); + return r * r; + } + + return base * ::arb::math::pow(base, exp - 1); +} + // Surface area of conic frustrum excluding the discs at each end, // with length L, end radii r1, r2. template diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp new file mode 100644 index 0000000000..504842a27c --- /dev/null +++ b/arbor/include/arbor/network.hpp @@ -0,0 +1,252 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arb { + +using network_location = std::array; + +using network_hash_type = std::uint64_t; + +class network_site_selection { +public: + static network_site_selection all(); + + static network_site_selection none(); + + static network_site_selection has_cell_kind(cell_kind kind); + + static network_site_selection has_tag(std::vector tags); + + static network_site_selection has_gid(std::vector gids); + + static network_site_selection has_gid_in_range(cell_gid_type begin, cell_gid_type end); + + static network_site_selection invert(network_site_selection s); + + network_site_selection operator&(network_site_selection right) const; + + network_site_selection operator|(network_site_selection right) const; + + network_site_selection operator^(network_site_selection right) const; + + bool operator()(cell_gid_type gid, + cell_kind kind, + const cell_tag_type& tag, + const mlocation& local_loc, + const network_location& global_loc) const; + +private: +}; + +class ARB_SYMBOL_VISIBLE network_connection_selection { +public: + // Random selection using the bernoulli random distribution with probability "p" between 0.0 + // and 1.0 + static network_connection_selection bernoulli_random(unsigned seed, double p); + + // Custom selection using the provided function "func". Repeated calls with the same arguments + // to "func" must yield the same result. For gap junction selection, + // "func" must be symmetric (func(a,b) = func(b,a)). + static network_connection_selection custom(std::function< + bool(cell_gid_type, const network_location&, cell_gid_type, const network_location&)> func); + + // Select all + static network_connection_selection all(); + + // Select none + static network_connection_selection none(); + + // Invert the selection + static network_connection_selection invert(network_connection_selection s); + + // Only select connections between different cells + static network_connection_selection inter_cell(); + + // Only select connections when the global labels are not equal. May select intra-cell + // connections, if the local label is not equal. + static network_connection_selection not_equal(); + + // only select within given distance. This may enable more efficient sampling through an + // internal spatial data structure. + static network_connection_selection within_distance(double distance); + + // random bernoulli sampling with a linear interpolated probabilty based on distance. Returns + // "false" for any distance outside of the interval [distance_begin, distance_end]. + static network_connection_selection linear_bernoulli_random(unsigned seed, + double distance_begin, + double p_begin, + double distance_end, + double p_end); + + // Returns true if a connection between src and dest is selected. + bool operator()(const cell_global_label_type& src, + const network_location& src_location, + const cell_global_label_type& dest, + const network_location& dest_location) const; + + network_connection_selection operator&(network_connection_selection right) const; + + network_connection_selection operator|(network_connection_selection right) const; + + network_connection_selection operator^(network_connection_selection right) const; + + // Returns true if a connection between src and dest is selected. + inline bool operator()(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const { + return impl_->select( + src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + } + + inline std::optional max_distance() const { return impl_->max_distance(); } + +private: + struct selection_impl { + virtual bool select(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const = 0; + + virtual std::optional max_distance() const { return std::nullopt; } + + virtual ~selection_impl() = default; + }; + + struct bernoulli_random_impl; + struct custom_impl; + struct inter_cell_impl; + struct not_equal_impl; + struct all_impl; + struct none_impl; + struct and_impl; + struct or_impl; + struct xor_impl; + struct invert_impl; + struct within_distance_impl; + struct linear_bernoulli_random_impl; + + network_connection_selection(std::shared_ptr impl); + + inline bool select(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const { + return impl_->select( + src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + } + + std::shared_ptr impl_; +}; + +class ARB_SYMBOL_VISIBLE network_value { +public: + // Uniform value + network_value(double value); + + // Uniform value. Will always return the same value given at construction. + static network_value uniform(double value); + + // Uniform random value in (range[0], range[1]]. Always returns the same value for repeated + // calls with the same arguments and calls are symmetric v(a, b) = v(b, a). + static network_value uniform_distribution(unsigned seed, const std::array& range); + + // Radom value from a normal distribution with given mean and standard deviation. Always returns + // the same value for repeated calls with the same arguments and calls are symmetric v(a, b) = + // v(b, a). + static network_value normal_distribution(unsigned seed, double mean, double std_deviation); + + // Radom value from a truncated normal distribution with given mean and standard deviation (of a + // non-truncated normal distribution), where the value is always in (range[0], range[1]]. Always + // returns the same value for repeated calls with the same arguments and calls are symmetric + // v(a, b) = v(b, a). Note: Values are generated by reject-accept method from a normal + // distribution. Low acceptance rate can leed to poor performance, for example with very small + // ranges or a mean far outside the range. + static network_value truncated_normal_distribution(unsigned seed, + double mean, + double std_deviation, + const std::array& range); + + // Custom value using the provided function "func". Repeated calls with the same arguments + // to "func" must yield the same result. For gap junction values, + // "func" must be symmetric (func(a,b) = func(b,a)). + static network_value custom(std::function func); + + inline double operator()(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const { + return impl_->get( + src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + } + +private: + + struct value_impl { + virtual double get(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const = 0; + + virtual ~value_impl() = default; + }; + + struct uniform_distribution_impl; + struct normal_distribution_impl; + struct truncated_normal_distribution_impl; + struct custom_impl; + struct uniform_impl; + + network_value(std::shared_ptr impl); + + inline double get(cell_gid_type src_gid, + const network_location& global_src_location, + network_hash_type src_hash, + cell_gid_type dest_gid, + const network_location& global_dest_location, + network_hash_type dest_hash) const { + return impl_->get( + src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + } + + std::shared_ptr impl_; +}; + +struct network_description { + network_site_selection src_selection; + network_site_selection dest_selection; + network_connection_selection connection_selection; + network_value weight; + network_value delay; +}; + +} // namespace arb diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 1f116b970e..15ac14cd7f 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -4,9 +4,10 @@ #include #include -#include #include #include +#include +#include #include namespace arb { @@ -104,6 +105,8 @@ struct ARB_ARBOR_API recipe: public has_gap_junctions, has_probes, connectivity virtual cell_kind get_cell_kind(cell_gid_type) const = 0; // Global property type will be specific to given cell kind. virtual std::any get_global_properties(cell_kind) const { return std::any{}; }; + // Optional network descriptions for generating cell connections + virtual std::vector network_descriptions() const { return {}; }; virtual ~recipe() {} }; diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index 8d5ddadbb7..7b126878cd 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -90,6 +90,19 @@ std::size_t label_resolution_map::count(const cell_gid_type& gid, const cell_tag return map.at(gid).count(tag); } +const cell_tag_type& label_resolution_map::tag_at(const cell_gid_type& gid, + const cell_lid_type& lid) const { + for(const auto& [tag, r_set] : map.at(gid)) { + for(const auto& range: r_set.ranges) { + if(lid>= range.begin && lid < range.end) { + return tag; + } + } + } + throw arbor_internal_error("gid and lid mismatch"); + return map.begin()->second.begin()->first; +} + label_resolution_map::label_resolution_map(const cell_labels_and_gids& clg) { arb_assert(clg.label_range.check_invariant()); const auto& gids = clg.gids; diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index dbfe2014a9..3373ea54ba 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -83,6 +83,8 @@ class ARB_ARBOR_API label_resolution_map { const range_set& at(const cell_gid_type& gid, const cell_tag_type& tag) const; std::size_t count(const cell_gid_type& gid, const cell_tag_type& tag) const; + const cell_tag_type& tag_at(const cell_gid_type& gid, const cell_lid_type& lid) const; + private: std::unordered_map> map; }; diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp new file mode 100644 index 0000000000..68458989ac --- /dev/null +++ b/arbor/network_generation.cpp @@ -0,0 +1,158 @@ +#include "network_generation.hpp" +#include "util/spatial_tree.hpp" + +#include +#include +#include +#include +#include + +namespace arb { + +namespace { +struct dest_site_info { + cell_gid_type gid; + cell_lid_type lid; + network_hash_type hash; +}; + +struct src_site_info { + cell_gid_type gid; + cell_lid_type lid; + double x, y, z; + network_hash_type hash; +}; +} // namespace + +std::vector generate_network_connections( + const std::vector& descriptions, + const recipe& rec, + const distributed_context& distributed, + const domain_decomposition& dom_dec, + const label_resolution_map& source_resolution_map, + const label_resolution_map& target_resolution_map) { + if (descriptions.empty()) return {}; + + std::vector> local_src_sites(descriptions.size()); + std::vector>> local_dest_sites( + descriptions.size()); + + // populate network sites for source and destination + for (const auto& group: dom_dec.groups()) { + switch (group.kind) { + case cell_kind::cable: { + cable_cell cell; + for (const auto& gid: group.gids) { + try { + cell = util::any_cast(rec.get_cell_description(gid)); + } + catch (std::bad_any_cast&) { + throw bad_cell_description(rec.get_cell_kind(gid), gid); + } + + place_pwlin location_resolver(cell.morphology()); + + // check all synapses of cell for potential destination + for (const auto& [name, placed_synapses]: cell.synapses()) { + for (const auto& p_syn: placed_synapses) { + // TODO: compute rotation and global offset + const mpoint point = location_resolver.at(p_syn.loc); + network_location location = {point.x, point.y, point.z}; + // TODO check if tag correct + const auto& tag = target_resolution_map.tag_at(gid, p_syn.lid); + + for (std::size_t i = 0; i < descriptions.size(); ++i) { + const auto& desc = descriptions[i]; + if (desc.dest_selection( + gid, cell_kind::cable, tag, p_syn.loc, location)) { + // TODO : compute hash + network_hash_type hash = 0; + local_dest_sites[i].push_back({location, {gid, p_syn.lid, hash}}); + } + } + } + } + + // check all detectors of cell for potential source + for (const auto& p_det: cell.detectors()) { + // TODO: compute rotation and global offset + const mpoint point = location_resolver.at(p_det.loc); + network_location location = {point.x, point.y, point.z}; + // TODO check if tag correct + const auto& tag = target_resolution_map.tag_at(gid, p_det.lid); + + for (std::size_t i = 0; i < descriptions.size(); ++i) { + const auto& desc = descriptions[i]; + if (desc.src_selection(gid, cell_kind::cable, tag, p_det.loc, location)) { + // TODO : compute hash + network_hash_type hash = 0; + local_src_sites[i].push_back( + {gid, p_det.lid, location[0], location[1], location[2], hash}); + } + } + } + } + } break; + case cell_kind::lif: { + // TODO + for (const auto& gid: group.gids) {} + } break; + case cell_kind::benchmark: { + // TODO + for (const auto& gid: group.gids) {} + } break; + case cell_kind::spike_source: { + // TODO + for (const auto& gid: group.gids) {} + } break; + } + } + + // create octrees + std::vector> local_dest_trees; + local_dest_trees.reserve(descriptions.size()); + for (std::size_t i = 0; i < descriptions.size(); ++i) { + const auto& desc = descriptions[i]; + const std::size_t max_depth = desc.connection_selection.max_distance().has_value() ? 10 : 1; + local_dest_trees.emplace_back(max_depth, 100, std::move(local_dest_sites[i])); + } + + // select connections + std::vector connections; + + for (std::size_t i = 0; i < descriptions.size(); ++i) { + const auto& desc = descriptions[i]; + const auto& src_sites = local_src_sites[i]; + const auto& dest_tree = local_dest_trees[i]; + + for (const auto& src: src_sites) { + auto sample_dest = [&](const network_location& dest_loc, const dest_site_info& dest) { + // TODO precompute distance + if (desc.connection_selection( + src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash)) { + const double w = desc.weight( + src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash); + const double d = desc.delay( + src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash); + + connections.emplace_back(cell_member_type{src.gid, src.lid}, + cell_member_type{dest.gid, dest.lid}, + w, + d); + } + }; + + if(desc.connection_selection.max_distance().has_value()) { + const double d = desc.connection_selection.max_distance().value(); + dest_tree.bounding_box_for_each(network_location{src.x - d, src.y - d, src.z - d}, + network_location{src.x + d, src.y + d, src.z + d}, + sample_dest); + } + else { dest_tree.for_each(sample_dest); } + } + } + + return connections; +} + +} // namespace arb diff --git a/arbor/network_generation.hpp b/arbor/network_generation.hpp new file mode 100644 index 0000000000..37a9fa56f8 --- /dev/null +++ b/arbor/network_generation.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include +#include +#include + +#include "connection.hpp" +#include "distributed_context.hpp" +#include "label_resolution.hpp" + +namespace arb { + +std::vector generate_network_connections( + const std::vector& descriptions, + const connectivity& rec, + const distributed_context& distributed, + const domain_decomposition& dom_dec, + const label_resolution_map& source_resolution_map, + const label_resolution_map& target_resolution_map); + +} // namespace arb diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp new file mode 100644 index 0000000000..6a2f29ac88 --- /dev/null +++ b/arbor/util/spatial_tree.hpp @@ -0,0 +1,174 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace arb { + +// An immutable spatial data structure for storing and iterating over data in "DIM" dimensional +// space. If DIM = 1 it's a binary tree, if DIM = 2 it's a quad tree and so on. +template +class spatial_tree { +public: + static_assert(DIM >= 1, "Dimension of tree must be at least 1."); + + using value_type = T; + using point_type = std::array; + using node_data = std::vector; + using leaf_data = std::vector>; + + spatial_tree(): size_(0), data_(leaf_data()) {} + + // Create a tree of given maximum depth and target leaf size. If any leaf holds more than the + // target size, it is recursively split into up to 2^DIM nodes until reaching the maximum depth. + spatial_tree(std::size_t max_depth, std::size_t leaf_size_target, leaf_data data): + size_(data.size()), + data_(std::move(data)) { + auto &leaf_d = std::get(data_); + if (leaf_d.empty()) return; + + min_.fill(std::numeric_limits::max()); + max_.fill(-std::numeric_limits::max()); + + for (const auto &[p, _]: leaf_d) { + for (std::size_t i = 0; i < DIM; ++i) { + if (p[i] < min_[i]) min_[i] = p[i]; + if (p[i] > max_[i]) max_[i] = p[i]; + } + } + + value_type mid; + for (std::size_t i = 0; i < DIM; ++i) { mid[i] = (max_[i] - min_[i]) / 2.0 + min_[i]; } + + if (max_depth > 1 && leaf_d.size() > leaf_size_target) { + constexpr auto divisor = math::pow(2, DIM); + + // The initial index of the sub node containing p + auto sub_node_index = [&](const point_type &p) { + std::size_t index = 0; + for (std::size_t i = 0; i < DIM; ++i) { index += i * 2 * (p[i] >= mid[i]); } + return index; + }; + + node_data new_nodes; + new_nodes.reserve(divisor); + + // assign each point to sub-node + std::array new_leaf_data; + for (const auto &[p, d]: leaf_d) { + new_leaf_data[sub_node_index(p)].emplace_back(p, d); + } + + // move data into new sub-nodes if not empty + for (auto &l_d: new_leaf_data) { + if (l_d.size()) + new_nodes.emplace_back(max_depth - 1, leaf_size_target, std::move(l_d)); + } + + // replace current data_ with new sub-nodes + this->data_ = std::move(new_nodes); + } + } + + spatial_tree(const spatial_tree &) = default; + + spatial_tree(spatial_tree &&t) { *this = std::move(t); } + + spatial_tree &operator=(const spatial_tree &) = default; + + spatial_tree &operator=(spatial_tree &&t) { + data_ = std::move(t.data_); + size_ = t.size_; + min_ = t.min_; + max_ = t.max_; + + t.data_ = leaf_data(); + t.size_ = 0; + t.min_ = point_type(); + t.max_ = point_type(); + + return *this; + } + + // Iterate over all points recursively. + // func must have signature `void func(const point_type&, const T&)`. + template + inline void for_each(const F &func) const { + std::visit( + [&](auto &&arg) { + using arg_type = std::decay_t; + if constexpr (std::is_same_v) { + for (const auto &node: arg) { node.for_each(func); } + } + if constexpr (std::is_same_v) { + for (const auto &[p, d]: arg) { func(p, d); } + } + }, + data_); + } + + // Iterate over all points within the given bounding box recursively. + // func must have signature `void func(const point_type&, const T&)`. + template + inline void bounding_box_for_each(const point_type &box_min, + const point_type &box_max, + const F &func) const { + auto all_smaller_eq = [](const point_type &lhs, const point_type &rhs) { + bool result = true; + for (std::size_t i = 0; i < DIM; ++i) { result &= lhs[i] <= rhs[i]; } + return result; + }; + + std::visit( + [&](auto &&arg) { + using arg_type = std::decay_t; + + if (all_smaller_eq(box_min, min_) && all_smaller_eq(max_, box_max)) { + // sub-nodes fully inside box -> call without further boundary + // checks + if constexpr (std::is_same_v) { + for (const auto &node: arg) { node.template for_each(func); } + } + if constexpr (std::is_same_v) { + for (const auto &[p, d]: arg) { func(p, d); } + } + } + else { + // sub-nodes partially overlap bounding box + if constexpr (std::is_same_v) { + for (const auto &node: arg) { + if (all_smaller_eq(node.min_, box_max) && + all_smaller_eq(box_min, node.max_)) + node.template bounding_box_for_each(box_min, box_max, func); + } + } + if constexpr (std::is_same_v) { + for (const auto &[p, d]: arg) { + if (all_smaller_eq(p, box_max) && all_smaller_eq(box_min, p)) { + func(p, d); + } + } + } + } + }, + data_); + } + + inline std::size_t size() const { return size_; } + + inline bool empty() const { return !size_; } + +private: + std::size_t size_; + point_type min_, max_; + std::variant data_; +}; + +} // namespace arb From 76109ba16321c9dea66d0352747bf29ca575af1f Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 6 Mar 2023 15:03:01 +0100 Subject: [PATCH 03/84] reset label resolution --- arbor/label_resolution.cpp | 13 ------------- arbor/label_resolution.hpp | 2 -- 2 files changed, 15 deletions(-) diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index 7b126878cd..8d5ddadbb7 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -90,19 +90,6 @@ std::size_t label_resolution_map::count(const cell_gid_type& gid, const cell_tag return map.at(gid).count(tag); } -const cell_tag_type& label_resolution_map::tag_at(const cell_gid_type& gid, - const cell_lid_type& lid) const { - for(const auto& [tag, r_set] : map.at(gid)) { - for(const auto& range: r_set.ranges) { - if(lid>= range.begin && lid < range.end) { - return tag; - } - } - } - throw arbor_internal_error("gid and lid mismatch"); - return map.begin()->second.begin()->first; -} - label_resolution_map::label_resolution_map(const cell_labels_and_gids& clg) { arb_assert(clg.label_range.check_invariant()); const auto& gids = clg.gids; diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index 3373ea54ba..dbfe2014a9 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -83,8 +83,6 @@ class ARB_ARBOR_API label_resolution_map { const range_set& at(const cell_gid_type& gid, const cell_tag_type& tag) const; std::size_t count(const cell_gid_type& gid, const cell_tag_type& tag) const; - const cell_tag_type& tag_at(const cell_gid_type& gid, const cell_lid_type& lid) const; - private: std::unordered_map> map; }; From 27798a1dcba30ec905fdcea8e997b6b2e4f1e88c Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 6 Mar 2023 18:55:47 +0100 Subject: [PATCH 04/84] network_impl --- arbor/CMakeLists.txt | 1 + arbor/include/arbor/network.hpp | 195 +++------ arbor/network.cpp | 710 ++++++++++++++++++++++++++++++++ arbor/network_impl.hpp | 62 +++ 4 files changed, 837 insertions(+), 131 deletions(-) create mode 100644 arbor/network.cpp create mode 100644 arbor/network_impl.hpp diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt index b58ebff44c..948205b2ef 100644 --- a/arbor/CMakeLists.txt +++ b/arbor/CMakeLists.txt @@ -45,6 +45,7 @@ set(arbor_sources morph/segment_tree.cpp morph/stitch.cpp merge_events.cpp + network.cpp simulation.cpp partition_load_balance.cpp profile/clock.cpp diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 504842a27c..8d747609cb 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -1,9 +1,8 @@ #pragma once +#include #include #include -#include -#include #include #include @@ -22,142 +21,78 @@ using network_location = std::array; using network_hash_type = std::uint64_t; -class network_site_selection { +struct network_site_info { + cell_gid_type gid; + cell_kind kind; + const cell_tag_type& tag; + const network_location& location; + network_hash_type hash; +}; + +struct network_selection_impl; + +class ARB_SYMBOL_VISIBLE network_selection { public: - static network_site_selection all(); + network_selection() { *this = network_selection::all(); } - static network_site_selection none(); + // Select all + static network_selection all(); - static network_site_selection has_cell_kind(cell_kind kind); + // Select none + static network_selection none(); - static network_site_selection has_tag(std::vector tags); + static network_selection source_cell_kind(cell_kind kind); - static network_site_selection has_gid(std::vector gids); + static network_selection destination_cell_kind(cell_kind kind); - static network_site_selection has_gid_in_range(cell_gid_type begin, cell_gid_type end); + static network_selection source_label(std::vector labels); - static network_site_selection invert(network_site_selection s); + static network_selection destination_label(std::vector labels); - network_site_selection operator&(network_site_selection right) const; + static network_selection source_gid(std::vector gids); - network_site_selection operator|(network_site_selection right) const; + static network_selection destination_gid(std::vector gids); - network_site_selection operator^(network_site_selection right) const; + // Invert the selection + static network_selection invert(network_selection s); - bool operator()(cell_gid_type gid, - cell_kind kind, - const cell_tag_type& tag, - const mlocation& local_loc, - const network_location& global_loc) const; + // Only select connections between different cells + static network_selection inter_cell(); -private: -}; + // Only select connections when the global labels are not equal. May select intra-cell + // connections, if the local label is not equal. + static network_selection not_equal(); -class ARB_SYMBOL_VISIBLE network_connection_selection { -public: // Random selection using the bernoulli random distribution with probability "p" between 0.0 // and 1.0 - static network_connection_selection bernoulli_random(unsigned seed, double p); + static network_selection bernoulli_random(unsigned seed, double p); // Custom selection using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction selection, // "func" must be symmetric (func(a,b) = func(b,a)). - static network_connection_selection custom(std::function< - bool(cell_gid_type, const network_location&, cell_gid_type, const network_location&)> func); - - // Select all - static network_connection_selection all(); - - // Select none - static network_connection_selection none(); - - // Invert the selection - static network_connection_selection invert(network_connection_selection s); - - // Only select connections between different cells - static network_connection_selection inter_cell(); - - // Only select connections when the global labels are not equal. May select intra-cell - // connections, if the local label is not equal. - static network_connection_selection not_equal(); + static network_selection custom( + std::function func); // only select within given distance. This may enable more efficient sampling through an // internal spatial data structure. - static network_connection_selection within_distance(double distance); + static network_selection within_distance(double distance); // random bernoulli sampling with a linear interpolated probabilty based on distance. Returns // "false" for any distance outside of the interval [distance_begin, distance_end]. - static network_connection_selection linear_bernoulli_random(unsigned seed, + static network_selection linear_bernoulli_random(unsigned seed, double distance_begin, double p_begin, double distance_end, double p_end); - // Returns true if a connection between src and dest is selected. - bool operator()(const cell_global_label_type& src, - const network_location& src_location, - const cell_global_label_type& dest, - const network_location& dest_location) const; + network_selection operator&(network_selection right) const; - network_connection_selection operator&(network_connection_selection right) const; + network_selection operator|(network_selection right) const; - network_connection_selection operator|(network_connection_selection right) const; - - network_connection_selection operator^(network_connection_selection right) const; - - // Returns true if a connection between src and dest is selected. - inline bool operator()(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const { - return impl_->select( - src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - } - - inline std::optional max_distance() const { return impl_->max_distance(); } + network_selection operator^(network_selection right) const; private: - struct selection_impl { - virtual bool select(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const = 0; - - virtual std::optional max_distance() const { return std::nullopt; } - - virtual ~selection_impl() = default; - }; - - struct bernoulli_random_impl; - struct custom_impl; - struct inter_cell_impl; - struct not_equal_impl; - struct all_impl; - struct none_impl; - struct and_impl; - struct or_impl; - struct xor_impl; - struct invert_impl; - struct within_distance_impl; - struct linear_bernoulli_random_impl; - - network_connection_selection(std::shared_ptr impl); - - inline bool select(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const { - return impl_->select( - src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - } - - std::shared_ptr impl_; + std::shared_ptr impl_; }; class ARB_SYMBOL_VISIBLE network_value { @@ -197,25 +132,25 @@ class ARB_SYMBOL_VISIBLE network_value { const network_location&, double)> func); - inline double operator()(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const { - return impl_->get( - src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - } + // inline double operator()(cell_gid_type src_gid, + // const network_location& global_src_location, + // network_hash_type src_hash, + // cell_gid_type dest_gid, + // const network_location& global_dest_location, + // network_hash_type dest_hash) const { + // return impl_->get( + // src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + // } private: struct value_impl { - virtual double get(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const = 0; + // virtual double get(cell_gid_type src_gid, + // const network_location& global_src_location, + // network_hash_type src_hash, + // cell_gid_type dest_gid, + // const network_location& global_dest_location, + // network_hash_type dest_hash) const = 0; virtual ~value_impl() = default; }; @@ -228,23 +163,21 @@ class ARB_SYMBOL_VISIBLE network_value { network_value(std::shared_ptr impl); - inline double get(cell_gid_type src_gid, - const network_location& global_src_location, - network_hash_type src_hash, - cell_gid_type dest_gid, - const network_location& global_dest_location, - network_hash_type dest_hash) const { - return impl_->get( - src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - } + // inline double get(cell_gid_type src_gid, + // const network_location& global_src_location, + // network_hash_type src_hash, + // cell_gid_type dest_gid, + // const network_location& global_dest_location, + // network_hash_type dest_hash) const { + // return impl_->get( + // src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); + // } std::shared_ptr impl_; }; struct network_description { - network_site_selection src_selection; - network_site_selection dest_selection; - network_connection_selection connection_selection; + network_selection selection; network_value weight; network_value delay; }; diff --git a/arbor/network.cpp b/arbor/network.cpp new file mode 100644 index 0000000000..186deb68ba --- /dev/null +++ b/arbor/network.cpp @@ -0,0 +1,710 @@ +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "network_impl.hpp" + +namespace arb { + +namespace { + +// Partial seed to use for network_value and network_selection generation. +// Different seed for each type to avoid unintentional correlation. +enum class network_seed : unsigned { + selection_bernoulli = 2058443, + spatial_selection_bernoulli = 839033, + value_uniform = 48202, + value_normal = 8405, + value_truncated_normal = 380237 +}; + +double uniform_rand_from_key_pair(std::array seed, + network_hash_type key_a, + network_hash_type key_b) { + using rand_type = r123::Threefry2x64; + const rand_type::ctr_type seed_input = {{seed[0], seed[1]}}; + + const rand_type::key_type key = {{std::min(key_a, key_b), std::max(key_a, key_b)}}; + rand_type gen; + return r123::u01(gen(seed_input, key)[0]); +} + +double network_location_distance(const network_location& a, const network_location& b) { + return std::sqrt(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); +} + +template +struct network_selection_crtp: public network_selection_impl { + bool select_source(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_source_impl(gid, cell, tag); + } + + bool select_destination(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_destination_impl(gid, cell, tag); + } + + bool select_source(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_source_impl(gid, cell, tag); + } + + bool select_destination(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_destination_impl(gid, cell, tag); + } + + bool select_source(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_source_impl(gid, cell, tag); + } + + bool select_destination(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_destination_impl(gid, cell, tag); + } + + bool select_source(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_source_impl(gid, cell, tag); + } + + bool select_destination(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return static_cast(this)->select_destination_impl(gid, cell, tag); + } +}; + +struct network_selection_all_impl: public network_selection_crtp { + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return true; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_none_impl: + public network_selection_crtp { + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return false; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return false; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return false; + } +}; + +struct network_selection_source_cell_kind_impl: public network_selection_impl { + cell_kind kind; + + explicit network_selection_source_cell_kind_impl(cell_kind k): kind(k) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return src.kind == kind; + } + + bool select_source(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::cable; + } + + bool select_source(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::lif; + } + + bool select_source(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::benchmark; + } + + bool select_source(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::spike_source; + } + + bool select_destination(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_destination(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_destination(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_destination(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return true; + } +}; + +struct network_selection_destination_cell_kind_impl: + public network_selection_impl { + cell_kind kind; + + explicit network_selection_destination_cell_kind_impl(cell_kind k): kind(k) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return src.kind == kind; + } + + bool select_source(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_source(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_source(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_source(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return true; + } + + bool select_destination(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::cable; + } + + bool select_destination(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::lif; + } + + bool select_destination(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::benchmark; + } + + bool select_destination(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const override { + return kind == cell_kind::spike_source; + } +}; + + +struct network_selection_source_label_impl: + public network_selection_crtp { + std::vector sorted_labels; + + explicit network_selection_source_label_impl(std::vector labels): + sorted_labels(std::move(labels)) { + std::sort(sorted_labels.begin(), sorted_labels.end()); + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.tag); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_destination_label_impl: + public network_selection_crtp { + std::vector sorted_labels; + + explicit network_selection_destination_label_impl(std::vector labels): + sorted_labels(std::move(labels)) { + std::sort(sorted_labels.begin(), sorted_labels.end()); + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.tag); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); + } +}; + +struct network_selection_source_gid_impl: + public network_selection_crtp { + std::vector sorted_gids; + + explicit network_selection_source_gid_impl(std::vector gids): + sorted_gids(std::move(gids)) { + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_destination_gid_impl: + public network_selection_crtp { + std::vector sorted_gids; + + explicit network_selection_destination_gid_impl(std::vector gids): + sorted_gids(std::move(gids)) { + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } +}; + + +struct network_selection_invert_impl: + public network_selection_crtp { + std::shared_ptr selection; + + explicit network_selection_invert_impl(std::shared_ptr s): + selection(std::move(s)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return !selection->select_connection(src, dest); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; // cannot exclude any because source selection cannot be inverted without + // knowing selection criteria. + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; // cannot exclude any because destination selection cannot be inverted without + // knowing selection criteria. + } +}; + + +struct network_selection_inter_cell_impl: public network_selection_crtp { + bool select_connection(const network_site_info& src, + const network_site_info& dest) const { + return src.gid != dest.gid; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_not_equal_impl: public network_selection_crtp { + bool select_connection(const network_site_info& src, + const network_site_info& dest) const { + return src.gid != dest.gid || src.tag != dest.tag || src.location != dest.location; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_custom_impl: + public network_selection_crtp { + std::function func; + + explicit network_selection_custom_impl( + std::function f): + func(std::move(f)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return func(src, dest); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + +struct network_selection_within_distance_impl: + public network_selection_crtp { + double distance; + + explicit network_selection_within_distance_impl(double distance): distance(distance) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return network_location_distance(src.location, dest.location) <= distance; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + std::optional max_distance() const override { return distance; } +}; + + +struct network_selection_bernoulli_random_impl: + public network_selection_crtp { + unsigned seed; + double probability; + + network_selection_bernoulli_random_impl(unsigned seed, double p): seed(seed), probability(p) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return uniform_rand_from_key_pair({unsigned(network_seed::selection_bernoulli), seed}, + src.hash, + dest.hash) < probability; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } +}; + + +struct network_selection_linear_bernoulli_random_impl: + public network_selection_crtp { + unsigned seed; + double distance_begin; + double p_begin; + double distance_end; + double p_end; + + network_selection_linear_bernoulli_random_impl(unsigned seed_, + double distance_begin_, + double p_begin_, + double distance_end_, + double p_end_): + seed(seed_), + distance_begin(distance_begin_), + p_begin(p_begin_), + distance_end(distance_end_), + p_end(p_end_) { + if (distance_begin > distance_end) { + std::swap(distance_begin, distance_end); + std::swap(p_begin, p_end); + } + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + const double distance = network_location_distance(src.location, dest.location); + + if(distance < distance_begin || distance > distance_end) return false; + + const double p = + (p_begin * (distance_end - distance) + p_end * (distance - distance_begin)) / + (distance_end - distance_begin); + + return uniform_rand_from_key_pair( + {unsigned(network_seed::spatial_selection_bernoulli), seed}, + src.hash, + dest.hash) < p; + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + std::optional max_distance() const override { return distance_end; } +}; + +struct network_selection_and_impl: public network_selection_crtp { + std::shared_ptr left, right; + + network_selection_and_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return left->select_connection(src, dest) && right->select_connection(src, dest); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return left->select_source(gid, cell, tag) && right->select_source(gid, cell, tag); + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return left->select_destination(gid, cell, tag) && + right->select_destination(gid, cell, tag); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::min(d_left.value(), d_right.value()); + if (d_left) return d_left.value(); + if (d_right) return d_right.value(); + + return std::nullopt; + } +}; + +struct network_selection_or_impl: public network_selection_crtp { + std::shared_ptr left, right; + + network_selection_or_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return left->select_connection(src, dest) || right->select_connection(src, dest); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return left->select_source(gid, cell, tag) || right->select_source(gid, cell, tag); + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return left->select_destination(gid, cell, tag) || + right->select_destination(gid, cell, tag); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::max(d_left.value(), d_right.value()); + + return std::nullopt; + } +}; + +struct network_selection_xor_impl: public network_selection_crtp { + std::shared_ptr left, right; + + network_selection_xor_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return left->select_connection(src, dest) ^ right->select_connection(src, dest); + } + + template + bool select_source_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + template + bool select_destination_impl(cell_gid_type gid, + const CellType& cell, + const cell_tag_type& tag) const { + return true; + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + const auto d_right = right->max_distance(); + + if (d_left && d_right) return std::max(d_left.value(), d_right.value()); + + return std::nullopt; + } +}; + +} + +} // namespace arb diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp new file mode 100644 index 0000000000..01efdf02fd --- /dev/null +++ b/arbor/network_impl.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "connection.hpp" +#include "distributed_context.hpp" +#include "label_resolution.hpp" + +namespace arb { + + +struct network_selection_impl { + virtual std::optional max_distance() const { return std::nullopt; } + + virtual bool select_connection(const network_site_info& src, + const network_site_info& dest) const = 0; + + virtual bool select_source(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_destination(cell_gid_type gid, + const cable_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_source(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_destination(cell_gid_type gid, + const lif_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_source(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_destination(cell_gid_type gid, + const spike_source_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_source(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual bool select_destination(cell_gid_type gid, + const benchmark_cell& cell, + const cell_tag_type& tag) const = 0; + + virtual ~network_selection_impl() = default; +}; + +} // namespace arb From ba9a1995ccfc54e4b330fca01641c7f333d9c97a Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 6 Mar 2023 19:22:12 +0100 Subject: [PATCH 05/84] simplify --- arbor/network.cpp | 435 +++++++++++------------------------------ arbor/network_impl.hpp | 33 +--- 2 files changed, 118 insertions(+), 350 deletions(-) diff --git a/arbor/network.cpp b/arbor/network.cpp index 186deb68ba..9907c7cd45 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include #include @@ -21,7 +21,7 @@ namespace { // Different seed for each type to avoid unintentional correlation. enum class network_seed : unsigned { selection_bernoulli = 2058443, - spatial_selection_bernoulli = 839033, + selection_linear_bernoulli = 839033, value_uniform = 48202, value_normal = 8405, value_truncated_normal = 380237 @@ -42,223 +42,84 @@ double network_location_distance(const network_location& a, const network_locati return std::sqrt(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); } -template -struct network_selection_crtp: public network_selection_impl { - bool select_source(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_source_impl(gid, cell, tag); - } - - bool select_destination(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_destination_impl(gid, cell, tag); - } - - bool select_source(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_source_impl(gid, cell, tag); - } - - bool select_destination(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_destination_impl(gid, cell, tag); - } - - bool select_source(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_source_impl(gid, cell, tag); - } - - bool select_destination(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_destination_impl(gid, cell, tag); - } - - bool select_source(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_source_impl(gid, cell, tag); - } - - bool select_destination(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return static_cast(this)->select_destination_impl(gid, cell, tag); - } -}; - -struct network_selection_all_impl: public network_selection_crtp { +struct network_selection_all_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { return true; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_none_impl: - public network_selection_crtp { +struct network_selection_none_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { return false; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return false; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return false; } }; struct network_selection_source_cell_kind_impl: public network_selection_impl { - cell_kind kind; + cell_kind select_kind; - explicit network_selection_source_cell_kind_impl(cell_kind k): kind(k) {} + explicit network_selection_source_cell_kind_impl(cell_kind k): select_kind(k) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return src.kind == kind; - } - - bool select_source(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::cable; - } - - bool select_source(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::lif; - } - - bool select_source(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::benchmark; - } - - bool select_source(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::spike_source; - } - - bool select_destination(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return true; + return src.kind == select_kind; } - bool select_destination(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return true; + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + return kind == select_kind; } - bool select_destination(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return true; - } - - bool select_destination(cell_gid_type gid, - const spike_source_cell& cell, + bool select_destination(cell_kind kind, + cell_gid_type gid, const cell_tag_type& tag) const override { return true; } }; -struct network_selection_destination_cell_kind_impl: - public network_selection_impl { - cell_kind kind; +struct network_selection_destination_cell_kind_impl: public network_selection_impl { + cell_kind select_kind; - explicit network_selection_destination_cell_kind_impl(cell_kind k): kind(k) {} + explicit network_selection_destination_cell_kind_impl(cell_kind k): select_kind(k) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return src.kind == kind; - } - - bool select_source(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return true; - } - - bool select_source(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return true; - } - - bool select_source(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return true; + return src.kind == select_kind; } - bool select_source(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - bool select_destination(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::cable; - } - - bool select_destination(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::lif; - } - - bool select_destination(cell_gid_type gid, - const benchmark_cell& cell, - const cell_tag_type& tag) const override { - return kind == cell_kind::benchmark; - } - - bool select_destination(cell_gid_type gid, - const spike_source_cell& cell, + bool select_destination(cell_kind kind, + cell_gid_type gid, const cell_tag_type& tag) const override { - return kind == cell_kind::spike_source; + return kind == select_kind; } }; - -struct network_selection_source_label_impl: - public network_selection_crtp { +struct network_selection_source_label_impl: public network_selection_impl { std::vector sorted_labels; explicit network_selection_source_label_impl(std::vector labels): @@ -271,23 +132,18 @@ struct network_selection_source_label_impl: return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.tag); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_destination_label_impl: - public network_selection_crtp { +struct network_selection_destination_label_impl: public network_selection_impl { std::vector sorted_labels; explicit network_selection_destination_label_impl(std::vector labels): @@ -300,23 +156,18 @@ struct network_selection_destination_label_impl: return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.tag); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); } }; -struct network_selection_source_gid_impl: - public network_selection_crtp { +struct network_selection_source_gid_impl: public network_selection_impl { std::vector sorted_gids; explicit network_selection_source_gid_impl(std::vector gids): @@ -329,23 +180,18 @@ struct network_selection_source_gid_impl: return std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_destination_gid_impl: - public network_selection_crtp { +struct network_selection_destination_gid_impl: public network_selection_impl { std::vector sorted_gids; explicit network_selection_destination_gid_impl(std::vector gids): @@ -358,24 +204,18 @@ struct network_selection_destination_gid_impl: return std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } }; - -struct network_selection_invert_impl: - public network_selection_crtp { +struct network_selection_invert_impl: public network_selection_impl { std::shared_ptr selection; explicit network_selection_invert_impl(std::shared_ptr s): @@ -386,96 +226,77 @@ struct network_selection_invert_impl: return !selection->select_connection(src, dest); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; // cannot exclude any because source selection cannot be inverted without // knowing selection criteria. } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; // cannot exclude any because destination selection cannot be inverted without // knowing selection criteria. } }; - -struct network_selection_inter_cell_impl: public network_selection_crtp { +struct network_selection_inter_cell_impl: public network_selection_impl { bool select_connection(const network_site_info& src, - const network_site_info& dest) const { + const network_site_info& dest) const override { return src.gid != dest.gid; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_not_equal_impl: public network_selection_crtp { +struct network_selection_not_equal_impl: public network_selection_impl { bool select_connection(const network_site_info& src, - const network_site_info& dest) const { + const network_site_info& dest) const override { return src.gid != dest.gid || src.tag != dest.tag || src.location != dest.location; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_custom_impl: - public network_selection_crtp { - std::function func; +struct network_selection_custom_impl: public network_selection_impl { + std::function func; - explicit network_selection_custom_impl( - std::function f): - func(std::move(f)) {} + explicit network_selection_custom_impl( + std::function f): + func(std::move(f)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { return func(src, dest); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; -struct network_selection_within_distance_impl: - public network_selection_crtp { +struct network_selection_within_distance_impl: public network_selection_impl { double distance; explicit network_selection_within_distance_impl(double distance): distance(distance) {} @@ -485,26 +306,20 @@ struct network_selection_within_distance_impl: return network_location_distance(src.location, dest.location) <= distance; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } std::optional max_distance() const override { return distance; } }; - -struct network_selection_bernoulli_random_impl: - public network_selection_crtp { +struct network_selection_bernoulli_random_impl: public network_selection_impl { unsigned seed; double probability; @@ -517,24 +332,18 @@ struct network_selection_bernoulli_random_impl: dest.hash) < probability; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } }; - -struct network_selection_linear_bernoulli_random_impl: - public network_selection_crtp { +struct network_selection_linear_bernoulli_random_impl: public network_selection_impl { unsigned seed; double distance_begin; double p_begin; @@ -561,36 +370,32 @@ struct network_selection_linear_bernoulli_random_impl: const network_site_info& dest) const override { const double distance = network_location_distance(src.location, dest.location); - if(distance < distance_begin || distance > distance_end) return false; + if (distance < distance_begin || distance > distance_end) return false; const double p = (p_begin * (distance_end - distance) + p_end * (distance - distance_begin)) / (distance_end - distance_begin); return uniform_rand_from_key_pair( - {unsigned(network_seed::spatial_selection_bernoulli), seed}, + {unsigned(network_seed::selection_linear_bernoulli), seed}, src.hash, dest.hash) < p; } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } std::optional max_distance() const override { return distance_end; } }; -struct network_selection_and_impl: public network_selection_crtp { +struct network_selection_and_impl: public network_selection_impl { std::shared_ptr left, right; network_selection_and_impl(std::shared_ptr l, @@ -603,19 +408,15 @@ struct network_selection_and_impl: public network_selection_crtpselect_connection(src, dest) && right->select_connection(src, dest); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { - return left->select_source(gid, cell, tag) && right->select_source(gid, cell, tag); + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + return left->select_source(kind, gid, tag) && right->select_source(kind, gid, tag); } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { - return left->select_destination(gid, cell, tag) && - right->select_destination(gid, cell, tag); + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { + return left->select_destination(kind, gid, tag) && + right->select_destination(kind, gid, tag); } std::optional max_distance() const override { @@ -630,7 +431,7 @@ struct network_selection_and_impl: public network_selection_crtp { +struct network_selection_or_impl: public network_selection_impl { std::shared_ptr left, right; network_selection_or_impl(std::shared_ptr l, @@ -643,19 +444,15 @@ struct network_selection_or_impl: public network_selection_crtpselect_connection(src, dest) || right->select_connection(src, dest); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { - return left->select_source(gid, cell, tag) || right->select_source(gid, cell, tag); + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + return left->select_source(kind, gid, tag) || right->select_source(kind, gid, tag); } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { - return left->select_destination(gid, cell, tag) || - right->select_destination(gid, cell, tag); + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { + return left->select_destination(kind, gid, tag) || + right->select_destination(kind, gid, tag); } std::optional max_distance() const override { @@ -668,7 +465,7 @@ struct network_selection_or_impl: public network_selection_crtp { +struct network_selection_xor_impl: public network_selection_impl { std::shared_ptr left, right; network_selection_xor_impl(std::shared_ptr l, @@ -681,17 +478,13 @@ struct network_selection_xor_impl: public network_selection_crtpselect_connection(src, dest) ^ right->select_connection(src, dest); } - template - bool select_source_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { return true; } - template - bool select_destination_impl(cell_gid_type gid, - const CellType& cell, - const cell_tag_type& tag) const { + bool select_destination(cell_kind kind, + cell_gid_type gid, + const cell_tag_type& tag) const override { return true; } @@ -705,6 +498,6 @@ struct network_selection_xor_impl: public network_selection_crtp max_distance() const { return std::nullopt; } virtual bool select_connection(const network_site_info& src, const network_site_info& dest) const = 0; - virtual bool select_source(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_destination(cell_gid_type gid, - const cable_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_source(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_destination(cell_gid_type gid, - const lif_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_source(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_destination(cell_gid_type gid, - const spike_source_cell& cell, - const cell_tag_type& tag) const = 0; - - virtual bool select_source(cell_gid_type gid, - const benchmark_cell& cell, + virtual bool select_source(cell_kind kind, + cell_gid_type gid, const cell_tag_type& tag) const = 0; - virtual bool select_destination(cell_gid_type gid, - const benchmark_cell& cell, + virtual bool select_destination(cell_kind kind, + cell_gid_type gid, const cell_tag_type& tag) const = 0; virtual ~network_selection_impl() = default; From 525a50d55bee3efeb093dc6931d134e433821278 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 8 Mar 2023 17:01:36 +0100 Subject: [PATCH 06/84] recv send --- arbor/communication/dry_run_context.cpp | 17 ++++ arbor/communication/mpi.hpp | 57 +++++++++++ arbor/communication/mpi_context.cpp | 53 ++++++++++ arbor/distributed_context.hpp | 84 +++++++++++++++ arbor/include/arbor/network.hpp | 1 + arbor/network_generation.cpp | 129 ++++++++++++++++++++++++ 6 files changed, 341 insertions(+) diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index fcfe4f9f56..12b8ecb386 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -109,6 +109,23 @@ struct dry_run_context_impl { return std::vector(num_ranks_, value); } + std::vector gather_all(std::size_t value) const { + return std::vector(num_ranks_, value); + } + + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + throw arbor_internal_error("send_recv_nonblocking: not implemented for dry run conext."); + + return distributed_request{ + std::make_unique()}; + } + int id() const { return 0; } int size() const { return num_ranks_; } diff --git a/arbor/communication/mpi.hpp b/arbor/communication/mpi.hpp index df3eaecb85..a9fbb52cc3 100644 --- a/arbor/communication/mpi.hpp +++ b/arbor/communication/mpi.hpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include @@ -319,5 +321,60 @@ T broadcast(int root, MPI_Comm comm) { return value; } +std::vector isend(std::size_t num_bytes, + const void* data, + int dest, + int tag, + MPI_Comm comm) { + constexpr std::size_t max_msg_size = static_cast(std::numeric_limits::max()); + + std::vector requests; + + for (std::size_t idx = 0; idx < num_bytes; idx += max_msg_size) { + requests.emplace_back(); + MPI_OR_THROW(MPI_Isend, + reinterpret_cast(const_cast(data)) + idx, + static_cast(std::min(max_msg_size, num_bytes - idx)), + MPI_BYTE, + dest, + tag, + comm, + &(requests.back())); + } + + return requests; +} + +std::vector irecv(std::size_t num_bytes, + void* data, + int source, + int tag, + MPI_Comm comm) { + constexpr std::size_t max_msg_size = static_cast(std::numeric_limits::max()); + + std::vector requests; + + for (std::size_t idx = 0; idx < num_bytes; idx += max_msg_size) { + requests.emplace_back(); + MPI_OR_THROW(MPI_Irecv, + reinterpret_cast(data) + idx, + static_cast(std::min(max_msg_size, num_bytes - idx)), + MPI_BYTE, + source, + tag, + comm, + &(requests.back())); + } + + return requests; +} + +void wait_all(std::vector requests) { + if(!requests.empty()) { + MPI_OR_THROW( + MPI_Waitall, static_cast(requests.size()), requests.data(), MPI_STATUSES_IGNORE); + } +} + } // namespace mpi } // namespace arb diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index 22a2428343..9de7d37679 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -5,6 +5,7 @@ #error "build only if MPI is enabled" #endif +#include #include #include @@ -66,6 +67,58 @@ struct mpi_context_impl { return mpi::gather(value, root, comm_); } + std::vector gather_all(std::size_t value) const { + return mpi::gather_all(value, comm_); + } + + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const { + + // Return dummy request of nothing to do + if (!recv_count && !send_count) + return distributed_request{ + std::make_unique()}; + if(recv_count && !recv_data) + throw arbor_internal_error( + "send_recv_nonblocking: recv_data is null."); + + if(send_count && !send_data) + throw arbor_internal_error( + "send_recv_nonblocking: send_data is null."); + + if (recv_data == send_data) + throw arbor_internal_error( + "send_recv_nonblocking: recv_data and send_data must not be the same."); + + auto recv_requests = mpi::irecv(recv_count, recv_data, source_id, tag, comm_); + auto send_requests = mpi::isend(send_count, send_data, dest_id, tag, comm_); + + struct mpi_send_recv_request : public distributed_request::distributed_request_interface { + std::vector recv_requests, send_requests; + + void finalize() override { + if (!recv_requests.empty()) { + mpi::wait_all(std::move(recv_requests)); + } + + if (!send_requests.empty()) { + mpi::wait_all(std::move(send_requests)); + } + }; + + ~mpi_send_recv_request() override { this->finalize(); } + }; + + return distributed_request{ + std::unique_ptr( + new mpi_send_recv_request{std::move(recv_requests), std::move(send_requests)})}; + } + std::string name() const { return "MPI"; } int id() const { return rank_; } int size() const { return size_; } diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index 58ca11c318..6c3f37056f 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -32,6 +33,20 @@ namespace arb { #define ARB_COLLECTIVE_TYPES_ float, double, int, unsigned, long, unsigned long, long long, unsigned long long +struct distributed_request { + inline void finalize() { + if (impl) impl->finalize(); + } + + struct distributed_request_interface { + virtual void finalize() {}; + + virtual ~distributed_request_interface() = default; + }; + + std::unique_ptr impl; +}; + // Defines the concept/interface for a distributed communication context. // // Uses value-semantic type erasure to define the interface, so that @@ -82,6 +97,31 @@ class distributed_context { return impl_->gather(value, root); } + std::vector gather_all(std::size_t value) const { + return impl_->gather_all(value); + } + + template + distributed_request send_recv_nonblocking(std::size_t recv_count, + T* recv_data, + int source_id, + std::size_t send_count, + const T* send_data, + int dest_id, + int tag) const { + static_assert(std::is_trivially_copyable::value, + "send_recv_nonblocking: Type T must be trivially copyable for memcpy or MPI send / " + "recv using MPI_BYTE."); + + impl_->send_recv_nonblocking(recv_count * sizeof(T), + recv_data, + source_id, + send_count * sizeof(T), + send_data, + dest_id, + tag); + } + int id() const { return impl_->id(); } @@ -114,6 +154,14 @@ class distributed_context { gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const = 0; virtual std::vector gather(std::string value, int root) const = 0; + virtual std::vector gather_all(std::size_t value) const = 0; + virtual distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const = 0; virtual int id() const = 0; virtual int size() const = 0; virtual void barrier() const = 0; @@ -153,6 +201,19 @@ class distributed_context { gather(std::string value, int root) const override { return wrapped.gather(value, root); } + std::vector gather_all(std::size_t value) const override { + return wrapped.gather_all(value); + } + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const override { + return wrapped.send_recv_nonblocking( + recv_count, recv_data, source_id, send_count, send_data, dest_id, tag); + } int id() const override { return wrapped.id(); } @@ -208,6 +269,29 @@ struct local_context { return {std::move(value)}; } + std::vector gather_all(std::size_t value) const { + return std::vector({value}); + } + + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + if (source != 0 || dest != 0) + throw arbor_internal_error( + "send_recv_nonblocking: source and destination id must be 0 for local context."); + if (dest_count != source_count) + throw arbor_internal_error( + "send_recv_nonblocking: dest_count not equal to source_count."); + std::memcpy(dest_data, source_data, source_count); + + return distributed_request{ + std::make_unique()}; + } + int id() const { return 0; } int size() const { return 1; } diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 8d747609cb..1d2755b9e6 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace arb { diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 68458989ac..a5b2ca72d9 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -5,7 +5,12 @@ #include #include #include + +#include #include +#include +#include +#include namespace arb { @@ -22,6 +27,130 @@ struct src_site_info { double x, y, z; network_hash_type hash; }; + + +struct site_info { + cell_gid_type gid; + cell_gid_type label_id; + cell_lid_type lid; + network_location location; +}; + +struct site_collection { + std::unordered_map label_id_mapping; + std::vector sites; + + inline void add_site(cell_gid_type gid, + const cell_tag_type& label, + cell_lid_type lid, + network_location location) { + + auto insert_it = label_id_mapping.insert({label, label_id_mapping.size()}); + + sites.emplace_back(site_info{gid, insert_it.first->second, lid, location}); + } +}; + +struct site_mapping { + std::vector sites; + std::string labels; + + site_mapping() = default; + + site_mapping(site_collection collection) { + + std::size_t totalLabelLength = 0; + for (const auto& [label, _]: collection.label_id_mapping) { + totalLabelLength += label.size(); + } + + labels.reserve(totalLabelLength + collection.label_id_mapping.size()); + std::vector label_id_to_start_idx(collection.label_id_mapping.size()); + for (const auto& [label, id]: collection.label_id_mapping) { + label_id_to_start_idx[id] = labels.size(); + labels.append(label); + labels.push_back('\0'); + } + + for(auto& si : collection.sites) { + si.label_id = label_id_to_start_idx.at(si.label_id); + } + + sites = std::move(collection.sites); + } + + std::string_view label_at_site(const site_info& si) { + return labels.c_str() + si.label_id; + } +}; + +template +void distributed_for_each_site(const distributed_context& distributed, + site_mapping mapping, + FUNC f) { + if(distributed.size() > 1) { + const auto my_rank = distributed.id(); + const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; + const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; + + const auto num_sites_per_rank = distributed.gather_all(mapping.sites.size()); + const auto label_string_size_per_rank = distributed.gather_all(mapping.labels.size()); + + const auto max_num_sites = + *std::max_element(num_sites_per_rank.begin(), num_sites_per_rank.end()); + const auto max_string_size = + *std::max_element(label_string_size_per_rank.begin(), label_string_size_per_rank.end()); + + mapping.sites.resize(max_num_sites); + mapping.labels.resize(max_string_size); + + site_mapping recv_mapping; + recv_mapping.sites.resize(max_num_sites); + recv_mapping.labels.resize(max_string_size); + + auto current_idx = my_rank; + + for(std::size_t step = 0; step < distributed.size() - 1; ++step) { + const auto next_idx = (current_idx + 1) % distributed.size(); + auto request_sites = distributed.send_recv_nonblocking(num_sites_per_rank[next_idx], + recv_mapping.sites.data(), + right_rank, + num_sites_per_rank[current_idx], + mapping.sites.data(), + left_rank, + 0); + + auto request_labels = + distributed.send_recv_nonblocking(label_string_size_per_rank[next_idx], + recv_mapping.labels.data(), + right_rank, + label_string_size_per_rank[current_idx], + mapping.labels.data(), + left_rank, + 1); + + for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { + const auto& s = mapping.sites[site_idx]; + f(s, mapping.label_at_site(s)); + } + + request_sites.finalize(); + request_labels.finalize(); + + std::swap(mapping, recv_mapping); + + current_idx = next_idx; + } + + for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { + const auto& s = mapping.sites[site_idx]; + f(s, mapping.label_at_site(s)); + } + } else { + for (const auto& s: mapping.sites) { f(s, mapping.label_at_site(s)); } + } +} + } // namespace std::vector generate_network_connections( From fae26514827556bb7345b9c631b98af4b02e2d4f Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sat, 11 Mar 2023 17:50:46 +0100 Subject: [PATCH 07/84] complete generation --- arbor/CMakeLists.txt | 1 + arbor/distributed_context.hpp | 2 +- arbor/include/arbor/network.hpp | 63 ++----- arbor/network.cpp | 303 ++++++++++++++++++++++++++------ arbor/network_generation.cpp | 247 ++++++++++++-------------- arbor/network_impl.hpp | 19 +- arbor/util/spatial_tree.hpp | 37 ++-- 7 files changed, 421 insertions(+), 251 deletions(-) diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt index 948205b2ef..807266eaec 100644 --- a/arbor/CMakeLists.txt +++ b/arbor/CMakeLists.txt @@ -46,6 +46,7 @@ set(arbor_sources morph/stitch.cpp merge_events.cpp network.cpp + network_generation.cpp simulation.cpp partition_load_balance.cpp profile/clock.cpp diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index 6c3f37056f..be64142cdc 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -113,7 +113,7 @@ class distributed_context { "send_recv_nonblocking: Type T must be trivially copyable for memcpy or MPI send / " "recv using MPI_BYTE."); - impl_->send_recv_nonblocking(recv_count * sizeof(T), + return impl_->send_recv_nonblocking(recv_count * sizeof(T), recv_data, source_id, send_count * sizeof(T), diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 1d2755b9e6..ead9c748e1 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -22,16 +23,28 @@ using network_location = std::array; using network_hash_type = std::uint64_t; -struct network_site_info { +struct ARB_SYMBOL_VISIBLE network_site_info { + network_site_info(cell_gid_type gid, + cell_lid_type lid, + cell_kind kind, + std::string_view label, + mlocation location, + network_location global_location); + cell_gid_type gid; + cell_lid_type lid; cell_kind kind; - const cell_tag_type& tag; - const network_location& location; + std::string_view label; + mlocation location; + network_location global_location; network_hash_type hash; }; struct network_selection_impl; +struct network_value_impl; + + class ARB_SYMBOL_VISIBLE network_selection { public: network_selection() { *this = network_selection::all(); } @@ -93,6 +106,7 @@ class ARB_SYMBOL_VISIBLE network_selection { network_selection operator^(network_selection right) const; private: + friend const network_selection_impl& get_network_selection_impl(const network_selection& s); std::shared_ptr impl_; }; @@ -133,48 +147,9 @@ class ARB_SYMBOL_VISIBLE network_value { const network_location&, double)> func); - // inline double operator()(cell_gid_type src_gid, - // const network_location& global_src_location, - // network_hash_type src_hash, - // cell_gid_type dest_gid, - // const network_location& global_dest_location, - // network_hash_type dest_hash) const { - // return impl_->get( - // src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - // } - private: - - struct value_impl { - // virtual double get(cell_gid_type src_gid, - // const network_location& global_src_location, - // network_hash_type src_hash, - // cell_gid_type dest_gid, - // const network_location& global_dest_location, - // network_hash_type dest_hash) const = 0; - - virtual ~value_impl() = default; - }; - - struct uniform_distribution_impl; - struct normal_distribution_impl; - struct truncated_normal_distribution_impl; - struct custom_impl; - struct uniform_impl; - - network_value(std::shared_ptr impl); - - // inline double get(cell_gid_type src_gid, - // const network_location& global_src_location, - // network_hash_type src_hash, - // cell_gid_type dest_gid, - // const network_location& global_dest_location, - // network_hash_type dest_hash) const { - // return impl_->get( - // src_gid, global_src_location, src_hash, dest_gid, global_dest_location, dest_hash); - // } - - std::shared_ptr impl_; + friend const network_value_impl& get_network_value_impl(const network_value& v); + std::shared_ptr impl_; }; struct network_description { diff --git a/arbor/network.cpp b/arbor/network.cpp index 9907c7cd45..dad5f5a495 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -24,9 +25,27 @@ enum class network_seed : unsigned { selection_linear_bernoulli = 839033, value_uniform = 48202, value_normal = 8405, - value_truncated_normal = 380237 + value_truncated_normal = 380237, + site_info = 984293 }; +// We only need minimal hash collisions and good spread over the hash range, because this will be +// used as input for random123, which then provides all desired hash properties. +// std::hash is implementation dependent, so we define our own for reproducibility. + +std::uint64_t simple_string_hash(const std::string_view& s) { + // use fnv1a hash algorithm + constexpr std::uint64_t prime = 1099511628211ull; + std::uint64_t h = 14695981039346656037ull; + + for (auto c: s) { + h ^= c; + h *= prime; + } + + return h; +} + double uniform_rand_from_key_pair(std::array seed, network_hash_type key_a, network_hash_type key_b) { @@ -38,6 +57,19 @@ double uniform_rand_from_key_pair(std::array seed, return r123::u01(gen(seed_input, key)[0]); } +double normal_rand_from_key_pair(std::array seed, + std::uint64_t key_a, + std::uint64_t key_b) { + using rand_type = r123::Threefry2x64; + const rand_type::ctr_type seed_input = {{seed[0], seed[1]}}; + + const rand_type::key_type key = {{std::min(key_a, key_b), std::max(key_a, key_b)}}; + rand_type gen; + const auto rand_num = gen(seed_input, key); + return r123::boxmuller(rand_num[0], rand_num[1]).x; +} + + double network_location_distance(const network_location& a, const network_location& b) { return std::sqrt(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); } @@ -48,13 +80,15 @@ struct network_selection_all_impl: public network_selection_impl { return true; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -66,13 +100,15 @@ struct network_selection_none_impl: public network_selection_impl { return false; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return false; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return false; } }; @@ -87,13 +123,15 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { return src.kind == select_kind; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return kind == select_kind; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -108,62 +146,68 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im return src.kind == select_kind; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return kind == select_kind; } }; struct network_selection_source_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector sorted_labels; - explicit network_selection_source_label_impl(std::vector labels): + explicit network_selection_source_label_impl(std::vector labels): sorted_labels(std::move(labels)) { std::sort(sorted_labels.begin(), sorted_labels.end()); } bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.tag); + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.label); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; struct network_selection_destination_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector sorted_labels; - explicit network_selection_destination_label_impl(std::vector labels): + explicit network_selection_destination_label_impl(std::vector labels): sorted_labels(std::move(labels)) { std::sort(sorted_labels.begin(), sorted_labels.end()); } bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.tag); + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.label); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), tag); + const std::string_view& label) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } }; @@ -180,13 +224,15 @@ struct network_selection_source_gid_impl: public network_selection_impl { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -204,13 +250,15 @@ struct network_selection_destination_gid_impl: public network_selection_impl { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } }; @@ -226,14 +274,16 @@ struct network_selection_invert_impl: public network_selection_impl { return !selection->select_connection(src, dest); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; // cannot exclude any because source selection cannot be inverted without // knowing selection criteria. } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; // cannot exclude any because destination selection cannot be inverted without // knowing selection criteria. } @@ -245,13 +295,15 @@ struct network_selection_inter_cell_impl: public network_selection_impl { return src.gid != dest.gid; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -259,16 +311,18 @@ struct network_selection_inter_cell_impl: public network_selection_impl { struct network_selection_not_equal_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return src.gid != dest.gid || src.tag != dest.tag || src.location != dest.location; + return src.gid != dest.gid || src.label != dest.label || src.location != dest.location; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -285,13 +339,15 @@ struct network_selection_custom_impl: public network_selection_impl { return func(src, dest); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -303,16 +359,18 @@ struct network_selection_within_distance_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return network_location_distance(src.location, dest.location) <= distance; + return network_location_distance(src.global_location, dest.global_location) <= distance; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } @@ -332,13 +390,15 @@ struct network_selection_bernoulli_random_impl: public network_selection_impl { dest.hash) < probability; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } }; @@ -368,7 +428,7 @@ struct network_selection_linear_bernoulli_random_impl: public network_selection_ bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - const double distance = network_location_distance(src.location, dest.location); + const double distance = network_location_distance(src.global_location, dest.global_location); if (distance < distance_begin || distance > distance_end) return false; @@ -382,13 +442,15 @@ struct network_selection_linear_bernoulli_random_impl: public network_selection_ dest.hash) < p; } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } @@ -408,15 +470,17 @@ struct network_selection_and_impl: public network_selection_impl { return left->select_connection(src, dest) && right->select_connection(src, dest); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { - return left->select_source(kind, gid, tag) && right->select_source(kind, gid, tag); + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return left->select_source(kind, gid, label) && right->select_source(kind, gid, label); } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { - return left->select_destination(kind, gid, tag) && - right->select_destination(kind, gid, tag); + const std::string_view& label) const override { + return left->select_destination(kind, gid, label) && + right->select_destination(kind, gid, label); } std::optional max_distance() const override { @@ -444,15 +508,17 @@ struct network_selection_or_impl: public network_selection_impl { return left->select_connection(src, dest) || right->select_connection(src, dest); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { - return left->select_source(kind, gid, tag) || right->select_source(kind, gid, tag); + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { - return left->select_destination(kind, gid, tag) || - right->select_destination(kind, gid, tag); + const std::string_view& label) const override { + return left->select_destination(kind, gid, label) || + right->select_destination(kind, gid, label); } std::optional max_distance() const override { @@ -478,13 +544,15 @@ struct network_selection_xor_impl: public network_selection_impl { return left->select_connection(src, dest) ^ right->select_connection(src, dest); } - bool select_source(cell_kind kind, cell_gid_type gid, const cell_tag_type& tag) const override { + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { return true; } bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const override { + const std::string_view& label) const override { return true; } @@ -498,6 +566,133 @@ struct network_selection_xor_impl: public network_selection_impl { } }; + +struct network_value_uniform_impl : public network_value_impl{ + double value; + + network_value_uniform_impl(double v): value(v) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return value; + } +}; + +struct network_value_uniform_distribution_impl : public network_value_impl{ + unsigned seed = 0; + std::array range; + + network_value_uniform_distribution_impl(unsigned rand_seed, const std::array& r): + seed(rand_seed), + range(r) { + if (range[0] >= range[1]) + throw std::invalid_argument("Uniform distribution: invalid range"); + } + + double get(const network_site_info& src, const network_site_info& dest) const override { + if (range[0] > range[1]) return range[1]; + + // random number between 0 and 1 + const auto rand_num = uniform_rand_from_key_pair( + {unsigned(network_seed::value_uniform), seed}, src.hash, dest.hash); + + return (range[1] - range[0]) * rand_num + range[0]; + } +}; + +struct network_value_normal_distribution_impl: public network_value_impl { + unsigned seed = 0; + double mean = 0.0; + double std_deviation = 1.0; + + network_value_normal_distribution_impl(unsigned rand_seed, double mean_, double std_deviation_): + seed(rand_seed), + mean(mean_), + std_deviation(std_deviation_) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return mean + std_deviation * + normal_rand_from_key_pair( + {unsigned(network_seed::value_normal), seed}, src.hash, dest.hash); + } +}; + +struct network_value_truncated_normal_distribution_impl: public network_value_impl { + unsigned seed = 0; + double mean = 0.0; + double std_deviation = 1.0; + std::array range; + + network_value_truncated_normal_distribution_impl(unsigned rand_seed, + double mean_, + double std_deviation_, + const std::array& range_): + seed(rand_seed), + mean(mean_), + std_deviation(std_deviation_), + range(range_) { + if (range[0] >= range[1]) + throw std::invalid_argument("Truncated normal distribution: invalid range"); + } + + double get(const network_site_info& src, const network_site_info& dest) const override { + + const auto src_hash = src.hash; + auto dest_hash = dest.hash; + + double value = 0.0; + + do { + value = + mean + std_deviation * normal_rand_from_key_pair( + {unsigned(network_seed::value_truncated_normal), seed}, + src_hash, + dest_hash); + ++dest_hash; + } while (!(value > range[0] && value <= range[1])); + + return value; + } +}; + +struct network_value_custom_impl: public network_value_impl { + std::function func; + + network_value_custom_impl( + std::function f): + func(std::move(f)) {} + + inline double get(const network_site_info& src, const network_site_info& dest) const override { + return func(src, dest); + } +}; + } // namespace +network_site_info::network_site_info(cell_gid_type gid, + cell_lid_type lid, + cell_kind kind, + std::string_view label, + mlocation location, + network_location global_location): + gid(gid), + lid(lid), + kind(kind), + label(std::move(label)), + location(location), + global_location(global_location) { + + std::uint64_t label_hash = simple_string_hash(this->label); + static_assert(sizeof(decltype(mlocation::pos)) == sizeof(std::uint64_t)); + std::uint64_t loc_pos_hash = *reinterpret_cast(&location.pos); + + const auto seed = static_cast(network_seed::site_info); + + using rand_type = r123::Threefry4x64; + const rand_type::ctr_type seed_input = {{seed, 2 * seed, 3 * seed, 4 * seed}}; + const rand_type::key_type key = {{gid, label_hash, location.branch, loc_pos_hash}}; + + rand_type gen; + hash = gen(seed_input, key)[0]; +} + } // namespace arb diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index a5b2ca72d9..7246910840 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -1,6 +1,9 @@ #include "network_generation.hpp" +#include "network_impl.hpp" #include "util/spatial_tree.hpp" +#include + #include #include #include @@ -8,87 +11,63 @@ #include #include +#include #include #include +#include #include namespace arb { namespace { -struct dest_site_info { - cell_gid_type gid; - cell_lid_type lid; - network_hash_type hash; -}; - -struct src_site_info { - cell_gid_type gid; - cell_lid_type lid; - double x, y, z; - network_hash_type hash; -}; - - -struct site_info { - cell_gid_type gid; - cell_gid_type label_id; - cell_lid_type lid; - network_location location; -}; - -struct site_collection { - std::unordered_map label_id_mapping; - std::vector sites; - - inline void add_site(cell_gid_type gid, - const cell_tag_type& label, - cell_lid_type lid, - network_location location) { - - auto insert_it = label_id_mapping.insert({label, label_id_mapping.size()}); - - sites.emplace_back(site_info{gid, insert_it.first->second, lid, location}); - } +struct distributed_site_info { + cell_gid_type gid = 0; + cell_lid_type lid = 0; + cell_kind kind = cell_kind::cable; + cell_gid_type label_start_idx = 0; + mlocation location = mlocation(); + network_location global_location = network_location(); + network_hash_type hash = 0; }; -struct site_mapping { - std::vector sites; +struct distributed_site_mapping { + std::vector sites; std::string labels; - site_mapping() = default; + distributed_site_mapping() = default; - site_mapping(site_collection collection) { - - std::size_t totalLabelLength = 0; - for (const auto& [label, _]: collection.label_id_mapping) { - totalLabelLength += label.size(); - } - - labels.reserve(totalLabelLength + collection.label_id_mapping.size()); - std::vector label_id_to_start_idx(collection.label_id_mapping.size()); - for (const auto& [label, id]: collection.label_id_mapping) { - label_id_to_start_idx[id] = labels.size(); - labels.append(label); - labels.push_back('\0'); - } + explicit distributed_site_mapping(const std::vector& net_sites) { + std::unordered_map label_to_start_idx; - for(auto& si : collection.sites) { - si.label_id = label_id_to_start_idx.at(si.label_id); + for (const auto& s: net_sites) { + const auto insert_pair = label_to_start_idx.insert({s.label, labels.size()}); + // append label if not contained in labels + if (insert_pair.second) { + labels.append(s.label); + labels.push_back('\0'); + } + sites.emplace_back(distributed_site_info{s.gid, + s.lid, + s.kind, + insert_pair.first->second, + s.location, + s.global_location, + s.hash}); } - - sites = std::move(collection.sites); } - std::string_view label_at_site(const site_info& si) { - return labels.c_str() + si.label_id; + std::string_view label_at_site(const distributed_site_info& si) { + return labels.c_str() + si.label_start_idx; } }; template void distributed_for_each_site(const distributed_context& distributed, - site_mapping mapping, + const std::vector& src_sites, FUNC f) { if(distributed.size() > 1) { + distributed_site_mapping mapping(src_sites); + const auto my_rank = distributed.id(); const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; @@ -104,7 +83,7 @@ void distributed_for_each_site(const distributed_context& distributed, mapping.sites.resize(max_num_sites); mapping.labels.resize(max_string_size); - site_mapping recv_mapping; + distributed_site_mapping recv_mapping; recv_mapping.sites.resize(max_num_sites); recv_mapping.labels.resize(max_string_size); @@ -131,7 +110,8 @@ void distributed_for_each_site(const distributed_context& distributed, for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { const auto& s = mapping.sites[site_idx]; - f(s, mapping.label_at_site(s)); + f(network_site_info{ + s.gid, s.lid, s.kind, mapping.label_at_site(s), s.location, s.global_location}); } request_sites.finalize(); @@ -144,27 +124,27 @@ void distributed_for_each_site(const distributed_context& distributed, for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { const auto& s = mapping.sites[site_idx]; - f(s, mapping.label_at_site(s)); + f(network_site_info{ + s.gid, s.lid, s.kind, mapping.label_at_site(s), s.location, s.global_location}); } - } else { - for (const auto& s: mapping.sites) { f(s, mapping.label_at_site(s)); } + } + else { + for (const auto& s: src_sites) { f(s); } } } } // namespace -std::vector generate_network_connections( - const std::vector& descriptions, +std::vector generate_network_connections(const network_description& description, const recipe& rec, const distributed_context& distributed, - const domain_decomposition& dom_dec, - const label_resolution_map& source_resolution_map, - const label_resolution_map& target_resolution_map) { - if (descriptions.empty()) return {}; + const domain_decomposition& dom_dec) { - std::vector> local_src_sites(descriptions.size()); - std::vector>> local_dest_sites( - descriptions.size()); + std::vector src_sites, dest_sites; + + const auto& selection = get_network_selection_impl(description.selection); + const auto& weight = get_network_value_impl(description.weight); + const auto& delay = get_network_value_impl(description.delay); // populate network sites for source and destination for (const auto& group: dom_dec.groups()) { @@ -179,45 +159,47 @@ std::vector generate_network_connections( throw bad_cell_description(rec.get_cell_kind(gid), gid); } + auto lid_to_label = [](const std::unordered_multimap& map, + cell_lid_type lid) -> const cell_tag_type& { + for (const auto& [label, range]: map) { + if (lid >= range.begin && lid < range.end) return label; + } + throw arbor_internal_error("unkown lid"); + }; + place_pwlin location_resolver(cell.morphology()); // check all synapses of cell for potential destination - for (const auto& [name, placed_synapses]: cell.synapses()) { + + for (const auto& [_, placed_synapses]: cell.synapses()) { for (const auto& p_syn: placed_synapses) { - // TODO: compute rotation and global offset - const mpoint point = location_resolver.at(p_syn.loc); - network_location location = {point.x, point.y, point.z}; // TODO check if tag correct - const auto& tag = target_resolution_map.tag_at(gid, p_syn.lid); - - for (std::size_t i = 0; i < descriptions.size(); ++i) { - const auto& desc = descriptions[i]; - if (desc.dest_selection( - gid, cell_kind::cable, tag, p_syn.loc, location)) { - // TODO : compute hash - network_hash_type hash = 0; - local_dest_sites[i].push_back({location, {gid, p_syn.lid, hash}}); - } + const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); + + if (selection.select_destination(cell_kind::cable, gid, label)) { + // TODO: compute rotation and global offset + const mpoint point = location_resolver.at(p_syn.loc); + network_location global_location = {point.x, point.y, point.z}; + dest_sites.emplace_back(gid, + p_syn.lid, + cell_kind::cable, + label, + p_syn.loc, + global_location); } } } // check all detectors of cell for potential source for (const auto& p_det: cell.detectors()) { - // TODO: compute rotation and global offset - const mpoint point = location_resolver.at(p_det.loc); - network_location location = {point.x, point.y, point.z}; // TODO check if tag correct - const auto& tag = target_resolution_map.tag_at(gid, p_det.lid); - - for (std::size_t i = 0; i < descriptions.size(); ++i) { - const auto& desc = descriptions[i]; - if (desc.src_selection(gid, cell_kind::cable, tag, p_det.loc, location)) { - // TODO : compute hash - network_hash_type hash = 0; - local_src_sites[i].push_back( - {gid, p_det.lid, location[0], location[1], location[2], hash}); - } + const auto& label = lid_to_label(cell.synapse_ranges(), p_det.lid); + if (selection.select_destination(cell_kind::cable, gid, label)) { + // TODO: compute rotation and global offset + const mpoint point = location_resolver.at(p_det.loc); + network_location global_location = {point.x, point.y, point.z}; + src_sites.emplace_back( + gid, p_det.lid, cell_kind::cable, label, p_det.loc, global_location); } } } @@ -237,49 +219,42 @@ std::vector generate_network_connections( } } - // create octrees - std::vector> local_dest_trees; - local_dest_trees.reserve(descriptions.size()); - for (std::size_t i = 0; i < descriptions.size(); ++i) { - const auto& desc = descriptions[i]; - const std::size_t max_depth = desc.connection_selection.max_distance().has_value() ? 10 : 1; - local_dest_trees.emplace_back(max_depth, 100, std::move(local_dest_sites[i])); - } + // create octree + const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; + const std::size_t max_leaf_size = 100; + spatial_tree local_dest_tree( + max_depth, max_leaf_size, std::move(dest_sites), [](const network_site_info& info) { + return info.global_location; + }); // select connections std::vector connections; - for (std::size_t i = 0; i < descriptions.size(); ++i) { - const auto& desc = descriptions[i]; - const auto& src_sites = local_src_sites[i]; - const auto& dest_tree = local_dest_trees[i]; - - for (const auto& src: src_sites) { - auto sample_dest = [&](const network_location& dest_loc, const dest_site_info& dest) { - // TODO precompute distance - if (desc.connection_selection( - src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash)) { - const double w = desc.weight( - src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash); - const double d = desc.delay( - src.gid, {src.x, src.y, src.z}, src.hash, dest.gid, dest_loc, dest.hash); - - connections.emplace_back(cell_member_type{src.gid, src.lid}, - cell_member_type{dest.gid, dest.lid}, - w, - d); - } - }; - - if(desc.connection_selection.max_distance().has_value()) { - const double d = desc.connection_selection.max_distance().value(); - dest_tree.bounding_box_for_each(network_location{src.x - d, src.y - d, src.z - d}, - network_location{src.x + d, src.y + d, src.z + d}, - sample_dest); + auto sample_destinations = [&] (const network_site_info& src) { + auto sample = [&] (const network_site_info& dest) { + if(selection.select_connection(src, dest)) { + connections.emplace_back(connection({src.gid, src.lid}, + {dest.gid, dest.lid}, + weight.get(src, dest), + delay.get(src, dest))); } - else { dest_tree.for_each(sample_dest); } + }; + + if(selection.max_distance().has_value()) { + const double d = selection.max_distance().value(); + local_dest_tree.bounding_box_for_each(network_location{src.global_location[0] - d, + src.global_location[1] - d, + src.global_location[2] - d}, + network_location{src.global_location[0] + d, + src.global_location[1] + d, + src.global_location[2] + d}, + sample); + } else { + local_dest_tree.for_each(sample); } - } + }; + + distributed_for_each_site(distributed, src_sites, sample_destinations); return connections; } diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 0024d72217..73fb7a81a8 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -25,13 +26,27 @@ struct network_selection_impl { virtual bool select_source(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const = 0; + const std::string_view& tag) const = 0; virtual bool select_destination(cell_kind kind, cell_gid_type gid, - const cell_tag_type& tag) const = 0; + const std::string_view& tag) const = 0; virtual ~network_selection_impl() = default; }; +inline const network_selection_impl& get_network_selection_impl(const network_selection& s) { + return *(s.impl_); +} + +struct network_value_impl { + virtual double get(const network_site_info& src, const network_site_info& dest) const = 0; + + virtual ~network_value_impl() = default; +}; + +inline const network_value_impl& get_network_value_impl(const network_value& v) { + return *(v.impl_); +} + } // namespace arb diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 6a2f29ac88..1fb5750d73 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -22,29 +22,35 @@ class spatial_tree { using value_type = T; using point_type = std::array; using node_data = std::vector; - using leaf_data = std::vector>; + using leaf_data = std::vector; + using location_func_type = point_type (*)(const T &); spatial_tree(): size_(0), data_(leaf_data()) {} // Create a tree of given maximum depth and target leaf size. If any leaf holds more than the // target size, it is recursively split into up to 2^DIM nodes until reaching the maximum depth. - spatial_tree(std::size_t max_depth, std::size_t leaf_size_target, leaf_data data): + spatial_tree(std::size_t max_depth, + std::size_t leaf_size_target, + leaf_data data, + location_func_type location): size_(data.size()), - data_(std::move(data)) { + data_(std::move(data)), + location_(location) { auto &leaf_d = std::get(data_); if (leaf_d.empty()) return; min_.fill(std::numeric_limits::max()); max_.fill(-std::numeric_limits::max()); - for (const auto &[p, _]: leaf_d) { + for (const auto &d: leaf_d) { + const auto p = location(d); for (std::size_t i = 0; i < DIM; ++i) { if (p[i] < min_[i]) min_[i] = p[i]; if (p[i] > max_[i]) max_[i] = p[i]; } } - value_type mid; + point_type mid; for (std::size_t i = 0; i < DIM; ++i) { mid[i] = (max_[i] - min_[i]) / 2.0 + min_[i]; } if (max_depth > 1 && leaf_d.size() > leaf_size_target) { @@ -62,14 +68,15 @@ class spatial_tree { // assign each point to sub-node std::array new_leaf_data; - for (const auto &[p, d]: leaf_d) { - new_leaf_data[sub_node_index(p)].emplace_back(p, d); + for (const auto &d: leaf_d) { + const auto p = location(d); + new_leaf_data[sub_node_index(p)].emplace_back(d); } // move data into new sub-nodes if not empty for (auto &l_d: new_leaf_data) { if (l_d.size()) - new_nodes.emplace_back(max_depth - 1, leaf_size_target, std::move(l_d)); + new_nodes.emplace_back(max_depth - 1, leaf_size_target, std::move(l_d), location); } // replace current data_ with new sub-nodes @@ -98,7 +105,7 @@ class spatial_tree { } // Iterate over all points recursively. - // func must have signature `void func(const point_type&, const T&)`. + // func must have signature `void func(const T&)`. template inline void for_each(const F &func) const { std::visit( @@ -108,14 +115,14 @@ class spatial_tree { for (const auto &node: arg) { node.for_each(func); } } if constexpr (std::is_same_v) { - for (const auto &[p, d]: arg) { func(p, d); } + for (const auto &d: arg) { func(d); } } }, data_); } // Iterate over all points within the given bounding box recursively. - // func must have signature `void func(const point_type&, const T&)`. + // func must have signature `void func(const T&)`. template inline void bounding_box_for_each(const point_type &box_min, const point_type &box_max, @@ -137,7 +144,7 @@ class spatial_tree { for (const auto &node: arg) { node.template for_each(func); } } if constexpr (std::is_same_v) { - for (const auto &[p, d]: arg) { func(p, d); } + for (const auto &d: arg) { func(d); } } } else { @@ -150,9 +157,10 @@ class spatial_tree { } } if constexpr (std::is_same_v) { - for (const auto &[p, d]: arg) { + for (const auto &d: arg) { + const auto p = location_(d); if (all_smaller_eq(p, box_max) && all_smaller_eq(box_min, p)) { - func(p, d); + func(d); } } } @@ -169,6 +177,7 @@ class spatial_tree { std::size_t size_; point_type min_, max_; std::variant data_; + location_func_type location_; }; } // namespace arb From f536c13cccfa4f6ac2ecd8c6cd2bff242f4ba460 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 12 Mar 2023 10:15:31 +0100 Subject: [PATCH 08/84] implemented construction --- arbor/include/arbor/network.hpp | 15 ++-- arbor/network.cpp | 121 ++++++++++++++++++++++++++++++-- arbor/util/spatial_tree.hpp | 1 + 3 files changed, 126 insertions(+), 11 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index ead9c748e1..f0b9938fa9 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -106,14 +106,16 @@ class ARB_SYMBOL_VISIBLE network_selection { network_selection operator^(network_selection right) const; private: + network_selection(std::shared_ptr impl); + friend const network_selection_impl& get_network_selection_impl(const network_selection& s); std::shared_ptr impl_; }; class ARB_SYMBOL_VISIBLE network_value { public: - // Uniform value - network_value(double value); + // Uniform value with conversion from double + network_value(double value) { *this = network_value::uniform(value); } // Uniform value. Will always return the same value given at construction. static network_value uniform(double value); @@ -141,13 +143,12 @@ class ARB_SYMBOL_VISIBLE network_value { // Custom value using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction values, // "func" must be symmetric (func(a,b) = func(b,a)). - static network_value custom(std::function func); + static network_value custom( + std::function func); private: + network_value(std::shared_ptr impl); + friend const network_value_impl& get_network_value_impl(const network_value& v); std::shared_ptr impl_; }; diff --git a/arbor/network.cpp b/arbor/network.cpp index dad5f5a495..b1c9364f22 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -160,9 +160,9 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im }; struct network_selection_source_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector sorted_labels; - explicit network_selection_source_label_impl(std::vector labels): + explicit network_selection_source_label_impl(std::vector labels): sorted_labels(std::move(labels)) { std::sort(sorted_labels.begin(), sorted_labels.end()); } @@ -186,9 +186,9 @@ struct network_selection_source_label_impl: public network_selection_impl { }; struct network_selection_destination_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector sorted_labels; - explicit network_selection_destination_label_impl(std::vector labels): + explicit network_selection_destination_label_impl(std::vector labels): sorted_labels(std::move(labels)) { std::sort(sorted_labels.begin(), sorted_labels.end()); } @@ -695,4 +695,117 @@ network_site_info::network_site_info(cell_gid_type gid, hash = gen(seed_input, key)[0]; } +network_selection::network_selection(std::shared_ptr impl): + impl_(std::move(impl)) {} + +network_selection network_selection::operator&(network_selection right) const { + return network_selection( + std::make_shared(this->impl_, std::move(right.impl_))); +} + +network_selection network_selection::operator|(network_selection right) const { + return network_selection( + std::make_shared(this->impl_, std::move(right.impl_))); +} + +network_selection network_selection::operator^(network_selection right) const { + return network_selection( + std::make_shared(this->impl_, std::move(right.impl_))); +} + +network_selection network_selection::all() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::none() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::source_cell_kind(cell_kind kind) { + return network_selection(std::make_shared(kind)); +} + +network_selection network_selection::destination_cell_kind(cell_kind kind) { + return network_selection(std::make_shared(kind)); +} + +network_selection network_selection::source_label(std::vector labels) { + return network_selection(std::make_shared(std::move(labels))); +} + +network_selection network_selection::destination_label(std::vector labels) { + return network_selection(std::make_shared(std::move(labels))); +} + +network_selection network_selection::source_gid(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); +} + +network_selection network_selection::destination_gid(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); +} + +network_selection network_selection::invert(network_selection s) { + return network_selection(std::make_shared(std::move(s.impl_))); +} + +network_selection network_selection::inter_cell() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::not_equal() { + return network_selection(std::make_shared()); +} + +network_selection network_selection::bernoulli_random(unsigned seed, double p) { + return network_selection(std::make_shared(seed, p)); +} + +network_selection network_selection::custom( + std::function func) { + return network_selection(std::make_shared(std::move(func))); +} + +network_selection network_selection::within_distance(double distance) { + return network_selection(std::make_shared(distance)); +} + +network_selection network_selection::linear_bernoulli_random(unsigned seed, + double distance_begin, + double p_begin, + double distance_end, + double p_end) { + return network_selection(std::make_shared( + seed, distance_begin, p_begin, distance_end, p_end)); +} + +network_value::network_value(std::shared_ptr impl): impl_(std::move(impl)) {} + +network_value network_value::uniform(double value) { + return network_value(std::make_shared(value)); +} + +network_value network_value::uniform_distribution(unsigned seed, + const std::array& range) { + return network_value(std::make_shared(seed, range)); +} + +network_value network_value::normal_distribution(unsigned seed, double mean, double std_deviation) { + return network_value( + std::make_shared(seed, mean, std_deviation)); +} + +network_value network_value::truncated_normal_distribution(unsigned seed, + double mean, + double std_deviation, + const std::array& range) { + return network_value(std::make_shared( + seed, mean, std_deviation, range)); +} + +network_value network_value::custom( + std::function func) { + return network_value(std::make_shared(std::move(func))); +} + } // namespace arb diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 1fb5750d73..81e76e0449 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -29,6 +29,7 @@ class spatial_tree { // Create a tree of given maximum depth and target leaf size. If any leaf holds more than the // target size, it is recursively split into up to 2^DIM nodes until reaching the maximum depth. + // The "location" function type must have signature (const T&) -> point_type. spatial_tree(std::size_t max_depth, std::size_t leaf_size_target, leaf_data data, From ab02df6e7a9d2edaa02d949180512bac26547e60 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 12 Mar 2023 17:51:37 +0100 Subject: [PATCH 09/84] support for named selection / value --- arbor/include/arbor/network.hpp | 50 ++++++++++-- arbor/network.cpp | 137 +++++++++++++++++++++++++++++--- arbor/network_generation.cpp | 10 ++- arbor/network_impl.hpp | 21 +++-- 4 files changed, 192 insertions(+), 26 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index f0b9938fa9..9be75eaa30 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace arb { @@ -44,9 +45,13 @@ struct network_selection_impl; struct network_value_impl; +class ARB_SYMBOL_VISIBLE network_label_dict; class ARB_SYMBOL_VISIBLE network_selection { public: + using custom_func_type = + std::function; + network_selection() { *this = network_selection::all(); } // Select all @@ -55,6 +60,8 @@ class ARB_SYMBOL_VISIBLE network_selection { // Select none static network_selection none(); + static network_selection named(std::string name); + static network_selection source_cell_kind(cell_kind kind); static network_selection destination_cell_kind(cell_kind kind); @@ -84,8 +91,7 @@ class ARB_SYMBOL_VISIBLE network_selection { // Custom selection using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction selection, // "func" must be symmetric (func(a,b) = func(b,a)). - static network_selection custom( - std::function func); + static network_selection custom(custom_func_type func); // only select within given distance. This may enable more efficient sampling through an // internal spatial data structure. @@ -108,18 +114,25 @@ class ARB_SYMBOL_VISIBLE network_selection { private: network_selection(std::shared_ptr impl); - friend const network_selection_impl& get_network_selection_impl(const network_selection& s); + friend std::shared_ptr thingify(network_selection s, + const network_label_dict& dict); + std::shared_ptr impl_; }; class ARB_SYMBOL_VISIBLE network_value { public: + using custom_func_type = + std::function; + // Uniform value with conversion from double network_value(double value) { *this = network_value::uniform(value); } // Uniform value. Will always return the same value given at construction. static network_value uniform(double value); + static network_value named(std::string name); + // Uniform random value in (range[0], range[1]]. Always returns the same value for repeated // calls with the same arguments and calls are symmetric v(a, b) = v(b, a). static network_value uniform_distribution(unsigned seed, const std::array& range); @@ -143,20 +156,45 @@ class ARB_SYMBOL_VISIBLE network_value { // Custom value using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction values, // "func" must be symmetric (func(a,b) = func(b,a)). - static network_value custom( - std::function func); + static network_value custom(custom_func_type func); private: network_value(std::shared_ptr impl); - friend const network_value_impl& get_network_value_impl(const network_value& v); + friend std::shared_ptr thingify(network_value v, + const network_label_dict& dict); + std::shared_ptr impl_; }; +class ARB_SYMBOL_VISIBLE network_label_dict { +public: + using ns_map = std::unordered_map; + using nv_map = std::unordered_map; + + network_label_dict& set(const std::string& name, network_selection s); + + network_label_dict& set(const std::string& name, network_value v); + + std::optional selection(const std::string& name) const; + + std::optional value(const std::string& name) const; + + inline const ns_map& selections() const { return selections_; } + + inline const nv_map& values() const { return values_; } + +private: + ns_map selections_; + nv_map values_; +}; + + struct network_description { network_selection selection; network_value weight; network_value delay; + network_label_dict dict; }; } // namespace arb diff --git a/arbor/network.cpp b/arbor/network.cpp index b1c9364f22..2915592248 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -9,7 +9,9 @@ #include #include #include +#include #include +#include #include #include "network_impl.hpp" @@ -287,6 +289,52 @@ struct network_selection_invert_impl: public network_selection_impl { return true; // cannot exclude any because destination selection cannot be inverted without // knowing selection criteria. } + + void initialize(const network_label_dict& dict) override { + selection->initialize(dict); + }; +}; + + +struct network_selection_named_impl: public network_selection_impl { + using impl_pointer_type =std::shared_ptr; + + std::variant selection; + + explicit network_selection_named_impl(std::string name): selection(std::move(name)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + if(!std::holds_alternative(selection)) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return std::get(selection)->select_connection(src, dest); + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if(!std::holds_alternative(selection)) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return std::get(selection)->select_source(kind, gid, label); + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if(!std::holds_alternative(selection)) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return std::get(selection)->select_destination(kind, gid, label); + } + + void initialize(const network_label_dict& dict) override { + if(std::holds_alternative(selection)) { + auto s = dict.selection(std::get(selection)); + if (!s.has_value()) + throw arbor_exception(std::string("Network selection with label \"") + + std::get(selection) + "\" not found."); + selection = thingify(s.value(), dict); + } + }; }; struct network_selection_inter_cell_impl: public network_selection_impl { @@ -328,10 +376,9 @@ struct network_selection_not_equal_impl: public network_selection_impl { }; struct network_selection_custom_impl: public network_selection_impl { - std::function func; + network_selection::custom_func_type func; - explicit network_selection_custom_impl( - std::function f): + explicit network_selection_custom_impl(network_selection::custom_func_type f): func(std::move(f)) {} bool select_connection(const network_site_info& src, @@ -493,6 +540,11 @@ struct network_selection_and_impl: public network_selection_impl { return std::nullopt; } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; }; struct network_selection_or_impl: public network_selection_impl { @@ -529,6 +581,11 @@ struct network_selection_or_impl: public network_selection_impl { return std::nullopt; } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; }; struct network_selection_xor_impl: public network_selection_impl { @@ -564,6 +621,11 @@ struct network_selection_xor_impl: public network_selection_impl { return std::nullopt; } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; }; @@ -655,17 +717,40 @@ struct network_value_truncated_normal_distribution_impl: public network_value_im }; struct network_value_custom_impl: public network_value_impl { - std::function func; + network_value::custom_func_type func; - network_value_custom_impl( - std::function f): - func(std::move(f)) {} + network_value_custom_impl(network_value::custom_func_type f): func(std::move(f)) {} - inline double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_site_info& src, const network_site_info& dest) const override { return func(src, dest); } }; + +struct network_value_named_impl: public network_value_impl { + using impl_pointer_type =std::shared_ptr; + + std::variant value; + + explicit network_value_named_impl(std::string name): value(std::move(name)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + if(!std::holds_alternative(value)) + throw arbor_internal_error("Trying to use unitialized named network value."); + return std::get(value)->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + if(std::holds_alternative(value)) { + auto s = dict.value(std::get(value)); + if (!s.has_value()) + throw arbor_exception(std::string("Network value with label \"") + + std::get(value) + "\" not found."); + value = thingify(s.value(), dict); + } + }; +}; + } // namespace network_site_info::network_site_info(cell_gid_type gid, @@ -721,6 +806,10 @@ network_selection network_selection::none() { return network_selection(std::make_shared()); } +network_selection network_selection::named(std::string name) { + return network_selection(std::make_shared(std::move(name))); +} + network_selection network_selection::source_cell_kind(cell_kind kind) { return network_selection(std::make_shared(kind)); } @@ -761,8 +850,7 @@ network_selection network_selection::bernoulli_random(unsigned seed, double p) { return network_selection(std::make_shared(seed, p)); } -network_selection network_selection::custom( - std::function func) { +network_selection network_selection::custom(custom_func_type func) { return network_selection(std::make_shared(std::move(func))); } @@ -803,9 +891,34 @@ network_value network_value::truncated_normal_distribution(unsigned seed, seed, mean, std_deviation, range)); } -network_value network_value::custom( - std::function func) { +network_value network_value::custom(custom_func_type func) { return network_value(std::make_shared(std::move(func))); } +network_value network_value::named(std::string name) { + return network_value(std::make_shared(std::move(name))); +} + +network_label_dict& network_label_dict::set(const std::string& name, network_selection s) { + selections_.insert_or_assign(name, std::move(s)); +} + +network_label_dict& network_label_dict::set(const std::string& name, network_value v) { + values_.insert_or_assign(name, std::move(v)); +} + +std::optional network_label_dict::selection(const std::string& name) const { + auto it = selections_.find(name); + if(it != selections_.end()) return it->second; + + return std::nullopt; +} + +std::optional network_label_dict::value(const std::string& name) const { + auto it = values_.find(name); + if(it != values_.end()) return it->second; + + return std::nullopt; +} + } // namespace arb diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 7246910840..5e8545e8cd 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -142,9 +142,13 @@ std::vector generate_network_connections(const network_description& std::vector src_sites, dest_sites; - const auto& selection = get_network_selection_impl(description.selection); - const auto& weight = get_network_value_impl(description.weight); - const auto& delay = get_network_value_impl(description.delay); + const auto& selection_ptr = thingify(description.selection, description.dict); + const auto& weight_ptr = thingify(description.weight, description.dict); + const auto& delay_ptr = thingify(description.delay, description.dict); + + const auto& selection = *selection_ptr; + const auto& weight = *weight_ptr; + const auto& delay = *delay_ptr; // populate network sites for source and destination for (const auto& group: dom_dec.groups()) { diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 73fb7a81a8..71c2e32a64 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -1,8 +1,10 @@ #pragma once +#include #include #include +#include #include #include #include @@ -32,21 +34,30 @@ struct network_selection_impl { cell_gid_type gid, const std::string_view& tag) const = 0; + virtual void initialize(const network_label_dict& dict) {}; + virtual ~network_selection_impl() = default; }; -inline const network_selection_impl& get_network_selection_impl(const network_selection& s) { - return *(s.impl_); +inline std::shared_ptr thingify(network_selection s, + const network_label_dict& dict) { + s.impl_->initialize(dict); + return s.impl_; } + struct network_value_impl { virtual double get(const network_site_info& src, const network_site_info& dest) const = 0; + virtual void initialize(const network_label_dict& dict) {}; + virtual ~network_value_impl() = default; }; -inline const network_value_impl& get_network_value_impl(const network_value& v) { - return *(v.impl_); +inline std::shared_ptr thingify(network_value v, + const network_label_dict& dict) { + v.impl_->initialize(dict); + return v.impl_; } -} // namespace arb +} // namespace arb From d1f1e41165fb11a0230d54a1a80a1924a2d20165 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 14 Mar 2023 09:04:41 +0100 Subject: [PATCH 10/84] add label parsing (incomplete) --- arbor/include/arbor/network.hpp | 12 +-- arbor/include/arbor/recipe.hpp | 5 +- arbor/network.cpp | 8 +- arbor/network_generation.cpp | 13 ++-- arbor/network_generation.hpp | 8 +- arborio/include/arborio/label_parse.hpp | 28 ++++++- arborio/label_parse.cpp | 99 +++++++++++++++++++++---- 7 files changed, 135 insertions(+), 38 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 9be75eaa30..07fed0a40d 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -52,7 +52,7 @@ class ARB_SYMBOL_VISIBLE network_selection { using custom_func_type = std::function; - network_selection() { *this = network_selection::all(); } + network_selection() { *this = network_selection::none(); } // Select all static network_selection all(); @@ -125,11 +125,13 @@ class ARB_SYMBOL_VISIBLE network_value { using custom_func_type = std::function; - // Uniform value with conversion from double - network_value(double value) { *this = network_value::uniform(value); } + network_value() { *this = network_value::scalar(0.0); } - // Uniform value. Will always return the same value given at construction. - static network_value uniform(double value); + // Scalar value with conversion from double + network_value(double value) { *this = network_value::scalar(value); } + + // Scalar value. Will always return the same value given at construction. + static network_value scalar(double value); static network_value named(std::string name); diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 15ac14cd7f..1a63bd9e16 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -106,7 +107,9 @@ struct ARB_ARBOR_API recipe: public has_gap_junctions, has_probes, connectivity // Global property type will be specific to given cell kind. virtual std::any get_global_properties(cell_kind) const { return std::any{}; }; // Optional network descriptions for generating cell connections - virtual std::vector network_descriptions() const { return {}; }; + virtual std::optional network_description() const { + return std::nullopt; + }; virtual ~recipe() {} }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 2915592248..5e4897791a 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -629,10 +629,10 @@ struct network_selection_xor_impl: public network_selection_impl { }; -struct network_value_uniform_impl : public network_value_impl{ +struct network_value_scalar_impl : public network_value_impl{ double value; - network_value_uniform_impl(double v): value(v) {} + network_value_scalar_impl(double v): value(v) {} double get(const network_site_info& src, const network_site_info& dest) const override { return value; @@ -869,8 +869,8 @@ network_selection network_selection::linear_bernoulli_random(unsigned seed, network_value::network_value(std::shared_ptr impl): impl_(std::move(impl)) {} -network_value network_value::uniform(double value) { - return network_value(std::make_shared(value)); +network_value network_value::scalar(double value) { + return network_value(std::make_shared(value)); } network_value network_value::uniform_distribution(unsigned seed, diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 5e8545e8cd..ac9cc3d3c5 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -135,16 +135,19 @@ void distributed_for_each_site(const distributed_context& distributed, } // namespace -std::vector generate_network_connections(const network_description& description, - const recipe& rec, +std::vector generate_network_connections(const recipe& rec, const distributed_context& distributed, const domain_decomposition& dom_dec) { + const auto description_opt = rec.network_description(); + if(!description_opt.has_value()) return {}; + + const auto& description = description_opt.value(); std::vector src_sites, dest_sites; - const auto& selection_ptr = thingify(description.selection, description.dict); - const auto& weight_ptr = thingify(description.weight, description.dict); - const auto& delay_ptr = thingify(description.delay, description.dict); + const auto selection_ptr = thingify(description.selection, description.dict); + const auto weight_ptr = thingify(description.weight, description.dict); + const auto delay_ptr = thingify(description.delay, description.dict); const auto& selection = *selection_ptr; const auto& weight = *weight_ptr; diff --git a/arbor/network_generation.hpp b/arbor/network_generation.hpp index 37a9fa56f8..b19c97c3fe 100644 --- a/arbor/network_generation.hpp +++ b/arbor/network_generation.hpp @@ -12,12 +12,8 @@ namespace arb { -std::vector generate_network_connections( - const std::vector& descriptions, - const connectivity& rec, +std::vector generate_network_connections(const recipe& rec, const distributed_context& distributed, - const domain_decomposition& dom_dec, - const label_resolution_map& source_resolution_map, - const label_resolution_map& target_resolution_map); + const domain_decomposition& dom_dec); } // namespace arb diff --git a/arborio/include/arborio/label_parse.hpp b/arborio/include/arborio/label_parse.hpp index d937cad149..88895401e7 100644 --- a/arborio/include/arborio/label_parse.hpp +++ b/arborio/include/arborio/label_parse.hpp @@ -4,10 +4,11 @@ #include #include -#include +#include #include +#include +#include #include -#include #include #include @@ -29,6 +30,10 @@ ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_iexpr_expression(const std::string& s); +ARB_ARBORIO_API parse_label_hopefully parse_network_selection_expression(const std::string& s); +ARB_ARBORIO_API parse_label_hopefully parse_network_value_expression( + const std::string& s); + namespace literals { struct morph_from_string { @@ -70,7 +75,22 @@ arb::region operator "" _reg(const char* s, std::size_t) { else throw r.error(); } -inline morph_from_string operator "" _morph(const char* s, std::size_t) { return {s}; } -inline morph_from_label operator "" _lab(const char* s, std::size_t) { return {s}; } +inline morph_from_string operator"" _morph(const char* s, std::size_t) { return {s}; } +inline morph_from_label operator"" _lab(const char* s, std::size_t) { return {s}; } + +inline arb::network_selection operator"" _ns(const char* s, std::size_t) { + if (auto r = parse_network_selection_expression(s)) + return *r; + else + throw r.error(); +} + +inline arb::network_value operator"" _nv(const char* s, std::size_t) { + if (auto r = parse_network_value_expression(s)) + return *r; + else + throw r.error(); +} + } // namespace literals } // namespace arborio diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 22df96481c..2b8e000a99 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -4,9 +4,10 @@ #include #include +#include #include #include -#include +#include #include @@ -20,8 +21,9 @@ label_parse_error::label_parse_error(const std::string& msg, const arb::src_loca namespace { +using eval_map_type= std::unordered_multimap; -std::unordered_multimap eval_map { +eval_map_type eval_map { // Functions that return regions {"region-nil", make_call<>(arb::reg::nil, "'region-nil' with 0 arguments")}, @@ -184,13 +186,55 @@ std::unordered_multimap eval_map { {"div", make_conversion_fold(arb::iexpr::div, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, }; -parse_label_hopefully eval(const s_expr& e); +eval_map_type network_eval_map{ + // network_selection + {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, + {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, + {"inter-cell", + make_call<>(arb::network_selection::inter_cell, + "network selection of inter-cell connections only")}, + {"network-selection", + make_call(arb::network_selection::named, + "network selection with 1 argument: (value:string)")}, + {"and", + make_fold( + [](arb::network_selection left, arb::network_selection right) { + return std::move(left) & std::move(right); + }, + "logical \"and\" operation of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"or", + make_fold( + [](arb::network_selection left, arb::network_selection right) { + return std::move(left) | std::move(right); + }, + "logical \"or\" operation of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"xor", + make_fold( + [](arb::network_selection left, arb::network_selection right) { + return std::move(left) ^ std::move(right); + }, + "logical \"xor\" operation of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + + // network_value + {"scalar", + make_call(arb::network_value::scalar, + "network value with 1 argument: (value:double)")}, + {"network-value", + make_call(arb::network_value::named, + "network value with 1 argument: (value:string)")}, + +}; -parse_label_hopefully> eval_args(const s_expr& e) { +parse_label_hopefully eval(const s_expr& e, const eval_map_type& map); + +parse_label_hopefully> eval_args(const s_expr& e, const eval_map_type& map) { if (!e) return {std::vector{}}; // empty argument list std::vector args; for (auto& h: e) { - if (auto arg=eval(h)) { + if (auto arg=eval(h, map)) { args.push_back(std::move(*arg)); } else { @@ -242,7 +286,7 @@ std::string eval_description(const char* name, const std::vector& args // a label_error_state with an error string and location. // // If there was an unexpected/fatal error, an exception will be thrown. -parse_label_hopefully eval(const s_expr& e) { +parse_label_hopefully eval(const s_expr& e, const eval_map_type& map) { if (e.is_atom()) { return eval_atom(e); } @@ -251,14 +295,14 @@ parse_label_hopefully eval(const s_expr& e) { // tail is a list of arguments. // Evaluate the arguments, and return error state if an error occurred. - auto args = eval_args(e.tail()); + auto args = eval_args(e.tail(), map); if (!args) { return util::unexpected(args.error()); } // Find all candidate functions that match the name of the function. auto& name = e.head().atom().spelling; - auto matches = eval_map.equal_range(name); + auto matches = map.equal_range(name); // Search for a candidate that matches the argument list. for (auto i=matches.first; i!=matches.second; ++i) { if (i->second.match_args(*args)) { // found a match: evaluate and return. @@ -284,14 +328,14 @@ parse_label_hopefully eval(const s_expr& e) { } // namespace ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const std::string& e) { - return eval(parse_s_expr(e)); + return eval(parse_s_expr(e), eval_map); } ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const s_expr& s) { - return eval(s); + return eval(s, eval_map); } ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s))) { + if (auto e = eval(parse_s_expr(s), eval_map)) { if (e->type() == typeid(region)) { return {std::move(std::any_cast(*e))}; } @@ -308,7 +352,7 @@ ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const } ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s))) { + if (auto e = eval(parse_s_expr(s), eval_map)) { if (e->type() == typeid(locset)) { return {std::move(std::any_cast(*e))}; } @@ -325,7 +369,7 @@ ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const } parse_label_hopefully parse_iexpr_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s))) { + if (auto e = eval(parse_s_expr(s), eval_map)) { if (e->type() == typeid(iexpr)) { return {std::move(std::any_cast(*e))}; } @@ -338,4 +382,33 @@ parse_label_hopefully parse_iexpr_expression(const std::string& s) { } } + +parse_label_hopefully parse_network_selection_expression(const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_selection)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected( + label_parse_error( + concat("Invalid iexpr description: '", s))); + } + else { + return util::unexpected(label_parse_error(std::string()+e.error().what())); + } +} + +parse_label_hopefully parse_network_value_expression(const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_value)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected( + label_parse_error( + concat("Invalid iexpr description: '", s))); + } + else { + return util::unexpected(label_parse_error(std::string()+e.error().what())); + } +} + } // namespace arborio From 5bb32441ce0b33fac41e9f70fdfbda8b13e3a2db Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 17 Mar 2023 18:11:43 +0100 Subject: [PATCH 11/84] first python interface, fix string lifetime --- arbor/distributed_context.hpp | 13 +++- arbor/include/arbor/network.hpp | 18 +++-- arbor/include/arbor/recipe.hpp | 8 +- arbor/network.cpp | 95 +++++++++++++++++------ arbor/network_generation.cpp | 133 ++++++++++++++++++-------------- python/CMakeLists.txt | 1 + python/network.cpp | 107 +++++++++++++++++++++++++ python/pyarb.cpp | 2 + python/recipe.cpp | 2 + python/recipe.hpp | 15 +++- 10 files changed, 298 insertions(+), 96 deletions(-) create mode 100644 python/network.cpp diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index be64142cdc..3b6adfe120 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -35,7 +35,10 @@ namespace arb { struct distributed_request { inline void finalize() { - if (impl) impl->finalize(); + if (impl) { + impl->finalize(); + impl.reset(); + } } struct distributed_request_interface { @@ -44,6 +47,14 @@ struct distributed_request { virtual ~distributed_request_interface() = default; }; + ~distributed_request() { + try { + finalize(); + } + catch (...) { + } + } + std::unique_ptr impl; }; diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 07fed0a40d..86d5b3cc54 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -25,6 +25,8 @@ using network_location = std::array; using network_hash_type = std::uint64_t; struct ARB_SYMBOL_VISIBLE network_site_info { + network_site_info() = default; + network_site_info(cell_gid_type gid, cell_lid_type lid, cell_kind kind, @@ -74,8 +76,16 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection destination_gid(std::vector gids); + static network_selection intersect(network_selection left, network_selection right); + + static network_selection join(network_selection left, network_selection right); + + static network_selection symmetric_difference(network_selection left, network_selection right); + + static network_selection difference(network_selection left, network_selection right); + // Invert the selection - static network_selection invert(network_selection s); + static network_selection complement(network_selection s); // Only select connections between different cells static network_selection inter_cell(); @@ -105,12 +115,6 @@ class ARB_SYMBOL_VISIBLE network_selection { double distance_end, double p_end); - network_selection operator&(network_selection right) const; - - network_selection operator|(network_selection right) const; - - network_selection operator^(network_selection right) const; - private: network_selection(std::shared_ptr impl); diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 1a63bd9e16..c7c9aadd9d 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -67,6 +67,10 @@ struct ARB_ARBOR_API has_gap_junctions { virtual std::vector gap_junctions_on(cell_gid_type) const { return {}; } + // Optional network descriptions for generating cell connections + virtual std::optional network_description() const { + return std::nullopt; + }; virtual ~has_gap_junctions() {} }; @@ -106,10 +110,6 @@ struct ARB_ARBOR_API recipe: public has_gap_junctions, has_probes, connectivity virtual cell_kind get_cell_kind(cell_gid_type) const = 0; // Global property type will be specific to given cell kind. virtual std::any get_global_properties(cell_kind) const { return std::any{}; }; - // Optional network descriptions for generating cell connections - virtual std::optional network_description() const { - return std::nullopt; - }; virtual ~recipe() {} }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 5e4897791a..1635ee2b3d 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -265,10 +265,10 @@ struct network_selection_destination_gid_impl: public network_selection_impl { } }; -struct network_selection_invert_impl: public network_selection_impl { +struct network_selection_complement_impl: public network_selection_impl { std::shared_ptr selection; - explicit network_selection_invert_impl(std::shared_ptr s): + explicit network_selection_complement_impl(std::shared_ptr s): selection(std::move(s)) {} bool select_connection(const network_site_info& src, @@ -279,14 +279,14 @@ struct network_selection_invert_impl: public network_selection_impl { bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; // cannot exclude any because source selection cannot be inverted without + return true; // cannot exclude any because source selection cannot be complemented without // knowing selection criteria. } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; // cannot exclude any because destination selection cannot be inverted without + return true; // cannot exclude any because destination selection cannot be complemented without // knowing selection criteria. } @@ -504,10 +504,10 @@ struct network_selection_linear_bernoulli_random_impl: public network_selection_ std::optional max_distance() const override { return distance_end; } }; -struct network_selection_and_impl: public network_selection_impl { +struct network_selection_intersect_impl: public network_selection_impl { std::shared_ptr left, right; - network_selection_and_impl(std::shared_ptr l, + network_selection_intersect_impl(std::shared_ptr l, std::shared_ptr r): left(std::move(l)), right(std::move(r)) {} @@ -547,10 +547,10 @@ struct network_selection_and_impl: public network_selection_impl { }; }; -struct network_selection_or_impl: public network_selection_impl { +struct network_selection_join_impl: public network_selection_impl { std::shared_ptr left, right; - network_selection_or_impl(std::shared_ptr l, + network_selection_join_impl(std::shared_ptr l, std::shared_ptr r): left(std::move(l)), right(std::move(r)) {} @@ -588,10 +588,10 @@ struct network_selection_or_impl: public network_selection_impl { }; }; -struct network_selection_xor_impl: public network_selection_impl { +struct network_selection_symmetric_difference_impl: public network_selection_impl { std::shared_ptr left, right; - network_selection_xor_impl(std::shared_ptr l, + network_selection_symmetric_difference_impl(std::shared_ptr l, std::shared_ptr r): left(std::move(l)), right(std::move(r)) {} @@ -604,13 +604,14 @@ struct network_selection_xor_impl: public network_selection_impl { bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; + return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; + return left->select_destination(kind, gid, label) || + right->select_destination(kind, gid, label); } std::optional max_distance() const override { @@ -629,6 +630,46 @@ struct network_selection_xor_impl: public network_selection_impl { }; +struct network_selection_difference_impl: public network_selection_impl { + std::shared_ptr left, right; + + network_selection_difference_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return left->select_connection(src, dest) && !(right->select_connection(src, dest)); + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return left->select_source(kind, gid, label); + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return left->select_destination(kind, gid, label); + } + + std::optional max_distance() const override { + const auto d_left = left->max_distance(); + + if (d_left) return d_left.value(); + + return std::nullopt; + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; +}; + + struct network_value_scalar_impl : public network_value_impl{ double value; @@ -783,19 +824,24 @@ network_site_info::network_site_info(cell_gid_type gid, network_selection::network_selection(std::shared_ptr impl): impl_(std::move(impl)) {} -network_selection network_selection::operator&(network_selection right) const { - return network_selection( - std::make_shared(this->impl_, std::move(right.impl_))); +network_selection network_selection::intersect(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); } -network_selection network_selection::operator|(network_selection right) const { - return network_selection( - std::make_shared(this->impl_, std::move(right.impl_))); +network_selection network_selection::join(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); } -network_selection network_selection::operator^(network_selection right) const { - return network_selection( - std::make_shared(this->impl_, std::move(right.impl_))); +network_selection network_selection::symmetric_difference(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); +} + +network_selection network_selection::difference(network_selection left, network_selection right) { + return network_selection(std::make_shared( + std::move(left.impl_), std::move(right.impl_))); } network_selection network_selection::all() { @@ -834,8 +880,9 @@ network_selection network_selection::destination_gid(std::vector return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::invert(network_selection s) { - return network_selection(std::make_shared(std::move(s.impl_))); +network_selection network_selection::complement(network_selection s) { + return network_selection( + std::make_shared(std::move(s.impl_))); } network_selection network_selection::inter_cell() { @@ -901,10 +948,12 @@ network_value network_value::named(std::string name) { network_label_dict& network_label_dict::set(const std::string& name, network_selection s) { selections_.insert_or_assign(name, std::move(s)); + return *this; } network_label_dict& network_label_dict::set(const std::string& name, network_value v) { values_.insert_or_assign(name, std::move(v)); + return *this; } std::optional network_label_dict::selection(const std::string& name) const { diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index ac9cc3d3c5..fd4b9b9112 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -30,50 +30,60 @@ struct distributed_site_info { network_hash_type hash = 0; }; -struct distributed_site_mapping { + +struct site_mapping { std::vector sites; std::string labels; + std::unordered_map label_map; - distributed_site_mapping() = default; + site_mapping() = default; - explicit distributed_site_mapping(const std::vector& net_sites) { - std::unordered_map label_to_start_idx; + inline std::size_t size() const { return sites.size(); } - for (const auto& s: net_sites) { - const auto insert_pair = label_to_start_idx.insert({s.label, labels.size()}); - // append label if not contained in labels - if (insert_pair.second) { - labels.append(s.label); - labels.push_back('\0'); - } - sites.emplace_back(distributed_site_info{s.gid, - s.lid, - s.kind, - insert_pair.first->second, - s.location, - s.global_location, - s.hash}); + void insert(const network_site_info& s) { + const auto insert_pair = label_map.insert({s.label, labels.size()}); + // append label if not contained in labels + if (insert_pair.second) { + labels.append(s.label); + labels.push_back('\0'); } + sites.emplace_back(distributed_site_info{s.gid, + s.lid, + s.kind, + insert_pair.first->second, + s.location, + s.global_location, + s.hash}); } - std::string_view label_at_site(const distributed_site_info& si) { - return labels.c_str() + si.label_start_idx; + network_site_info get_site(std::size_t idx) const { + const auto& s = this->sites.at(idx); + + network_site_info info; + info.gid = s.gid; + info.lid = s.lid; + info.kind = s.kind; + info.label = labels.c_str() + s.label_start_idx; + info.location = s.location; + info.global_location = s.global_location; + info.hash = s.hash; + + return info; } }; -template -void distributed_for_each_site(const distributed_context& distributed, - const std::vector& src_sites, - FUNC f) { - if(distributed.size() > 1) { - distributed_site_mapping mapping(src_sites); +struct distributed_site_mapping { + const distributed_context& distributed; + std::vector num_sites_per_rank, label_string_size_per_rank; + site_mapping mapping, recv_mapping; - const auto my_rank = distributed.id(); - const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; - const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; + explicit distributed_site_mapping(const distributed_context& distributed, site_mapping m): + distributed(distributed), + mapping(std::move(m)) { + mapping.label_map.clear(); // no longer valid after first exchange - const auto num_sites_per_rank = distributed.gather_all(mapping.sites.size()); - const auto label_string_size_per_rank = distributed.gather_all(mapping.labels.size()); + num_sites_per_rank = distributed.gather_all(mapping.sites.size()); + label_string_size_per_rank = distributed.gather_all(mapping.labels.size()); const auto max_num_sites = *std::max_element(num_sites_per_rank.begin(), num_sites_per_rank.end()); @@ -82,14 +92,18 @@ void distributed_for_each_site(const distributed_context& distributed, mapping.sites.resize(max_num_sites); mapping.labels.resize(max_string_size); - - distributed_site_mapping recv_mapping; recv_mapping.sites.resize(max_num_sites); recv_mapping.labels.resize(max_string_size); + } - auto current_idx = my_rank; + template + void for_each_site(const FUNC& f) { + const auto my_rank = distributed.id(); + const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; + const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; - for(std::size_t step = 0; step < distributed.size() - 1; ++step) { + auto current_idx = my_rank; + for (std::size_t step = 0; step < distributed.size() - 1; ++step) { const auto next_idx = (current_idx + 1) % distributed.size(); auto request_sites = distributed.send_recv_nonblocking(num_sites_per_rank[next_idx], recv_mapping.sites.data(), @@ -109,9 +123,7 @@ void distributed_for_each_site(const distributed_context& distributed, 1); for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { - const auto& s = mapping.sites[site_idx]; - f(network_site_info{ - s.gid, s.lid, s.kind, mapping.label_at_site(s), s.location, s.global_location}); + f(mapping.get_site(site_idx)); } request_sites.finalize(); @@ -123,15 +135,10 @@ void distributed_for_each_site(const distributed_context& distributed, } for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { - const auto& s = mapping.sites[site_idx]; - f(network_site_info{ - s.gid, s.lid, s.kind, mapping.label_at_site(s), s.location, s.global_location}); + f(mapping.get_site(site_idx)); } } - else { - for (const auto& s: src_sites) { f(s); } - } -} +}; } // namespace @@ -139,11 +146,11 @@ std::vector generate_network_connections(const recipe& rec, const distributed_context& distributed, const domain_decomposition& dom_dec) { const auto description_opt = rec.network_description(); - if(!description_opt.has_value()) return {}; + if (!description_opt.has_value()) return {}; const auto& description = description_opt.value(); - std::vector src_sites, dest_sites; + site_mapping src_sites, dest_sites; const auto selection_ptr = thingify(description.selection, description.dict); const auto weight_ptr = thingify(description.weight, description.dict); @@ -187,12 +194,12 @@ std::vector generate_network_connections(const recipe& rec, // TODO: compute rotation and global offset const mpoint point = location_resolver.at(p_syn.loc); network_location global_location = {point.x, point.y, point.z}; - dest_sites.emplace_back(gid, + dest_sites.insert({gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, - global_location); + global_location}); } } } @@ -200,13 +207,13 @@ std::vector generate_network_connections(const recipe& rec, // check all detectors of cell for potential source for (const auto& p_det: cell.detectors()) { // TODO check if tag correct - const auto& label = lid_to_label(cell.synapse_ranges(), p_det.lid); + const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); if (selection.select_destination(cell_kind::cable, gid, label)) { // TODO: compute rotation and global offset const mpoint point = location_resolver.at(p_det.loc); network_location global_location = {point.x, point.y, point.z}; - src_sites.emplace_back( - gid, p_det.lid, cell_kind::cable, label, p_det.loc, global_location); + src_sites.insert( + {gid, p_det.lid, cell_kind::cable, label, p_det.loc, global_location}); } } } @@ -227,19 +234,24 @@ std::vector generate_network_connections(const recipe& rec, } // create octree + std::vector network_dest_sites; + network_dest_sites.reserve(dest_sites.size()); + for(std::size_t i = 0; i < dest_sites.size(); ++i) { + network_dest_sites.emplace_back(dest_sites.get_site(i)); + } const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; const std::size_t max_leaf_size = 100; spatial_tree local_dest_tree( - max_depth, max_leaf_size, std::move(dest_sites), [](const network_site_info& info) { + max_depth, max_leaf_size, std::move(network_dest_sites), [](const network_site_info& info) { return info.global_location; }); // select connections std::vector connections; - auto sample_destinations = [&] (const network_site_info& src) { - auto sample = [&] (const network_site_info& dest) { - if(selection.select_connection(src, dest)) { + auto sample_destinations = [&](const network_site_info& src) { + auto sample = [&](const network_site_info& dest) { + if (selection.select_connection(src, dest)) { connections.emplace_back(connection({src.gid, src.lid}, {dest.gid, dest.lid}, weight.get(src, dest), @@ -247,7 +259,7 @@ std::vector generate_network_connections(const recipe& rec, } }; - if(selection.max_distance().has_value()) { + if (selection.max_distance().has_value()) { const double d = selection.max_distance().value(); local_dest_tree.bounding_box_for_each(network_location{src.global_location[0] - d, src.global_location[1] - d, @@ -256,12 +268,13 @@ std::vector generate_network_connections(const recipe& rec, src.global_location[1] + d, src.global_location[2] + d}, sample); - } else { - local_dest_tree.for_each(sample); } + else { local_dest_tree.for_each(sample); } }; - distributed_for_each_site(distributed, src_sites, sample_destinations); + distributed_site_mapping distributed_src_sites(distributed, std::move(src_sites)); + + distributed_src_sites.for_each_site(sample_destinations); return connections; } diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 69e0f8bd1c..36b5cd5bb0 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -34,6 +34,7 @@ set(pyarb_source mechanism.cpp morphology.cpp mpi.cpp + network.cpp profiler.cpp pyarb.cpp recipe.cpp diff --git a/python/network.cpp b/python/network.cpp new file mode 100644 index 0000000000..a135b48527 --- /dev/null +++ b/python/network.cpp @@ -0,0 +1,107 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include "error.hpp" + +namespace py = pybind11; + +namespace pyarb { + +void register_network(py::module& m) { + using namespace py::literals; + + py::class_ network_site_info( + m, "network_site_info", "Identifies a network site to connect to / from"); + network_site_info.def_readwrite("gid", &arb::network_site_info::gid) + .def_readwrite("lid", &arb::network_site_info::lid) + .def_readwrite("kind", &arb::network_site_info::kind) + .def_readwrite("label", &arb::network_site_info::label) + .def_readwrite("location", &arb::network_site_info::location) + .def_readwrite("global_location", &arb::network_site_info::global_location); + + py::class_ network_selection( + m, "network_selection", "Network selection."); + network_selection.def_static("custom", [](arb::network_selection::custom_func_type func) { + return arb::network_selection::custom( + [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(src, dest); + }, + "Python error already thrown"); + }); + }); + + py::class_ network_value(m, "network_value", "Network value."); + network_value.def_static("custom", [](arb::network_value::custom_func_type func) { + return arb::network_value::custom( + [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(src, dest); + }, + "Python error already thrown"); + }); + }); + + py::class_ network_description( + m, "network_description", "Network description."); + network_description.def( + py::init( + [](std::string selection, + std::string weight, + std::string delay, + std::unordered_map> map) { + arb::network_label_dict dict; + for (const auto& [label, v]: map) { + const auto& dict_label = label; + std::visit( + arb::util::overload( + [&](const std::string& s) { + auto sel = arborio::parse_network_selection_expression(s); + if (sel) { + dict.set(dict_label, *sel); + return; + } + + auto val = arborio::parse_network_value_expression(s); + if (val) { + dict.set(dict_label, *val); + return; + } + + throw pyarb_error( + std::string("Failed to parse \"") + dict_label + + "\" label in dict of network description. \nSelection " + "label parse error:\n" + + sel.error().what() + "\nValue label parse error:\n" + + val.error().what()); + }, + [&](const arb::network_selection& sel) { dict.set(dict_label, sel); }, + [&](const arb::network_value& val) { dict.set(dict_label, val); }), + v); + } + return arb::network_description(); + }), + "selection"_a, + "weight"_a, + "delay"_a, + "dict"_a, + "Construct network description."); +} + +} // namespace pyarb diff --git a/python/pyarb.cpp b/python/pyarb.cpp index 57e503b07a..4910afc5b2 100644 --- a/python/pyarb.cpp +++ b/python/pyarb.cpp @@ -30,6 +30,7 @@ void register_schedules(pybind11::module& m); void register_simulation(pybind11::module& m, pyarb_global_ptr); void register_single_cell(pybind11::module& m); void register_arborenv(pybind11::module& m); +void register_network(pybind11::module& m); #ifdef ARB_MPI_ENABLED void register_mpi(pybind11::module& m); @@ -58,6 +59,7 @@ PYBIND11_MODULE(_arbor, m) { pyarb::register_mechanisms(m); pyarb::register_morphology(m); pyarb::register_profiler(m); + pyarb::register_network(m); pyarb::register_recipe(m); pyarb::register_schedules(m); pyarb::register_simulation(m, global_ptr); diff --git a/python/recipe.cpp b/python/recipe.cpp index 54cef50e7a..e982f5a943 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -200,6 +200,8 @@ void register_recipe(pybind11::module& m) { .def("gap_junctions_on", &py_recipe::gap_junctions_on, "gid"_a, "A list of the gap junctions connected to gid, [] by default.") + .def("network_description", &py_recipe::network_description, + "Network description of cell connections.") .def("probes", &py_recipe::probes, "gid"_a, "The probes to allow monitoring.") diff --git a/python/recipe.hpp b/python/recipe.hpp index 1a3b44538e..c8fbc1df45 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -1,13 +1,15 @@ #pragma once #include +#include #include #include #include -#include #include +#include +#include #include #include "error.hpp" @@ -48,6 +50,9 @@ class py_recipe { virtual pybind11::object global_properties(arb::cell_kind kind) const { return pybind11::none(); }; + virtual std::optional network_description() const { + return std::nullopt; + }; }; class py_recipe_trampoline: public py_recipe { @@ -76,6 +81,10 @@ class py_recipe_trampoline: public py_recipe { PYBIND11_OVERRIDE(std::vector, py_recipe, gap_junctions_on, gid); } + std::optional network_description() const override { + PYBIND11_OVERRIDE_PURE(arb::network_description, py_recipe, network_description); + } + std::vector probes(arb::cell_gid_type gid) const override { PYBIND11_OVERRIDE(std::vector, py_recipe, probes, gid); } @@ -129,6 +138,10 @@ class py_recipe_shim: public arb::recipe { } std::any get_global_properties(arb::cell_kind kind) const override; + + std::optional network_description() const override { + return try_catch_pyexception([&]() { return impl_->network_description(); }, msg); + }; }; } // namespace pyarb From b9dcfa6afb187b4b06cbf4fa3c1399d4e54eaf1b Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 17 Mar 2023 18:12:07 +0100 Subject: [PATCH 12/84] label parse --- arborio/label_parse.cpp | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 2b8e000a99..edcc29297a 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -196,27 +196,22 @@ eval_map_type network_eval_map{ {"network-selection", make_call(arb::network_selection::named, "network selection with 1 argument: (value:string)")}, - {"and", - make_fold( - [](arb::network_selection left, arb::network_selection right) { - return std::move(left) & std::move(right); - }, - "logical \"and\" operation of network selections with at least 2 arguments: " + {"intersect", + make_fold(arb::network_selection::intersect, + "intersection of network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, - {"or", - make_fold( - [](arb::network_selection left, arb::network_selection right) { - return std::move(left) | std::move(right); - }, - "logical \"or\" operation of network selections with at least 2 arguments: " + {"join", + make_fold(arb::network_selection::join, + "join or union operation of network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, - {"xor", - make_fold( - [](arb::network_selection left, arb::network_selection right) { - return std::move(left) ^ std::move(right); - }, - "logical \"xor\" operation of network selections with at least 2 arguments: " + {"symmetric_difference", + make_fold(arb::network_selection::symmetric_difference, + "symmetric difference operation between network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, + {"difference", + make_call(arb::network_selection::difference, + "difference of first selection with the second one: " + "(network_selection network_selection)")}, // network_value {"scalar", From 16dc1e9f7798b1c7a15ee3ad611c1133a1a858ba Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sat, 25 Mar 2023 13:48:23 +0100 Subject: [PATCH 13/84] add rotation / translation, more selections --- arbor/include/arbor/network.hpp | 6 ++- arbor/include/arbor/recipe.hpp | 3 ++ arbor/network.cpp | 89 +++++++++++++++++++++------------ arbor/network_generation.cpp | 4 +- arborio/label_parse.cpp | 6 ++- 5 files changed, 71 insertions(+), 37 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 86d5b3cc54..ee144b0af5 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -105,7 +105,11 @@ class ARB_SYMBOL_VISIBLE network_selection { // only select within given distance. This may enable more efficient sampling through an // internal spatial data structure. - static network_selection within_distance(double distance); + static network_selection distance_lt(double distance); + + // only select if distance greater then given distance. This may enable more efficient sampling + // through an internal spatial data structure. + static network_selection distance_gt(double distance); // random bernoulli sampling with a linear interpolated probabilty based on distance. Returns // "false" for any distance outside of the interval [distance_begin, distance_end]. diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index c7c9aadd9d..24b9b2d335 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -110,6 +111,8 @@ struct ARB_ARBOR_API recipe: public has_gap_junctions, has_probes, connectivity virtual cell_kind get_cell_kind(cell_gid_type) const = 0; // Global property type will be specific to given cell kind. virtual std::any get_global_properties(cell_kind) const { return std::any{}; }; + // Global cell isometry describing rotation and translation of the cell + virtual isometry get_cell_isometry(cell_gid_type gid) const { return isometry(); }; virtual ~recipe() {} }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 1635ee2b3d..061672652c 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -71,7 +71,6 @@ double normal_rand_from_key_pair(std::array seed, return r123::boxmuller(rand_num[0], rand_num[1]).x; } - double network_location_distance(const network_location& a, const network_location& b) { return std::sqrt(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); } @@ -286,18 +285,15 @@ struct network_selection_complement_impl: public network_selection_impl { bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; // cannot exclude any because destination selection cannot be complemented without - // knowing selection criteria. + return true; // cannot exclude any because destination selection cannot be complemented + // without knowing selection criteria. } - void initialize(const network_label_dict& dict) override { - selection->initialize(dict); - }; + void initialize(const network_label_dict& dict) override { selection->initialize(dict); }; }; - struct network_selection_named_impl: public network_selection_impl { - using impl_pointer_type =std::shared_ptr; + using impl_pointer_type = std::shared_ptr; std::variant selection; @@ -305,7 +301,7 @@ struct network_selection_named_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - if(!std::holds_alternative(selection)) + if (!std::holds_alternative(selection)) throw arbor_internal_error("Trying to use unitialized named network selection."); return std::get(selection)->select_connection(src, dest); } @@ -313,7 +309,7 @@ struct network_selection_named_impl: public network_selection_impl { bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if(!std::holds_alternative(selection)) + if (!std::holds_alternative(selection)) throw arbor_internal_error("Trying to use unitialized named network selection."); return std::get(selection)->select_source(kind, gid, label); } @@ -321,13 +317,13 @@ struct network_selection_named_impl: public network_selection_impl { bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if(!std::holds_alternative(selection)) + if (!std::holds_alternative(selection)) throw arbor_internal_error("Trying to use unitialized named network selection."); return std::get(selection)->select_destination(kind, gid, label); } void initialize(const network_label_dict& dict) override { - if(std::holds_alternative(selection)) { + if (std::holds_alternative(selection)) { auto s = dict.selection(std::get(selection)); if (!s.has_value()) throw arbor_exception(std::string("Network selection with label \"") + @@ -399,14 +395,14 @@ struct network_selection_custom_impl: public network_selection_impl { } }; -struct network_selection_within_distance_impl: public network_selection_impl { +struct network_selection_distance_lt_impl: public network_selection_impl { double distance; - explicit network_selection_within_distance_impl(double distance): distance(distance) {} + explicit network_selection_distance_lt_impl(double distance): distance(distance) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return network_location_distance(src.global_location, dest.global_location) <= distance; + return network_location_distance(src.global_location, dest.global_location) < distance; } bool select_source(cell_kind kind, @@ -424,6 +420,29 @@ struct network_selection_within_distance_impl: public network_selection_impl { std::optional max_distance() const override { return distance; } }; +struct network_selection_distance_gt_impl: public network_selection_impl { + double distance; + + explicit network_selection_distance_gt_impl(double distance): distance(distance) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return network_location_distance(src.global_location, dest.global_location) > distance; + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return true; + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return true; + } +}; + struct network_selection_bernoulli_random_impl: public network_selection_impl { unsigned seed; double probability; @@ -475,7 +494,8 @@ struct network_selection_linear_bernoulli_random_impl: public network_selection_ bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - const double distance = network_location_distance(src.global_location, dest.global_location); + const double distance = + network_location_distance(src.global_location, dest.global_location); if (distance < distance_begin || distance > distance_end) return false; @@ -629,7 +649,6 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp }; }; - struct network_selection_difference_impl: public network_selection_impl { std::shared_ptr left, right; @@ -669,8 +688,7 @@ struct network_selection_difference_impl: public network_selection_impl { }; }; - -struct network_value_scalar_impl : public network_value_impl{ +struct network_value_scalar_impl: public network_value_impl { double value; network_value_scalar_impl(double v): value(v) {} @@ -680,7 +698,7 @@ struct network_value_scalar_impl : public network_value_impl{ } }; -struct network_value_uniform_distribution_impl : public network_value_impl{ +struct network_value_uniform_distribution_impl: public network_value_impl { unsigned seed = 0; std::array range; @@ -767,22 +785,21 @@ struct network_value_custom_impl: public network_value_impl { } }; - struct network_value_named_impl: public network_value_impl { - using impl_pointer_type =std::shared_ptr; + using impl_pointer_type = std::shared_ptr; std::variant value; explicit network_value_named_impl(std::string name): value(std::move(name)) {} double get(const network_site_info& src, const network_site_info& dest) const override { - if(!std::holds_alternative(value)) + if (!std::holds_alternative(value)) throw arbor_internal_error("Trying to use unitialized named network value."); return std::get(value)->get(src, dest); } void initialize(const network_label_dict& dict) override { - if(std::holds_alternative(value)) { + if (std::holds_alternative(value)) { auto s = dict.value(std::get(value)); if (!s.has_value()) throw arbor_exception(std::string("Network value with label \"") + @@ -834,7 +851,8 @@ network_selection network_selection::join(network_selection left, network_select std::move(left.impl_), std::move(right.impl_))); } -network_selection network_selection::symmetric_difference(network_selection left, network_selection right) { +network_selection network_selection::symmetric_difference(network_selection left, + network_selection right) { return network_selection(std::make_shared( std::move(left.impl_), std::move(right.impl_))); } @@ -865,11 +883,13 @@ network_selection network_selection::destination_cell_kind(cell_kind kind) { } network_selection network_selection::source_label(std::vector labels) { - return network_selection(std::make_shared(std::move(labels))); + return network_selection( + std::make_shared(std::move(labels))); } network_selection network_selection::destination_label(std::vector labels) { - return network_selection(std::make_shared(std::move(labels))); + return network_selection( + std::make_shared(std::move(labels))); } network_selection network_selection::source_gid(std::vector gids) { @@ -877,7 +897,8 @@ network_selection network_selection::source_gid(std::vector gids) } network_selection network_selection::destination_gid(std::vector gids) { - return network_selection(std::make_shared(std::move(gids))); + return network_selection( + std::make_shared(std::move(gids))); } network_selection network_selection::complement(network_selection s) { @@ -901,8 +922,12 @@ network_selection network_selection::custom(custom_func_type func) { return network_selection(std::make_shared(std::move(func))); } -network_selection network_selection::within_distance(double distance) { - return network_selection(std::make_shared(distance)); +network_selection network_selection::distance_lt(double distance) { + return network_selection(std::make_shared(distance)); +} + +network_selection network_selection::distance_gt(double distance) { + return network_selection(std::make_shared(distance)); } network_selection network_selection::linear_bernoulli_random(unsigned seed, @@ -958,14 +983,14 @@ network_label_dict& network_label_dict::set(const std::string& name, network_val std::optional network_label_dict::selection(const std::string& name) const { auto it = selections_.find(name); - if(it != selections_.end()) return it->second; + if (it != selections_.end()) return it->second; return std::nullopt; } std::optional network_label_dict::value(const std::string& name) const { auto it = values_.find(name); - if(it != values_.end()) return it->second; + if (it != values_.end()) return it->second; return std::nullopt; } diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index fd4b9b9112..5bafc90927 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -181,7 +181,7 @@ std::vector generate_network_connections(const recipe& rec, throw arbor_internal_error("unkown lid"); }; - place_pwlin location_resolver(cell.morphology()); + place_pwlin location_resolver(cell.morphology(), rec.get_cell_isometry(gid)); // check all synapses of cell for potential destination @@ -191,7 +191,6 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); if (selection.select_destination(cell_kind::cable, gid, label)) { - // TODO: compute rotation and global offset const mpoint point = location_resolver.at(p_syn.loc); network_location global_location = {point.x, point.y, point.z}; dest_sites.insert({gid, @@ -209,7 +208,6 @@ std::vector generate_network_connections(const recipe& rec, // TODO check if tag correct const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); if (selection.select_destination(cell_kind::cable, gid, label)) { - // TODO: compute rotation and global offset const mpoint point = location_resolver.at(p_det.loc); network_location global_location = {point.x, point.y, point.z}; src_sites.insert( diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index edcc29297a..3937ae13ae 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -209,9 +209,13 @@ eval_map_type network_eval_map{ "symmetric difference operation between network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, {"difference", - make_call(arb::network_selection::difference, + make_call( + arb::network_selection::difference, "difference of first selection with the second one: " "(network_selection network_selection)")}, + {"complement", + make_call(arb::network_selection::complement, + "complement of given selection: (network_selection)")}, // network_value {"scalar", From edae2d9ad055435d01aa94a728fdd39c9d5f06b2 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sat, 25 Mar 2023 19:57:14 +0100 Subject: [PATCH 14/84] use mpoint --- arbor/include/arbor/network.hpp | 7 +++---- arbor/network.cpp | 6 +++--- arbor/network_generation.cpp | 35 +++++++++++++++------------------ 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index ee144b0af5..061b0899ca 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -20,8 +21,6 @@ namespace arb { -using network_location = std::array; - using network_hash_type = std::uint64_t; struct ARB_SYMBOL_VISIBLE network_site_info { @@ -32,14 +31,14 @@ struct ARB_SYMBOL_VISIBLE network_site_info { cell_kind kind, std::string_view label, mlocation location, - network_location global_location); + mpoint global_location); cell_gid_type gid; cell_lid_type lid; cell_kind kind; std::string_view label; mlocation location; - network_location global_location; + mpoint global_location; network_hash_type hash; }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 061672652c..3a895a52e8 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -71,8 +71,8 @@ double normal_rand_from_key_pair(std::array seed, return r123::boxmuller(rand_num[0], rand_num[1]).x; } -double network_location_distance(const network_location& a, const network_location& b) { - return std::sqrt(a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); +double network_location_distance(const mpoint& a, const mpoint& b) { + return std::sqrt(a.x * b.x + a.y * b.y + a.z * b.z); } struct network_selection_all_impl: public network_selection_impl { @@ -816,7 +816,7 @@ network_site_info::network_site_info(cell_gid_type gid, cell_kind kind, std::string_view label, mlocation location, - network_location global_location): + mpoint global_location): gid(gid), lid(lid), kind(kind), diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 5bafc90927..582c30fd77 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -26,7 +26,7 @@ struct distributed_site_info { cell_kind kind = cell_kind::cable; cell_gid_type label_start_idx = 0; mlocation location = mlocation(); - network_location global_location = network_location(); + mpoint global_location = mpoint(); network_hash_type hash = 0; }; @@ -192,13 +192,8 @@ std::vector generate_network_connections(const recipe& rec, if (selection.select_destination(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_syn.loc); - network_location global_location = {point.x, point.y, point.z}; - dest_sites.insert({gid, - p_syn.lid, - cell_kind::cable, - label, - p_syn.loc, - global_location}); + dest_sites.insert( + {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); } } } @@ -209,9 +204,8 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); if (selection.select_destination(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_det.loc); - network_location global_location = {point.x, point.y, point.z}; src_sites.insert( - {gid, p_det.lid, cell_kind::cable, label, p_det.loc, global_location}); + {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); } } } @@ -239,9 +233,11 @@ std::vector generate_network_connections(const recipe& rec, } const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; const std::size_t max_leaf_size = 100; - spatial_tree local_dest_tree( - max_depth, max_leaf_size, std::move(network_dest_sites), [](const network_site_info& info) { - return info.global_location; + spatial_tree local_dest_tree(max_depth, + max_leaf_size, + std::move(network_dest_sites), + [](const network_site_info& info) -> spatial_tree::point_type { + return {info.global_location.x, info.global_location.y, info.global_location.z}; }); // select connections @@ -259,12 +255,13 @@ std::vector generate_network_connections(const recipe& rec, if (selection.max_distance().has_value()) { const double d = selection.max_distance().value(); - local_dest_tree.bounding_box_for_each(network_location{src.global_location[0] - d, - src.global_location[1] - d, - src.global_location[2] - d}, - network_location{src.global_location[0] + d, - src.global_location[1] + d, - src.global_location[2] + d}, + local_dest_tree.bounding_box_for_each( + decltype(local_dest_tree)::point_type{src.global_location.x - d, + src.global_location.y - d, + src.global_location.z - d}, + decltype(local_dest_tree)::point_type{src.global_location.x + d, + src.global_location.y + d, + src.global_location.z + d}, sample); } else { local_dest_tree.for_each(sample); } From 6882c5936e609acfa7df14e397cbba9de59018f8 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sat, 25 Mar 2023 21:23:53 +0100 Subject: [PATCH 15/84] all labels for selection --- arbor/include/arbor/network.hpp | 12 ++--- arbor/network.cpp | 39 ++++------------ arborio/label_parse.cpp | 80 ++++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 061b0899ca..c70f341f7c 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -89,13 +89,9 @@ class ARB_SYMBOL_VISIBLE network_selection { // Only select connections between different cells static network_selection inter_cell(); - // Only select connections when the global labels are not equal. May select intra-cell - // connections, if the local label is not equal. - static network_selection not_equal(); - // Random selection using the bernoulli random distribution with probability "p" between 0.0 // and 1.0 - static network_selection bernoulli_random(unsigned seed, double p); + static network_selection random_bernoulli(unsigned seed, double p); // Custom selection using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction selection, @@ -110,9 +106,9 @@ class ARB_SYMBOL_VISIBLE network_selection { // through an internal spatial data structure. static network_selection distance_gt(double distance); - // random bernoulli sampling with a linear interpolated probabilty based on distance. Returns - // "false" for any distance outside of the interval [distance_begin, distance_end]. - static network_selection linear_bernoulli_random(unsigned seed, + // randomly selected with a probability linearly interpolated between [p_begin, p_end] based on + // the distance in the interval [distance_begin, distance_end]. + static network_selection random_linear_distance(unsigned seed, double distance_begin, double p_begin, double distance_end, diff --git a/arbor/network.cpp b/arbor/network.cpp index 3a895a52e8..13a911d592 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -352,25 +352,6 @@ struct network_selection_inter_cell_impl: public network_selection_impl { } }; -struct network_selection_not_equal_impl: public network_selection_impl { - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { - return src.gid != dest.gid || src.label != dest.label || src.location != dest.location; - } - - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { - return true; - } - - bool select_destination(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { - return true; - } -}; - struct network_selection_custom_impl: public network_selection_impl { network_selection::custom_func_type func; @@ -443,11 +424,11 @@ struct network_selection_distance_gt_impl: public network_selection_impl { } }; -struct network_selection_bernoulli_random_impl: public network_selection_impl { +struct network_selection_random_bernoulli_impl: public network_selection_impl { unsigned seed; double probability; - network_selection_bernoulli_random_impl(unsigned seed, double p): seed(seed), probability(p) {} + network_selection_random_bernoulli_impl(unsigned seed, double p): seed(seed), probability(p) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -469,14 +450,14 @@ struct network_selection_bernoulli_random_impl: public network_selection_impl { } }; -struct network_selection_linear_bernoulli_random_impl: public network_selection_impl { +struct network_selection_random_linear_distance_impl: public network_selection_impl { unsigned seed; double distance_begin; double p_begin; double distance_end; double p_end; - network_selection_linear_bernoulli_random_impl(unsigned seed_, + network_selection_random_linear_distance_impl(unsigned seed_, double distance_begin_, double p_begin_, double distance_end_, @@ -910,12 +891,8 @@ network_selection network_selection::inter_cell() { return network_selection(std::make_shared()); } -network_selection network_selection::not_equal() { - return network_selection(std::make_shared()); -} - -network_selection network_selection::bernoulli_random(unsigned seed, double p) { - return network_selection(std::make_shared(seed, p)); +network_selection network_selection::random_bernoulli(unsigned seed, double p) { + return network_selection(std::make_shared(seed, p)); } network_selection network_selection::custom(custom_func_type func) { @@ -930,12 +907,12 @@ network_selection network_selection::distance_gt(double distance) { return network_selection(std::make_shared(distance)); } -network_selection network_selection::linear_bernoulli_random(unsigned seed, +network_selection network_selection::random_linear_distance(unsigned seed, double distance_begin, double p_begin, double distance_end, double p_end) { - return network_selection(std::make_shared( + return network_selection(std::make_shared( seed, distance_begin, p_begin, distance_end, p_end)); } diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 3937ae13ae..c9641606d6 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -1,9 +1,11 @@ #include #include +#include #include #include +#include #include #include #include @@ -21,6 +23,14 @@ label_parse_error::label_parse_error(const std::string& msg, const arb::src_loca namespace { +struct gid_list { + gid_list() = default; + + gid_list(cell_gid_type gid) : gids({gid}) {} + + std::vector gids; +}; + using eval_map_type= std::unordered_multimap; eval_map_type eval_map { @@ -187,6 +197,28 @@ eval_map_type eval_map { }; eval_map_type network_eval_map{ + // cell kind + {"cable-cell", make_call<>([]() { return arb::cell_kind::cable; }, "Cable cell kind")}, + {"lif-cell", make_call<>([]() { return arb::cell_kind::lif; }, "Lif cell kind")}, + {"benchmark-cell", + make_call<>([]() { return arb::cell_kind::benchmark; }, "Benchmark cell kind")}, + {"spike-source-cell", + make_call<>([]() { return arb::cell_kind::benchmark; }, "Spike source cell kind")}, + + // gid list + {"gid-list", + make_call([](cell_gid_type gid) { return gid_list(gid); }, + "List of global indices")}, + {"gid-list", + make_conversion_fold( + [](gid_list a, gid_list b) { + a.gids.insert(a.gids.end(), b.gids.begin(), b.gids.end()); + return a; + }, + "List of global indices with at least 2 arguments: ((gid-list | integer) (gid-list | " + "integer) [...(gid-list | " + "integer)])")}, + // network_selection {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, @@ -215,7 +247,53 @@ eval_map_type network_eval_map{ "(network_selection network_selection)")}, {"complement", make_call(arb::network_selection::complement, - "complement of given selection: (network_selection)")}, + "complement of given selection argument: (network_selection)")}, + {"source-cell-kind", + make_call(arb::network_selection::source_cell_kind, + "all sources of cells matching given cell kind argument: (kind:cell-kind)")}, + {"destination-cell-kind", + make_call(arb::network_selection::destination_cell_kind, + "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, + {"source-gid", + make_call( + [](cell_gid_type gid) { + return arb::network_selection::source_gid(std::vector({gid})); + }, + "all sources in cell with given gid: (gid:integer)")}, + {"source-gid", + make_call( + [](gid_list list) { return arb::network_selection::source_gid(std::move(list.gids)); }, + "all sources of cells gid in list argument: (list: gid-list)")}, + {"destination-gid", + make_call( + [](cell_gid_type gid) { + return arb::network_selection::destination_gid(std::vector({gid})); + }, + "all destinations in cell with given gid: (gid:integer)")}, + {"destination-gid", + make_call( + [](gid_list list) { + return arb::network_selection::destination_gid(std::move(list.gids)); + }, + "all destinations of cells gid in list argument: (list: gid-list)")}, + {"random-bernoulli", + make_call(arb::network_selection::random_bernoulli, + "randomly selected with given seed and probability. 2 arguments: (seed:integer, " + "p:real)")}, + {"random-linear-distance", + make_call( + arb::network_selection::random_linear_distance, + "randomly selected with a probability linearly interpolated between [p_begin, p_end] " + "based on the distance in the interval [distance_begin, distance_end]. 5 arguments: " + "(seed:integer, distance_begin:real, p_begin:real, distance_end:real, p_end:real)")}, + {"distance-lt", + make_call(arb::network_selection::distance_lt, + "Select if distance between source and destination is less than given distance in " + "micro meter: (distance:real)")}, + {"distance-gt", + make_call(arb::network_selection::distance_gt, + "Select if distance between source and destination is greater than given distance in " + "micro meter: (distance:real)")}, // network_value {"scalar", From c90dd98480efc8bebe280e47adf7392909969e3a Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 26 Mar 2023 10:29:37 +0200 Subject: [PATCH 16/84] value labels --- arbor/include/arbor/network.hpp | 4 ++ arbor/network.cpp | 63 +++++++++++++++++++++++ arborio/label_parse.cpp | 91 +++++++++++++++++++++++++-------- test/unit/CMakeLists.txt | 1 + 4 files changed, 138 insertions(+), 21 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index c70f341f7c..27c30fd69d 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -73,8 +73,12 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection source_gid(std::vector gids); + static network_selection source_gid(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection destination_gid(std::vector gids); + static network_selection destination_gid(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection intersect(network_selection left, network_selection right); static network_selection join(network_selection left, network_selection right); diff --git a/arbor/network.cpp b/arbor/network.cpp index 13a911d592..956667c821 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -238,6 +238,33 @@ struct network_selection_source_gid_impl: public network_selection_impl { } }; + +struct network_selection_source_gid_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end; + + explicit network_selection_source_gid_range_impl(cell_gid_type gid_begin, + cell_gid_type gid_end): + gid_begin(gid_begin), + gid_end(gid_end) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return src.gid >= gid_begin && src.gid < gid_end; + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return gid >= gid_begin && gid < gid_end; + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return true; + } +}; + struct network_selection_destination_gid_impl: public network_selection_impl { std::vector sorted_gids; @@ -264,6 +291,32 @@ struct network_selection_destination_gid_impl: public network_selection_impl { } }; +struct network_selection_destination_gid_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end; + + explicit network_selection_destination_gid_range_impl(cell_gid_type gid_begin, + cell_gid_type gid_end): + gid_begin(gid_begin), + gid_end(gid_end) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + return dest.gid >= gid_begin && dest.gid < gid_end; + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return true; + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return gid >= gid_begin && gid < gid_end; + } +}; + struct network_selection_complement_impl: public network_selection_impl { std::shared_ptr selection; @@ -877,11 +930,21 @@ network_selection network_selection::source_gid(std::vector gids) return network_selection(std::make_shared(std::move(gids))); } +network_selection network_selection::source_gid(cell_gid_type gid_begin, cell_gid_type gid_end) { + return network_selection(std::make_shared(gid_begin, gid_end)); +} + network_selection network_selection::destination_gid(std::vector gids) { return network_selection( std::make_shared(std::move(gids))); } +network_selection network_selection::destination_gid(cell_gid_type gid_begin, + cell_gid_type gid_end) { + return network_selection( + std::make_shared(gid_begin, gid_end)); +} + network_selection network_selection::complement(network_selection s) { return network_selection( std::make_shared(std::move(s.impl_))); diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index c9641606d6..453841d315 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -21,16 +21,26 @@ label_parse_error::label_parse_error(const std::string& msg, const arb::src_loca arb::arbor_exception(concat("error in label description: ", msg," at :", loc.line, ":", loc.column)) {} - namespace { -struct gid_list { - gid_list() = default; +struct gid_list_label { + gid_list_label() = default; - gid_list(cell_gid_type gid) : gids({gid}) {} + gid_list_label(cell_gid_type gid): gids({gid}) {} std::vector gids; }; +struct gid_range_label { + gid_range_label() = default; + + gid_range_label(cell_gid_type gid_begin, cell_gid_type gid_end): + gid_begin(gid_begin), + gid_end(gid_end) {} + + cell_gid_type gid_begin = 0; + cell_gid_type gid_end = 0; +}; + using eval_map_type= std::unordered_multimap; eval_map_type eval_map { @@ -205,13 +215,16 @@ eval_map_type network_eval_map{ {"spike-source-cell", make_call<>([]() { return arb::cell_kind::benchmark; }, "Spike source cell kind")}, - // gid list + // gid structs + {"gid-range", + make_call( + [](int gid_begin, int gid_end) { return gid_range_label(gid_begin, gid_end); }, + "Range of gids in interval [begin, end): (begin, end)")}, {"gid-list", - make_call([](cell_gid_type gid) { return gid_list(gid); }, - "List of global indices")}, + make_call([](int gid) { return gid_list_label(gid); }, "Single gid: (gid:integer)")}, {"gid-list", - make_conversion_fold( - [](gid_list a, gid_list b) { + make_conversion_fold( + [](gid_list_label a, gid_list_label b) { a.gids.insert(a.gids.end(), b.gids.begin(), b.gids.end()); return a; }, @@ -255,27 +268,43 @@ eval_map_type network_eval_map{ make_call(arb::network_selection::destination_cell_kind, "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, {"source-gid", - make_call( - [](cell_gid_type gid) { - return arb::network_selection::source_gid(std::vector({gid})); + make_call( + [](int gid) { + return arb::network_selection::source_gid( + std::vector({static_cast(gid)})); }, "all sources in cell with given gid: (gid:integer)")}, {"source-gid", - make_call( - [](gid_list list) { return arb::network_selection::source_gid(std::move(list.gids)); }, + make_call( + [](gid_list_label list) { + return arb::network_selection::source_gid(std::move(list.gids)); + }, "all sources of cells gid in list argument: (list: gid-list)")}, + {"source-gid", + make_call( + [](gid_range_label range) { + return arb::network_selection::source_gid(range.gid_begin, range.gid_end); + }, + "All sources of cells within gid range: (range: gid-range)")}, {"destination-gid", - make_call( - [](cell_gid_type gid) { - return arb::network_selection::destination_gid(std::vector({gid})); + make_call( + [](int gid) { + return arb::network_selection::destination_gid( + std::vector({static_cast(gid)})); }, "all destinations in cell with given gid: (gid:integer)")}, {"destination-gid", - make_call( - [](gid_list list) { + make_call( + [](gid_list_label list) { return arb::network_selection::destination_gid(std::move(list.gids)); }, "all destinations of cells gid in list argument: (list: gid-list)")}, + {"destination-gid", + make_call( + [](gid_range_label range) { + return arb::network_selection::destination_gid(range.gid_begin, range.gid_end); + }, + "All destinations of cells within gid range: (range: gid-range)")}, {"random-bernoulli", make_call(arb::network_selection::random_bernoulli, "randomly selected with given seed and probability. 2 arguments: (seed:integer, " @@ -298,10 +327,30 @@ eval_map_type network_eval_map{ // network_value {"scalar", make_call(arb::network_value::scalar, - "network value with 1 argument: (value:double)")}, + "network value with 1 argument: (value:real)")}, {"network-value", make_call(arb::network_value::named, "network value with 1 argument: (value:string)")}, + {"uniform_distribution", + make_call( + [](unsigned seed, double begin, double end) { + return arb::network_value::uniform_distribution(seed, {begin, end}); + }, + "Uniform random distribution within interval [begin, end): (seed:integer, begin:real, " + "end:real)")}, + {"normal_distribution", + make_call(arb::network_value::normal_distribution, + "Normal random distribution with given mean and standard deviation: (seed:integer, " + "mean:real, std_deviation:real)")}, + {"truncated_normal_distribution", + make_call( + [](unsigned seed, double mean, double std_deviation, double begin, double end) { + return arb::network_value::truncated_normal_distribution( + seed, mean, std_deviation, {begin, end}); + }, + "Truncated normal random distribution with given mean and standard deviation within " + "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " + "end:real)")}, }; @@ -402,7 +451,7 @@ parse_label_hopefully eval(const s_expr& e, const eval_map_type& map) location(e))); } -} // namespace +} // namespace ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const std::string& e) { return eval(parse_s_expr(e), eval_map); diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 5600454916..bea0c4c30f 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -108,6 +108,7 @@ set(unit_sources test_morph_primitives.cpp test_morph_stitch.cpp test_multi_event_stream.cpp + test_network.cpp test_ordered_forest.cpp test_padded.cpp test_partition.cpp From 1593e28c5d82eccc4d5075f2f0137c5542cc5a77 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 26 Mar 2023 12:25:12 +0200 Subject: [PATCH 17/84] implemented other cell types --- arbor/network_generation.cpp | 72 +++++++++++++++++++++++++++++------- arbor/network_generation.hpp | 2 +- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 582c30fd77..8af44f2382 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -1,19 +1,25 @@ #include "network_generation.hpp" +#include "cell_group_factory.hpp" #include "network_impl.hpp" #include "util/spatial_tree.hpp" #include +#include #include #include +#include #include +#include #include #include #include #include +#include #include #include +#include #include #include @@ -143,11 +149,13 @@ struct distributed_site_mapping { } // namespace std::vector generate_network_connections(const recipe& rec, - const distributed_context& distributed, + const context& ctx, const domain_decomposition& dom_dec) { const auto description_opt = rec.network_description(); if (!description_opt.has_value()) return {}; + const distributed_context& distributed = *(ctx->distributed); + const auto& description = description_opt.value(); site_mapping src_sites, dest_sites; @@ -164,6 +172,7 @@ std::vector generate_network_connections(const recipe& rec, for (const auto& group: dom_dec.groups()) { switch (group.kind) { case cell_kind::cable: { + // We need access to morphology, so the cell is create directly cable_cell cell; for (const auto& gid: group.gids) { try { @@ -202,7 +211,7 @@ std::vector generate_network_connections(const recipe& rec, for (const auto& p_det: cell.detectors()) { // TODO check if tag correct const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); - if (selection.select_destination(cell_kind::cable, gid, label)) { + if (selection.select_source(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_det.loc); src_sites.insert( {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); @@ -210,17 +219,54 @@ std::vector generate_network_connections(const recipe& rec, } } } break; - case cell_kind::lif: { - // TODO - for (const auto& gid: group.gids) {} - } break; - case cell_kind::benchmark: { - // TODO - for (const auto& gid: group.gids) {} - } break; - case cell_kind::spike_source: { - // TODO - for (const auto& gid: group.gids) {} + default: { + // Assuming all other cell types do not have a morphology. We can use label resolution + // through factory and set local position to 0. + auto factory = cell_kind_implementation(group.kind, group.backend, *ctx, 0); + + // We only need the label ranges + cell_label_range sources, destinations; + std::ignore = factory(group.gids, rec, sources, destinations); + + std::size_t source_label_offset = 0; + std::size_t destination_label_offset = 0; + for (std::size_t i = 0; i < group.gids.size(); ++i) { + const auto gid = group.gids[i]; + const auto iso = rec.get_cell_isometry(gid); + const auto point = iso.apply(mpoint{0.0, 0.0, 0.0, 0.0}); + const auto num_source_labels = sources.sizes().at(i); + const auto num_destination_labels = destinations.sizes().at(i); + + // Iterate over each source label for current gid + for (std::size_t j = source_label_offset; + j < source_label_offset + num_source_labels; + ++j) { + const auto& label = sources.labels().at(j); + const auto& range = sources.ranges().at(j); + for (auto lid = range.begin; lid < range.end; ++lid) { + if (selection.select_source(group.kind, gid, label)) { + src_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); + } + } + } + + // Iterate over each destination label for current gid + for (std::size_t j = destination_label_offset; + j < destination_label_offset + num_destination_labels; + ++j) { + const auto& label = destinations.labels().at(j); + const auto& range = destinations.ranges().at(j); + for (auto lid = range.begin; lid < range.end; ++lid) { + if (selection.select_destination(group.kind, gid, label)) { + dest_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); + } + } + } + + source_label_offset += num_source_labels; + destination_label_offset += num_destination_labels; + } + } break; } } diff --git a/arbor/network_generation.hpp b/arbor/network_generation.hpp index b19c97c3fe..bff19d4ccb 100644 --- a/arbor/network_generation.hpp +++ b/arbor/network_generation.hpp @@ -13,7 +13,7 @@ namespace arb { std::vector generate_network_connections(const recipe& rec, - const distributed_context& distributed, + const context& ctx, const domain_decomposition& dom_dec); } // namespace arb From 443ea8e2664ce636a25a67c97f1926d781b5fdb4 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 26 Mar 2023 18:55:19 +0200 Subject: [PATCH 18/84] now using network description to generate connections --- arbor/communication/communicator.cpp | 40 +++++++-- arbor/communication/communicator.hpp | 7 +- arbor/include/arbor/network.hpp | 4 + arbor/include/arbor/recipe.hpp | 8 +- arbor/include/arbor/simulation.hpp | 2 +- arbor/network.cpp | 92 ++++++++++++++++++++- arbor/simulation.cpp | 8 +- arborio/label_parse.cpp | 15 +++- python/network.cpp | 7 +- python/recipe.cpp | 2 + python/recipe.hpp | 14 +++- test/unit-distributed/test_communicator.cpp | 9 +- 12 files changed, 175 insertions(+), 33 deletions(-) diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index 9bffc2703f..ad2b5d26fc 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -18,21 +18,22 @@ #include "util/partition.hpp" #include "util/rangeutil.hpp" #include "util/span.hpp" +#include "network_generation.hpp" #include "communication/communicator.hpp" namespace arb { -communicator::communicator(const recipe& rec, - const domain_decomposition& dom_dec, - execution_context& ctx): num_total_cells_{rec.num_cells()}, - num_local_cells_{dom_dec.num_local_cells()}, - num_local_groups_{dom_dec.num_groups()}, - num_domains_{(cell_size_type) ctx.distributed->size()}, - distributed_{ctx.distributed}, - thread_pool_{ctx.thread_pool} {} +communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx): + num_total_cells_{rec.num_cells()}, + num_local_cells_{dom_dec.num_local_cells()}, + num_local_groups_{dom_dec.num_groups()}, + num_domains_{(cell_size_type)ctx->distributed->size()}, + ctx_(ctx), + distributed_{ctx->distributed}, + thread_pool_{ctx->thread_pool} {} -void communicator::update_connections(const connectivity& rec, +void communicator::update_connections(const recipe& rec, const domain_decomposition& dom_dec, const label_resolution_map& source_resolution_map, const label_resolution_map& target_resolution_map) { @@ -42,6 +43,10 @@ void communicator::update_connections(const connectivity& rec, index_divisions_.clear(); index_on_domain_.clear(); + + // Construct connections from high-level specification + auto generated_connections = generate_network_connections(rec, ctx_, dom_dec); + // For caching information about each cell struct gid_info { using connection_list = decltype(std::declval().connections_on(0)); @@ -76,6 +81,7 @@ void communicator::update_connections(const connectivity& rec, // Build the connection information for local cells in parallel. cell_local_size_type n_cons = util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); + n_cons += generated_connections.size(); std::vector src_domains; src_domains.reserve(n_cons); std::vector src_counts(num_domains_); @@ -91,6 +97,15 @@ void communicator::update_connections(const connectivity& rec, src_counts[src]++; } } + for (const auto& c: generated_connections) { + auto sgid = c.source.gid; + if (sgid >= num_total_cells_) { + throw arb::bad_connection_source_gid(c.source.gid, sgid, num_total_cells_); + } + const auto src = dom_dec.gid_domain(sgid); + src_domains.push_back(src); + src_counts[src]++; + } // Construct the connections. // The loop above gave the information required to construct in place @@ -111,6 +126,13 @@ void communicator::update_connections(const connectivity& rec, {c.source.gid, src_lid}, {cell.gid, tgt_lid}, c.weight, c.delay}; } } + for (const auto& c: generated_connections) { + auto offset = offsets[*src_domain]++; + ++src_domain; + connections_[offset] = c; + } + + // Build cell partition by group for passing events to cell groups index_part_ = util::make_partition(index_divisions_, diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index 850ce0f119..b1c83b7d67 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -32,9 +32,7 @@ class ARB_ARBOR_API communicator { public: communicator() = default; - explicit communicator(const recipe& rec, - const domain_decomposition& dom_dec, - execution_context& ctx); + explicit communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx); /// The range of event queues that belong to cells in group i. std::pair group_queue_range(cell_size_type i); @@ -70,7 +68,7 @@ class ARB_ARBOR_API communicator { void reset(); - void update_connections(const connectivity& rec, + void update_connections(const recipe& rec, const domain_decomposition& dom_dec, const label_resolution_map& source_resolution_map, const label_resolution_map& target_resolution_map); @@ -86,6 +84,7 @@ class ARB_ARBOR_API communicator { util::partition_view_type> index_part_; std::unordered_map index_on_domain_; + context ctx_; distributed_context_handle distributed_; task_system_handle thread_pool_; std::uint64_t num_spikes_ = 0u; diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 27c30fd69d..78136afd77 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -79,6 +79,10 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection destination_gid(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection ring(std::vector gids); + + static network_selection ring(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection intersect(network_selection left, network_selection right); static network_selection join(network_selection left, network_selection right); diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 24b9b2d335..69cccd83ec 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -68,10 +68,6 @@ struct ARB_ARBOR_API has_gap_junctions { virtual std::vector gap_junctions_on(cell_gid_type) const { return {}; } - // Optional network descriptions for generating cell connections - virtual std::optional network_description() const { - return std::nullopt; - }; virtual ~has_gap_junctions() {} }; @@ -79,6 +75,10 @@ struct ARB_ARBOR_API has_synapses { virtual std::vector connections_on(cell_gid_type) const { return {}; } + // Optional network descriptions for generating cell connections + virtual std::optional network_description() const { + return std::nullopt; + }; virtual ~has_synapses() {} }; diff --git a/arbor/include/arbor/simulation.hpp b/arbor/include/arbor/simulation.hpp index 4acdadafca..def7e2960d 100644 --- a/arbor/include/arbor/simulation.hpp +++ b/arbor/include/arbor/simulation.hpp @@ -45,7 +45,7 @@ class ARB_ARBOR_API simulation { static simulation_builder create(recipe const &); - void update(const connectivity& rec); + void update(const recipe& rec); void reset(); diff --git a/arbor/network.cpp b/arbor/network.cpp index 956667c821..6e771dcf17 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -242,7 +242,7 @@ struct network_selection_source_gid_impl: public network_selection_impl { struct network_selection_source_gid_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end; - explicit network_selection_source_gid_range_impl(cell_gid_type gid_begin, + network_selection_source_gid_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end): gid_begin(gid_begin), gid_end(gid_end) {} @@ -268,7 +268,7 @@ struct network_selection_source_gid_range_impl: public network_selection_impl { struct network_selection_destination_gid_impl: public network_selection_impl { std::vector sorted_gids; - explicit network_selection_destination_gid_impl(std::vector gids): + network_selection_destination_gid_impl(std::vector gids): sorted_gids(std::move(gids)) { std::sort(sorted_gids.begin(), sorted_gids.end()); } @@ -294,7 +294,7 @@ struct network_selection_destination_gid_impl: public network_selection_impl { struct network_selection_destination_gid_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end; - explicit network_selection_destination_gid_range_impl(cell_gid_type gid_begin, + network_selection_destination_gid_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end): gid_begin(gid_begin), gid_end(gid_end) {} @@ -317,6 +317,80 @@ struct network_selection_destination_gid_range_impl: public network_selection_im } }; +struct network_selection_ring_impl: public network_selection_impl { + std::vector gids; // preserved order of ring + std::vector sorted_gids; + network_selection_ring_impl(std::vector gids): gids(std::move(gids)) { + sorted_gids = this->gids; // copy + std::sort(sorted_gids.begin(), sorted_gids.end()); + } + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + if(gids.empty()) return false; + + // gids size always > 0 frome here on + + // First check if both are part of ring + if (!std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid) || + !std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid)) + return false; + + for(std::size_t i = 0; i < gids.size() - 1; ++i) { + // return true if neighbors in gids list + if ((src.gid == gids[i] && dest.gid == gids[i + 1])) return true; + } + + // return true if front and back gid to close ring + if ((dest.gid == gids.front() && src.gid == gids.back())) return true; + + return false; + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + } +}; + + +struct network_selection_ring_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end; + + network_selection_ring_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end): + gid_begin(gid_begin), + gid_end(gid_end) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + if (src.gid < gid_begin || src.gid >= gid_end) return false; + if (dest.gid < gid_begin || dest.gid >= gid_end) return false; + + return src.gid + 1 == dest.gid || (dest.gid == gid_begin && src.gid == gid_end - 1); + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if (gid < gid_begin || gid >= gid_end) return false; + return true; + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if (gid < gid_begin || gid >= gid_end) return false; + return true; + } +}; + struct network_selection_complement_impl: public network_selection_impl { std::shared_ptr selection; @@ -945,6 +1019,18 @@ network_selection network_selection::destination_gid(cell_gid_type gid_begin, std::make_shared(gid_begin, gid_end)); } + +network_selection network_selection::ring(std::vector gids) { + return network_selection( + std::make_shared(std::move(gids))); +} + +network_selection network_selection::ring(cell_gid_type gid_begin, + cell_gid_type gid_end) { + return network_selection( + std::make_shared(gid_begin, gid_end)); +} + network_selection network_selection::complement(network_selection s) { return network_selection( std::make_shared(std::move(s.impl_))); diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index 887462d12e..4cc447ad8d 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -92,7 +92,7 @@ class simulation_state { public: simulation_state(const recipe& rec, const domain_decomposition& decomp, context ctx, arb_seed_type seed); - void update(const connectivity& rec); + void update(const recipe& rec); void reset(); @@ -222,12 +222,12 @@ simulation_state::simulation_state( source_resolution_map_ = label_resolution_map(std::move(global_sources)); target_resolution_map_ = label_resolution_map(std::move(local_targets)); - communicator_ = communicator(rec, ddc_, *ctx_); + communicator_ = communicator(rec, ddc_, ctx_); update(rec); epoch_.reset(); } -void simulation_state::update(const connectivity& rec) { +void simulation_state::update(const recipe& rec) { communicator_.update_connections(rec, ddc_, source_resolution_map_, target_resolution_map_); // Use half minimum delay of the network for max integration interval. t_interval_ = communicator_.min_delay()/2; @@ -530,7 +530,7 @@ void simulation::reset() { impl_->reset(); } -void simulation::update(const connectivity& rec) { impl_->update(rec); } +void simulation::update(const recipe& rec) { impl_->update(rec); } time_type simulation::run(time_type tfinal, time_type dt) { if (dt <= 0.0) { diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 453841d315..b98a906369 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -267,6 +267,7 @@ eval_map_type network_eval_map{ {"destination-cell-kind", make_call(arb::network_selection::destination_cell_kind, "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, + //TODO source / destination label {"source-gid", make_call( [](int gid) { @@ -304,7 +305,19 @@ eval_map_type network_eval_map{ [](gid_range_label range) { return arb::network_selection::destination_gid(range.gid_begin, range.gid_end); }, - "All destinations of cells within gid range: (range: gid-range)")}, + "all destinations of cells within gid range: (range: gid-range)")}, + {"ring", + make_call( + [](gid_list_label list) { return arb::network_selection::ring(std::move(list.gids)); }, + "Only select connections between neighboring gids in list or between first and last " + "entry: (gids: gid-list)")}, + {"ring", + make_call( + [](gid_range_label range) { + return arb::network_selection::ring(range.gid_begin, range.gid_end); + }, + "Only select connections between neighboring gids in range [begin, end) or between " + "first and last range member: (range: gid-range)")}, {"random-bernoulli", make_call(arb::network_selection::random_bernoulli, "randomly selected with given seed and probability. 2 arguments: (seed:integer, " diff --git a/python/network.cpp b/python/network.cpp index a135b48527..7c7c4f1102 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -95,7 +95,12 @@ void register_network(py::module& m) { [&](const arb::network_value& val) { dict.set(dict_label, val); }), v); } - return arb::network_description(); + auto desc = arb::network_description{ + arborio::parse_network_selection_expression(selection).unwrap(), + arborio::parse_network_value_expression(weight).unwrap(), + arborio::parse_network_value_expression(delay).unwrap(), + dict}; + return desc; }), "selection"_a, "weight"_a, diff --git a/python/recipe.cpp b/python/recipe.cpp index e982f5a943..6fc8eb826d 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -202,6 +202,8 @@ void register_recipe(pybind11::module& m) { "A list of the gap junctions connected to gid, [] by default.") .def("network_description", &py_recipe::network_description, "Network description of cell connections.") + .def("cell_isometry", &py_recipe::cell_isometry, + "Isometry describing translation and rotation of cell.") .def("probes", &py_recipe::probes, "gid"_a, "The probes to allow monitoring.") diff --git a/python/recipe.hpp b/python/recipe.hpp index c8fbc1df45..2d5f28eda1 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -53,6 +54,9 @@ class py_recipe { virtual std::optional network_description() const { return std::nullopt; }; + virtual arb::isometry cell_isometry(arb::cell_gid_type gid) const { + return arb::isometry(); + }; }; class py_recipe_trampoline: public py_recipe { @@ -82,9 +86,13 @@ class py_recipe_trampoline: public py_recipe { } std::optional network_description() const override { - PYBIND11_OVERRIDE_PURE(arb::network_description, py_recipe, network_description); + PYBIND11_OVERRIDE(std::optional, py_recipe, network_description); } + arb::isometry cell_isometry(arb::cell_gid_type gid) const override { + PYBIND11_OVERRIDE(arb::isometry, py_recipe, cell_isometry, gid); + }; + std::vector probes(arb::cell_gid_type gid) const override { PYBIND11_OVERRIDE(std::vector, py_recipe, probes, gid); } @@ -142,6 +150,10 @@ class py_recipe_shim: public arb::recipe { std::optional network_description() const override { return try_catch_pyexception([&]() { return impl_->network_description(); }, msg); }; + + arb::isometry get_cell_isometry(arb::cell_gid_type gid) const override { + return try_catch_pyexception([&]() { return impl_->cell_isometry(gid); }, msg); + }; }; } // namespace pyarb diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index 5884a31142..4149ea8a50 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -531,7 +531,7 @@ TEST(communicator, ring) auto global_sources = g_context->distributed->gather_cell_labels_and_gids(local_sources); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets)); // every cell fires EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;})); @@ -638,7 +638,7 @@ TEST(communicator, all2all) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, mc_gids}); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids})); auto connections = C.connections(); @@ -648,7 +648,6 @@ TEST(communicator, all2all) EXPECT_EQ(i, c.source.gid); EXPECT_EQ(0u, c.source.index); EXPECT_EQ(i, c.destination); - EXPECT_LT(c.index_on_domain, n_local); } } @@ -685,13 +684,13 @@ TEST(communicator, mini_network) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, gids}); // construct the communicator - auto C = communicator(R, D, *g_context); + auto C = communicator(R, D, g_context); C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids})); // sort connections by source then target auto connections = C.connections(); util::sort(connections, [](const connection& lhs, const connection& rhs) { - return std::forward_as_tuple(lhs.source, lhs.index_on_domain, lhs.destination) < std::forward_as_tuple(rhs.source, rhs.index_on_domain, rhs.destination); + return std::forward_as_tuple(lhs.source, lhs.destination) < std::forward_as_tuple(rhs.source, rhs.destination); }); // Expect one set of 22 connections from every rank: these have been sorted. From 8c47e9686972a29316de102239aa0e7935d79da2 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 2 Apr 2023 10:54:32 +0200 Subject: [PATCH 19/84] chain instead of ring --- arbor/include/arbor/network.hpp | 10 ++- arbor/network.cpp | 115 +++++++++++++++++-------- arborio/label_parse.cpp | 147 +++++++++++++++----------------- 3 files changed, 154 insertions(+), 118 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 78136afd77..c850a885a2 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -73,15 +73,17 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection source_gid(std::vector gids); - static network_selection source_gid(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection source_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); static network_selection destination_gid(std::vector gids); - static network_selection destination_gid(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection destination_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); - static network_selection ring(std::vector gids); + static network_selection chain(std::vector gids); - static network_selection ring(cell_gid_type gid_begin, cell_gid_type gid_end); + static network_selection chain_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + + static network_selection reverse_chain_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); static network_selection intersect(network_selection left, network_selection right); diff --git a/arbor/network.cpp b/arbor/network.cpp index 6e771dcf17..85bcf7bc2e 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -240,22 +240,24 @@ struct network_selection_source_gid_impl: public network_selection_impl { struct network_selection_source_gid_range_impl: public network_selection_impl { - cell_gid_type gid_begin, gid_end; + cell_gid_type gid_begin, gid_end, step; network_selection_source_gid_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end): + cell_gid_type gid_end, + cell_gid_type step): gid_begin(gid_begin), - gid_end(gid_end) {} + gid_end(gid_end), + step(step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return src.gid >= gid_begin && src.gid < gid_end; + return src.gid >= gid_begin && src.gid < gid_end && !((src.gid - gid_begin) % step); } bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return gid >= gid_begin && gid < gid_end; + return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } bool select_destination(cell_kind kind, @@ -292,16 +294,18 @@ struct network_selection_destination_gid_impl: public network_selection_impl { }; struct network_selection_destination_gid_range_impl: public network_selection_impl { - cell_gid_type gid_begin, gid_end; + cell_gid_type gid_begin, gid_end, step; network_selection_destination_gid_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end): + cell_gid_type gid_end, + cell_gid_type step): gid_begin(gid_begin), - gid_end(gid_end) {} + gid_end(gid_end), + step(step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return dest.gid >= gid_begin && dest.gid < gid_end; + return dest.gid >= gid_begin && dest.gid < gid_end && !((dest.gid - gid_begin) % step); } bool select_source(cell_kind kind, @@ -313,14 +317,14 @@ struct network_selection_destination_gid_range_impl: public network_selection_im bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return gid >= gid_begin && gid < gid_end; + return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } }; -struct network_selection_ring_impl: public network_selection_impl { +struct network_selection_chain_impl: public network_selection_impl { std::vector gids; // preserved order of ring std::vector sorted_gids; - network_selection_ring_impl(std::vector gids): gids(std::move(gids)) { + network_selection_chain_impl(std::vector gids): gids(std::move(gids)) { sorted_gids = this->gids; // copy std::sort(sorted_gids.begin(), sorted_gids.end()); } @@ -341,9 +345,6 @@ struct network_selection_ring_impl: public network_selection_impl { if ((src.gid == gids[i] && dest.gid == gids[i + 1])) return true; } - // return true if front and back gid to close ring - if ((dest.gid == gids.front() && src.gid == gids.back())) return true; - return false; } @@ -361,33 +362,71 @@ struct network_selection_ring_impl: public network_selection_impl { }; -struct network_selection_ring_range_impl: public network_selection_impl { - cell_gid_type gid_begin, gid_end; +struct network_selection_chain_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; - network_selection_ring_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end): + network_selection_chain_range_impl(cell_gid_type gid_begin, + cell_gid_type gid_end, + cell_gid_type step): gid_begin(gid_begin), - gid_end(gid_end) {} + gid_end(gid_end), + step(step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - if (src.gid < gid_begin || src.gid >= gid_end) return false; - if (dest.gid < gid_begin || dest.gid >= gid_end) return false; + if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || + dest.gid >= gid_end) + return false; - return src.gid + 1 == dest.gid || (dest.gid == gid_begin && src.gid == gid_end - 1); + return src.gid + step == dest.gid && !((src.gid - gid_begin) % step); } bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { if (gid < gid_begin || gid >= gid_end) return false; - return true; + return !((gid - gid_begin) % step); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { if (gid < gid_begin || gid >= gid_end) return false; - return true; + return !((gid - gid_begin) % step); + } +}; + +struct network_selection_reverse_chain_range_impl: public network_selection_impl { + cell_gid_type gid_begin, gid_end, step; + + network_selection_reverse_chain_range_impl(cell_gid_type gid_begin, + cell_gid_type gid_end, + cell_gid_type step): + gid_begin(gid_begin), + gid_end(gid_end), + step(step) {} + + bool select_connection(const network_site_info& src, + const network_site_info& dest) const override { + if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || + dest.gid >= gid_end) + return false; + + return dest.gid + step == src.gid && !((src.gid - gid_begin) % step); + } + + bool select_source(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if (gid < gid_begin || gid >= gid_end) return false; + return !((gid - gid_begin) % step); + } + + bool select_destination(cell_kind kind, + cell_gid_type gid, + const std::string_view& label) const override { + if (gid < gid_begin || gid >= gid_end) return false; + return !((gid - gid_begin) % step); } }; @@ -1004,8 +1043,9 @@ network_selection network_selection::source_gid(std::vector gids) return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::source_gid(cell_gid_type gid_begin, cell_gid_type gid_end) { - return network_selection(std::make_shared(gid_begin, gid_end)); +network_selection network_selection::source_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step) { + return network_selection( + std::make_shared(gid_begin, gid_end, step)); } network_selection network_selection::destination_gid(std::vector gids) { @@ -1013,22 +1053,29 @@ network_selection network_selection::destination_gid(std::vector std::make_shared(std::move(gids))); } -network_selection network_selection::destination_gid(cell_gid_type gid_begin, - cell_gid_type gid_end) { +network_selection network_selection::destination_gid_range(cell_gid_type gid_begin, + cell_gid_type gid_end, cell_gid_type step) { return network_selection( - std::make_shared(gid_begin, gid_end)); + std::make_shared(gid_begin, gid_end, step)); } +network_selection network_selection::chain(std::vector gids) { + return network_selection( + std::make_shared(std::move(gids))); +} -network_selection network_selection::ring(std::vector gids) { +network_selection network_selection::chain_range(cell_gid_type gid_begin, + cell_gid_type gid_end, + cell_gid_type step) { return network_selection( - std::make_shared(std::move(gids))); + std::make_shared(gid_begin, gid_end, step)); } -network_selection network_selection::ring(cell_gid_type gid_begin, - cell_gid_type gid_end) { +network_selection network_selection::reverse_chain_range(cell_gid_type gid_begin, + cell_gid_type gid_end, + cell_gid_type step) { return network_selection( - std::make_shared(gid_begin, gid_end)); + std::make_shared(gid_begin, gid_end, step)); } network_selection network_selection::complement(network_selection s) { diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index b98a906369..de50bf6709 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -22,25 +22,6 @@ label_parse_error::label_parse_error(const std::string& msg, const arb::src_loca {} namespace { -struct gid_list_label { - gid_list_label() = default; - - gid_list_label(cell_gid_type gid): gids({gid}) {} - - std::vector gids; -}; - -struct gid_range_label { - gid_range_label() = default; - - gid_range_label(cell_gid_type gid_begin, cell_gid_type gid_end): - gid_begin(gid_begin), - gid_end(gid_end) {} - - cell_gid_type gid_begin = 0; - cell_gid_type gid_end = 0; -}; - using eval_map_type= std::unordered_multimap; eval_map_type eval_map { @@ -215,23 +196,6 @@ eval_map_type network_eval_map{ {"spike-source-cell", make_call<>([]() { return arb::cell_kind::benchmark; }, "Spike source cell kind")}, - // gid structs - {"gid-range", - make_call( - [](int gid_begin, int gid_end) { return gid_range_label(gid_begin, gid_end); }, - "Range of gids in interval [begin, end): (begin, end)")}, - {"gid-list", - make_call([](int gid) { return gid_list_label(gid); }, "Single gid: (gid:integer)")}, - {"gid-list", - make_conversion_fold( - [](gid_list_label a, gid_list_label b) { - a.gids.insert(a.gids.end(), b.gids.begin(), b.gids.end()); - return a; - }, - "List of global indices with at least 2 arguments: ((gid-list | integer) (gid-list | " - "integer) [...(gid-list | " - "integer)])")}, - // network_selection {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, @@ -267,57 +231,80 @@ eval_map_type network_eval_map{ {"destination-cell-kind", make_call(arb::network_selection::destination_cell_kind, "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, - //TODO source / destination label - {"source-gid", - make_call( - [](int gid) { - return arb::network_selection::source_gid( - std::vector({static_cast(gid)})); + // TODO source / destination label + {"source-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform(vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& + x) { + return std::get(x); + }); + return arb::network_selection::source_label(std::move(labels)); }, - "all sources in cell with given gid: (gid:integer)")}, - {"source-gid", - make_call( - [](gid_list_label list) { - return arb::network_selection::source_gid(std::move(list.gids)); + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"destination-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform(vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& + x) { + return std::get(x); + }); + return arb::network_selection::destination_label(std::move(labels)); }, - "all sources of cells gid in list argument: (list: gid-list)")}, + "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, {"source-gid", - make_call( - [](gid_range_label range) { - return arb::network_selection::source_gid(range.gid_begin, range.gid_end); - }, - "All sources of cells within gid range: (range: gid-range)")}, - {"destination-gid", - make_call( - [](int gid) { - return arb::network_selection::destination_gid( - std::vector({static_cast(gid)})); - }, - "all destinations in cell with given gid: (gid:integer)")}, - {"destination-gid", - make_call( - [](gid_list_label list) { - return arb::network_selection::destination_gid(std::move(list.gids)); + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::source_gid(std::move(gids)); }, - "all destinations of cells gid in list argument: (list: gid-list)")}, + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"source-gid-range", + make_call(arb::network_selection::source_gid_range, + "all sources in cell with gid range [begin, end) with given step size: (begin:integer) " + "(end:integer) (step:integer)")}, {"destination-gid", - make_call( - [](gid_range_label range) { - return arb::network_selection::destination_gid(range.gid_begin, range.gid_end); + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::destination_gid(std::move(gids)); }, - "all destinations of cells within gid range: (range: gid-range)")}, - {"ring", - make_call( - [](gid_list_label list) { return arb::network_selection::ring(std::move(list.gids)); }, - "Only select connections between neighboring gids in list or between first and last " - "entry: (gids: gid-list)")}, - {"ring", - make_call( - [](gid_range_label range) { - return arb::network_selection::ring(range.gid_begin, range.gid_end); + "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"destination-gid-range", + make_call(arb::network_selection::destination_gid_range, + "all destinations in cell with gid range [begin, end) with given step size: " + "(begin:integer) (end:integer) (step:integer)")}, + {"chain", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::chain(std::move(gids)); }, - "Only select connections between neighboring gids in range [begin, end) or between " - "first and last range member: (range: gid-range)")}, + "A chain of connections in the given order of gids in list, such that entry \"i\" is " + "the source and entry \"i+1\" the destination: (gid:integer) [...(gid:integer)]")}, + {"chain-range", + make_call(arb::network_selection::chain_range, + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i\" is connected as source to the destination \"i+1\": (begin:integer) " + "(end:integer) (step:integer)")}, + {"reverse-chain-range", + make_call(arb::network_selection::reverse_chain_range, + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i+1\" is connected as source to the destination \"i\". This results in " + "connection directions in reverse compared to the (chain-range ...) selection: " + "(begin:integer) " + "(end:integer) (step:integer)")}, {"random-bernoulli", make_call(arb::network_selection::random_bernoulli, "randomly selected with given seed and probability. 2 arguments: (seed:integer, " From f1b351cc1a2469a3bb61e32cc96e5eae45acb864 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 2 Apr 2023 19:58:12 +0200 Subject: [PATCH 20/84] rename selection labels --- arbor/include/arbor/network.hpp | 8 ++++---- arbor/network.cpp | 32 ++++++++++++++++---------------- arborio/label_parse.cpp | 17 ++++++++--------- python/network.cpp | 1 + 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index c850a885a2..83ab9ef984 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -71,13 +71,13 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection destination_label(std::vector labels); - static network_selection source_gid(std::vector gids); + static network_selection source_cell(std::vector gids); - static network_selection source_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection source_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); - static network_selection destination_gid(std::vector gids); + static network_selection destination_cell(std::vector gids); - static network_selection destination_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection destination_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); static network_selection chain(std::vector gids); diff --git a/arbor/network.cpp b/arbor/network.cpp index 85bcf7bc2e..fd31eb396e 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -212,10 +212,10 @@ struct network_selection_destination_label_impl: public network_selection_impl { } }; -struct network_selection_source_gid_impl: public network_selection_impl { +struct network_selection_source_cell_impl: public network_selection_impl { std::vector sorted_gids; - explicit network_selection_source_gid_impl(std::vector gids): + explicit network_selection_source_cell_impl(std::vector gids): sorted_gids(std::move(gids)) { std::sort(sorted_gids.begin(), sorted_gids.end()); } @@ -239,10 +239,10 @@ struct network_selection_source_gid_impl: public network_selection_impl { }; -struct network_selection_source_gid_range_impl: public network_selection_impl { +struct network_selection_source_cell_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_source_gid_range_impl(cell_gid_type gid_begin, + network_selection_source_cell_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step): gid_begin(gid_begin), @@ -267,10 +267,10 @@ struct network_selection_source_gid_range_impl: public network_selection_impl { } }; -struct network_selection_destination_gid_impl: public network_selection_impl { +struct network_selection_destination_cell_impl: public network_selection_impl { std::vector sorted_gids; - network_selection_destination_gid_impl(std::vector gids): + network_selection_destination_cell_impl(std::vector gids): sorted_gids(std::move(gids)) { std::sort(sorted_gids.begin(), sorted_gids.end()); } @@ -293,10 +293,10 @@ struct network_selection_destination_gid_impl: public network_selection_impl { } }; -struct network_selection_destination_gid_range_impl: public network_selection_impl { +struct network_selection_destination_cell_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_destination_gid_range_impl(cell_gid_type gid_begin, + network_selection_destination_cell_range_impl(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step): gid_begin(gid_begin), @@ -1039,24 +1039,24 @@ network_selection network_selection::destination_label(std::vector(std::move(labels))); } -network_selection network_selection::source_gid(std::vector gids) { - return network_selection(std::make_shared(std::move(gids))); +network_selection network_selection::source_cell(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::source_gid_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step) { +network_selection network_selection::source_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step) { return network_selection( - std::make_shared(gid_begin, gid_end, step)); + std::make_shared(gid_begin, gid_end, step)); } -network_selection network_selection::destination_gid(std::vector gids) { +network_selection network_selection::destination_cell(std::vector gids) { return network_selection( - std::make_shared(std::move(gids))); + std::make_shared(std::move(gids))); } -network_selection network_selection::destination_gid_range(cell_gid_type gid_begin, +network_selection network_selection::destination_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step) { return network_selection( - std::make_shared(gid_begin, gid_end, step)); + std::make_shared(gid_begin, gid_end, step)); } network_selection network_selection::chain(std::vector gids) { diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index de50bf6709..769627675d 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -231,7 +231,6 @@ eval_map_type network_eval_map{ {"destination-cell-kind", make_call(arb::network_selection::destination_cell_kind, "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, - // TODO source / destination label {"source-label", make_arg_vec_call( [](const std::vector>& vec) { @@ -254,32 +253,32 @@ eval_map_type network_eval_map{ return arb::network_selection::destination_label(std::move(labels)); }, "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"source-gid", + {"source-cell", make_arg_vec_call( [](const std::vector>& vec) { std::vector gids; std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { return std::get(x); }); - return arb::network_selection::source_gid(std::move(gids)); + return arb::network_selection::source_cell(std::move(gids)); }, "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"source-gid-range", - make_call(arb::network_selection::source_gid_range, + {"source-cell-range", + make_call(arb::network_selection::source_cell_range, "all sources in cell with gid range [begin, end) with given step size: (begin:integer) " "(end:integer) (step:integer)")}, - {"destination-gid", + {"destination-cell", make_arg_vec_call( [](const std::vector>& vec) { std::vector gids; std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { return std::get(x); }); - return arb::network_selection::destination_gid(std::move(gids)); + return arb::network_selection::destination_cell(std::move(gids)); }, "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-gid-range", - make_call(arb::network_selection::destination_gid_range, + {"destination-cell-range", + make_call(arb::network_selection::destination_cell_range, "all destinations in cell with gid range [begin, end) with given step size: " "(begin:integer) (end:integer) (step:integer)")}, {"chain", diff --git a/python/network.cpp b/python/network.cpp index 7c7c4f1102..242bab2ee9 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include From b993f2f442d5e6ebb5fb9cfc2ff5f6411997f596 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 11 Apr 2023 13:52:39 +0200 Subject: [PATCH 21/84] gid_range added to common types --- arbor/include/arbor/common_types.hpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index 4cf6a41d6d..4a636614b6 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -66,6 +66,19 @@ struct lid_range { begin(b), end(e) {} }; +// Global range of indices with given step size. + +struct gid_range { + cell_gid_type begin = 0; + cell_gid_type end = 0; + cell_gid_type step = 1; + gid_range() = default; + gid_range(cell_gid_type b, cell_gid_type e): + begin(b), end(e), step(1) {} + gid_range(cell_gid_type b, cell_gid_type e, cell_gid_type s): + begin(b), end(e), step(s) {} +}; + // Policy for selecting a cell_lid_type from a range of possible values. enum class lid_selection_policy { From 35dc0edb97a0d5b4c8dfe560a75c729a68a3b6b4 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 13 Apr 2023 13:10:25 +0200 Subject: [PATCH 22/84] renaming and test label parsing --- arbor/include/arbor/network.hpp | 20 +- arbor/network.cpp | 326 ++++++++++----- arbor/network_impl.hpp | 5 + arborio/include/arborio/label_parse.hpp | 3 +- arborio/label_parse.cpp | 500 +++++++++++++----------- python/network.cpp | 71 ++-- test/unit/test_s_expr.cpp | 82 ++++ 7 files changed, 658 insertions(+), 349 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 83ab9ef984..d2ad6b9dc2 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -12,12 +12,12 @@ #include #include #include +#include #include #include +#include #include #include -#include -#include namespace arb { @@ -73,17 +73,17 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection source_cell(std::vector gids); - static network_selection source_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection source_cell(gid_range range); static network_selection destination_cell(std::vector gids); - static network_selection destination_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection destination_cell(gid_range range); static network_selection chain(std::vector gids); - static network_selection chain_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection chain(gid_range range); - static network_selection reverse_chain_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step); + static network_selection chain_reverse(gid_range range); static network_selection intersect(network_selection left, network_selection right); @@ -110,11 +110,11 @@ class ARB_SYMBOL_VISIBLE network_selection { // only select within given distance. This may enable more efficient sampling through an // internal spatial data structure. - static network_selection distance_lt(double distance); + static network_selection distance_lt(double d); // only select if distance greater then given distance. This may enable more efficient sampling // through an internal spatial data structure. - static network_selection distance_gt(double distance); + static network_selection distance_gt(double d); // randomly selected with a probability linearly interpolated between [p_begin, p_end] based on // the distance in the interval [distance_begin, distance_end]. @@ -124,6 +124,8 @@ class ARB_SYMBOL_VISIBLE network_selection { double distance_end, double p_end); + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_selection& s); + private: network_selection(std::shared_ptr impl); @@ -173,6 +175,8 @@ class ARB_SYMBOL_VISIBLE network_value { // "func" must be symmetric (func(a,b) = func(b,a)). static network_value custom(custom_func_type func); + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_value& v); + private: network_value(std::shared_ptr impl); diff --git a/arbor/network.cpp b/arbor/network.cpp index fd31eb396e..abf0fad636 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -71,10 +71,6 @@ double normal_rand_from_key_pair(std::array seed, return r123::boxmuller(rand_num[0], rand_num[1]).x; } -double network_location_distance(const mpoint& a, const mpoint& b) { - return std::sqrt(a.x * b.x + a.y * b.y + a.z * b.z); -} - struct network_selection_all_impl: public network_selection_impl { bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -92,6 +88,8 @@ struct network_selection_all_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { os << "(all)"; } }; struct network_selection_none_impl: public network_selection_impl { @@ -112,6 +110,8 @@ struct network_selection_none_impl: public network_selection_impl { const std::string_view& label) const override { return false; } + + void print(std::ostream& os) const override { os << "(none)"; } }; struct network_selection_source_cell_kind_impl: public network_selection_impl { @@ -135,6 +135,17 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { + os << "(source-cell-kind ("; + switch (select_kind) { + case arb::cell_kind::spike_source: os << "spike-source"; break; + case arb::cell_kind::cable: os << "cable"; break; + case arb::cell_kind::lif: os << "lif"; break; + case arb::cell_kind::benchmark: os << "benchmark"; break; + } + os << "-cell))"; + } }; struct network_selection_destination_cell_kind_impl: public network_selection_impl { @@ -158,6 +169,17 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im const std::string_view& label) const override { return kind == select_kind; } + + void print(std::ostream& os) const override { + os << "(destination-cell-kind ("; + switch (select_kind) { + case arb::cell_kind::spike_source: os << "spike-source"; break; + case arb::cell_kind::cable: os << "cable"; break; + case arb::cell_kind::lif: os << "lif"; break; + case arb::cell_kind::benchmark: os << "benchmark"; break; + } + os << "-cell))"; + } }; struct network_selection_source_label_impl: public network_selection_impl { @@ -184,6 +206,12 @@ struct network_selection_source_label_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { + os << "(source-label"; + for (const auto& l: sorted_labels) { os << " \"" << l << "\""; } + os << ")"; + } }; struct network_selection_destination_label_impl: public network_selection_impl { @@ -210,6 +238,12 @@ struct network_selection_destination_label_impl: public network_selection_impl { const std::string_view& label) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } + + void print(std::ostream& os) const override { + os << "(destination-label"; + for (const auto& l: sorted_labels) { os << " \"" << l << "\""; } + os << ")"; + } }; struct network_selection_source_cell_impl: public network_selection_impl { @@ -236,18 +270,21 @@ struct network_selection_source_cell_impl: public network_selection_impl { const std::string_view& label) const override { return true; } -}; + void print(std::ostream& os) const override { + os << "(source-cell"; + for (const auto& g: sorted_gids) { os << " " << g; } + os << ")"; + } +}; struct network_selection_source_cell_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_source_cell_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step): - gid_begin(gid_begin), - gid_end(gid_end), - step(step) {} + network_selection_source_cell_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -265,6 +302,10 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { + os << "(source-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } }; struct network_selection_destination_cell_impl: public network_selection_impl { @@ -291,17 +332,21 @@ struct network_selection_destination_cell_impl: public network_selection_impl { const std::string_view& label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } + + void print(std::ostream& os) const override { + os << "(destination-cell"; + for (const auto& g: sorted_gids) { os << " " << g; } + os << ")"; + } }; struct network_selection_destination_cell_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_destination_cell_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step): - gid_begin(gid_begin), - gid_end(gid_end), - step(step) {} + network_selection_destination_cell_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -319,19 +364,23 @@ struct network_selection_destination_cell_range_impl: public network_selection_i const std::string_view& label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } + + void print(std::ostream& os) const override { + os << "(destination-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } }; struct network_selection_chain_impl: public network_selection_impl { - std::vector gids; // preserved order of ring + std::vector gids; // preserved order of ring std::vector sorted_gids; network_selection_chain_impl(std::vector gids): gids(std::move(gids)) { - sorted_gids = this->gids; // copy + sorted_gids = this->gids; // copy std::sort(sorted_gids.begin(), sorted_gids.end()); } bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - if(gids.empty()) return false; + if (gids.empty()) return false; // gids size always > 0 frome here on @@ -340,7 +389,7 @@ struct network_selection_chain_impl: public network_selection_impl { !std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid)) return false; - for(std::size_t i = 0; i < gids.size() - 1; ++i) { + for (std::size_t i = 0; i < gids.size() - 1; ++i) { // return true if neighbors in gids list if ((src.gid == gids[i] && dest.gid == gids[i + 1])) return true; } @@ -359,18 +408,21 @@ struct network_selection_chain_impl: public network_selection_impl { const std::string_view& label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } -}; + void print(std::ostream& os) const override { + os << "(chain"; + for (const auto& g: gids) { os << " " << g; } + os << ")"; + } +}; struct network_selection_chain_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_chain_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step): - gid_begin(gid_begin), - gid_end(gid_end), - step(step) {} + network_selection_chain_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -394,17 +446,19 @@ struct network_selection_chain_range_impl: public network_selection_impl { if (gid < gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } + + void print(std::ostream& os) const override { + os << "(chain (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } }; struct network_selection_reverse_chain_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_reverse_chain_range_impl(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step): - gid_begin(gid_begin), - gid_end(gid_end), - step(step) {} + network_selection_reverse_chain_range_impl(gid_range r): + gid_begin(r.begin), + gid_end(r.end), + step(r.step) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { @@ -428,6 +482,10 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl if (gid < gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } + + void print(std::ostream& os) const override { + os << "(chain-reverse (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + } }; struct network_selection_complement_impl: public network_selection_impl { @@ -456,47 +514,56 @@ struct network_selection_complement_impl: public network_selection_impl { } void initialize(const network_label_dict& dict) override { selection->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(complement "; + selection->print(os); + os << ")"; + } }; struct network_selection_named_impl: public network_selection_impl { using impl_pointer_type = std::shared_ptr; - std::variant selection; + impl_pointer_type selection; + std::string selection_name; - explicit network_selection_named_impl(std::string name): selection(std::move(name)) {} + explicit network_selection_named_impl(std::string name): selection_name(std::move(name)) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - if (!std::holds_alternative(selection)) + if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); - return std::get(selection)->select_connection(src, dest); + return selection->select_connection(src, dest); } bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (!std::holds_alternative(selection)) + if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); - return std::get(selection)->select_source(kind, gid, label); + return selection->select_source(kind, gid, label); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (!std::holds_alternative(selection)) + if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); - return std::get(selection)->select_destination(kind, gid, label); + return selection->select_destination(kind, gid, label); } void initialize(const network_label_dict& dict) override { - if (std::holds_alternative(selection)) { - auto s = dict.selection(std::get(selection)); - if (!s.has_value()) - throw arbor_exception(std::string("Network selection with label \"") + - std::get(selection) + "\" not found."); - selection = thingify(s.value(), dict); - } + auto s = dict.selection(selection_name); + if (!s.has_value()) + throw arbor_exception( + std::string("Network selection with label \"") + selection_name + "\" not found."); + selection = thingify(s.value(), dict); }; + + void print(std::ostream& os) const override { + os << "(network-selection \"" << selection_name << "\")"; + } }; struct network_selection_inter_cell_impl: public network_selection_impl { @@ -516,6 +583,8 @@ struct network_selection_inter_cell_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { os << "(inter-cell)"; } }; struct network_selection_custom_impl: public network_selection_impl { @@ -540,16 +609,18 @@ struct network_selection_custom_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { os << "(custom-network-selection)"; } }; struct network_selection_distance_lt_impl: public network_selection_impl { - double distance; + double d; - explicit network_selection_distance_lt_impl(double distance): distance(distance) {} + explicit network_selection_distance_lt_impl(double d): d(d) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return network_location_distance(src.global_location, dest.global_location) < distance; + return distance(src.global_location, dest.global_location) < d; } bool select_source(cell_kind kind, @@ -564,17 +635,19 @@ struct network_selection_distance_lt_impl: public network_selection_impl { return true; } - std::optional max_distance() const override { return distance; } + std::optional max_distance() const override { return d; } + + void print(std::ostream& os) const override { os << "(distance-lt " << d << ")"; } }; struct network_selection_distance_gt_impl: public network_selection_impl { - double distance; + double d; - explicit network_selection_distance_gt_impl(double distance): distance(distance) {} + explicit network_selection_distance_gt_impl(double d): d(d) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return network_location_distance(src.global_location, dest.global_location) > distance; + return distance(src.global_location, dest.global_location) > d; } bool select_source(cell_kind kind, @@ -588,6 +661,8 @@ struct network_selection_distance_gt_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { os << "(distance-gt " << d << ")"; } }; struct network_selection_random_bernoulli_impl: public network_selection_impl { @@ -614,6 +689,10 @@ struct network_selection_random_bernoulli_impl: public network_selection_impl { const std::string_view& label) const override { return true; } + + void print(std::ostream& os) const override { + os << "(random-bernoulli " << seed << " " << probability << ")"; + } }; struct network_selection_random_linear_distance_impl: public network_selection_impl { @@ -641,14 +720,12 @@ struct network_selection_random_linear_distance_impl: public network_selection_i bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - const double distance = - network_location_distance(src.global_location, dest.global_location); + const double d = distance(src.global_location, dest.global_location); - if (distance < distance_begin || distance > distance_end) return false; + if (d < distance_begin || d > distance_end) return false; - const double p = - (p_begin * (distance_end - distance) + p_end * (distance - distance_begin)) / - (distance_end - distance_begin); + const double p = (p_begin * (distance_end - d) + p_end * (d - distance_begin)) / + (distance_end - distance_begin); return uniform_rand_from_key_pair( {unsigned(network_seed::selection_linear_bernoulli), seed}, @@ -669,6 +746,11 @@ struct network_selection_random_linear_distance_impl: public network_selection_i } std::optional max_distance() const override { return distance_end; } + + void print(std::ostream& os) const override { + os << "(random-linear-distance " << seed << " " << distance_begin << " " << p_begin << " " + << distance_end << " " << p_end << ")"; + } }; struct network_selection_intersect_impl: public network_selection_impl { @@ -712,6 +794,14 @@ struct network_selection_intersect_impl: public network_selection_impl { left->initialize(dict); right->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(intersect "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } }; struct network_selection_join_impl: public network_selection_impl { @@ -753,6 +843,14 @@ struct network_selection_join_impl: public network_selection_impl { left->initialize(dict); right->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(join "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } }; struct network_selection_symmetric_difference_impl: public network_selection_impl { @@ -794,6 +892,14 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp left->initialize(dict); right->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(symmetric-difference "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } }; struct network_selection_difference_impl: public network_selection_impl { @@ -833,6 +939,14 @@ struct network_selection_difference_impl: public network_selection_impl { left->initialize(dict); right->initialize(dict); }; + + void print(std::ostream& os) const override { + os << "(difference "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } }; struct network_value_scalar_impl: public network_value_impl { @@ -843,6 +957,8 @@ struct network_value_scalar_impl: public network_value_impl { double get(const network_site_info& src, const network_site_info& dest) const override { return value; } + + void print(std::ostream& os) const override { os << "(scalar " << value << ")"; } }; struct network_value_uniform_distribution_impl: public network_value_impl { @@ -865,6 +981,10 @@ struct network_value_uniform_distribution_impl: public network_value_impl { return (range[1] - range[0]) * rand_num + range[0]; } + + void print(std::ostream& os) const override { + os << "(uniform-distribution " << seed << " " << range[0] << " " << range[1] << ")"; + } }; struct network_value_normal_distribution_impl: public network_value_impl { @@ -882,6 +1002,10 @@ struct network_value_normal_distribution_impl: public network_value_impl { normal_rand_from_key_pair( {unsigned(network_seed::value_normal), seed}, src.hash, dest.hash); } + + void print(std::ostream& os) const override { + os << "(normal-distribution " << seed << " " << mean << " " << std_deviation << ")"; + } }; struct network_value_truncated_normal_distribution_impl: public network_value_impl { @@ -920,6 +1044,11 @@ struct network_value_truncated_normal_distribution_impl: public network_value_im return value; } + + void print(std::ostream& os) const override { + os << "(truncated-normal-distribution " << seed << " " << mean << " " << std_deviation + << " " << range[0] << " " << range[1] << ")"; + } }; struct network_value_custom_impl: public network_value_impl { @@ -930,30 +1059,34 @@ struct network_value_custom_impl: public network_value_impl { double get(const network_site_info& src, const network_site_info& dest) const override { return func(src, dest); } + + void print(std::ostream& os) const override { os << "(custom-network-value)"; } }; struct network_value_named_impl: public network_value_impl { using impl_pointer_type = std::shared_ptr; - std::variant value; + impl_pointer_type value; + std::string value_name; - explicit network_value_named_impl(std::string name): value(std::move(name)) {} + explicit network_value_named_impl(std::string name): value_name(std::move(name)) {} double get(const network_site_info& src, const network_site_info& dest) const override { - if (!std::holds_alternative(value)) - throw arbor_internal_error("Trying to use unitialized named network value."); - return std::get(value)->get(src, dest); + if (!value) throw arbor_internal_error("Trying to use unitialized named network value."); + return value->get(src, dest); } void initialize(const network_label_dict& dict) override { - if (std::holds_alternative(value)) { - auto s = dict.value(std::get(value)); - if (!s.has_value()) - throw arbor_exception(std::string("Network value with label \"") + - std::get(value) + "\" not found."); - value = thingify(s.value(), dict); - } + auto v = dict.value(value_name); + if (!v.has_value()) + throw arbor_exception( + std::string("Network value with label \"") + value_name + "\" not found."); + value = thingify(v.value(), dict); }; + + void print(std::ostream& os) const override { + os << "(network-value \"" << value_name << "\")"; + } }; } // namespace @@ -1043,9 +1176,8 @@ network_selection network_selection::source_cell(std::vector gids return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::source_cell_range(cell_gid_type gid_begin, cell_gid_type gid_end, cell_gid_type step) { - return network_selection( - std::make_shared(gid_begin, gid_end, step)); +network_selection network_selection::source_cell(gid_range range) { + return network_selection(std::make_shared(range)); } network_selection network_selection::destination_cell(std::vector gids) { @@ -1053,29 +1185,21 @@ network_selection network_selection::destination_cell(std::vector std::make_shared(std::move(gids))); } -network_selection network_selection::destination_cell_range(cell_gid_type gid_begin, - cell_gid_type gid_end, cell_gid_type step) { +network_selection network_selection::destination_cell(gid_range range) { return network_selection( - std::make_shared(gid_begin, gid_end, step)); + std::make_shared(range)); } network_selection network_selection::chain(std::vector gids) { - return network_selection( - std::make_shared(std::move(gids))); + return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::chain_range(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step) { - return network_selection( - std::make_shared(gid_begin, gid_end, step)); +network_selection network_selection::chain(gid_range range) { + return network_selection(std::make_shared(range)); } -network_selection network_selection::reverse_chain_range(cell_gid_type gid_begin, - cell_gid_type gid_end, - cell_gid_type step) { - return network_selection( - std::make_shared(gid_begin, gid_end, step)); +network_selection network_selection::chain_reverse(gid_range range) { + return network_selection(std::make_shared(range)); } network_selection network_selection::complement(network_selection s) { @@ -1095,12 +1219,12 @@ network_selection network_selection::custom(custom_func_type func) { return network_selection(std::make_shared(std::move(func))); } -network_selection network_selection::distance_lt(double distance) { - return network_selection(std::make_shared(distance)); +network_selection network_selection::distance_lt(double d) { + return network_selection(std::make_shared(d)); } -network_selection network_selection::distance_gt(double distance) { - return network_selection(std::make_shared(distance)); +network_selection network_selection::distance_gt(double d) { + return network_selection(std::make_shared(d)); } network_selection network_selection::random_linear_distance(unsigned seed, @@ -1168,4 +1292,14 @@ std::optional network_label_dict::value(const std::string& name) return std::nullopt; } +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_selection& s) { + if (s.impl_) s.impl_->print(os); + return os; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_value& v) { + if (v.impl_) v.impl_->print(os); + return os; +} + } // namespace arb diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 71c2e32a64..6a4420368f 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -36,6 +37,8 @@ struct network_selection_impl { virtual void initialize(const network_label_dict& dict) {}; + virtual void print(std::ostream& os) const = 0; + virtual ~network_selection_impl() = default; }; @@ -51,6 +54,8 @@ struct network_value_impl { virtual void initialize(const network_label_dict& dict) {}; + virtual void print(std::ostream& os) const = 0; + virtual ~network_value_impl() = default; }; diff --git a/arborio/include/arborio/label_parse.hpp b/arborio/include/arborio/label_parse.hpp index 88895401e7..3e0f8e42f9 100644 --- a/arborio/include/arborio/label_parse.hpp +++ b/arborio/include/arborio/label_parse.hpp @@ -29,7 +29,8 @@ ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const arb ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_iexpr_expression(const std::string& s); - +ARB_ARBORIO_API parse_label_hopefully parse_network_selection_expression(const std::string& s); +ARB_ARBORIO_API parse_label_hopefully parse_network_value_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_network_selection_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_network_value_expression( const std::string& s); diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 769627675d..797eb724e3 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -18,105 +18,125 @@ namespace arborio { label_parse_error::label_parse_error(const std::string& msg, const arb::src_location& loc): - arb::arbor_exception(concat("error in label description: ", msg," at :", loc.line, ":", loc.column)) -{} + arb::arbor_exception( + concat("error in label description: ", msg, " at :", loc.line, ":", loc.column)) {} namespace { -using eval_map_type= std::unordered_multimap; +using eval_map_type = std::unordered_multimap; -eval_map_type eval_map { +eval_map_type eval_map{ // Functions that return regions - {"region-nil", make_call<>(arb::reg::nil, - "'region-nil' with 0 arguments")}, - {"all", make_call<>(arb::reg::all, - "'all' with 0 arguments")}, - {"tag", make_call(arb::reg::tagged, - "'tag' with 1 argment: (tag_id:integer)")}, - {"segment", make_call(arb::reg::segment, - "'segment' with 1 argment: (segment_id:integer)")}, - {"branch", make_call(arb::reg::branch, - "'branch' with 1 argument: (branch_id:integer)")}, - {"cable", make_call(arb::reg::cable, - "'cable' with 3 arguments: (branch_id:integer prox:real dist:real)")}, - {"region", make_call(arb::reg::named, - "'region' with 1 argument: (name:string)")}, - {"distal-interval", make_call(arb::reg::distal_interval, - "'distal-interval' with 2 arguments: (start:locset extent:real)")}, - {"distal-interval", make_call( - [](arb::locset ls){return arb::reg::distal_interval(std::move(ls), std::numeric_limits::max());}, - "'distal-interval' with 1 argument: (start:locset)")}, - {"proximal-interval", make_call(arb::reg::proximal_interval, - "'proximal-interval' with 2 arguments: (start:locset extent:real)")}, - {"proximal-interval", make_call( - [](arb::locset ls){return arb::reg::proximal_interval(std::move(ls), std::numeric_limits::max());}, - "'proximal_interval' with 1 argument: (start:locset)")}, - {"complete", make_call(arb::reg::complete, - "'complete' with 1 argment: (reg:region)")}, - {"radius-lt", make_call(arb::reg::radius_lt, - "'radius-lt' with 2 arguments: (reg:region radius:real)")}, - {"radius-le", make_call(arb::reg::radius_le, - "'radius-le' with 2 arguments: (reg:region radius:real)")}, - {"radius-gt", make_call(arb::reg::radius_gt, - "'radius-gt' with 2 arguments: (reg:region radius:real)")}, - {"radius-ge", make_call(arb::reg::radius_ge, - "'radius-ge' with 2 arguments: (reg:region radius:real)")}, - {"z-dist-from-root-lt", make_call(arb::reg::z_dist_from_root_lt, - "'z-dist-from-root-lt' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-le", make_call(arb::reg::z_dist_from_root_le, - "'z-dist-from-root-le' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-gt", make_call(arb::reg::z_dist_from_root_gt, - "'z-dist-from-root-gt' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-ge", make_call(arb::reg::z_dist_from_root_ge, - "'z-dist-from-root-ge' with 1 arguments: (distance:real)")}, - {"complement", make_call(arb::complement, - "'complement' with 1 argment: (reg:region)")}, - {"difference", make_call(arb::difference, - "'difference' with 2 argments: (reg:region, reg:region)")}, - {"join", make_fold(static_cast(arb::join), - "'join' with at least 2 arguments: (region region [...region])")}, - {"intersect", make_fold(static_cast(arb::intersect), - "'intersect' with at least 2 arguments: (region region [...region])")}, + {"region-nil", make_call<>(arb::reg::nil, "'region-nil' with 0 arguments")}, + {"all", make_call<>(arb::reg::all, "'all' with 0 arguments")}, + {"tag", make_call(arb::reg::tagged, "'tag' with 1 argment: (tag_id:integer)")}, + {"segment", + make_call(arb::reg::segment, "'segment' with 1 argment: (segment_id:integer)")}, + {"branch", make_call(arb::reg::branch, "'branch' with 1 argument: (branch_id:integer)")}, + {"cable", + make_call(arb::reg::cable, + "'cable' with 3 arguments: (branch_id:integer prox:real dist:real)")}, + {"region", make_call(arb::reg::named, "'region' with 1 argument: (name:string)")}, + {"distal-interval", + make_call(arb::reg::distal_interval, + "'distal-interval' with 2 arguments: (start:locset extent:real)")}, + {"distal-interval", + make_call( + [](arb::locset ls) { + return arb::reg::distal_interval(std::move(ls), std::numeric_limits::max()); + }, + "'distal-interval' with 1 argument: (start:locset)")}, + {"proximal-interval", + make_call(arb::reg::proximal_interval, + "'proximal-interval' with 2 arguments: (start:locset extent:real)")}, + {"proximal-interval", + make_call( + [](arb::locset ls) { + return arb::reg::proximal_interval( + std::move(ls), std::numeric_limits::max()); + }, + "'proximal_interval' with 1 argument: (start:locset)")}, + {"complete", + make_call(arb::reg::complete, "'complete' with 1 argment: (reg:region)")}, + {"radius-lt", + make_call(arb::reg::radius_lt, + "'radius-lt' with 2 arguments: (reg:region radius:real)")}, + {"radius-le", + make_call(arb::reg::radius_le, + "'radius-le' with 2 arguments: (reg:region radius:real)")}, + {"radius-gt", + make_call(arb::reg::radius_gt, + "'radius-gt' with 2 arguments: (reg:region radius:real)")}, + {"radius-ge", + make_call(arb::reg::radius_ge, + "'radius-ge' with 2 arguments: (reg:region radius:real)")}, + {"z-dist-from-root-lt", + make_call(arb::reg::z_dist_from_root_lt, + "'z-dist-from-root-lt' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-le", + make_call(arb::reg::z_dist_from_root_le, + "'z-dist-from-root-le' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-gt", + make_call(arb::reg::z_dist_from_root_gt, + "'z-dist-from-root-gt' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-ge", + make_call(arb::reg::z_dist_from_root_ge, + "'z-dist-from-root-ge' with 1 arguments: (distance:real)")}, + {"complement", + make_call(arb::complement, "'complement' with 1 argment: (reg:region)")}, + {"difference", + make_call(arb::difference, + "'difference' with 2 argments: (reg:region, reg:region)")}, + {"join", + make_fold(static_cast(arb::join), + "'join' with at least 2 arguments: (region region [...region])")}, + {"intersect", + make_fold( + static_cast(arb::intersect), + "'intersect' with at least 2 arguments: (region region [...region])")}, // Functions that return locsets - {"locset-nil", make_call<>(arb::ls::nil, - "'locset-nil' with 0 arguments")}, - {"root", make_call<>(arb::ls::root, - "'root' with 0 arguments")}, - {"location", make_call([](int bid, double pos){return arb::ls::location(arb::msize_t(bid), pos);}, - "'location' with 2 arguments: (branch_id:integer position:real)")}, - {"terminal", make_call<>(arb::ls::terminal, - "'terminal' with 0 arguments")}, - {"distal", make_call(arb::ls::most_distal, - "'distal' with 1 argument: (reg:region)")}, - {"proximal", make_call(arb::ls::most_proximal, - "'proximal' with 1 argument: (reg:region)")}, - {"distal-translate", make_call(arb::ls::distal_translate, - "'distal-translate' with 2 arguments: (ls:locset distance:real)")}, - {"proximal-translate", make_call(arb::ls::proximal_translate, - "'proximal-translate' with 2 arguments: (ls:locset distance:real)")}, - {"uniform", make_call(arb::ls::uniform, - "'uniform' with 4 arguments: (reg:region, first:int, last:int, seed:int)")}, - {"on-branches", make_call(arb::ls::on_branches, - "'on-branches' with 1 argument: (pos:double)")}, - {"on-components", make_call(arb::ls::on_components, - "'on-components' with 2 arguments: (pos:double, reg:region)")}, - {"boundary", make_call(arb::ls::boundary, - "'boundary' with 1 argument: (reg:region)")}, - {"cboundary", make_call(arb::ls::cboundary, - "'cboundary' with 1 argument: (reg:region)")}, - {"segment-boundaries", make_call<>(arb::ls::segment_boundaries, - "'segment-boundaries' with 0 arguments")}, - {"support", make_call(arb::ls::support, - "'support' with 1 argument (ls:locset)")}, - {"locset", make_call(arb::ls::named, - "'locset' with 1 argument: (name:string)")}, - {"restrict", make_call(arb::ls::restrict, - "'restrict' with 2 arguments: (ls:locset, reg:region)")}, - {"join", make_fold(static_cast(arb::join), - "'join' with at least 2 arguments: (locset locset [...locset])")}, - {"sum", make_fold(static_cast(arb::sum), - "'sum' with at least 2 arguments: (locset locset [...locset])")}, - + {"locset-nil", make_call<>(arb::ls::nil, "'locset-nil' with 0 arguments")}, + {"root", make_call<>(arb::ls::root, "'root' with 0 arguments")}, + {"location", + make_call( + [](int bid, double pos) { return arb::ls::location(arb::msize_t(bid), pos); }, + "'location' with 2 arguments: (branch_id:integer position:real)")}, + {"terminal", make_call<>(arb::ls::terminal, "'terminal' with 0 arguments")}, + {"distal", + make_call(arb::ls::most_distal, "'distal' with 1 argument: (reg:region)")}, + {"proximal", + make_call(arb::ls::most_proximal, "'proximal' with 1 argument: (reg:region)")}, + {"distal-translate", + make_call(arb::ls::distal_translate, + "'distal-translate' with 2 arguments: (ls:locset distance:real)")}, + {"proximal-translate", + make_call(arb::ls::proximal_translate, + "'proximal-translate' with 2 arguments: (ls:locset distance:real)")}, + {"uniform", + make_call(arb::ls::uniform, + "'uniform' with 4 arguments: (reg:region, first:int, last:int, seed:int)")}, + {"on-branches", + make_call(arb::ls::on_branches, "'on-branches' with 1 argument: (pos:double)")}, + {"on-components", + make_call(arb::ls::on_components, + "'on-components' with 2 arguments: (pos:double, reg:region)")}, + {"boundary", + make_call(arb::ls::boundary, "'boundary' with 1 argument: (reg:region)")}, + {"cboundary", + make_call(arb::ls::cboundary, "'cboundary' with 1 argument: (reg:region)")}, + {"segment-boundaries", + make_call<>(arb::ls::segment_boundaries, "'segment-boundaries' with 0 arguments")}, + {"support", make_call(arb::ls::support, "'support' with 1 argument (ls:locset)")}, + {"locset", make_call(arb::ls::named, "'locset' with 1 argument: (name:string)")}, + {"restrict", + make_call(arb::ls::restrict, + "'restrict' with 2 arguments: (ls:locset, reg:region)")}, + {"join", + make_fold(static_cast(arb::join), + "'join' with at least 2 arguments: (locset locset [...locset])")}, + {"sum", + make_fold(static_cast(arb::sum), + "'sum' with at least 2 arguments: (locset locset [...locset])")}, // iexpr {"iexpr", make_call(arb::iexpr::named, "iexpr with 1 argument: (value:string)")}, @@ -125,52 +145,92 @@ eval_map_type eval_map { {"pi", make_call<>(arb::iexpr::pi, "iexpr with no argument")}, - {"distance", make_call(static_cast(arb::iexpr::distance), + {"distance", + make_call( + static_cast(arb::iexpr::distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"distance", make_call(static_cast(arb::iexpr::distance), + {"distance", + make_call(static_cast(arb::iexpr::distance), "iexpr with 1 argument: (loc:locset)")}, - {"distance", make_call(static_cast(arb::iexpr::distance), + {"distance", + make_call( + static_cast(arb::iexpr::distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"distance", make_call(static_cast(arb::iexpr::distance), + {"distance", + make_call(static_cast(arb::iexpr::distance), "iexpr with 1 argument: (reg:region)")}, - {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", + make_call( + static_cast(arb::iexpr::proximal_distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", + make_call( + static_cast(arb::iexpr::proximal_distance), "iexpr with 1 argument: (loc:locset)")}, - {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", + make_call( + static_cast(arb::iexpr::proximal_distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", + make_call( + static_cast(arb::iexpr::proximal_distance), "iexpr with 1 arguments: (reg:region)")}, - {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), + {"distal-distance", + make_call( + static_cast(arb::iexpr::distal_distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), + {"distal-distance", + make_call( + static_cast(arb::iexpr::distal_distance), "iexpr with 1 argument: (loc:locset)")}, - {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), + {"distal-distance", + make_call( + static_cast(arb::iexpr::distal_distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), + {"distal-distance", + make_call( + static_cast(arb::iexpr::distal_distance), "iexpr with 1 argument: (reg:region)")}, - {"interpolation", make_call(static_cast(arb::iexpr::interpolation), - "iexpr with 4 arguments: (prox_value:double, prox_list:locset, dist_value:double, dist_list:locset)")}, - {"interpolation", make_call(static_cast(arb::iexpr::interpolation), - "iexpr with 4 arguments: (prox_value:double, prox_list:region, dist_value:double, dist_list:region)")}, - - {"radius", make_call(static_cast(arb::iexpr::radius), "iexpr with 1 argument: (value:double)")}, - {"radius", make_call<>(static_cast(arb::iexpr::radius), "iexpr with no argument")}, - - {"diameter", make_call(static_cast(arb::iexpr::diameter), "iexpr with 1 argument: (value:double)")}, - {"diameter", make_call<>(static_cast(arb::iexpr::diameter), "iexpr with no argument")}, + {"interpolation", + make_call( + static_cast( + arb::iexpr::interpolation), + "iexpr with 4 arguments: (prox_value:double, prox_list:locset, dist_value:double, " + "dist_list:locset)")}, + {"interpolation", + make_call( + static_cast( + arb::iexpr::interpolation), + "iexpr with 4 arguments: (prox_value:double, prox_list:region, dist_value:double, " + "dist_list:region)")}, + + {"radius", + make_call(static_cast(arb::iexpr::radius), + "iexpr with 1 argument: (value:double)")}, + {"radius", + make_call<>(static_cast(arb::iexpr::radius), "iexpr with no argument")}, + + {"diameter", + make_call(static_cast(arb::iexpr::diameter), + "iexpr with 1 argument: (value:double)")}, + {"diameter", + make_call<>(static_cast(arb::iexpr::diameter), "iexpr with no argument")}, {"exp", make_call(arb::iexpr::exp, "iexpr with 1 argument: (value:iexpr)")}, {"exp", make_call(arb::iexpr::exp, "iexpr with 1 argument: (value:double)")}, - {"step_right", make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:iexpr)")}, - {"step_right", make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:double)")}, + {"step_right", + make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:iexpr)")}, + {"step_right", + make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:double)")}, - {"step_left", make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:iexpr)")}, - {"step_left", make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:double)")}, + {"step_left", + make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:iexpr)")}, + {"step_left", + make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:double)")}, {"step", make_call(arb::iexpr::step, "iexpr with 1 argument: (value:iexpr)")}, {"step", make_call(arb::iexpr::step, "iexpr with 1 argument: (value:double)")}, @@ -178,23 +238,44 @@ eval_map_type eval_map { {"log", make_call(arb::iexpr::log, "iexpr with 1 argument: (value:iexpr)")}, {"log", make_call(arb::iexpr::log, "iexpr with 1 argument: (value:double)")}, - {"add", make_conversion_fold(arb::iexpr::add, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, - - {"sub", make_conversion_fold(arb::iexpr::sub, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, - - {"mul", make_conversion_fold(arb::iexpr::mul, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, - - {"div", make_conversion_fold(arb::iexpr::div, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, + {"add", + make_conversion_fold(arb::iexpr::add, + "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " + "double)])")}, + + {"sub", + make_conversion_fold(arb::iexpr::sub, + "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " + "double)])")}, + + {"mul", + make_conversion_fold(arb::iexpr::mul, + "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " + "double)])")}, + + {"div", + make_conversion_fold(arb::iexpr::div, + "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " + "double)])")}, }; eval_map_type network_eval_map{ + {"gid-range", + make_call([](int begin, int end) { return arb::gid_range(begin, end); }, + "Gid range [begin, end) with step size 1: ((begin:integer) (end:integer))")}, + {"gid-range", + make_call( + [](int begin, int end, int step) { return arb::gid_range(begin, end, step); }, + "Gid range [begin, end) with step size: ((begin:integer) (end:integer) " + "(step:integer))")}, + // cell kind {"cable-cell", make_call<>([]() { return arb::cell_kind::cable; }, "Cable cell kind")}, {"lif-cell", make_call<>([]() { return arb::cell_kind::lif; }, "Lif cell kind")}, {"benchmark-cell", make_call<>([]() { return arb::cell_kind::benchmark; }, "Benchmark cell kind")}, {"spike-source-cell", - make_call<>([]() { return arb::cell_kind::benchmark; }, "Spike source cell kind")}, + make_call<>([]() { return arb::cell_kind::spike_source; }, "Spike source cell kind")}, // network_selection {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, @@ -206,15 +287,15 @@ eval_map_type network_eval_map{ make_call(arb::network_selection::named, "network selection with 1 argument: (value:string)")}, {"intersect", - make_fold(arb::network_selection::intersect, + make_conversion_fold(arb::network_selection::intersect, "intersection of network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, {"join", - make_fold(arb::network_selection::join, + make_conversion_fold(arb::network_selection::join, "join or union operation of network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, - {"symmetric_difference", - make_fold(arb::network_selection::symmetric_difference, + {"symmetric-difference", + make_conversion_fold(arb::network_selection::symmetric_difference, "symmetric difference operation between network selections with at least 2 arguments: " "(network_selection network_selection [...network_selection])")}, {"difference", @@ -235,10 +316,10 @@ eval_map_type network_eval_map{ make_arg_vec_call( [](const std::vector>& vec) { std::vector labels; - std::transform(vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& - x) { - return std::get(x); - }); + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); return arb::network_selection::source_label(std::move(labels)); }, "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, @@ -246,10 +327,10 @@ eval_map_type network_eval_map{ make_arg_vec_call( [](const std::vector>& vec) { std::vector labels; - std::transform(vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& - x) { - return std::get(x); - }); + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); return arb::network_selection::destination_label(std::move(labels)); }, "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, @@ -263,10 +344,10 @@ eval_map_type network_eval_map{ return arb::network_selection::source_cell(std::move(gids)); }, "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"source-cell-range", - make_call(arb::network_selection::source_cell_range, - "all sources in cell with gid range [begin, end) with given step size: (begin:integer) " - "(end:integer) (step:integer)")}, + {"source-cell", + make_call(static_cast( + arb::network_selection::source_cell), + "all sources in cell with gid range: (range:gid-range)")}, {"destination-cell", make_arg_vec_call( [](const std::vector>& vec) { @@ -277,10 +358,11 @@ eval_map_type network_eval_map{ return arb::network_selection::destination_cell(std::move(gids)); }, "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-cell-range", - make_call(arb::network_selection::destination_cell_range, - "all destinations in cell with gid range [begin, end) with given step size: " - "(begin:integer) (end:integer) (step:integer)")}, + {"destination-cell", + make_call(static_cast( + arb::network_selection::destination_cell), + "all destinations in cell with gid range: " + "(range:gid-range)")}, {"chain", make_arg_vec_call( [](const std::vector>& vec) { @@ -292,13 +374,14 @@ eval_map_type network_eval_map{ }, "A chain of connections in the given order of gids in list, such that entry \"i\" is " "the source and entry \"i+1\" the destination: (gid:integer) [...(gid:integer)]")}, - {"chain-range", - make_call(arb::network_selection::chain_range, + {"chain", + make_call( + static_cast(arb::network_selection::chain), "A chain of connections for all gids in range [begin, end) with given step size. Each " "entry \"i\" is connected as source to the destination \"i+1\": (begin:integer) " "(end:integer) (step:integer)")}, - {"reverse-chain-range", - make_call(arb::network_selection::reverse_chain_range, + {"chain-reverse", + make_call(arb::network_selection::chain_reverse, "A chain of connections for all gids in range [begin, end) with given step size. Each " "entry \"i+1\" is connected as source to the destination \"i\". This results in " "connection directions in reverse compared to the (chain-range ...) selection: " @@ -330,18 +413,18 @@ eval_map_type network_eval_map{ {"network-value", make_call(arb::network_value::named, "network value with 1 argument: (value:string)")}, - {"uniform_distribution", + {"uniform-distribution", make_call( [](unsigned seed, double begin, double end) { return arb::network_value::uniform_distribution(seed, {begin, end}); }, "Uniform random distribution within interval [begin, end): (seed:integer, begin:real, " "end:real)")}, - {"normal_distribution", + {"normal-distribution", make_call(arb::network_value::normal_distribution, "Normal random distribution with given mean and standard deviation: (seed:integer, " "mean:real, std_deviation:real)")}, - {"truncated_normal_distribution", + {"truncated-normal-distribution", make_call( [](unsigned seed, double mean, double std_deviation, double begin, double end) { return arb::network_value::truncated_normal_distribution( @@ -356,15 +439,11 @@ eval_map_type network_eval_map{ parse_label_hopefully eval(const s_expr& e, const eval_map_type& map); parse_label_hopefully> eval_args(const s_expr& e, const eval_map_type& map) { - if (!e) return {std::vector{}}; // empty argument list + if (!e) return {std::vector{}}; // empty argument list std::vector args; for (auto& h: e) { - if (auto arg=eval(h, map)) { - args.push_back(std::move(*arg)); - } - else { - return util::unexpected(std::move(arg.error())); - } + if (auto arg = eval(h, map)) { args.push_back(std::move(*arg)); } + else { return util::unexpected(std::move(arg.error())); } } return args; } @@ -378,20 +457,20 @@ parse_label_hopefully> eval_args(const s_expr& e, const ev // types (integer, real, region, locset) are inferred from the arguments. std::string eval_description(const char* name, const std::vector& args) { auto type_string = [](const std::type_info& t) -> const char* { - if (t==typeid(int)) return "integer"; - if (t==typeid(double)) return "real"; - if (t==typeid(arb::region)) return "region"; - if (t==typeid(arb::locset)) return "locset"; + if (t == typeid(int)) return "integer"; + if (t == typeid(double)) return "real"; + if (t == typeid(arb::region)) return "region"; + if (t == typeid(arb::locset)) return "locset"; return "unknown"; }; const auto nargs = args.size(); - std::string msg = concat("'", name, "' with ", nargs, "argument", nargs!=1u?"s:" : ":"); + std::string msg = concat("'", name, "' with ", nargs, "argument", nargs != 1u ? "s:" : ":"); if (nargs) { msg += " ("; bool first = true; for (auto& a: args) { - msg += concat(first?"":" ", type_string(a.type())); + msg += concat(first ? "" : " ", type_string(a.type())); first = false; } msg += ")"; @@ -412,42 +491,43 @@ std::string eval_description(const char* name, const std::vector& args // // If there was an unexpected/fatal error, an exception will be thrown. parse_label_hopefully eval(const s_expr& e, const eval_map_type& map) { - if (e.is_atom()) { - return eval_atom(e); - } + if (e.is_atom()) { return eval_atom(e); } if (e.head().is_atom()) { // This must be a function evaluation, where head is the function name, and // tail is a list of arguments. // Evaluate the arguments, and return error state if an error occurred. auto args = eval_args(e.tail(), map); - if (!args) { - return util::unexpected(args.error()); - } + if (!args) { return util::unexpected(args.error()); } // Find all candidate functions that match the name of the function. auto& name = e.head().atom().spelling; auto matches = map.equal_range(name); // Search for a candidate that matches the argument list. - for (auto i=matches.first; i!=matches.second; ++i) { - if (i->second.match_args(*args)) { // found a match: evaluate and return. + for (auto i = matches.first; i != matches.second; ++i) { + if (i->second.match_args(*args)) { // found a match: evaluate and return. return i->second.eval(*args); } } // Unable to find a match: try to return a helpful error message. const auto nc = std::distance(matches.first, matches.second); - auto msg = concat("No matches for ", eval_description(name.c_str(), *args), "\n There are ", nc, " potential candidates", nc?":":"."); + auto msg = concat("No matches for ", + eval_description(name.c_str(), *args), + "\n There are ", + nc, + " potential candidates", + nc ? ":" : "."); int count = 0; - for (auto i=matches.first; i!=matches.second; ++i) { + for (auto i = matches.first; i != matches.second; ++i) { msg += concat("\n Candidate ", ++count, " ", i->second.message); } return util::unexpected(label_parse_error(msg, location(e))); } return util::unexpected(label_parse_error( - concat("'", e, "' is not either integer, real expression of the form (op )"), - location(e))); + concat("'", e, "' is not either integer, real expression of the form (op )"), + location(e))); } } // namespace @@ -461,65 +541,47 @@ ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const s_e ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const std::string& s) { if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(region)) { - return {std::move(std::any_cast(*e))}; - } + if (e->type() == typeid(region)) { return {std::move(std::any_cast(*e))}; } if (e->type() == typeid(std::string)) { return {reg::named(std::move(std::any_cast(*e)))}; } - return util::unexpected( - label_parse_error( - concat("Invalid region description: '", s ,"' is neither a valid region expression or region label string."))); - } - else { - return util::unexpected(label_parse_error(std::string()+e.error().what())); + return util::unexpected(label_parse_error(concat("Invalid region description: '", + s, + "' is neither a valid region expression or region label string."))); } + else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s) { if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(locset)) { - return {std::move(std::any_cast(*e))}; - } + if (e->type() == typeid(locset)) { return {std::move(std::any_cast(*e))}; } if (e->type() == typeid(std::string)) { return {ls::named(std::move(std::any_cast(*e)))}; } - return util::unexpected( - label_parse_error( - concat("Invalid region description: '", s ,"' is neither a valid locset expression or locset label string."))); - } - else { - return util::unexpected(label_parse_error(std::string()+e.error().what())); + return util::unexpected(label_parse_error(concat("Invalid region description: '", + s, + "' is neither a valid locset expression or locset label string."))); } + else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } parse_label_hopefully parse_iexpr_expression(const std::string& s) { if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(iexpr)) { - return {std::move(std::any_cast(*e))}; - } - return util::unexpected( - label_parse_error( - concat("Invalid iexpr description: '", s))); - } - else { - return util::unexpected(label_parse_error(std::string()+e.error().what())); + if (e->type() == typeid(iexpr)) { return {std::move(std::any_cast(*e))}; } + return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); } + else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } - -parse_label_hopefully parse_network_selection_expression(const std::string& s) { +parse_label_hopefully parse_network_selection_expression( + const std::string& s) { if (auto e = eval(parse_s_expr(s), network_eval_map)) { if (e->type() == typeid(arb::network_selection)) { return {std::move(std::any_cast(*e))}; } - return util::unexpected( - label_parse_error( - concat("Invalid iexpr description: '", s))); - } - else { - return util::unexpected(label_parse_error(std::string()+e.error().what())); + return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); } + else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } parse_label_hopefully parse_network_value_expression(const std::string& s) { @@ -527,13 +589,9 @@ parse_label_hopefully parse_network_value_expression(const s if (e->type() == typeid(arb::network_value)) { return {std::move(std::any_cast(*e))}; } - return util::unexpected( - label_parse_error( - concat("Invalid iexpr description: '", s))); - } - else { - return util::unexpected(label_parse_error(std::string()+e.error().what())); + return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); } + else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } -} // namespace arborio +} // namespace arborio diff --git a/python/network.cpp b/python/network.cpp index 242bab2ee9..7331479c6f 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -14,6 +14,8 @@ #include #include "error.hpp" +#include "util.hpp" +#include "strprintf.hpp" namespace py = pybind11; @@ -29,34 +31,57 @@ void register_network(py::module& m) { .def_readwrite("kind", &arb::network_site_info::kind) .def_readwrite("label", &arb::network_site_info::label) .def_readwrite("location", &arb::network_site_info::location) - .def_readwrite("global_location", &arb::network_site_info::global_location); + .def_readwrite("global_location", &arb::network_site_info::global_location) + .def("__repr__", [](const arb::network_site_info& s) { + return util::pprintf("", + s.lid, + s.kind, + s.label, + s.location, + s.global_location); + }); py::class_ network_selection( m, "network_selection", "Network selection."); - network_selection.def_static("custom", [](arb::network_selection::custom_func_type func) { - return arb::network_selection::custom( - [=](const arb::network_site_info& src, const arb::network_site_info& dest) { - return try_catch_pyexception( - [&]() { - pybind11::gil_scoped_acquire guard; - return func(src, dest); - }, - "Python error already thrown"); - }); - }); + network_selection + .def_static("custom", + [](arb::network_selection::custom_func_type func) { + return arb::network_selection::custom( + [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(src, dest); + }, + "Python error already thrown"); + }); + }) + .def("__str__", + [](const arb::network_selection& s) { + return util::pprintf("", s); + }) + .def("__repr__", [](const arb::network_selection& s) { return util::pprintf("{}", s); }); py::class_ network_value(m, "network_value", "Network value."); - network_value.def_static("custom", [](arb::network_value::custom_func_type func) { - return arb::network_value::custom( - [=](const arb::network_site_info& src, const arb::network_site_info& dest) { - return try_catch_pyexception( - [&]() { - pybind11::gil_scoped_acquire guard; - return func(src, dest); - }, - "Python error already thrown"); - }); - }); + network_value + .def_static("custom", + [](arb::network_value::custom_func_type func) { + return arb::network_value::custom( + [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(src, dest); + }, + "Python error already thrown"); + }); + }) + .def("__str__", + [](const arb::network_value& v) { + return util::pprintf("", v); + }) + .def("__repr__", [](const arb::network_value& v) { return util::pprintf("{}", v); }); py::class_ network_description( m, "network_description", "Network description."); diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index f1e9237833..1888cab937 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -215,6 +216,24 @@ std::string round_trip_iexpr(const char* in) { } } +std::string round_trip_network_selection(const char* in) { + if (auto x = parse_network_selection_expression(in)) { + return util::pprintf("{}", std::any_cast(*x)); + } + else { + return x.error().what(); + } +} + +std::string round_trip_network_value(const char* in) { + if (auto x = parse_network_value_expression(in)) { + return util::pprintf("{}", std::any_cast(*x)); + } + else { + return x.error().what(); + } +} + TEST(cv_policies, round_tripping) { auto literals = {"(every-segment (tag 42))", @@ -336,6 +355,69 @@ TEST(iexpr, round_tripping) { round_trip_label("(pi)")); } +TEST(network_selection, round_tripping) { + auto network_literals = { + "(all)", + "(none)", + "(inter-cell)", + "(network-selection \"abc\")", + "(intersect (all) (none))", + "(join (all) (none))", + "(symmetric-difference (all) (none))", + "(difference (all) (none))", + "(complement (all))", + "(source-cell-kind (cable-cell))", + "(source-cell-kind (lif-cell))", + "(source-cell-kind (benchmark-cell))", + "(source-cell-kind (spike-source-cell))", + "(destination-cell-kind (cable-cell))", + "(destination-cell-kind (lif-cell))", + "(destination-cell-kind (benchmark-cell))", + "(destination-cell-kind (spike-source-cell))", + "(source-label \"abc\")", + "(source-label \"abc\" \"def\")", + "(source-label \"abc\" \"def\" \"ghi\")", + "(destination-label \"abc\")", + "(destination-label \"abc\" \"def\")", + "(destination-label \"abc\" \"def\" \"ghi\")", + "(source-cell 0 1 3 15)", + "(source-cell (gid-range 4 8 2))", + "(destination-cell 0 1 3 15)", + "(destination-cell (gid-range 4 8 2))", + "(chain 3 1 0 5 7 6)", // order should be preserved + "(chain (gid-range 2 14 3))", + "(chain-reverse (gid-range 2 14 3))", + "(random-bernoulli 42 0.1)", + "(random-linear-distance 42 2.5 0.2 5.2 0.9)", + "(distance-lt 0.5)", + "(distance-gt 0.5)", + }; + for (auto l: network_literals) { + EXPECT_EQ(l, round_trip_network_selection(l)); + } + + // test order for more than two arguments + EXPECT_EQ("(join (join (join (all) (none)) (inter-cell)) (source-cell 0))", + round_trip_network_selection("(join (all) (none) (inter-cell) (source-cell 0))")); + EXPECT_EQ("(intersect (intersect (intersect (all) (none)) (inter-cell)) (source-cell 0))", + round_trip_network_selection("(intersect (all) (none) (inter-cell) (source-cell 0))")); +} + + +TEST(network_value, round_tripping) { + auto network_literals = { + "(scalar 1.3)", + "(network-value \"abc\")", + "(uniform-distribution 42 0 0.8)", + "(normal-distribution 42 0.5 0.1)", + "(truncated-normal-distribution 42 0.5 0.1 0.3 0.7)", + }; + + for (auto l: network_literals) { + EXPECT_EQ(l, round_trip_network_value(l)); + } +} + TEST(regloc, round_tripping) { EXPECT_EQ("(cable 3 0 1)", round_trip_label("(branch 3)")); EXPECT_EQ("(intersect (tag 1) (intersect (tag 2) (tag 3)))", round_trip_label("(intersect (tag 1) (tag 2) (tag 3))")); From 7b0dbf1898a936ce5b5d1e734a20a57ce7e6aa44 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 13 Apr 2023 15:45:34 +0200 Subject: [PATCH 23/84] add math operation to network value --- arbor/include/arbor/network.hpp | 130 ++++++------ arbor/network.cpp | 352 +++++++++++++++++++++++++------- arborio/label_parse.cpp | 70 ++++++- test/unit/test_s_expr.cpp | 55 ++++- 4 files changed, 458 insertions(+), 149 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index d2ad6b9dc2..dc3478726f 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -48,6 +48,75 @@ struct network_value_impl; class ARB_SYMBOL_VISIBLE network_label_dict; +class ARB_SYMBOL_VISIBLE network_value { +public: + using custom_func_type = + std::function; + + network_value() { *this = network_value::scalar(0.0); } + + // Scalar value with conversion from double + network_value(double value) { *this = network_value::scalar(value); } + + // Scalar value. Will always return the same value given at construction. + static network_value scalar(double value); + + static network_value named(std::string name); + + static network_value distance(double scale = 1.0); + + // Uniform random value in (range[0], range[1]]. Always returns the same value for repeated + // calls with the same arguments and calls are symmetric v(a, b) = v(b, a). + static network_value uniform_distribution(unsigned seed, const std::array& range); + + // Radom value from a normal distribution with given mean and standard deviation. Always returns + // the same value for repeated calls with the same arguments and calls are symmetric v(a, b) = + // v(b, a). + static network_value normal_distribution(unsigned seed, double mean, double std_deviation); + + // Radom value from a truncated normal distribution with given mean and standard deviation (of a + // non-truncated normal distribution), where the value is always in (range[0], range[1]]. Always + // returns the same value for repeated calls with the same arguments and calls are symmetric + // v(a, b) = v(b, a). Note: Values are generated by reject-accept method from a normal + // distribution. Low acceptance rate can leed to poor performance, for example with very small + // ranges or a mean far outside the range. + static network_value truncated_normal_distribution(unsigned seed, + double mean, + double std_deviation, + const std::array& range); + + // Custom value using the provided function "func". Repeated calls with the same arguments + // to "func" must yield the same result. For gap junction values, + // "func" must be symmetric (func(a,b) = func(b,a)). + static network_value custom(custom_func_type func); + + static network_value add(network_value left, network_value right); + + static network_value sub(network_value left, network_value right); + + static network_value mul(network_value left, network_value right); + + static network_value div(network_value left, network_value right); + + static network_value exp(network_value v); + + static network_value log(network_value v); + + static network_value min(network_value left, network_value right); + + static network_value max(network_value left, network_value right); + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_value& v); + +private: + network_value(std::shared_ptr impl); + + friend std::shared_ptr thingify(network_value v, + const network_label_dict& dict); + + std::shared_ptr impl_; +}; + class ARB_SYMBOL_VISIBLE network_selection { public: using custom_func_type = @@ -101,7 +170,7 @@ class ARB_SYMBOL_VISIBLE network_selection { // Random selection using the bernoulli random distribution with probability "p" between 0.0 // and 1.0 - static network_selection random_bernoulli(unsigned seed, double p); + static network_selection random(unsigned seed, network_value p); // Custom selection using the provided function "func". Repeated calls with the same arguments // to "func" must yield the same result. For gap junction selection, @@ -116,14 +185,6 @@ class ARB_SYMBOL_VISIBLE network_selection { // through an internal spatial data structure. static network_selection distance_gt(double d); - // randomly selected with a probability linearly interpolated between [p_begin, p_end] based on - // the distance in the interval [distance_begin, distance_end]. - static network_selection random_linear_distance(unsigned seed, - double distance_begin, - double p_begin, - double distance_end, - double p_end); - ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_selection& s); private: @@ -135,57 +196,6 @@ class ARB_SYMBOL_VISIBLE network_selection { std::shared_ptr impl_; }; -class ARB_SYMBOL_VISIBLE network_value { -public: - using custom_func_type = - std::function; - - network_value() { *this = network_value::scalar(0.0); } - - // Scalar value with conversion from double - network_value(double value) { *this = network_value::scalar(value); } - - // Scalar value. Will always return the same value given at construction. - static network_value scalar(double value); - - static network_value named(std::string name); - - // Uniform random value in (range[0], range[1]]. Always returns the same value for repeated - // calls with the same arguments and calls are symmetric v(a, b) = v(b, a). - static network_value uniform_distribution(unsigned seed, const std::array& range); - - // Radom value from a normal distribution with given mean and standard deviation. Always returns - // the same value for repeated calls with the same arguments and calls are symmetric v(a, b) = - // v(b, a). - static network_value normal_distribution(unsigned seed, double mean, double std_deviation); - - // Radom value from a truncated normal distribution with given mean and standard deviation (of a - // non-truncated normal distribution), where the value is always in (range[0], range[1]]. Always - // returns the same value for repeated calls with the same arguments and calls are symmetric - // v(a, b) = v(b, a). Note: Values are generated by reject-accept method from a normal - // distribution. Low acceptance rate can leed to poor performance, for example with very small - // ranges or a mean far outside the range. - static network_value truncated_normal_distribution(unsigned seed, - double mean, - double std_deviation, - const std::array& range); - - // Custom value using the provided function "func". Repeated calls with the same arguments - // to "func" must yield the same result. For gap junction values, - // "func" must be symmetric (func(a,b) = func(b,a)). - static network_value custom(custom_func_type func); - - ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_value& v); - -private: - network_value(std::shared_ptr impl); - - friend std::shared_ptr thingify(network_value v, - const network_label_dict& dict); - - std::shared_ptr impl_; -}; - class ARB_SYMBOL_VISIBLE network_label_dict { public: using ns_map = std::unordered_map; diff --git a/arbor/network.cpp b/arbor/network.cpp index abf0fad636..d4e59e5ebe 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -23,8 +23,7 @@ namespace { // Partial seed to use for network_value and network_selection generation. // Different seed for each type to avoid unintentional correlation. enum class network_seed : unsigned { - selection_bernoulli = 2058443, - selection_linear_bernoulli = 839033, + selection_random = 2058443, value_uniform = 48202, value_normal = 8405, value_truncated_normal = 380237, @@ -665,72 +664,21 @@ struct network_selection_distance_gt_impl: public network_selection_impl { void print(std::ostream& os) const override { os << "(distance-gt " << d << ")"; } }; -struct network_selection_random_bernoulli_impl: public network_selection_impl { +struct network_selection_random_impl: public network_selection_impl { unsigned seed; - double probability; - network_selection_random_bernoulli_impl(unsigned seed, double p): seed(seed), probability(p) {} + network_value p_value; + std::shared_ptr probability; // may be null if unitialize(...) not called - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { - return uniform_rand_from_key_pair({unsigned(network_seed::selection_bernoulli), seed}, - src.hash, - dest.hash) < probability; - } - - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { - return true; - } - - bool select_destination(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { - return true; - } - - void print(std::ostream& os) const override { - os << "(random-bernoulli " << seed << " " << probability << ")"; - } -}; - -struct network_selection_random_linear_distance_impl: public network_selection_impl { - unsigned seed; - double distance_begin; - double p_begin; - double distance_end; - double p_end; - - network_selection_random_linear_distance_impl(unsigned seed_, - double distance_begin_, - double p_begin_, - double distance_end_, - double p_end_): - seed(seed_), - distance_begin(distance_begin_), - p_begin(p_begin_), - distance_end(distance_end_), - p_end(p_end_) { - if (distance_begin > distance_end) { - std::swap(distance_begin, distance_end); - std::swap(p_begin, p_end); - } - } + network_selection_random_impl(unsigned seed, network_value p): seed(seed), p_value(std::move(p)) {} bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - const double d = distance(src.global_location, dest.global_location); - - if (d < distance_begin || d > distance_end) return false; - - const double p = (p_begin * (distance_end - d) + p_end * (d - distance_begin)) / - (distance_end - distance_begin); - - return uniform_rand_from_key_pair( - {unsigned(network_seed::selection_linear_bernoulli), seed}, + if (!probability) + throw arbor_internal_error("Trying to use unitialized named network selection."); + return uniform_rand_from_key_pair({unsigned(network_seed::selection_random), seed}, src.hash, - dest.hash) < p; + dest.hash) < probability->get(src, dest); } bool select_source(cell_kind kind, @@ -745,11 +693,14 @@ struct network_selection_random_linear_distance_impl: public network_selection_i return true; } - std::optional max_distance() const override { return distance_end; } + void initialize(const network_label_dict& dict) override { + probability = thingify(p_value, dict); + }; void print(std::ostream& os) const override { - os << "(random-linear-distance " << seed << " " << distance_begin << " " << p_begin << " " - << distance_end << " " << p_end << ")"; + os << "(random " << seed << " "; + os << p_value; + os << ")"; } }; @@ -961,6 +912,19 @@ struct network_value_scalar_impl: public network_value_impl { void print(std::ostream& os) const override { os << "(scalar " << value << ")"; } }; + +struct network_value_distance_impl: public network_value_impl { + double scale; + + network_value_distance_impl(double s): scale(s) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return scale * distance(src.global_location, dest.global_location); + } + + void print(std::ostream& os) const override { os << "(distance " << scale << ")"; } +}; + struct network_value_uniform_distribution_impl: public network_value_impl { unsigned seed = 0; std::array range; @@ -1089,6 +1053,210 @@ struct network_value_named_impl: public network_value_impl { } }; +struct network_value_add_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_add_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return left->get(src, dest) + right->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(add "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + + +struct network_value_mul_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_mul_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return left->get(src, dest) * right->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(mul "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + + +struct network_value_sub_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_sub_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return left->get(src, dest) - right->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(sub "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_div_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_div_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + const auto v_right = right ->get(src,dest); + if (!v_right) throw arbor_exception("network_value: division by 0."); + return left->get(src, dest) / right->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(div "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_max_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_max_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return std::max(left->get(src, dest), right->get(src, dest)); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(max "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_min_impl: public network_value_impl { + std::shared_ptr left, right; + + network_value_min_impl(std::shared_ptr l, + std::shared_ptr r): + left(std::move(l)), + right(std::move(r)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return std::min(left->get(src, dest), right->get(src, dest)); + } + + void initialize(const network_label_dict& dict) override { + left->initialize(dict); + right->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(min "; + left->print(os); + os << " "; + right->print(os); + os << ")"; + } +}; + +struct network_value_exp_impl: public network_value_impl { + std::shared_ptr value; + + network_value_exp_impl(std::shared_ptr v): + value(std::move(v)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + return std::exp(value->get(src, dest)); + } + + void initialize(const network_label_dict& dict) override { + value->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(exp "; + value->print(os); + os << ")"; + } +}; + +struct network_value_log_impl: public network_value_impl { + std::shared_ptr value; + + network_value_log_impl(std::shared_ptr v): + value(std::move(v)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + const auto v = value->get(src, dest); + if (v <= 0.0) throw arbor_exception("network_value: log of value <= 0.0."); + return std::log(value->get(src, dest)); + } + + void initialize(const network_label_dict& dict) override { + value->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(log "; + value->print(os); + os << ")"; + } +}; + } // namespace network_site_info::network_site_info(cell_gid_type gid, @@ -1211,8 +1379,9 @@ network_selection network_selection::inter_cell() { return network_selection(std::make_shared()); } -network_selection network_selection::random_bernoulli(unsigned seed, double p) { - return network_selection(std::make_shared(seed, p)); +network_selection network_selection::random(unsigned seed, network_value p) { + return network_selection( + std::make_shared(seed, std::move(p))); } network_selection network_selection::custom(custom_func_type func) { @@ -1227,21 +1396,16 @@ network_selection network_selection::distance_gt(double d) { return network_selection(std::make_shared(d)); } -network_selection network_selection::random_linear_distance(unsigned seed, - double distance_begin, - double p_begin, - double distance_end, - double p_end) { - return network_selection(std::make_shared( - seed, distance_begin, p_begin, distance_end, p_end)); -} - network_value::network_value(std::shared_ptr impl): impl_(std::move(impl)) {} network_value network_value::scalar(double value) { return network_value(std::make_shared(value)); } +network_value network_value::distance(double scale) { + return network_value(std::make_shared(scale)); +} + network_value network_value::uniform_distribution(unsigned seed, const std::array& range) { return network_value(std::make_shared(seed, range)); @@ -1278,6 +1442,44 @@ network_label_dict& network_label_dict::set(const std::string& name, network_val return *this; } +network_value network_value::add(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::sub(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::mul(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::div(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::exp(network_value v) { + return network_value(std::make_shared(std::move(v.impl_))); +} + +network_value network_value::log(network_value v) { + return network_value(std::make_shared(std::move(v.impl_))); +} + +network_value network_value::min(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + +network_value network_value::max(network_value left, network_value right) { + return network_value( + std::make_shared(std::move(left.impl_), std::move(right.impl_))); +} + std::optional network_label_dict::selection(const std::string& name) const { auto it = selections_.find(name); if (it != selections_.end()) return it->second; diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 797eb724e3..35519d59e2 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -387,16 +387,15 @@ eval_map_type network_eval_map{ "connection directions in reverse compared to the (chain-range ...) selection: " "(begin:integer) " "(end:integer) (step:integer)")}, - {"random-bernoulli", - make_call(arb::network_selection::random_bernoulli, + {"random", + make_call(arb::network_selection::random, "randomly selected with given seed and probability. 2 arguments: (seed:integer, " "p:real)")}, - {"random-linear-distance", - make_call( - arb::network_selection::random_linear_distance, - "randomly selected with a probability linearly interpolated between [p_begin, p_end] " - "based on the distance in the interval [distance_begin, distance_end]. 5 arguments: " - "(seed:integer, distance_begin:real, p_begin:real, distance_end:real, p_end:real)")}, + {"random", + make_call(arb::network_selection::random, + "randomly selected with given seed and probability function. Any probability value is " + "clamped to [0.0, 1.0]. 2 arguments: (seed:integer, " + "p:network-value)")}, {"distance-lt", make_call(arb::network_selection::distance_lt, "Select if distance between source and destination is less than given distance in " @@ -409,10 +408,17 @@ eval_map_type network_eval_map{ // network_value {"scalar", make_call(arb::network_value::scalar, - "network value with 1 argument: (value:real)")}, + "A fixed scalar value. 1 argument: (value:real)")}, {"network-value", make_call(arb::network_value::named, - "network value with 1 argument: (value:string)")}, + "A named network value with 1 argument: (value:string)")}, + {"distance", + make_call(arb::network_value::distance, + "Distance between source and destination scaled by given value with unit [1/um]. 1 " + "argument: (scale:real)")}, + {"distance", + make_call<>([]() { return arb::network_value::distance(1.0); }, + "Distance between source and destination scaled by 1.0 with unit [1/um].")}, {"uniform-distribution", make_call( [](unsigned seed, double begin, double end) { @@ -433,7 +439,49 @@ eval_map_type network_eval_map{ "Truncated normal random distribution with given mean and standard deviation within " "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " "end:real)")}, - + {"add", + make_conversion_fold( + arb::network_value::add, + "Sum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"sub", + make_conversion_fold( + arb::network_value::sub, + "Subtraction of network values from the first argument with at least 2 arguments: " + "((network-value | double) (network-value | double) [...(network-value | double)])")}, + {"mul", + make_conversion_fold( + arb::network_value::mul, + "Multiplication of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"div", + make_conversion_fold( + arb::network_value::div, + "Division of the first argument by each following network value sequentially with at " + "least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"min", + make_conversion_fold( + arb::network_value::min, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"max", + make_conversion_fold( + arb::network_value::max, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"log", + make_call(arb::network_value::log, + "Logarithm. 1 argument: (value:real)")}, + {"log", + make_call(arb::network_value::log, + "Logarithm. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, + "Logarithm. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, + "Logarithm. 1 argument: (value:real)")}, }; parse_label_hopefully eval(const s_expr& e, const eval_map_type& map); diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index 1888cab937..bae6ad209a 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -387,8 +387,8 @@ TEST(network_selection, round_tripping) { "(chain 3 1 0 5 7 6)", // order should be preserved "(chain (gid-range 2 14 3))", "(chain-reverse (gid-range 2 14 3))", - "(random-bernoulli 42 0.1)", - "(random-linear-distance 42 2.5 0.2 5.2 0.9)", + "(random 42 (scalar 0.1))", + "(random 42 (normal-distribution 43 0.5 0.1))", "(distance-lt 0.5)", "(distance-gt 0.5)", }; @@ -407,20 +407,69 @@ TEST(network_selection, round_tripping) { TEST(network_value, round_tripping) { auto network_literals = { "(scalar 1.3)", + "(distance 1.3)", "(network-value \"abc\")", "(uniform-distribution 42 0 0.8)", "(normal-distribution 42 0.5 0.1)", "(truncated-normal-distribution 42 0.5 0.1 0.3 0.7)", + "(log (scalar 1.3))", + "(exp (scalar 1.3))", }; for (auto l: network_literals) { EXPECT_EQ(l, round_trip_network_value(l)); } + + EXPECT_EQ("(log (scalar 1.3))", round_trip_network_value("(log 1.3)")); + EXPECT_EQ("(exp (scalar 1.3))", round_trip_network_value("(exp 1.3)")); + + EXPECT_EQ( + "(add (scalar -2.1) (scalar 3.1))", round_trip_network_value("(add -2.1 (scalar 3.1))")); + EXPECT_EQ("(add (add (add (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(add -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(sub (scalar -2.1) (scalar 3.1))", round_trip_network_value("(sub -2.1 (scalar 3.1))")); + EXPECT_EQ("(sub (sub (sub (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(sub -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(mul (scalar -2.1) (scalar 3.1))", round_trip_network_value("(mul -2.1 (scalar 3.1))")); + EXPECT_EQ("(mul (mul (mul (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(mul -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(div (scalar -2.1) (scalar 3.1))", round_trip_network_value("(div -2.1 (scalar 3.1))")); + EXPECT_EQ("(div (div (div (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(div -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(min (scalar -2.1) (scalar 3.1))", round_trip_network_value("(min -2.1 (scalar 3.1))")); + EXPECT_EQ("(min (min (min (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(min -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); + + EXPECT_EQ( + "(max (scalar -2.1) (scalar 3.1))", round_trip_network_value("(max -2.1 (scalar 3.1))")); + EXPECT_EQ("(max (max (max (scalar -2.1) (scalar 3.1)) (uniform-distribution 42 0 0.8)) " + "(network-value \"abc\"))", + round_trip_network_value( + "(max -2.1 (scalar 3.1) (uniform-distribution 42 0 0.8) (network-value \"abc\"))")); } TEST(regloc, round_tripping) { EXPECT_EQ("(cable 3 0 1)", round_trip_label("(branch 3)")); - EXPECT_EQ("(intersect (tag 1) (intersect (tag 2) (tag 3)))", round_trip_label("(intersect (tag 1) (tag 2) (tag 3))")); + EXPECT_EQ("(intersect (tag 1) (intersect (tag 2) (tag 3)))", + round_trip_label("(intersect (tag 1) (tag 2) (tag 3))")); auto region_literals = { "(cable 2 0.1 0.4)", "(region \"foo\")", From 343218b68f8cad3e502422172b7b5cd2ab3c6017 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 20 Apr 2023 13:13:35 +0200 Subject: [PATCH 24/84] add network unit tests --- arbor/include/arbor/network.hpp | 10 +- arbor/network.cpp | 27 +- test/unit/test_network.cpp | 805 ++++++++++++++++++++++++++++++++ 3 files changed, 827 insertions(+), 15 deletions(-) create mode 100644 test/unit/test_network.cpp diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index dc3478726f..6d9fcab003 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -132,6 +132,9 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection named(std::string name); + // Only select connections between different cells + static network_selection inter_cell(); + static network_selection source_cell_kind(cell_kind kind); static network_selection destination_cell_kind(cell_kind kind); @@ -158,16 +161,13 @@ class ARB_SYMBOL_VISIBLE network_selection { static network_selection join(network_selection left, network_selection right); - static network_selection symmetric_difference(network_selection left, network_selection right); - static network_selection difference(network_selection left, network_selection right); + static network_selection symmetric_difference(network_selection left, network_selection right); + // Invert the selection static network_selection complement(network_selection s); - // Only select connections between different cells - static network_selection inter_cell(); - // Random selection using the bernoulli random distribution with probability "p" between 0.0 // and 1.0 static network_selection random(unsigned seed, network_value p); diff --git a/arbor/network.cpp b/arbor/network.cpp index d4e59e5ebe..acc5efd565 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -154,7 +154,7 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im bool select_connection(const network_site_info& src, const network_site_info& dest) const override { - return src.kind == select_kind; + return dest.kind == select_kind; } bool select_source(cell_kind kind, @@ -399,13 +399,15 @@ struct network_selection_chain_impl: public network_selection_impl { bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + return !sorted_gids.empty() && + std::binary_search(sorted_gids.begin(), sorted_gids.end() - 1, gid); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); + return !sorted_gids.empty() && + std::binary_search(sorted_gids.begin() + 1, sorted_gids.end(), gid); } void print(std::ostream& os) const override { @@ -435,14 +437,16 @@ struct network_selection_chain_range_impl: public network_selection_impl { bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (gid < gid_begin || gid >= gid_end) return false; + // Return false if outside range or if equal to last element, which cannot be a source + if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (gid < gid_begin || gid >= gid_end) return false; + // Return false if outside range or if equal to first element, which cannot be a destination + if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } @@ -471,14 +475,16 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (gid < gid_begin || gid >= gid_end) return false; + // Return false if outside range or if equal to first element, which cannot be a source + if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } bool select_destination(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - if (gid < gid_begin || gid >= gid_end) return false; + // Return false if outside range or if equal to last element, which cannot be a destination + if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); } @@ -676,9 +682,10 @@ struct network_selection_random_impl: public network_selection_impl { const network_site_info& dest) const override { if (!probability) throw arbor_internal_error("Trying to use unitialized named network selection."); - return uniform_rand_from_key_pair({unsigned(network_seed::selection_random), seed}, - src.hash, - dest.hash) < probability->get(src, dest); + const auto r = uniform_rand_from_key_pair( + {unsigned(network_seed::selection_random), seed}, src.hash, dest.hash); + const auto p = (probability->get(src, dest)); + return r < p; } bool select_source(cell_kind kind, diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp new file mode 100644 index 0000000000..148c2850cf --- /dev/null +++ b/test/unit/test_network.cpp @@ -0,0 +1,805 @@ +#include + +#include + +#include "network_impl.hpp" + +#include +#include + +using namespace arb; + +namespace { +std::vector test_sites = { + {0, 0, cell_kind::cable, "a", {1, 0.5}, {0.0, 0.0, 0.0}}, + {1, 0, cell_kind::benchmark, "b", {0, 0.0}, {1.0, 0.0, 0.0}}, + {2, 0, cell_kind::lif, "c", {0, 0.0}, {2.0, 0.0, 0.0}}, + {3, 0, cell_kind::spike_source, "d", {0, 0.0}, {3.0, 0.0, 0.0}}, + {4, 0, cell_kind::cable, "e", {0, 0.2}, {4.0, 0.0, 0.0}}, + {5, 0, cell_kind::cable, "f", {5, 0.1}, {5.0, 0.0, 0.0}}, + {6, 0, cell_kind::cable, "g", {4, 0.3}, {6.0, 0.0, 0.0}}, + {7, 0, cell_kind::cable, "h", {0, 1.0}, {7.0, 0.0, 0.0}}, + {9, 0, cell_kind::cable, "i", {0, 0.1}, {12.0, 3.0, 4.0}}, + + {10, 0, cell_kind::cable, "a", {0, 0.1}, {12.0, 15.0, 16.0}}, + {10, 1, cell_kind::cable, "b", {1, 0.1}, {13.0, 15.0, 16.0}}, + {10, 2, cell_kind::cable, "c", {1, 0.5}, {14.0, 15.0, 16.0}}, + {10, 3, cell_kind::cable, "d", {1, 1.0}, {15.0, 15.0, 16.0}}, + {10, 4, cell_kind::cable, "e", {2, 0.1}, {16.0, 15.0, 16.0}}, + {10, 5, cell_kind::cable, "f", {3, 0.1}, {16.0, 16.0, 16.0}}, + {10, 6, cell_kind::cable, "g", {4, 0.1}, {12.0, 17.0, 16.0}}, + {10, 7, cell_kind::cable, "h", {5, 0.1}, {12.0, 18.0, 16.0}}, + {10, 8, cell_kind::cable, "i", {6, 0.1}, {12.0, 19.0, 16.0}}, + + {11, 0, cell_kind::cable, "abcd", {0, 0.1}, {-2.0, -5.0, 3.0}}, + {11, 1, cell_kind::cable, "cabd", {1, 0.2}, {-2.1, -5.0, 3.0}}, + {11, 2, cell_kind::cable, "cbad", {1, 0.3}, {-2.2, -5.0, 3.0}}, + {11, 3, cell_kind::cable, "acbd", {1, 1.0}, {-2.3, -5.0, 3.0}}, + {11, 4, cell_kind::cable, "bacd", {2, 0.2}, {-2.4, -5.0, 3.0}}, + {11, 5, cell_kind::cable, "bcad", {3, 0.3}, {-2.5, -5.0, 3.0}}, + {11, 6, cell_kind::cable, "dabc", {4, 0.4}, {-2.6, -5.0, 3.0}}, + {11, 7, cell_kind::cable, "dbca", {5, 0.5}, {-2.7, -5.0, 3.0}}, + {11, 8, cell_kind::cable, "dcab", {6, 0.6}, {-2.8, -5.0, 3.0}}, +}; +} + +TEST(network_selection, all) { + const auto s = thingify(network_selection::all(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& dest: test_sites) { EXPECT_TRUE(s->select_connection(source, dest)); } + } +} + + +TEST(network_selection, none) { + const auto s = thingify(network_selection::none(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_FALSE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_FALSE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& dest: test_sites) { EXPECT_FALSE(s->select_connection(source, dest)); } + } +} + +TEST(network_selection, source_cell_kind) { + const auto s = + thingify(network_selection::source_cell_kind(cell_kind::benchmark), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.kind == cell_kind::benchmark, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(source.kind == cell_kind::benchmark, s->select_connection(source, dest)); + } + } +} + + +TEST(network_selection, destination_cell_kind) { + const auto s = + thingify(network_selection::destination_cell_kind(cell_kind::benchmark), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.kind == cell_kind::benchmark, s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(dest.kind == cell_kind::benchmark, s->select_connection(source, dest)); + } + } +} + +TEST(network_selection, source_label) { + const auto s = thingify(network_selection::source_label({"b", "e"}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.label == "b" || site.label == "e", + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& source: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ( + source.label == "b" || source.label == "e", s->select_connection(source, dest)); + } + } +} + +TEST(network_selection, destination_label) { + const auto s = thingify(network_selection::destination_label({"b", "e"}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.label == "b" || site.label == "e", + s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(dest.label == "b" || dest.label == "e", s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, source_cell_vec) { + const auto s = thingify(network_selection::source_cell({{1, 5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == 1 || src.gid == 5, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, destination_cell_vec) { + const auto s = thingify(network_selection::destination_cell({{1, 5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(dest.gid == 1 || dest.gid == 5, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, source_cell_range) { + const auto s = + thingify(network_selection::source_cell(gid_range(1, 6, 4)), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == 1 || src.gid == 5, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, destination_cell_range) { + const auto s = + thingify(network_selection::destination_cell(gid_range(1, 6, 4)), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(dest.gid == 1 || dest.gid == 5, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, chain) { + const auto s = + thingify(network_selection::chain({{0,2,5}}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 0 || site.gid == 2, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 2 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ((src.gid == 0 && dest.gid == 2) || (src.gid == 2 && dest.gid == 5), + s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, chain_range) { + const auto s = + thingify(network_selection::chain({gid_range(1,8,3)}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 4 || site.gid == 7, s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ((src.gid == 1 && dest.gid == 4) || (src.gid == 4 && dest.gid == 7), + s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, chain_range_reverse) { + const auto s = + thingify(network_selection::chain_reverse({gid_range(1,8,3)}), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ( + site.gid == 7 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 4 || site.gid == 1, s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ((src.gid == 7 && dest.gid == 4) || (src.gid == 4 && dest.gid == 1), + s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, inter_cell) { + const auto s = + thingify(network_selection::inter_cell(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid != dest.gid, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, named) { + network_label_dict dict; + dict.set("mysel", network_selection::inter_cell()); + const auto s = + thingify(network_selection::named("mysel"), dict); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid != dest.gid, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, intersect) { + const auto s = thingify(network_selection::intersect(network_selection::source_cell({1}), + network_selection::destination_cell({2})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 1, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ(site.gid == 2, s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == 1 && dest.gid == 2, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, join) { + const auto s = thingify( + network_selection::join(network_selection::intersect(network_selection::source_cell({1}), + network_selection::destination_cell({2})), + network_selection::intersect( + network_selection::source_cell({4}), network_selection::destination_cell({5}))), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ(site.gid == 2 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ((src.gid == 1 && dest.gid == 2) || (src.gid == 4 && dest.gid == 5), + s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, difference) { + const auto s = + thingify(network_selection::difference(network_selection::source_cell({{0, 1, 2}}), + network_selection::source_cell({{1, 3}})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2, + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == 0 || src.gid == 2, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, symmetric_difference) { + const auto s = thingify( + network_selection::symmetric_difference( + network_selection::source_cell({{0, 1, 2}}), network_selection::source_cell({{1, 3}})), + network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2 || site.gid == 3, + s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == 0 || src.gid == 2 || src.gid == 3, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, complement) { + const auto s = thingify( + network_selection::complement(network_selection::inter_cell()), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(src.gid == dest.gid, s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, random_p_1) { + const auto s = thingify(network_selection::random(42, 1.0), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_TRUE(s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, random_p_0) { + const auto s = thingify(network_selection::random(42, 0.0), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_FALSE(s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, random_seed) { + const auto s1 = thingify(network_selection::random(42, 0.5), network_label_dict()); + const auto s2 = thingify(network_selection::random(4592304, 0.5), network_label_dict()); + + bool all_eq = true; + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + all_eq &= (s1->select_connection(src, dest) == s2->select_connection(src, dest)); + } + } + EXPECT_FALSE(all_eq); +} + +TEST(network_selection, random_reproducibility) { + const auto s = thingify(network_selection::random(42, 0.5), network_label_dict()); + + std::vector sites = { + {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = {1, 1, 0, 1, 1, 0, 0, 0, 0}; + + std::size_t i = 0; + for (const auto& src: sites) { + for (const auto& dest: sites) { + EXPECT_EQ(ref.at(i), s->select_connection(src, dest)); + ++i; + } + }; +} + +TEST(network_selection, custom) { + auto inter_cell_func = [](const network_site_info& src, const network_site_info& dest) { + return src.gid != dest.gid; + }; + const auto s = thingify(network_selection::custom(inter_cell_func), network_label_dict()); + const auto s_ref = thingify(network_selection::inter_cell(), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(s->select_connection(src, dest), s_ref->select_connection(src, dest)); + } + } +} + +TEST(network_selection, distance_lt) { + const double d = 2.1; + const auto s = + thingify(network_selection::distance_lt(d), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(distance(src.global_location, dest.global_location) < d, + s->select_connection(src, dest)); + } + } +} + +TEST(network_selection, distance_gt) { + const double d = 2.1; + const auto s = + thingify(network_selection::distance_gt(d), network_label_dict()); + + for (const auto& site: test_sites) { + EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + } + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_EQ(distance(src.global_location, dest.global_location) > d, + s->select_connection(src, dest)); + } + } +} + + +TEST(network_value, scalar) { + const auto v = thingify(network_value::scalar(2.0), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + } +} + +TEST(network_value, conversion) { + const auto v = thingify(static_cast(2.0), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + } +} + +TEST(network_value, named) { + auto dict = network_label_dict(); + dict.set("myval", network_value::scalar(2.0)); + const auto v = thingify(network_value::named("myval"), dict); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + } +} + +TEST(network_value, distance) { + const auto v = thingify(network_value::distance(), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_DOUBLE_EQ( + distance(src.global_location, dest.global_location), v->get(src, dest)); + } + } +} + +TEST(network_value, uniform_distribution) { + const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + + double mean = 0.0; + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { mean += v->get(src, dest); } + } + + mean /= test_sites.size() * test_sites.size(); + EXPECT_NEAR(mean, -1.0, 1e3); +} + +TEST(network_value, uniform_distribution_reproducibility) { + const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + + std::vector sites = { + {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = { + 1.08007184307616289, + 0.688511962867972116, + -2.83551807417554347, + 0.688511962867972116, + 0.824599122495063064, + 1.4676501652366376, + -2.83551807417554347, + 1.4676501652366376, + -4.89687864740961487, + }; + + std::size_t i = 0; + for (const auto& src: sites) { + for (const auto& dest: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + ++i; + } + }; +} + +TEST(network_value, normal_distribution) { + const double mean = 5.0; + const double std_dev = 3.0; + const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + + double sample_mean = 0.0; + double sample_dev = 0.0; + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + const auto result = v->get(src, dest); + sample_mean += result; + sample_dev += (result - mean) * (result - mean); + } + } + + sample_mean /= test_sites.size() * test_sites.size(); + sample_dev = std::sqrt(sample_dev / (test_sites.size() * test_sites.size())); + + EXPECT_NEAR(sample_mean, mean, 1e-1); + EXPECT_NEAR(sample_dev, std_dev, 1e-1); +} + +TEST(network_value, normal_distribution_reproducibility) { + const double mean = 5.0; + const double std_dev = 3.0; + const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + + std::vector sites = { + {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = { + 9.27330832850693909, + 6.29969914563416733, + 1.81597827782531063, + 6.29969914563416733, + 8.12362497769330183, + 1.52496785710691851, + 1.81597827782531063, + 1.52496785710691851, + 1.49089022270221472, + }; + + std::size_t i = 0; + for (const auto& src: sites) { + for (const auto& dest: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + ++i; + } + }; +} + +TEST(network_value, truncated_normal_distribution) { + const double mean = 5.0; + const double std_dev = 3.0; + // symmtric upper / lower bound around mean for easy check of mean + const double lower_bound = 1.0; + const double upper_bound = 9.0; + + const auto v = thingify( + network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), + network_label_dict()); + + double sample_mean = 0.0; + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + const auto result = v->get(src, dest); + EXPECT_GT(result, lower_bound); + EXPECT_LE(result, upper_bound); + sample_mean += result; + } + } + + sample_mean /= test_sites.size() * test_sites.size(); + + EXPECT_NEAR(sample_mean, mean, 1e-1); +} + +TEST(network_value, truncated_normal_distribution_reproducibility) { + const double mean = 5.0; + const double std_dev = 3.0; + + const double lower_bound = 2.0; + const double upper_bound = 9.0; + + const auto v = thingify( + network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), + network_label_dict()); + + std::vector sites = { + {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + }; + std::vector ref = { + 2.81708378066100629, + 4.82619033891918026, + 7.82585873628304096, + 4.82619033891918026, + 3.95914976610015401, + 5.74869285185564216, + 7.82585873628304096, + 5.74869285185564216, + 5.45028211635819293, + }; + + std::size_t i = 0; + for (const auto& src: sites) { + for (const auto& dest: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + ++i; + } + }; +} + +TEST(network_value, custom) { + auto func = [](const network_site_info& src, const network_site_info& dest) { + return src.global_location.x + dest.global_location.x; + }; + + const auto v = thingify(network_value::custom(func), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_DOUBLE_EQ(v->get(src, dest), src.global_location.x + dest.global_location.x); + } + } +} + +TEST(network_value, add) { + const auto v = + thingify(network_value::add(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 5.0); } + } +} + +TEST(network_value, sub) { + const auto v = + thingify(network_value::sub(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), -1.0); } + } +} + +TEST(network_value, mul) { + const auto v = + thingify(network_value::mul(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 6.0); } + } +} + +TEST(network_value, div) { + const auto v = + thingify(network_value::div(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 2.0 / 3.0); } + } +} + +TEST(network_value, exp) { + const auto v = + thingify(network_value::exp(network_value::scalar(2.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), std::exp(2.0)); } + } +} + +TEST(network_value, log) { + const auto v = thingify(network_value::log(network_value::scalar(2.0)), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), std::log(2.0)); } + } +} + +TEST(network_value, min) { + const auto v1 = + thingify(network_value::min(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + const auto v2 = + thingify(network_value::min(network_value::scalar(3.0), network_value::scalar(2.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(src, dest), 2.0); + EXPECT_DOUBLE_EQ(v2->get(src, dest), 2.0); + } + } +} + +TEST(network_value, max) { + const auto v1 = + thingify(network_value::max(network_value::scalar(2.0), network_value::scalar(3.0)), + network_label_dict()); + const auto v2 = + thingify(network_value::max(network_value::scalar(3.0), network_value::scalar(2.0)), + network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(src, dest), 3.0); + EXPECT_DOUBLE_EQ(v2->get(src, dest), 3.0); + } + } +} From 229da63de9f672a79d150307c5df68a30c4a673c Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 21 Apr 2023 16:44:44 +0200 Subject: [PATCH 25/84] move label parsing to new files --- arborio/CMakeLists.txt | 1 + arborio/include/arborio/label_parse.hpp | 29 +- arborio/include/arborio/networkio.hpp | 44 ++ arborio/label_parse.cpp | 680 +++++++----------------- arborio/networkio.cpp | 359 +++++++++++++ python/network.cpp | 1 + test/unit/test_s_expr.cpp | 3 +- 7 files changed, 599 insertions(+), 518 deletions(-) create mode 100644 arborio/include/arborio/networkio.hpp create mode 100644 arborio/networkio.cpp diff --git a/arborio/CMakeLists.txt b/arborio/CMakeLists.txt index 1e40f43922..c8c6b20d01 100644 --- a/arborio/CMakeLists.txt +++ b/arborio/CMakeLists.txt @@ -6,6 +6,7 @@ set(arborio-sources cv_policy_parse.cpp label_parse.cpp neuroml.cpp + networkio.cpp nml_parse_morphology.cpp) add_library(arborio ${arborio-sources}) diff --git a/arborio/include/arborio/label_parse.hpp b/arborio/include/arborio/label_parse.hpp index 3e0f8e42f9..d937cad149 100644 --- a/arborio/include/arborio/label_parse.hpp +++ b/arborio/include/arborio/label_parse.hpp @@ -4,11 +4,10 @@ #include #include -#include -#include #include -#include +#include #include +#include #include #include @@ -29,11 +28,6 @@ ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const arb ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s); ARB_ARBORIO_API parse_label_hopefully parse_iexpr_expression(const std::string& s); -ARB_ARBORIO_API parse_label_hopefully parse_network_selection_expression(const std::string& s); -ARB_ARBORIO_API parse_label_hopefully parse_network_value_expression(const std::string& s); -ARB_ARBORIO_API parse_label_hopefully parse_network_selection_expression(const std::string& s); -ARB_ARBORIO_API parse_label_hopefully parse_network_value_expression( - const std::string& s); namespace literals { @@ -76,22 +70,7 @@ arb::region operator "" _reg(const char* s, std::size_t) { else throw r.error(); } -inline morph_from_string operator"" _morph(const char* s, std::size_t) { return {s}; } -inline morph_from_label operator"" _lab(const char* s, std::size_t) { return {s}; } - -inline arb::network_selection operator"" _ns(const char* s, std::size_t) { - if (auto r = parse_network_selection_expression(s)) - return *r; - else - throw r.error(); -} - -inline arb::network_value operator"" _nv(const char* s, std::size_t) { - if (auto r = parse_network_value_expression(s)) - return *r; - else - throw r.error(); -} - +inline morph_from_string operator "" _morph(const char* s, std::size_t) { return {s}; } +inline morph_from_label operator "" _lab(const char* s, std::size_t) { return {s}; } } // namespace literals } // namespace arborio diff --git a/arborio/include/arborio/networkio.hpp b/arborio/include/arborio/networkio.hpp new file mode 100644 index 0000000000..c57f0d98a5 --- /dev/null +++ b/arborio/include/arborio/networkio.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace arborio { + +struct ARB_SYMBOL_VISIBLE network_parse_error: arb::arbor_exception { + explicit network_parse_error(const std::string& msg, const arb::src_location& loc); + explicit network_parse_error(const std::string& msg): arb::arbor_exception(msg) {} +}; + +template +using parse_network_hopefully = arb::util::expected; + +ARB_ARBORIO_API parse_network_hopefully parse_network_selection_expression( + const std::string& s); +ARB_ARBORIO_API parse_network_hopefully parse_network_value_expression( + const std::string& s); + +namespace literals { +inline arb::network_selection operator"" _ns(const char* s, std::size_t) { + if (auto r = parse_network_selection_expression(s)) + return *r; + else + throw r.error(); +} + +inline arb::network_value operator"" _nv(const char* s, std::size_t) { + if (auto r = parse_network_value_expression(s)) + return *r; + else + throw r.error(); +} + +} // namespace literals +} // namespace arborio diff --git a/arborio/label_parse.cpp b/arborio/label_parse.cpp index 35519d59e2..22df96481c 100644 --- a/arborio/label_parse.cpp +++ b/arborio/label_parse.cpp @@ -1,15 +1,12 @@ #include #include -#include #include #include -#include -#include #include #include -#include +#include #include @@ -18,125 +15,105 @@ namespace arborio { label_parse_error::label_parse_error(const std::string& msg, const arb::src_location& loc): - arb::arbor_exception( - concat("error in label description: ", msg, " at :", loc.line, ":", loc.column)) {} + arb::arbor_exception(concat("error in label description: ", msg," at :", loc.line, ":", loc.column)) +{} + namespace { -using eval_map_type = std::unordered_multimap; -eval_map_type eval_map{ +std::unordered_multimap eval_map { // Functions that return regions - {"region-nil", make_call<>(arb::reg::nil, "'region-nil' with 0 arguments")}, - {"all", make_call<>(arb::reg::all, "'all' with 0 arguments")}, - {"tag", make_call(arb::reg::tagged, "'tag' with 1 argment: (tag_id:integer)")}, - {"segment", - make_call(arb::reg::segment, "'segment' with 1 argment: (segment_id:integer)")}, - {"branch", make_call(arb::reg::branch, "'branch' with 1 argument: (branch_id:integer)")}, - {"cable", - make_call(arb::reg::cable, - "'cable' with 3 arguments: (branch_id:integer prox:real dist:real)")}, - {"region", make_call(arb::reg::named, "'region' with 1 argument: (name:string)")}, - {"distal-interval", - make_call(arb::reg::distal_interval, - "'distal-interval' with 2 arguments: (start:locset extent:real)")}, - {"distal-interval", - make_call( - [](arb::locset ls) { - return arb::reg::distal_interval(std::move(ls), std::numeric_limits::max()); - }, - "'distal-interval' with 1 argument: (start:locset)")}, - {"proximal-interval", - make_call(arb::reg::proximal_interval, - "'proximal-interval' with 2 arguments: (start:locset extent:real)")}, - {"proximal-interval", - make_call( - [](arb::locset ls) { - return arb::reg::proximal_interval( - std::move(ls), std::numeric_limits::max()); - }, - "'proximal_interval' with 1 argument: (start:locset)")}, - {"complete", - make_call(arb::reg::complete, "'complete' with 1 argment: (reg:region)")}, - {"radius-lt", - make_call(arb::reg::radius_lt, - "'radius-lt' with 2 arguments: (reg:region radius:real)")}, - {"radius-le", - make_call(arb::reg::radius_le, - "'radius-le' with 2 arguments: (reg:region radius:real)")}, - {"radius-gt", - make_call(arb::reg::radius_gt, - "'radius-gt' with 2 arguments: (reg:region radius:real)")}, - {"radius-ge", - make_call(arb::reg::radius_ge, - "'radius-ge' with 2 arguments: (reg:region radius:real)")}, - {"z-dist-from-root-lt", - make_call(arb::reg::z_dist_from_root_lt, - "'z-dist-from-root-lt' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-le", - make_call(arb::reg::z_dist_from_root_le, - "'z-dist-from-root-le' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-gt", - make_call(arb::reg::z_dist_from_root_gt, - "'z-dist-from-root-gt' with 1 arguments: (distance:real)")}, - {"z-dist-from-root-ge", - make_call(arb::reg::z_dist_from_root_ge, - "'z-dist-from-root-ge' with 1 arguments: (distance:real)")}, - {"complement", - make_call(arb::complement, "'complement' with 1 argment: (reg:region)")}, - {"difference", - make_call(arb::difference, - "'difference' with 2 argments: (reg:region, reg:region)")}, - {"join", - make_fold(static_cast(arb::join), - "'join' with at least 2 arguments: (region region [...region])")}, - {"intersect", - make_fold( - static_cast(arb::intersect), - "'intersect' with at least 2 arguments: (region region [...region])")}, + {"region-nil", make_call<>(arb::reg::nil, + "'region-nil' with 0 arguments")}, + {"all", make_call<>(arb::reg::all, + "'all' with 0 arguments")}, + {"tag", make_call(arb::reg::tagged, + "'tag' with 1 argment: (tag_id:integer)")}, + {"segment", make_call(arb::reg::segment, + "'segment' with 1 argment: (segment_id:integer)")}, + {"branch", make_call(arb::reg::branch, + "'branch' with 1 argument: (branch_id:integer)")}, + {"cable", make_call(arb::reg::cable, + "'cable' with 3 arguments: (branch_id:integer prox:real dist:real)")}, + {"region", make_call(arb::reg::named, + "'region' with 1 argument: (name:string)")}, + {"distal-interval", make_call(arb::reg::distal_interval, + "'distal-interval' with 2 arguments: (start:locset extent:real)")}, + {"distal-interval", make_call( + [](arb::locset ls){return arb::reg::distal_interval(std::move(ls), std::numeric_limits::max());}, + "'distal-interval' with 1 argument: (start:locset)")}, + {"proximal-interval", make_call(arb::reg::proximal_interval, + "'proximal-interval' with 2 arguments: (start:locset extent:real)")}, + {"proximal-interval", make_call( + [](arb::locset ls){return arb::reg::proximal_interval(std::move(ls), std::numeric_limits::max());}, + "'proximal_interval' with 1 argument: (start:locset)")}, + {"complete", make_call(arb::reg::complete, + "'complete' with 1 argment: (reg:region)")}, + {"radius-lt", make_call(arb::reg::radius_lt, + "'radius-lt' with 2 arguments: (reg:region radius:real)")}, + {"radius-le", make_call(arb::reg::radius_le, + "'radius-le' with 2 arguments: (reg:region radius:real)")}, + {"radius-gt", make_call(arb::reg::radius_gt, + "'radius-gt' with 2 arguments: (reg:region radius:real)")}, + {"radius-ge", make_call(arb::reg::radius_ge, + "'radius-ge' with 2 arguments: (reg:region radius:real)")}, + {"z-dist-from-root-lt", make_call(arb::reg::z_dist_from_root_lt, + "'z-dist-from-root-lt' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-le", make_call(arb::reg::z_dist_from_root_le, + "'z-dist-from-root-le' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-gt", make_call(arb::reg::z_dist_from_root_gt, + "'z-dist-from-root-gt' with 1 arguments: (distance:real)")}, + {"z-dist-from-root-ge", make_call(arb::reg::z_dist_from_root_ge, + "'z-dist-from-root-ge' with 1 arguments: (distance:real)")}, + {"complement", make_call(arb::complement, + "'complement' with 1 argment: (reg:region)")}, + {"difference", make_call(arb::difference, + "'difference' with 2 argments: (reg:region, reg:region)")}, + {"join", make_fold(static_cast(arb::join), + "'join' with at least 2 arguments: (region region [...region])")}, + {"intersect", make_fold(static_cast(arb::intersect), + "'intersect' with at least 2 arguments: (region region [...region])")}, // Functions that return locsets - {"locset-nil", make_call<>(arb::ls::nil, "'locset-nil' with 0 arguments")}, - {"root", make_call<>(arb::ls::root, "'root' with 0 arguments")}, - {"location", - make_call( - [](int bid, double pos) { return arb::ls::location(arb::msize_t(bid), pos); }, - "'location' with 2 arguments: (branch_id:integer position:real)")}, - {"terminal", make_call<>(arb::ls::terminal, "'terminal' with 0 arguments")}, - {"distal", - make_call(arb::ls::most_distal, "'distal' with 1 argument: (reg:region)")}, - {"proximal", - make_call(arb::ls::most_proximal, "'proximal' with 1 argument: (reg:region)")}, - {"distal-translate", - make_call(arb::ls::distal_translate, - "'distal-translate' with 2 arguments: (ls:locset distance:real)")}, - {"proximal-translate", - make_call(arb::ls::proximal_translate, - "'proximal-translate' with 2 arguments: (ls:locset distance:real)")}, - {"uniform", - make_call(arb::ls::uniform, - "'uniform' with 4 arguments: (reg:region, first:int, last:int, seed:int)")}, - {"on-branches", - make_call(arb::ls::on_branches, "'on-branches' with 1 argument: (pos:double)")}, - {"on-components", - make_call(arb::ls::on_components, - "'on-components' with 2 arguments: (pos:double, reg:region)")}, - {"boundary", - make_call(arb::ls::boundary, "'boundary' with 1 argument: (reg:region)")}, - {"cboundary", - make_call(arb::ls::cboundary, "'cboundary' with 1 argument: (reg:region)")}, - {"segment-boundaries", - make_call<>(arb::ls::segment_boundaries, "'segment-boundaries' with 0 arguments")}, - {"support", make_call(arb::ls::support, "'support' with 1 argument (ls:locset)")}, - {"locset", make_call(arb::ls::named, "'locset' with 1 argument: (name:string)")}, - {"restrict", - make_call(arb::ls::restrict, - "'restrict' with 2 arguments: (ls:locset, reg:region)")}, - {"join", - make_fold(static_cast(arb::join), - "'join' with at least 2 arguments: (locset locset [...locset])")}, - {"sum", - make_fold(static_cast(arb::sum), - "'sum' with at least 2 arguments: (locset locset [...locset])")}, + {"locset-nil", make_call<>(arb::ls::nil, + "'locset-nil' with 0 arguments")}, + {"root", make_call<>(arb::ls::root, + "'root' with 0 arguments")}, + {"location", make_call([](int bid, double pos){return arb::ls::location(arb::msize_t(bid), pos);}, + "'location' with 2 arguments: (branch_id:integer position:real)")}, + {"terminal", make_call<>(arb::ls::terminal, + "'terminal' with 0 arguments")}, + {"distal", make_call(arb::ls::most_distal, + "'distal' with 1 argument: (reg:region)")}, + {"proximal", make_call(arb::ls::most_proximal, + "'proximal' with 1 argument: (reg:region)")}, + {"distal-translate", make_call(arb::ls::distal_translate, + "'distal-translate' with 2 arguments: (ls:locset distance:real)")}, + {"proximal-translate", make_call(arb::ls::proximal_translate, + "'proximal-translate' with 2 arguments: (ls:locset distance:real)")}, + {"uniform", make_call(arb::ls::uniform, + "'uniform' with 4 arguments: (reg:region, first:int, last:int, seed:int)")}, + {"on-branches", make_call(arb::ls::on_branches, + "'on-branches' with 1 argument: (pos:double)")}, + {"on-components", make_call(arb::ls::on_components, + "'on-components' with 2 arguments: (pos:double, reg:region)")}, + {"boundary", make_call(arb::ls::boundary, + "'boundary' with 1 argument: (reg:region)")}, + {"cboundary", make_call(arb::ls::cboundary, + "'cboundary' with 1 argument: (reg:region)")}, + {"segment-boundaries", make_call<>(arb::ls::segment_boundaries, + "'segment-boundaries' with 0 arguments")}, + {"support", make_call(arb::ls::support, + "'support' with 1 argument (ls:locset)")}, + {"locset", make_call(arb::ls::named, + "'locset' with 1 argument: (name:string)")}, + {"restrict", make_call(arb::ls::restrict, + "'restrict' with 2 arguments: (ls:locset, reg:region)")}, + {"join", make_fold(static_cast(arb::join), + "'join' with at least 2 arguments: (locset locset [...locset])")}, + {"sum", make_fold(static_cast(arb::sum), + "'sum' with at least 2 arguments: (locset locset [...locset])")}, + // iexpr {"iexpr", make_call(arb::iexpr::named, "iexpr with 1 argument: (value:string)")}, @@ -145,92 +122,52 @@ eval_map_type eval_map{ {"pi", make_call<>(arb::iexpr::pi, "iexpr with no argument")}, - {"distance", - make_call( - static_cast(arb::iexpr::distance), + {"distance", make_call(static_cast(arb::iexpr::distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"distance", - make_call(static_cast(arb::iexpr::distance), + {"distance", make_call(static_cast(arb::iexpr::distance), "iexpr with 1 argument: (loc:locset)")}, - {"distance", - make_call( - static_cast(arb::iexpr::distance), + {"distance", make_call(static_cast(arb::iexpr::distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"distance", - make_call(static_cast(arb::iexpr::distance), + {"distance", make_call(static_cast(arb::iexpr::distance), "iexpr with 1 argument: (reg:region)")}, - {"proximal-distance", - make_call( - static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"proximal-distance", - make_call( - static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), "iexpr with 1 argument: (loc:locset)")}, - {"proximal-distance", - make_call( - static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"proximal-distance", - make_call( - static_cast(arb::iexpr::proximal_distance), + {"proximal-distance", make_call(static_cast(arb::iexpr::proximal_distance), "iexpr with 1 arguments: (reg:region)")}, - {"distal-distance", - make_call( - static_cast(arb::iexpr::distal_distance), + {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), "iexpr with 2 arguments: (scale:double, loc:locset)")}, - {"distal-distance", - make_call( - static_cast(arb::iexpr::distal_distance), + {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), "iexpr with 1 argument: (loc:locset)")}, - {"distal-distance", - make_call( - static_cast(arb::iexpr::distal_distance), + {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), "iexpr with 2 arguments: (scale:double, reg:region)")}, - {"distal-distance", - make_call( - static_cast(arb::iexpr::distal_distance), + {"distal-distance", make_call(static_cast(arb::iexpr::distal_distance), "iexpr with 1 argument: (reg:region)")}, - {"interpolation", - make_call( - static_cast( - arb::iexpr::interpolation), - "iexpr with 4 arguments: (prox_value:double, prox_list:locset, dist_value:double, " - "dist_list:locset)")}, - {"interpolation", - make_call( - static_cast( - arb::iexpr::interpolation), - "iexpr with 4 arguments: (prox_value:double, prox_list:region, dist_value:double, " - "dist_list:region)")}, - - {"radius", - make_call(static_cast(arb::iexpr::radius), - "iexpr with 1 argument: (value:double)")}, - {"radius", - make_call<>(static_cast(arb::iexpr::radius), "iexpr with no argument")}, - - {"diameter", - make_call(static_cast(arb::iexpr::diameter), - "iexpr with 1 argument: (value:double)")}, - {"diameter", - make_call<>(static_cast(arb::iexpr::diameter), "iexpr with no argument")}, + {"interpolation", make_call(static_cast(arb::iexpr::interpolation), + "iexpr with 4 arguments: (prox_value:double, prox_list:locset, dist_value:double, dist_list:locset)")}, + {"interpolation", make_call(static_cast(arb::iexpr::interpolation), + "iexpr with 4 arguments: (prox_value:double, prox_list:region, dist_value:double, dist_list:region)")}, + + {"radius", make_call(static_cast(arb::iexpr::radius), "iexpr with 1 argument: (value:double)")}, + {"radius", make_call<>(static_cast(arb::iexpr::radius), "iexpr with no argument")}, + + {"diameter", make_call(static_cast(arb::iexpr::diameter), "iexpr with 1 argument: (value:double)")}, + {"diameter", make_call<>(static_cast(arb::iexpr::diameter), "iexpr with no argument")}, {"exp", make_call(arb::iexpr::exp, "iexpr with 1 argument: (value:iexpr)")}, {"exp", make_call(arb::iexpr::exp, "iexpr with 1 argument: (value:double)")}, - {"step_right", - make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:iexpr)")}, - {"step_right", - make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:double)")}, + {"step_right", make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:iexpr)")}, + {"step_right", make_call(arb::iexpr::step_right, "iexpr with 1 argument: (value:double)")}, - {"step_left", - make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:iexpr)")}, - {"step_left", - make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:double)")}, + {"step_left", make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:iexpr)")}, + {"step_left", make_call(arb::iexpr::step_left, "iexpr with 1 argument: (value:double)")}, {"step", make_call(arb::iexpr::step, "iexpr with 1 argument: (value:iexpr)")}, {"step", make_call(arb::iexpr::step, "iexpr with 1 argument: (value:double)")}, @@ -238,260 +175,27 @@ eval_map_type eval_map{ {"log", make_call(arb::iexpr::log, "iexpr with 1 argument: (value:iexpr)")}, {"log", make_call(arb::iexpr::log, "iexpr with 1 argument: (value:double)")}, - {"add", - make_conversion_fold(arb::iexpr::add, - "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " - "double)])")}, - - {"sub", - make_conversion_fold(arb::iexpr::sub, - "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " - "double)])")}, - - {"mul", - make_conversion_fold(arb::iexpr::mul, - "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " - "double)])")}, - - {"div", - make_conversion_fold(arb::iexpr::div, - "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | " - "double)])")}, -}; + {"add", make_conversion_fold(arb::iexpr::add, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, + + {"sub", make_conversion_fold(arb::iexpr::sub, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, + + {"mul", make_conversion_fold(arb::iexpr::mul, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, -eval_map_type network_eval_map{ - {"gid-range", - make_call([](int begin, int end) { return arb::gid_range(begin, end); }, - "Gid range [begin, end) with step size 1: ((begin:integer) (end:integer))")}, - {"gid-range", - make_call( - [](int begin, int end, int step) { return arb::gid_range(begin, end, step); }, - "Gid range [begin, end) with step size: ((begin:integer) (end:integer) " - "(step:integer))")}, - - // cell kind - {"cable-cell", make_call<>([]() { return arb::cell_kind::cable; }, "Cable cell kind")}, - {"lif-cell", make_call<>([]() { return arb::cell_kind::lif; }, "Lif cell kind")}, - {"benchmark-cell", - make_call<>([]() { return arb::cell_kind::benchmark; }, "Benchmark cell kind")}, - {"spike-source-cell", - make_call<>([]() { return arb::cell_kind::spike_source; }, "Spike source cell kind")}, - - // network_selection - {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, - {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, - {"inter-cell", - make_call<>(arb::network_selection::inter_cell, - "network selection of inter-cell connections only")}, - {"network-selection", - make_call(arb::network_selection::named, - "network selection with 1 argument: (value:string)")}, - {"intersect", - make_conversion_fold(arb::network_selection::intersect, - "intersection of network selections with at least 2 arguments: " - "(network_selection network_selection [...network_selection])")}, - {"join", - make_conversion_fold(arb::network_selection::join, - "join or union operation of network selections with at least 2 arguments: " - "(network_selection network_selection [...network_selection])")}, - {"symmetric-difference", - make_conversion_fold(arb::network_selection::symmetric_difference, - "symmetric difference operation between network selections with at least 2 arguments: " - "(network_selection network_selection [...network_selection])")}, - {"difference", - make_call( - arb::network_selection::difference, - "difference of first selection with the second one: " - "(network_selection network_selection)")}, - {"complement", - make_call(arb::network_selection::complement, - "complement of given selection argument: (network_selection)")}, - {"source-cell-kind", - make_call(arb::network_selection::source_cell_kind, - "all sources of cells matching given cell kind argument: (kind:cell-kind)")}, - {"destination-cell-kind", - make_call(arb::network_selection::destination_cell_kind, - "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, - {"source-label", - make_arg_vec_call( - [](const std::vector>& vec) { - std::vector labels; - std::transform( - vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { - return std::get(x); - }); - return arb::network_selection::source_label(std::move(labels)); - }, - "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-label", - make_arg_vec_call( - [](const std::vector>& vec) { - std::vector labels; - std::transform( - vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { - return std::get(x); - }); - return arb::network_selection::destination_label(std::move(labels)); - }, - "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"source-cell", - make_arg_vec_call( - [](const std::vector>& vec) { - std::vector gids; - std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { - return std::get(x); - }); - return arb::network_selection::source_cell(std::move(gids)); - }, - "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"source-cell", - make_call(static_cast( - arb::network_selection::source_cell), - "all sources in cell with gid range: (range:gid-range)")}, - {"destination-cell", - make_arg_vec_call( - [](const std::vector>& vec) { - std::vector gids; - std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { - return std::get(x); - }); - return arb::network_selection::destination_cell(std::move(gids)); - }, - "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-cell", - make_call(static_cast( - arb::network_selection::destination_cell), - "all destinations in cell with gid range: " - "(range:gid-range)")}, - {"chain", - make_arg_vec_call( - [](const std::vector>& vec) { - std::vector gids; - std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { - return std::get(x); - }); - return arb::network_selection::chain(std::move(gids)); - }, - "A chain of connections in the given order of gids in list, such that entry \"i\" is " - "the source and entry \"i+1\" the destination: (gid:integer) [...(gid:integer)]")}, - {"chain", - make_call( - static_cast(arb::network_selection::chain), - "A chain of connections for all gids in range [begin, end) with given step size. Each " - "entry \"i\" is connected as source to the destination \"i+1\": (begin:integer) " - "(end:integer) (step:integer)")}, - {"chain-reverse", - make_call(arb::network_selection::chain_reverse, - "A chain of connections for all gids in range [begin, end) with given step size. Each " - "entry \"i+1\" is connected as source to the destination \"i\". This results in " - "connection directions in reverse compared to the (chain-range ...) selection: " - "(begin:integer) " - "(end:integer) (step:integer)")}, - {"random", - make_call(arb::network_selection::random, - "randomly selected with given seed and probability. 2 arguments: (seed:integer, " - "p:real)")}, - {"random", - make_call(arb::network_selection::random, - "randomly selected with given seed and probability function. Any probability value is " - "clamped to [0.0, 1.0]. 2 arguments: (seed:integer, " - "p:network-value)")}, - {"distance-lt", - make_call(arb::network_selection::distance_lt, - "Select if distance between source and destination is less than given distance in " - "micro meter: (distance:real)")}, - {"distance-gt", - make_call(arb::network_selection::distance_gt, - "Select if distance between source and destination is greater than given distance in " - "micro meter: (distance:real)")}, - - // network_value - {"scalar", - make_call(arb::network_value::scalar, - "A fixed scalar value. 1 argument: (value:real)")}, - {"network-value", - make_call(arb::network_value::named, - "A named network value with 1 argument: (value:string)")}, - {"distance", - make_call(arb::network_value::distance, - "Distance between source and destination scaled by given value with unit [1/um]. 1 " - "argument: (scale:real)")}, - {"distance", - make_call<>([]() { return arb::network_value::distance(1.0); }, - "Distance between source and destination scaled by 1.0 with unit [1/um].")}, - {"uniform-distribution", - make_call( - [](unsigned seed, double begin, double end) { - return arb::network_value::uniform_distribution(seed, {begin, end}); - }, - "Uniform random distribution within interval [begin, end): (seed:integer, begin:real, " - "end:real)")}, - {"normal-distribution", - make_call(arb::network_value::normal_distribution, - "Normal random distribution with given mean and standard deviation: (seed:integer, " - "mean:real, std_deviation:real)")}, - {"truncated-normal-distribution", - make_call( - [](unsigned seed, double mean, double std_deviation, double begin, double end) { - return arb::network_value::truncated_normal_distribution( - seed, mean, std_deviation, {begin, end}); - }, - "Truncated normal random distribution with given mean and standard deviation within " - "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " - "end:real)")}, - {"add", - make_conversion_fold( - arb::network_value::add, - "Sum of network values with at least 2 arguments: ((network-value | double) " - "(network-value | double) [...(network-value | double)])")}, - {"sub", - make_conversion_fold( - arb::network_value::sub, - "Subtraction of network values from the first argument with at least 2 arguments: " - "((network-value | double) (network-value | double) [...(network-value | double)])")}, - {"mul", - make_conversion_fold( - arb::network_value::mul, - "Multiplication of network values with at least 2 arguments: ((network-value | double) " - "(network-value | double) [...(network-value | double)])")}, - {"div", - make_conversion_fold( - arb::network_value::div, - "Division of the first argument by each following network value sequentially with at " - "least 2 arguments: ((network-value | double) " - "(network-value | double) [...(network-value | double)])")}, - {"min", - make_conversion_fold( - arb::network_value::min, - "Minimum of network values with at least 2 arguments: ((network-value | double) " - "(network-value | double) [...(network-value | double)])")}, - {"max", - make_conversion_fold( - arb::network_value::max, - "Minimum of network values with at least 2 arguments: ((network-value | double) " - "(network-value | double) [...(network-value | double)])")}, - {"log", - make_call(arb::network_value::log, - "Logarithm. 1 argument: (value:real)")}, - {"log", - make_call(arb::network_value::log, - "Logarithm. 1 argument: (value:real)")}, - {"exp", - make_call(arb::network_value::exp, - "Logarithm. 1 argument: (value:real)")}, - {"exp", - make_call(arb::network_value::exp, - "Logarithm. 1 argument: (value:real)")}, + {"div", make_conversion_fold(arb::iexpr::div, "iexpr with at least 2 arguments: ((iexpr | double) (iexpr | double) [...(iexpr | double)])")}, }; -parse_label_hopefully eval(const s_expr& e, const eval_map_type& map); +parse_label_hopefully eval(const s_expr& e); -parse_label_hopefully> eval_args(const s_expr& e, const eval_map_type& map) { - if (!e) return {std::vector{}}; // empty argument list +parse_label_hopefully> eval_args(const s_expr& e) { + if (!e) return {std::vector{}}; // empty argument list std::vector args; for (auto& h: e) { - if (auto arg = eval(h, map)) { args.push_back(std::move(*arg)); } - else { return util::unexpected(std::move(arg.error())); } + if (auto arg=eval(h)) { + args.push_back(std::move(*arg)); + } + else { + return util::unexpected(std::move(arg.error())); + } } return args; } @@ -505,20 +209,20 @@ parse_label_hopefully> eval_args(const s_expr& e, const ev // types (integer, real, region, locset) are inferred from the arguments. std::string eval_description(const char* name, const std::vector& args) { auto type_string = [](const std::type_info& t) -> const char* { - if (t == typeid(int)) return "integer"; - if (t == typeid(double)) return "real"; - if (t == typeid(arb::region)) return "region"; - if (t == typeid(arb::locset)) return "locset"; + if (t==typeid(int)) return "integer"; + if (t==typeid(double)) return "real"; + if (t==typeid(arb::region)) return "region"; + if (t==typeid(arb::locset)) return "locset"; return "unknown"; }; const auto nargs = args.size(); - std::string msg = concat("'", name, "' with ", nargs, "argument", nargs != 1u ? "s:" : ":"); + std::string msg = concat("'", name, "' with ", nargs, "argument", nargs!=1u?"s:" : ":"); if (nargs) { msg += " ("; bool first = true; for (auto& a: args) { - msg += concat(first ? "" : " ", type_string(a.type())); + msg += concat(first?"":" ", type_string(a.type())); first = false; } msg += ")"; @@ -538,108 +242,100 @@ std::string eval_description(const char* name, const std::vector& args // a label_error_state with an error string and location. // // If there was an unexpected/fatal error, an exception will be thrown. -parse_label_hopefully eval(const s_expr& e, const eval_map_type& map) { - if (e.is_atom()) { return eval_atom(e); } +parse_label_hopefully eval(const s_expr& e) { + if (e.is_atom()) { + return eval_atom(e); + } if (e.head().is_atom()) { // This must be a function evaluation, where head is the function name, and // tail is a list of arguments. // Evaluate the arguments, and return error state if an error occurred. - auto args = eval_args(e.tail(), map); - if (!args) { return util::unexpected(args.error()); } + auto args = eval_args(e.tail()); + if (!args) { + return util::unexpected(args.error()); + } // Find all candidate functions that match the name of the function. auto& name = e.head().atom().spelling; - auto matches = map.equal_range(name); + auto matches = eval_map.equal_range(name); // Search for a candidate that matches the argument list. - for (auto i = matches.first; i != matches.second; ++i) { - if (i->second.match_args(*args)) { // found a match: evaluate and return. + for (auto i=matches.first; i!=matches.second; ++i) { + if (i->second.match_args(*args)) { // found a match: evaluate and return. return i->second.eval(*args); } } // Unable to find a match: try to return a helpful error message. const auto nc = std::distance(matches.first, matches.second); - auto msg = concat("No matches for ", - eval_description(name.c_str(), *args), - "\n There are ", - nc, - " potential candidates", - nc ? ":" : "."); + auto msg = concat("No matches for ", eval_description(name.c_str(), *args), "\n There are ", nc, " potential candidates", nc?":":"."); int count = 0; - for (auto i = matches.first; i != matches.second; ++i) { + for (auto i=matches.first; i!=matches.second; ++i) { msg += concat("\n Candidate ", ++count, " ", i->second.message); } return util::unexpected(label_parse_error(msg, location(e))); } return util::unexpected(label_parse_error( - concat("'", e, "' is not either integer, real expression of the form (op )"), - location(e))); + concat("'", e, "' is not either integer, real expression of the form (op )"), + location(e))); } -} // namespace +} // namespace ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const std::string& e) { - return eval(parse_s_expr(e), eval_map); + return eval(parse_s_expr(e)); } ARB_ARBORIO_API parse_label_hopefully parse_label_expression(const s_expr& s) { - return eval(s, eval_map); + return eval(s); } ARB_ARBORIO_API parse_label_hopefully parse_region_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(region)) { return {std::move(std::any_cast(*e))}; } + if (auto e = eval(parse_s_expr(s))) { + if (e->type() == typeid(region)) { + return {std::move(std::any_cast(*e))}; + } if (e->type() == typeid(std::string)) { return {reg::named(std::move(std::any_cast(*e)))}; } - return util::unexpected(label_parse_error(concat("Invalid region description: '", - s, - "' is neither a valid region expression or region label string."))); + return util::unexpected( + label_parse_error( + concat("Invalid region description: '", s ,"' is neither a valid region expression or region label string."))); + } + else { + return util::unexpected(label_parse_error(std::string()+e.error().what())); } - else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } ARB_ARBORIO_API parse_label_hopefully parse_locset_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(locset)) { return {std::move(std::any_cast(*e))}; } + if (auto e = eval(parse_s_expr(s))) { + if (e->type() == typeid(locset)) { + return {std::move(std::any_cast(*e))}; + } if (e->type() == typeid(std::string)) { return {ls::named(std::move(std::any_cast(*e)))}; } - return util::unexpected(label_parse_error(concat("Invalid region description: '", - s, - "' is neither a valid locset expression or locset label string."))); + return util::unexpected( + label_parse_error( + concat("Invalid region description: '", s ,"' is neither a valid locset expression or locset label string."))); } - else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } -} - -parse_label_hopefully parse_iexpr_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s), eval_map)) { - if (e->type() == typeid(iexpr)) { return {std::move(std::any_cast(*e))}; } - return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); + else { + return util::unexpected(label_parse_error(std::string()+e.error().what())); } - else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } -parse_label_hopefully parse_network_selection_expression( - const std::string& s) { - if (auto e = eval(parse_s_expr(s), network_eval_map)) { - if (e->type() == typeid(arb::network_selection)) { - return {std::move(std::any_cast(*e))}; +parse_label_hopefully parse_iexpr_expression(const std::string& s) { + if (auto e = eval(parse_s_expr(s))) { + if (e->type() == typeid(iexpr)) { + return {std::move(std::any_cast(*e))}; } - return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); + return util::unexpected( + label_parse_error( + concat("Invalid iexpr description: '", s))); } - else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } -} - -parse_label_hopefully parse_network_value_expression(const std::string& s) { - if (auto e = eval(parse_s_expr(s), network_eval_map)) { - if (e->type() == typeid(arb::network_value)) { - return {std::move(std::any_cast(*e))}; - } - return util::unexpected(label_parse_error(concat("Invalid iexpr description: '", s))); + else { + return util::unexpected(label_parse_error(std::string()+e.error().what())); } - else { return util::unexpected(label_parse_error(std::string() + e.error().what())); } } -} // namespace arborio +} // namespace arborio diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp new file mode 100644 index 0000000000..64a5c9376d --- /dev/null +++ b/arborio/networkio.cpp @@ -0,0 +1,359 @@ +#include +#include +#include + +#include + +#include +#include +#include + +#include "parse_helpers.hpp" + +namespace arborio { + +network_parse_error::network_parse_error(const std::string& msg, const arb::src_location& loc): + arb::arbor_exception( + concat("error in label description: ", msg, " at :", loc.line, ":", loc.column)) {} + +namespace { +using eval_map_type = std::unordered_multimap; + +eval_map_type network_eval_map{ + {"gid-range", + make_call([](int begin, int end) { return arb::gid_range(begin, end); }, + "Gid range [begin, end) with step size 1: ((begin:integer) (end:integer))")}, + {"gid-range", + make_call( + [](int begin, int end, int step) { return arb::gid_range(begin, end, step); }, + "Gid range [begin, end) with step size: ((begin:integer) (end:integer) " + "(step:integer))")}, + + // cell kind + {"cable-cell", make_call<>([]() { return arb::cell_kind::cable; }, "Cable cell kind")}, + {"lif-cell", make_call<>([]() { return arb::cell_kind::lif; }, "Lif cell kind")}, + {"benchmark-cell", + make_call<>([]() { return arb::cell_kind::benchmark; }, "Benchmark cell kind")}, + {"spike-source-cell", + make_call<>([]() { return arb::cell_kind::spike_source; }, "Spike source cell kind")}, + + // network_selection + {"all", make_call<>(arb::network_selection::all, "network selection of all cells and labels")}, + {"none", make_call<>(arb::network_selection::none, "network selection of no cells and labels")}, + {"inter-cell", + make_call<>(arb::network_selection::inter_cell, + "network selection of inter-cell connections only")}, + {"network-selection", + make_call(arb::network_selection::named, + "network selection with 1 argument: (value:string)")}, + {"intersect", + make_conversion_fold(arb::network_selection::intersect, + "intersection of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"join", + make_conversion_fold(arb::network_selection::join, + "join or union operation of network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"symmetric-difference", + make_conversion_fold(arb::network_selection::symmetric_difference, + "symmetric difference operation between network selections with at least 2 arguments: " + "(network_selection network_selection [...network_selection])")}, + {"difference", + make_call( + arb::network_selection::difference, + "difference of first selection with the second one: " + "(network_selection network_selection)")}, + {"complement", + make_call(arb::network_selection::complement, + "complement of given selection argument: (network_selection)")}, + {"source-cell-kind", + make_call(arb::network_selection::source_cell_kind, + "all sources of cells matching given cell kind argument: (kind:cell-kind)")}, + {"destination-cell-kind", + make_call(arb::network_selection::destination_cell_kind, + "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, + {"source-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::source_label(std::move(labels)); + }, + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"destination-label", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector labels; + std::transform( + vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::destination_label(std::move(labels)); + }, + "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"source-cell", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::source_cell(std::move(gids)); + }, + "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"source-cell", + make_call(static_cast( + arb::network_selection::source_cell), + "all sources in cell with gid range: (range:gid-range)")}, + {"destination-cell", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::destination_cell(std::move(gids)); + }, + "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"destination-cell", + make_call(static_cast( + arb::network_selection::destination_cell), + "all destinations in cell with gid range: " + "(range:gid-range)")}, + {"chain", + make_arg_vec_call( + [](const std::vector>& vec) { + std::vector gids; + std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { + return std::get(x); + }); + return arb::network_selection::chain(std::move(gids)); + }, + "A chain of connections in the given order of gids in list, such that entry \"i\" is " + "the source and entry \"i+1\" the destination: (gid:integer) [...(gid:integer)]")}, + {"chain", + make_call( + static_cast(arb::network_selection::chain), + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i\" is connected as source to the destination \"i+1\": (begin:integer) " + "(end:integer) (step:integer)")}, + {"chain-reverse", + make_call(arb::network_selection::chain_reverse, + "A chain of connections for all gids in range [begin, end) with given step size. Each " + "entry \"i+1\" is connected as source to the destination \"i\". This results in " + "connection directions in reverse compared to the (chain-range ...) selection: " + "(begin:integer) " + "(end:integer) (step:integer)")}, + {"random", + make_call(arb::network_selection::random, + "randomly selected with given seed and probability. 2 arguments: (seed:integer, " + "p:real)")}, + {"random", + make_call(arb::network_selection::random, + "randomly selected with given seed and probability function. Any probability value is " + "clamped to [0.0, 1.0]. 2 arguments: (seed:integer, " + "p:network-value)")}, + {"distance-lt", + make_call(arb::network_selection::distance_lt, + "Select if distance between source and destination is less than given distance in " + "micro meter: (distance:real)")}, + {"distance-gt", + make_call(arb::network_selection::distance_gt, + "Select if distance between source and destination is greater than given distance in " + "micro meter: (distance:real)")}, + + // network_value + {"scalar", + make_call(arb::network_value::scalar, + "A fixed scalar value. 1 argument: (value:real)")}, + {"network-value", + make_call(arb::network_value::named, + "A named network value with 1 argument: (value:string)")}, + {"distance", + make_call(arb::network_value::distance, + "Distance between source and destination scaled by given value with unit [1/um]. 1 " + "argument: (scale:real)")}, + {"distance", + make_call<>([]() { return arb::network_value::distance(1.0); }, + "Distance between source and destination scaled by 1.0 with unit [1/um].")}, + {"uniform-distribution", + make_call( + [](unsigned seed, double begin, double end) { + return arb::network_value::uniform_distribution(seed, {begin, end}); + }, + "Uniform random distribution within interval [begin, end): (seed:integer, begin:real, " + "end:real)")}, + {"normal-distribution", + make_call(arb::network_value::normal_distribution, + "Normal random distribution with given mean and standard deviation: (seed:integer, " + "mean:real, std_deviation:real)")}, + {"truncated-normal-distribution", + make_call( + [](unsigned seed, double mean, double std_deviation, double begin, double end) { + return arb::network_value::truncated_normal_distribution( + seed, mean, std_deviation, {begin, end}); + }, + "Truncated normal random distribution with given mean and standard deviation within " + "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " + "end:real)")}, + {"add", + make_conversion_fold( + arb::network_value::add, + "Sum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"sub", + make_conversion_fold( + arb::network_value::sub, + "Subtraction of network values from the first argument with at least 2 arguments: " + "((network-value | double) (network-value | double) [...(network-value | double)])")}, + {"mul", + make_conversion_fold( + arb::network_value::mul, + "Multiplication of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"div", + make_conversion_fold( + arb::network_value::div, + "Division of the first argument by each following network value sequentially with at " + "least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"min", + make_conversion_fold( + arb::network_value::min, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"max", + make_conversion_fold( + arb::network_value::max, + "Minimum of network values with at least 2 arguments: ((network-value | double) " + "(network-value | double) [...(network-value | double)])")}, + {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, + {"log", + make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, + {"exp", make_call(arb::network_value::exp, "Logarithm. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, "Logarithm. 1 argument: (value:real)")}, +}; + +parse_network_hopefully eval(const s_expr& e, const eval_map_type& map); + +parse_network_hopefully> eval_args(const s_expr& e, + const eval_map_type& map) { + if (!e) return {std::vector{}}; // empty argument list + std::vector args; + for (auto& h: e) { + if (auto arg = eval(h, map)) { args.push_back(std::move(*arg)); } + else { return util::unexpected(std::move(arg.error())); } + } + return args; +} + +// Generate a string description of a function evaluation of the form: +// Example output: +// 'foo' with 1 argument: (real) +// 'bar' with 0 arguments +// 'cat' with 3 arguments: (locset region integer) +// Where 'foo', 'bar' and 'cat' are the name of the function, and the +// types (integer, real, region, locset) are inferred from the arguments. +std::string eval_description(const char* name, const std::vector& args) { + auto type_string = [](const std::type_info& t) -> const char* { + if (t == typeid(int)) return "integer"; + if (t == typeid(double)) return "real"; + if (t == typeid(arb::region)) return "region"; + if (t == typeid(arb::locset)) return "locset"; + return "unknown"; + }; + + const auto nargs = args.size(); + std::string msg = concat("'", name, "' with ", nargs, "argument", nargs != 1u ? "s:" : ":"); + if (nargs) { + msg += " ("; + bool first = true; + for (auto& a: args) { + msg += concat(first ? "" : " ", type_string(a.type())); + first = false; + } + msg += ")"; + } + return msg; +} + +// Evaluate an s expression. +// On success the result is wrapped in std::any, where the result is one of: +// int : an integer atom +// double : a real atom +// std::string : a string atom: to be treated as a label +// arb::region : a region +// arb::locset : a locset +// +// If there invalid input is detected, hopefully return value contains +// a label_error_state with an error string and location. +// +// If there was an unexpected/fatal error, an exception will be thrown. +parse_network_hopefully eval(const s_expr& e, const eval_map_type& map) { + if (e.is_atom()) { return eval_atom(e); } + if (e.head().is_atom()) { + // This must be a function evaluation, where head is the function name, and + // tail is a list of arguments. + + // Evaluate the arguments, and return error state if an error occurred. + auto args = eval_args(e.tail(), map); + if (!args) { return util::unexpected(args.error()); } + + // Find all candidate functions that match the name of the function. + auto& name = e.head().atom().spelling; + auto matches = map.equal_range(name); + // Search for a candidate that matches the argument list. + for (auto i = matches.first; i != matches.second; ++i) { + if (i->second.match_args(*args)) { // found a match: evaluate and return. + return i->second.eval(*args); + } + } + + // Unable to find a match: try to return a helpful error message. + const auto nc = std::distance(matches.first, matches.second); + auto msg = concat("No matches for ", + eval_description(name.c_str(), *args), + "\n There are ", + nc, + " potential candidates", + nc ? ":" : "."); + int count = 0; + for (auto i = matches.first; i != matches.second; ++i) { + msg += concat("\n Candidate ", ++count, " ", i->second.message); + } + return util::unexpected(network_parse_error(msg, location(e))); + } + + return util::unexpected(network_parse_error( + concat("'", e, "' is not either integer, real expression of the form (op )"), + location(e))); +} + +} // namespace + +parse_network_hopefully parse_network_selection_expression( + const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_selection)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected(network_parse_error(concat("Invalid iexpr description: '", s))); + } + else { return util::unexpected(network_parse_error(std::string() + e.error().what())); } +} + +parse_network_hopefully parse_network_value_expression(const std::string& s) { + if (auto e = eval(parse_s_expr(s), network_eval_map)) { + if (e->type() == typeid(arb::network_value)) { + return {std::move(std::any_cast(*e))}; + } + return util::unexpected(network_parse_error(concat("Invalid iexpr description: '", s))); + } + else { return util::unexpected(network_parse_error(std::string() + e.error().what())); } +} + +} // namespace arborio diff --git a/python/network.cpp b/python/network.cpp index 7331479c6f..c0184cca1f 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index bae6ad209a..15593575cd 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -10,9 +10,10 @@ #include #include -#include #include +#include #include +#include #include "parse_s_expr.hpp" #include "util/strprintf.hpp" From ce753852126d36600dc4a7e935978f845bda8508 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 2 May 2023 11:58:54 +0200 Subject: [PATCH 26/84] network_value operators --- arbor/include/arbor/network.hpp | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 6d9fcab003..91f810feae 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -117,6 +117,28 @@ class ARB_SYMBOL_VISIBLE network_value { std::shared_ptr impl_; }; +ARB_ARBOR_API inline network_value operator+(network_value a, network_value b) { + return network_value::add(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator-(network_value a, network_value b) { + return network_value::sub(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator*(network_value a, network_value b) { + return network_value::mul(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator/(network_value a, network_value b) { + return network_value::div(std::move(a), std::move(b)); +} + +ARB_ARBOR_API inline network_value operator+(network_value a) { return a; } + +ARB_ARBOR_API inline network_value operator-(network_value a) { + return network_value::mul(-1.0, std::move(a)); +} + class ARB_SYMBOL_VISIBLE network_selection { public: using custom_func_type = @@ -226,4 +248,18 @@ struct network_description { network_label_dict dict; }; +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); + +template +network_selection join(network_selection l, network_selection r, Args... args) { + return join(join(std::move(l), std::move(r)), std::move(args)...); +} + +ARB_ARBOR_API network_selection intersect(network_selection left, network_selection right); + +template +network_selection intersect(network_selection l, network_selection r, Args... args) { + return intersect(intersect(std::move(l), std::move(r)), std::move(args)...); +} + } // namespace arb From 2fec3f1f5767fe5c2e47204819c67fae781c10a5 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 2 May 2023 11:59:31 +0200 Subject: [PATCH 27/84] examples --- arbor/network.cpp | 8 + example/CMakeLists.txt | 1 + example/network_description/CMakeLists.txt | 4 + example/network_description/branch_cell.hpp | 130 +++++++ .../network_description.cpp | 337 ++++++++++++++++++ example/network_description/readme.md | 3 + python/example/network_description.py | 184 ++++++++++ 7 files changed, 667 insertions(+) create mode 100644 example/network_description/CMakeLists.txt create mode 100644 example/network_description/branch_cell.hpp create mode 100644 example/network_description/network_description.cpp create mode 100644 example/network_description/readme.md create mode 100755 python/example/network_description.py diff --git a/arbor/network.cpp b/arbor/network.cpp index acc5efd565..6c6518e250 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -1511,4 +1511,12 @@ ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_value& v) return os; } +ARB_ARBOR_API network_selection join(network_selection left, network_selection right) { + return network_selection::join(std::move(left), std::move(right)); +} + +ARB_ARBOR_API network_selection intersect(network_selection left, network_selection right) { + return network_selection::intersect(std::move(left), std::move(right)); +} + } // namespace arb diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index e97d41acab..8a32228d60 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -16,3 +16,4 @@ add_subdirectory(lfp) add_subdirectory(diffusion) add_subdirectory(v_clamp) add_subdirectory(ornstein_uhlenbeck) +add_subdirectory(network_description) diff --git a/example/network_description/CMakeLists.txt b/example/network_description/CMakeLists.txt new file mode 100644 index 0000000000..c4db25b297 --- /dev/null +++ b/example/network_description/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(network_description EXCLUDE_FROM_ALL network_description.cpp) +add_dependencies(examples network_description) + +target_link_libraries(network_description PRIVATE arbor arborio arborenv arbor-sup ${json_library_name}) diff --git a/example/network_description/branch_cell.hpp b/example/network_description/branch_cell.hpp new file mode 100644 index 0000000000..61740a9929 --- /dev/null +++ b/example/network_description/branch_cell.hpp @@ -0,0 +1,130 @@ +#pragma once + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +#include +#include + +using namespace arborio::literals; + +// Parameters used to generate the random cell morphologies. +struct cell_parameters { + cell_parameters() = default; + + // Maximum number of levels in the cell (not including the soma) + unsigned max_depth = 5; + + // The following parameters are described as ranges. + // The first value is at the soma, and the last value is used on the last level. + // Values at levels in between are found by linear interpolation. + std::array branch_probs = {1.0, 0.5}; // Probability of a branch occuring. + std::array compartments = {20, 2}; // Compartment count on a branch. + std::array lengths = {200, 20}; // Length of branch in μm. + + // The number of synapses per cell. + unsigned synapses = 1; +}; + +cell_parameters parse_cell_parameters(nlohmann::json& json) { + cell_parameters params; + sup::param_from_json(params.max_depth, "depth", json); + sup::param_from_json(params.branch_probs, "branch-probs", json); + sup::param_from_json(params.compartments, "compartments", json); + sup::param_from_json(params.lengths, "lengths", json); + sup::param_from_json(params.synapses, "synapses", json); + + return params; +} + +// Helper used to interpolate in branch_cell. +template +double interp(const std::array& r, unsigned i, unsigned n) { + double p = i * 1./(n-1); + double r0 = r[0]; + double r1 = r[1]; + return r[0] + p*(r1-r0); +} + +arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& params) { + arb::segment_tree tree; + + // Add soma. + double srad = 12.6157/2.0; // soma radius + int stag = 1; // soma tag + tree.append(arb::mnpos, {0, 0,-srad, srad}, {0, 0, srad, srad}, stag); // For area of 500 μm². + + std::vector> levels; + levels.push_back({0}); + + // Standard mersenne_twister_engine seeded with gid. + std::mt19937 gen(gid); + std::uniform_real_distribution dis(0, 1); + + double drad = 0.5; // Diameter of 1 μm for each dendrite cable. + int dtag = 3; // Dendrite tag. + + double dist_from_soma = srad; // Start dendrite at the edge of the soma. + for (unsigned i=0; i sec_ids; + for (unsigned sec: levels[i]) { + for (unsigned j=0; j<2; ++j) { + if (dis(gen) 1) { + decor.place(arb::ls::uniform("dend"_lab, 0, params.synapses - 2, gid), arb::synapse("expsyn"), "extra_syns"); + } + + // Make a CV between every sample in the sample tree. + decor.set_default(arb::cv_policy_every_segment()); + + arb::cable_cell cell(arb::morphology(tree), decor, labels); + + return cell; +} diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp new file mode 100644 index 0000000000..8e03fd0736 --- /dev/null +++ b/example/network_description/network_description.cpp @@ -0,0 +1,337 @@ +/* + * A miniapp that demonstrates how to use network expressions + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "arbor/network.hpp" +#include "branch_cell.hpp" + +#ifdef ARB_MPI_ENABLED +#include +#include +#endif + +struct ring_params { + ring_params() = default; + + std::string name = "default"; + unsigned num_cells = 100; + double min_delay = 10; + double duration = 1000; + cell_parameters cell; +}; + +ring_params read_options(int argc, char** argv); +using arb::cell_gid_type; +using arb::cell_lid_type; +using arb::cell_size_type; +using arb::cell_member_type; +using arb::cell_kind; +using arb::time_type; + +// Writes voltage trace as a json file. +void write_trace_json(const arb::trace_data& trace); + +// Generate a cell. +arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& params); + +class ring_recipe: public arb::recipe { +public: + ring_recipe(unsigned num_cells, cell_parameters params, unsigned min_delay): + num_cells_(num_cells), + cell_params_(params), + min_delay_(min_delay) + { + gprop_.default_parameters = arb::neuron_parameter_defaults; + } + + cell_size_type num_cells() const override { + return num_cells_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + return branch_cell(gid, cell_params_); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + + arb::isometry get_cell_isometry(cell_gid_type gid) const override { + // place cells with equal distance on a circle + const double angle = 2 * 3.1415926535897932 * gid / num_cells_; + const double radius = 500.0; + return arb::isometry::translate(radius * std::cos(angle), radius * std::sin(angle), 0.0); + }; + + std::optional network_description() const override { + // create a chain + auto ring = arb::network_selection::chain(arb::gid_range(0, num_cells_)); + // connect front and back of chain to form ring + ring = arb::join(ring, + arb::intersect(arb::network_selection::source_cell({num_cells_ - 1}), + arb::network_selection::destination_cell({0}))); + + // Create random connections with probability proportional to the inverse distance within a + // radius + const double max_dist = 400.0; + auto probability = (max_dist - arb::network_value::distance()) / max_dist; + + // restrict to inter-cell connections and to distance within radius + auto seed = 42; + auto rand = intersect(arb::network_selection::random(seed, probability), + arb::network_selection::distance_lt(max_dist), + arb::network_selection::inter_cell()); + + // combine ring with random selection + auto s = join(ring, rand); + + // restrict to certain source and destination labels + s = arb::intersect(s, + arb::network_selection::source_label({"detector"}), + arb::network_selection::destination_label({"primary_syn"})); + + // normal distributed weight with mean 0.05 μS, standard deviation 0.02 μS + // and truncated to [0.025, 0.075] + auto w = "(truncated-normal-distribution 42 0.05 0.02 0.025 0.075)"_nv; + // note: We are using s-expressions here as an alternative for creating a network_value + + return arb::network_description{s, w, min_delay_, {}}; + }; + + // Return one event generator on gid 0. This generates a single event that will + // kick start the spiking. + std::vector event_generators(cell_gid_type gid) const override { + std::vector gens; + if (!gid) { + gens.push_back(arb::explicit_generator({"primary_syn"}, event_weight_, std::vector{1.0f})); + } + return gens; + } + + std::vector get_probes(cell_gid_type gid) const override { + // Measure membrane voltage at end of soma. + arb::mlocation loc{0, 0.0}; + return {arb::cable_probe_membrane_voltage{loc}}; + } + + std::any get_global_properties(arb::cell_kind) const override { + return gprop_; + } + + + +private: + cell_size_type num_cells_; + cell_parameters cell_params_; + double min_delay_; + float event_weight_ = 0.05; + arb::cable_cell_global_properties gprop_; +}; + +int main(int argc, char** argv) { + try { + bool root = true; + + arb::proc_allocation resources; + resources.num_threads = arbenv::default_concurrency(); + +#ifdef ARB_MPI_ENABLED + arbenv::with_mpi guard(argc, argv, false); + resources.gpu_id = arbenv::find_private_gpu(MPI_COMM_WORLD); + auto context = arb::make_context(resources, MPI_COMM_WORLD); + root = arb::rank(context) == 0; +#else + resources.gpu_id = arbenv::default_gpu(); + auto context = arb::make_context(resources); +#endif + +#ifdef ARB_PROFILE_ENABLED + arb::profile::profiler_initialize(context); +#endif + + std::cout << sup::mask_stream(root); + + // Print a banner with information about hardware configuration + std::cout << "gpu: " << (has_gpu(context)? "yes": "no") << "\n"; + std::cout << "threads: " << num_threads(context) << "\n"; + std::cout << "mpi: " << (has_mpi(context)? "yes": "no") << "\n"; + std::cout << "ranks: " << num_ranks(context) << "\n" << std::endl; + + auto params = read_options(argc, argv); + + arb::profile::meter_manager meters; + meters.start(context); + + // Create an instance of our recipe. + ring_recipe recipe(params.num_cells, params.cell, params.min_delay); + + // Construct the model. + auto decomposition = arb::partition_load_balance(recipe, context); + arb::simulation sim(recipe, context, decomposition); + + // Set up the probe that will measure voltage in the cell. + + // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0) + auto probeset_id = cell_member_type{0, 0}; + // The schedule for sampling is 10 samples every 1 ms. + auto sched = arb::regular_schedule(1); + // This is where the voltage samples will be stored as (time, value) pairs + arb::trace_vector voltage; + // Now attach the sampler at probeset_id, with sampling schedule sched, writing to voltage + sim.add_sampler(arb::one_probe(probeset_id), sched, arb::make_simple_sampler(voltage)); + + // Set up recording of spikes to a vector on the root process. + std::vector recorded_spikes; + if (root) { + sim.set_global_spike_callback( + [&recorded_spikes](const std::vector& spikes) { + recorded_spikes.insert(recorded_spikes.end(), spikes.begin(), spikes.end()); + }); + } + + meters.checkpoint("model-init", context); + + if (root) { + sim.set_epoch_callback(arb::epoch_progress_bar()); + } + std::cout << "running simulation\n" << std::endl; + // Run the simulation for 100 ms, with time steps of 0.025 ms. + sim.run(params.duration, 0.025); + + meters.checkpoint("model-run", context); + + auto ns = sim.num_spikes(); + + // Write spikes to file + if (root) { + std::cout << "\n" << ns << " spikes generated at rate of " + << params.duration/ns << " ms between spikes\n"; + std::ofstream fid("spikes.gdf"); + if (!fid.good()) { + std::cerr << "Warning: unable to open file spikes.gdf for spike output\n"; + } + else { + char linebuf[45]; + for (auto spike: recorded_spikes) { + auto n = std::snprintf( + linebuf, sizeof(linebuf), "%u %.4f\n", + unsigned{spike.source.gid}, float(spike.time)); + fid.write(linebuf, n); + } + } + } + + // Write the samples to a json file. + if (root) { + write_trace_json(voltage.at(0)); + } + + auto profile = arb::profile::profiler_summary(); + std::cout << profile << "\n"; + + auto report = arb::profile::make_meter_report(meters, context); + std::cout << report; + } + catch (std::exception& e) { + std::cerr << "exception caught in ring miniapp: " << e.what() << "\n"; + return 1; + } + + return 0; +} + +void write_trace_json(const arb::trace_data& trace) { + std::string path = "./voltages.json"; + + nlohmann::json json; + json["name"] = "ring demo"; + json["units"] = "mV"; + json["cell"] = "0.0"; + json["probe"] = "0"; + + auto& jt = json["data"]["time"]; + auto& jy = json["data"]["voltage"]; + + for (const auto& sample: trace) { + jt.push_back(sample.t); + jy.push_back(sample.v); + } + + std::ofstream file(path); + file << std::setw(1) << json << "\n"; +} + +ring_params read_options(int argc, char** argv) { + using sup::param_from_json; + + ring_params params; + if (argc<2) { + std::cout << "Using default parameters.\n"; + return params; + } + if (argc>2) { + throw std::runtime_error("More than one command line option is not permitted."); + } + + std::string fname = argv[1]; + std::cout << "Loading parameters from file: " << fname << "\n"; + std::ifstream f(fname); + + if (!f.good()) { + throw std::runtime_error("Unable to open input parameter file: "+fname); + } + + nlohmann::json json; + f >> json; + + param_from_json(params.name, "name", json); + param_from_json(params.num_cells, "num-cells", json); + param_from_json(params.duration, "duration", json); + param_from_json(params.min_delay, "min-delay", json); + params.cell = parse_cell_parameters(json); + + if (!json.empty()) { + for (auto it=json.begin(); it!=json.end(); ++it) { + std::cout << " Warning: unused input parameter: \"" << it.key() << "\"\n"; + } + std::cout << "\n"; + } + + return params; +} + diff --git a/example/network_description/readme.md b/example/network_description/readme.md new file mode 100644 index 0000000000..fb9950a54d --- /dev/null +++ b/example/network_description/readme.md @@ -0,0 +1,3 @@ +# Ring Example + +A miniapp that demonstrates how to describe how to build a simple ring network. diff --git a/python/example/network_description.py b/python/example/network_description.py new file mode 100755 index 0000000000..32d9d1960f --- /dev/null +++ b/python/example/network_description.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# This script is included in documentation. Adapt line numbers if touched. + +import arbor +import pandas # You may have to pip install these +import seaborn # You may have to pip install these +from math import sqrt +import math + +# Construct a cell with the following morphology. +# The soma (at the root of the tree) is marked 's', and +# the end of each branch i is marked 'bi'. +# +# b1 +# / +# s----b0 +# \ +# b2 + + +def make_cable_cell(gid): + # (1) Build a segment tree + tree = arbor.segment_tree() + + # Soma (tag=1) with radius 6 μm, modelled as cylinder of length 2*radius + s = tree.append( + arbor.mnpos, arbor.mpoint(-12, 0, 0, 6), arbor.mpoint(0, 0, 0, 6), tag=1 + ) + + # (b0) Single dendrite (tag=3) of length 50 μm and radius 2 μm attached to soma. + b0 = tree.append(s, arbor.mpoint(0, 0, 0, 2), arbor.mpoint(0, 0, 50, 2), tag=3) + + # Attach two dendrites (tag=3) of length 50 μm to the end of the first dendrite. + # (b1) Radius tapers from 2 to 0.5 μm over the length of the dendrite. + tree.append( + b0, + arbor.mpoint(0, 0, 50, 2), + arbor.mpoint(0, 50 / sqrt(2), 50 + 50 / sqrt(2) , 0.5), + tag=3, + ) + # (b2) Constant radius of 1 μm over the length of the dendrite. + tree.append( + b0, + arbor.mpoint(0, 0, 50, 1), + arbor.mpoint(0, -50 / sqrt(2), 50 + 50 / sqrt(2), 1), + tag=3, + ) + + # Associate labels to tags + labels = arbor.label_dict( + { + "soma": "(tag 1)", + "dend": "(tag 3)", + # (2) Mark location for synapse at the midpoint of branch 1 (the first dendrite). + "synapse_site": "(location 1 0.5)", + # Mark the root of the tree. + "root": "(root)", + } + ) + + # (3) Create a decor and a cable_cell + decor = ( + arbor.decor() + # Put hh dynamics on soma, and passive properties on the dendrites. + .paint('"soma"', arbor.density("hh")).paint('"dend"', arbor.density("pas")) + # (4) Attach a single synapse. + .place('"synapse_site"', arbor.synapse("expsyn"), "syn") + # Attach a detector with threshold of -10 mV. + .place('"root"', arbor.threshold_detector(-10), "detector") + ) + + return arbor.cable_cell(tree, decor, labels) + + +# (5) Create a recipe that generates a network of connected cells. +class random_ring_recipe(arbor.recipe): + def __init__(self, ncells): + # The base C++ class constructor must be called first, to ensure that + # all memory in the C++ class is initialized correctly. + arbor.recipe.__init__(self) + self.ncells = ncells + self.props = arbor.neuron_cable_properties() + + # (6) The num_cells method that returns the total number of cells in the model + # must be implemented. + def num_cells(self): + return self.ncells + + # (7) The cell_description method returns a cell + def cell_description(self, gid): + return make_cable_cell(gid) + + # The kind method returns the type of cell with gid. + # Note: this must agree with the type returned by cell_description. + def cell_kind(self, gid): + return arbor.cell_kind.cable + + def cell_isometry(self, gid): + # place cells with equal distance on a circle + radius = 500.0 # μm + angle = 2.0 * math.pi * gid / self.ncells + return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) + + def network_description(self): + seed = 42 + + # create a chain + ring = f"(chain (gid-range 0 {self.ncells}))" + # connect front and back of chain to form ring + ring = f"(join {ring} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + + # Create random connections with probability proportional to the inverse distance within a + # radius + max_dist = 400.0 # μm + probability = f"(div (sub {max_dist} (distance)) {max_dist})" + rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" + + # combine ring with random selection + s = f"(join {ring} {rand})" + # restrict to inter-cell connections and certain source / destination labels + s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" + + # normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS + # and truncated to [0.005, 0.035] + w = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + # fixed delay + d = "(scalar 5.0)" # ms delay + + return arbor.network_description(s, w, d, {}) + + # (9) Attach a generator to the first cell in the ring. + def event_generators(self, gid): + if gid == 0: + sched = arbor.explicit_schedule([1]) # one event at 1 ms + weight = 0.1 # 0.1 μS on expsyn + return [arbor.event_generator("syn", weight, sched)] + return [] + + # (10) Place a probe at the root of each cell. + def probes(self, gid): + return [arbor.cable_probe_membrane_voltage('"root"')] + + def global_properties(self, kind): + return self.props + + +# (11) Instantiate recipe +ncells = 4 +recipe = random_ring_recipe(ncells) + +# (12) Create an execution context using all locally available threads and simulation +ctx = arbor.context("avail_threads") +sim = arbor.simulation(recipe, ctx) + +# (13) Set spike generators to record +sim.record(arbor.spike_recording.all) + +# (14) Attach a sampler to the voltage probe on cell 0. Sample rate of 10 sample every ms. +handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] + +# (15) Run simulation for 100 ms +sim.run(100) +print("Simulation finished") + +# (16) Print spike times +print("spikes:") +for sp in sim.spikes(): + print(" ", sp) + +# (17) Plot the recorded voltages over time. +print("Plotting results ...") +df_list = [] +for gid in range(ncells): + samples, meta = sim.samples(handles[gid])[0] + df_list.append( + pandas.DataFrame( + {"t/ms": samples[:, 0], "U/mV": samples[:, 1], "Cell": f"cell {gid}"} + ) + ) + +df = pandas.concat(df_list, ignore_index=True) +seaborn.relplot( + data=df, kind="line", x="t/ms", y="U/mV", hue="Cell", errorbar=None +).savefig("network_ring_result.svg") From 2eb1c9047a598ec87cb913c665831df7792c4063 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 2 May 2023 12:00:11 +0200 Subject: [PATCH 28/84] doc update --- example/network_description/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/network_description/readme.md b/example/network_description/readme.md index fb9950a54d..97131ae1a6 100644 --- a/example/network_description/readme.md +++ b/example/network_description/readme.md @@ -1,3 +1,3 @@ # Ring Example -A miniapp that demonstrates how to describe how to build a simple ring network. +A miniapp that demonstrates how to describe how to build a simple ring network with random interconnection using the network description language. From edb90bf348440bc623237004358ed8380b9c6370 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 11 May 2023 16:49:40 +0200 Subject: [PATCH 29/84] fix mpi compilation --- arbor/communication/mpi.hpp | 6 +++--- arbor/communication/mpi_context.cpp | 5 +++++ example/network_description/network_description.cpp | 2 +- python/example/network_description.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/arbor/communication/mpi.hpp b/arbor/communication/mpi.hpp index a9fbb52cc3..9e4d105a66 100644 --- a/arbor/communication/mpi.hpp +++ b/arbor/communication/mpi.hpp @@ -321,7 +321,7 @@ T broadcast(int root, MPI_Comm comm) { return value; } -std::vector isend(std::size_t num_bytes, +inline std::vector isend(std::size_t num_bytes, const void* data, int dest, int tag, @@ -345,7 +345,7 @@ std::vector isend(std::size_t num_bytes, return requests; } -std::vector irecv(std::size_t num_bytes, +inline std::vector irecv(std::size_t num_bytes, void* data, int source, int tag, @@ -369,7 +369,7 @@ std::vector irecv(std::size_t num_bytes, return requests; } -void wait_all(std::vector requests) { +inline void wait_all(std::vector requests) { if(!requests.empty()) { MPI_OR_THROW( MPI_Waitall, static_cast(requests.size()), requests.data(), MPI_STATUSES_IGNORE); diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index 9de7d37679..15d75edc3b 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -101,6 +101,11 @@ struct mpi_context_impl { struct mpi_send_recv_request : public distributed_request::distributed_request_interface { std::vector recv_requests, send_requests; + mpi_send_recv_request(std::vector recv_requests, + std::vector send_requests): + recv_requests(std::move(recv_requests)), + send_requests(std::move(send_requests)) {} + void finalize() override { if (!recv_requests.empty()) { mpi::wait_all(std::move(recv_requests)); diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index 8e03fd0736..b780dc6a21 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -106,7 +106,7 @@ class ring_recipe: public arb::recipe { arb::intersect(arb::network_selection::source_cell({num_cells_ - 1}), arb::network_selection::destination_cell({0}))); - // Create random connections with probability proportional to the inverse distance within a + // Create random connections with probability inversely proportional to the distance within a // radius const double max_dist = 400.0; auto probability = (max_dist - arb::network_value::distance()) / max_dist; diff --git a/python/example/network_description.py b/python/example/network_description.py index 32d9d1960f..a166cbffd4 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -109,7 +109,7 @@ def network_description(self): # connect front and back of chain to form ring ring = f"(join {ring} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" - # Create random connections with probability proportional to the inverse distance within a + # Create random connections with probability inversely proportional to the distance within a # radius max_dist = 400.0 # μm probability = f"(div (sub {max_dist} (distance)) {max_dist})" From 2ab0d89672990e176a0d9001865734a5f2b1ceea Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 15 May 2023 15:37:08 +0200 Subject: [PATCH 30/84] Add if-else network value --- arbor/include/arbor/network.hpp | 8 +++++++ arbor/network.cpp | 41 +++++++++++++++++++++++++++++++++ arborio/networkio.cpp | 5 ++++ test/unit/test_network.cpp | 15 ++++++++++++ test/unit/test_s_expr.cpp | 1 + 5 files changed, 70 insertions(+) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 91f810feae..18f4539a6c 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -48,6 +48,8 @@ struct network_value_impl; class ARB_SYMBOL_VISIBLE network_label_dict; +class ARB_SYMBOL_VISIBLE network_selection; + class ARB_SYMBOL_VISIBLE network_value { public: using custom_func_type = @@ -106,6 +108,10 @@ class ARB_SYMBOL_VISIBLE network_value { static network_value max(network_value left, network_value right); + static network_value if_else(network_selection cond, + network_value true_value, + network_value false_value); + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_value& v); private: @@ -215,6 +221,8 @@ class ARB_SYMBOL_VISIBLE network_selection { friend std::shared_ptr thingify(network_selection s, const network_label_dict& dict); + friend class network_value; + std::shared_ptr impl_; }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 6c6518e250..c6a07b7582 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -1264,6 +1264,40 @@ struct network_value_log_impl: public network_value_impl { } }; +struct network_value_if_else_impl: public network_value_impl { + std::shared_ptr cond; + std::shared_ptr true_value; + std::shared_ptr false_value; + + network_value_if_else_impl(std::shared_ptr cond, + std::shared_ptr true_value, + std::shared_ptr false_value): + cond(std::move(cond)), + true_value(std::move(true_value)), + false_value(std::move(false_value)) {} + + double get(const network_site_info& src, const network_site_info& dest) const override { + if (cond->select_connection(src, dest)) return true_value->get(src, dest); + return false_value->get(src, dest); + } + + void initialize(const network_label_dict& dict) override { + cond->initialize(dict); + true_value->initialize(dict); + false_value->initialize(dict); + }; + + void print(std::ostream& os) const override { + os << "(if-else "; + cond->print(os); + os << " "; + true_value->print(os); + os << " "; + false_value->print(os); + os << ")"; + } +}; + } // namespace network_site_info::network_site_info(cell_gid_type gid, @@ -1487,6 +1521,13 @@ network_value network_value::max(network_value left, network_value right) { std::make_shared(std::move(left.impl_), std::move(right.impl_))); } +network_value network_value::if_else(network_selection cond, + network_value true_value, + network_value false_value) { + return network_value(std::make_shared( + std::move(cond.impl_), std::move(true_value.impl_), std::move(false_value.impl_))); +} + std::optional network_label_dict::selection(const std::string& name) const { auto it = selections_.find(name); if (it != selections_.end()) return it->second; diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp index 64a5c9376d..cd38dd4ac4 100644 --- a/arborio/networkio.cpp +++ b/arborio/networkio.cpp @@ -199,6 +199,11 @@ eval_map_type network_eval_map{ "Truncated normal random distribution with given mean and standard deviation within " "interval [begin, end]: (seed:integer, mean:real, std_deviation:real, begin:real, " "end:real)")}, + {"if-else", + make_call(arb::network_value::if_else, + "Returns the first network-value if a connection is the given network-selection and " + "the second network-value otherwise. 3 arguments: (sel:network-selection, " + "true_value:network-value, false_value:network_value)")}, {"add", make_conversion_fold( arb::network_value::add, diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp index 148c2850cf..4485852981 100644 --- a/test/unit/test_network.cpp +++ b/test/unit/test_network.cpp @@ -803,3 +803,18 @@ TEST(network_value, max) { } } } + +TEST(network_value, if_else) { + const auto v1 = network_value::scalar(2.0); + const auto v2 = network_value::scalar(3.0); + + const auto s = network_selection::inter_cell(); + + const auto v = thingify(network_value::if_else(s, v1, v2), network_label_dict()); + + for (const auto& src: test_sites) { + for (const auto& dest: test_sites) { + EXPECT_DOUBLE_EQ(v->get(src, dest), src.gid != dest.gid ? 2.0 : 3.0); + } + } +} diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index 15593575cd..dbbfbc0eb9 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -415,6 +415,7 @@ TEST(network_value, round_tripping) { "(truncated-normal-distribution 42 0.5 0.1 0.3 0.7)", "(log (scalar 1.3))", "(exp (scalar 1.3))", + "(if-else (inter-cell) (scalar 5.1) (log (scalar 1.3)))", }; for (auto l: network_literals) { From 75fa66ab983598c1c7f5e064cde4a062e9d030b5 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 18 May 2023 19:21:45 +0200 Subject: [PATCH 31/84] generalized ring exchange --- arbor/communication/distributed_for_each.hpp | 158 +++++++++++++++++++ arbor/network_generation.cpp | 136 +++++----------- 2 files changed, 198 insertions(+), 96 deletions(-) create mode 100644 arbor/communication/distributed_for_each.hpp diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp new file mode 100644 index 0000000000..2db82c9453 --- /dev/null +++ b/arbor/communication/distributed_for_each.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "distributed_context.hpp" +#include "util/range.hpp" + +namespace arb { + +namespace impl { +template +void for_each_in_tuple(FUNC&& func, std::tuple& t, std::index_sequence) { + (func(Is, std::get(t)), ...); +} + +template +void for_each_in_tuple(FUNC&& func, std::tuple& t) { + for_each_in_tuple(func, t, std::index_sequence_for()); +} + +template +void for_each_in_tuple_pair(FUNC&& func, + std::tuple& t1, + std::tuple& t2, + std::index_sequence) { + (func(Is, std::get(t1), std::get(t2)), ...); +} + +template +void for_each_in_tuple_pair(FUNC&& func, std::tuple& t1, std::tuple& t2) { + for_each_in_tuple_pair(func, t1, t2, std::index_sequence_for()); +} + +} // namespace impl + +template +void distributed_for_each(FUNC&& func, + const distributed_context& distributed, + const std::vector&... args) { + + static_assert(sizeof...(args) > 0); + auto arg_tuple = std::forward_as_tuple(args...); + + struct vec_info { + std::size_t offset; // offset in bytes + std::size_t size; // size in bytes + }; + + std::array info; + std::size_t buffer_size = 0; + + // Compute offsets in bytes for each vector when placed in common buffer + { + std::size_t offset = info.size() * sizeof(vec_info); + impl::for_each_in_tuple( + [&](std::size_t i, auto&& vec) { + using T = typename std::remove_reference_t::value_type; + static_assert(std::is_trivially_copyable_v); + static_assert(alignof(std::max_align_t) >= alignof(T)); + static_assert(alignof(std::max_align_t) % alignof(T) == 0); + + // make sure alignment of offset fulfills requirement + const auto alignment_excess = offset % alignof(T); + offset += alignment_excess > 0 ? alignof(T) - (alignment_excess) : 0; + + const auto size_in_bytes = vec.size() * sizeof(T); + + info[i].size = size_in_bytes; + info[i].offset = offset; + + buffer_size = offset + size_in_bytes; + offset += size_in_bytes; + }, + arg_tuple); + } + + // compute maximum buffer size between ranks, such that we only allocate once + const std::size_t max_buffer_size = distributed.max(buffer_size); + + // exit if all vectors on all ranks are empty + if (max_buffer_size == info.size() * sizeof(vec_info)) return; + + // use malloc for std::max_align_t alignment + auto deleter = [](char* ptr) { std::free(ptr); }; + std::unique_ptr buffer((char*)std::malloc(max_buffer_size), deleter); + std::unique_ptr recv_buffer( + (char*)std::malloc(max_buffer_size), deleter); + + // copy offset and size info to front of buffer + std::memcpy(buffer.get(), info.data(), info.size() * sizeof(vec_info)); + + // copy each vector to each location in buffer + impl::for_each_in_tuple( + [&](std::size_t i, auto&& vec) { + using T = typename std::remove_reference_t::value_type; + std::memcpy(buffer.get() + info[i].offset, vec.data(), vec.size() * sizeof(T)); + }, + arg_tuple); + + std::tuple...> ranges; + + const auto my_rank = distributed.id(); + const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; + const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; + + // exchange buffer in ring pattern and apply function at each step + for (std::size_t step = 0; step < distributed.size() - 1; ++step) { + // always expect to recieve the max size but send actual size. MPI_recv only expects a max + // size, not the actual size. + const auto current_info = (const vec_info*)buffer.get(); + + auto request = distributed.send_recv_nonblocking(max_buffer_size, + recv_buffer.get(), + right_rank, + current_info[info.size() - 1].offset + current_info[info.size() - 1].size, + buffer.get(), + left_rank, + 0); + + // update ranges + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range((T*)(buffer.get() + current_info[i].offset), + (T*)(buffer.get() + current_info[i].offset + current_info[i].size)); + }, + arg_tuple, + ranges); + + // call provided function with ranges pointing to current buffer + std::apply(func, ranges); + + request.finalize(); + buffer.swap(recv_buffer); + } + + // final step does not require any exchange + const auto current_info = (const vec_info*)buffer.get(); + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range((T*)(buffer.get() + current_info[i].offset), + (T*)(buffer.get() + current_info[i].offset + current_info[i].size)); + }, + arg_tuple, + ranges); + + // call provided function with ranges pointing to current buffer + std::apply(func, ranges); +} + +} // namespace arb diff --git a/arbor/network_generation.cpp b/arbor/network_generation.cpp index 8af44f2382..c8ae1842f7 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_generation.cpp @@ -1,6 +1,8 @@ #include "network_generation.hpp" #include "cell_group_factory.hpp" +#include "communication/distributed_for_each.hpp" #include "network_impl.hpp" +#include "util/range.hpp" #include "util/spatial_tree.hpp" #include @@ -39,18 +41,18 @@ struct distributed_site_info { struct site_mapping { std::vector sites; - std::string labels; - std::unordered_map label_map; + std::vector labels; + std::unordered_map label_map; site_mapping() = default; inline std::size_t size() const { return sites.size(); } void insert(const network_site_info& s) { - const auto insert_pair = label_map.insert({s.label, labels.size()}); + const auto insert_pair = label_map.insert({std::string(s.label), labels.size()}); // append label if not contained in labels if (insert_pair.second) { - labels.append(s.label); + labels.insert(labels.end(), s.label.begin(), s.label.end()); labels.push_back('\0'); } sites.emplace_back(distributed_site_info{s.gid, @@ -69,7 +71,7 @@ struct site_mapping { info.gid = s.gid; info.lid = s.lid; info.kind = s.kind; - info.label = labels.c_str() + s.label_start_idx; + info.label = labels.data() + s.label_start_idx; info.location = s.location; info.global_location = s.global_location; info.hash = s.hash; @@ -78,74 +80,6 @@ struct site_mapping { } }; -struct distributed_site_mapping { - const distributed_context& distributed; - std::vector num_sites_per_rank, label_string_size_per_rank; - site_mapping mapping, recv_mapping; - - explicit distributed_site_mapping(const distributed_context& distributed, site_mapping m): - distributed(distributed), - mapping(std::move(m)) { - mapping.label_map.clear(); // no longer valid after first exchange - - num_sites_per_rank = distributed.gather_all(mapping.sites.size()); - label_string_size_per_rank = distributed.gather_all(mapping.labels.size()); - - const auto max_num_sites = - *std::max_element(num_sites_per_rank.begin(), num_sites_per_rank.end()); - const auto max_string_size = - *std::max_element(label_string_size_per_rank.begin(), label_string_size_per_rank.end()); - - mapping.sites.resize(max_num_sites); - mapping.labels.resize(max_string_size); - recv_mapping.sites.resize(max_num_sites); - recv_mapping.labels.resize(max_string_size); - } - - template - void for_each_site(const FUNC& f) { - const auto my_rank = distributed.id(); - const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; - const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; - - auto current_idx = my_rank; - for (std::size_t step = 0; step < distributed.size() - 1; ++step) { - const auto next_idx = (current_idx + 1) % distributed.size(); - auto request_sites = distributed.send_recv_nonblocking(num_sites_per_rank[next_idx], - recv_mapping.sites.data(), - right_rank, - num_sites_per_rank[current_idx], - mapping.sites.data(), - left_rank, - 0); - - auto request_labels = - distributed.send_recv_nonblocking(label_string_size_per_rank[next_idx], - recv_mapping.labels.data(), - right_rank, - label_string_size_per_rank[current_idx], - mapping.labels.data(), - left_rank, - 1); - - for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { - f(mapping.get_site(site_idx)); - } - - request_sites.finalize(); - request_labels.finalize(); - - std::swap(mapping, recv_mapping); - - current_idx = next_idx; - } - - for (std::size_t site_idx = 0; site_idx < num_sites_per_rank[current_idx]; ++site_idx) { - f(mapping.get_site(site_idx)); - } - } -}; - } // namespace std::vector generate_network_connections(const recipe& rec, @@ -289,33 +223,43 @@ std::vector generate_network_connections(const recipe& rec, // select connections std::vector connections; - auto sample_destinations = [&](const network_site_info& src) { - auto sample = [&](const network_site_info& dest) { - if (selection.select_connection(src, dest)) { - connections.emplace_back(connection({src.gid, src.lid}, - {dest.gid, dest.lid}, - weight.get(src, dest), - delay.get(src, dest))); + auto sample_sources = [&](const util::range& source_range, + const util::range& label_range) { + for (const auto& s: source_range) { + network_site_info src; + src.gid = s.gid; + src.lid = s.lid; + src.kind = s.kind; + src.label = label_range.data() + s.label_start_idx; + src.location = s.location; + src.global_location = s.global_location; + src.hash = s.hash; + + auto sample = [&](const network_site_info& dest) { + if (selection.select_connection(src, dest)) { + connections.emplace_back(connection({src.gid, src.lid}, + {dest.gid, dest.lid}, + weight.get(src, dest), + delay.get(src, dest))); + } + }; + + if (selection.max_distance().has_value()) { + const double d = selection.max_distance().value(); + local_dest_tree.bounding_box_for_each( + decltype(local_dest_tree)::point_type{src.global_location.x - d, + src.global_location.y - d, + src.global_location.z - d}, + decltype(local_dest_tree)::point_type{src.global_location.x + d, + src.global_location.y + d, + src.global_location.z + d}, + sample); } - }; - - if (selection.max_distance().has_value()) { - const double d = selection.max_distance().value(); - local_dest_tree.bounding_box_for_each( - decltype(local_dest_tree)::point_type{src.global_location.x - d, - src.global_location.y - d, - src.global_location.z - d}, - decltype(local_dest_tree)::point_type{src.global_location.x + d, - src.global_location.y + d, - src.global_location.z + d}, - sample); + else { local_dest_tree.for_each(sample); } } - else { local_dest_tree.for_each(sample); } }; - distributed_site_mapping distributed_src_sites(distributed, std::move(src_sites)); - - distributed_src_sites.for_each_site(sample_destinations); + distributed_for_each(sample_sources, distributed, src_sites.sites, src_sites.labels); return connections; } From f666941696187bbcfd707eb205f5a9cffcc12918 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 2 Jul 2023 14:38:45 +0200 Subject: [PATCH 32/84] add network connection export and multi-threading --- arbor/CMakeLists.txt | 2 +- arbor/communication/communicator.cpp | 4 +- arbor/include/arbor/network.hpp | 35 +-- arbor/include/arbor/network_generation.hpp | 13 ++ arbor/include/arbor/recipe.hpp | 1 + arbor/network.cpp | 204 ++++++++--------- arbor/network_generation.hpp | 19 -- ...etwork_generation.cpp => network_impl.cpp} | 212 +++++++++++++----- arbor/network_impl.hpp | 29 ++- python/network.cpp | 52 +++-- test/unit-distributed/test_communicator.cpp | 4 +- test/unit/test_network.cpp | 18 +- 12 files changed, 360 insertions(+), 233 deletions(-) create mode 100644 arbor/include/arbor/network_generation.hpp delete mode 100644 arbor/network_generation.hpp rename arbor/{network_generation.cpp => network_impl.cpp} (51%) diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt index 807266eaec..cf64e3a6cd 100644 --- a/arbor/CMakeLists.txt +++ b/arbor/CMakeLists.txt @@ -46,7 +46,7 @@ set(arbor_sources morph/stitch.cpp merge_events.cpp network.cpp - network_generation.cpp + network_impl.cpp simulation.cpp partition_load_balance.cpp profile/clock.cpp diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index ad2b5d26fc..03e8f803c1 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -13,12 +13,12 @@ #include "connection.hpp" #include "distributed_context.hpp" #include "execution_context.hpp" +#include "network_impl.hpp" #include "profile/profiler_macro.hpp" #include "threading/threading.hpp" #include "util/partition.hpp" #include "util/rangeutil.hpp" #include "util/span.hpp" -#include "network_generation.hpp" #include "communication/communicator.hpp" @@ -45,7 +45,7 @@ void communicator::update_connections(const recipe& rec, // Construct connections from high-level specification - auto generated_connections = generate_network_connections(rec, ctx_, dom_dec); + auto generated_connections = generate_connections(rec, ctx_, dom_dec); // For caching information about each cell struct gid_info { diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 18f4539a6c..29180f571a 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -24,24 +25,27 @@ namespace arb { using network_hash_type = std::uint64_t; struct ARB_SYMBOL_VISIBLE network_site_info { - network_site_info() = default; - - network_site_info(cell_gid_type gid, - cell_lid_type lid, - cell_kind kind, - std::string_view label, - mlocation location, - mpoint global_location); - cell_gid_type gid; - cell_lid_type lid; cell_kind kind; - std::string_view label; + cell_tag_type label; mlocation location; mpoint global_location; - network_hash_type hash; + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_site_info& s); }; +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_site_info, + (a.gid, a.kind, a.label, a.location, a.global_location), + (b.gid, a.kind, b.label, b.location, b.global_location)) + +struct ARB_SYMBOL_VISIBLE network_connection_info { + network_site_info src, dest; + + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_connection_info& s); +}; + +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info, (a.src, a.dest), (b.src, b.dest)) + struct network_selection_impl; struct network_value_impl; @@ -52,8 +56,7 @@ class ARB_SYMBOL_VISIBLE network_selection; class ARB_SYMBOL_VISIBLE network_value { public: - using custom_func_type = - std::function; + using custom_func_type = std::function; network_value() { *this = network_value::scalar(0.0); } @@ -147,8 +150,7 @@ ARB_ARBOR_API inline network_value operator-(network_value a) { class ARB_SYMBOL_VISIBLE network_selection { public: - using custom_func_type = - std::function; + using custom_func_type = std::function; network_selection() { *this = network_selection::none(); } @@ -248,7 +250,6 @@ class ARB_SYMBOL_VISIBLE network_label_dict { nv_map values_; }; - struct network_description { network_selection selection; network_value weight; diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp new file mode 100644 index 0000000000..7a61f6e948 --- /dev/null +++ b/arbor/include/arbor/network_generation.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace arb { + +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec); + +} // namespace arb diff --git a/arbor/include/arbor/recipe.hpp b/arbor/include/arbor/recipe.hpp index 69cccd83ec..8921c728ef 100644 --- a/arbor/include/arbor/recipe.hpp +++ b/arbor/include/arbor/recipe.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include diff --git a/arbor/network.cpp b/arbor/network.cpp index c6a07b7582..268370689a 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -27,25 +27,8 @@ enum class network_seed : unsigned { value_uniform = 48202, value_normal = 8405, value_truncated_normal = 380237, - site_info = 984293 }; -// We only need minimal hash collisions and good spread over the hash range, because this will be -// used as input for random123, which then provides all desired hash properties. -// std::hash is implementation dependent, so we define our own for reproducibility. - -std::uint64_t simple_string_hash(const std::string_view& s) { - // use fnv1a hash algorithm - constexpr std::uint64_t prime = 1099511628211ull; - std::uint64_t h = 14695981039346656037ull; - - for (auto c: s) { - h ^= c; - h *= prime; - } - - return h; -} double uniform_rand_from_key_pair(std::array seed, network_hash_type key_a, @@ -71,8 +54,8 @@ double normal_rand_from_key_pair(std::array seed, } struct network_selection_all_impl: public network_selection_impl { - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return true; } @@ -93,8 +76,8 @@ struct network_selection_all_impl: public network_selection_impl { struct network_selection_none_impl: public network_selection_impl { - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return false; } @@ -118,8 +101,8 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { explicit network_selection_source_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return src.kind == select_kind; } @@ -152,8 +135,8 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im explicit network_selection_destination_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return dest.kind == select_kind; } @@ -189,8 +172,8 @@ struct network_selection_source_label_impl: public network_selection_impl { std::sort(sorted_labels.begin(), sorted_labels.end()); } - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.label); } @@ -221,8 +204,8 @@ struct network_selection_destination_label_impl: public network_selection_impl { std::sort(sorted_labels.begin(), sorted_labels.end()); } - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.label); } @@ -253,8 +236,8 @@ struct network_selection_source_cell_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid); } @@ -285,8 +268,8 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return src.gid >= gid_begin && src.gid < gid_end && !((src.gid - gid_begin) % step); } @@ -315,8 +298,8 @@ struct network_selection_destination_cell_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid); } @@ -347,8 +330,8 @@ struct network_selection_destination_cell_range_impl: public network_selection_i gid_end(r.end), step(r.step) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return dest.gid >= gid_begin && dest.gid < gid_end && !((dest.gid - gid_begin) % step); } @@ -377,8 +360,8 @@ struct network_selection_chain_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { if (gids.empty()) return false; // gids size always > 0 frome here on @@ -425,8 +408,8 @@ struct network_selection_chain_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || dest.gid >= gid_end) return false; @@ -463,8 +446,8 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl gid_end(r.end), step(r.step) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || dest.gid >= gid_end) return false; @@ -499,8 +482,8 @@ struct network_selection_complement_impl: public network_selection_impl { explicit network_selection_complement_impl(std::shared_ptr s): selection(std::move(s)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return !selection->select_connection(src, dest); } @@ -535,8 +518,8 @@ struct network_selection_named_impl: public network_selection_impl { explicit network_selection_named_impl(std::string name): selection_name(std::move(name)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_connection(src, dest); @@ -572,8 +555,8 @@ struct network_selection_named_impl: public network_selection_impl { }; struct network_selection_inter_cell_impl: public network_selection_impl { - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return src.gid != dest.gid; } @@ -598,9 +581,14 @@ struct network_selection_custom_impl: public network_selection_impl { explicit network_selection_custom_impl(network_selection::custom_func_type f): func(std::move(f)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { - return func(src, dest); + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { + return func({{src.gid, + src.kind, + cell_tag_type(src.label), + src.location, + src.global_location}, + {dest.gid, dest.kind, cell_tag_type(dest.label), dest.location, dest.global_location}}); } bool select_source(cell_kind kind, @@ -623,8 +611,8 @@ struct network_selection_distance_lt_impl: public network_selection_impl { explicit network_selection_distance_lt_impl(double d): d(d) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return distance(src.global_location, dest.global_location) < d; } @@ -650,8 +638,8 @@ struct network_selection_distance_gt_impl: public network_selection_impl { explicit network_selection_distance_gt_impl(double d): d(d) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return distance(src.global_location, dest.global_location) > d; } @@ -678,8 +666,8 @@ struct network_selection_random_impl: public network_selection_impl { network_selection_random_impl(unsigned seed, network_value p): seed(seed), p_value(std::move(p)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { if (!probability) throw arbor_internal_error("Trying to use unitialized named network selection."); const auto r = uniform_rand_from_key_pair( @@ -719,8 +707,8 @@ struct network_selection_intersect_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return left->select_connection(src, dest) && right->select_connection(src, dest); } @@ -770,8 +758,8 @@ struct network_selection_join_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return left->select_connection(src, dest) || right->select_connection(src, dest); } @@ -819,8 +807,8 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return left->select_connection(src, dest) ^ right->select_connection(src, dest); } @@ -868,8 +856,8 @@ struct network_selection_difference_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_site_info& src, - const network_site_info& dest) const override { + bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const override { return left->select_connection(src, dest) && !(right->select_connection(src, dest)); } @@ -912,7 +900,7 @@ struct network_value_scalar_impl: public network_value_impl { network_value_scalar_impl(double v): value(v) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return value; } @@ -925,7 +913,7 @@ struct network_value_distance_impl: public network_value_impl { network_value_distance_impl(double s): scale(s) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return scale * distance(src.global_location, dest.global_location); } @@ -943,7 +931,7 @@ struct network_value_uniform_distribution_impl: public network_value_impl { throw std::invalid_argument("Uniform distribution: invalid range"); } - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { if (range[0] > range[1]) return range[1]; // random number between 0 and 1 @@ -968,7 +956,7 @@ struct network_value_normal_distribution_impl: public network_value_impl { mean(mean_), std_deviation(std_deviation_) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return mean + std_deviation * normal_rand_from_key_pair( {unsigned(network_seed::value_normal), seed}, src.hash, dest.hash); @@ -997,7 +985,7 @@ struct network_value_truncated_normal_distribution_impl: public network_value_im throw std::invalid_argument("Truncated normal distribution: invalid range"); } - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { const auto src_hash = src.hash; auto dest_hash = dest.hash; @@ -1027,8 +1015,13 @@ struct network_value_custom_impl: public network_value_impl { network_value_custom_impl(network_value::custom_func_type f): func(std::move(f)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { - return func(src, dest); + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + return func({{src.gid, + src.kind, + cell_tag_type(src.label), + src.location, + src.global_location}, + {dest.gid, dest.kind, cell_tag_type(dest.label), dest.location, dest.global_location}}); } void print(std::ostream& os) const override { os << "(custom-network-value)"; } @@ -1042,7 +1035,7 @@ struct network_value_named_impl: public network_value_impl { explicit network_value_named_impl(std::string name): value_name(std::move(name)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { if (!value) throw arbor_internal_error("Trying to use unitialized named network value."); return value->get(src, dest); } @@ -1068,7 +1061,7 @@ struct network_value_add_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return left->get(src, dest) + right->get(src, dest); } @@ -1095,7 +1088,7 @@ struct network_value_mul_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return left->get(src, dest) * right->get(src, dest); } @@ -1122,7 +1115,7 @@ struct network_value_sub_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return left->get(src, dest) - right->get(src, dest); } @@ -1148,7 +1141,7 @@ struct network_value_div_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { const auto v_right = right ->get(src,dest); if (!v_right) throw arbor_exception("network_value: division by 0."); return left->get(src, dest) / right->get(src, dest); @@ -1176,7 +1169,7 @@ struct network_value_max_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return std::max(left->get(src, dest), right->get(src, dest)); } @@ -1202,7 +1195,7 @@ struct network_value_min_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return std::min(left->get(src, dest), right->get(src, dest)); } @@ -1226,7 +1219,7 @@ struct network_value_exp_impl: public network_value_impl { network_value_exp_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { return std::exp(value->get(src, dest)); } @@ -1247,7 +1240,7 @@ struct network_value_log_impl: public network_value_impl { network_value_log_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { const auto v = value->get(src, dest); if (v <= 0.0) throw arbor_exception("network_value: log of value <= 0.0."); return std::log(value->get(src, dest)); @@ -1276,7 +1269,7 @@ struct network_value_if_else_impl: public network_value_impl { true_value(std::move(true_value)), false_value(std::move(false_value)) {} - double get(const network_site_info& src, const network_site_info& dest) const override { + double get(const network_full_site_info& src, const network_full_site_info& dest) const override { if (cond->select_connection(src, dest)) return true_value->get(src, dest); return false_value->get(src, dest); } @@ -1300,32 +1293,6 @@ struct network_value_if_else_impl: public network_value_impl { } // namespace -network_site_info::network_site_info(cell_gid_type gid, - cell_lid_type lid, - cell_kind kind, - std::string_view label, - mlocation location, - mpoint global_location): - gid(gid), - lid(lid), - kind(kind), - label(std::move(label)), - location(location), - global_location(global_location) { - - std::uint64_t label_hash = simple_string_hash(this->label); - static_assert(sizeof(decltype(mlocation::pos)) == sizeof(std::uint64_t)); - std::uint64_t loc_pos_hash = *reinterpret_cast(&location.pos); - - const auto seed = static_cast(network_seed::site_info); - - using rand_type = r123::Threefry4x64; - const rand_type::ctr_type seed_input = {{seed, 2 * seed, 3 * seed, 4 * seed}}; - const rand_type::key_type key = {{gid, label_hash, location.branch, loc_pos_hash}}; - - rand_type gen; - hash = gen(seed_input, key)[0]; -} network_selection::network_selection(std::shared_ptr impl): impl_(std::move(impl)) {} @@ -1552,6 +1519,25 @@ ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_value& v) return os; } +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_site_info& s) { + + os << ""; + return os; +} + +ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_connection_info& s) { + + os << ""; + return os; +} + ARB_ARBOR_API network_selection join(network_selection left, network_selection right) { return network_selection::join(std::move(left), std::move(right)); } diff --git a/arbor/network_generation.hpp b/arbor/network_generation.hpp deleted file mode 100644 index bff19d4ccb..0000000000 --- a/arbor/network_generation.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -#include -#include -#include - -#include "connection.hpp" -#include "distributed_context.hpp" -#include "label_resolution.hpp" - -namespace arb { - -std::vector generate_network_connections(const recipe& rec, - const context& ctx, - const domain_decomposition& dom_dec); - -} // namespace arb diff --git a/arbor/network_generation.cpp b/arbor/network_impl.cpp similarity index 51% rename from arbor/network_generation.cpp rename to arbor/network_impl.cpp index c8ae1842f7..88b56687ca 100644 --- a/arbor/network_generation.cpp +++ b/arbor/network_impl.cpp @@ -1,7 +1,9 @@ -#include "network_generation.hpp" +#include "network_impl.hpp" #include "cell_group_factory.hpp" #include "communication/distributed_for_each.hpp" +#include "label_resolution.hpp" #include "network_impl.hpp" +#include "threading/threading.hpp" #include "util/range.hpp" #include "util/spatial_tree.hpp" @@ -10,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -18,16 +21,35 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace arb { namespace { +// We only need minimal hash collisions and good spread over the hash range, because this will be +// used as input for random123, which then provides all desired hash properties. +// std::hash is implementation dependent, so we define our own for reproducibility. + +std::uint64_t simple_string_hash(const std::string_view& s) { + // use fnv1a hash algorithm + constexpr std::uint64_t prime = 1099511628211ull; + std::uint64_t h = 14695981039346656037ull; + + for (auto c: s) { + h ^= c; + h *= prime; + } + + return h; +} + struct distributed_site_info { cell_gid_type gid = 0; cell_lid_type lid = 0; @@ -38,7 +60,6 @@ struct distributed_site_info { network_hash_type hash = 0; }; - struct site_mapping { std::vector sites; std::vector labels; @@ -48,7 +69,7 @@ struct site_mapping { inline std::size_t size() const { return sites.size(); } - void insert(const network_site_info& s) { + void insert(const network_full_site_info& s) { const auto insert_pair = label_map.insert({std::string(s.label), labels.size()}); // append label if not contained in labels if (insert_pair.second) { @@ -64,10 +85,10 @@ struct site_mapping { s.hash}); } - network_site_info get_site(std::size_t idx) const { + network_full_site_info get_site(std::size_t idx) const { const auto& s = this->sites.at(idx); - network_site_info info; + network_full_site_info info; info.gid = s.gid; info.lid = s.lid; info.kind = s.kind; @@ -80,9 +101,28 @@ struct site_mapping { } }; -} // namespace +void push_back(std::vector& vec, + const network_full_site_info& src, + const network_full_site_info& dest, + double weight, + double delay) { + vec.emplace_back(connection({src.gid, src.lid}, {dest.gid, dest.lid}, weight, delay)); +} -std::vector generate_network_connections(const recipe& rec, +void push_back(std::vector& vec, + const network_full_site_info& src, + const network_full_site_info& dest, + double weight, + double delay) { + vec.emplace_back(network_connection_info{ + network_site_info{ + src.gid, src.kind, std::string(src.label), src.location, src.global_location}, + network_site_info{ + dest.gid, dest.kind, std::string(dest.label), dest.location, dest.global_location}}); +} + +template +std::vector generate_network_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { const auto description_opt = rec.network_description(); @@ -93,6 +133,7 @@ std::vector generate_network_connections(const recipe& rec, const auto& description = description_opt.value(); site_mapping src_sites, dest_sites; + std::mutex src_sites_mutex, dest_sites_mutex; const auto selection_ptr = thingify(description.selection, description.dict); const auto weight_ptr = thingify(description.weight, description.dict); @@ -107,51 +148,54 @@ std::vector generate_network_connections(const recipe& rec, switch (group.kind) { case cell_kind::cable: { // We need access to morphology, so the cell is create directly - cable_cell cell; - for (const auto& gid: group.gids) { - try { - cell = util::any_cast(rec.get_cell_description(gid)); - } - catch (std::bad_any_cast&) { - throw bad_cell_description(rec.get_cell_kind(gid), gid); - } - - auto lid_to_label = [](const std::unordered_multimap& map, - cell_lid_type lid) -> const cell_tag_type& { - for (const auto& [label, range]: map) { - if (lid >= range.begin && lid < range.end) return label; + threading::parallel_for::apply( + 0, group.gids.size(), ctx->thread_pool.get(), [&](int i) { + const auto gid = group.gids[i]; + cable_cell cell; + try { + cell = util::any_cast(rec.get_cell_description(gid)); + } + catch (std::bad_any_cast&) { + throw bad_cell_description(rec.get_cell_kind(gid), gid); } - throw arbor_internal_error("unkown lid"); - }; - place_pwlin location_resolver(cell.morphology(), rec.get_cell_isometry(gid)); + auto lid_to_label = + [](const std::unordered_multimap& map, + cell_lid_type lid) -> const cell_tag_type& { + for (const auto& [label, range]: map) { + if (lid >= range.begin && lid < range.end) return label; + } + throw arbor_internal_error("unkown lid"); + }; + + place_pwlin location_resolver(cell.morphology(), rec.get_cell_isometry(gid)); - // check all synapses of cell for potential destination + // check all synapses of cell for potential destination - for (const auto& [_, placed_synapses]: cell.synapses()) { - for (const auto& p_syn: placed_synapses) { - // TODO check if tag correct - const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); + for (const auto& [_, placed_synapses]: cell.synapses()) { + for (const auto& p_syn: placed_synapses) { + const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); - if (selection.select_destination(cell_kind::cable, gid, label)) { - const mpoint point = location_resolver.at(p_syn.loc); - dest_sites.insert( - {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); + if (selection.select_destination(cell_kind::cable, gid, label)) { + const mpoint point = location_resolver.at(p_syn.loc); + std::lock_guard guard(dest_sites_mutex); + dest_sites.insert( + {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); + } } } - } - // check all detectors of cell for potential source - for (const auto& p_det: cell.detectors()) { - // TODO check if tag correct - const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); - if (selection.select_source(cell_kind::cable, gid, label)) { - const mpoint point = location_resolver.at(p_det.loc); - src_sites.insert( - {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); + // check all detectors of cell for potential source + for (const auto& p_det: cell.detectors()) { + const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); + if (selection.select_source(cell_kind::cable, gid, label)) { + const mpoint point = location_resolver.at(p_det.loc); + std::lock_guard guard(src_sites_mutex); + src_sites.insert( + {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); + } } - } - } + }); } break; default: { // Assuming all other cell types do not have a morphology. We can use label resolution @@ -179,6 +223,7 @@ std::vector generate_network_connections(const recipe& rec, const auto& range = sources.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_source(group.kind, gid, label)) { + std::lock_guard guard(src_sites_mutex); src_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); } } @@ -192,6 +237,7 @@ std::vector generate_network_connections(const recipe& rec, const auto& range = destinations.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_destination(group.kind, gid, label)) { + std::lock_guard guard(dest_sites_mutex); dest_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); } } @@ -206,27 +252,30 @@ std::vector generate_network_connections(const recipe& rec, } // create octree - std::vector network_dest_sites; + std::vector network_dest_sites; network_dest_sites.reserve(dest_sites.size()); - for(std::size_t i = 0; i < dest_sites.size(); ++i) { + for (std::size_t i = 0; i < dest_sites.size(); ++i) { network_dest_sites.emplace_back(dest_sites.get_site(i)); } const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; const std::size_t max_leaf_size = 100; - spatial_tree local_dest_tree(max_depth, + spatial_tree local_dest_tree(max_depth, max_leaf_size, std::move(network_dest_sites), - [](const network_site_info& info) -> spatial_tree::point_type { + [](const network_full_site_info& info) + -> spatial_tree::point_type { return {info.global_location.x, info.global_location.y, info.global_location.z}; }); // select connections - std::vector connections; + std::vector connections; + std::mutex connections_mutex; auto sample_sources = [&](const util::range& source_range, const util::range& label_range) { - for (const auto& s: source_range) { - network_site_info src; + threading::parallel_for::apply(0, source_range.size(), ctx->thread_pool.get(), [&](int i) { + const auto& s = source_range[i]; + network_full_site_info src; src.gid = s.gid; src.lid = s.lid; src.kind = s.kind; @@ -235,12 +284,13 @@ std::vector generate_network_connections(const recipe& rec, src.global_location = s.global_location; src.hash = s.hash; - auto sample = [&](const network_site_info& dest) { + auto sample = [&](const network_full_site_info& dest) { if (selection.select_connection(src, dest)) { - connections.emplace_back(connection({src.gid, src.lid}, - {dest.gid, dest.lid}, - weight.get(src, dest), - delay.get(src, dest))); + const auto w = weight.get(src, dest); + const auto d = delay.get(src, dest); + + std::lock_guard guard(connections_mutex); + push_back(connections, src, dest, w, d); } }; @@ -256,7 +306,7 @@ std::vector generate_network_connections(const recipe& rec, sample); } else { local_dest_tree.for_each(sample); } - } + }); }; distributed_for_each(sample_sources, distributed, src_sites.sites, src_sites.labels); @@ -264,4 +314,54 @@ std::vector generate_network_connections(const recipe& rec, return connections; } +} // namespace + + + +network_full_site_info::network_full_site_info(cell_gid_type gid, + cell_lid_type lid, + cell_kind kind, + std::string_view label, + mlocation location, + mpoint global_location): + gid(gid), + lid(lid), + kind(kind), + label(std::move(label)), + location(location), + global_location(global_location) { + + std::uint64_t label_hash = simple_string_hash(this->label); + static_assert(sizeof(decltype(mlocation::pos)) == sizeof(std::uint64_t)); + std::uint64_t loc_pos_hash = *reinterpret_cast(&location.pos); + + // Initial seed. Changes will affect reproducibility of generated network connections. + constexpr std::uint64_t seed = 984293; + + using rand_type = r123::Threefry4x64; + const rand_type::ctr_type seed_input = {{seed, 2 * seed, 3 * seed, 4 * seed}}; + const rand_type::key_type key = {{gid, label_hash, location.branch, loc_pos_hash}}; + + rand_type gen; + hash = gen(seed_input, key)[0]; +} + +std::vector generate_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec) { + return generate_network_connections(rec, ctx, dom_dec); +} + +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec) { + auto connections = generate_network_connections(rec, ctx, dom_dec); + + // generated connections may have different orer each time due to multi-threading. + // sort before returning to user for reproducibility. + std::sort(connections.begin(), connections.end()); + + return connections; +} + } // namespace arb diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 6a4420368f..31480cf15b 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -21,11 +21,30 @@ namespace arb { +struct network_full_site_info { + network_full_site_info() = default; + + network_full_site_info(cell_gid_type gid, + cell_lid_type lid, + cell_kind kind, + std::string_view label, + mlocation location, + mpoint global_location); + + cell_gid_type gid; + cell_lid_type lid; + cell_kind kind; + std::string_view label; + mlocation location; + mpoint global_location; + network_hash_type hash; +}; + struct network_selection_impl { virtual std::optional max_distance() const { return std::nullopt; } - virtual bool select_connection(const network_site_info& src, - const network_site_info& dest) const = 0; + virtual bool select_connection(const network_full_site_info& src, + const network_full_site_info& dest) const = 0; virtual bool select_source(cell_kind kind, cell_gid_type gid, @@ -50,7 +69,7 @@ inline std::shared_ptr thingify(network_selection s, struct network_value_impl { - virtual double get(const network_site_info& src, const network_site_info& dest) const = 0; + virtual double get(const network_full_site_info& src, const network_full_site_info& dest) const = 0; virtual void initialize(const network_label_dict& dict) {}; @@ -65,4 +84,8 @@ inline std::shared_ptr thingify(network_value v, return v.impl_; } +std::vector generate_connections(const recipe& rec, + const context& ctx, + const domain_decomposition& dom_dec); + } // namespace arb diff --git a/python/network.cpp b/python/network.cpp index c0184cca1f..86f32155f2 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -4,7 +4,9 @@ #include #include +#include #include +#include #include #include #include @@ -14,9 +16,11 @@ #include #include +#include "context.hpp" #include "error.hpp" -#include "util.hpp" +#include "recipe.hpp" #include "strprintf.hpp" +#include "util.hpp" namespace py = pybind11; @@ -28,32 +32,32 @@ void register_network(py::module& m) { py::class_ network_site_info( m, "network_site_info", "Identifies a network site to connect to / from"); network_site_info.def_readwrite("gid", &arb::network_site_info::gid) - .def_readwrite("lid", &arb::network_site_info::lid) .def_readwrite("kind", &arb::network_site_info::kind) .def_readwrite("label", &arb::network_site_info::label) .def_readwrite("location", &arb::network_site_info::location) .def_readwrite("global_location", &arb::network_site_info::global_location) - .def("__repr__", [](const arb::network_site_info& s) { - return util::pprintf("", - s.lid, - s.kind, - s.label, - s.location, - s.global_location); - }); + .def("__repr__", [](const arb::network_site_info& s) { return util::pprintf("{}", s); }) + .def("__str__", [](const arb::network_site_info& s) { return util::pprintf("{}", s); }); + + py::class_ network_connection_info( + m, "network_connection_info", "Identifies a network connection"); + network_connection_info.def_readwrite("src", &arb::network_connection_info::src) + .def_readwrite("dest", &arb::network_connection_info::dest) + .def("__repr__", [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }) + .def("__str__", [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }); py::class_ network_selection( m, "network_selection", "Network selection."); + network_selection .def_static("custom", [](arb::network_selection::custom_func_type func) { return arb::network_selection::custom( - [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + [=](const arb::network_connection_info& c) { return try_catch_pyexception( [&]() { pybind11::gil_scoped_acquire guard; - return func(src, dest); + return func(c); }, "Python error already thrown"); }); @@ -69,11 +73,11 @@ void register_network(py::module& m) { .def_static("custom", [](arb::network_value::custom_func_type func) { return arb::network_value::custom( - [=](const arb::network_site_info& src, const arb::network_site_info& dest) { + [=](const arb::network_connection_info& c) { return try_catch_pyexception( [&]() { pybind11::gil_scoped_acquire guard; - return func(src, dest); + return func(c); }, "Python error already thrown"); }); @@ -134,6 +138,24 @@ void register_network(py::module& m) { "delay"_a, "dict"_a, "Construct network description."); + + m.def( + "generate_network_connections", + [](const std::shared_ptr& rec, + std::shared_ptr ctx, + std::optional decomp) { + py_recipe_shim rec_shim(rec); + + if (!ctx) ctx = std::make_shared(arb::make_context()); + if (!decomp) decomp = arb::partition_load_balance(rec_shim, ctx->context); + + return generate_network_connections(rec_shim, ctx->context, decomp.value()); + }, + "recipe"_a, + pybind11::arg_v("context", pybind11::none(), "Execution context"), + pybind11::arg_v("decomp", pybind11::none(), "Domain decomposition"), + "Generate network connections from the network description in the recipe. Will only " + "generate connections with local gids in the domain composition as destination."); } } // namespace pyarb diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index 4149ea8a50..b182da9af1 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -647,7 +647,7 @@ TEST(communicator, all2all) auto c = connections[i*n_local+j]; EXPECT_EQ(i, c.source.gid); EXPECT_EQ(0u, c.source.index); - EXPECT_EQ(i, c.destination); + EXPECT_EQ(i, c.destination.index); } } @@ -704,7 +704,7 @@ TEST(communicator, mini_network) auto c = connections[i*22 + j]; EXPECT_EQ(ex_source_gids[j], c.source.gid); EXPECT_EQ(ex_source_lids[j], c.source.index); - EXPECT_EQ(ex_target_lids[i%2][j], c.destination); + EXPECT_EQ(ex_target_lids[i%2][j], c.destination.index); } } } diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp index 4485852981..039fa01505 100644 --- a/test/unit/test_network.cpp +++ b/test/unit/test_network.cpp @@ -10,7 +10,7 @@ using namespace arb; namespace { -std::vector test_sites = { +std::vector test_sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {0.0, 0.0, 0.0}}, {1, 0, cell_kind::benchmark, "b", {0, 0.0}, {1.0, 0.0, 0.0}}, {2, 0, cell_kind::lif, "c", {0, 0.0}, {2.0, 0.0, 0.0}}, @@ -433,7 +433,7 @@ TEST(network_selection, random_seed) { TEST(network_selection, random_reproducibility) { const auto s = thingify(network_selection::random(42, 0.5), network_label_dict()); - std::vector sites = { + std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, @@ -450,8 +450,8 @@ TEST(network_selection, random_reproducibility) { } TEST(network_selection, custom) { - auto inter_cell_func = [](const network_site_info& src, const network_site_info& dest) { - return src.gid != dest.gid; + auto inter_cell_func = [](const network_connection_info& c) { + return c.src.gid != c.dest.gid; }; const auto s = thingify(network_selection::custom(inter_cell_func), network_label_dict()); const auto s_ref = thingify(network_selection::inter_cell(), network_label_dict()); @@ -557,7 +557,7 @@ TEST(network_value, uniform_distribution) { TEST(network_value, uniform_distribution_reproducibility) { const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); - std::vector sites = { + std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, @@ -610,7 +610,7 @@ TEST(network_value, normal_distribution_reproducibility) { const double std_dev = 3.0; const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); - std::vector sites = { + std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, @@ -674,7 +674,7 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), network_label_dict()); - std::vector sites = { + std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, @@ -701,8 +701,8 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { } TEST(network_value, custom) { - auto func = [](const network_site_info& src, const network_site_info& dest) { - return src.global_location.x + dest.global_location.x; + auto func = [](const network_connection_info& c) { + return c.src.global_location.x + c.dest.global_location.x; }; const auto v = thingify(network_value::custom(func), network_label_dict()); From 15516b0596c23d3416bc7ea98057ae3d38a777c4 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 24 Jul 2023 18:36:42 +0200 Subject: [PATCH 33/84] add tree and distributed for each tests --- arbor/communication/distributed_for_each.hpp | 30 +++- arbor/network_impl.cpp | 7 +- test/unit-distributed/CMakeLists.txt | 1 + .../test_distributed_for_each.cpp | 92 +++++++++++ test/unit/CMakeLists.txt | 1 + test/unit/test_spatial_tree.cpp | 155 ++++++++++++++++++ 6 files changed, 277 insertions(+), 9 deletions(-) create mode 100644 test/unit-distributed/test_distributed_for_each.cpp create mode 100644 test/unit/test_spatial_tree.cpp diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp index 2db82c9453..a097c2605e 100644 --- a/arbor/communication/distributed_for_each.hpp +++ b/arbor/communication/distributed_for_each.hpp @@ -1,11 +1,12 @@ #pragma once -#include -#include +#include #include -#include +#include #include #include +#include +#include #include #include "distributed_context.hpp" @@ -42,7 +43,7 @@ void for_each_in_tuple_pair(FUNC&& func, std::tuple& t1, std::tuple void distributed_for_each(FUNC&& func, const distributed_context& distributed, - const std::vector&... args) { + const util::range&... args) { static_assert(sizeof...(args) > 0); auto arg_tuple = std::forward_as_tuple(args...); @@ -83,8 +84,22 @@ void distributed_for_each(FUNC&& func, // compute maximum buffer size between ranks, such that we only allocate once const std::size_t max_buffer_size = distributed.max(buffer_size); - // exit if all vectors on all ranks are empty - if (max_buffer_size == info.size() * sizeof(vec_info)) return; + std::tuple::value_type*>...> + ranges; + + if (max_buffer_size == info.size() * sizeof(vec_info)) { + // if all empty, call function with empty ranges for each step and exit + impl::for_each_in_tuple_pair( + [&](std::size_t i, auto&& vec, auto&& r) { + using T = typename std::remove_reference_t::value_type; + r = util::range(nullptr, nullptr); + }, + arg_tuple, + ranges); + + for (std::size_t step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); } + return; + } // use malloc for std::max_align_t alignment auto deleter = [](char* ptr) { std::free(ptr); }; @@ -99,11 +114,10 @@ void distributed_for_each(FUNC&& func, impl::for_each_in_tuple( [&](std::size_t i, auto&& vec) { using T = typename std::remove_reference_t::value_type; - std::memcpy(buffer.get() + info[i].offset, vec.data(), vec.size() * sizeof(T)); + std::copy(vec.begin(), vec.end(), (T*)(buffer.get() + info[i].offset)); }, arg_tuple); - std::tuple...> ranges; const auto my_rank = distributed.id(); const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1; diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 88b56687ca..6fc8328fbe 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -309,7 +309,12 @@ std::vector generate_network_connections(const recipe& rec, }); }; - distributed_for_each(sample_sources, distributed, src_sites.sites, src_sites.labels); + distributed_for_each(sample_sources, + distributed, + util::make_range(src_sites.sites.begin(), src_sites.sites.end()), + util::make_range(src_sites.labels.begin(), src_sites.labels.end())); + + // distributed_for_each(sample_sources, distributed, src_sites.sites, src_sites.labels); return connections; } diff --git a/test/unit-distributed/CMakeLists.txt b/test/unit-distributed/CMakeLists.txt index 545d18e021..8d9b0d8bb1 100644 --- a/test/unit-distributed/CMakeLists.txt +++ b/test/unit-distributed/CMakeLists.txt @@ -3,6 +3,7 @@ set(unit-distributed_sources test_domain_decomposition.cpp test_communicator.cpp test_mpi.cpp + test_distributed_for_each.cpp # unit test driver test.cpp diff --git a/test/unit-distributed/test_distributed_for_each.cpp b/test/unit-distributed/test_distributed_for_each.cpp new file mode 100644 index 0000000000..80a38274d6 --- /dev/null +++ b/test/unit-distributed/test_distributed_for_each.cpp @@ -0,0 +1,92 @@ +#include +#include "test.hpp" + +#include +#include +#include + +#include "communication/distributed_for_each.hpp" +#include "execution_context.hpp" +#include "util/range.hpp" + +using namespace arb; + +// check when all input is size 0 +TEST(distributed_for_each, all_zero) { + std::vector data; + + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + auto sample = [&](const util::range& range) { + EXPECT_EQ(0, range.size()); + ++call_count; + }; + + distributed_for_each( + sample, *g_context->distributed, util::make_range(data.begin(), data.end())); + + EXPECT_EQ(num_ranks, call_count); +} + +// check when input on one rank is size 0 +TEST(distributed_for_each, one_zero) { + const auto rank = g_context->distributed->id(); + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + // test data size is equal to rank id and vector is filled with rank id + std::vector data; + for (int i = 0; i < rank; ++i) { data.push_back(rank); } + + auto sample = [&](const util::range& range) { + const auto origin_rank = range.empty() ? 0 : range.front(); + + EXPECT_EQ(origin_rank, range.size()); + for (const auto& value: range) { EXPECT_EQ(value, origin_rank); } + ++call_count; + }; + + distributed_for_each( + sample, *g_context->distributed, util::make_range(data.begin(), data.end())); + + EXPECT_EQ(num_ranks, call_count); +} + +// check multiple types +TEST(distributed_for_each, multiple) { + const auto rank = g_context->distributed->id(); + const int num_ranks = g_context->distributed->size(); + int call_count = 0; + + std::vector data_1; + std::vector data_2; + std::vector> data_3; + // test data size is equal to rank id + 1and vector is filled with rank id + for (int i = 0; i < rank + 1; ++i) { data_1.push_back(rank); } + // test different data sizes for each type + for (std::size_t i = 0; i < 2 * data_1.size(); ++i) { data_2.push_back(rank); } + for (std::size_t i = 0; i < 3 * data_1.size(); ++i) { data_3.push_back(rank); } + + auto sample = [&](const util::range& range_1, + const util::range& range_2, + const util::range*>& range_3) { + const auto origin_rank = range_1.empty() ? 0 : range_1.front(); + + EXPECT_EQ(origin_rank + 1, range_1.size()); + EXPECT_EQ(range_2.size(), 2 * range_1.size()); + EXPECT_EQ(range_3.size(), 3 * range_1.size()); + for (const auto& value: range_1) { EXPECT_EQ(value, origin_rank); } + for (const auto& value: range_2) { EXPECT_EQ(value, double(origin_rank)); } + for (const auto& value: range_3) { EXPECT_EQ(value, std::complex(origin_rank)); } + ++call_count; + }; + + distributed_for_each(sample, + *g_context->distributed, + util::make_range(data_1.begin(), data_1.end()), + util::make_range(data_2.begin(), data_2.end()), + util::make_range(data_3.begin(), data_3.end())); + + EXPECT_EQ(num_ranks, call_count); +} diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index bea0c4c30f..e84a325c97 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -128,6 +128,7 @@ set(unit_sources test_simd.cpp test_simulation.cpp test_span.cpp + test_spatial_tree.cpp test_spike_source.cpp test_spikes.cpp test_spike_store.cpp diff --git a/test/unit/test_spatial_tree.cpp b/test/unit/test_spatial_tree.cpp new file mode 100644 index 0000000000..f32d9a04df --- /dev/null +++ b/test/unit/test_spatial_tree.cpp @@ -0,0 +1,155 @@ +#include + +#include + +#include "util/spatial_tree.hpp" + +#include +#include +#include +#include +#include +#include + +using namespace arb; + +namespace { + +template +struct data_point { + int id = 0; + std::array point; + + bool operator<(const data_point& p) const { + return id < p.id || (id == p.id && point < p.point); + } +}; + +template +struct bounding_box_data { + bounding_box_data(std::size_t seed, + std::size_t num_points, + std::array box_min, + std::array box_max): + box_min(box_min), + box_max(box_max) { + + std::minstd_rand rand_gen(seed); + + data.reserve(num_points); + for (std::size_t i = 0; i < num_points; ++i) { + data_point p; + p.id = i; + for (std::size_t d = 0; d < DIM; ++d) { + + std::uniform_real_distribution distri(box_min[d], box_max[d]); + p.point[d] = distri(rand_gen); + } + data.emplace_back(p); + } + } + + std::array box_min; + std::array box_max; + std::vector> data; +}; + +class st_test: + public ::testing::TestWithParam< + std::tuple> { +public: + void test_spatial_tree() { + switch (std::get<0>(GetParam())) { + case 1: test_spatial_tree_dim<1>(); break; + case 2: test_spatial_tree_dim<2>(); break; + case 3: test_spatial_tree_dim<3>(); break; + case 4: test_spatial_tree_dim<4>(); break; + case 5: test_spatial_tree_dim<5>(); break; + case 6: test_spatial_tree_dim<6>(); break; + default: ASSERT_TRUE(false); + } + } + +private: + template + void test_spatial_tree_dim() { + std::size_t max_depth = std::get<1>(GetParam()); + std::size_t leaf_size_target = std::get<2>(GetParam()); + std::size_t num_points = std::get<3>(GetParam()); + + std::vector> boxes; + std::array box_min, box_max; + std::vector> data; + box_min.fill(-10.0); + box_max.fill(0.0); + + for (std::size_t i = 0; i < DIM; ++i) { + boxes.emplace_back(1, num_points, box_min, box_max); + data.insert(data.end(), boxes.back().data.begin(), boxes.back().data.end()); + box_min[i] += 20.0; + box_max[i] += 20.0; + } + + spatial_tree, DIM> tree( + max_depth, leaf_size_target, data, [](const data_point& d) { return d.point; }); + + // check box without any points + tree.bounding_box_for_each( + box_min, box_max, [](const data_point& d) { ASSERT_TRUE(false); }); + + // check iteration over full tree + { + std::vector> tree_data; + tree.for_each([&](const data_point& d) { tree_data.emplace_back(d); }); + ASSERT_EQ(data.size(), tree_data.size()); + + std::sort(data.begin(), data.end()); + std::sort(tree_data.begin(), tree_data.end()); + for (std::size_t i = 0; i < data.size(); ++i) { + ASSERT_EQ(data[i].id, tree_data[i].id); + ASSERT_EQ(data[i].point, tree_data[i].point); + } + } + + // check contents within each box + for (auto& box: boxes) { + std::vector> tree_data; + tree.bounding_box_for_each(box.box_min, box.box_max, [&](const data_point& d) { + tree_data.emplace_back(d); + }); + ASSERT_EQ(box.data.size(), tree_data.size()); + + std::sort(tree_data.begin(), tree_data.end()); + std::sort(box.data.begin(), box.data.end()); + + for (std::size_t i = 0; i < box.data.size(); ++i) { + ASSERT_EQ(box.data[i].id, tree_data[i].id); + ASSERT_EQ(box.data[i].point, tree_data[i].point); + } + } + } +}; + +std::string param_type_names( + const ::testing::TestParamInfo>& + info) { + std::stringstream stream; + + stream << "dim_" << std::get<0>(info.param); + stream << "_depth_" << std::get<1>(info.param); + stream << "_leaf_" << std::get<2>(info.param); + stream << "_n_" << std::get<3>(info.param); + + return stream.str(); +} +} // namespace + +TEST_P(st_test, param) { test_spatial_tree(); } + +INSTANTIATE_TEST_SUITE_P(spatial_tree, + st_test, + ::testing::Combine(::testing::Values(1, 2, 3), + ::testing::Values(1, 10, 20), + ::testing::Values(1, 100, 1000), + ::testing::Values(0, 1, 10, 100, 1000, 2000)), + param_type_names); From 9cbabec80f2cc78f54c8fab8a208fe22c961cc27 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 26 Jul 2023 17:54:39 +0200 Subject: [PATCH 34/84] add network generation test --- arbor/include/arbor/network.hpp | 1 + arbor/network.cpp | 2 + arbor/network_impl.cpp | 4 +- python/network.cpp | 8 +- test/unit-distributed/CMakeLists.txt | 5 +- .../test_network_generation.cpp | 180 ++++++++++++++++++ 6 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 test/unit-distributed/test_network_generation.cpp diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 29180f571a..f7875999ff 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -40,6 +40,7 @@ ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_site_info, struct ARB_SYMBOL_VISIBLE network_connection_info { network_site_info src, dest; + double weight, delay; ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_connection_info& s); }; diff --git a/arbor/network.cpp b/arbor/network.cpp index 268370689a..1a55427ce0 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -1534,6 +1534,8 @@ ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_connectio os << ""; return os; } diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 6fc8328fbe..7a5ddc83fd 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -118,7 +118,9 @@ void push_back(std::vector& vec, network_site_info{ src.gid, src.kind, std::string(src.label), src.location, src.global_location}, network_site_info{ - dest.gid, dest.kind, std::string(dest.label), dest.location, dest.global_location}}); + dest.gid, dest.kind, std::string(dest.label), dest.location, dest.global_location}, + weight, + delay}); } template diff --git a/python/network.cpp b/python/network.cpp index 86f32155f2..90780e6ec4 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -43,8 +43,12 @@ void register_network(py::module& m) { m, "network_connection_info", "Identifies a network connection"); network_connection_info.def_readwrite("src", &arb::network_connection_info::src) .def_readwrite("dest", &arb::network_connection_info::dest) - .def("__repr__", [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }) - .def("__str__", [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }); + .def_readwrite("weight", &arb::network_connection_info::weight) + .def_readwrite("delay", &arb::network_connection_info::delay) + .def("__repr__", + [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }) + .def("__str__", + [](const arb::network_connection_info& c) { return util::pprintf("{}", c); }); py::class_ network_selection( m, "network_selection", "Network selection."); diff --git a/test/unit-distributed/CMakeLists.txt b/test/unit-distributed/CMakeLists.txt index 8d9b0d8bb1..96ea5abf4b 100644 --- a/test/unit-distributed/CMakeLists.txt +++ b/test/unit-distributed/CMakeLists.txt @@ -4,6 +4,7 @@ set(unit-distributed_sources test_communicator.cpp test_mpi.cpp test_distributed_for_each.cpp + test_network_generation.cpp # unit test driver test.cpp @@ -14,7 +15,7 @@ add_dependencies(tests unit-local) target_compile_options(unit-local PRIVATE ${ARB_CXX_FLAGS_TARGET_FULL}) target_compile_definitions(unit-local PRIVATE TEST_LOCAL) -target_link_libraries(unit-local PRIVATE ext-gtest arbor arborenv arbor-sup arbor-private-headers ext-tinyopt) +target_link_libraries(unit-local PRIVATE ext-gtest arbor arborenv arborio arbor-sup arbor-private-headers ext-tinyopt) if(ARB_WITH_MPI) add_executable(unit-mpi EXCLUDE_FROM_ALL ${unit-distributed_sources}) @@ -22,6 +23,6 @@ if(ARB_WITH_MPI) target_compile_options(unit-mpi PRIVATE ${ARB_CXX_FLAGS_TARGET_FULL}) target_compile_definitions(unit-mpi PRIVATE TEST_MPI) - target_link_libraries(unit-mpi PRIVATE ext-gtest arbor arborenv arbor-sup arbor-private-headers ext-tinyopt) + target_link_libraries(unit-mpi PRIVATE ext-gtest arbor arborenv arborio arbor-sup arbor-private-headers ext-tinyopt) endif() diff --git a/test/unit-distributed/test_network_generation.cpp b/test/unit-distributed/test_network_generation.cpp new file mode 100644 index 0000000000..298a77f5b8 --- /dev/null +++ b/test/unit-distributed/test_network_generation.cpp @@ -0,0 +1,180 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test.hpp" +#include "execution_context.hpp" + +using namespace arb; +using namespace arborio::literals; + + +namespace { +// Create alternatingly a cable, lif and spike source cell with at most one source or destination +class network_test_recipe: public arb::recipe { +public: + network_test_recipe(unsigned num_cells, + network_selection selection, + network_value weight, + network_value delay): + num_cells_(num_cells), + selection_(selection), + weight_(weight), + delay_(delay) { + gprop_.default_parameters = arb::neuron_parameter_defaults; + } + + cell_size_type num_cells() const override { + return num_cells_; + } + + arb::util::unique_any get_cell_description(cell_gid_type gid) const override { + if(gid % 3 == 1) { + return lif_cell("source", "target"); + } + if(gid % 3 == 2) { + return spike_source_cell("spike_source"); + } + + // cable cell + int stag = 1; // soma tag + int dtag = 3; // Dendrite tag. + double srad = 12.6157 / 2.0; // soma radius + double drad = 0.5; // Diameter of 1 μm for each dendrite cable. + arb::segment_tree tree; + tree.append( + arb::mnpos, {0, 0, -srad, srad}, {0, 0, srad, srad}, stag); // For area of 500 μm². + tree.append(0, {0, 0, 2 * srad, drad}, dtag); + + arb::label_dict labels; + labels.set("soma", reg::tagged(stag)); + labels.set("dend", reg::tagged(dtag)); + + auto decor = arb::decor{} + .paint("soma"_lab, arb::density("hh")) + .paint("dend"_lab, arb::density("pas")) + .set_default(arb::axial_resistivity{100}) // [Ω·cm] + .place(arb::mlocation{0, 0}, arb::threshold_detector{10}, "detector") + .place(arb::mlocation{0, 0.5}, arb::synapse("expsyn"), "primary_syn"); + + return arb::cable_cell(arb::morphology(tree), decor, labels); + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + if(gid % 3 == 1) { + return cell_kind::lif; + } + if(gid % 3 == 2) { + return cell_kind::spike_source; + } + + return cell_kind::cable; + } + + arb::isometry get_cell_isometry(cell_gid_type gid) const override { + // place cells with equal distance on a circle + const double angle = 2 * 3.1415926535897932 * gid / num_cells_; + const double radius = 500.0; + return arb::isometry::translate(radius * std::cos(angle), radius * std::sin(angle), 0.0); + }; + + std::optional network_description() const override { + return arb::network_description{selection_, weight_, delay_, {}}; + }; + + std::vector event_generators(cell_gid_type gid) const override { + return {}; + } + + std::vector get_probes(cell_gid_type gid) const override { + return {}; + } + + std::any get_global_properties(arb::cell_kind) const override { return gprop_; } + +private: + cell_size_type num_cells_; + arb::cable_cell_global_properties gprop_; + network_selection selection_; + network_value weight_, delay_; +}; + +} // namespace + +TEST(network_generation, all) { + const auto& ctx = g_context; + const int num_ranks = ctx->distributed->size(); + + const auto selection = network_selection::all(); + const auto weight = 2.0; + const auto delay = 3.0; + + const auto num_cells = 3 * num_ranks; + + auto rec = network_test_recipe(num_cells, selection, weight, delay); + + const auto decomp = partition_load_balance(rec, ctx); + + const auto connections = generate_network_connections(rec, ctx, decomp); + + std::unordered_map> connections_by_dest; + + for(const auto& c : connections) { + EXPECT_EQ(c.weight, weight); + EXPECT_EQ(c.delay, delay); + connections_by_dest[c.dest.gid].emplace_back(c); + } + + for (const auto& group: decomp.groups()) { + const auto num_dest = group.kind == cell_kind::spike_source ? 0 : 1; + for(const auto gid : group.gids) { + EXPECT_EQ(connections_by_dest[gid].size(), num_cells * num_dest); + } + } +} + + +TEST(network_generation, cable_only) { + const auto& ctx = g_context; + const int num_ranks = ctx->distributed->size(); + + const auto selection = intersect(network_selection::source_cell_kind(cell_kind::cable), + network_selection::destination_cell_kind(cell_kind::cable)); + const auto weight = 2.0; + const auto delay = 3.0; + + const auto num_cells = 3 * num_ranks; + + auto rec = network_test_recipe(num_cells, selection, weight, delay); + + const auto decomp = partition_load_balance(rec, ctx); + + const auto connections = generate_network_connections(rec, ctx, decomp); + + std::unordered_map> connections_by_dest; + + for(const auto& c : connections) { + EXPECT_EQ(c.weight, weight); + EXPECT_EQ(c.delay, delay); + connections_by_dest[c.dest.gid].emplace_back(c); + } + + for (const auto& group: decomp.groups()) { + for(const auto gid : group.gids) { + // Only one third is a cable cell + EXPECT_EQ(connections_by_dest[gid].size(), + group.kind == cell_kind::cable ? num_cells / 3 : 0); + } + } +} From 508bcc687eb982d8d04fe9d51496aa071308b490 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 28 Jul 2023 09:51:41 +0200 Subject: [PATCH 35/84] more effective multi-threading --- arbor/network_impl.cpp | 51 +++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 7a5ddc83fd..8b76a5494a 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -145,20 +146,28 @@ std::vector generate_network_connections(const recipe& rec, const auto& weight = *weight_ptr; const auto& delay = *delay_ptr; - // populate network sites for source and destination + std::unordered_map> gids_by_kind; + for (const auto& group: dom_dec.groups()) { - switch (group.kind) { - case cell_kind::cable: { - // We need access to morphology, so the cell is create directly + auto& gids = gids_by_kind[group.kind]; + for (const auto& gid: group.gids) { gids.emplace_back(gid); } + } + + for (const auto& [kind, gids]: gids_by_kind) { + // populate network sites for source and destination + if (kind == cell_kind::cable) { + const auto& cable_gids = gids; threading::parallel_for::apply( - 0, group.gids.size(), ctx->thread_pool.get(), [&](int i) { - const auto gid = group.gids[i]; + 0, cable_gids.size(), ctx->thread_pool.get(), [&](int i) { + const auto gid = cable_gids[i]; + const auto kind = rec.get_cell_kind(gid); + // We need access to morphology, so the cell is create directly cable_cell cell; try { cell = util::any_cast(rec.get_cell_description(gid)); } catch (std::bad_any_cast&) { - throw bad_cell_description(rec.get_cell_kind(gid), gid); + throw bad_cell_description(kind, gid); } auto lid_to_label = @@ -198,20 +207,20 @@ std::vector generate_network_connections(const recipe& rec, } } }); - } break; - default: { - // Assuming all other cell types do not have a morphology. We can use label resolution - // through factory and set local position to 0. - auto factory = cell_kind_implementation(group.kind, group.backend, *ctx, 0); + } + else { + // Assuming all other cell types do not have a morphology. We can use label + // resolution through factory and set local position to 0. + auto factory = cell_kind_implementation(kind, backend_kind::multicore, *ctx, 0); // We only need the label ranges cell_label_range sources, destinations; - std::ignore = factory(group.gids, rec, sources, destinations); + std::ignore = factory(gids, rec, sources, destinations); std::size_t source_label_offset = 0; std::size_t destination_label_offset = 0; - for (std::size_t i = 0; i < group.gids.size(); ++i) { - const auto gid = group.gids[i]; + for (std::size_t i = 0; i < gids.size(); ++i) { + const auto gid = gids[i]; const auto iso = rec.get_cell_isometry(gid); const auto point = iso.apply(mpoint{0.0, 0.0, 0.0, 0.0}); const auto num_source_labels = sources.sizes().at(i); @@ -224,9 +233,9 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = sources.labels().at(j); const auto& range = sources.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { - if (selection.select_source(group.kind, gid, label)) { + if (selection.select_source(kind, gid, label)) { std::lock_guard guard(src_sites_mutex); - src_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); + src_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); } } } @@ -238,9 +247,9 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = destinations.labels().at(j); const auto& range = destinations.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { - if (selection.select_destination(group.kind, gid, label)) { + if (selection.select_destination(kind, gid, label)) { std::lock_guard guard(dest_sites_mutex); - dest_sites.insert({gid, lid, group.kind, label, {0, 0.0}, point}); + dest_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); } } } @@ -248,8 +257,6 @@ std::vector generate_network_connections(const recipe& rec, source_label_offset += num_source_labels; destination_label_offset += num_destination_labels; } - - } break; } } @@ -323,8 +330,6 @@ std::vector generate_network_connections(const recipe& rec, } // namespace - - network_full_site_info::network_full_site_info(cell_gid_type gid, cell_lid_type lid, cell_kind kind, From fc38ec899ae0f7c94ec9caddf4fb1e2aa914332a Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 15 Aug 2023 18:41:57 +0200 Subject: [PATCH 36/84] fix mpi compilation --- arbor/communication/mpi_context.cpp | 10 ++++++++++ arbor/distributed_context.hpp | 12 ------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index ad1094e543..8f10a99253 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -212,6 +212,16 @@ struct remote_context_impl { return mpi_.gather_cell_labels_and_gids(local_labels_and_gids); } + distributed_request send_recv_nonblocking(std::size_t recv_count, + void* recv_data, + int source_id, + std::size_t send_count, + const void* send_data, + int dest_id, + int tag) const { + return mpi_.send_recv_nonblocking(recv_count, recv_data, source_id, send_count, send_data, dest_id, tag); + } + template std::vector gather(T value, int root) const { return mpi_.gather(value, root); } std::string name() const { return "MPIRemote"; } int id() const { return mpi_.id(); } diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index 0bc104004b..ea67765155 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -114,10 +114,6 @@ class distributed_context { return impl_->gather(value, root); } - std::vector gather_all(std::size_t value) const { - return impl_->gather_all(value); - } - template distributed_request send_recv_nonblocking(std::size_t recv_count, T* recv_data, @@ -176,7 +172,6 @@ class distributed_context { gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const = 0; virtual std::vector gather(std::string value, int root) const = 0; - virtual std::vector gather_all(std::size_t value) const = 0; virtual distributed_request send_recv_nonblocking(std::size_t recv_count, void* recv_data, int source_id, @@ -229,9 +224,6 @@ class distributed_context { gather(std::string value, int root) const override { return wrapped.gather(value, root); } - std::vector gather_all(std::size_t value) const override { - return wrapped.gather_all(value); - } distributed_request send_recv_nonblocking(std::size_t recv_count, void* recv_data, int source_id, @@ -306,10 +298,6 @@ struct local_context { return {std::move(value)}; } - std::vector gather_all(std::size_t value) const { - return std::vector({value}); - } - distributed_request send_recv_nonblocking(std::size_t dest_count, void* dest_data, int dest, From a5bf351493dc030f4b39516aa9be4a904736caaf Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 22 Aug 2023 10:39:30 +0200 Subject: [PATCH 37/84] documentation --- arbor/communication/distributed_for_each.hpp | 9 + arbor/distributed_context.hpp | 15 +- arborio/networkio.cpp | 4 +- doc/concepts/interconnectivity.rst | 283 +++++++++++++++++- doc/python/recipe.rst | 11 + .../network_description.cpp | 13 +- python/example/network_description.py | 16 +- 7 files changed, 326 insertions(+), 25 deletions(-) diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp index a097c2605e..b223d06c61 100644 --- a/arbor/communication/distributed_for_each.hpp +++ b/arbor/communication/distributed_for_each.hpp @@ -40,6 +40,15 @@ void for_each_in_tuple_pair(FUNC&& func, std::tuple& t1, std::tuple::value_type*>&...) -> void + * Given 'n' distributed ranks, the function will be called 'n' times with data from each rank. + * There is no guaranteed order. + */ template void distributed_for_each(FUNC&& func, const distributed_context& distributed, diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index ea67765155..eb9db2ae64 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -35,7 +35,16 @@ namespace arb { #define ARB_COLLECTIVE_TYPES_ float, double, int, unsigned, long, unsigned long, long long, unsigned long long + +// A helper struct, representing a request for data exchange. +// After calling finalize() or destruction, the data exchange is guaranteed to be finished. struct distributed_request { + struct distributed_request_interface { + virtual void finalize() {}; + + virtual ~distributed_request_interface() = default; + }; + inline void finalize() { if (impl) { impl->finalize(); @@ -43,12 +52,6 @@ struct distributed_request { } } - struct distributed_request_interface { - virtual void finalize() {}; - - virtual ~distributed_request_interface() = default; - }; - ~distributed_request() { try { finalize(); diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp index cd38dd4ac4..136bd0b618 100644 --- a/arborio/networkio.cpp +++ b/arborio/networkio.cpp @@ -238,9 +238,9 @@ eval_map_type network_eval_map{ {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, - {"exp", make_call(arb::network_value::exp, "Logarithm. 1 argument: (value:real)")}, + {"exp", make_call(arb::network_value::exp, "Exponential function. 1 argument: (value:real)")}, {"exp", - make_call(arb::network_value::exp, "Logarithm. 1 argument: (value:real)")}, + make_call(arb::network_value::exp, "Exponential function. 1 argument: (value:real)")}, }; parse_network_hopefully eval(const s_expr& e, const eval_map_type& map); diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 8b7de20c46..50ac01af56 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -18,6 +18,277 @@ These sites as such are not connected yet, however the :ref:`recipe ` +, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or destination label, cell indices and also distance between source and destination. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. +Each connection also requires a weight and delay value. For this purpose, a ``network_value`` type is available, that allows to mathematically describe the value calculation using common math functions, as well random distributions. + +The following example shows the relevant recipe functions, where cells are connected into a ring with additional random connections between them: + +.. code-block:: python + + def network_description(self): + seed = 42 + + # create a chain + s_chain = f"(chain (gid-range 0 {self.ncells}))" + # connect front and back of chain to form ring + s_ring = f"(join {s_chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + + # Create random connections with probability inversely proportional to the distance within a + # radius + max_dist = 400.0 # μm + probability = f"(div (sub {max_dist} (distance)) {max_dist})" + s_rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" + + # combine ring with random selection + s = f"(join {s_ring} {s_rand})" + # restrict to inter-cell connections and certain source / destination labels + s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" + + # fixed weight for connections in ring + w_ring = f"(scalar 0.01)" + # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS + # and truncated to [0.005, 0.035] + w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + + # combine into single weight expression + w = f"(if-else {s_ring} {w_ring} {w_rand})" + + # fixed delay + d = "(scalar 5.0)" # ms delay + + return arbor.network_description(s, w, d, {}) + + def cell_isometry(self, gid): + # place cells with equal distance on a circle + radius = 500.0 # μm + angle = 2.0 * math.pi * gid / self.ncells + return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) + + +The export function ``generate_network_connections`` allows the inspection of generated connections. The exported connections include the cell index, local label and location of both source and destination. + + +.. note:: + + Expressions using distance require a cell isometry to resolve the global location of connection points. + +.. note:: + + A high-level description may be used together with providing explicit connection lists for each cell, but it is up to the user to avoid multiple connections between the same source and destination. + +.. warning:: + + Generating connections always involves additional work and may increase the time spent in the simulation initialization phase. + + +.. _interconnectivity-selection-expressions: + +Network Selection Expressions +----------------------------- + +.. label:: (gid-range begin:integer end:integer) + + A range expression, representing a range of indices in the half-open interval [begin, end). + +.. label:: (gid-range begin:integer end:integer step:integer) + + A range expression, representing a range of indices in the half-open interval [begin, end) with a given step size. Step size must be positive. + +.. label:: (cable-cell) + + Cell kind expression for cable cells. + +.. label:: (lif-cell) + + Cell kind expression for lif cells. + +.. label:: (benchmark-cell) + + Cell kind expression for benchmark cells. + +.. label:: (spike-source-cell) + + Cell kind expression for spike source cells. + +.. label:: (all) + + A selection of all possible connections. + +.. label:: (none) + + A selection representing the empty set of possible connections. + +.. label:: (inter-cell) + + A selection of all connections that connect two different cells. + +.. label:: (network-selection name:string) + + A named selection within the network dictionary. + +.. label:: (intersect network-selection network-selection [...network-selection]) + + The intersection of at least two selections. + +.. label:: (join network-selection network-selection [...network-selection]) + + The union of at least two selections. + +.. label:: (symmetric-difference network-selection network-selection [...network-selection]) + + The symmetric difference of at least two selections. + +.. label:: (difference network-selection network-selection) + + The difference of two selections. + +.. label:: (difference network-selection) + + The complement or opposite of the given selection. + +.. label:: (source-cell-kind kind:cell-kind) + + All connections, where the source cell is of the given type. + +.. label:: (destination-cell-kind kind:cell-kind) + + All connections, where the destination cell is of the given type. + +.. label:: (source-label label:string) + + All connections, where the source label matches the given label. + +.. label:: (destination-label label:string) + + All connections, where the destination label matches the given label. + +.. label:: (source-cell integer [...integer]) + + All connections, where the source cell index matches one of the given integer values. + +.. label:: (source-cell range:gid-range) + + All connections, where the source cell index is contained in the given gid-range. + +.. label:: (destination-cell integer [...integer]) + + All connections, where the destination cell index matches one of the given integer values. + +.. label:: (destination-cell range:gid-range) + + All connections, where the destination cell index is contained in the given gid-range. + +.. label:: (chain integer [...integer]) + + A chain of connections between cells in the given order of in the list, such that entry "i" is the source and entry "i+1" the destination. + +.. label:: (chain range:gid-range) + + A chain of connections between cells in the given order of the gid-range, such that entry "i" is the source and entry "i+1" the destination. + +.. label:: (chain-reverse range:gid-range) + + A chain of connections between cells in reverse of the given order of the gid-range, such that entry "i+1" is the source and entry "i" the destination. + +.. label:: (random p:real) + + A random selection of connections, where each connection is selected with the given probability. + +.. label:: (random p:network-value) + + A random selection of connections, where each connection is selected with the given probability expression. + +.. label:: (random p:network-value) + + A random selection of connections, where each connection is selected with the given probability expression. + +.. label:: (distance-lt dist:real) + + All connections, where the distance between source and destination is less than the given value in micro meter. + +.. label:: (distance-gt dist:real) + + All connections, where the distance between source and destination is greater than the given value in micro meter. + + +.. _interconnectivity-value-expressions: + +Network Value Expressions +------------------------- + +.. label:: (scalar value:real) + + A scalar of given value. + +.. label:: (network-value name:string) + + A named network value in the network dictionary. + +.. label:: (distance) + + The distance between source and destination. + +.. label:: (distance value:real) + + The distance between source and destination scaled by the given value. + +.. label:: (uniform-distribution seed:integer begin:real end:real) + + Uniform random distribution within the interval [begin, end). + +.. label:: (normal-distribution seed:integer mean:real std_deviation:real) + + Normal random distribution with given mean and standard deviation. + +.. label:: (truncated-normal-distribution seed:integer mean:real std_deviation:real begin:real end:real) + + Truncated normal random distribution with given mean and standard deviation within the interval [begin, end). + +.. label:: (if-else sel:network-selection true_value:network-value false_value:network-value) + + Truncated normal random distribution with given mean and standard deviation within the interval [begin, end). + +.. label:: (add (network-value | real) (network-value | real) [... (network-value | real)]) + + Addition of at least two network values or real numbers. + +.. label:: (sub (network-value | real) (network-value | real) [... (network-value | real)]) + + Subtraction of at least two network values or real numbers. + +.. label:: (mul (network-value | real) (network-value | real) [... (network-value | real)]) + + Multiplication of at least two network values or real numbers. + +.. label:: (div (network-value | real) (network-value | real) [... (network-value | real)]) + + Division of at least two network values or real numbers. + The expression is evaluated from the left to right, dividing the first element by each divisor in turn. + +.. label:: (min (network-value | real) (network-value | real) [... (network-value | real)]) + + Minimum of at least two network values or real numbers. + +.. label:: (max (network-value | real) (network-value | real) [... (network-value | real)]) + + Maximum of at least two network values or real numbers. + +.. label:: (log (network-value | real)) + + Logarithm of a network value or real number. + +.. label:: (exp (network-value | real)) + + Exponential function of a network value or real number. + + + .. _interconnectivity-mut: Mutability @@ -37,8 +308,8 @@ connection table outside calls to `run`, for example # extend the recipe to more connections rec.add_connections() - # use `connections_on` to build a new connection table - sim.update_connections(rec) + # use updated recipe to build a new connection table + sim.update(rec) # run simulation for 0.25ms with the extended connectivity sim.run(0.5, 0.025) @@ -48,12 +319,6 @@ must be explicitly included in the updated callback. This can also be used to update connection weights and delays. Note, however, that there is currently no way to introduce new sites to the simulation, nor any changes to gap junctions. -The ``update_connections`` method accepts either a full ``recipe`` (but will -**only** use the ``connections_on`` and ``events_generators`` callbacks) or a -``connectivity``, which is a reduced recipe exposing only the relevant callbacks. -Currently ``connectivity`` is only available in C++; Python users have to pass a -full recipe. - .. warning:: The semantics of connection updates are subtle and might produce surprising @@ -78,6 +343,8 @@ full recipe. in these callbacks. This is doubly important when using models with dynamic connectivity where the temptation to store all connections is even larger and each call to ``update`` will re-evaluate the corresponding callbacks. + Alternatively, connections can be generated by Arbor using the network DSL + through the ``network_description`` callback function. .. _interconnectivitycross: diff --git a/doc/python/recipe.rst b/doc/python/recipe.rst index 6e23106dcd..09da00ba36 100644 --- a/doc/python/recipe.rst +++ b/doc/python/recipe.rst @@ -84,7 +84,12 @@ Recipe By default returns an empty list. + .. function:: network_description() + Returns a network description, consisting of a network selection, network value for + weight and delay, and a network dictionary. + + By default returns none. .. function:: gap_junctions_on(gid) @@ -122,6 +127,12 @@ Recipe By default returns an empty object. + .. function:: cell_isometry(gid) + + Returns a isometry consisting of translation and rotation, which is applied to the cell morphology for resolving global locations. + + By default returns a isometry without translation and rotation. + Cells ------ diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index b780dc6a21..24a6943410 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -125,10 +125,17 @@ class ring_recipe: public arb::recipe { arb::network_selection::source_label({"detector"}), arb::network_selection::destination_label({"primary_syn"})); - // normal distributed weight with mean 0.05 μS, standard deviation 0.02 μS + // random normal distributed weight with mean 0.05 μS, standard deviation 0.02 μS // and truncated to [0.025, 0.075] - auto w = "(truncated-normal-distribution 42 0.05 0.02 0.025 0.075)"_nv; - // note: We are using s-expressions here as an alternative for creating a network_value + auto w_rand = "(truncated-normal-distribution 42 0.05 0.02 0.025 0.075)"_nv; + // note: We are using s-expressions here as an alternative for creating a network_value. + // This alternative way is also available for network selections. + + // fixed weight for connections in ring + auto w_ring = "(scalar 0.01)"_nv; + + // combine into single weight by using the "ring" selection as condition + auto w = arb::network_value::if_else(ring, w_ring, w_rand); return arb::network_description{s, w, min_delay_, {}}; }; diff --git a/python/example/network_description.py b/python/example/network_description.py index a166cbffd4..1e0a8b8883 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -120,9 +120,15 @@ def network_description(self): # restrict to inter-cell connections and certain source / destination labels s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" - # normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS + # fixed weight for connections in ring + w_ring = f"(scalar 0.01)" + # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS # and truncated to [0.005, 0.035] - w = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" + + # combine into single weight expression + w = f"(if-else {ring} {w_ring} {w_rand})" + # fixed delay d = "(scalar 5.0)" # ms delay @@ -145,12 +151,10 @@ def global_properties(self, kind): # (11) Instantiate recipe -ncells = 4 +ncells = 20 recipe = random_ring_recipe(ncells) -# (12) Create an execution context using all locally available threads and simulation -ctx = arbor.context("avail_threads") -sim = arbor.simulation(recipe, ctx) +sim = arbor.simulation(recipe) # (13) Set spike generators to record sim.record(arbor.spike_recording.all) From 8dae79256b369d5b6a829411af988c9fa907191c Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 24 Aug 2023 16:57:04 +0200 Subject: [PATCH 38/84] api documentation --- arbor/include/arbor/network.hpp | 50 +++- arbor/include/arbor/network_generation.hpp | 2 + arbor/network_impl.cpp | 4 +- doc/concepts/interconnectivity.rst | 7 +- doc/cpp/interconnectivity.rst | 261 +++++++++++++++++++++ doc/python/interconnectivity.rst | 59 +++++ 6 files changed, 364 insertions(+), 19 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index f7875999ff..d3cdb9b0b5 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -67,23 +67,21 @@ class ARB_SYMBOL_VISIBLE network_value { // Scalar value. Will always return the same value given at construction. static network_value scalar(double value); + // A named value inside a network label dictionary static network_value named(std::string name); + // Distamce netweem source and destination site static network_value distance(double scale = 1.0); - // Uniform random value in (range[0], range[1]]. Always returns the same value for repeated - // calls with the same arguments and calls are symmetric v(a, b) = v(b, a). + // Uniform random value in (range[0], range[1]]. static network_value uniform_distribution(unsigned seed, const std::array& range); - // Radom value from a normal distribution with given mean and standard deviation. Always returns - // the same value for repeated calls with the same arguments and calls are symmetric v(a, b) = - // v(b, a). + // Radom value from a normal distribution with given mean and standard deviation. static network_value normal_distribution(unsigned seed, double mean, double std_deviation); // Radom value from a truncated normal distribution with given mean and standard deviation (of a - // non-truncated normal distribution), where the value is always in (range[0], range[1]]. Always - // returns the same value for repeated calls with the same arguments and calls are symmetric - // v(a, b) = v(b, a). Note: Values are generated by reject-accept method from a normal + // non-truncated normal distribution), where the value is always in (range[0], range[1]]. + // Note: Values are generated by reject-accept method from a normal // distribution. Low acceptance rate can leed to poor performance, for example with very small // ranges or a mean far outside the range. static network_value truncated_normal_distribution(unsigned seed, @@ -92,8 +90,7 @@ class ARB_SYMBOL_VISIBLE network_value { const std::array& range); // Custom value using the provided function "func". Repeated calls with the same arguments - // to "func" must yield the same result. For gap junction values, - // "func" must be symmetric (func(a,b) = func(b,a)). + // to "func" must yield the same result. static network_value custom(custom_func_type func); static network_value add(network_value left, network_value right); @@ -112,6 +109,7 @@ class ARB_SYMBOL_VISIBLE network_value { static network_value max(network_value left, network_value right); + // if contained in selection, the true_value is used and the false_value otherwise. static network_value if_else(network_selection cond, network_value true_value, network_value false_value); @@ -161,39 +159,55 @@ class ARB_SYMBOL_VISIBLE network_selection { // Select none static network_selection none(); + // Named selection in the network label dictionary static network_selection named(std::string name); // Only select connections between different cells static network_selection inter_cell(); + // Select connections with the given source cell kind static network_selection source_cell_kind(cell_kind kind); + // Select connections with the given destination cell kind static network_selection destination_cell_kind(cell_kind kind); + // Select connections with the given source label static network_selection source_label(std::vector labels); + // Select connections with the given destination label static network_selection destination_label(std::vector labels); + // Select connections with source cells matching the indices in the list static network_selection source_cell(std::vector gids); + // Select connections with source cells matching the indices in the range static network_selection source_cell(gid_range range); + // Select connections with destination cells matching the indices in the list static network_selection destination_cell(std::vector gids); + // Select connections with destination cells matching the indices in the range static network_selection destination_cell(gid_range range); + // Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" static network_selection chain(std::vector gids); + // Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" static network_selection chain(gid_range range); + // Select connections that form a reversed chain, such that source cell "i+1" is connected to the destination cell "i" static network_selection chain_reverse(gid_range range); + // Select connections, that are selected by both "left" and "right" static network_selection intersect(network_selection left, network_selection right); + // Select connections, that are selected by either or both "left" and "right" static network_selection join(network_selection left, network_selection right); + // Select connections, that are selected by "left", unless selected by "right" static network_selection difference(network_selection left, network_selection right); + // Select connections, that are selected by "left" or "right", but not both static network_selection symmetric_difference(network_selection left, network_selection right); // Invert the selection @@ -234,16 +248,22 @@ class ARB_SYMBOL_VISIBLE network_label_dict { using ns_map = std::unordered_map; using nv_map = std::unordered_map; + // Store a network selection under the given name network_label_dict& set(const std::string& name, network_selection s); + // Store a network value under the given name network_label_dict& set(const std::string& name, network_value v); + // Returns the stored network selection of the given name if it exists. None otherwise. std::optional selection(const std::string& name) const; + // Returns the stored network value of the given name if it exists. None otherwise. std::optional value(const std::string& name) const; + // All stored network selections inline const ns_map& selections() const { return selections_; } + // All stored network value inline const nv_map& values() const { return values_; } private: @@ -251,22 +271,30 @@ class ARB_SYMBOL_VISIBLE network_label_dict { nv_map values_; }; -struct network_description { +// A complete network description required for processing +struct ARB_SYMBOL_VISIBLE network_description { network_selection selection; network_value weight; network_value delay; network_label_dict dict; }; +// Join two network selections ARB_ARBOR_API network_selection join(network_selection left, network_selection right); +// Join three or more network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); template network_selection join(network_selection l, network_selection r, Args... args) { return join(join(std::move(l), std::move(r)), std::move(args)...); } +// Intersect two network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); ARB_ARBOR_API network_selection intersect(network_selection left, network_selection right); +// Intersect three or more network selections +ARB_ARBOR_API network_selection join(network_selection left, network_selection right); template network_selection intersect(network_selection l, network_selection r, Args... args) { return intersect(intersect(std::move(l), std::move(r)), std::move(args)...); diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp index 7a61f6e948..bf174b3f76 100644 --- a/arbor/include/arbor/network_generation.hpp +++ b/arbor/include/arbor/network_generation.hpp @@ -6,6 +6,8 @@ namespace arb { +// Generate and return list of connections from the network description of the recipe. +// Does not include connections from the "connections_on" recipe function. ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec); diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 624156b673..8528bc34c7 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -369,8 +369,8 @@ ARB_ARBOR_API std::vector generate_network_connections( const domain_decomposition& dom_dec) { auto connections = generate_network_connections(rec, ctx, dom_dec); - // generated connections may have different orer each time due to multi-threading. - // sort before returning to user for reproducibility. + // generated connections may have different order each time due to multi-threading. + // Sort before returning to user for reproducibility. std::sort(connections.begin(), connections.end()); return connections; diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 50ac01af56..6d876e49f3 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -23,8 +23,7 @@ The recipe callbacks are interrogated during simulation creation. High Level Network Description ------------------------------ -As an alternative to providing a list of connections for each cell in the :ref:`recipe ` -, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or destination label, cell indices and also distance between source and destination. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. +As an alternative to providing a list of connections for each cell in the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or destination label, cell indices and also distance between source and destination. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. Each connection also requires a weight and delay value. For this purpose, a ``network_value`` type is available, that allows to mathematically describe the value calculation using common math functions, as well random distributions. The following example shows the relevant recipe functions, where cells are connected into a ring with additional random connections between them: @@ -204,10 +203,6 @@ Network Selection Expressions A random selection of connections, where each connection is selected with the given probability expression. -.. label:: (random p:network-value) - - A random selection of connections, where each connection is selected with the given probability expression. - .. label:: (distance-lt dist:real) All connections, where the distance between source and destination is less than the given value in micro meter. diff --git a/doc/cpp/interconnectivity.rst b/doc/cpp/interconnectivity.rst index 9f572b1f90..6fd34c4e35 100644 --- a/doc/cpp/interconnectivity.rst +++ b/doc/cpp/interconnectivity.rst @@ -99,3 +99,264 @@ Interconnectivity .. cpp:member:: float weight unit-less gap junction connection weight. + +.. cpp:class:: network_site_info + + A network connection site on a cell. Used for generated connections through the high-level network description. + + .. cpp:member:: cell_gid_type gid + + The cell index. + + .. cpp:member:: cell_kind kind + + The cell kind. + + .. cpp:member:: cell_tag_type label + + The associated label. + + .. cpp:member:: mlocation location + + The local location on the cell. + + .. cpp:member:: mpoint global_location + + The global location in cartesian coordinates. + + +.. cpp:class:: network_connection_info + + A network connection between cells. Used for generated connections through the high-level network description. + + .. cpp:member:: network_site_info src + + The source connection site. + + .. cpp:member:: network_site_info dest + + The destination connection site. + + +.. cpp:class:: network_value + + A network value, describing the its calculation for each connection. + + .. cpp:function:: network_value scalar(double value) + + A fixed scalar valaue. + + .. cpp:function:: network_value named(std::string name) + + A named network value in the network label dictionary. + + .. cpp:function:: network_value distance() + + The value representing the distance between source and destination. + + .. cpp:function:: network_value uniform_distribution(unsigned seed, const std::array& range) + + A uniform random distribution within [range_0, range_1) + + .. cpp:function:: network_value normal_distribution(unsigned seed, double mean, double std_deviation) + + A normal random distribution with given mean and standard deviation. + + .. cpp:function:: network_value truncated_normal_distribution(unsigned seed, double mean, double std_deviation, const std::array& range) + + A truncated normal random distribution with given mean and standard deviation. Sampled through accept-reject method to only returns values in [range_0, range_1) + + .. cpp:function:: network_value custom(custom_func_type func) + + Custom value using the provided function "func". Repeated calls with the same arguments to "func" must yield the same result. + + .. cpp:function:: network_value add(network_value left, network_value right) + + Summation of two values. + + .. cpp:function:: network_value sub(network_value left, network_value right) + + Subtraction of two values. + + .. cpp:function:: network_value mul(network_value left, network_value right) + + Multiplication of two values. + + .. cpp:function:: network_value div(network_value left, network_value right) + + Division of two values. + + .. cpp:function:: network_value min(network_value left, network_value right) + + Minimum of two values. + + .. cpp:function:: network_value max(network_value left, network_value right) + + Maximum of two values. + + .. cpp:function:: network_value exp(network_value v) + + Exponential of given value. + + .. cpp:function:: network_value log(network_value v) + + Logarithm of given value. + + .. cpp:function:: if_else(network_selection cond, network_value true_value, network_value false_value) + + if contained in selection, the true_value is used and the false_value otherwise. + + +.. cpp:class:: network_selection + + A network selection, describing a subset of all possible connections. + + .. cpp:function:: network_selection all() + + Select all + + .. cpp:function:: network_selection none(); + + Select none + + .. cpp:function:: network_selection named(std::string name); + + Named selection in the network label dictionary + + .. cpp:function:: network_selection inter_cell(); + + Only select connections between different cells + + .. cpp:function:: network_selection source_cell_kind(cell_kind kind); + + Select connections with the given source cell kind + + .. cpp:function:: network_selection destination_cell_kind(cell_kind kind); + + Select connections with the given destination cell kind + + .. cpp:function:: network_selection source_label(std::vector labels); + + Select connections with the given source label + + .. cpp:function:: network_selection destination_label(std::vector labels); + + Select connections with the given destination label + + .. cpp:function:: network_selection source_cell(std::vector gids); + + Select connections with source cells matching the indices in the list + + .. cpp:function:: network_selection source_cell(gid_range range); + + Select connections with source cells matching the indices in the range + + .. cpp:function:: network_selection destination_cell(std::vector gids); + + Select connections with destination cells matching the indices in the list + + .. cpp:function:: network_selection destination_cell(gid_range range); + + Select connections with destination cells matching the indices in the range + + .. cpp:function:: network_selection chain(std::vector gids); + + Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + + .. cpp:function:: network_selection chain(gid_range range); + + Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + + .. cpp:function:: network_selection chain_reverse(gid_range range); + + Select connections that form a reversed chain, such that source cell "i+1" is connected to the destination cell "i" + + .. cpp:function:: network_selection intersect(network_selection left, network_selection right); + + Select connections, that are selected by both "left" and "right" + + .. cpp:function:: network_selection join(network_selection left, network_selection right); + + Select connections, that are selected by either or both "left" and "right" + + .. cpp:function:: network_selection difference(network_selection left, network_selection right); + + Select connections, that are selected by "left", unless selected by "right" + + .. cpp:function:: network_selection symmetric_difference(network_selection left, network_selection right); + + Select connections, that are selected by "left" or "right", but not both + + .. cpp:function:: network_selection complement(network_selection s); + + Invert the selection + + .. cpp:function:: network_selection random(unsigned seed, network_value p); + + Random selection using the bernoulli random distribution with probability "p" between 0.0 and 1.0 + + .. cpp:function:: network_selection custom(custom_func_type func); + + Custom selection using the provided function "func". Repeated calls with the same arguments + to "func" must yield the same result. For gap junction selection, + "func" must be symmetric (func(a,b) = func(b,a)). + + .. cpp:function:: network_selection distance_lt(double d); + + Only select within given distance. This may enable more efficient sampling through an + internal spatial data structure. + + .. cpp:function:: network_selection distance_gt(double d); + + Only select if distance greater then given distance. This may enable more efficient sampling + through an internal spatial data structure. + + +.. cpp:class:: network_label_dict + + Dictionary storing named network values and selections. + + .. cpp:function:: network_label_dict& set(const std::string& name, network_selection s) + + Store a network selection under the given name + + .. cpp:function:: network_label_dict& set(const std::string& name, network_value v) + + Store a network value under the given name + + .. cpp:function:: std::optional selection(const std::string& name) const + + Returns the stored network selection of the given name if it exists. None otherwise. + + .. cpp:function:: std::optional value(const std::string& name) const + + Returns the stored network value of the given name if it exists. None otherwise. + + .. cpp:function:: const ns_map& selections() const + + All stored network selections + + .. cpp:function:: const nv_map& selections() const + + All stored network values + + +.. cpp:class:: network_description + + A complete network description required for processing. + + .. cpp:member:: network_selection selection + + Selection of connections. + + .. cpp:member:: network_value weight + + Weight of generated connections. + + .. cpp:member:: network_value delay + + Delay of generated connections. + + .. cpp:member:: network_label_dict dict + + Label dictionary for named selecations and values. diff --git a/doc/python/interconnectivity.rst b/doc/python/interconnectivity.rst index e5c40905b5..a65b68d4cd 100644 --- a/doc/python/interconnectivity.rst +++ b/doc/python/interconnectivity.rst @@ -112,3 +112,62 @@ Interconnectivity .. attribute:: threshold Voltage threshold of threshold detector [mV] + + +.. class:: network_site_info + + A network connection site on a cell. Used for generated connections through the high-level network description. + + .. attribute:: gid + + The cell index. + + .. attribute:: kind + + The cell kind. + + .. attribute:: label + + The associated label. + + .. attribute:: location + + The local location on the cell. + + .. attribute:: global_location + + The global location in cartesian coordinates. + + +.. class:: network_connection_info + + A network connection between cells. Used for generated connections through the high-level network description. + + .. attribute:: src + + The source connection site. + + .. attribute:: dest + + The destination connection site. + + +.. class:: network_description + + A complete network description required for processing. + + .. attribute:: selection + + Selection of connections. + + .. attribute:: weight + + Weight of generated connections. + + .. attribute:: delay + + Delay of generated connections. + + .. attribute:: dict + + Dictionary for named selecations and values. From 4b33622364f8163e2d216a43f1f8049e2638ecb9 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 25 Aug 2023 13:57:13 +0200 Subject: [PATCH 39/84] reformat example --- doc/concepts/interconnectivity.rst | 14 +++++++------- python/example/network_description.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 6d876e49f3..8018c29102 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -34,20 +34,20 @@ The following example shows the relevant recipe functions, where cells are conne seed = 42 # create a chain - s_chain = f"(chain (gid-range 0 {self.ncells}))" + chain = f"(chain (gid-range 0 {self.ncells}))" # connect front and back of chain to form ring - s_ring = f"(join {s_chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" # Create random connections with probability inversely proportional to the distance within a # radius - max_dist = 400.0 # μm + max_dist = 400.0 # μm probability = f"(div (sub {max_dist} (distance)) {max_dist})" - s_rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" + rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" # combine ring with random selection - s = f"(join {s_ring} {s_rand})" + s = f"(join {ring} {rand})" # restrict to inter-cell connections and certain source / destination labels - s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" + s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' # fixed weight for connections in ring w_ring = f"(scalar 0.01)" @@ -56,7 +56,7 @@ The following example shows the relevant recipe functions, where cells are conne w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" # combine into single weight expression - w = f"(if-else {s_ring} {w_ring} {w_rand})" + w = f"(if-else {ring} {w_ring} {w_rand})" # fixed delay d = "(scalar 5.0)" # ms delay diff --git a/python/example/network_description.py b/python/example/network_description.py index 1e0a8b8883..c59b42d580 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -35,7 +35,7 @@ def make_cable_cell(gid): tree.append( b0, arbor.mpoint(0, 0, 50, 2), - arbor.mpoint(0, 50 / sqrt(2), 50 + 50 / sqrt(2) , 0.5), + arbor.mpoint(0, 50 / sqrt(2), 50 + 50 / sqrt(2), 0.5), tag=3, ) # (b2) Constant radius of 1 μm over the length of the dendrite. @@ -97,28 +97,30 @@ def cell_kind(self, gid): def cell_isometry(self, gid): # place cells with equal distance on a circle - radius = 500.0 # μm + radius = 500.0 # μm angle = 2.0 * math.pi * gid / self.ncells - return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) + return arbor.isometry.translate( + radius * math.cos(angle), radius * math.sin(angle), 0 + ) def network_description(self): seed = 42 # create a chain - ring = f"(chain (gid-range 0 {self.ncells}))" + chain = f"(chain (gid-range 0 {self.ncells}))" # connect front and back of chain to form ring - ring = f"(join {ring} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" # Create random connections with probability inversely proportional to the distance within a # radius - max_dist = 400.0 # μm + max_dist = 400.0 # μm probability = f"(div (sub {max_dist} (distance)) {max_dist})" rand = f"(intersect (random {seed} {probability}) (distance-lt {max_dist}))" # combine ring with random selection s = f"(join {ring} {rand})" # restrict to inter-cell connections and certain source / destination labels - s = f"(intersect {s} (inter-cell) (source-label \"detector\") (destination-label \"syn\"))" + s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' # fixed weight for connections in ring w_ring = f"(scalar 0.01)" From 6e44819b060eb3c9e81cfa450ef7f408efb311da Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 25 Aug 2023 14:05:03 +0200 Subject: [PATCH 40/84] fix flake8 warning --- doc/concepts/interconnectivity.rst | 2 +- python/example/network_description.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 8018c29102..1855e9f701 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -50,7 +50,7 @@ The following example shows the relevant recipe functions, where cells are conne s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' # fixed weight for connections in ring - w_ring = f"(scalar 0.01)" + w_ring = "(scalar 0.01)" # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS # and truncated to [0.005, 0.035] w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" diff --git a/python/example/network_description.py b/python/example/network_description.py index c59b42d580..accc5af844 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -123,7 +123,7 @@ def network_description(self): s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' # fixed weight for connections in ring - w_ring = f"(scalar 0.01)" + w_ring = "(scalar 0.01)" # random normal distributed weight with mean 0.02 μS, standard deviation 0.01 μS # and truncated to [0.005, 0.035] w_rand = f"(truncated-normal-distribution {seed} 0.02 0.01 0.005 0.035)" From d8ae7d241df41d69e3101517e7ff9ca25aac1cab Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 25 Aug 2023 14:29:16 +0200 Subject: [PATCH 41/84] fix unit test linking with shared library --- arbor/network_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 31480cf15b..a42a0de4a7 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -21,7 +21,7 @@ namespace arb { -struct network_full_site_info { +struct ARB_SYMBOL_VISIBLE network_full_site_info { network_full_site_info() = default; network_full_site_info(cell_gid_type gid, From d21b51d9e0a3e4ec05619cfff409179c80772533 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 28 Aug 2023 10:20:41 +0200 Subject: [PATCH 42/84] add connection print out to examples --- arbor/include/arbor/network_generation.hpp | 3 +++ arbor/network_impl.cpp | 9 +++++++++ .../network_description.cpp | 18 +++++++++++++++--- python/example/network_description.py | 13 ++++++++++--- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp index bf174b3f76..515c7abbea 100644 --- a/arbor/include/arbor/network_generation.hpp +++ b/arbor/include/arbor/network_generation.hpp @@ -12,4 +12,7 @@ ARB_ARBOR_API std::vector generate_network_connections( const context& ctx, const domain_decomposition& dom_dec); + +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec); + } // namespace arb diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 8528bc34c7..9aabb9ff4e 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -13,8 +13,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -376,4 +378,11 @@ ARB_ARBOR_API std::vector generate_network_connections( return connections; } +ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { + auto ctx = arb::make_context(); + auto decomp = arb::partition_load_balance(rec, ctx); + + return generate_network_connections(rec, ctx, decomp); +} + } // namespace arb diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index 24a6943410..20f923fa61 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -37,7 +39,6 @@ #include #include -#include "arbor/network.hpp" #include "branch_cell.hpp" #ifdef ARB_MPI_ENABLED @@ -49,9 +50,9 @@ struct ring_params { ring_params() = default; std::string name = "default"; - unsigned num_cells = 100; + unsigned num_cells = 20; double min_delay = 10; - double duration = 1000; + double duration = 100; cell_parameters cell; }; @@ -242,6 +243,17 @@ int main(int argc, char** argv) { meters.checkpoint("model-run", context); + + // Print generated connections + if (root) { + const auto connections = arb::generate_network_connections(recipe); + std::cout << "Connections:" << std::endl; + for(const auto& c: connections) { + std::cout << "(" << c.src.gid << ", \"" << c.src.label << "\") ->"; + std::cout << "(" << c.dest.gid << ", \"" << c.dest.label << "\")" << std::endl; + } + } + auto ns = sim.num_spikes(); // Write spikes to file diff --git a/python/example/network_description.py b/python/example/network_description.py index accc5af844..449195303e 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -164,16 +164,23 @@ def global_properties(self, kind): # (14) Attach a sampler to the voltage probe on cell 0. Sample rate of 10 sample every ms. handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] -# (15) Run simulation for 100 ms +# (15) Inspect generated connections +connections = arbor.generate_network_connections(recipe) + +print("connections:") +for c in connections: + print(f"({c.src.gid}, \"{c.src.label}\") -> ({c.dest.gid}, \"{c.dest.label}\")") + +# (16) Run simulation for 100 ms sim.run(100) print("Simulation finished") -# (16) Print spike times +# (17) Print spike times print("spikes:") for sp in sim.spikes(): print(" ", sp) -# (17) Plot the recorded voltages over time. +# (18) Plot the recorded voltages over time. print("Plotting results ...") df_list = [] for gid in range(ncells): From bbdc519f4796247d32cf84c3b3886d094bd34cac Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 28 Aug 2023 10:29:42 +0200 Subject: [PATCH 43/84] add new examples to scripts --- scripts/run_cpp_examples.sh | 2 ++ scripts/run_python_examples.sh | 1 + 2 files changed, 3 insertions(+) diff --git a/scripts/run_cpp_examples.sh b/scripts/run_cpp_examples.sh index fd07affd11..e5b43c6843 100755 --- a/scripts/run_cpp_examples.sh +++ b/scripts/run_cpp_examples.sh @@ -36,6 +36,7 @@ all_examples=( "plasticity" "ou" "voltage-clamp" + "network_description" "remote" ) @@ -58,6 +59,7 @@ expected_outputs=( "" "" "" + 205 "" ) diff --git a/scripts/run_python_examples.sh b/scripts/run_python_examples.sh index 8661a8bdf0..e1af7429b3 100755 --- a/scripts/run_python_examples.sh +++ b/scripts/run_python_examples.sh @@ -35,6 +35,7 @@ runpyex network_ring.py # runpyex network_ring_mpi_plot.py # no need to test runpyex network_ring_gpu.py # by default, gpu_id=None runpyex network_two_cells_gap_junctions.py +runpyex network_ring.py runpyex diffusion.py runpyex plasticity.py runpyex v-clamp.py From 25086f155bbbf64c7903aa9be27158bac62b47ca Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 28 Aug 2023 10:47:09 +0200 Subject: [PATCH 44/84] reformat example --- python/example/network_description.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/example/network_description.py b/python/example/network_description.py index 449195303e..bfc242c5de 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -169,7 +169,7 @@ def global_properties(self, kind): print("connections:") for c in connections: - print(f"({c.src.gid}, \"{c.src.label}\") -> ({c.dest.gid}, \"{c.dest.label}\")") + print(f'({c.src.gid}, "{c.src.label}") -> ({c.dest.gid}, "{c.dest.label}")') # (16) Run simulation for 100 ms sim.run(100) From 30ccb9bb7f82da6800b02c07611f3d61d874c481 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 28 Aug 2023 13:41:01 +0200 Subject: [PATCH 45/84] fix ci script value --- arbor/communication/distributed_for_each.hpp | 2 +- scripts/run_cpp_examples.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp index b223d06c61..55401772c4 100644 --- a/arbor/communication/distributed_for_each.hpp +++ b/arbor/communication/distributed_for_each.hpp @@ -133,7 +133,7 @@ void distributed_for_each(FUNC&& func, const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1; // exchange buffer in ring pattern and apply function at each step - for (std::size_t step = 0; step < distributed.size() - 1; ++step) { + for (int step = 0; step < distributed.size() - 1; ++step) { // always expect to recieve the max size but send actual size. MPI_recv only expects a max // size, not the actual size. const auto current_info = (const vec_info*)buffer.get(); diff --git a/scripts/run_cpp_examples.sh b/scripts/run_cpp_examples.sh index e5b43c6843..96d498c2b0 100755 --- a/scripts/run_cpp_examples.sh +++ b/scripts/run_cpp_examples.sh @@ -59,7 +59,7 @@ expected_outputs=( "" "" "" - 205 + 37 "" ) From c8421fd326a10cbbc31ce76e804598929db8b660 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 29 Aug 2023 12:37:09 +0200 Subject: [PATCH 46/84] Switch std::string -> hashes for label resolution. --- arbor/communication/mpi_context.cpp | 12 ++-- arbor/label_resolution.cpp | 50 +++++++------ arbor/label_resolution.hpp | 19 ++--- test/unit/test_fvm_lowered.cpp | 104 ++++++++++++++++++---------- test/unit/test_label_resolution.cpp | 51 ++++++++------ 5 files changed, 135 insertions(+), 101 deletions(-) diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index 6019e76065..c2fcf64af7 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -59,13 +59,11 @@ struct mpi_context_impl { } cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { - std::vector sizes; - std::vector labels; - std::vector ranges; - sizes = mpi::gather_all(local_ranges.sizes(), comm_); - labels = mpi::gather_all(local_ranges.labels(), comm_); - ranges = mpi::gather_all(local_ranges.ranges(), comm_); - return cell_label_range(sizes, labels, ranges); + cell_label_range res; + res.sizes = mpi::gather_all(local_ranges.sizes, comm_); + res.labels = mpi::gather_all(local_ranges.labels, comm_); + res.ranges = mpi::gather_all(local_ranges.ranges, comm_); + return res; } cell_labels_and_gids gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const { diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index dd928098b6..b143fc6bce 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -17,39 +17,40 @@ namespace arb { cell_label_range::cell_label_range(std::vector size_vec, std::vector label_vec, std::vector range_vec): - sizes_(std::move(size_vec)), labels_(std::move(label_vec)), ranges_(std::move(range_vec)) + sizes(std::move(size_vec)), ranges(std::move(range_vec)) { + std::transform(label_vec.begin(), label_vec.end(), + std::back_inserter(labels), + internal_hash); arb_assert(check_invariant()); }; -void cell_label_range::add_cell() { - sizes_.push_back(0); -} +void cell_label_range::add_cell() { sizes.push_back(0); } void cell_label_range::add_label(cell_tag_type label, lid_range range) { - if (sizes_.empty()) throw arbor_internal_error("adding label to cell_label_range without cell"); - ++sizes_.back(); - labels_.push_back(std::move(label)); - ranges_.push_back(std::move(range)); + if (sizes.empty()) throw arbor_internal_error("adding label to cell_label_range without cell"); + ++sizes.back(); + labels.push_back(internal_hash(label)); + ranges.push_back(std::move(range)); } void cell_label_range::append(cell_label_range other) { using std::make_move_iterator; - sizes_.insert(sizes_.end(), make_move_iterator(other.sizes_.begin()), make_move_iterator(other.sizes_.end())); - labels_.insert(labels_.end(), make_move_iterator(other.labels_.begin()), make_move_iterator(other.labels_.end())); - ranges_.insert(ranges_.end(), make_move_iterator(other.ranges_.begin()), make_move_iterator(other.ranges_.end())); + sizes.insert(sizes.end(), make_move_iterator(other.sizes.begin()), make_move_iterator(other.sizes.end())); + labels.insert(labels.end(), make_move_iterator(other.labels.begin()), make_move_iterator(other.labels.end())); + ranges.insert(ranges.end(), make_move_iterator(other.ranges.begin()), make_move_iterator(other.ranges.end())); } bool cell_label_range::check_invariant() const { - const cell_size_type count = std::accumulate(sizes_.begin(), sizes_.end(), cell_size_type(0)); - return count==labels_.size() && count==ranges_.size(); + const cell_size_type count = std::accumulate(sizes.begin(), sizes.end(), cell_size_type(0)); + return count==labels.size() && count==ranges.size(); } // cell_labels_and_gids methods cell_labels_and_gids::cell_labels_and_gids(cell_label_range lr, std::vector gid): label_range(std::move(lr)), gids(std::move(gid)) { - if (label_range.sizes().size()!=gids.size()) throw arbor_internal_error("cell_label_range and gid count mismatch"); + if (label_range.sizes.size()!=gids.size()) throw arbor_internal_error("cell_label_range and gid count mismatch"); } void cell_labels_and_gids::append(cell_labels_and_gids other) { @@ -58,7 +59,7 @@ void cell_labels_and_gids::append(cell_labels_and_gids other) { } bool cell_labels_and_gids::check_invariant() const { - return label_range.check_invariant() && label_range.sizes().size()==gids.size(); + return label_range.check_invariant() && label_range.sizes.size()==gids.size(); } // label_resolution_map methods @@ -82,34 +83,35 @@ lid_hopefully label_resolution_map::range_set::at(unsigned idx) const { } const label_resolution_map::range_set& label_resolution_map::at(cell_gid_type gid, const cell_tag_type& tag) const { - return map.at(gid).at(tag); + return map.at(gid).at(internal_hash(tag)); } std::size_t label_resolution_map::count(cell_gid_type gid, const cell_tag_type& tag) const { if (!map.count(gid)) return 0u; - return map.at(gid).count(tag); + return map.at(gid).count(internal_hash(tag)); } label_resolution_map::label_resolution_map(const cell_labels_and_gids& clg) { arb_assert(clg.label_range.check_invariant()); const auto& gids = clg.gids; - const auto& labels = clg.label_range.labels(); - const auto& ranges = clg.label_range.ranges(); - const auto& sizes = clg.label_range.sizes(); + const auto& labels = clg.label_range.labels; + const auto& ranges = clg.label_range.ranges; + const auto& sizes = clg.label_range.sizes; std::vector label_divs; auto partn = util::make_partition(label_divs, sizes); for (auto i: util::count_along(partn)) { auto gid = gids[i]; - std::unordered_map m; + std::unordered_map m; for (auto label_idx: util::make_span(partn[i])) { const auto range = ranges[label_idx]; auto size = int(range.end - range.begin); if (size < 0) { throw arb::arbor_internal_error("label_resolution_map: invalid lid_range"); } - auto& range_set = m[labels[label_idx]]; + auto& label = labels[label_idx]; + auto& range_set = m[label]; range_set.ranges.push_back(range); range_set.ranges_partition.push_back(range_set.ranges_partition.back() + size); } @@ -204,10 +206,12 @@ lid_hopefully update_state(resolver::state_variant& v, cell_lid_type resolver::resolve(cell_gid_type gid, const cell_local_label_type& label) { const auto& [tag, pol] = label; + auto hash = internal_hash(tag); + if (!label_map_->count(gid, tag)) throw arb::bad_connection_label(gid, tag, "label does not exist"); const auto& range_set = label_map_->at(gid, tag); - auto& state = state_map_[gid][tag]; + auto& state = state_map_[gid][hash]; // Policy round_robin_halt: use previous state of round_robin policy, if existent if (pol == lid_selection_policy::round_robin_halt diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index 9b30a41fa3..d8823cbfe3 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -9,6 +9,7 @@ #include #include "util/partition.hpp" +#include "util/hash.hpp" namespace arb { @@ -18,8 +19,7 @@ using lid_hopefully = arb::util::expected; // `sizes` is a partitioning vector for associating a cell with a set of // (label, range) pairs in `labels`, `ranges`. // gids of the cells are unknown. -class ARB_ARBOR_API cell_label_range { -public: +struct ARB_ARBOR_API cell_label_range { cell_label_range() = default; cell_label_range(cell_label_range&&) = default; cell_label_range(const cell_label_range&) = default; @@ -36,19 +36,14 @@ class ARB_ARBOR_API cell_label_range { bool check_invariant() const; - const auto& sizes() const { return sizes_; } - const auto& labels() const { return labels_; } - const auto& ranges() const { return ranges_; } - -private: // The number of labels associated with each cell. - std::vector sizes_; + std::vector sizes; // The labels corresponding to each cell, partitioned according to sizes_. - std::vector labels_; + std::vector labels; // The lid_range corresponding to each label. - std::vector ranges_; + std::vector ranges; }; // Struct for associating each cell of `cell_label_range` with a gid. @@ -83,7 +78,7 @@ class ARB_ARBOR_API label_resolution_map { std::size_t count(cell_gid_type gid, const cell_tag_type& tag) const; private: - std::unordered_map> map; + std::unordered_map> map; }; struct ARB_ARBOR_API round_robin_state { @@ -123,6 +118,6 @@ struct ARB_ARBOR_API resolver { state_variant construct_state(lid_selection_policy pol, cell_lid_type state); const label_resolution_map* label_map_; - map>> state_map_; + map>> state_map_; }; } // namespace arb diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 61837432ff..29a88913a3 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -993,94 +993,122 @@ TEST(fvm_lowered, label_data) { { auto clg = cell_labels_and_gids(fvm_info.target_data, gids); std::vector expected_sizes = {2, 0, 0, 2, 0, 0, 2, 0, 0, 2}; - std::vector> expected_labeled_ranges, actual_labeled_ranges; - expected_labeled_ranges = { - {"1_synapse", {4, 5}}, {"4_synapses", {0, 4}}, - {"1_synapse", {4, 5}}, {"4_synapses", {0, 4}}, - {"1_synapse", {4, 5}}, {"4_synapses", {0, 4}}, - {"1_synapse", {4, 5}}, {"4_synapses", {0, 4}} + std::vector> expected_labeled_ranges = { + {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, + {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, + {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, + {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}} }; + std::vector> actual_labeled_ranges; + EXPECT_EQ(clg.gids, gids); - EXPECT_EQ(clg.label_range.sizes(), expected_sizes); - EXPECT_EQ(clg.label_range.labels().size(), expected_labeled_ranges.size()); - EXPECT_EQ(clg.label_range.ranges().size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.sizes, expected_sizes); + EXPECT_EQ(clg.label_range.labels.size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.ranges.size(), expected_labeled_ranges.size()); for (unsigned i = 0; i < expected_labeled_ranges.size(); ++i) { - actual_labeled_ranges.push_back({clg.label_range.labels()[i], clg.label_range.ranges()[i]}); + actual_labeled_ranges.push_back({clg.label_range.labels[i], clg.label_range.ranges[i]}); } std::vector size_partition; auto part = util::make_partition(size_partition, expected_sizes); for (const auto& r: part) { util::sort(util::subrange_view(actual_labeled_ranges, r)); + util::sort(util::subrange_view(expected_labeled_ranges, r)); } EXPECT_EQ(actual_labeled_ranges, expected_labeled_ranges); + + // Check for hash collisions; if we have one, the hash will appear twice in a given range, + // making the set of ids smaller than expected + const auto& labels = clg.label_range.labels; + for (const auto& [beg, end]: part) { + std::unordered_set unique(labels.begin() + beg, labels.begin() + end); + EXPECT_EQ(unique.size(), end - beg); + } } // detectors { auto clg = cell_labels_and_gids(fvm_info.source_data, gids); std::vector expected_sizes = {1, 2, 2, 1, 2, 2, 1, 2, 2, 1}; - std::vector> expected_labeled_ranges, actual_labeled_ranges; - expected_labeled_ranges = { - {"1_detector", {0, 1}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"1_detector", {0, 1}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"1_detector", {0, 1}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"2_detectors", {3, 5}}, {"3_detectors", {0, 3}}, - {"1_detector", {0, 1}} + std::vector> expected_labeled_ranges = { + {internal_hash("1_detector"), {0, 1}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("1_detector"), {0, 1}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("1_detector"), {0, 1}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, + {internal_hash("1_detector"), {0, 1}} }; + std::vector> actual_labeled_ranges; EXPECT_EQ(clg.gids, gids); - EXPECT_EQ(clg.label_range.sizes(), expected_sizes); - EXPECT_EQ(clg.label_range.labels().size(), expected_labeled_ranges.size()); - EXPECT_EQ(clg.label_range.ranges().size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.sizes, expected_sizes); + EXPECT_EQ(clg.label_range.labels.size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.ranges.size(), expected_labeled_ranges.size()); for (unsigned i = 0; i < expected_labeled_ranges.size(); ++i) { - actual_labeled_ranges.push_back({clg.label_range.labels()[i], clg.label_range.ranges()[i]}); + actual_labeled_ranges.push_back({clg.label_range.labels[i], clg.label_range.ranges[i]}); } std::vector size_partition; auto part = util::make_partition(size_partition, expected_sizes); for (const auto& r: part) { util::sort(util::subrange_view(actual_labeled_ranges, r)); + util::sort(util::subrange_view(expected_labeled_ranges, r)); } EXPECT_EQ(actual_labeled_ranges, expected_labeled_ranges); + + // Check for hash collisions; if we have one, the hash will appear twice in a given range, + // making the set of ids smaller than expected + const auto& labels = clg.label_range.labels; + for (const auto& [beg, end]: part) { + std::unordered_set unique(labels.begin() + beg, labels.begin() + end); + EXPECT_EQ(unique.size(), end - beg); + } } // gap_junctions { auto clg = cell_labels_and_gids(fvm_info.gap_junction_data, gids); std::vector expected_sizes = {0, 2, 2, 0, 2, 2, 0, 2, 2, 0}; - std::vector> expected_labeled_ranges, actual_labeled_ranges; - expected_labeled_ranges = { - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, - {"1_gap_junction", {2, 3}}, {"2_gap_junctions", {0, 2}}, + std::vector> expected_labeled_ranges = { + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, }; EXPECT_EQ(clg.gids, gids); - EXPECT_EQ(clg.label_range.sizes(), expected_sizes); - EXPECT_EQ(clg.label_range.labels().size(), expected_labeled_ranges.size()); - EXPECT_EQ(clg.label_range.ranges().size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.sizes, expected_sizes); + EXPECT_EQ(clg.label_range.labels.size(), expected_labeled_ranges.size()); + EXPECT_EQ(clg.label_range.ranges.size(), expected_labeled_ranges.size()); + std::vector> actual_labeled_ranges; for (unsigned i = 0; i < expected_labeled_ranges.size(); ++i) { - actual_labeled_ranges.push_back({clg.label_range.labels()[i], clg.label_range.ranges()[i]}); + actual_labeled_ranges.push_back({clg.label_range.labels[i], clg.label_range.ranges[i]}); } std::vector size_partition; auto part = util::make_partition(size_partition, expected_sizes); for (const auto& r: part) { util::sort(util::subrange_view(actual_labeled_ranges, r)); + util::sort(util::subrange_view(expected_labeled_ranges, r)); } EXPECT_EQ(actual_labeled_ranges, expected_labeled_ranges); + + // Check for hash collisions; if we have one, the hash will appear twice in a given range, + // making the set of ids smaller than expected + const auto& labels = clg.label_range.labels; + for (const auto& [beg, end]: part) { + std::unordered_set unique(labels.begin() + beg, labels.begin() + end); + EXPECT_EQ(unique.size(), end - beg); + } } } diff --git a/test/unit/test_label_resolution.cpp b/test/unit/test_label_resolution.cpp index 75adfbdae2..450428a743 100644 --- a/test/unit/test_label_resolution.cpp +++ b/test/unit/test_label_resolution.cpp @@ -8,6 +8,14 @@ using namespace arb; +std::vector make_labels(const std::vector& ls) { + std::vector res; + std::transform(ls.begin(), ls.end(), + std::back_inserter(res), + internal_hash); + return res; +} + TEST(test_cell_label_range, build) { using ivec = std::vector; using svec = std::vector; @@ -16,18 +24,18 @@ TEST(test_cell_label_range, build) { // Test add_cell and add_label auto b0 = cell_label_range(); EXPECT_THROW(b0.add_label("l0", {0u, 1u}), arb::arbor_internal_error); - EXPECT_TRUE(b0.sizes().empty()); - EXPECT_TRUE(b0.labels().empty()); - EXPECT_TRUE(b0.ranges().empty()); + EXPECT_TRUE(b0.sizes.empty()); + EXPECT_TRUE(b0.labels.empty()); + EXPECT_TRUE(b0.ranges.empty()); EXPECT_TRUE(b0.check_invariant()); auto b1 = cell_label_range(); b1.add_cell(); b1.add_cell(); b1.add_cell(); - EXPECT_EQ((ivec{0u, 0u, 0u}), b1.sizes()); - EXPECT_TRUE(b1.labels().empty()); - EXPECT_TRUE(b1.ranges().empty()); + EXPECT_EQ((ivec{0u, 0u, 0u}), b1.sizes); + EXPECT_TRUE(b1.labels.empty()); + EXPECT_TRUE(b1.ranges.empty()); EXPECT_TRUE(b1.check_invariant()); auto b2 = cell_label_range(); @@ -42,9 +50,9 @@ TEST(test_cell_label_range, build) { b2.add_label("l4", {7u, 2u}); b2.add_label("l4", {7u, 2u}); b2.add_label("l2", {7u, 2u}); - EXPECT_EQ((ivec{3u, 0u, 5u}), b2.sizes()); - EXPECT_EQ((svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b2.labels()); - EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b2.ranges()); + EXPECT_EQ((ivec{3u, 0u, 5u}), b2.sizes); + EXPECT_EQ(make_labels(svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b2.labels); + EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b2.ranges); EXPECT_TRUE(b2.check_invariant()); auto b3 = cell_label_range(); @@ -52,28 +60,29 @@ TEST(test_cell_label_range, build) { b3.add_label("r0", {0u, 9u}); b3.add_label("r1", {10u, 10u}); b3.add_cell(); - EXPECT_EQ((ivec{2u, 0u}), b3.sizes()); - EXPECT_EQ((svec{"r0", "r1"}), b3.labels()); - EXPECT_EQ((lvec{{0u, 9u}, {10u, 10u}}), b3.ranges()); + EXPECT_EQ((ivec{2u, 0u}), b3.sizes); + EXPECT_EQ(make_labels + (svec{"r0", "r1"}), b3.labels); + EXPECT_EQ((lvec{{0u, 9u}, {10u, 10u}}), b3.ranges); EXPECT_TRUE(b3.check_invariant()); // Test appending b0.append(b1); - EXPECT_EQ((ivec{0u, 0u, 0u}), b0.sizes()); - EXPECT_TRUE(b0.labels().empty()); - EXPECT_TRUE(b0.ranges().empty()); + EXPECT_EQ((ivec{0u, 0u, 0u}), b0.sizes); + EXPECT_TRUE(b0.labels.empty()); + EXPECT_TRUE(b0.ranges.empty()); EXPECT_TRUE(b0.check_invariant()); b0.append(b2); - EXPECT_EQ((ivec{0u, 0u, 0u, 3u, 0u, 5u}), b0.sizes()); - EXPECT_EQ((svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b0.labels()); - EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b0.ranges()); + EXPECT_EQ((ivec{0u, 0u, 0u, 3u, 0u, 5u}), b0.sizes); + EXPECT_EQ(make_labels(svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b0.labels); + EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b0.ranges); EXPECT_TRUE(b0.check_invariant()); b0.append(b3); - EXPECT_EQ((ivec{0u, 0u, 0u, 3u, 0u, 5u, 2u, 0u}), b0.sizes()); - EXPECT_EQ((svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2", "r0", "r1"}), b0.labels()); - EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}, {0u, 9u}, {10u, 10u}}), b0.ranges()); + EXPECT_EQ((ivec{0u, 0u, 0u, 3u, 0u, 5u, 2u, 0u}), b0.sizes); + EXPECT_EQ(make_labels(svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2", "r0", "r1"}), b0.labels); + EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}, {0u, 9u}, {10u, 10u}}), b0.ranges); EXPECT_TRUE(b0.check_invariant()); } From d13a5b1ea7f8df4b1684ef7cddab5322be5c4214 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:16:01 +0200 Subject: [PATCH 47/84] Invariant checks for hash collisions. --- arbor/fvm_lowered_cell_impl.hpp | 4 ++++ arbor/label_resolution.cpp | 12 +++++++++++ arbor/util/hash.hpp | 35 +++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 arbor/util/hash.hpp diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 6905889510..fc10d6ff36 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -370,6 +370,10 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( } } + if (!fvm_info.target_data.check_invariant()) throw arbor_internal_error{"Building cell target data resulted in invalid state."}; + if (!fvm_info.source_data.check_invariant()) throw arbor_internal_error{"Building cell source data resulted in invalid state."}; + if (!fvm_info.gap_junction_data.check_invariant()) throw arbor_internal_error{"Building cell gj data resulted in invalid state."}; + cable_cell_global_properties global_props; try { std::any rec_props = rec.get_global_properties(cell_kind::cable); diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index b143fc6bce..aaf974c2cf 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -43,6 +44,17 @@ void cell_label_range::append(cell_label_range other) { bool cell_label_range::check_invariant() const { const cell_size_type count = std::accumulate(sizes.begin(), sizes.end(), cell_size_type(0)); + size_t beg = 0; + for (auto size: sizes) { + size_t end = beg + size; + std::unordered_set seen; + for (auto idx = beg; idx < end; ++idx) { + auto hash = labels[idx]; + if (seen.count(hash)) return false; + seen.insert(hash); + } + beg = end; + } return count==labels.size() && count==ranges.size(); } diff --git a/arbor/util/hash.hpp b/arbor/util/hash.hpp new file mode 100644 index 0000000000..416cde1a84 --- /dev/null +++ b/arbor/util/hash.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace arb { +using hash_type = uint64_t; + +// Non-cryptographic hash function for mapping strings to internal +// identifiers. Concretely, FNV-1a hash function taken from +// +// http://www.isthe.com/chongo/tech/comp/fnv/index.html +// +// NOTE: It may be worth it considering different hash functions in +// the future that have better characteristic, xxHash or Murmur +// look interesting but are more complex and likely require adding +// external dependencies. +// NOTE: this is the obligatory comment on a better hash function +// that will be here until the end of time. + +constexpr hash_type offset_basis = 0xcbf29ce484222325; +constexpr hash_type prime = 0x100000001b3; + +constexpr hash_type internal_hash(std::string_view data) { + hash_type hash = offset_basis; + + for (uint8_t byte: data) { + hash = hash ^ byte; + hash = hash * prime; + } + + return hash; +} + +} From ee037fc4f374215f1cfb1de115229274ad3f455f Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:58:22 +0200 Subject: [PATCH 48/84] Account for same keys on cell label range. --- arbor/fvm_lowered_cell_impl.hpp | 52 ++++++++++++++------------------- arbor/label_resolution.cpp | 13 +-------- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index fc10d6ff36..70fccf07ea 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -312,11 +312,26 @@ fvm_detector_info get_detector_info(arb_size_type max, return { max, std::move(cv), std::move(threshold), ctx }; } +inline cell_size_type +add_labels(cell_label_range& clr, const std::unordered_multimap& ranges) { + clr.add_cell(); + cell_size_type count = 0; + std::unordered_map hashes; + for (const auto& [label, range]: ranges) { + auto hash = internal_hash(label); + if (hashes.count(hash) && hashes.at(hash) != label) { + auto err = util::strprintf("Hash collision {} ~ {} = {}", label, hashes.at(hash), hash); + throw arbor_internal_error{err}; + } + clr.add_label(label, range); + count += (range.end - range.begin); + } + return count; +} + template -fvm_initialization_data fvm_lowered_cell_impl::initialize( - const std::vector& gids, - const recipe& rec) -{ +fvm_initialization_data fvm_lowered_cell_impl::initialize(const std::vector& gids, + const recipe& rec) { using std::any_cast; using util::count_along; using util::make_span; @@ -346,34 +361,11 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( for (auto i : util::make_span(ncell)) { auto gid = gids[i]; const auto& c = cells[i]; - - fvm_info.source_data.add_cell(); - fvm_info.target_data.add_cell(); - fvm_info.gap_junction_data.add_cell(); - - unsigned count = 0; - for (const auto& [label, range]: c.detector_ranges()) { - fvm_info.source_data.add_label(label, range); - count+=(range.end - range.begin); - } - fvm_info.num_sources[gid] = count; - - count = 0; - for (const auto& [label, range]: c.synapse_ranges()) { - fvm_info.target_data.add_label(label, range); - count+=(range.end - range.begin); - } - fvm_info.num_targets[gid] = count; - - for (const auto& [label, range]: c.junction_ranges()) { - fvm_info.gap_junction_data.add_label(label, range); - } + fvm_info.num_sources[gid] = add_labels(fvm_info.source_data, c.detector_ranges()); + fvm_info.num_targets[gid] = add_labels(fvm_info.target_data, c.synapse_ranges()); + add_labels(fvm_info.gap_junction_data, c.junction_ranges()); } - if (!fvm_info.target_data.check_invariant()) throw arbor_internal_error{"Building cell target data resulted in invalid state."}; - if (!fvm_info.source_data.check_invariant()) throw arbor_internal_error{"Building cell source data resulted in invalid state."}; - if (!fvm_info.gap_junction_data.check_invariant()) throw arbor_internal_error{"Building cell gj data resulted in invalid state."}; - cable_cell_global_properties global_props; try { std::any rec_props = rec.get_global_properties(cell_kind::cable); diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index aaf974c2cf..69c3b6b57f 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -37,24 +37,13 @@ void cell_label_range::add_label(cell_tag_type label, lid_range range) { void cell_label_range::append(cell_label_range other) { using std::make_move_iterator; - sizes.insert(sizes.end(), make_move_iterator(other.sizes.begin()), make_move_iterator(other.sizes.end())); + sizes.insert(sizes.end(), make_move_iterator(other.sizes.begin()), make_move_iterator(other.sizes.end())); labels.insert(labels.end(), make_move_iterator(other.labels.begin()), make_move_iterator(other.labels.end())); ranges.insert(ranges.end(), make_move_iterator(other.ranges.begin()), make_move_iterator(other.ranges.end())); } bool cell_label_range::check_invariant() const { const cell_size_type count = std::accumulate(sizes.begin(), sizes.end(), cell_size_type(0)); - size_t beg = 0; - for (auto size: sizes) { - size_t end = beg + size; - std::unordered_set seen; - for (auto idx = beg; idx < end; ++idx) { - auto hash = labels[idx]; - if (seen.count(hash)) return false; - seen.insert(hash); - } - beg = end; - } return count==labels.size() && count==ranges.size(); } From 1a74723a0431430ebc6937d944764160dadd3f67 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Thu, 31 Aug 2023 16:27:01 +0200 Subject: [PATCH 49/84] review --- arbor/network_impl.cpp | 8 +++----- .../unit-distributed/test_distributed_for_each.cpp | 8 ++++---- test/unit/test_s_expr.cpp | 14 +++++++------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 9aabb9ff4e..143e5d69d6 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -4,7 +4,7 @@ #include "label_resolution.hpp" #include "network_impl.hpp" #include "threading/threading.hpp" -#include "util/range.hpp" +#include "util/rangeutil.hpp" #include "util/spatial_tree.hpp" #include @@ -322,10 +322,8 @@ std::vector generate_network_connections(const recipe& rec, distributed_for_each(sample_sources, distributed, - util::make_range(src_sites.sites.begin(), src_sites.sites.end()), - util::make_range(src_sites.labels.begin(), src_sites.labels.end())); - - // distributed_for_each(sample_sources, distributed, src_sites.sites, src_sites.labels); + util::range_view(src_sites.sites), + util::range_view(src_sites.labels)); return connections; } diff --git a/test/unit-distributed/test_distributed_for_each.cpp b/test/unit-distributed/test_distributed_for_each.cpp index 80a38274d6..d3965892db 100644 --- a/test/unit-distributed/test_distributed_for_each.cpp +++ b/test/unit-distributed/test_distributed_for_each.cpp @@ -7,7 +7,7 @@ #include "communication/distributed_for_each.hpp" #include "execution_context.hpp" -#include "util/range.hpp" +#include "util/rangeutil.hpp" using namespace arb; @@ -84,9 +84,9 @@ TEST(distributed_for_each, multiple) { distributed_for_each(sample, *g_context->distributed, - util::make_range(data_1.begin(), data_1.end()), - util::make_range(data_2.begin(), data_2.end()), - util::make_range(data_3.begin(), data_3.end())); + util::range_view(data_1), + util::range_view(data_2), + util::range_view(data_3)); EXPECT_EQ(num_ranks, call_count); } diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index c07c094a80..6797db0226 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -174,7 +174,7 @@ TEST(s_expr, iterate) { template std::string round_trip_label(const char* in) { if (auto x = parse_label_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -183,7 +183,7 @@ std::string round_trip_label(const char* in) { std::string round_trip_cv(const char* in) { if (auto x = parse_cv_policy_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -192,7 +192,7 @@ std::string round_trip_cv(const char* in) { std::string round_trip_region(const char* in) { if (auto x = parse_region_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -201,7 +201,7 @@ std::string round_trip_region(const char* in) { std::string round_trip_locset(const char* in) { if (auto x = parse_locset_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -210,7 +210,7 @@ std::string round_trip_locset(const char* in) { std::string round_trip_iexpr(const char* in) { if (auto x = parse_iexpr_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -219,7 +219,7 @@ std::string round_trip_iexpr(const char* in) { std::string round_trip_network_selection(const char* in) { if (auto x = parse_network_selection_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); @@ -228,7 +228,7 @@ std::string round_trip_network_selection(const char* in) { std::string round_trip_network_value(const char* in) { if (auto x = parse_network_value_expression(in)) { - return util::pprintf("{}", std::any_cast(*x)); + return util::to_string(std::any_cast(*x)); } else { return x.error().what(); From 71d116b97f85acd33f9b808ee6e74dbb165ec9b8 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 1 Sep 2023 08:27:14 +0200 Subject: [PATCH 50/84] added "visit_variant" alternative to std::visit for better performance --- arbor/util/spatial_tree.hpp | 67 +++++++++++++++++------------------- arbor/util/visit_variant.hpp | 41 ++++++++++++++++++++++ 2 files changed, 73 insertions(+), 35 deletions(-) create mode 100644 arbor/util/visit_variant.hpp diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 81e76e0449..10f3879194 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -1,5 +1,7 @@ #pragma once +#include "util/visit_variant.hpp" + #include #include @@ -109,17 +111,14 @@ class spatial_tree { // func must have signature `void func(const T&)`. template inline void for_each(const F &func) const { - std::visit( - [&](auto &&arg) { - using arg_type = std::decay_t; - if constexpr (std::is_same_v) { - for (const auto &node: arg) { node.for_each(func); } - } - if constexpr (std::is_same_v) { - for (const auto &d: arg) { func(d); } - } + util::visit_variant( + data_, + [&](const node_data &data) { + for (const auto &node: data) { node.for_each(func); } }, - data_); + [&](const leaf_data &data) { + for (const auto &d: data) { func(d); } + }); } // Iterate over all points within the given bounding box recursively. @@ -134,40 +133,38 @@ class spatial_tree { return result; }; - std::visit( - [&](auto &&arg) { - using arg_type = std::decay_t; - + util::visit_variant( + data_, + [&](const node_data &data) { if (all_smaller_eq(box_min, min_) && all_smaller_eq(max_, box_max)) { // sub-nodes fully inside box -> call without further boundary // checks - if constexpr (std::is_same_v) { - for (const auto &node: arg) { node.template for_each(func); } - } - if constexpr (std::is_same_v) { - for (const auto &d: arg) { func(d); } - } + for (const auto &node: data) { node.template for_each(func); } } else { // sub-nodes partially overlap bounding box - if constexpr (std::is_same_v) { - for (const auto &node: arg) { - if (all_smaller_eq(node.min_, box_max) && - all_smaller_eq(box_min, node.max_)) - node.template bounding_box_for_each(box_min, box_max, func); - } - } - if constexpr (std::is_same_v) { - for (const auto &d: arg) { - const auto p = location_(d); - if (all_smaller_eq(p, box_max) && all_smaller_eq(box_min, p)) { - func(d); - } - } + for (const auto &node: data) { + if (all_smaller_eq(node.min_, box_max) && + all_smaller_eq(box_min, node.max_)) + node.template bounding_box_for_each(box_min, box_max, func); } } }, - data_); + [&](const leaf_data &data) { + if (all_smaller_eq(box_min, min_) && all_smaller_eq(max_, box_max)) { + // sub-nodes fully inside box -> call without further boundary + // checks + for (const auto &d: data) { func(d); } + } + else { + // sub-nodes partially overlap bounding box + for (const auto &d: data) { + const auto p = location_(d); + if (all_smaller_eq(p, box_max) && all_smaller_eq(box_min, p)) { func(d); } + } + } + }); + } inline std::size_t size() const { return size_; } diff --git a/arbor/util/visit_variant.hpp b/arbor/util/visit_variant.hpp new file mode 100644 index 0000000000..bf6e32ef16 --- /dev/null +++ b/arbor/util/visit_variant.hpp @@ -0,0 +1,41 @@ +#pragma once + + +#include +#include + + +namespace arb { +namespace util { + +namespace impl { +template +inline void visit_variant_impl(VARIANT &&v, F &&f) { + constexpr auto index = std::variant_size_v> - 1; + if (v.index() == index) f(std::get(v)); +} + +template +inline void visit_variant_impl(VARIANT &&v, F &&f, FUNCS &&...functions) { + constexpr auto index = + std::variant_size_v> - sizeof...(FUNCS) - 1; + if (v.index() == index) f(std::get(v)); + visit_variant_impl(std::forward(v), std::forward(functions)...); +} +} // namespace impl + +/* + * Similar to std::visit, call contained type with matching function. Expects a function for each + * type in variant and in the same order. More performant than std::visit through the use of + * indexing instead of function tables. + */ +template +inline void visit_variant(VARIANT &&v, FUNCS &&...functions) { + static_assert(std::variant_size_v> == + sizeof...(FUNCS), + "The first argument must be of type std::variant and the " + "number of functions must match the variant size."); + impl::visit_variant_impl(std::forward(v), std::forward(functions)...); +} +} // namespace util +} // namespace arb From d4e334ed5ba070e0a681447cddc322e2f35792ee Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sat, 2 Sep 2023 19:47:49 +0200 Subject: [PATCH 51/84] improved multithreading by lock removal --- arbor/network_impl.cpp | 122 +++++++++++++++++++++------------- arbor/threading/threading.cpp | 8 +++ arbor/threading/threading.hpp | 6 +- 3 files changed, 89 insertions(+), 47 deletions(-) diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 143e5d69d6..518c8ce5ee 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -88,6 +88,12 @@ struct site_mapping { s.hash}); } + void insert(const site_mapping& m) { + for(std::size_t idx = 0; idx < m.size(); ++idx) { + this->insert(m.get_site(idx)); + } + } + network_full_site_info get_site(std::size_t idx) const { const auto& s = this->sites.at(idx); @@ -137,9 +143,6 @@ std::vector generate_network_connections(const recipe& rec, const auto& description = description_opt.value(); - site_mapping src_sites, dest_sites; - std::mutex src_sites_mutex, dest_sites_mutex; - const auto selection_ptr = thingify(description.selection, description.dict); const auto weight_ptr = thingify(description.weight, description.dict); const auto delay_ptr = thingify(description.delay, description.dict); @@ -155,12 +158,20 @@ std::vector generate_network_connections(const recipe& rec, for (const auto& gid: group.gids) { gids.emplace_back(gid); } } + const auto num_batches = ctx->thread_pool->get_num_threads(); + std::vector src_site_batches(num_batches); + std::vector dest_site_batches(num_batches); + for (const auto& [kind, gids]: gids_by_kind) { + const auto batch_size = (gids.size() + num_batches - 1) / num_batches; // populate network sites for source and destination if (kind == cell_kind::cable) { const auto& cable_gids = gids; threading::parallel_for::apply( - 0, cable_gids.size(), ctx->thread_pool.get(), [&](int i) { + 0, cable_gids.size(), batch_size, ctx->thread_pool.get(), [&](int i) { + const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); + auto& src_sites = src_site_batches[batch_idx]; + auto& dest_sites = dest_site_batches[batch_idx]; const auto gid = cable_gids[i]; const auto kind = rec.get_cell_kind(gid); // We need access to morphology, so the cell is create directly @@ -191,7 +202,6 @@ std::vector generate_network_connections(const recipe& rec, if (selection.select_destination(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_syn.loc); - std::lock_guard guard(dest_sites_mutex); dest_sites.insert( {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); } @@ -203,7 +213,6 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); if (selection.select_source(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_det.loc); - std::lock_guard guard(src_sites_mutex); src_sites.insert( {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); } @@ -219,6 +228,9 @@ std::vector generate_network_connections(const recipe& rec, cell_label_range sources, destinations; std::ignore = factory(gids, rec, sources, destinations); + auto& src_sites = src_site_batches[0]; + auto& dest_sites = dest_site_batches[0]; + std::size_t source_label_offset = 0; std::size_t destination_label_offset = 0; for (std::size_t i = 0; i < gids.size(); ++i) { @@ -236,7 +248,6 @@ std::vector generate_network_connections(const recipe& rec, const auto& range = sources.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_source(kind, gid, label)) { - std::lock_guard guard(src_sites_mutex); src_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); } } @@ -250,7 +261,6 @@ std::vector generate_network_connections(const recipe& rec, const auto& range = destinations.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_destination(kind, gid, label)) { - std::lock_guard guard(dest_sites_mutex); dest_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); } } @@ -262,11 +272,23 @@ std::vector generate_network_connections(const recipe& rec, } } + site_mapping& src_sites = src_site_batches.front(); + + // combine src batches + for (std::size_t batch_idx = 1; batch_idx < src_site_batches.size(); ++batch_idx) { + + for (std::size_t i = 0; i < src_site_batches[batch_idx].size(); ++i) { + src_sites.insert(src_site_batches[batch_idx].get_site(i)); + } + } + // create octree std::vector network_dest_sites; - network_dest_sites.reserve(dest_sites.size()); - for (std::size_t i = 0; i < dest_sites.size(); ++i) { - network_dest_sites.emplace_back(dest_sites.get_site(i)); + network_dest_sites.reserve(dest_site_batches[0].size() * num_batches); + for (const auto& dest_sites: dest_site_batches) { + for (std::size_t i = 0; i < dest_sites.size(); ++i) { + network_dest_sites.emplace_back(dest_sites.get_site(i)); + } } const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; const std::size_t max_leaf_size = 100; @@ -279,45 +301,47 @@ std::vector generate_network_connections(const recipe& rec, }); // select connections - std::vector connections; - std::mutex connections_mutex; + std::vector> connection_batches(num_batches); auto sample_sources = [&](const util::range& source_range, const util::range& label_range) { - threading::parallel_for::apply(0, source_range.size(), ctx->thread_pool.get(), [&](int i) { - const auto& s = source_range[i]; - network_full_site_info src; - src.gid = s.gid; - src.lid = s.lid; - src.kind = s.kind; - src.label = label_range.data() + s.label_start_idx; - src.location = s.location; - src.global_location = s.global_location; - src.hash = s.hash; - - auto sample = [&](const network_full_site_info& dest) { - if (selection.select_connection(src, dest)) { - const auto w = weight.get(src, dest); - const auto d = delay.get(src, dest); - - std::lock_guard guard(connections_mutex); - push_back(connections, src, dest, w, d); + const auto batch_size = (source_range.size() + num_batches - 1) / num_batches; + threading::parallel_for::apply( + 0, source_range.size(), batch_size, ctx->thread_pool.get(), [&](int i) { + const auto& s = source_range[i]; + const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); + auto& connections = connection_batches[batch_idx]; + network_full_site_info src; + src.gid = s.gid; + src.lid = s.lid; + src.kind = s.kind; + src.label = label_range.data() + s.label_start_idx; + src.location = s.location; + src.global_location = s.global_location; + src.hash = s.hash; + + auto sample = [&](const network_full_site_info& dest) { + if (selection.select_connection(src, dest)) { + const auto w = weight.get(src, dest); + const auto d = delay.get(src, dest); + + push_back(connections, src, dest, w, d); + } + }; + + if (selection.max_distance().has_value()) { + const double d = selection.max_distance().value(); + local_dest_tree.bounding_box_for_each( + decltype(local_dest_tree)::point_type{src.global_location.x - d, + src.global_location.y - d, + src.global_location.z - d}, + decltype(local_dest_tree)::point_type{src.global_location.x + d, + src.global_location.y + d, + src.global_location.z + d}, + sample); } - }; - - if (selection.max_distance().has_value()) { - const double d = selection.max_distance().value(); - local_dest_tree.bounding_box_for_each( - decltype(local_dest_tree)::point_type{src.global_location.x - d, - src.global_location.y - d, - src.global_location.z - d}, - decltype(local_dest_tree)::point_type{src.global_location.x + d, - src.global_location.y + d, - src.global_location.z + d}, - sample); - } - else { local_dest_tree.for_each(sample); } - }); + else { local_dest_tree.for_each(sample); } + }); }; distributed_for_each(sample_sources, @@ -325,6 +349,12 @@ std::vector generate_network_connections(const recipe& rec, util::range_view(src_sites.sites), util::range_view(src_sites.labels)); + // concatenate + auto connections = std::move(connection_batches.front()); + for (std::size_t i = 1; i < connection_batches.size(); ++i) { + connections.insert( + connections.end(), connection_batches[i].begin(), connection_batches[i].end()); + } return connections; } diff --git a/arbor/threading/threading.cpp b/arbor/threading/threading.cpp index 9b7c639ca8..44166d3396 100644 --- a/arbor/threading/threading.cpp +++ b/arbor/threading/threading.cpp @@ -1,8 +1,10 @@ #include +#include #include #include #include +#include #include "threading/threading.hpp" #include "affinity.hpp" @@ -184,3 +186,9 @@ void task_system::async(priority_task ptsk) { std::unordered_map task_system::get_thread_ids() const { return thread_ids_; }; + +std::optional task_system::get_current_thread_id() const { + const auto it = thread_ids_.find(std::this_thread::get_id()); + if(it != thread_ids_.end()) return it->second; + return std::nullopt; +} diff --git a/arbor/threading/threading.hpp b/arbor/threading/threading.hpp index c690781e89..7e2e4a2509 100644 --- a/arbor/threading/threading.hpp +++ b/arbor/threading/threading.hpp @@ -8,11 +8,12 @@ #include #include #include +#include #include #include -#include #include #include +#include #include @@ -220,6 +221,9 @@ class ARB_ARBOR_API task_system { // Returns the thread_id map std::unordered_map get_thread_ids() const; + + // Returns the calling thread id if part of the task system + std::optional get_current_thread_id() const; }; class task_group { From 76d44baefaee9023b99b484eb4f59f61b2f69ee3 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 3 Sep 2023 13:34:16 +0200 Subject: [PATCH 52/84] move mapping of gid to local domain index into domain decomposition --- arbor/communication/communicator.cpp | 29 ++++++++++---------- arbor/communication/communicator.hpp | 1 - arbor/connection.hpp | 15 ++++++---- arbor/domain_decomposition.cpp | 21 ++++++++++---- arbor/include/arbor/domain_decomposition.hpp | 6 ++-- arbor/network_impl.cpp | 10 ++++--- 6 files changed, 49 insertions(+), 33 deletions(-) diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index a4dc168e24..937c71df25 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -63,7 +63,6 @@ void communicator::update_connections(const recipe& rec, connections_.clear(); connection_part_.clear(); index_divisions_.clear(); - index_on_domain_.clear(); PL(); // Construct connections from high-level specification @@ -83,7 +82,6 @@ void communicator::update_connections(const recipe& rec, PE(init:communicator:update:collect_gids); std::vector gids; gids.reserve(num_local_cells_); for (const auto& g: dom_dec.groups()) util::append(gids, g.gids); - for (const auto index: util::make_span(gids.size())) index_on_domain_.insert({gids[index], index}); PL(); // Build the connection information for local cells. @@ -153,7 +151,11 @@ void communicator::update_connections(const recipe& rec, auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); auto offset = offsets[*src_domain]++; ++src_domain; - connections_[offset] = {{src_gid, src_lid}, {tgt_gid, tgt_lid}, conn.weight, conn.delay}; + connections_[offset] = {{src_gid, src_lid}, + tgt_lid, + conn.weight, + conn.delay, + dom_dec.index_on_domain(tgt_gid)}; } for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) { const auto& conn = gid_ext_connections[cidx]; @@ -161,7 +163,8 @@ void communicator::update_connections(const recipe& rec, auto src_gid = conn.source.rid; if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid); auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); - ext_connections_[ext] = {src, {tgt_gid, tgt_lid}, conn.weight, conn.delay}; + ext_connections_[ext] = { + src, tgt_lid, conn.weight, conn.delay, dom_dec.index_on_domain(tgt_gid)}; ++ext; } } @@ -247,10 +250,8 @@ void communicator::remote_ctrl_send_continue(const epoch& e) { ctx_->distributed void communicator::remote_ctrl_send_done() { ctx_->distributed->remote_ctrl_send_done(); } // Internal helper to append to the event queues -template -void append_events_from_domain(C cons, - S spks, const std::unordered_map& index_on_domain, - std::vector& queues) { +template +void append_events_from_domain(C cons, S spks, std::vector& queues) { // Predicate for partitioning struct spike_pred { bool operator()(const spike& spk, const cell_member_type& src) { return spk.source < src; } @@ -272,7 +273,7 @@ void append_events_from_domain(C cons, while (cn != ce && sp != se) { auto sources = std::equal_range(sp, se, cn->source, spike_pred()); for (auto s: util::make_range(sources)) { - queues[index_on_domain.at(cn->destination.gid)].push_back(make_event(*cn, s)); + queues[cn->index_on_domain].push_back(make_event(*cn, s)); } sp = sources.first; ++cn; @@ -282,7 +283,7 @@ void append_events_from_domain(C cons, while (cn != ce && sp != se) { auto targets = std::equal_range(cn, ce, sp->source); for (auto c: util::make_range(targets)) { - queues[index_on_domain.at(c.destination.gid)].push_back(make_event(c, *sp)); + queues[c.index_on_domain].push_back(make_event(c, *sp)); } cn = targets.first; ++sp; @@ -298,9 +299,9 @@ void communicator::make_event_queues( const auto& sp = global_spikes.partition(); const auto& cp = connection_part_; for (auto dom: util::make_span(num_domains_)) { - append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom+1]), - util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]), index_on_domain_, - queues); + append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom + 1]), + util::subrange_view(global_spikes.values(), sp[dom], sp[dom + 1]), + queues); } num_local_events_ = util::sum_by(queues, [](const auto& q) {return q.size();}, num_local_events_); // Now that all local spikes have been processed; consume the remote events coming in. @@ -309,7 +310,7 @@ void communicator::make_event_queues( std::for_each(spikes.begin(), spikes.end(), [](auto& s) { s.source = global_cell_of(s.source); }); - append_events_from_domain(ext_connections_, spikes, index_on_domain_, queues); + append_events_from_domain(ext_connections_, spikes, queues); } std::uint64_t communicator::num_spikes() const { diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index 200ecec527..4423b2f93c 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -102,7 +102,6 @@ class ARB_ARBOR_API communicator { std::vector connection_part_; std::vector index_divisions_; util::partition_view_type> index_part_; - std::unordered_map index_on_domain_; spike_predicate remote_spike_filter_; diff --git a/arbor/connection.hpp b/arbor/connection.hpp index 4063cb0bb0..e83fa28722 100644 --- a/arbor/connection.hpp +++ b/arbor/connection.hpp @@ -10,14 +10,15 @@ namespace arb { struct connection { cell_member_type source = {0, 0}; - cell_member_type destination = {0, 0}; - double weight = 0.0f; - double delay = 0.0f; + cell_lid_type destination = 0; + float weight = 0.0f; + float delay = 0.0f; + cell_size_type index_on_domain = cell_gid_type(-1); }; inline spike_event make_event(const connection& c, const spike& s) { - return {c.destination.index, s.time + c.delay, c.weight}; + return {c.destination, s.time + c.delay, c.weight}; } // connections are sorted by source id @@ -29,6 +30,8 @@ static inline bool operator<(cell_member_type lhs, const connection& rhs) { ret } // namespace arb static inline std::ostream& operator<<(std::ostream& o, arb::connection const& con) { - return o << "con [" << con.source << " -> " << con.destination << " : weight " << con.weight - << ", delay " << con.delay << "]"; + return o << "con [" << con.source << " -> " << con.destination + << " : weight " << con.weight + << ", delay " << con.delay + << ", index " << con.index_on_domain << "]"; } diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp index aa22082120..eb90a7d5d5 100644 --- a/arbor/domain_decomposition.cpp +++ b/arbor/domain_decomposition.cpp @@ -1,11 +1,13 @@ #include #include #include +#include -#include +#include +#include #include +#include #include -#include #include "execution_context.hpp" #include "util/partition.hpp" @@ -22,15 +24,18 @@ domain_decomposition::domain_decomposition( partition_gid_domain(const gathered_vector& divs, unsigned domains) { auto rank_part = util::partition_view(divs.partition()); for (auto rank: count_along(rank_part)) { + cell_size_type index_on_domain = 0; for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) { - gid_map[gid] = rank; + gid_map[gid] = {rank, index_on_domain}; + ++index_on_domain; } } } - int operator()(cell_gid_type gid) const { + std::pair operator()(cell_gid_type gid) const { return gid_map.at(gid); } - std::unordered_map gid_map; + // Maps gid to domain index and cell index on domain + std::unordered_map> gid_map; }; const auto* dist = ctx->distributed.get(); @@ -85,7 +90,11 @@ domain_decomposition::domain_decomposition( } int domain_decomposition::gid_domain(cell_gid_type gid) const { - return gid_domain_(gid); + return gid_domain_(gid).first; +} + +cell_size_type domain_decomposition::index_on_domain(cell_gid_type gid) const { + return gid_domain_(gid).second; } int domain_decomposition::num_domains() const { diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index 2706cd2935..54f61ed7ce 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -45,6 +46,7 @@ class ARB_ARBOR_API domain_decomposition { domain_decomposition& operator=(const domain_decomposition&) = default; int gid_domain(cell_gid_type gid) const; + cell_size_type index_on_domain(cell_gid_type gid) const; int num_domains() const; int domain_id() const; cell_size_type num_local_cells() const; @@ -54,10 +56,10 @@ class ARB_ARBOR_API domain_decomposition { const group_description& group(unsigned) const; private: - /// Return the domain id of cell with gid. + /// Return the domain id and index on domain of cell with gid. /// Supplied by the load balancing algorithm that generates the domain /// decomposition. - std::function gid_domain_; + std::function(cell_gid_type)> gid_domain_; /// Number of distributed domains int num_domains_; diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 518c8ce5ee..4f27425203 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -110,15 +110,17 @@ struct site_mapping { } }; -void push_back(std::vector& vec, +void push_back(const domain_decomposition& dom_dec, + std::vector& vec, const network_full_site_info& src, const network_full_site_info& dest, double weight, double delay) { - vec.emplace_back(connection{{src.gid, src.lid}, {dest.gid, dest.lid}, weight, delay}); + vec.emplace_back(connection{{src.gid, src.lid}, dest.lid, (float)weight, (float)delay, dom_dec.index_on_domain(dest.gid)}); } -void push_back(std::vector& vec, +void push_back(const domain_decomposition&, + std::vector& vec, const network_full_site_info& src, const network_full_site_info& dest, double weight, @@ -325,7 +327,7 @@ std::vector generate_network_connections(const recipe& rec, const auto w = weight.get(src, dest); const auto d = delay.get(src, dest); - push_back(connections, src, dest, w, d); + push_back(dom_dec, connections, src, dest, w, d); } }; From 365f9e30b1cd8a7682cf26dcd484c9f266f253de Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 3 Sep 2023 13:50:49 +0200 Subject: [PATCH 53/84] fix communicator test --- test/unit-distributed/test_communicator.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index 43c07e1833..c282c762e5 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -652,7 +652,8 @@ TEST(communicator, all2all) auto c = connections[i*n_local+j]; EXPECT_EQ(i, c.source.gid); EXPECT_EQ(0u, c.source.index); - EXPECT_EQ(i, c.destination.index); + EXPECT_EQ(i, c.destination); + EXPECT_LT(c.index_on_domain, n_local); } } @@ -695,7 +696,7 @@ TEST(communicator, mini_network) // sort connections by source then target auto connections = C.connections(); util::sort(connections, [](const connection& lhs, const connection& rhs) { - return std::forward_as_tuple(lhs.source, lhs.destination) < std::forward_as_tuple(rhs.source, rhs.destination); + return std::forward_as_tuple(lhs.source, lhs.index_on_domain, lhs.destination) < std::forward_as_tuple(rhs.source, rhs.index_on_domain, rhs.destination); }); // Expect one set of 22 connections from every rank: these have been sorted. @@ -709,7 +710,7 @@ TEST(communicator, mini_network) auto c = connections[i*22 + j]; EXPECT_EQ(ex_source_gids[j], c.source.gid); EXPECT_EQ(ex_source_lids[j], c.source.index); - EXPECT_EQ(ex_target_lids[i%2][j], c.destination.index); + EXPECT_EQ(ex_target_lids[i%2][j], c.destination); } } } From 11ae9544803f2f32c6c48184737e4980ea893adf Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Sun, 3 Sep 2023 14:02:04 +0200 Subject: [PATCH 54/84] rename destination to target --- arbor/connection.hpp | 6 +- arbor/include/arbor/network.hpp | 36 +- arbor/include/arbor/network_generation.hpp | 1 - arbor/network.cpp | 415 +++++++++--------- arbor/network_impl.cpp | 105 ++--- arbor/network_impl.hpp | 18 +- arborio/networkio.cpp | 45 +- doc/concepts/interconnectivity.rst | 42 +- doc/cpp/interconnectivity.rst | 50 +-- doc/python/interconnectivity.rst | 32 +- .../network_description.cpp | 81 ++-- python/example/network_description.py | 8 +- python/network.cpp | 48 +- test/unit-distributed/test_communicator.cpp | 10 +- .../test_network_generation.cpp | 58 +-- test/unit/test_network.cpp | 394 ++++++++--------- 16 files changed, 675 insertions(+), 674 deletions(-) diff --git a/arbor/connection.hpp b/arbor/connection.hpp index e83fa28722..276ed217aa 100644 --- a/arbor/connection.hpp +++ b/arbor/connection.hpp @@ -10,7 +10,7 @@ namespace arb { struct connection { cell_member_type source = {0, 0}; - cell_lid_type destination = 0; + cell_lid_type target = 0; float weight = 0.0f; float delay = 0.0f; cell_size_type index_on_domain = cell_gid_type(-1); @@ -18,7 +18,7 @@ struct connection { inline spike_event make_event(const connection& c, const spike& s) { - return {c.destination, s.time + c.delay, c.weight}; + return {c.target, s.time + c.delay, c.weight}; } // connections are sorted by source id @@ -30,7 +30,7 @@ static inline bool operator<(cell_member_type lhs, const connection& rhs) { ret } // namespace arb static inline std::ostream& operator<<(std::ostream& o, arb::connection const& con) { - return o << "con [" << con.source << " -> " << con.destination + return o << "con [" << con.source << " -> " << con.target << " : weight " << con.weight << ", delay " << con.delay << ", index " << con.index_on_domain << "]"; diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index d3cdb9b0b5..3ea02b40cd 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -39,13 +39,16 @@ ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_site_info, (b.gid, a.kind, b.label, b.location, b.global_location)) struct ARB_SYMBOL_VISIBLE network_connection_info { - network_site_info src, dest; + network_site_info source, target; double weight, delay; - ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_connection_info& s); + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, + const network_connection_info& s); }; -ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info, (a.src, a.dest), (b.src, b.dest)) +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info, + (a.source, a.target), + (b.source, b.target)) struct network_selection_impl; @@ -70,7 +73,7 @@ class ARB_SYMBOL_VISIBLE network_value { // A named value inside a network label dictionary static network_value named(std::string name); - // Distamce netweem source and destination site + // Distamce netweem source and target site static network_value distance(double scale = 1.0); // Uniform random value in (range[0], range[1]]. @@ -168,14 +171,14 @@ class ARB_SYMBOL_VISIBLE network_selection { // Select connections with the given source cell kind static network_selection source_cell_kind(cell_kind kind); - // Select connections with the given destination cell kind - static network_selection destination_cell_kind(cell_kind kind); + // Select connections with the given target cell kind + static network_selection target_cell_kind(cell_kind kind); // Select connections with the given source label static network_selection source_label(std::vector labels); - // Select connections with the given destination label - static network_selection destination_label(std::vector labels); + // Select connections with the given target label + static network_selection target_label(std::vector labels); // Select connections with source cells matching the indices in the list static network_selection source_cell(std::vector gids); @@ -183,19 +186,22 @@ class ARB_SYMBOL_VISIBLE network_selection { // Select connections with source cells matching the indices in the range static network_selection source_cell(gid_range range); - // Select connections with destination cells matching the indices in the list - static network_selection destination_cell(std::vector gids); + // Select connections with target cells matching the indices in the list + static network_selection target_cell(std::vector gids); - // Select connections with destination cells matching the indices in the range - static network_selection destination_cell(gid_range range); + // Select connections with target cells matching the indices in the range + static network_selection target_cell(gid_range range); - // Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + // Select connections that form a chain, such that source cell "i" is connected to the target + // cell "i+1" static network_selection chain(std::vector gids); - // Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + // Select connections that form a chain, such that source cell "i" is connected to the target + // cell "i+1" static network_selection chain(gid_range range); - // Select connections that form a reversed chain, such that source cell "i+1" is connected to the destination cell "i" + // Select connections that form a reversed chain, such that source cell "i+1" is connected to + // the target cell "i" static network_selection chain_reverse(gid_range range); // Select connections, that are selected by both "left" and "right" diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp index 515c7abbea..aa8e9e25ea 100644 --- a/arbor/include/arbor/network_generation.hpp +++ b/arbor/include/arbor/network_generation.hpp @@ -12,7 +12,6 @@ ARB_ARBOR_API std::vector generate_network_connections( const context& ctx, const domain_decomposition& dom_dec); - ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec); } // namespace arb diff --git a/arbor/network.cpp b/arbor/network.cpp index 1a55427ce0..a6e08af8cd 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -29,7 +29,6 @@ enum class network_seed : unsigned { value_truncated_normal = 380237, }; - double uniform_rand_from_key_pair(std::array seed, network_hash_type key_a, network_hash_type key_b) { @@ -54,8 +53,8 @@ double normal_rand_from_key_pair(std::array seed, } struct network_selection_all_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { return true; } @@ -65,7 +64,7 @@ struct network_selection_all_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -76,8 +75,8 @@ struct network_selection_all_impl: public network_selection_impl { struct network_selection_none_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { return false; } @@ -87,7 +86,7 @@ struct network_selection_none_impl: public network_selection_impl { return false; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return false; @@ -101,9 +100,9 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { explicit network_selection_source_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return src.kind == select_kind; + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return source.kind == select_kind; } bool select_source(cell_kind kind, @@ -112,7 +111,7 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { return kind == select_kind; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -130,14 +129,14 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { } }; -struct network_selection_destination_cell_kind_impl: public network_selection_impl { +struct network_selection_target_cell_kind_impl: public network_selection_impl { cell_kind select_kind; - explicit network_selection_destination_cell_kind_impl(cell_kind k): select_kind(k) {} + explicit network_selection_target_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return dest.kind == select_kind; + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return target.kind == select_kind; } bool select_source(cell_kind kind, @@ -146,14 +145,14 @@ struct network_selection_destination_cell_kind_impl: public network_selection_im return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return kind == select_kind; } void print(std::ostream& os) const override { - os << "(destination-cell-kind ("; + os << "(target-cell-kind ("; switch (select_kind) { case arb::cell_kind::spike_source: os << "spike-source"; break; case arb::cell_kind::cable: os << "cable"; break; @@ -172,9 +171,9 @@ struct network_selection_source_label_impl: public network_selection_impl { std::sort(sorted_labels.begin(), sorted_labels.end()); } - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), src.label); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), source.label); } bool select_source(cell_kind kind, @@ -183,7 +182,7 @@ struct network_selection_source_label_impl: public network_selection_impl { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -196,17 +195,17 @@ struct network_selection_source_label_impl: public network_selection_impl { } }; -struct network_selection_destination_label_impl: public network_selection_impl { +struct network_selection_target_label_impl: public network_selection_impl { std::vector sorted_labels; - explicit network_selection_destination_label_impl(std::vector labels): + explicit network_selection_target_label_impl(std::vector labels): sorted_labels(std::move(labels)) { std::sort(sorted_labels.begin(), sorted_labels.end()); } - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), dest.label); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::binary_search(sorted_labels.begin(), sorted_labels.end(), target.label); } bool select_source(cell_kind kind, @@ -215,14 +214,14 @@ struct network_selection_destination_label_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } void print(std::ostream& os) const override { - os << "(destination-label"; + os << "(target-label"; for (const auto& l: sorted_labels) { os << " \"" << l << "\""; } os << ")"; } @@ -236,9 +235,9 @@ struct network_selection_source_cell_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid); } bool select_source(cell_kind kind, @@ -247,7 +246,7 @@ struct network_selection_source_cell_impl: public network_selection_impl { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -268,9 +267,10 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return src.gid >= gid_begin && src.gid < gid_end && !((src.gid - gid_begin) % step); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return source.gid >= gid_begin && source.gid < gid_end && + !((source.gid - gid_begin) % step); } bool select_source(cell_kind kind, @@ -279,7 +279,7 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -290,17 +290,17 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { } }; -struct network_selection_destination_cell_impl: public network_selection_impl { +struct network_selection_target_cell_impl: public network_selection_impl { std::vector sorted_gids; - network_selection_destination_cell_impl(std::vector gids): + network_selection_target_cell_impl(std::vector gids): sorted_gids(std::move(gids)) { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid); } bool select_source(cell_kind kind, @@ -309,30 +309,31 @@ struct network_selection_destination_cell_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } void print(std::ostream& os) const override { - os << "(destination-cell"; + os << "(target-cell"; for (const auto& g: sorted_gids) { os << " " << g; } os << ")"; } }; -struct network_selection_destination_cell_range_impl: public network_selection_impl { +struct network_selection_target_cell_range_impl: public network_selection_impl { cell_gid_type gid_begin, gid_end, step; - network_selection_destination_cell_range_impl(gid_range r): + network_selection_target_cell_range_impl(gid_range r): gid_begin(r.begin), gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return dest.gid >= gid_begin && dest.gid < gid_end && !((dest.gid - gid_begin) % step); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return target.gid >= gid_begin && target.gid < gid_end && + !((target.gid - gid_begin) % step); } bool select_source(cell_kind kind, @@ -341,14 +342,14 @@ struct network_selection_destination_cell_range_impl: public network_selection_i return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } void print(std::ostream& os) const override { - os << "(destination-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; + os << "(target-cell (gid-range " << gid_begin << " " << gid_end << " " << step << "))"; } }; @@ -360,20 +361,20 @@ struct network_selection_chain_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { if (gids.empty()) return false; // gids size always > 0 frome here on // First check if both are part of ring - if (!std::binary_search(sorted_gids.begin(), sorted_gids.end(), src.gid) || - !std::binary_search(sorted_gids.begin(), sorted_gids.end(), dest.gid)) + if (!std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid) || + !std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid)) return false; for (std::size_t i = 0; i < gids.size() - 1; ++i) { // return true if neighbors in gids list - if ((src.gid == gids[i] && dest.gid == gids[i + 1])) return true; + if ((source.gid == gids[i] && target.gid == gids[i + 1])) return true; } return false; @@ -386,7 +387,7 @@ struct network_selection_chain_impl: public network_selection_impl { std::binary_search(sorted_gids.begin(), sorted_gids.end() - 1, gid); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return !sorted_gids.empty() && @@ -408,13 +409,13 @@ struct network_selection_chain_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || - dest.gid >= gid_end) + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || + target.gid >= gid_end) return false; - return src.gid + step == dest.gid && !((src.gid - gid_begin) % step); + return source.gid + step == target.gid && !((source.gid - gid_begin) % step); } bool select_source(cell_kind kind, @@ -425,10 +426,10 @@ struct network_selection_chain_range_impl: public network_selection_impl { return !((gid - gid_begin) % step); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - // Return false if outside range or if equal to first element, which cannot be a destination + // Return false if outside range or if equal to first element, which cannot be a target if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } @@ -446,13 +447,13 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - if (src.gid < gid_begin || src.gid >= gid_end || dest.gid < gid_begin || - dest.gid >= gid_end) + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || + target.gid >= gid_end) return false; - return dest.gid + step == src.gid && !((src.gid - gid_begin) % step); + return target.gid + step == source.gid && !((source.gid - gid_begin) % step); } bool select_source(cell_kind kind, @@ -463,10 +464,10 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl return !((gid - gid_begin) % step); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - // Return false if outside range or if equal to last element, which cannot be a destination + // Return false if outside range or if equal to last element, which cannot be a target if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); } @@ -482,9 +483,9 @@ struct network_selection_complement_impl: public network_selection_impl { explicit network_selection_complement_impl(std::shared_ptr s): selection(std::move(s)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return !selection->select_connection(src, dest); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return !selection->select_connection(source, target); } bool select_source(cell_kind kind, @@ -494,10 +495,10 @@ struct network_selection_complement_impl: public network_selection_impl { // knowing selection criteria. } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return true; // cannot exclude any because destination selection cannot be complemented + return true; // cannot exclude any because target selection cannot be complemented // without knowing selection criteria. } @@ -518,11 +519,11 @@ struct network_selection_named_impl: public network_selection_impl { explicit network_selection_named_impl(std::string name): selection_name(std::move(name)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); - return selection->select_connection(src, dest); + return selection->select_connection(source, target); } bool select_source(cell_kind kind, @@ -533,12 +534,12 @@ struct network_selection_named_impl: public network_selection_impl { return selection->select_source(kind, gid, label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); - return selection->select_destination(kind, gid, label); + return selection->select_target(kind, gid, label); } void initialize(const network_label_dict& dict) override { @@ -555,9 +556,9 @@ struct network_selection_named_impl: public network_selection_impl { }; struct network_selection_inter_cell_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return src.gid != dest.gid; + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return source.gid != target.gid; } bool select_source(cell_kind kind, @@ -566,7 +567,7 @@ struct network_selection_inter_cell_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -581,14 +582,18 @@ struct network_selection_custom_impl: public network_selection_impl { explicit network_selection_custom_impl(network_selection::custom_func_type f): func(std::move(f)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return func({{src.gid, - src.kind, - cell_tag_type(src.label), - src.location, - src.global_location}, - {dest.gid, dest.kind, cell_tag_type(dest.label), dest.location, dest.global_location}}); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return func({{source.gid, + source.kind, + cell_tag_type(source.label), + source.location, + source.global_location}, + {target.gid, + target.kind, + cell_tag_type(target.label), + target.location, + target.global_location}}); } bool select_source(cell_kind kind, @@ -597,7 +602,7 @@ struct network_selection_custom_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -611,9 +616,9 @@ struct network_selection_distance_lt_impl: public network_selection_impl { explicit network_selection_distance_lt_impl(double d): d(d) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return distance(src.global_location, dest.global_location) < d; + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return distance(source.global_location, target.global_location) < d; } bool select_source(cell_kind kind, @@ -622,7 +627,7 @@ struct network_selection_distance_lt_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -638,9 +643,9 @@ struct network_selection_distance_gt_impl: public network_selection_impl { explicit network_selection_distance_gt_impl(double d): d(d) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return distance(src.global_location, dest.global_location) > d; + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return distance(source.global_location, target.global_location) > d; } bool select_source(cell_kind kind, @@ -649,7 +654,7 @@ struct network_selection_distance_gt_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -662,17 +667,19 @@ struct network_selection_random_impl: public network_selection_impl { unsigned seed; network_value p_value; - std::shared_ptr probability; // may be null if unitialize(...) not called + std::shared_ptr probability; // may be null if unitialize(...) not called - network_selection_random_impl(unsigned seed, network_value p): seed(seed), p_value(std::move(p)) {} + network_selection_random_impl(unsigned seed, network_value p): + seed(seed), + p_value(std::move(p)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { if (!probability) throw arbor_internal_error("Trying to use unitialized named network selection."); const auto r = uniform_rand_from_key_pair( - {unsigned(network_seed::selection_random), seed}, src.hash, dest.hash); - const auto p = (probability->get(src, dest)); + {unsigned(network_seed::selection_random), seed}, source.hash, target.hash); + const auto p = (probability->get(source, target)); return r < p; } @@ -682,7 +689,7 @@ struct network_selection_random_impl: public network_selection_impl { return true; } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { return true; @@ -707,9 +714,9 @@ struct network_selection_intersect_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return left->select_connection(src, dest) && right->select_connection(src, dest); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->select_connection(source, target) && right->select_connection(source, target); } bool select_source(cell_kind kind, @@ -718,11 +725,10 @@ struct network_selection_intersect_impl: public network_selection_impl { return left->select_source(kind, gid, label) && right->select_source(kind, gid, label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return left->select_destination(kind, gid, label) && - right->select_destination(kind, gid, label); + return left->select_target(kind, gid, label) && right->select_target(kind, gid, label); } std::optional max_distance() const override { @@ -758,9 +764,9 @@ struct network_selection_join_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return left->select_connection(src, dest) || right->select_connection(src, dest); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->select_connection(source, target) || right->select_connection(source, target); } bool select_source(cell_kind kind, @@ -769,11 +775,10 @@ struct network_selection_join_impl: public network_selection_impl { return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return left->select_destination(kind, gid, label) || - right->select_destination(kind, gid, label); + return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } std::optional max_distance() const override { @@ -807,9 +812,9 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return left->select_connection(src, dest) ^ right->select_connection(src, dest); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->select_connection(source, target) ^ right->select_connection(source, target); } bool select_source(cell_kind kind, @@ -818,11 +823,10 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return left->select_destination(kind, gid, label) || - right->select_destination(kind, gid, label); + return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } std::optional max_distance() const override { @@ -856,9 +860,10 @@ struct network_selection_difference_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const override { - return left->select_connection(src, dest) && !(right->select_connection(src, dest)); + bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->select_connection(source, target) && + !(right->select_connection(source, target)); } bool select_source(cell_kind kind, @@ -867,10 +872,10 @@ struct network_selection_difference_impl: public network_selection_impl { return left->select_source(kind, gid, label); } - bool select_destination(cell_kind kind, + bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& label) const override { - return left->select_destination(kind, gid, label); + return left->select_target(kind, gid, label); } std::optional max_distance() const override { @@ -900,21 +905,22 @@ struct network_value_scalar_impl: public network_value_impl { network_value_scalar_impl(double v): value(v) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { return value; } void print(std::ostream& os) const override { os << "(scalar " << value << ")"; } }; - struct network_value_distance_impl: public network_value_impl { double scale; network_value_distance_impl(double s): scale(s) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return scale * distance(src.global_location, dest.global_location); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return scale * distance(source.global_location, target.global_location); } void print(std::ostream& os) const override { os << "(distance " << scale << ")"; } @@ -931,12 +937,13 @@ struct network_value_uniform_distribution_impl: public network_value_impl { throw std::invalid_argument("Uniform distribution: invalid range"); } - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { if (range[0] > range[1]) return range[1]; // random number between 0 and 1 const auto rand_num = uniform_rand_from_key_pair( - {unsigned(network_seed::value_uniform), seed}, src.hash, dest.hash); + {unsigned(network_seed::value_uniform), seed}, source.hash, target.hash); return (range[1] - range[0]) * rand_num + range[0]; } @@ -956,10 +963,12 @@ struct network_value_normal_distribution_impl: public network_value_impl { mean(mean_), std_deviation(std_deviation_) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { return mean + std_deviation * - normal_rand_from_key_pair( - {unsigned(network_seed::value_normal), seed}, src.hash, dest.hash); + normal_rand_from_key_pair({unsigned(network_seed::value_normal), seed}, + source.hash, + target.hash); } void print(std::ostream& os) const override { @@ -985,10 +994,11 @@ struct network_value_truncated_normal_distribution_impl: public network_value_im throw std::invalid_argument("Truncated normal distribution: invalid range"); } - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { - const auto src_hash = src.hash; - auto dest_hash = dest.hash; + const auto src_hash = source.hash; + auto dest_hash = target.hash; double value = 0.0; @@ -1015,13 +1025,18 @@ struct network_value_custom_impl: public network_value_impl { network_value_custom_impl(network_value::custom_func_type f): func(std::move(f)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return func({{src.gid, - src.kind, - cell_tag_type(src.label), - src.location, - src.global_location}, - {dest.gid, dest.kind, cell_tag_type(dest.label), dest.location, dest.global_location}}); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return func({{source.gid, + source.kind, + cell_tag_type(source.label), + source.location, + source.global_location}, + {target.gid, + target.kind, + cell_tag_type(target.label), + target.location, + target.global_location}}); } void print(std::ostream& os) const override { os << "(custom-network-value)"; } @@ -1035,9 +1050,10 @@ struct network_value_named_impl: public network_value_impl { explicit network_value_named_impl(std::string name): value_name(std::move(name)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { if (!value) throw arbor_internal_error("Trying to use unitialized named network value."); - return value->get(src, dest); + return value->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1061,8 +1077,9 @@ struct network_value_add_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return left->get(src, dest) + right->get(src, dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->get(source, target) + right->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1079,7 +1096,6 @@ struct network_value_add_impl: public network_value_impl { } }; - struct network_value_mul_impl: public network_value_impl { std::shared_ptr left, right; @@ -1088,8 +1104,9 @@ struct network_value_mul_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return left->get(src, dest) * right->get(src, dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->get(source, target) * right->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1106,7 +1123,6 @@ struct network_value_mul_impl: public network_value_impl { } }; - struct network_value_sub_impl: public network_value_impl { std::shared_ptr left, right; @@ -1115,8 +1131,9 @@ struct network_value_sub_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return left->get(src, dest) - right->get(src, dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return left->get(source, target) - right->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1141,10 +1158,11 @@ struct network_value_div_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - const auto v_right = right ->get(src,dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + const auto v_right = right->get(source, target); if (!v_right) throw arbor_exception("network_value: division by 0."); - return left->get(src, dest) / right->get(src, dest); + return left->get(source, target) / right->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1169,8 +1187,9 @@ struct network_value_max_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return std::max(left->get(src, dest), right->get(src, dest)); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::max(left->get(source, target), right->get(source, target)); } void initialize(const network_label_dict& dict) override { @@ -1195,8 +1214,9 @@ struct network_value_min_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return std::min(left->get(src, dest), right->get(src, dest)); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::min(left->get(source, target), right->get(source, target)); } void initialize(const network_label_dict& dict) override { @@ -1216,16 +1236,14 @@ struct network_value_min_impl: public network_value_impl { struct network_value_exp_impl: public network_value_impl { std::shared_ptr value; - network_value_exp_impl(std::shared_ptr v): - value(std::move(v)) {} + network_value_exp_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - return std::exp(value->get(src, dest)); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + return std::exp(value->get(source, target)); } - void initialize(const network_label_dict& dict) override { - value->initialize(dict); - }; + void initialize(const network_label_dict& dict) override { value->initialize(dict); }; void print(std::ostream& os) const override { os << "(exp "; @@ -1237,18 +1255,16 @@ struct network_value_exp_impl: public network_value_impl { struct network_value_log_impl: public network_value_impl { std::shared_ptr value; - network_value_log_impl(std::shared_ptr v): - value(std::move(v)) {} + network_value_log_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - const auto v = value->get(src, dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + const auto v = value->get(source, target); if (v <= 0.0) throw arbor_exception("network_value: log of value <= 0.0."); - return std::log(value->get(src, dest)); + return std::log(value->get(source, target)); } - void initialize(const network_label_dict& dict) override { - value->initialize(dict); - }; + void initialize(const network_label_dict& dict) override { value->initialize(dict); }; void print(std::ostream& os) const override { os << "(log "; @@ -1269,9 +1285,10 @@ struct network_value_if_else_impl: public network_value_impl { true_value(std::move(true_value)), false_value(std::move(false_value)) {} - double get(const network_full_site_info& src, const network_full_site_info& dest) const override { - if (cond->select_connection(src, dest)) return true_value->get(src, dest); - return false_value->get(src, dest); + double get(const network_full_site_info& source, + const network_full_site_info& target) const override { + if (cond->select_connection(source, target)) return true_value->get(source, target); + return false_value->get(source, target); } void initialize(const network_label_dict& dict) override { @@ -1293,7 +1310,6 @@ struct network_value_if_else_impl: public network_value_impl { } // namespace - network_selection::network_selection(std::shared_ptr impl): impl_(std::move(impl)) {} @@ -1334,8 +1350,8 @@ network_selection network_selection::source_cell_kind(cell_kind kind) { return network_selection(std::make_shared(kind)); } -network_selection network_selection::destination_cell_kind(cell_kind kind) { - return network_selection(std::make_shared(kind)); +network_selection network_selection::target_cell_kind(cell_kind kind) { + return network_selection(std::make_shared(kind)); } network_selection network_selection::source_label(std::vector labels) { @@ -1343,9 +1359,9 @@ network_selection network_selection::source_label(std::vector lab std::make_shared(std::move(labels))); } -network_selection network_selection::destination_label(std::vector labels) { +network_selection network_selection::target_label(std::vector labels) { return network_selection( - std::make_shared(std::move(labels))); + std::make_shared(std::move(labels))); } network_selection network_selection::source_cell(std::vector gids) { @@ -1356,14 +1372,12 @@ network_selection network_selection::source_cell(gid_range range) { return network_selection(std::make_shared(range)); } -network_selection network_selection::destination_cell(std::vector gids) { - return network_selection( - std::make_shared(std::move(gids))); +network_selection network_selection::target_cell(std::vector gids) { + return network_selection(std::make_shared(std::move(gids))); } -network_selection network_selection::destination_cell(gid_range range) { - return network_selection( - std::make_shared(range)); +network_selection network_selection::target_cell(gid_range range) { + return network_selection(std::make_shared(range)); } network_selection network_selection::chain(std::vector gids) { @@ -1388,8 +1402,7 @@ network_selection network_selection::inter_cell() { } network_selection network_selection::random(unsigned seed, network_value p) { - return network_selection( - std::make_shared(seed, std::move(p))); + return network_selection(std::make_shared(seed, std::move(p))); } network_selection network_selection::custom(custom_func_type func) { @@ -1532,8 +1545,8 @@ ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_site_info ARB_ARBOR_API std::ostream& operator<<(std::ostream& os, const network_connection_info& s) { - os << ""; diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 4f27425203..168a47be15 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -89,9 +89,7 @@ struct site_mapping { } void insert(const site_mapping& m) { - for(std::size_t idx = 0; idx < m.size(); ++idx) { - this->insert(m.get_site(idx)); - } + for (std::size_t idx = 0; idx < m.size(); ++idx) { this->insert(m.get_site(idx)); } } network_full_site_info get_site(std::size_t idx) const { @@ -112,24 +110,33 @@ struct site_mapping { void push_back(const domain_decomposition& dom_dec, std::vector& vec, - const network_full_site_info& src, - const network_full_site_info& dest, + const network_full_site_info& source, + const network_full_site_info& target, double weight, double delay) { - vec.emplace_back(connection{{src.gid, src.lid}, dest.lid, (float)weight, (float)delay, dom_dec.index_on_domain(dest.gid)}); + vec.emplace_back(connection{{source.gid, source.lid}, + target.lid, + (float)weight, + (float)delay, + dom_dec.index_on_domain(target.gid)}); } void push_back(const domain_decomposition&, std::vector& vec, - const network_full_site_info& src, - const network_full_site_info& dest, + const network_full_site_info& source, + const network_full_site_info& target, double weight, double delay) { - vec.emplace_back(network_connection_info{ - network_site_info{ - src.gid, src.kind, std::string(src.label), src.location, src.global_location}, - network_site_info{ - dest.gid, dest.kind, std::string(dest.label), dest.location, dest.global_location}, + vec.emplace_back(network_connection_info{network_site_info{source.gid, + source.kind, + std::string(source.label), + source.location, + source.global_location}, + network_site_info{target.gid, + target.kind, + std::string(target.label), + target.location, + target.global_location}, weight, delay}); } @@ -166,7 +173,7 @@ std::vector generate_network_connections(const recipe& rec, for (const auto& [kind, gids]: gids_by_kind) { const auto batch_size = (gids.size() + num_batches - 1) / num_batches; - // populate network sites for source and destination + // populate network sites for source and target if (kind == cell_kind::cable) { const auto& cable_gids = gids; threading::parallel_for::apply( @@ -196,13 +203,13 @@ std::vector generate_network_connections(const recipe& rec, place_pwlin location_resolver(cell.morphology(), rec.get_cell_isometry(gid)); - // check all synapses of cell for potential destination + // check all synapses of cell for potential target for (const auto& [_, placed_synapses]: cell.synapses()) { for (const auto& p_syn: placed_synapses) { const auto& label = lid_to_label(cell.synapse_ranges(), p_syn.lid); - if (selection.select_destination(cell_kind::cable, gid, label)) { + if (selection.select_target(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_syn.loc); dest_sites.insert( {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); @@ -227,20 +234,20 @@ std::vector generate_network_connections(const recipe& rec, auto factory = cell_kind_implementation(kind, backend_kind::multicore, *ctx, 0); // We only need the label ranges - cell_label_range sources, destinations; - std::ignore = factory(gids, rec, sources, destinations); + cell_label_range sources, targets; + std::ignore = factory(gids, rec, sources, targets); auto& src_sites = src_site_batches[0]; auto& dest_sites = dest_site_batches[0]; std::size_t source_label_offset = 0; - std::size_t destination_label_offset = 0; + std::size_t target_label_offset = 0; for (std::size_t i = 0; i < gids.size(); ++i) { const auto gid = gids[i]; const auto iso = rec.get_cell_isometry(gid); const auto point = iso.apply(mpoint{0.0, 0.0, 0.0, 0.0}); const auto num_source_labels = sources.sizes().at(i); - const auto num_destination_labels = destinations.sizes().at(i); + const auto num_target_labels = targets.sizes().at(i); // Iterate over each source label for current gid for (std::size_t j = source_label_offset; @@ -255,28 +262,28 @@ std::vector generate_network_connections(const recipe& rec, } } - // Iterate over each destination label for current gid - for (std::size_t j = destination_label_offset; - j < destination_label_offset + num_destination_labels; + // Iterate over each target label for current gid + for (std::size_t j = target_label_offset; + j < target_label_offset + num_target_labels; ++j) { - const auto& label = destinations.labels().at(j); - const auto& range = destinations.ranges().at(j); + const auto& label = targets.labels().at(j); + const auto& range = targets.ranges().at(j); for (auto lid = range.begin; lid < range.end; ++lid) { - if (selection.select_destination(kind, gid, label)) { + if (selection.select_target(kind, gid, label)) { dest_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); } } } source_label_offset += num_source_labels; - destination_label_offset += num_destination_labels; + target_label_offset += num_target_labels; } } } site_mapping& src_sites = src_site_batches.front(); - // combine src batches + // combine source batches for (std::size_t batch_idx = 1; batch_idx < src_site_batches.size(); ++batch_idx) { for (std::size_t i = 0; i < src_site_batches[batch_idx].size(); ++i) { @@ -313,33 +320,33 @@ std::vector generate_network_connections(const recipe& rec, const auto& s = source_range[i]; const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); auto& connections = connection_batches[batch_idx]; - network_full_site_info src; - src.gid = s.gid; - src.lid = s.lid; - src.kind = s.kind; - src.label = label_range.data() + s.label_start_idx; - src.location = s.location; - src.global_location = s.global_location; - src.hash = s.hash; - - auto sample = [&](const network_full_site_info& dest) { - if (selection.select_connection(src, dest)) { - const auto w = weight.get(src, dest); - const auto d = delay.get(src, dest); - - push_back(dom_dec, connections, src, dest, w, d); + network_full_site_info source; + source.gid = s.gid; + source.lid = s.lid; + source.kind = s.kind; + source.label = label_range.data() + s.label_start_idx; + source.location = s.location; + source.global_location = s.global_location; + source.hash = s.hash; + + auto sample = [&](const network_full_site_info& target) { + if (selection.select_connection(source, target)) { + const auto w = weight.get(source, target); + const auto d = delay.get(source, target); + + push_back(dom_dec, connections, source, target, w, d); } }; if (selection.max_distance().has_value()) { const double d = selection.max_distance().value(); local_dest_tree.bounding_box_for_each( - decltype(local_dest_tree)::point_type{src.global_location.x - d, - src.global_location.y - d, - src.global_location.z - d}, - decltype(local_dest_tree)::point_type{src.global_location.x + d, - src.global_location.y + d, - src.global_location.z + d}, + decltype(local_dest_tree)::point_type{source.global_location.x - d, + source.global_location.y - d, + source.global_location.z - d}, + decltype(local_dest_tree)::point_type{source.global_location.x + d, + source.global_location.y + d, + source.global_location.z + d}, sample); } else { local_dest_tree.for_each(sample); } diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index a42a0de4a7..480ff7d3d8 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -1,9 +1,9 @@ #pragma once #include -#include -#include #include +#include +#include #include #include @@ -43,18 +43,18 @@ struct ARB_SYMBOL_VISIBLE network_full_site_info { struct network_selection_impl { virtual std::optional max_distance() const { return std::nullopt; } - virtual bool select_connection(const network_full_site_info& src, - const network_full_site_info& dest) const = 0; + virtual bool select_connection(const network_full_site_info& source, + const network_full_site_info& target) const = 0; virtual bool select_source(cell_kind kind, cell_gid_type gid, const std::string_view& tag) const = 0; - virtual bool select_destination(cell_kind kind, + virtual bool select_target(cell_kind kind, cell_gid_type gid, const std::string_view& tag) const = 0; - virtual void initialize(const network_label_dict& dict) {}; + virtual void initialize(const network_label_dict& dict){}; virtual void print(std::ostream& os) const = 0; @@ -67,11 +67,11 @@ inline std::shared_ptr thingify(network_selection s, return s.impl_; } - struct network_value_impl { - virtual double get(const network_full_site_info& src, const network_full_site_info& dest) const = 0; + virtual double get(const network_full_site_info& source, + const network_full_site_info& target) const = 0; - virtual void initialize(const network_label_dict& dict) {}; + virtual void initialize(const network_label_dict& dict){}; virtual void print(std::ostream& os) const = 0; diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp index 136bd0b618..72ca1f172f 100644 --- a/arborio/networkio.cpp +++ b/arborio/networkio.cpp @@ -69,9 +69,9 @@ eval_map_type network_eval_map{ {"source-cell-kind", make_call(arb::network_selection::source_cell_kind, "all sources of cells matching given cell kind argument: (kind:cell-kind)")}, - {"destination-cell-kind", - make_call(arb::network_selection::destination_cell_kind, - "all destinations of cells matching given cell kind argument: (kind:cell-kind)")}, + {"target-cell-kind", + make_call(arb::network_selection::target_cell_kind, + "all targets of cells matching given cell kind argument: (kind:cell-kind)")}, {"source-label", make_arg_vec_call( [](const std::vector>& vec) { @@ -83,7 +83,7 @@ eval_map_type network_eval_map{ return arb::network_selection::source_label(std::move(labels)); }, "all sources in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-label", + {"target-label", make_arg_vec_call( [](const std::vector>& vec) { std::vector labels; @@ -91,9 +91,9 @@ eval_map_type network_eval_map{ vec.begin(), vec.end(), std::back_inserter(labels), [](const auto& x) { return std::get(x); }); - return arb::network_selection::destination_label(std::move(labels)); + return arb::network_selection::target_label(std::move(labels)); }, - "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + "all targets in cell with gid in list: (gid:integer) [...(gid:integer)]")}, {"source-cell", make_arg_vec_call( [](const std::vector>& vec) { @@ -108,20 +108,20 @@ eval_map_type network_eval_map{ make_call(static_cast( arb::network_selection::source_cell), "all sources in cell with gid range: (range:gid-range)")}, - {"destination-cell", + {"target-cell", make_arg_vec_call( [](const std::vector>& vec) { std::vector gids; std::transform(vec.begin(), vec.end(), std::back_inserter(gids), [](const auto& x) { return std::get(x); }); - return arb::network_selection::destination_cell(std::move(gids)); + return arb::network_selection::target_cell(std::move(gids)); }, - "all destinations in cell with gid in list: (gid:integer) [...(gid:integer)]")}, - {"destination-cell", + "all targets in cell with gid in list: (gid:integer) [...(gid:integer)]")}, + {"target-cell", make_call(static_cast( - arb::network_selection::destination_cell), - "all destinations in cell with gid range: " + arb::network_selection::target_cell), + "all targets in cell with gid range: " "(range:gid-range)")}, {"chain", make_arg_vec_call( @@ -133,17 +133,17 @@ eval_map_type network_eval_map{ return arb::network_selection::chain(std::move(gids)); }, "A chain of connections in the given order of gids in list, such that entry \"i\" is " - "the source and entry \"i+1\" the destination: (gid:integer) [...(gid:integer)]")}, + "the source and entry \"i+1\" the target: (gid:integer) [...(gid:integer)]")}, {"chain", make_call( static_cast(arb::network_selection::chain), "A chain of connections for all gids in range [begin, end) with given step size. Each " - "entry \"i\" is connected as source to the destination \"i+1\": (begin:integer) " + "entry \"i\" is connected as source to the target \"i+1\": (begin:integer) " "(end:integer) (step:integer)")}, {"chain-reverse", make_call(arb::network_selection::chain_reverse, "A chain of connections for all gids in range [begin, end) with given step size. Each " - "entry \"i+1\" is connected as source to the destination \"i\". This results in " + "entry \"i+1\" is connected as source to the target \"i\". This results in " "connection directions in reverse compared to the (chain-range ...) selection: " "(begin:integer) " "(end:integer) (step:integer)")}, @@ -158,11 +158,11 @@ eval_map_type network_eval_map{ "p:network-value)")}, {"distance-lt", make_call(arb::network_selection::distance_lt, - "Select if distance between source and destination is less than given distance in " + "Select if distance between source and target is less than given distance in " "micro meter: (distance:real)")}, {"distance-gt", make_call(arb::network_selection::distance_gt, - "Select if distance between source and destination is greater than given distance in " + "Select if distance between source and target is greater than given distance in " "micro meter: (distance:real)")}, // network_value @@ -174,11 +174,11 @@ eval_map_type network_eval_map{ "A named network value with 1 argument: (value:string)")}, {"distance", make_call(arb::network_value::distance, - "Distance between source and destination scaled by given value with unit [1/um]. 1 " + "Distance between source and target scaled by given value with unit [1/um]. 1 " "argument: (scale:real)")}, {"distance", make_call<>([]() { return arb::network_value::distance(1.0); }, - "Distance between source and destination scaled by 1.0 with unit [1/um].")}, + "Distance between source and target scaled by 1.0 with unit [1/um].")}, {"uniform-distribution", make_call( [](unsigned seed, double begin, double end) { @@ -238,9 +238,12 @@ eval_map_type network_eval_map{ {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, {"log", make_call(arb::network_value::log, "Logarithm. 1 argument: (value:real)")}, - {"exp", make_call(arb::network_value::exp, "Exponential function. 1 argument: (value:real)")}, {"exp", - make_call(arb::network_value::exp, "Exponential function. 1 argument: (value:real)")}, + make_call(arb::network_value::exp, + "Exponential function. 1 argument: (value:real)")}, + {"exp", + make_call(arb::network_value::exp, + "Exponential function. 1 argument: (value:real)")}, }; parse_network_hopefully eval(const s_expr& e, const eval_map_type& map); diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index 1855e9f701..ee78cc3b39 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -23,7 +23,7 @@ The recipe callbacks are interrogated during simulation creation. High Level Network Description ------------------------------ -As an alternative to providing a list of connections for each cell in the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or destination label, cell indices and also distance between source and destination. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. +As an alternative to providing a list of connections for each cell in the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or target label, cell indices and also distance between source and target. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. Each connection also requires a weight and delay value. For this purpose, a ``network_value`` type is available, that allows to mathematically describe the value calculation using common math functions, as well random distributions. The following example shows the relevant recipe functions, where cells are connected into a ring with additional random connections between them: @@ -36,7 +36,7 @@ The following example shows the relevant recipe functions, where cells are conne # create a chain chain = f"(chain (gid-range 0 {self.ncells}))" # connect front and back of chain to form ring - ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (target-cell 0)))" # Create random connections with probability inversely proportional to the distance within a # radius @@ -46,8 +46,8 @@ The following example shows the relevant recipe functions, where cells are conne # combine ring with random selection s = f"(join {ring} {rand})" - # restrict to inter-cell connections and certain source / destination labels - s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' + # restrict to inter-cell connections and certain source / target labels + s = f'(intersect {s} (inter-cell) (source-label "detector") (target-label "syn"))' # fixed weight for connections in ring w_ring = "(scalar 0.01)" @@ -70,7 +70,7 @@ The following example shows the relevant recipe functions, where cells are conne return arbor.isometry.translate(radius * math.cos(angle), radius * math.sin(angle), 0) -The export function ``generate_network_connections`` allows the inspection of generated connections. The exported connections include the cell index, local label and location of both source and destination. +The export function ``generate_network_connections`` allows the inspection of generated connections. The exported connections include the cell index, local label and location of both source and target. .. note:: @@ -79,7 +79,7 @@ The export function ``generate_network_connections`` allows the inspection of ge .. note:: - A high-level description may be used together with providing explicit connection lists for each cell, but it is up to the user to avoid multiple connections between the same source and destination. + A high-level description may be used together with providing explicit connection lists for each cell, but it is up to the user to avoid multiple connections between the same source and target. .. warning:: @@ -155,17 +155,17 @@ Network Selection Expressions All connections, where the source cell is of the given type. -.. label:: (destination-cell-kind kind:cell-kind) +.. label:: (target-cell-kind kind:cell-kind) - All connections, where the destination cell is of the given type. + All connections, where the target cell is of the given type. .. label:: (source-label label:string) All connections, where the source label matches the given label. -.. label:: (destination-label label:string) +.. label:: (target-label label:string) - All connections, where the destination label matches the given label. + All connections, where the target label matches the given label. .. label:: (source-cell integer [...integer]) @@ -175,25 +175,25 @@ Network Selection Expressions All connections, where the source cell index is contained in the given gid-range. -.. label:: (destination-cell integer [...integer]) +.. label:: (target-cell integer [...integer]) - All connections, where the destination cell index matches one of the given integer values. + All connections, where the target cell index matches one of the given integer values. -.. label:: (destination-cell range:gid-range) +.. label:: (target-cell range:gid-range) - All connections, where the destination cell index is contained in the given gid-range. + All connections, where the target cell index is contained in the given gid-range. .. label:: (chain integer [...integer]) - A chain of connections between cells in the given order of in the list, such that entry "i" is the source and entry "i+1" the destination. + A chain of connections between cells in the given order of in the list, such that entry "i" is the source and entry "i+1" the target. .. label:: (chain range:gid-range) - A chain of connections between cells in the given order of the gid-range, such that entry "i" is the source and entry "i+1" the destination. + A chain of connections between cells in the given order of the gid-range, such that entry "i" is the source and entry "i+1" the target. .. label:: (chain-reverse range:gid-range) - A chain of connections between cells in reverse of the given order of the gid-range, such that entry "i+1" is the source and entry "i" the destination. + A chain of connections between cells in reverse of the given order of the gid-range, such that entry "i+1" is the source and entry "i" the target. .. label:: (random p:real) @@ -205,11 +205,11 @@ Network Selection Expressions .. label:: (distance-lt dist:real) - All connections, where the distance between source and destination is less than the given value in micro meter. + All connections, where the distance between source and target is less than the given value in micro meter. .. label:: (distance-gt dist:real) - All connections, where the distance between source and destination is greater than the given value in micro meter. + All connections, where the distance between source and target is greater than the given value in micro meter. .. _interconnectivity-value-expressions: @@ -227,11 +227,11 @@ Network Value Expressions .. label:: (distance) - The distance between source and destination. + The distance between source and target. .. label:: (distance value:real) - The distance between source and destination scaled by the given value. + The distance between source and target scaled by the given value. .. label:: (uniform-distribution seed:integer begin:real end:real) diff --git a/doc/cpp/interconnectivity.rst b/doc/cpp/interconnectivity.rst index 6fd34c4e35..8e034ed89c 100644 --- a/doc/cpp/interconnectivity.rst +++ b/doc/cpp/interconnectivity.rst @@ -8,11 +8,11 @@ Interconnectivity .. cpp:class:: cell_connection Describes a connection between two cells: a pre-synaptic source and a - post-synaptic destination. The source is typically a threshold detector on - a cell or a spike source. The destination is a synapse on the post-synaptic cell. + post-synaptic target. The source is typically a threshold detector on + a cell or a spike source. The target is a synapse on the post-synaptic cell. - The :cpp:member:`dest` does not include the gid of a cell, this is because a - :cpp:class:`cell_connection` is bound to the destination cell which means that the gid + The :cpp:member:`target` does not include the gid of a cell, this is because a + :cpp:class:`cell_connection` is bound to the target cell which means that the gid is implicitly known. .. cpp:member:: cell_global_label_type source @@ -20,7 +20,7 @@ Interconnectivity Source end point, represented by a :cpp:type:`cell_global_label_type` which packages a cell gid, label of a group of sources on the cell, and source selection policy. - .. cpp:member:: cell_local_label_type dest + .. cpp:member:: cell_local_label_type target Destination end point on the cell, represented by a :cpp:type:`cell_local_label_type` which packages a label of a group of targets on the cell and a selection policy. @@ -41,11 +41,11 @@ Interconnectivity .. cpp:class:: ext_cell_connection Describes a connection between two cells: a pre-synaptic source and a - post-synaptic destination. The source is typically a threshold detector on - a cell or a spike source. The destination is a synapse on the post-synaptic cell. + post-synaptic target. The source is typically a threshold detector on + a cell or a spike source. The target is a synapse on the post-synaptic cell. - The :cpp:member:`dest` does not include the gid of a cell, this is because a - :cpp:class:`ext_cell_connection` is bound to the destination cell which means that the gid + The :cpp:member:`target` does not include the gid of a cell, this is because a + :cpp:class:`ext_cell_connection` is bound to the target cell which means that the gid is implicitly known. .. cpp:member:: cell_remote_label_type source @@ -53,7 +53,7 @@ Interconnectivity Source end point, represented by a :cpp:type:`cell_remote_label_type` which packages a cell gid, integral tag of a group of sources on the cell, and source selection policy. - .. cpp:member:: cell_local_label_type dest + .. cpp:member:: cell_local_label_type target Destination end point on the cell, represented by a :cpp:type:`cell_local_label_type` which packages a label of a group of targets on the cell and a selection policy. @@ -129,13 +129,13 @@ Interconnectivity A network connection between cells. Used for generated connections through the high-level network description. - .. cpp:member:: network_site_info src + .. cpp:member:: network_site_info source The source connection site. - .. cpp:member:: network_site_info dest + .. cpp:member:: network_site_info target - The destination connection site. + The target connection site. .. cpp:class:: network_value @@ -152,7 +152,7 @@ Interconnectivity .. cpp:function:: network_value distance() - The value representing the distance between source and destination. + The value representing the distance between source and target. .. cpp:function:: network_value uniform_distribution(unsigned seed, const std::array& range) @@ -231,17 +231,17 @@ Interconnectivity Select connections with the given source cell kind - .. cpp:function:: network_selection destination_cell_kind(cell_kind kind); + .. cpp:function:: network_selection target_cell_kind(cell_kind kind); - Select connections with the given destination cell kind + Select connections with the given target cell kind .. cpp:function:: network_selection source_label(std::vector labels); Select connections with the given source label - .. cpp:function:: network_selection destination_label(std::vector labels); + .. cpp:function:: network_selection target_label(std::vector labels); - Select connections with the given destination label + Select connections with the given target label .. cpp:function:: network_selection source_cell(std::vector gids); @@ -251,25 +251,25 @@ Interconnectivity Select connections with source cells matching the indices in the range - .. cpp:function:: network_selection destination_cell(std::vector gids); + .. cpp:function:: network_selection target_cell(std::vector gids); - Select connections with destination cells matching the indices in the list + Select connections with target cells matching the indices in the list - .. cpp:function:: network_selection destination_cell(gid_range range); + .. cpp:function:: network_selection target_cell(gid_range range); - Select connections with destination cells matching the indices in the range + Select connections with target cells matching the indices in the range .. cpp:function:: network_selection chain(std::vector gids); - Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + Select connections that form a chain, such that source cell "i" is connected to the target cell "i+1" .. cpp:function:: network_selection chain(gid_range range); - Select connections that form a chain, such that source cell "i" is connected to the destination cell "i+1" + Select connections that form a chain, such that source cell "i" is connected to the target cell "i+1" .. cpp:function:: network_selection chain_reverse(gid_range range); - Select connections that form a reversed chain, such that source cell "i+1" is connected to the destination cell "i" + Select connections that form a reversed chain, such that source cell "i+1" is connected to the target cell "i" .. cpp:function:: network_selection intersect(network_selection left, network_selection right); diff --git a/doc/python/interconnectivity.rst b/doc/python/interconnectivity.rst index a65b68d4cd..62d1299b24 100644 --- a/doc/python/interconnectivity.rst +++ b/doc/python/interconnectivity.rst @@ -7,15 +7,15 @@ Interconnectivity .. class:: connection - Describes a connection between two cells, defined by source and destination end points (that is pre-synaptic and + Describes a connection between two cells, defined by source and target end points (that is pre-synaptic and post-synaptic respectively), a connection weight and a delay time. - The :attr:`dest` does not include the gid of a cell, this is because a :class:`arbor.connection` is bound to the - destination cell which means that the gid is implicitly known. + The :attr:`target` does not include the gid of a cell, this is because a :class:`arbor.connection` is bound to the + target cell which means that the gid is implicitly known. - .. function:: connection(source, destination, weight, delay) + .. function:: connection(source, target, weight, delay) - Construct a connection between the :attr:`source` and the :attr:`dest` with a :attr:`weight` and :attr:`delay`. + Construct a connection between the :attr:`source` and the :attr:`target` with a :attr:`weight` and :attr:`delay`. .. attribute:: source @@ -23,10 +23,10 @@ Interconnectivity (gid, label) or a (gid, (label, policy)) tuple. If the policy is not indicated, the default :attr:`arbor.selection_policy.univalent` is used). - .. attribute:: dest + .. attribute:: target - The destination end point of the connection (type: :class:`arbor.cell_local_label` representing the label of the - destination on the cell, which can be initialized with just a label, in which case the default + The target end point of the connection (type: :class:`arbor.cell_local_label` representing the label of the + target on the cell, which can be initialized with just a label, in which case the default :attr:`arbor.selection_policy.univalent` is used, or a (label, policy) tuple). The gid of the cell is implicitly known. @@ -63,18 +63,18 @@ Interconnectivity def connections_on(gid): # construct a connection from the "detector" source label on cell 2 # to the "syn" target label on cell gid with weight 0.01 and delay of 10 ms. - src = (2, "detector") # gid and locset label of the source - dest = "syn" # gid of the destination is determined by the argument to `connections_on`. + source = (2, "detector") # gid and locset label of the source + target = "syn" # gid of the target is determined by the argument to `connections_on`. w = 0.01 # weight of the connection. Correspondes to 0.01 μS on expsyn mechanisms d = 10 # delay in ms - return [arbor.connection(src, dest, w, d)] + return [arbor.connection(source, target, w, d)] .. class:: gap_junction_connection Describes a gap junction between two gap junction sites. The :attr:`local` site does not include the gid of a cell, this is because a :class:`arbor.gap_junction_connection` - is bound to the destination cell which means that the gid is implicitly known. + is bound to the target cell which means that the gid is implicitly known. .. note:: @@ -96,7 +96,7 @@ Interconnectivity .. attribute:: local The gap junction site: the local half of the gap junction connection (type: :class:`arbor.cell_local_label` - representing the label of the destination on the cell, which can be initialized with just a label, in which case + representing the label of the target on the cell, which can be initialized with just a label, in which case the default :attr:`arbor.selection_policy.univalent` is used, or a (label, policy) tuple). The gid of the cell is implicitly known. @@ -143,13 +143,13 @@ Interconnectivity A network connection between cells. Used for generated connections through the high-level network description. - .. attribute:: src + .. attribute:: source The source connection site. - .. attribute:: dest + .. attribute:: target - The destination connection site. + The target connection site. .. class:: network_description diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index 20f923fa61..d2344b3292 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -23,14 +23,14 @@ #include #include #include +#include +#include #include #include #include #include #include #include -#include -#include #include #include @@ -58,10 +58,10 @@ struct ring_params { ring_params read_options(int argc, char** argv); using arb::cell_gid_type; +using arb::cell_kind; using arb::cell_lid_type; -using arb::cell_size_type; using arb::cell_member_type; -using arb::cell_kind; +using arb::cell_size_type; using arb::time_type; // Writes voltage trace as a json file. @@ -75,22 +75,17 @@ class ring_recipe: public arb::recipe { ring_recipe(unsigned num_cells, cell_parameters params, unsigned min_delay): num_cells_(num_cells), cell_params_(params), - min_delay_(min_delay) - { + min_delay_(min_delay) { gprop_.default_parameters = arb::neuron_parameter_defaults; } - cell_size_type num_cells() const override { - return num_cells_; - } + cell_size_type num_cells() const override { return num_cells_; } arb::util::unique_any get_cell_description(cell_gid_type gid) const override { return branch_cell(gid, cell_params_); } - cell_kind get_cell_kind(cell_gid_type gid) const override { - return cell_kind::cable; - } + cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; } arb::isometry get_cell_isometry(cell_gid_type gid) const override { // place cells with equal distance on a circle @@ -105,10 +100,10 @@ class ring_recipe: public arb::recipe { // connect front and back of chain to form ring ring = arb::join(ring, arb::intersect(arb::network_selection::source_cell({num_cells_ - 1}), - arb::network_selection::destination_cell({0}))); + arb::network_selection::target_cell({0}))); - // Create random connections with probability inversely proportional to the distance within a - // radius + // Create random connections with probability inversely proportional to the distance within + // a radius const double max_dist = 400.0; auto probability = (max_dist - arb::network_value::distance()) / max_dist; @@ -121,10 +116,10 @@ class ring_recipe: public arb::recipe { // combine ring with random selection auto s = join(ring, rand); - // restrict to certain source and destination labels + // restrict to certain source and target labels s = arb::intersect(s, arb::network_selection::source_label({"detector"}), - arb::network_selection::destination_label({"primary_syn"})); + arb::network_selection::target_label({"primary_syn"})); // random normal distributed weight with mean 0.05 μS, standard deviation 0.02 μS // and truncated to [0.025, 0.075] @@ -146,7 +141,8 @@ class ring_recipe: public arb::recipe { std::vector event_generators(cell_gid_type gid) const override { std::vector gens; if (!gid) { - gens.push_back(arb::explicit_generator({"primary_syn"}, event_weight_, std::vector{1.0f})); + gens.push_back( + arb::explicit_generator({"primary_syn"}, event_weight_, std::vector{1.0f})); } return gens; } @@ -157,11 +153,7 @@ class ring_recipe: public arb::recipe { return {arb::cable_probe_membrane_voltage{loc}}; } - std::any get_global_properties(arb::cell_kind) const override { - return gprop_; - } - - + std::any get_global_properties(arb::cell_kind) const override { return gprop_; } private: cell_size_type num_cells_; @@ -195,9 +187,9 @@ int main(int argc, char** argv) { std::cout << sup::mask_stream(root); // Print a banner with information about hardware configuration - std::cout << "gpu: " << (has_gpu(context)? "yes": "no") << "\n"; + std::cout << "gpu: " << (has_gpu(context) ? "yes" : "no") << "\n"; std::cout << "threads: " << num_threads(context) << "\n"; - std::cout << "mpi: " << (has_mpi(context)? "yes": "no") << "\n"; + std::cout << "mpi: " << (has_mpi(context) ? "yes" : "no") << "\n"; std::cout << "ranks: " << num_ranks(context) << "\n" << std::endl; auto params = read_options(argc, argv); @@ -234,23 +226,20 @@ int main(int argc, char** argv) { meters.checkpoint("model-init", context); - if (root) { - sim.set_epoch_callback(arb::epoch_progress_bar()); - } + if (root) { sim.set_epoch_callback(arb::epoch_progress_bar()); } std::cout << "running simulation\n" << std::endl; // Run the simulation for 100 ms, with time steps of 0.025 ms. sim.run(params.duration, 0.025); meters.checkpoint("model-run", context); - // Print generated connections if (root) { const auto connections = arb::generate_network_connections(recipe); std::cout << "Connections:" << std::endl; - for(const auto& c: connections) { - std::cout << "(" << c.src.gid << ", \"" << c.src.label << "\") ->"; - std::cout << "(" << c.dest.gid << ", \"" << c.dest.label << "\")" << std::endl; + for (const auto& c: connections) { + std::cout << "(" << c.source.gid << ", \"" << c.source.label << "\") ->"; + std::cout << "(" << c.target.gid << ", \"" << c.target.label << "\")" << std::endl; } } @@ -258,8 +247,9 @@ int main(int argc, char** argv) { // Write spikes to file if (root) { - std::cout << "\n" << ns << " spikes generated at rate of " - << params.duration/ns << " ms between spikes\n"; + std::cout << "\n" + << ns << " spikes generated at rate of " << params.duration / ns + << " ms between spikes\n"; std::ofstream fid("spikes.gdf"); if (!fid.good()) { std::cerr << "Warning: unable to open file spikes.gdf for spike output\n"; @@ -267,18 +257,18 @@ int main(int argc, char** argv) { else { char linebuf[45]; for (auto spike: recorded_spikes) { - auto n = std::snprintf( - linebuf, sizeof(linebuf), "%u %.4f\n", - unsigned{spike.source.gid}, float(spike.time)); + auto n = std::snprintf(linebuf, + sizeof(linebuf), + "%u %.4f\n", + unsigned{spike.source.gid}, + float(spike.time)); fid.write(linebuf, n); } } } // Write the samples to a json file. - if (root) { - write_trace_json(voltage.at(0)); - } + if (root) { write_trace_json(voltage.at(0)); } auto profile = arb::profile::profiler_summary(); std::cout << profile << "\n"; @@ -319,11 +309,11 @@ ring_params read_options(int argc, char** argv) { using sup::param_from_json; ring_params params; - if (argc<2) { + if (argc < 2) { std::cout << "Using default parameters.\n"; return params; } - if (argc>2) { + if (argc > 2) { throw std::runtime_error("More than one command line option is not permitted."); } @@ -331,9 +321,7 @@ ring_params read_options(int argc, char** argv) { std::cout << "Loading parameters from file: " << fname << "\n"; std::ifstream f(fname); - if (!f.good()) { - throw std::runtime_error("Unable to open input parameter file: "+fname); - } + if (!f.good()) { throw std::runtime_error("Unable to open input parameter file: " + fname); } nlohmann::json json; f >> json; @@ -345,7 +333,7 @@ ring_params read_options(int argc, char** argv) { params.cell = parse_cell_parameters(json); if (!json.empty()) { - for (auto it=json.begin(); it!=json.end(); ++it) { + for (auto it = json.begin(); it != json.end(); ++it) { std::cout << " Warning: unused input parameter: \"" << it.key() << "\"\n"; } std::cout << "\n"; @@ -353,4 +341,3 @@ ring_params read_options(int argc, char** argv) { return params; } - diff --git a/python/example/network_description.py b/python/example/network_description.py index bfc242c5de..09a855df45 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -109,7 +109,7 @@ def network_description(self): # create a chain chain = f"(chain (gid-range 0 {self.ncells}))" # connect front and back of chain to form ring - ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (destination-cell 0)))" + ring = f"(join {chain} (intersect (source-cell {self.ncells - 1}) (target-cell 0)))" # Create random connections with probability inversely proportional to the distance within a # radius @@ -119,8 +119,8 @@ def network_description(self): # combine ring with random selection s = f"(join {ring} {rand})" - # restrict to inter-cell connections and certain source / destination labels - s = f'(intersect {s} (inter-cell) (source-label "detector") (destination-label "syn"))' + # restrict to inter-cell connections and certain source / target labels + s = f'(intersect {s} (inter-cell) (source-label "detector") (target-label "syn"))' # fixed weight for connections in ring w_ring = "(scalar 0.01)" @@ -169,7 +169,7 @@ def global_properties(self, kind): print("connections:") for c in connections: - print(f'({c.src.gid}, "{c.src.label}") -> ({c.dest.gid}, "{c.dest.label}")') + print(f'({c.source.gid}, "{c.source.label}") -> ({c.target.gid}, "{c.target.label}")') # (16) Run simulation for 100 ms sim.run(100) diff --git a/python/network.cpp b/python/network.cpp index 90780e6ec4..c267211535 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -1,8 +1,8 @@ +#include #include #include #include #include -#include #include #include @@ -11,10 +11,10 @@ #include #include -#include +#include #include +#include #include -#include #include "context.hpp" #include "error.hpp" @@ -41,8 +41,8 @@ void register_network(py::module& m) { py::class_ network_connection_info( m, "network_connection_info", "Identifies a network connection"); - network_connection_info.def_readwrite("src", &arb::network_connection_info::src) - .def_readwrite("dest", &arb::network_connection_info::dest) + network_connection_info.def_readwrite("source", &arb::network_connection_info::source) + .def_readwrite("target", &arb::network_connection_info::target) .def_readwrite("weight", &arb::network_connection_info::weight) .def_readwrite("delay", &arb::network_connection_info::delay) .def("__repr__", @@ -56,15 +56,14 @@ void register_network(py::module& m) { network_selection .def_static("custom", [](arb::network_selection::custom_func_type func) { - return arb::network_selection::custom( - [=](const arb::network_connection_info& c) { - return try_catch_pyexception( - [&]() { - pybind11::gil_scoped_acquire guard; - return func(c); - }, - "Python error already thrown"); - }); + return arb::network_selection::custom([=](const arb::network_connection_info& c) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(c); + }, + "Python error already thrown"); + }); }) .def("__str__", [](const arb::network_selection& s) { @@ -76,15 +75,14 @@ void register_network(py::module& m) { network_value .def_static("custom", [](arb::network_value::custom_func_type func) { - return arb::network_value::custom( - [=](const arb::network_connection_info& c) { - return try_catch_pyexception( - [&]() { - pybind11::gil_scoped_acquire guard; - return func(c); - }, - "Python error already thrown"); - }); + return arb::network_value::custom([=](const arb::network_connection_info& c) { + return try_catch_pyexception( + [&]() { + pybind11::gil_scoped_acquire guard; + return func(c); + }, + "Python error already thrown"); + }); }) .def("__str__", [](const arb::network_value& v) { @@ -130,7 +128,7 @@ void register_network(py::module& m) { [&](const arb::network_value& val) { dict.set(dict_label, val); }), v); } - auto desc = arb::network_description{ + auto desc = arb::network_description{ arborio::parse_network_selection_expression(selection).unwrap(), arborio::parse_network_value_expression(weight).unwrap(), arborio::parse_network_value_expression(delay).unwrap(), @@ -159,7 +157,7 @@ void register_network(py::module& m) { pybind11::arg_v("context", pybind11::none(), "Execution context"), pybind11::arg_v("decomp", pybind11::none(), "Domain decomposition"), "Generate network connections from the network description in the recipe. Will only " - "generate connections with local gids in the domain composition as destination."); + "generate connections with local gids in the domain composition as target."); } } // namespace pyarb diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index c282c762e5..03add15bba 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -217,7 +217,7 @@ namespace { std::vector connections_on(cell_gid_type gid) const override { // a single connection from the preceding cell, i.e. a ring - // weight is the destination gid + // weight is the target gid // delay is 1 cell_global_label_type src = {gid==0? size_-1: gid-1, "src"}; cell_local_label_type dst = {"tgt"}; @@ -291,7 +291,7 @@ namespace { for (auto sid: util::make_span(0, size_)) { cell_connection con( {sid, {"src", arb::lid_selection_policy::round_robin}}, // source - {"tgt", arb::lid_selection_policy::round_robin}, // destination + {"tgt", arb::lid_selection_policy::round_robin}, // target float(gid+sid), // weight 1.0f); // delay cons.push_back(con); @@ -652,7 +652,7 @@ TEST(communicator, all2all) auto c = connections[i*n_local+j]; EXPECT_EQ(i, c.source.gid); EXPECT_EQ(0u, c.source.index); - EXPECT_EQ(i, c.destination); + EXPECT_EQ(i, c.target); EXPECT_LT(c.index_on_domain, n_local); } } @@ -696,7 +696,7 @@ TEST(communicator, mini_network) // sort connections by source then target auto connections = C.connections(); util::sort(connections, [](const connection& lhs, const connection& rhs) { - return std::forward_as_tuple(lhs.source, lhs.index_on_domain, lhs.destination) < std::forward_as_tuple(rhs.source, rhs.index_on_domain, rhs.destination); + return std::forward_as_tuple(lhs.source, lhs.index_on_domain, lhs.target) < std::forward_as_tuple(rhs.source, rhs.index_on_domain, rhs.target); }); // Expect one set of 22 connections from every rank: these have been sorted. @@ -710,7 +710,7 @@ TEST(communicator, mini_network) auto c = connections[i*22 + j]; EXPECT_EQ(ex_source_gids[j], c.source.gid); EXPECT_EQ(ex_source_lids[j], c.source.index); - EXPECT_EQ(ex_target_lids[i%2][j], c.destination); + EXPECT_EQ(ex_target_lids[i%2][j], c.target); } } } diff --git a/test/unit-distributed/test_network_generation.cpp b/test/unit-distributed/test_network_generation.cpp index 298a77f5b8..5e7c78f68f 100644 --- a/test/unit-distributed/test_network_generation.cpp +++ b/test/unit-distributed/test_network_generation.cpp @@ -3,25 +3,24 @@ #include #include -#include -#include #include #include +#include #include #include #include #include +#include #include -#include "test.hpp" #include "execution_context.hpp" +#include "test.hpp" using namespace arb; using namespace arborio::literals; - namespace { -// Create alternatingly a cable, lif and spike source cell with at most one source or destination +// Create alternatingly a cable, lif and spike source cell with at most one source or target class network_test_recipe: public arb::recipe { public: network_test_recipe(unsigned num_cells, @@ -35,19 +34,13 @@ class network_test_recipe: public arb::recipe { gprop_.default_parameters = arb::neuron_parameter_defaults; } - cell_size_type num_cells() const override { - return num_cells_; - } + cell_size_type num_cells() const override { return num_cells_; } arb::util::unique_any get_cell_description(cell_gid_type gid) const override { - if(gid % 3 == 1) { - return lif_cell("source", "target"); - } - if(gid % 3 == 2) { - return spike_source_cell("spike_source"); - } - - // cable cell + if (gid % 3 == 1) { return lif_cell("source", "target"); } + if (gid % 3 == 2) { return spike_source_cell("spike_source"); } + + // cable cell int stag = 1; // soma tag int dtag = 3; // Dendrite tag. double srad = 12.6157 / 2.0; // soma radius @@ -72,12 +65,8 @@ class network_test_recipe: public arb::recipe { } cell_kind get_cell_kind(cell_gid_type gid) const override { - if(gid % 3 == 1) { - return cell_kind::lif; - } - if(gid % 3 == 2) { - return cell_kind::spike_source; - } + if (gid % 3 == 1) { return cell_kind::lif; } + if (gid % 3 == 2) { return cell_kind::spike_source; } return cell_kind::cable; } @@ -97,9 +86,7 @@ class network_test_recipe: public arb::recipe { return {}; } - std::vector get_probes(cell_gid_type gid) const override { - return {}; - } + std::vector get_probes(cell_gid_type gid) const override { return {}; } std::any get_global_properties(arb::cell_kind) const override { return gprop_; } @@ -130,27 +117,26 @@ TEST(network_generation, all) { std::unordered_map> connections_by_dest; - for(const auto& c : connections) { + for (const auto& c: connections) { EXPECT_EQ(c.weight, weight); EXPECT_EQ(c.delay, delay); - connections_by_dest[c.dest.gid].emplace_back(c); + connections_by_dest[c.target.gid].emplace_back(c); } for (const auto& group: decomp.groups()) { const auto num_dest = group.kind == cell_kind::spike_source ? 0 : 1; - for(const auto gid : group.gids) { - EXPECT_EQ(connections_by_dest[gid].size(), num_cells * num_dest); - } + for (const auto gid: group.gids) { + EXPECT_EQ(connections_by_dest[gid].size(), num_cells * num_dest); + } } } - TEST(network_generation, cable_only) { const auto& ctx = g_context; const int num_ranks = ctx->distributed->size(); const auto selection = intersect(network_selection::source_cell_kind(cell_kind::cable), - network_selection::destination_cell_kind(cell_kind::cable)); + network_selection::target_cell_kind(cell_kind::cable)); const auto weight = 2.0; const auto delay = 3.0; @@ -164,15 +150,15 @@ TEST(network_generation, cable_only) { std::unordered_map> connections_by_dest; - for(const auto& c : connections) { + for (const auto& c: connections) { EXPECT_EQ(c.weight, weight); EXPECT_EQ(c.delay, delay); - connections_by_dest[c.dest.gid].emplace_back(c); + connections_by_dest[c.target.gid].emplace_back(c); } for (const auto& group: decomp.groups()) { - for(const auto gid : group.gids) { - // Only one third is a cable cell + for (const auto gid: group.gids) { + // Only one third is a cable cell EXPECT_EQ(connections_by_dest[gid].size(), group.kind == cell_kind::cable ? num_cells / 3 : 0); } diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp index 039fa01505..09e80b1563 100644 --- a/test/unit/test_network.cpp +++ b/test/unit/test_network.cpp @@ -48,25 +48,24 @@ TEST(network_selection, all) { for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { - for (const auto& dest: test_sites) { EXPECT_TRUE(s->select_connection(source, dest)); } + for (const auto& target: test_sites) { EXPECT_TRUE(s->select_connection(source, target)); } } } - TEST(network_selection, none) { const auto s = thingify(network_selection::none(), network_label_dict()); for (const auto& site: test_sites) { EXPECT_FALSE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_FALSE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_FALSE(s->select_target(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { - for (const auto& dest: test_sites) { EXPECT_FALSE(s->select_connection(source, dest)); } + for (const auto& target: test_sites) { EXPECT_FALSE(s->select_connection(source, target)); } } } @@ -77,30 +76,29 @@ TEST(network_selection, source_cell_kind) { for (const auto& site: test_sites) { EXPECT_EQ( site.kind == cell_kind::benchmark, s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(source.kind == cell_kind::benchmark, s->select_connection(source, dest)); + for (const auto& target: test_sites) { + EXPECT_EQ(source.kind == cell_kind::benchmark, s->select_connection(source, target)); } } } - -TEST(network_selection, destination_cell_kind) { +TEST(network_selection, target_cell_kind) { const auto s = - thingify(network_selection::destination_cell_kind(cell_kind::benchmark), network_label_dict()); + thingify(network_selection::target_cell_kind(cell_kind::benchmark), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( - site.kind == cell_kind::benchmark, s->select_destination(site.kind, site.gid, site.label)); + site.kind == cell_kind::benchmark, s->select_target(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(dest.kind == cell_kind::benchmark, s->select_connection(source, dest)); + for (const auto& target: test_sites) { + EXPECT_EQ(target.kind == cell_kind::benchmark, s->select_connection(source, target)); } } } @@ -111,29 +109,30 @@ TEST(network_selection, source_label) { for (const auto& site: test_sites) { EXPECT_EQ(site.label == "b" || site.label == "e", s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { - for (const auto& dest: test_sites) { + for (const auto& target: test_sites) { EXPECT_EQ( - source.label == "b" || source.label == "e", s->select_connection(source, dest)); + source.label == "b" || source.label == "e", s->select_connection(source, target)); } } } -TEST(network_selection, destination_label) { - const auto s = thingify(network_selection::destination_label({"b", "e"}), network_label_dict()); +TEST(network_selection, target_label) { + const auto s = thingify(network_selection::target_label({"b", "e"}), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ(site.label == "b" || site.label == "e", - s->select_destination(site.kind, site.gid, site.label)); + s->select_target(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(dest.label == "b" || dest.label == "e", s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ( + target.label == "b" || target.label == "e", s->select_connection(source, target)); } } } @@ -144,28 +143,28 @@ TEST(network_selection, source_cell_vec) { for (const auto& site: test_sites) { EXPECT_EQ( site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == 1 || src.gid == 5, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 || source.gid == 5, s->select_connection(source, target)); } } } -TEST(network_selection, destination_cell_vec) { - const auto s = thingify(network_selection::destination_cell({{1, 5}}), network_label_dict()); +TEST(network_selection, target_cell_vec) { + const auto s = thingify(network_selection::target_cell({{1, 5}}), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( - site.gid == 1 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + site.gid == 1 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(dest.gid == 1 || dest.gid == 5, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.gid == 1 || target.gid == 5, s->select_connection(source, target)); } } } @@ -177,102 +176,99 @@ TEST(network_selection, source_cell_range) { for (const auto& site: test_sites) { EXPECT_EQ( site.gid == 1 || site.gid == 5, s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == 1 || src.gid == 5, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 || source.gid == 5, s->select_connection(source, target)); } } } -TEST(network_selection, destination_cell_range) { +TEST(network_selection, target_cell_range) { const auto s = - thingify(network_selection::destination_cell(gid_range(1, 6, 4)), network_label_dict()); + thingify(network_selection::target_cell(gid_range(1, 6, 4)), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( - site.gid == 1 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + site.gid == 1 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(dest.gid == 1 || dest.gid == 5, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(target.gid == 1 || target.gid == 5, s->select_connection(source, target)); } } } TEST(network_selection, chain) { - const auto s = - thingify(network_selection::chain({{0,2,5}}), network_label_dict()); + const auto s = thingify(network_selection::chain({{0, 2, 5}}), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( site.gid == 0 || site.gid == 2, s->select_source(site.kind, site.gid, site.label)); EXPECT_EQ( - site.gid == 2 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + site.gid == 2 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ((src.gid == 0 && dest.gid == 2) || (src.gid == 2 && dest.gid == 5), - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 0 && target.gid == 2) || (source.gid == 2 && target.gid == 5), + s->select_connection(source, target)); } } } TEST(network_selection, chain_range) { - const auto s = - thingify(network_selection::chain({gid_range(1,8,3)}), network_label_dict()); + const auto s = thingify(network_selection::chain({gid_range(1, 8, 3)}), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); EXPECT_EQ( - site.gid == 4 || site.gid == 7, s->select_destination(site.kind, site.gid, site.label)); + site.gid == 4 || site.gid == 7, s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ((src.gid == 1 && dest.gid == 4) || (src.gid == 4 && dest.gid == 7), - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 1 && target.gid == 4) || (source.gid == 4 && target.gid == 7), + s->select_connection(source, target)); } } } TEST(network_selection, chain_range_reverse) { const auto s = - thingify(network_selection::chain_reverse({gid_range(1,8,3)}), network_label_dict()); + thingify(network_selection::chain_reverse({gid_range(1, 8, 3)}), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ( site.gid == 7 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); EXPECT_EQ( - site.gid == 4 || site.gid == 1, s->select_destination(site.kind, site.gid, site.label)); + site.gid == 4 || site.gid == 1, s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ((src.gid == 7 && dest.gid == 4) || (src.gid == 4 && dest.gid == 1), - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 7 && target.gid == 4) || (source.gid == 4 && target.gid == 1), + s->select_connection(source, target)); } } } TEST(network_selection, inter_cell) { - const auto s = - thingify(network_selection::inter_cell(), network_label_dict()); + const auto s = thingify(network_selection::inter_cell(), network_label_dict()); for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid != dest.gid, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid != target.gid, s->select_connection(source, target)); } } } @@ -280,34 +276,33 @@ TEST(network_selection, inter_cell) { TEST(network_selection, named) { network_label_dict dict; dict.set("mysel", network_selection::inter_cell()); - const auto s = - thingify(network_selection::named("mysel"), dict); + const auto s = thingify(network_selection::named("mysel"), dict); for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid != dest.gid, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid != target.gid, s->select_connection(source, target)); } } } TEST(network_selection, intersect) { const auto s = thingify(network_selection::intersect(network_selection::source_cell({1}), - network_selection::destination_cell({2})), + network_selection::target_cell({2})), network_label_dict()); for (const auto& site: test_sites) { EXPECT_EQ(site.gid == 1, s->select_source(site.kind, site.gid, site.label)); - EXPECT_EQ(site.gid == 2, s->select_destination(site.kind, site.gid, site.label)); + EXPECT_EQ(site.gid == 2, s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == 1 && dest.gid == 2, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 1 && target.gid == 2, s->select_connection(source, target)); } } } @@ -315,20 +310,22 @@ TEST(network_selection, intersect) { TEST(network_selection, join) { const auto s = thingify( network_selection::join(network_selection::intersect(network_selection::source_cell({1}), - network_selection::destination_cell({2})), + network_selection::target_cell({2})), network_selection::intersect( - network_selection::source_cell({4}), network_selection::destination_cell({5}))), + network_selection::source_cell({4}), network_selection::target_cell({5}))), network_label_dict()); for (const auto& site: test_sites) { - EXPECT_EQ(site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); - EXPECT_EQ(site.gid == 2 || site.gid == 5, s->select_destination(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 1 || site.gid == 4, s->select_source(site.kind, site.gid, site.label)); + EXPECT_EQ( + site.gid == 2 || site.gid == 5, s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ((src.gid == 1 && dest.gid == 2) || (src.gid == 4 && dest.gid == 5), - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ((source.gid == 1 && target.gid == 2) || (source.gid == 4 && target.gid == 5), + s->select_connection(source, target)); } } } @@ -342,12 +339,12 @@ TEST(network_selection, difference) { for (const auto& site: test_sites) { EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2, s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == 0 || src.gid == 2, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 0 || source.gid == 2, s->select_connection(source, target)); } } } @@ -361,12 +358,13 @@ TEST(network_selection, symmetric_difference) { for (const auto& site: test_sites) { EXPECT_EQ(site.gid == 0 || site.gid == 1 || site.gid == 2 || site.gid == 3, s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == 0 || src.gid == 2 || src.gid == 3, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == 0 || source.gid == 2 || source.gid == 3, + s->select_connection(source, target)); } } } @@ -377,12 +375,12 @@ TEST(network_selection, complement) { for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(src.gid == dest.gid, s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(source.gid == target.gid, s->select_connection(source, target)); } } } @@ -392,13 +390,11 @@ TEST(network_selection, random_p_1) { for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_TRUE(s->select_connection(src, dest)); - } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_TRUE(s->select_connection(source, target)); } } } @@ -407,13 +403,11 @@ TEST(network_selection, random_p_0) { for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_FALSE(s->select_connection(src, dest)); - } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_FALSE(s->select_connection(source, target)); } } } @@ -422,9 +416,10 @@ TEST(network_selection, random_seed) { const auto s2 = thingify(network_selection::random(4592304, 0.5), network_label_dict()); bool all_eq = true; - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - all_eq &= (s1->select_connection(src, dest) == s2->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + all_eq &= + (s1->select_connection(source, target) == s2->select_connection(source, target)); } } EXPECT_FALSE(all_eq); @@ -441,9 +436,9 @@ TEST(network_selection, random_reproducibility) { std::vector ref = {1, 1, 0, 1, 1, 0, 0, 0, 0}; std::size_t i = 0; - for (const auto& src: sites) { - for (const auto& dest: sites) { - EXPECT_EQ(ref.at(i), s->select_connection(src, dest)); + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_EQ(ref.at(i), s->select_connection(source, target)); ++i; } }; @@ -451,73 +446,71 @@ TEST(network_selection, random_reproducibility) { TEST(network_selection, custom) { auto inter_cell_func = [](const network_connection_info& c) { - return c.src.gid != c.dest.gid; + return c.source.gid != c.target.gid; }; const auto s = thingify(network_selection::custom(inter_cell_func), network_label_dict()); const auto s_ref = thingify(network_selection::inter_cell(), network_label_dict()); for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(s->select_connection(src, dest), s_ref->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ( + s->select_connection(source, target), s_ref->select_connection(source, target)); } } } TEST(network_selection, distance_lt) { const double d = 2.1; - const auto s = - thingify(network_selection::distance_lt(d), network_label_dict()); + const auto s = thingify(network_selection::distance_lt(d), network_label_dict()); for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(distance(src.global_location, dest.global_location) < d, - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(distance(source.global_location, target.global_location) < d, + s->select_connection(source, target)); } } } TEST(network_selection, distance_gt) { const double d = 2.1; - const auto s = - thingify(network_selection::distance_gt(d), network_label_dict()); + const auto s = thingify(network_selection::distance_gt(d), network_label_dict()); for (const auto& site: test_sites) { EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); - EXPECT_TRUE(s->select_destination(site.kind, site.gid, site.label)); + EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_EQ(distance(src.global_location, dest.global_location) > d, - s->select_connection(src, dest)); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_EQ(distance(source.global_location, target.global_location) > d, + s->select_connection(source, target)); } } } - TEST(network_value, scalar) { const auto v = thingify(network_value::scalar(2.0), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } } } TEST(network_value, conversion) { const auto v = thingify(static_cast(2.0), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } } } @@ -526,28 +519,29 @@ TEST(network_value, named) { dict.set("myval", network_value::scalar(2.0)); const auto v = thingify(network_value::named("myval"), dict); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(src, dest)); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(2.0, v->get(source, target)); } } } TEST(network_value, distance) { const auto v = thingify(network_value::distance(), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ( - distance(src.global_location, dest.global_location), v->get(src, dest)); + distance(source.global_location, target.global_location), v->get(source, target)); } } } TEST(network_value, uniform_distribution) { - const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + const auto v = + thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); double mean = 0.0; - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { mean += v->get(src, dest); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { mean += v->get(source, target); } } mean /= test_sites.size() * test_sites.size(); @@ -555,7 +549,8 @@ TEST(network_value, uniform_distribution) { } TEST(network_value, uniform_distribution_reproducibility) { - const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); + const auto v = + thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, @@ -575,9 +570,9 @@ TEST(network_value, uniform_distribution_reproducibility) { }; std::size_t i = 0; - for (const auto& src: sites) { - for (const auto& dest: sites) { - EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); ++i; } }; @@ -586,13 +581,14 @@ TEST(network_value, uniform_distribution_reproducibility) { TEST(network_value, normal_distribution) { const double mean = 5.0; const double std_dev = 3.0; - const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + const auto v = + thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); double sample_mean = 0.0; double sample_dev = 0.0; - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - const auto result = v->get(src, dest); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + const auto result = v->get(source, target); sample_mean += result; sample_dev += (result - mean) * (result - mean); } @@ -608,7 +604,8 @@ TEST(network_value, normal_distribution) { TEST(network_value, normal_distribution_reproducibility) { const double mean = 5.0; const double std_dev = 3.0; - const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); + const auto v = + thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); std::vector sites = { {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, @@ -628,9 +625,9 @@ TEST(network_value, normal_distribution_reproducibility) { }; std::size_t i = 0; - for (const auto& src: sites) { - for (const auto& dest: sites) { - EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); ++i; } }; @@ -649,9 +646,9 @@ TEST(network_value, truncated_normal_distribution) { double sample_mean = 0.0; - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - const auto result = v->get(src, dest); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + const auto result = v->get(source, target); EXPECT_GT(result, lower_bound); EXPECT_LE(result, upper_bound); sample_mean += result; @@ -692,9 +689,9 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { }; std::size_t i = 0; - for (const auto& src: sites) { - for (const auto& dest: sites) { - EXPECT_DOUBLE_EQ(ref.at(i), v->get(src, dest)); + for (const auto& source: sites) { + for (const auto& target: sites) { + EXPECT_DOUBLE_EQ(ref.at(i), v->get(source, target)); ++i; } }; @@ -702,14 +699,15 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { TEST(network_value, custom) { auto func = [](const network_connection_info& c) { - return c.src.global_location.x + c.dest.global_location.x; + return c.source.global_location.x + c.target.global_location.x; }; const auto v = thingify(network_value::custom(func), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_DOUBLE_EQ(v->get(src, dest), src.global_location.x + dest.global_location.x); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ( + v->get(source, target), source.global_location.x + target.global_location.x); } } } @@ -719,8 +717,8 @@ TEST(network_value, add) { thingify(network_value::add(network_value::scalar(2.0), network_value::scalar(3.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 5.0); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), 5.0); } } } @@ -729,8 +727,8 @@ TEST(network_value, sub) { thingify(network_value::sub(network_value::scalar(2.0), network_value::scalar(3.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), -1.0); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), -1.0); } } } @@ -739,8 +737,8 @@ TEST(network_value, mul) { thingify(network_value::mul(network_value::scalar(2.0), network_value::scalar(3.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 6.0); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { EXPECT_DOUBLE_EQ(v->get(source, target), 6.0); } } } @@ -749,26 +747,30 @@ TEST(network_value, div) { thingify(network_value::div(network_value::scalar(2.0), network_value::scalar(3.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), 2.0 / 3.0); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), 2.0 / 3.0); + } } } TEST(network_value, exp) { - const auto v = - thingify(network_value::exp(network_value::scalar(2.0)), - network_label_dict()); + const auto v = thingify(network_value::exp(network_value::scalar(2.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), std::exp(2.0)); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), std::exp(2.0)); + } } } TEST(network_value, log) { const auto v = thingify(network_value::log(network_value::scalar(2.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { EXPECT_DOUBLE_EQ(v->get(src, dest), std::log(2.0)); } + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), std::log(2.0)); + } } } @@ -780,10 +782,10 @@ TEST(network_value, min) { thingify(network_value::min(network_value::scalar(3.0), network_value::scalar(2.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_DOUBLE_EQ(v1->get(src, dest), 2.0); - EXPECT_DOUBLE_EQ(v2->get(src, dest), 2.0); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(source, target), 2.0); + EXPECT_DOUBLE_EQ(v2->get(source, target), 2.0); } } } @@ -796,10 +798,10 @@ TEST(network_value, max) { thingify(network_value::max(network_value::scalar(3.0), network_value::scalar(2.0)), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_DOUBLE_EQ(v1->get(src, dest), 3.0); - EXPECT_DOUBLE_EQ(v2->get(src, dest), 3.0); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v1->get(source, target), 3.0); + EXPECT_DOUBLE_EQ(v2->get(source, target), 3.0); } } } @@ -812,9 +814,9 @@ TEST(network_value, if_else) { const auto v = thingify(network_value::if_else(s, v1, v2), network_label_dict()); - for (const auto& src: test_sites) { - for (const auto& dest: test_sites) { - EXPECT_DOUBLE_EQ(v->get(src, dest), src.gid != dest.gid ? 2.0 : 3.0); + for (const auto& source: test_sites) { + for (const auto& target: test_sites) { + EXPECT_DOUBLE_EQ(v->get(source, target), source.gid != target.gid ? 2.0 : 3.0); } } } From 7e1aaf1f2451b9c668d4ce9f4a34180aff26cd8f Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 4 Sep 2023 09:18:01 +0200 Subject: [PATCH 55/84] fix test --- test/unit/test_s_expr.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/unit/test_s_expr.cpp b/test/unit/test_s_expr.cpp index 6797db0226..3668084cc4 100644 --- a/test/unit/test_s_expr.cpp +++ b/test/unit/test_s_expr.cpp @@ -371,20 +371,20 @@ TEST(network_selection, round_tripping) { "(source-cell-kind (lif-cell))", "(source-cell-kind (benchmark-cell))", "(source-cell-kind (spike-source-cell))", - "(destination-cell-kind (cable-cell))", - "(destination-cell-kind (lif-cell))", - "(destination-cell-kind (benchmark-cell))", - "(destination-cell-kind (spike-source-cell))", + "(target-cell-kind (cable-cell))", + "(target-cell-kind (lif-cell))", + "(target-cell-kind (benchmark-cell))", + "(target-cell-kind (spike-source-cell))", "(source-label \"abc\")", "(source-label \"abc\" \"def\")", "(source-label \"abc\" \"def\" \"ghi\")", - "(destination-label \"abc\")", - "(destination-label \"abc\" \"def\")", - "(destination-label \"abc\" \"def\" \"ghi\")", + "(target-label \"abc\")", + "(target-label \"abc\" \"def\")", + "(target-label \"abc\" \"def\" \"ghi\")", "(source-cell 0 1 3 15)", "(source-cell (gid-range 4 8 2))", - "(destination-cell 0 1 3 15)", - "(destination-cell (gid-range 4 8 2))", + "(target-cell 0 1 3 15)", + "(target-cell (gid-range 4 8 2))", "(chain 3 1 0 5 7 6)", // order should be preserved "(chain (gid-range 2 14 3))", "(chain-reverse (gid-range 2 14 3))", From 3dd836c937c87f79b37d0a56dad39f9b445ccea5 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 4 Sep 2023 09:18:22 +0200 Subject: [PATCH 56/84] revert to float type in spike event --- arbor/include/arbor/spike_event.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbor/include/arbor/spike_event.hpp b/arbor/include/arbor/spike_event.hpp index e318f02115..cf27cfc06f 100644 --- a/arbor/include/arbor/spike_event.hpp +++ b/arbor/include/arbor/spike_event.hpp @@ -15,7 +15,7 @@ namespace arb { struct spike_event { cell_lid_type target; time_type time; - double weight; + float weight; friend bool operator==(const spike_event& l, const spike_event& r) { return l.target==r.target && l.time==r.time && l.weight==r.weight; From 55a376f75394010327e954fe5640796ea5d43af1 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 4 Sep 2023 09:20:52 +0200 Subject: [PATCH 57/84] python reformatting --- python/example/network_description.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/example/network_description.py b/python/example/network_description.py index 09a855df45..1b39c7aac4 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -169,7 +169,9 @@ def global_properties(self, kind): print("connections:") for c in connections: - print(f'({c.source.gid}, "{c.source.label}") -> ({c.target.gid}, "{c.target.label}")') + print( + f'({c.source.gid}, "{c.source.label}") -> ({c.target.gid}, "{c.target.label}")' + ) # (16) Run simulation for 100 ms sim.run(100) From 384dff7095b2180bd892fa45ce65b91a81e2dfa3 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 6 Sep 2023 09:36:29 +0200 Subject: [PATCH 58/84] remove reference from string_view --- arbor/network.cpp | 192 +++++++++++------------------------------ arbor/network_impl.hpp | 8 +- 2 files changed, 50 insertions(+), 150 deletions(-) diff --git a/arbor/network.cpp b/arbor/network.cpp index a6e08af8cd..8295b51cf2 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -58,15 +58,11 @@ struct network_selection_all_impl: public network_selection_impl { return true; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -80,15 +76,11 @@ struct network_selection_none_impl: public network_selection_impl { return false; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return false; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return false; } @@ -105,15 +97,11 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { return source.kind == select_kind; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return kind == select_kind; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -139,15 +127,11 @@ struct network_selection_target_cell_kind_impl: public network_selection_impl { return target.kind == select_kind; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return kind == select_kind; } @@ -176,15 +160,11 @@ struct network_selection_source_label_impl: public network_selection_impl { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), source.label); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -208,15 +188,11 @@ struct network_selection_target_label_impl: public network_selection_impl { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), target.label); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); } @@ -240,15 +216,11 @@ struct network_selection_source_cell_impl: public network_selection_impl { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -273,15 +245,11 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -303,15 +271,11 @@ struct network_selection_target_cell_impl: public network_selection_impl { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } @@ -336,15 +300,11 @@ struct network_selection_target_cell_range_impl: public network_selection_impl { !((target.gid - gid_begin) % step); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } @@ -380,16 +340,12 @@ struct network_selection_chain_impl: public network_selection_impl { return false; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return !sorted_gids.empty() && std::binary_search(sorted_gids.begin(), sorted_gids.end() - 1, gid); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return !sorted_gids.empty() && std::binary_search(sorted_gids.begin() + 1, sorted_gids.end(), gid); } @@ -418,17 +374,13 @@ struct network_selection_chain_range_impl: public network_selection_impl { return source.gid + step == target.gid && !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { // Return false if outside range or if equal to last element, which cannot be a source if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { // Return false if outside range or if equal to first element, which cannot be a target if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); @@ -456,17 +408,13 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl return target.gid + step == source.gid && !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { // Return false if outside range or if equal to first element, which cannot be a source if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { // Return false if outside range or if equal to last element, which cannot be a target if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); @@ -488,16 +436,12 @@ struct network_selection_complement_impl: public network_selection_impl { return !selection->select_connection(source, target); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; // cannot exclude any because source selection cannot be complemented without // knowing selection criteria. } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; // cannot exclude any because target selection cannot be complemented // without knowing selection criteria. } @@ -526,17 +470,13 @@ struct network_selection_named_impl: public network_selection_impl { return selection->select_connection(source, target); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_source(kind, gid, label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_target(kind, gid, label); @@ -561,15 +501,11 @@ struct network_selection_inter_cell_impl: public network_selection_impl { return source.gid != target.gid; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -596,15 +532,11 @@ struct network_selection_custom_impl: public network_selection_impl { target.global_location}}); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -621,15 +553,11 @@ struct network_selection_distance_lt_impl: public network_selection_impl { return distance(source.global_location, target.global_location) < d; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -648,15 +576,11 @@ struct network_selection_distance_gt_impl: public network_selection_impl { return distance(source.global_location, target.global_location) > d; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -683,15 +607,11 @@ struct network_selection_random_impl: public network_selection_impl { return r < p; } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return true; } @@ -719,15 +639,11 @@ struct network_selection_intersect_impl: public network_selection_impl { return left->select_connection(source, target) && right->select_connection(source, target); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_source(kind, gid, label) && right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_target(kind, gid, label) && right->select_target(kind, gid, label); } @@ -769,15 +685,11 @@ struct network_selection_join_impl: public network_selection_impl { return left->select_connection(source, target) || right->select_connection(source, target); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } @@ -817,15 +729,11 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp return left->select_connection(source, target) ^ right->select_connection(source, target); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } @@ -866,15 +774,11 @@ struct network_selection_difference_impl: public network_selection_impl { !(right->select_connection(source, target)); } - bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_source(kind, gid, label); } - bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { return left->select_target(kind, gid, label); } diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index 480ff7d3d8..b38a0b23c5 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -46,13 +46,9 @@ struct network_selection_impl { virtual bool select_connection(const network_full_site_info& source, const network_full_site_info& target) const = 0; - virtual bool select_source(cell_kind kind, - cell_gid_type gid, - const std::string_view& tag) const = 0; + virtual bool select_source(cell_kind kind, cell_gid_type gid, std::string_view tag) const = 0; - virtual bool select_target(cell_kind kind, - cell_gid_type gid, - const std::string_view& tag) const = 0; + virtual bool select_target(cell_kind kind, cell_gid_type gid, std::string_view tag) const = 0; virtual void initialize(const network_label_dict& dict){}; From f008f72420c8daeac15e7a6f9d85847df9966a1c Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:18:21 +0200 Subject: [PATCH 59/84] Push hashes up into decor. --- arbor/benchmark_cell_group.cpp | 4 ++-- arbor/cable_cell.cpp | 18 +++++++++++------ arbor/cable_cell_param.cpp | 9 ++++++++- arbor/fvm_lowered_cell_impl.hpp | 7 +------ arbor/include/arbor/cable_cell.hpp | 10 +++++----- arbor/include/arbor/cable_cell_param.hpp | 5 ++++- arbor/include/arbor/common_types.hpp | 4 ++++ arbor/include/arbor/cv_policy.hpp | 2 +- arbor/label_resolution.cpp | 13 ++++++++++-- arbor/label_resolution.hpp | 3 ++- arbor/lif_cell_group.cpp | 4 ++-- arbor/spike_source_cell_group.cpp | 2 +- arbor/util/hash.hpp | 4 ++-- arborio/cableio.cpp | 12 +++++++----- test/unit/test_cable_cell.cpp | 25 ++++++++++++------------ test/unit/test_label_resolution.cpp | 22 ++++++++++----------- 16 files changed, 86 insertions(+), 58 deletions(-) diff --git a/arbor/benchmark_cell_group.cpp b/arbor/benchmark_cell_group.cpp index bdcb1789a6..d2ac067182 100644 --- a/arbor/benchmark_cell_group.cpp +++ b/arbor/benchmark_cell_group.cpp @@ -40,8 +40,8 @@ benchmark_cell_group::benchmark_cell_group(const std::vector& gid for (const auto& c: cells_) { cg_sources.add_cell(); cg_targets.add_cell(); - cg_sources.add_label(c.source, {0, 1}); - cg_targets.add_label(c.target, {0, 1}); + cg_sources.add_label(internal_hash(c.source), {0, 1}); + cg_targets.add_label(internal_hash(c.target), {0, 1}); } benchmark_cell_group::reset(); diff --git a/arbor/cable_cell.cpp b/arbor/cable_cell.cpp index 3afbc82400..210fdd3cea 100644 --- a/arbor/cable_cell.cpp +++ b/arbor/cable_cell.cpp @@ -88,7 +88,7 @@ struct cable_cell_impl { decor decorations; // The placeable label to lid_range map - dynamic_typed_map>::type> labeled_lid_ranges; + dynamic_typed_map>::type> labeled_lid_ranges; cable_cell_impl(const arb::morphology& m, const label_dict& labels, const decor& decorations): provider(m, labels), @@ -120,7 +120,7 @@ struct cable_cell_impl { } template - void place(const locset& ls, const Item& item, const cell_tag_type& label) { + void place(const locset& ls, const Item& item, const hash_type& label) { auto& mm = get_location_map(item); cell_lid_type& lid = placed_count.get(); cell_lid_type first = lid; @@ -226,7 +226,8 @@ void cable_cell_impl::init(const decor& d) { for (const auto& p: d.placements()) { auto& where = std::get<0>(p); auto& label = std::get<2>(p); - std::visit([this, &where, &label] (auto&& what) {return this->place(where, what, label);}, std::get<1>(p)); + std::visit([this, &where, &label] (auto&& what) {return this->place(where, what, label); }, + std::get<1>(p)); } } @@ -280,16 +281,21 @@ const cable_cell_parameter_set& cable_cell::default_parameters() const { return impl_->decorations.defaults(); } -const std::unordered_multimap& cable_cell::detector_ranges() const { +const cable_cell::lid_range_map& cable_cell::detector_ranges() const { return impl_->labeled_lid_ranges.get(); } -const std::unordered_multimap& cable_cell::synapse_ranges() const { +const cable_cell::lid_range_map& cable_cell::synapse_ranges() const { return impl_->labeled_lid_ranges.get(); } -const std::unordered_multimap& cable_cell::junction_ranges() const { +const cable_cell::lid_range_map& cable_cell::junction_ranges() const { return impl_->labeled_lid_ranges.get(); } +cell_tag_type decor::tag_of(hash_type hash) const { + if (!hashes_.count(hash)) throw arbor_internal_error{util::pprintf("Unknown hash for {}.", std::to_string(hash))}; + return hashes_.at(hash); +} + } // namespace arb diff --git a/arbor/cable_cell_param.cpp b/arbor/cable_cell_param.cpp index 909ec7faba..55c4cdd079 100644 --- a/arbor/cable_cell_param.cpp +++ b/arbor/cable_cell_param.cpp @@ -10,7 +10,9 @@ #include #include +#include "util/hash.hpp" #include "util/maputil.hpp" +#include "util/strprintf.hpp" namespace arb { @@ -120,7 +122,12 @@ decor& decor::paint(region where, paintable what) { } decor& decor::place(locset where, placeable what, cell_tag_type label) { - placements_.emplace_back(std::move(where), std::move(what), std::move(label)); + auto hash = internal_hash(label); + if (hashes_.count(hash) && hashes_.at(hash) != label) { + throw arbor_internal_error{util::strprintf("Hash collision {} ./. {}", label, hashes_.at(hash))}; + } + placements_.emplace_back(std::move(where), std::move(what), hash); + hashes_.emplace(hash, label); return *this; } diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 70fccf07ea..5206635229 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -313,16 +313,11 @@ fvm_detector_info get_detector_info(arb_size_type max, } inline cell_size_type -add_labels(cell_label_range& clr, const std::unordered_multimap& ranges) { +add_labels(cell_label_range& clr, const cable_cell::lid_range_map& ranges) { clr.add_cell(); cell_size_type count = 0; std::unordered_map hashes; for (const auto& [label, range]: ranges) { - auto hash = internal_hash(label); - if (hashes.count(hash) && hashes.at(hash) != label) { - auto err = util::strprintf("Hash collision {} ~ {} = {}", label, hashes.at(hash), hash); - throw arbor_internal_error{err}; - } clr.add_label(label, range); count += (range.end - range.begin); } diff --git a/arbor/include/arbor/cable_cell.hpp b/arbor/include/arbor/cable_cell.hpp index 221c967d14..e119508a49 100644 --- a/arbor/include/arbor/cable_cell.hpp +++ b/arbor/include/arbor/cable_cell.hpp @@ -246,8 +246,8 @@ using cable_cell_location_map = static_typed_map; // High-level abstract representation of a cell. -class ARB_SYMBOL_VISIBLE cable_cell { -public: +struct ARB_SYMBOL_VISIBLE cable_cell { + using lid_range_map = std::unordered_multimap; using index_type = cell_lid_type; using size_type = cell_local_size_type; using value_type = double; @@ -311,9 +311,9 @@ class ARB_SYMBOL_VISIBLE cable_cell { const cable_cell_parameter_set& default_parameters() const; // The labeled lid_ranges of sources, targets and gap_junctions on the cell; - const std::unordered_multimap& detector_ranges() const; - const std::unordered_multimap& synapse_ranges() const; - const std::unordered_multimap& junction_ranges() const; + const lid_range_map& detector_ranges() const; + const lid_range_map& synapse_ranges() const; + const lid_range_map& junction_ranges() const; private: std::unique_ptr impl_; diff --git a/arbor/include/arbor/cable_cell_param.hpp b/arbor/include/arbor/cable_cell_param.hpp index 46ed13e152..62ed0532ed 100644 --- a/arbor/include/arbor/cable_cell_param.hpp +++ b/arbor/include/arbor/cable_cell_param.hpp @@ -313,8 +313,9 @@ struct ARB_ARBOR_API cable_cell_parameter_set { // are to be applied to a morphology in a cable_cell. class ARB_ARBOR_API decor { std::vector> paintings_; - std::vector> placements_; + std::vector> placements_; cable_cell_parameter_set defaults_; + std::unordered_map hashes_; public: const auto& paintings() const {return paintings_; } @@ -324,6 +325,8 @@ class ARB_ARBOR_API decor { decor& paint(region, paintable); decor& place(locset, placeable, cell_tag_type); decor& set_default(defaultable); + + cell_tag_type tag_of(hash_type) const; }; ARB_ARBOR_API extern cable_cell_parameter_set neuron_parameter_defaults; diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index 34a4b2dddd..1d5fa262e8 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -19,6 +19,10 @@ namespace arb { +// Internal hashes use this 64bit id + +using hash_type = std::uint64_t; + // For identifying cells globally. using cell_gid_type = std::uint32_t; diff --git a/arbor/include/arbor/cv_policy.hpp b/arbor/include/arbor/cv_policy.hpp index a1a5077cf1..a56a5c51d9 100644 --- a/arbor/include/arbor/cv_policy.hpp +++ b/arbor/include/arbor/cv_policy.hpp @@ -59,7 +59,7 @@ namespace arb { -class cable_cell; +struct cable_cell; struct cv_policy_base { virtual locset cv_boundary_points(const cable_cell& cell) const = 0; diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index 69c3b6b57f..5506a79e6c 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -26,12 +26,21 @@ cell_label_range::cell_label_range(std::vector size_vec, arb_assert(check_invariant()); }; +cell_label_range::cell_label_range(std::vector size_vec, + std::vector label_vec, + std::vector range_vec): + sizes(std::move(size_vec)), labels(std::move(label_vec)), ranges(std::move(range_vec)) +{ + arb_assert(check_invariant()); +}; + + void cell_label_range::add_cell() { sizes.push_back(0); } -void cell_label_range::add_label(cell_tag_type label, lid_range range) { +void cell_label_range::add_label(hash_type label, lid_range range) { if (sizes.empty()) throw arbor_internal_error("adding label to cell_label_range without cell"); ++sizes.back(); - labels.push_back(internal_hash(label)); + labels.push_back(label); ranges.push_back(std::move(range)); } diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index d8823cbfe3..92d0270fe1 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -27,10 +27,11 @@ struct ARB_ARBOR_API cell_label_range { cell_label_range& operator=(cell_label_range&&) = default; cell_label_range(std::vector size_vec, std::vector label_vec, std::vector range_vec); + cell_label_range(std::vector size_vec, std::vector label_vec, std::vector range_vec); void add_cell(); - void add_label(cell_tag_type label, lid_range range); + void add_label(hash_type label, lid_range range); void append(cell_label_range other); diff --git a/arbor/lif_cell_group.cpp b/arbor/lif_cell_group.cpp index e655744483..ee4e881673 100644 --- a/arbor/lif_cell_group.cpp +++ b/arbor/lif_cell_group.cpp @@ -26,8 +26,8 @@ lif_cell_group::lif_cell_group(const std::vector& gids, // tell our caller about this cell's connections cg_sources.add_cell(); cg_targets.add_cell(); - cg_sources.add_label(cell.source, {0, 1}); - cg_targets.add_label(cell.target, {0, 1}); + cg_sources.add_label(internal_hash(cell.source), {0, 1}); + cg_targets.add_label(internal_hash(cell.target), {0, 1}); // insert probes where needed auto probes = rec.get_probes(gid); for (const auto lid: util::count_along(probes)) { diff --git a/arbor/spike_source_cell_group.cpp b/arbor/spike_source_cell_group.cpp index 92ea2799b7..41bc9705c0 100644 --- a/arbor/spike_source_cell_group.cpp +++ b/arbor/spike_source_cell_group.cpp @@ -33,7 +33,7 @@ spike_source_cell_group::spike_source_cell_group( try { auto cell = util::any_cast(rec.get_cell_description(gid)); time_sequences_.emplace_back(cell.seqs); - cg_sources.add_label(cell.source, {0, 1}); + cg_sources.add_label(internal_hash(cell.source), {0, 1}); } catch (std::bad_any_cast& e) { throw bad_cell_description(cell_kind::spike_source, gid); diff --git a/arbor/util/hash.hpp b/arbor/util/hash.hpp index 416cde1a84..988fec2a12 100644 --- a/arbor/util/hash.hpp +++ b/arbor/util/hash.hpp @@ -3,9 +3,9 @@ #include #include -namespace arb { -using hash_type = uint64_t; +#include +namespace arb { // Non-cryptographic hash function for mapping strings to internal // identifiers. Concretely, FNV-1a hash function taken from // diff --git a/arborio/cableio.cpp b/arborio/cableio.cpp index 2a15455e79..b0383b74a9 100644 --- a/arborio/cableio.cpp +++ b/arborio/cableio.cpp @@ -135,8 +135,10 @@ s_expr mksexp(const decor& d) { { return slist("paint"_symbol, round_trip(p.first), mksexp(x)); }, p.second)); } for (const auto& p: d.placements()) { - decorations.push_back(std::visit([&](auto& x) - { return slist("place"_symbol, round_trip(std::get<0>(p)), mksexp(x), s_expr(std::get<2>(p))); }, std::get<1>(p))); + decorations.push_back(std::visit([&](auto& x) { + auto lbl = d.tag_of(std::get<2>(p)); + return slist("place"_symbol, round_trip(std::get<0>(p)), mksexp(x), s_expr(lbl)); + }, std::get<1>(p))); } return {"decor"_symbol, slist_range(decorations)}; } @@ -282,9 +284,9 @@ decor make_decor(const std::vector(p), std::get<1>(p), std::get<2>(p)); }, - [&](const paint_pair & p) { d.paint(p.first, p.second); }, - [&](const defaultable & p){ d.set_default(p); }); + [&](const place_tuple& p) { d.place(std::get<0>(p), std::get<1>(p), std::get<2>(p)); }, + [&](const paint_pair& p) { d.paint(p.first, p.second); }, + [&](const defaultable& p){ d.set_default(p); }); std::visit(decor_visitor, a); } return d; diff --git a/test/unit/test_cable_cell.cpp b/test/unit/test_cable_cell.cpp index 1e59691bd2..c1066e9827 100644 --- a/test/unit/test_cable_cell.cpp +++ b/test/unit/test_cable_cell.cpp @@ -1,5 +1,6 @@ #include #include "../common_cells.hpp" +#include "util/hash.hpp" #include #include @@ -45,20 +46,20 @@ TEST(cable_cell, lid_ranges) { const auto& src_ranges = cell.detector_ranges(); const auto& tgt_ranges = cell.synapse_ranges(); - EXPECT_EQ(1u, tgt_ranges.count("t0")); - EXPECT_EQ(1u, tgt_ranges.count("t1")); - EXPECT_EQ(1u, src_ranges.count("s0")); - EXPECT_EQ(1u, tgt_ranges.count("t2")); - EXPECT_EQ(1u, src_ranges.count("s1")); - EXPECT_EQ(2u, tgt_ranges.count("t3")); + EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t0"))); + EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t1"))); + EXPECT_EQ(1u, src_ranges.count(internal_hash("s0"))); + EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t2"))); + EXPECT_EQ(1u, src_ranges.count(internal_hash("s1"))); + EXPECT_EQ(2u, tgt_ranges.count(internal_hash("t3"))); - auto r1 = tgt_ranges.equal_range("t0").first->second; - auto r2 = tgt_ranges.equal_range("t1").first->second; - auto r3 = src_ranges.equal_range("s0").first->second; - auto r4 = tgt_ranges.equal_range("t2").first->second; - auto r5 = src_ranges.equal_range("s1").first->second; + auto r1 = tgt_ranges.equal_range(internal_hash("t0")).first->second; + auto r2 = tgt_ranges.equal_range(internal_hash("t1")).first->second; + auto r3 = src_ranges.equal_range(internal_hash("s0")).first->second; + auto r4 = tgt_ranges.equal_range(internal_hash("t2")).first->second; + auto r5 = src_ranges.equal_range(internal_hash("s1")).first->second; - auto r6_range = tgt_ranges.equal_range("t3"); + auto r6_range = tgt_ranges.equal_range(internal_hash("t3")); auto r6_0 = r6_range.first; auto r6_1 = std::next(r6_range.first); if (r6_0->second.begin != 4u) { diff --git a/test/unit/test_label_resolution.cpp b/test/unit/test_label_resolution.cpp index 450428a743..2dacbdd607 100644 --- a/test/unit/test_label_resolution.cpp +++ b/test/unit/test_label_resolution.cpp @@ -23,7 +23,7 @@ TEST(test_cell_label_range, build) { // Test add_cell and add_label auto b0 = cell_label_range(); - EXPECT_THROW(b0.add_label("l0", {0u, 1u}), arb::arbor_internal_error); + EXPECT_THROW(b0.add_label(internal_hash("l0"), {0u, 1u}), arb::arbor_internal_error); EXPECT_TRUE(b0.sizes.empty()); EXPECT_TRUE(b0.labels.empty()); EXPECT_TRUE(b0.ranges.empty()); @@ -40,16 +40,16 @@ TEST(test_cell_label_range, build) { auto b2 = cell_label_range(); b2.add_cell(); - b2.add_label("l0", {0u, 1u}); - b2.add_label("l0", {3u, 13u}); - b2.add_label("l1", {0u, 5u}); + b2.add_label(internal_hash("l0"), {0u, 1u}); + b2.add_label(internal_hash("l0"), {3u, 13u}); + b2.add_label(internal_hash("l1"), {0u, 5u}); b2.add_cell(); b2.add_cell(); - b2.add_label("l2", {6u, 8u}); - b2.add_label("l3", {1u, 0u}); - b2.add_label("l4", {7u, 2u}); - b2.add_label("l4", {7u, 2u}); - b2.add_label("l2", {7u, 2u}); + b2.add_label(internal_hash("l2"), {6u, 8u}); + b2.add_label(internal_hash("l3"), {1u, 0u}); + b2.add_label(internal_hash("l4"), {7u, 2u}); + b2.add_label(internal_hash("l4"), {7u, 2u}); + b2.add_label(internal_hash("l2"), {7u, 2u}); EXPECT_EQ((ivec{3u, 0u, 5u}), b2.sizes); EXPECT_EQ(make_labels(svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b2.labels); EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b2.ranges); @@ -57,8 +57,8 @@ TEST(test_cell_label_range, build) { auto b3 = cell_label_range(); b3.add_cell(); - b3.add_label("r0", {0u, 9u}); - b3.add_label("r1", {10u, 10u}); + b3.add_label(internal_hash("r0"), {0u, 9u}); + b3.add_label(internal_hash("r1"), {10u, 10u}); b3.add_cell(); EXPECT_EQ((ivec{2u, 0u}), b3.sizes); EXPECT_EQ(make_labels From d8546fc6003629bcf6b961883fd3be2ec94c152a Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:40:49 +0200 Subject: [PATCH 60/84] Map back to labels. --- python/cells.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cells.cpp b/python/cells.cpp index f46bcf8aaf..c1c446b53c 100644 --- a/python/cells.cpp +++ b/python/cells.cpp @@ -829,7 +829,7 @@ void register_cells(pybind11::module& m) { [](arb::decor& dec) { std::vector> result; for (const auto& [k, v, t]: dec.placements()) { - result.emplace_back(to_string(k), v, t); + result.emplace_back(to_string(k), v, dec.tag_of(t)); } return result; }, From c94f1c2c239858c6ec85f39c9b5079da408bdf1e Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 19 Sep 2023 17:13:24 +0200 Subject: [PATCH 61/84] use hashes --- arbor/include/arbor/network.hpp | 19 +- arbor/network.cpp | 386 ++++++++++++++++---------------- arbor/network_impl.cpp | 277 +++++++---------------- arbor/network_impl.hpp | 30 +-- python/network.cpp | 8 +- test/unit/test_network.cpp | 115 +++++----- 6 files changed, 354 insertions(+), 481 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index 3ea02b40cd..d5717d4d3c 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -25,9 +25,20 @@ namespace arb { using network_hash_type = std::uint64_t; struct ARB_SYMBOL_VISIBLE network_site_info { + // network_site_info(cell_gid_type gid, + // cell_kind kind, + // hash_type label, + // mlocation location, + // mpoint global_location): + // gid(gid), + // kind(kind), + // label(label), + // location(location), + // global_location(global_location) {} + cell_gid_type gid; cell_kind kind; - cell_tag_type label; + hash_type label; mlocation location; mpoint global_location; @@ -60,7 +71,8 @@ class ARB_SYMBOL_VISIBLE network_selection; class ARB_SYMBOL_VISIBLE network_value { public: - using custom_func_type = std::function; + using custom_func_type = + std::function; network_value() { *this = network_value::scalar(0.0); } @@ -152,7 +164,8 @@ ARB_ARBOR_API inline network_value operator-(network_value a) { class ARB_SYMBOL_VISIBLE network_selection { public: - using custom_func_type = std::function; + using custom_func_type = + std::function; network_selection() { *this = network_selection::none(); } diff --git a/arbor/network.cpp b/arbor/network.cpp index 8295b51cf2..5c31b819cf 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -13,8 +13,10 @@ #include #include #include +#include #include "network_impl.hpp" +#include "util/hash.hpp" namespace arb { @@ -29,40 +31,49 @@ enum class network_seed : unsigned { value_truncated_normal = 380237, }; -double uniform_rand_from_key_pair(std::array seed, - network_hash_type key_a, - network_hash_type key_b) { - using rand_type = r123::Threefry2x64; - const rand_type::ctr_type seed_input = {{seed[0], seed[1]}}; +std::uint64_t location_hash(const mlocation& loc) { + const double l = static_cast(loc.branch) + loc.pos; + return *reinterpret_cast(&l); +} + +double uniform_rand(std::array seed, + const network_site_info& source, + const network_site_info& target) { + using rand_type = r123::Threefry4x64; + const rand_type::ctr_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; - const rand_type::key_type key = {{std::min(key_a, key_b), std::max(key_a, key_b)}}; + const rand_type::key_type key = { + {source.gid, location_hash(source.location), target.gid, location_hash(target.location)}}; rand_type gen; return r123::u01(gen(seed_input, key)[0]); } -double normal_rand_from_key_pair(std::array seed, - std::uint64_t key_a, - std::uint64_t key_b) { - using rand_type = r123::Threefry2x64; - const rand_type::ctr_type seed_input = {{seed[0], seed[1]}}; +double normal_rand(std::array seed, + const network_site_info& source, + const network_site_info& target) { - const rand_type::key_type key = {{std::min(key_a, key_b), std::max(key_a, key_b)}}; + using rand_type = r123::Threefry4x64; + const rand_type::ctr_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; + + const rand_type::key_type key = { + {source.gid, location_hash(source.location), target.gid, location_hash(target.location)}}; rand_type gen; const auto rand_num = gen(seed_input, key); + return r123::boxmuller(rand_num[0], rand_num[1]).x; } struct network_selection_all_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return true; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -71,16 +82,16 @@ struct network_selection_all_impl: public network_selection_impl { struct network_selection_none_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return false; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return false; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return false; } @@ -92,16 +103,16 @@ struct network_selection_source_cell_kind_impl: public network_selection_impl { explicit network_selection_source_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return source.kind == select_kind; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return kind == select_kind; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -122,16 +133,16 @@ struct network_selection_target_cell_kind_impl: public network_selection_impl { explicit network_selection_target_cell_kind_impl(cell_kind k): select_kind(k) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return target.kind == select_kind; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return kind == select_kind; } @@ -148,57 +159,66 @@ struct network_selection_target_cell_kind_impl: public network_selection_impl { }; struct network_selection_source_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector labels; + std::vector sorted_hashes; + + explicit network_selection_source_label_impl(std::vector labels_): + labels(std::move(labels_)) { + sorted_hashes.reserve(labels.size()); + for(const auto& l : labels) sorted_hashes.emplace_back(internal_hash(l)); - explicit network_selection_source_label_impl(std::vector labels): - sorted_labels(std::move(labels)) { - std::sort(sorted_labels.begin(), sorted_labels.end()); + std::sort(sorted_hashes.begin(), sorted_hashes.end()); } - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), source.label); + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), source.label); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } void print(std::ostream& os) const override { os << "(source-label"; - for (const auto& l: sorted_labels) { os << " \"" << l << "\""; } + for (const auto& l: labels) { os << " \"" << l << "\""; } os << ")"; } }; struct network_selection_target_label_impl: public network_selection_impl { - std::vector sorted_labels; + std::vector labels; + std::vector sorted_hashes; - explicit network_selection_target_label_impl(std::vector labels): - sorted_labels(std::move(labels)) { - std::sort(sorted_labels.begin(), sorted_labels.end()); + + explicit network_selection_target_label_impl(std::vector labels_): + labels(std::move(labels_)) { + sorted_hashes.reserve(labels.size()); + for(const auto& l : labels) sorted_hashes.emplace_back(internal_hash(l)); + + std::sort(sorted_hashes.begin(), sorted_hashes.end()); } - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), target.label); + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), target.label); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { - return std::binary_search(sorted_labels.begin(), sorted_labels.end(), label); + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { + return std::binary_search(sorted_hashes.begin(), sorted_hashes.end(), label); } void print(std::ostream& os) const override { os << "(target-label"; - for (const auto& l: sorted_labels) { os << " \"" << l << "\""; } + for (const auto& l: labels) { os << " \"" << l << "\""; } os << ")"; } }; @@ -211,16 +231,16 @@ struct network_selection_source_cell_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), source.gid); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -239,17 +259,17 @@ struct network_selection_source_cell_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return source.gid >= gid_begin && source.gid < gid_end && !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -266,16 +286,16 @@ struct network_selection_target_cell_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), target.gid); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return std::binary_search(sorted_gids.begin(), sorted_gids.end(), gid); } @@ -294,17 +314,17 @@ struct network_selection_target_cell_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return target.gid >= gid_begin && target.gid < gid_end && !((target.gid - gid_begin) % step); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return gid >= gid_begin && gid < gid_end && !((gid - gid_begin) % step); } @@ -321,8 +341,8 @@ struct network_selection_chain_impl: public network_selection_impl { std::sort(sorted_gids.begin(), sorted_gids.end()); } - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { if (gids.empty()) return false; // gids size always > 0 frome here on @@ -340,12 +360,12 @@ struct network_selection_chain_impl: public network_selection_impl { return false; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return !sorted_gids.empty() && std::binary_search(sorted_gids.begin(), sorted_gids.end() - 1, gid); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return !sorted_gids.empty() && std::binary_search(sorted_gids.begin() + 1, sorted_gids.end(), gid); } @@ -365,8 +385,8 @@ struct network_selection_chain_range_impl: public network_selection_impl { gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || target.gid >= gid_end) return false; @@ -374,13 +394,13 @@ struct network_selection_chain_range_impl: public network_selection_impl { return source.gid + step == target.gid && !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { // Return false if outside range or if equal to last element, which cannot be a source if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { // Return false if outside range or if equal to first element, which cannot be a target if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); @@ -399,8 +419,8 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl gid_end(r.end), step(r.step) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { if (source.gid < gid_begin || source.gid >= gid_end || target.gid < gid_begin || target.gid >= gid_end) return false; @@ -408,13 +428,13 @@ struct network_selection_reverse_chain_range_impl: public network_selection_impl return target.gid + step == source.gid && !((source.gid - gid_begin) % step); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { // Return false if outside range or if equal to first element, which cannot be a source if (gid <= gid_begin || gid >= gid_end) return false; return !((gid - gid_begin) % step); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { // Return false if outside range or if equal to last element, which cannot be a target if (gid < gid_begin || gid >= gid_end - 1) return false; return !((gid - gid_begin) % step); @@ -431,17 +451,17 @@ struct network_selection_complement_impl: public network_selection_impl { explicit network_selection_complement_impl(std::shared_ptr s): selection(std::move(s)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return !selection->select_connection(source, target); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; // cannot exclude any because source selection cannot be complemented without // knowing selection criteria. } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; // cannot exclude any because target selection cannot be complemented // without knowing selection criteria. } @@ -463,20 +483,20 @@ struct network_selection_named_impl: public network_selection_impl { explicit network_selection_named_impl(std::string name): selection_name(std::move(name)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_connection(source, target); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_source(kind, gid, label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { if (!selection) throw arbor_internal_error("Trying to use unitialized named network selection."); return selection->select_target(kind, gid, label); @@ -496,16 +516,16 @@ struct network_selection_named_impl: public network_selection_impl { }; struct network_selection_inter_cell_impl: public network_selection_impl { - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return source.gid != target.gid; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -518,25 +538,16 @@ struct network_selection_custom_impl: public network_selection_impl { explicit network_selection_custom_impl(network_selection::custom_func_type f): func(std::move(f)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { - return func({{source.gid, - source.kind, - cell_tag_type(source.label), - source.location, - source.global_location}, - {target.gid, - target.kind, - cell_tag_type(target.label), - target.location, - target.global_location}}); - } - - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { + return func(source, target); + } + + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -548,16 +559,16 @@ struct network_selection_distance_lt_impl: public network_selection_impl { explicit network_selection_distance_lt_impl(double d): d(d) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return distance(source.global_location, target.global_location) < d; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -571,16 +582,16 @@ struct network_selection_distance_gt_impl: public network_selection_impl { explicit network_selection_distance_gt_impl(double d): d(d) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return distance(source.global_location, target.global_location) > d; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -597,21 +608,21 @@ struct network_selection_random_impl: public network_selection_impl { seed(seed), p_value(std::move(p)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { if (!probability) throw arbor_internal_error("Trying to use unitialized named network selection."); - const auto r = uniform_rand_from_key_pair( - {unsigned(network_seed::selection_random), seed}, source.hash, target.hash); + const auto r = uniform_rand( + {unsigned(network_seed::selection_random), seed, seed + 1, seed + 2}, source, target); const auto p = (probability->get(source, target)); return r < p; } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return true; } @@ -634,16 +645,16 @@ struct network_selection_intersect_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return left->select_connection(source, target) && right->select_connection(source, target); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_source(kind, gid, label) && right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_target(kind, gid, label) && right->select_target(kind, gid, label); } @@ -680,16 +691,16 @@ struct network_selection_join_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return left->select_connection(source, target) || right->select_connection(source, target); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } @@ -724,16 +735,16 @@ struct network_selection_symmetric_difference_impl: public network_selection_imp left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return left->select_connection(source, target) ^ right->select_connection(source, target); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_source(kind, gid, label) || right->select_source(kind, gid, label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_target(kind, gid, label) || right->select_target(kind, gid, label); } @@ -768,17 +779,17 @@ struct network_selection_difference_impl: public network_selection_impl { left(std::move(l)), right(std::move(r)) {} - bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const override { + bool select_connection(const network_site_info& source, + const network_site_info& target) const override { return left->select_connection(source, target) && !(right->select_connection(source, target)); } - bool select_source(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_source(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_source(kind, gid, label); } - bool select_target(cell_kind kind, cell_gid_type gid, std::string_view label) const override { + bool select_target(cell_kind kind, cell_gid_type gid, hash_type label) const override { return left->select_target(kind, gid, label); } @@ -809,8 +820,8 @@ struct network_value_scalar_impl: public network_value_impl { network_value_scalar_impl(double v): value(v) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return value; } @@ -822,8 +833,8 @@ struct network_value_distance_impl: public network_value_impl { network_value_distance_impl(double s): scale(s) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return scale * distance(source.global_location, target.global_location); } @@ -841,13 +852,13 @@ struct network_value_uniform_distribution_impl: public network_value_impl { throw std::invalid_argument("Uniform distribution: invalid range"); } - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { if (range[0] > range[1]) return range[1]; // random number between 0 and 1 - const auto rand_num = uniform_rand_from_key_pair( - {unsigned(network_seed::value_uniform), seed}, source.hash, target.hash); + const auto rand_num = uniform_rand( + {unsigned(network_seed::value_uniform), seed, seed + 1, seed + 2}, source, target); return (range[1] - range[0]) * rand_num + range[0]; } @@ -867,12 +878,13 @@ struct network_value_normal_distribution_impl: public network_value_impl { mean(mean_), std_deviation(std_deviation_) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { - return mean + std_deviation * - normal_rand_from_key_pair({unsigned(network_seed::value_normal), seed}, - source.hash, - target.hash); + double get(const network_site_info& source, + const network_site_info& target) const override { + return mean + + std_deviation * + normal_rand({unsigned(network_seed::value_normal), seed, seed + 1, seed + 2}, + source, + target); } void print(std::ostream& os) const override { @@ -898,21 +910,21 @@ struct network_value_truncated_normal_distribution_impl: public network_value_im throw std::invalid_argument("Truncated normal distribution: invalid range"); } - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { - - const auto src_hash = source.hash; - auto dest_hash = target.hash; + double get(const network_site_info& source, + const network_site_info& target) const override { double value = 0.0; + auto dynamic_seed = seed; do { value = - mean + std_deviation * normal_rand_from_key_pair( - {unsigned(network_seed::value_truncated_normal), seed}, - src_hash, - dest_hash); - ++dest_hash; + mean + std_deviation * normal_rand({unsigned(network_seed::value_truncated_normal), + dynamic_seed, + dynamic_seed + 1, + dynamic_seed + 2}, + source, + target); + ++dynamic_seed; } while (!(value > range[0] && value <= range[1])); return value; @@ -929,18 +941,9 @@ struct network_value_custom_impl: public network_value_impl { network_value_custom_impl(network_value::custom_func_type f): func(std::move(f)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { - return func({{source.gid, - source.kind, - cell_tag_type(source.label), - source.location, - source.global_location}, - {target.gid, - target.kind, - cell_tag_type(target.label), - target.location, - target.global_location}}); + double get(const network_site_info& source, + const network_site_info& target) const override { + return func(source, target); } void print(std::ostream& os) const override { os << "(custom-network-value)"; } @@ -954,8 +957,8 @@ struct network_value_named_impl: public network_value_impl { explicit network_value_named_impl(std::string name): value_name(std::move(name)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { if (!value) throw arbor_internal_error("Trying to use unitialized named network value."); return value->get(source, target); } @@ -981,8 +984,8 @@ struct network_value_add_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return left->get(source, target) + right->get(source, target); } @@ -1008,8 +1011,8 @@ struct network_value_mul_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return left->get(source, target) * right->get(source, target); } @@ -1035,8 +1038,8 @@ struct network_value_sub_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return left->get(source, target) - right->get(source, target); } @@ -1062,8 +1065,8 @@ struct network_value_div_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { const auto v_right = right->get(source, target); if (!v_right) throw arbor_exception("network_value: division by 0."); return left->get(source, target) / right->get(source, target); @@ -1091,8 +1094,8 @@ struct network_value_max_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return std::max(left->get(source, target), right->get(source, target)); } @@ -1118,8 +1121,8 @@ struct network_value_min_impl: public network_value_impl { left(std::move(l)), right(std::move(r)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return std::min(left->get(source, target), right->get(source, target)); } @@ -1142,8 +1145,8 @@ struct network_value_exp_impl: public network_value_impl { network_value_exp_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { return std::exp(value->get(source, target)); } @@ -1161,8 +1164,8 @@ struct network_value_log_impl: public network_value_impl { network_value_log_impl(std::shared_ptr v): value(std::move(v)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, + const network_site_info& target) const override { const auto v = value->get(source, target); if (v <= 0.0) throw arbor_exception("network_value: log of value <= 0.0."); return std::log(value->get(source, target)); @@ -1189,8 +1192,7 @@ struct network_value_if_else_impl: public network_value_impl { true_value(std::move(true_value)), false_value(std::move(false_value)) {} - double get(const network_full_site_info& source, - const network_full_site_info& target) const override { + double get(const network_site_info& source, const network_site_info& target) const override { if (cond->select_connection(source, target)) return true_value->get(source, target); return false_value->get(source, target); } diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 168a47be15..07a2fe2480 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -36,109 +36,26 @@ namespace arb { namespace { -// We only need minimal hash collisions and good spread over the hash range, because this will be -// used as input for random123, which then provides all desired hash properties. -// std::hash is implementation dependent, so we define our own for reproducibility. - -std::uint64_t simple_string_hash(const std::string_view& s) { - // use fnv1a hash algorithm - constexpr std::uint64_t prime = 1099511628211ull; - std::uint64_t h = 14695981039346656037ull; - - for (auto c: s) { - h ^= c; - h *= prime; - } - - return h; -} - -struct distributed_site_info { - cell_gid_type gid = 0; - cell_lid_type lid = 0; - cell_kind kind = cell_kind::cable; - cell_gid_type label_start_idx = 0; - mlocation location = mlocation(); - mpoint global_location = mpoint(); - network_hash_type hash = 0; -}; - -struct site_mapping { - std::vector sites; - std::vector labels; - std::unordered_map label_map; - - site_mapping() = default; - - inline std::size_t size() const { return sites.size(); } +struct network_site_info_extended { + network_site_info_extended(network_site_info info, cell_lid_type lid): + info(std::move(info)), + lid(lid) {} - void insert(const network_full_site_info& s) { - const auto insert_pair = label_map.insert({std::string(s.label), labels.size()}); - // append label if not contained in labels - if (insert_pair.second) { - labels.insert(labels.end(), s.label.begin(), s.label.end()); - labels.push_back('\0'); - } - sites.emplace_back(distributed_site_info{s.gid, - s.lid, - s.kind, - insert_pair.first->second, - s.location, - s.global_location, - s.hash}); - } - - void insert(const site_mapping& m) { - for (std::size_t idx = 0; idx < m.size(); ++idx) { this->insert(m.get_site(idx)); } - } - - network_full_site_info get_site(std::size_t idx) const { - const auto& s = this->sites.at(idx); - - network_full_site_info info; - info.gid = s.gid; - info.lid = s.lid; - info.kind = s.kind; - info.label = labels.data() + s.label_start_idx; - info.location = s.location; - info.global_location = s.global_location; - info.hash = s.hash; - - return info; - } + network_site_info info; + cell_lid_type lid; }; void push_back(const domain_decomposition& dom_dec, std::vector& vec, - const network_full_site_info& source, - const network_full_site_info& target, + const network_site_info_extended& source, + const network_site_info_extended& target, double weight, double delay) { - vec.emplace_back(connection{{source.gid, source.lid}, + vec.emplace_back(connection{{source.info.gid, source.lid}, target.lid, (float)weight, (float)delay, - dom_dec.index_on_domain(target.gid)}); -} - -void push_back(const domain_decomposition&, - std::vector& vec, - const network_full_site_info& source, - const network_full_site_info& target, - double weight, - double delay) { - vec.emplace_back(network_connection_info{network_site_info{source.gid, - source.kind, - std::string(source.label), - source.location, - source.global_location}, - network_site_info{target.gid, - target.kind, - std::string(target.label), - target.location, - target.global_location}, - weight, - delay}); + dom_dec.index_on_domain(target.info.gid)}); } template @@ -168,8 +85,8 @@ std::vector generate_network_connections(const recipe& rec, } const auto num_batches = ctx->thread_pool->get_num_threads(); - std::vector src_site_batches(num_batches); - std::vector dest_site_batches(num_batches); + std::vector> src_site_batches(num_batches); + std::vector> tgt_site_batches(num_batches); for (const auto& [kind, gids]: gids_by_kind) { const auto batch_size = (gids.size() + num_batches - 1) / num_batches; @@ -180,7 +97,7 @@ std::vector generate_network_connections(const recipe& rec, 0, cable_gids.size(), batch_size, ctx->thread_pool.get(), [&](int i) { const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); auto& src_sites = src_site_batches[batch_idx]; - auto& dest_sites = dest_site_batches[batch_idx]; + auto& tgt_sites = tgt_site_batches[batch_idx]; const auto gid = cable_gids[i]; const auto kind = rec.get_cell_kind(gid); // We need access to morphology, so the cell is create directly @@ -193,8 +110,8 @@ std::vector generate_network_connections(const recipe& rec, } auto lid_to_label = - [](const std::unordered_multimap& map, - cell_lid_type lid) -> const cell_tag_type& { + [](const std::unordered_multimap& map, + cell_lid_type lid) -> hash_type { for (const auto& [label, range]: map) { if (lid >= range.begin && lid < range.end) return label; } @@ -211,8 +128,10 @@ std::vector generate_network_connections(const recipe& rec, if (selection.select_target(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_syn.loc); - dest_sites.insert( - {gid, p_syn.lid, cell_kind::cable, label, p_syn.loc, point}); + tgt_sites.emplace_back( + network_site_info{ + gid, cell_kind::cable, label, p_syn.loc, point}, + p_syn.lid); } } } @@ -222,8 +141,9 @@ std::vector generate_network_connections(const recipe& rec, const auto& label = lid_to_label(cell.detector_ranges(), p_det.lid); if (selection.select_source(cell_kind::cable, gid, label)) { const mpoint point = location_resolver.at(p_det.loc); - src_sites.insert( - {gid, p_det.lid, cell_kind::cable, label, p_det.loc, point}); + src_sites.emplace_back( + network_site_info{gid, cell_kind::cable, label, p_det.loc, point}, + p_det.lid); } } }); @@ -238,7 +158,7 @@ std::vector generate_network_connections(const recipe& rec, std::ignore = factory(gids, rec, sources, targets); auto& src_sites = src_site_batches[0]; - auto& dest_sites = dest_site_batches[0]; + auto& tgt_sites = tgt_site_batches[0]; std::size_t source_label_offset = 0; std::size_t target_label_offset = 0; @@ -246,18 +166,19 @@ std::vector generate_network_connections(const recipe& rec, const auto gid = gids[i]; const auto iso = rec.get_cell_isometry(gid); const auto point = iso.apply(mpoint{0.0, 0.0, 0.0, 0.0}); - const auto num_source_labels = sources.sizes().at(i); - const auto num_target_labels = targets.sizes().at(i); + const auto num_source_labels = sources.sizes.at(i); + const auto num_target_labels = targets.sizes.at(i); // Iterate over each source label for current gid for (std::size_t j = source_label_offset; j < source_label_offset + num_source_labels; ++j) { - const auto& label = sources.labels().at(j); - const auto& range = sources.ranges().at(j); + const auto& label = sources.labels.at(j); + const auto& range = sources.ranges.at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_source(kind, gid, label)) { - src_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); + src_sites.emplace_back( + network_site_info{gid, kind, label, mlocation{0, 0.0}, point}, lid); } } } @@ -266,11 +187,12 @@ std::vector generate_network_connections(const recipe& rec, for (std::size_t j = target_label_offset; j < target_label_offset + num_target_labels; ++j) { - const auto& label = targets.labels().at(j); - const auto& range = targets.ranges().at(j); + const auto& label = targets.labels.at(j); + const auto& range = targets.ranges.at(j); for (auto lid = range.begin; lid < range.end; ++lid) { if (selection.select_target(kind, gid, label)) { - dest_sites.insert({gid, lid, kind, label, {0, 0.0}, point}); + tgt_sites.emplace_back( + network_site_info{gid, kind, label, mlocation{0, 0.0}, point}, lid); } } } @@ -281,58 +203,44 @@ std::vector generate_network_connections(const recipe& rec, } } - site_mapping& src_sites = src_site_batches.front(); - // combine source batches - for (std::size_t batch_idx = 1; batch_idx < src_site_batches.size(); ++batch_idx) { + auto src_sites = std::move(src_site_batches.back()); + src_site_batches.pop_back(); + for (const auto& batch: src_site_batches) + src_sites.insert(src_sites.end(), batch.begin(), batch.end()); - for (std::size_t i = 0; i < src_site_batches[batch_idx].size(); ++i) { - src_sites.insert(src_site_batches[batch_idx].get_site(i)); - } - } + auto tgt_sites = std::move(tgt_site_batches.back()); + tgt_site_batches.pop_back(); + for (const auto& batch: tgt_site_batches) + tgt_sites.insert(tgt_sites.end(), batch.begin(), batch.end()); // create octree - std::vector network_dest_sites; - network_dest_sites.reserve(dest_site_batches[0].size() * num_batches); - for (const auto& dest_sites: dest_site_batches) { - for (std::size_t i = 0; i < dest_sites.size(); ++i) { - network_dest_sites.emplace_back(dest_sites.get_site(i)); - } - } const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; const std::size_t max_leaf_size = 100; - spatial_tree local_dest_tree(max_depth, + spatial_tree local_tgt_tree(max_depth, max_leaf_size, - std::move(network_dest_sites), - [](const network_full_site_info& info) - -> spatial_tree::point_type { - return {info.global_location.x, info.global_location.y, info.global_location.z}; + std::move(tgt_sites), + [](const network_site_info_extended& ex) + -> spatial_tree::point_type { + return { + ex.info.global_location.x, ex.info.global_location.y, ex.info.global_location.z}; }); // select connections std::vector> connection_batches(num_batches); - auto sample_sources = [&](const util::range& source_range, - const util::range& label_range) { + auto sample_sources = [&](const util::range& source_range) { const auto batch_size = (source_range.size() + num_batches - 1) / num_batches; threading::parallel_for::apply( 0, source_range.size(), batch_size, ctx->thread_pool.get(), [&](int i) { - const auto& s = source_range[i]; + const auto& source = source_range[i]; const auto batch_idx = ctx->thread_pool->get_current_thread_id().value(); auto& connections = connection_batches[batch_idx]; - network_full_site_info source; - source.gid = s.gid; - source.lid = s.lid; - source.kind = s.kind; - source.label = label_range.data() + s.label_start_idx; - source.location = s.location; - source.global_location = s.global_location; - source.hash = s.hash; - - auto sample = [&](const network_full_site_info& target) { - if (selection.select_connection(source, target)) { - const auto w = weight.get(source, target); - const auto d = delay.get(source, target); + + auto sample = [&](const network_site_info_extended& target) { + if (selection.select_connection(source.info, target.info)) { + const auto w = weight.get(source.info, target.info); + const auto d = delay.get(source.info, target.info); push_back(dom_dec, connections, source, target, w, d); } @@ -340,23 +248,20 @@ std::vector generate_network_connections(const recipe& rec, if (selection.max_distance().has_value()) { const double d = selection.max_distance().value(); - local_dest_tree.bounding_box_for_each( - decltype(local_dest_tree)::point_type{source.global_location.x - d, - source.global_location.y - d, - source.global_location.z - d}, - decltype(local_dest_tree)::point_type{source.global_location.x + d, - source.global_location.y + d, - source.global_location.z + d}, + local_tgt_tree.bounding_box_for_each( + decltype(local_tgt_tree)::point_type{source.info.global_location.x - d, + source.info.global_location.y - d, + source.info.global_location.z - d}, + decltype(local_tgt_tree)::point_type{source.info.global_location.x + d, + source.info.global_location.y + d, + source.info.global_location.z + d}, sample); } - else { local_dest_tree.for_each(sample); } + else { local_tgt_tree.for_each(sample); } }); }; - distributed_for_each(sample_sources, - distributed, - util::range_view(src_sites.sites), - util::range_view(src_sites.labels)); + distributed_for_each(sample_sources, distributed, util::range_view(src_sites)); // concatenate auto connections = std::move(connection_batches.front()); @@ -369,57 +274,29 @@ std::vector generate_network_connections(const recipe& rec, } // namespace -network_full_site_info::network_full_site_info(cell_gid_type gid, - cell_lid_type lid, - cell_kind kind, - std::string_view label, - mlocation location, - mpoint global_location): - gid(gid), - lid(lid), - kind(kind), - label(std::move(label)), - location(location), - global_location(global_location) { - - std::uint64_t label_hash = simple_string_hash(this->label); - static_assert(sizeof(decltype(mlocation::pos)) == sizeof(std::uint64_t)); - std::uint64_t loc_pos_hash = *reinterpret_cast(&location.pos); - - // Initial seed. Changes will affect reproducibility of generated network connections. - constexpr std::uint64_t seed = 984293; - - using rand_type = r123::Threefry4x64; - const rand_type::ctr_type seed_input = {{seed, 2 * seed, 3 * seed, 4 * seed}}; - const rand_type::key_type key = {{gid, label_hash, location.branch, loc_pos_hash}}; - - rand_type gen; - hash = gen(seed_input, key)[0]; -} - std::vector generate_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { return generate_network_connections(rec, ctx, dom_dec); } -ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, - const context& ctx, - const domain_decomposition& dom_dec) { - auto connections = generate_network_connections(rec, ctx, dom_dec); +// ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, +// const context& ctx, +// const domain_decomposition& dom_dec) { +// auto connections = generate_network_connections(rec, ctx, dom_dec); - // generated connections may have different order each time due to multi-threading. - // Sort before returning to user for reproducibility. - std::sort(connections.begin(), connections.end()); +// // generated connections may have different order each time due to multi-threading. +// // Sort before returning to user for reproducibility. +// std::sort(connections.begin(), connections.end()); - return connections; -} +// return connections; +// } -ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { - auto ctx = arb::make_context(); - auto decomp = arb::partition_load_balance(rec, ctx); +// ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { +// auto ctx = arb::make_context(); +// auto decomp = arb::partition_load_balance(rec, ctx); - return generate_network_connections(rec, ctx, decomp); -} +// return generate_network_connections(rec, ctx, decomp); +// } } // namespace arb diff --git a/arbor/network_impl.hpp b/arbor/network_impl.hpp index b38a0b23c5..65c62767ef 100644 --- a/arbor/network_impl.hpp +++ b/arbor/network_impl.hpp @@ -21,34 +21,15 @@ namespace arb { -struct ARB_SYMBOL_VISIBLE network_full_site_info { - network_full_site_info() = default; - - network_full_site_info(cell_gid_type gid, - cell_lid_type lid, - cell_kind kind, - std::string_view label, - mlocation location, - mpoint global_location); - - cell_gid_type gid; - cell_lid_type lid; - cell_kind kind; - std::string_view label; - mlocation location; - mpoint global_location; - network_hash_type hash; -}; - struct network_selection_impl { virtual std::optional max_distance() const { return std::nullopt; } - virtual bool select_connection(const network_full_site_info& source, - const network_full_site_info& target) const = 0; + virtual bool select_connection(const network_site_info& source, + const network_site_info& target) const = 0; - virtual bool select_source(cell_kind kind, cell_gid_type gid, std::string_view tag) const = 0; + virtual bool select_source(cell_kind kind, cell_gid_type gid, hash_type tag) const = 0; - virtual bool select_target(cell_kind kind, cell_gid_type gid, std::string_view tag) const = 0; + virtual bool select_target(cell_kind kind, cell_gid_type gid, hash_type tag) const = 0; virtual void initialize(const network_label_dict& dict){}; @@ -64,8 +45,7 @@ inline std::shared_ptr thingify(network_selection s, } struct network_value_impl { - virtual double get(const network_full_site_info& source, - const network_full_site_info& target) const = 0; + virtual double get(const network_site_info& source, const network_site_info& target) const = 0; virtual void initialize(const network_label_dict& dict){}; diff --git a/python/network.cpp b/python/network.cpp index c267211535..50d3d1269f 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -56,11 +56,11 @@ void register_network(py::module& m) { network_selection .def_static("custom", [](arb::network_selection::custom_func_type func) { - return arb::network_selection::custom([=](const arb::network_connection_info& c) { + return arb::network_selection::custom([=](const arb::network_site_info& source, const arb::network_site_info& target) { return try_catch_pyexception( [&]() { pybind11::gil_scoped_acquire guard; - return func(c); + return func(source, target); }, "Python error already thrown"); }); @@ -75,11 +75,11 @@ void register_network(py::module& m) { network_value .def_static("custom", [](arb::network_value::custom_func_type func) { - return arb::network_value::custom([=](const arb::network_connection_info& c) { + return arb::network_value::custom([=](const arb::network_site_info& source, const arb::network_site_info& target) { return try_catch_pyexception( [&]() { pybind11::gil_scoped_acquire guard; - return func(c); + return func(source, target); }, "Python error already thrown"); }); diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp index 09e80b1563..f04669333d 100644 --- a/test/unit/test_network.cpp +++ b/test/unit/test_network.cpp @@ -3,6 +3,7 @@ #include #include "network_impl.hpp" +#include "util/hash.hpp" #include #include @@ -10,36 +11,36 @@ using namespace arb; namespace { -std::vector test_sites = { - {0, 0, cell_kind::cable, "a", {1, 0.5}, {0.0, 0.0, 0.0}}, - {1, 0, cell_kind::benchmark, "b", {0, 0.0}, {1.0, 0.0, 0.0}}, - {2, 0, cell_kind::lif, "c", {0, 0.0}, {2.0, 0.0, 0.0}}, - {3, 0, cell_kind::spike_source, "d", {0, 0.0}, {3.0, 0.0, 0.0}}, - {4, 0, cell_kind::cable, "e", {0, 0.2}, {4.0, 0.0, 0.0}}, - {5, 0, cell_kind::cable, "f", {5, 0.1}, {5.0, 0.0, 0.0}}, - {6, 0, cell_kind::cable, "g", {4, 0.3}, {6.0, 0.0, 0.0}}, - {7, 0, cell_kind::cable, "h", {0, 1.0}, {7.0, 0.0, 0.0}}, - {9, 0, cell_kind::cable, "i", {0, 0.1}, {12.0, 3.0, 4.0}}, - - {10, 0, cell_kind::cable, "a", {0, 0.1}, {12.0, 15.0, 16.0}}, - {10, 1, cell_kind::cable, "b", {1, 0.1}, {13.0, 15.0, 16.0}}, - {10, 2, cell_kind::cable, "c", {1, 0.5}, {14.0, 15.0, 16.0}}, - {10, 3, cell_kind::cable, "d", {1, 1.0}, {15.0, 15.0, 16.0}}, - {10, 4, cell_kind::cable, "e", {2, 0.1}, {16.0, 15.0, 16.0}}, - {10, 5, cell_kind::cable, "f", {3, 0.1}, {16.0, 16.0, 16.0}}, - {10, 6, cell_kind::cable, "g", {4, 0.1}, {12.0, 17.0, 16.0}}, - {10, 7, cell_kind::cable, "h", {5, 0.1}, {12.0, 18.0, 16.0}}, - {10, 8, cell_kind::cable, "i", {6, 0.1}, {12.0, 19.0, 16.0}}, - - {11, 0, cell_kind::cable, "abcd", {0, 0.1}, {-2.0, -5.0, 3.0}}, - {11, 1, cell_kind::cable, "cabd", {1, 0.2}, {-2.1, -5.0, 3.0}}, - {11, 2, cell_kind::cable, "cbad", {1, 0.3}, {-2.2, -5.0, 3.0}}, - {11, 3, cell_kind::cable, "acbd", {1, 1.0}, {-2.3, -5.0, 3.0}}, - {11, 4, cell_kind::cable, "bacd", {2, 0.2}, {-2.4, -5.0, 3.0}}, - {11, 5, cell_kind::cable, "bcad", {3, 0.3}, {-2.5, -5.0, 3.0}}, - {11, 6, cell_kind::cable, "dabc", {4, 0.4}, {-2.6, -5.0, 3.0}}, - {11, 7, cell_kind::cable, "dbca", {5, 0.5}, {-2.7, -5.0, 3.0}}, - {11, 8, cell_kind::cable, "dcab", {6, 0.6}, {-2.8, -5.0, 3.0}}, +std::vector test_sites = { + {0, cell_kind::cable, internal_hash("a"), {1, 0.5}, {0.0, 0.0, 0.0}}, + {1, cell_kind::benchmark, internal_hash("b"), {0, 0.0}, {1.0, 0.0, 0.0}}, + {2, cell_kind::lif, internal_hash("c"), {0, 0.0}, {2.0, 0.0, 0.0}}, + {3, cell_kind::spike_source, internal_hash("d"), {0, 0.0}, {3.0, 0.0, 0.0}}, + {4, cell_kind::cable, internal_hash("e"), {0, 0.2}, {4.0, 0.0, 0.0}}, + {5, cell_kind::cable, internal_hash("f"), {5, 0.1}, {5.0, 0.0, 0.0}}, + {6, cell_kind::cable, internal_hash("g"), {4, 0.3}, {6.0, 0.0, 0.0}}, + {7, cell_kind::cable, internal_hash("h"), {0, 1.0}, {7.0, 0.0, 0.0}}, + {9, cell_kind::cable, internal_hash("i"), {0, 0.1}, {12.0, 3.0, 4.0}}, + + {10, cell_kind::cable, internal_hash("a"), {0, 0.1}, {12.0, 15.0, 16.0}}, + {10, cell_kind::cable, internal_hash("b"), {1, 0.1}, {13.0, 15.0, 16.0}}, + {10, cell_kind::cable, internal_hash("c"), {1, 0.5}, {14.0, 15.0, 16.0}}, + {10, cell_kind::cable, internal_hash("d"), {1, 1.0}, {15.0, 15.0, 16.0}}, + {10, cell_kind::cable, internal_hash("e"), {2, 0.1}, {16.0, 15.0, 16.0}}, + {10, cell_kind::cable, internal_hash("f"), {3, 0.1}, {16.0, 16.0, 16.0}}, + {10, cell_kind::cable, internal_hash("g"), {4, 0.1}, {12.0, 17.0, 16.0}}, + {10, cell_kind::cable, internal_hash("h"), {5, 0.1}, {12.0, 18.0, 16.0}}, + {10, cell_kind::cable, internal_hash("i"), {6, 0.1}, {12.0, 19.0, 16.0}}, + + {11, cell_kind::cable, internal_hash("abcd"), {0, 0.1}, {-2.0, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("cabd"), {1, 0.2}, {-2.1, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("cbad"), {1, 0.3}, {-2.2, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("acbd"), {1, 1.0}, {-2.3, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("bacd"), {2, 0.2}, {-2.4, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("bcad"), {3, 0.3}, {-2.5, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("dabc"), {4, 0.4}, {-2.6, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("dbca"), {5, 0.5}, {-2.7, -5.0, 3.0}}, + {11, cell_kind::cable, internal_hash("dcab"), {6, 0.6}, {-2.8, -5.0, 3.0}}, }; } @@ -107,15 +108,15 @@ TEST(network_selection, source_label) { const auto s = thingify(network_selection::source_label({"b", "e"}), network_label_dict()); for (const auto& site: test_sites) { - EXPECT_EQ(site.label == "b" || site.label == "e", + EXPECT_EQ(site.label == internal_hash("b") || site.label == internal_hash("e"), s->select_source(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_target(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { for (const auto& target: test_sites) { - EXPECT_EQ( - source.label == "b" || source.label == "e", s->select_connection(source, target)); + EXPECT_EQ(source.label == internal_hash("b") || source.label == internal_hash("e"), + s->select_connection(source, target)); } } } @@ -124,15 +125,15 @@ TEST(network_selection, target_label) { const auto s = thingify(network_selection::target_label({"b", "e"}), network_label_dict()); for (const auto& site: test_sites) { - EXPECT_EQ(site.label == "b" || site.label == "e", + EXPECT_EQ(site.label == internal_hash("b") || site.label == internal_hash("e"), s->select_target(site.kind, site.gid, site.label)); EXPECT_TRUE(s->select_source(site.kind, site.gid, site.label)); } for (const auto& source: test_sites) { for (const auto& target: test_sites) { - EXPECT_EQ( - target.label == "b" || target.label == "e", s->select_connection(source, target)); + EXPECT_EQ(target.label == internal_hash("b") || target.label == internal_hash("e"), + s->select_connection(source, target)); } } } @@ -428,10 +429,10 @@ TEST(network_selection, random_seed) { TEST(network_selection, random_reproducibility) { const auto s = thingify(network_selection::random(42, 0.5), network_label_dict()); - std::vector sites = { - {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, - {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, - {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + std::vector sites = { + {0, cell_kind::cable, internal_hash("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, internal_hash("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, internal_hash("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; std::vector ref = {1, 1, 0, 1, 1, 0, 0, 0, 0}; @@ -445,8 +446,8 @@ TEST(network_selection, random_reproducibility) { } TEST(network_selection, custom) { - auto inter_cell_func = [](const network_connection_info& c) { - return c.source.gid != c.target.gid; + auto inter_cell_func = [](const network_site_info& source, const network_site_info& target) { + return source.gid != target.gid; }; const auto s = thingify(network_selection::custom(inter_cell_func), network_label_dict()); const auto s_ref = thingify(network_selection::inter_cell(), network_label_dict()); @@ -552,10 +553,10 @@ TEST(network_value, uniform_distribution_reproducibility) { const auto v = thingify(network_value::uniform_distribution(42, {-5.0, 3.0}), network_label_dict()); - std::vector sites = { - {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, - {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, - {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + std::vector sites = { + {0, cell_kind::cable, internal_hash("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, internal_hash("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, internal_hash("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; std::vector ref = { 1.08007184307616289, @@ -598,7 +599,7 @@ TEST(network_value, normal_distribution) { sample_dev = std::sqrt(sample_dev / (test_sites.size() * test_sites.size())); EXPECT_NEAR(sample_mean, mean, 1e-1); - EXPECT_NEAR(sample_dev, std_dev, 1e-1); + EXPECT_NEAR(sample_dev, std_dev, 1.5e-1); } TEST(network_value, normal_distribution_reproducibility) { @@ -607,10 +608,10 @@ TEST(network_value, normal_distribution_reproducibility) { const auto v = thingify(network_value::normal_distribution(42, mean, std_dev), network_label_dict()); - std::vector sites = { - {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, - {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, - {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + std::vector sites = { + {0, cell_kind::cable, internal_hash("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, internal_hash("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, internal_hash("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; std::vector ref = { 9.27330832850693909, @@ -671,10 +672,10 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { network_value::truncated_normal_distribution(42, mean, std_dev, {lower_bound, upper_bound}), network_label_dict()); - std::vector sites = { - {0, 0, cell_kind::cable, "a", {1, 0.5}, {1.2, 2.3, 3.4}}, - {0, 1, cell_kind::cable, "b", {0, 0.1}, {-1.0, 0.5, 0.7}}, - {1, 0, cell_kind::benchmark, "c", {0, 0.0}, {20.5, -59.5, 5.0}}, + std::vector sites = { + {0, cell_kind::cable, internal_hash("a"), {1, 0.5}, {1.2, 2.3, 3.4}}, + {0, cell_kind::cable, internal_hash("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, + {1, cell_kind::benchmark, internal_hash("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; std::vector ref = { 2.81708378066100629, @@ -698,8 +699,8 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { } TEST(network_value, custom) { - auto func = [](const network_connection_info& c) { - return c.source.global_location.x + c.target.global_location.x; + auto func = [](const network_site_info& source, const network_site_info& target) { + return source.global_location.x + target.global_location.x; }; const auto v = thingify(network_value::custom(func), network_label_dict()); From cf2a5f89e33be74bf4ce993ff1fd27dc3da3a7f5 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Tue, 17 Oct 2023 20:59:38 +0200 Subject: [PATCH 62/84] Add internal_hash to hash_def and integrate w/ hash_value. --- arbor/cable_cell_param.cpp | 2 +- arbor/include/arbor/common_types.hpp | 2 +- arbor/include/arbor/util/hash_def.hpp | 38 ++++++++++++++++++++++++--- arbor/label_resolution.cpp | 6 ++--- arbor/label_resolution.hpp | 2 +- arbor/util/hash.hpp | 35 ------------------------ test/unit/test_cable_cell.cpp | 2 +- test/unit/test_label_resolution.cpp | 2 +- 8 files changed, 42 insertions(+), 47 deletions(-) delete mode 100644 arbor/util/hash.hpp diff --git a/arbor/cable_cell_param.cpp b/arbor/cable_cell_param.cpp index 55c4cdd079..6f47a0cca8 100644 --- a/arbor/cable_cell_param.cpp +++ b/arbor/cable_cell_param.cpp @@ -10,7 +10,7 @@ #include #include -#include "util/hash.hpp" +#include #include "util/maputil.hpp" #include "util/strprintf.hpp" diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index 1d5fa262e8..529f700a3a 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -21,7 +21,7 @@ namespace arb { // Internal hashes use this 64bit id -using hash_type = std::uint64_t; +using hash_type = std::size_t; // For identifying cells globally. diff --git a/arbor/include/arbor/util/hash_def.hpp b/arbor/include/arbor/util/hash_def.hpp index fc6770cc41..304fe4c58d 100644 --- a/arbor/include/arbor/util/hash_def.hpp +++ b/arbor/include/arbor/util/hash_def.hpp @@ -17,16 +17,46 @@ */ #include -#include +#include // Helpers for forming hash values of compounds objects. namespace arb { -inline std::size_t hash_value_combine(std::size_t n) { - return n; +// Non-cryptographic hash function for mapping strings to internal +// identifiers. Concretely, FNV-1a hash function taken from +// +// http://www.isthe.com/chongo/tech/comp/fnv/index.html +// +// NOTE: It may be worth it considering different hash functions in +// the future that have better characteristic, xxHash or Murmur +// look interesting but are more complex and likely require adding +// external dependencies. +// NOTE: this is the obligatory comment on a better hash function +// that will be here until the end of time. + +template +inline constexpr std::size_t internal_hash(T&& data) { + if constexpr (std::is_convertible_v) { + constexpr std::size_t prime = 0x100000001b3; + constexpr std::size_t offset_basis=0xcbf29ce484222325; + + std::size_t hash = offset_basis; + + for (uint8_t byte: std::string_view{data}) { + hash = hash ^ byte; + hash = hash * prime; + } + + return hash; + } else { + return std::hash>{}(data); + } } +inline +std::size_t hash_value_combine(std::size_t n) { return n; } + template std::size_t hash_value_combine(std::size_t n, std::size_t m, T... tail) { constexpr std::size_t prime2 = 54517; @@ -36,7 +66,7 @@ std::size_t hash_value_combine(std::size_t n, std::size_t m, T... tail) { template std::size_t hash_value(const T&... t) { constexpr std::size_t prime1 = 93481; - return hash_value_combine(prime1, std::hash{}(t)...); + return hash_value_combine(prime1, internal_hash(t)...); } } diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index 5506a79e6c..14786144f2 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -1,15 +1,15 @@ #include #include -#include +#include #include #include #include #include +#include #include "label_resolution.hpp" #include "util/partition.hpp" -#include "util/rangeutil.hpp" #include "util/span.hpp" namespace arb { @@ -22,7 +22,7 @@ cell_label_range::cell_label_range(std::vector size_vec, { std::transform(label_vec.begin(), label_vec.end(), std::back_inserter(labels), - internal_hash); + internal_hash); arb_assert(check_invariant()); }; diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index 92d0270fe1..508c9817e9 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -9,7 +9,7 @@ #include #include "util/partition.hpp" -#include "util/hash.hpp" +#include namespace arb { diff --git a/arbor/util/hash.hpp b/arbor/util/hash.hpp deleted file mode 100644 index 988fec2a12..0000000000 --- a/arbor/util/hash.hpp +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include - -#include - -namespace arb { -// Non-cryptographic hash function for mapping strings to internal -// identifiers. Concretely, FNV-1a hash function taken from -// -// http://www.isthe.com/chongo/tech/comp/fnv/index.html -// -// NOTE: It may be worth it considering different hash functions in -// the future that have better characteristic, xxHash or Murmur -// look interesting but are more complex and likely require adding -// external dependencies. -// NOTE: this is the obligatory comment on a better hash function -// that will be here until the end of time. - -constexpr hash_type offset_basis = 0xcbf29ce484222325; -constexpr hash_type prime = 0x100000001b3; - -constexpr hash_type internal_hash(std::string_view data) { - hash_type hash = offset_basis; - - for (uint8_t byte: data) { - hash = hash ^ byte; - hash = hash * prime; - } - - return hash; -} - -} diff --git a/test/unit/test_cable_cell.cpp b/test/unit/test_cable_cell.cpp index c1066e9827..2757de30df 100644 --- a/test/unit/test_cable_cell.cpp +++ b/test/unit/test_cable_cell.cpp @@ -1,6 +1,6 @@ #include #include "../common_cells.hpp" -#include "util/hash.hpp" +#include #include #include diff --git a/test/unit/test_label_resolution.cpp b/test/unit/test_label_resolution.cpp index 2dacbdd607..d4f5ade50c 100644 --- a/test/unit/test_label_resolution.cpp +++ b/test/unit/test_label_resolution.cpp @@ -12,7 +12,7 @@ std::vector make_labels(const std::vector& ls) { std::vector res; std::transform(ls.begin(), ls.end(), std::back_inserter(res), - internal_hash); + internal_hash); return res; } From 682a6aacf8812a1b0212bbb89c4f89901e4a5e59 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:20:12 +0100 Subject: [PATCH 63/84] Add some testing, treat pointers. --- arbor/include/arbor/util/hash_def.hpp | 51 +++++++++++++++++++-------- test/unit/CMakeLists.txt | 1 + 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/arbor/include/arbor/util/hash_def.hpp b/arbor/include/arbor/util/hash_def.hpp index 304fe4c58d..c321c945a9 100644 --- a/arbor/include/arbor/util/hash_def.hpp +++ b/arbor/include/arbor/util/hash_def.hpp @@ -18,6 +18,9 @@ #include #include +#include + +#include // Helpers for forming hash values of compounds objects. @@ -37,36 +40,56 @@ namespace arb { template inline constexpr std::size_t internal_hash(T&& data) { + using D = std::decay_t; + constexpr std::size_t prime = 0x100000001b3; + constexpr std::size_t offset_basis = 0xcbf29ce484222325; + static_assert(!std::is_pointer_v || std::is_same_v || std::is_convertible_v, + "Pointer types except void* will not be hashed."); if constexpr (std::is_convertible_v) { - constexpr std::size_t prime = 0x100000001b3; - constexpr std::size_t offset_basis=0xcbf29ce484222325; - std::size_t hash = offset_basis; - for (uint8_t byte: std::string_view{data}) { hash = hash ^ byte; hash = hash * prime; } - return hash; - } else { - return std::hash>{}(data); } + if constexpr (std::is_integral_v) { + unsigned long long bytes = data; + std::size_t hash = offset_basis; + for (int ix = 0; ix < sizeof(data); ++ix) { + uint8_t byte = bytes & 255; + bytes >>= 8; + hash = hash ^ byte; + hash = hash * prime; + } + return hash; + } + if constexpr (std::is_pointer_v) { + unsigned long long bytes = reinterpret_cast(data); + std::size_t hash = offset_basis; + for (int ix = 0; ix < sizeof(data); ++ix) { + uint8_t byte = bytes & 255; + bytes >>= 8; + hash = hash ^ byte; + hash = hash * prime; + } + return hash; + } + return std::hash{}(data); } inline std::size_t hash_value_combine(std::size_t n) { return n; } -template -std::size_t hash_value_combine(std::size_t n, std::size_t m, T... tail) { - constexpr std::size_t prime2 = 54517; - return hash_value_combine(prime2*n + m, tail...); +template +std::size_t hash_value_combine(std::size_t n, const T& head, const Ts&... tail) { + constexpr std::size_t prime = 54517; + return hash_value_combine(prime*n + internal_hash(head), tail...); } template -std::size_t hash_value(const T&... t) { - constexpr std::size_t prime1 = 93481; - return hash_value_combine(prime1, internal_hash(t)...); +std::size_t hash_value(const T&... ts) { + return hash_value_combine(0, ts...); } } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 1a85d77360..b68a6de6d5 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -80,6 +80,7 @@ set(unit_sources test_forest.cpp test_fvm_layout.cpp test_fvm_lowered.cpp + test_hash.cpp test_diffusion.cpp test_iexpr.cpp test_index.cpp From ef792c93e45b2eb73b7e48324fae0939de1fc174 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:43:42 +0100 Subject: [PATCH 64/84] Merge part two. --- arbor/fvm_lowered_cell_impl.hpp | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index fdfb504784..94716b21c5 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -327,8 +327,37 @@ add_labels(cell_label_range& clr, const cable_cell::lid_range_map& ranges) { return count; } -fvm_initialization_data fvm_lowered_cell_impl::initialize(const std::vector& gids, - const recipe& rec) { +template void +fvm_lowered_cell_impl::add_probes(const std::vector& gids, + const std::vector& cells, + const recipe& rec, + const fvm_cv_discretization& D, + const std::unordered_map& mechptr_by_name, + const fvm_mechanism_data& mech_data, + const std::vector& target_handles, + probe_association_map& probe_map) { + auto ncell = gids.size(); + + std::vector probe_data; + for (auto cell_idx: util::make_span(ncell)) { + cell_gid_type gid = gids[cell_idx]; + const auto& rec_probes = rec.get_probes(gid); + for (const auto& pi: rec_probes) { + resolve_probe_address(probe_data, cells, cell_idx, pi.address, D, mech_data, target_handles, mechptr_by_name); + if (!probe_data.empty()) { + cell_address_type addr{gid, pi.tag}; + if (probe_map.count(addr)) throw dup_cell_probe(cell_kind::cable, gid, pi.tag); + for (auto& data: probe_data) { + probe_map.insert(addr, std::move(data)); + } + } + } + } +} + +template fvm_initialization_data +fvm_lowered_cell_impl::initialize(const std::vector& gids, + const recipe& rec) { using std::any_cast; using util::count_along; using util::make_span; From 97ca7a4e6ef2e15715b3812d48c4ad5abac28c3a Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:45:29 +0100 Subject: [PATCH 65/84] The missing test. --- test/unit/test_hash.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 test/unit/test_hash.cpp diff --git a/test/unit/test_hash.cpp b/test/unit/test_hash.cpp new file mode 100644 index 0000000000..444830cc42 --- /dev/null +++ b/test/unit/test_hash.cpp @@ -0,0 +1,24 @@ +#include + +#include + +#include + +TEST(hash, string_eq) { + ASSERT_EQ(arb::hash_value("foobar"), arb::hash_value(std::string{"foobar"})); + ASSERT_EQ(arb::hash_value("foobar"), arb::internal_hash("foobar")); + ASSERT_NE(arb::hash_value("foobar"), arb::internal_hash("barfoo")); +} + +TEST(hash, doesnt_compile) { + double foo = 42; + // Sadly we cannot check static assertions... this shoudln't compile + // EXPECT_ANY_THROW(arb::hash_value(&foo)); + // this should + arb::hash_value((void*) &foo); +} + +// check that we do not fall into the trap of the STL... +TEST(hash, integral_is_not_identity) { + ASSERT_NE(arb::hash_value(42), 42); +} From c583b2f25699e5c50fbc29075053a6b22a47f3d8 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 29 Nov 2023 09:46:40 +0100 Subject: [PATCH 66/84] Shuffle internal hash and combine into a detail namespace. --- arbor/benchmark_cell_group.cpp | 5 ++- arbor/cable_cell_param.cpp | 5 +-- arbor/include/arbor/cable_cell_param.hpp | 2 -- arbor/include/arbor/util/hash_def.hpp | 12 +++---- arbor/label_resolution.cpp | 8 ++--- arbor/label_resolution.hpp | 3 +- arbor/lif_cell_group.cpp | 4 +-- arbor/spike_source_cell_group.cpp | 4 +-- test/unit/test_cable_cell.cpp | 27 ++++++++------- test/unit/test_fvm_lowered.cpp | 44 +++++++++++------------- test/unit/test_hash.cpp | 4 +-- test/unit/test_label_resolution.cpp | 24 ++++++------- 12 files changed, 65 insertions(+), 77 deletions(-) diff --git a/arbor/benchmark_cell_group.cpp b/arbor/benchmark_cell_group.cpp index d2ac067182..36887fc8ea 100644 --- a/arbor/benchmark_cell_group.cpp +++ b/arbor/benchmark_cell_group.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -40,8 +39,8 @@ benchmark_cell_group::benchmark_cell_group(const std::vector& gid for (const auto& c: cells_) { cg_sources.add_cell(); cg_targets.add_cell(); - cg_sources.add_label(internal_hash(c.source), {0, 1}); - cg_targets.add_label(internal_hash(c.target), {0, 1}); + cg_sources.add_label(hash_value(c.source), {0, 1}); + cg_targets.add_label(hash_value(c.target), {0, 1}); } benchmark_cell_group::reset(); diff --git a/arbor/cable_cell_param.cpp b/arbor/cable_cell_param.cpp index 6f47a0cca8..e3b425de7a 100644 --- a/arbor/cable_cell_param.cpp +++ b/arbor/cable_cell_param.cpp @@ -1,7 +1,4 @@ -#include #include -#include -#include #include #include #include @@ -122,7 +119,7 @@ decor& decor::paint(region where, paintable what) { } decor& decor::place(locset where, placeable what, cell_tag_type label) { - auto hash = internal_hash(label); + auto hash = hash_value(label); if (hashes_.count(hash) && hashes_.at(hash) != label) { throw arbor_internal_error{util::strprintf("Hash collision {} ./. {}", label, hashes_.at(hash))}; } diff --git a/arbor/include/arbor/cable_cell_param.hpp b/arbor/include/arbor/cable_cell_param.hpp index 62ed0532ed..59bb5bb014 100644 --- a/arbor/include/arbor/cable_cell_param.hpp +++ b/arbor/include/arbor/cable_cell_param.hpp @@ -1,12 +1,10 @@ #pragma once #include -#include #include #include #include #include -#include #include #include diff --git a/arbor/include/arbor/util/hash_def.hpp b/arbor/include/arbor/util/hash_def.hpp index c321c945a9..933a3237e2 100644 --- a/arbor/include/arbor/util/hash_def.hpp +++ b/arbor/include/arbor/util/hash_def.hpp @@ -20,12 +20,11 @@ #include #include -#include - // Helpers for forming hash values of compounds objects. - namespace arb { +namespace detail { + // Non-cryptographic hash function for mapping strings to internal // identifiers. Concretely, FNV-1a hash function taken from // @@ -87,10 +86,11 @@ std::size_t hash_value_combine(std::size_t n, const T& head, const Ts&... tail) return hash_value_combine(prime*n + internal_hash(head), tail...); } -template -std::size_t hash_value(const T&... ts) { - return hash_value_combine(0, ts...); } + +// User facing API +template +std::size_t hash_value(const T&... ts) { return detail::hash_value_combine(0, ts...); } } #define ARB_DEFINE_HASH(type,...)\ diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index 14786144f2..29c8b76454 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -22,7 +22,7 @@ cell_label_range::cell_label_range(std::vector size_vec, { std::transform(label_vec.begin(), label_vec.end(), std::back_inserter(labels), - internal_hash); + hash_value); arb_assert(check_invariant()); }; @@ -93,12 +93,12 @@ lid_hopefully label_resolution_map::range_set::at(unsigned idx) const { } const label_resolution_map::range_set& label_resolution_map::at(cell_gid_type gid, const cell_tag_type& tag) const { - return map.at(gid).at(internal_hash(tag)); + return map.at(gid).at(hash_value(tag)); } std::size_t label_resolution_map::count(cell_gid_type gid, const cell_tag_type& tag) const { if (!map.count(gid)) return 0u; - return map.at(gid).count(internal_hash(tag)); + return map.at(gid).count(hash_value(tag)); } label_resolution_map::label_resolution_map(const cell_labels_and_gids& clg) { @@ -216,7 +216,7 @@ lid_hopefully update_state(resolver::state_variant& v, cell_lid_type resolver::resolve(cell_gid_type gid, const cell_local_label_type& label) { const auto& [tag, pol] = label; - auto hash = internal_hash(tag); + auto hash = hash_value(tag); if (!label_map_->count(gid, tag)) throw arb::bad_connection_label(gid, tag, "label does not exist"); const auto& range_set = label_map_->at(gid, tag); diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index 508c9817e9..6f0b14d47b 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -8,7 +8,6 @@ #include #include -#include "util/partition.hpp" #include namespace arb { @@ -26,7 +25,7 @@ struct ARB_ARBOR_API cell_label_range { cell_label_range& operator=(const cell_label_range&) = default; cell_label_range& operator=(cell_label_range&&) = default; - cell_label_range(std::vector size_vec, std::vector label_vec, std::vector range_vec); + cell_label_range(std::vector size_vec, std::vector label_vec, std::vector rapfnge_vec); cell_label_range(std::vector size_vec, std::vector label_vec, std::vector range_vec); void add_cell(); diff --git a/arbor/lif_cell_group.cpp b/arbor/lif_cell_group.cpp index 8fdee746b3..6698d0edef 100644 --- a/arbor/lif_cell_group.cpp +++ b/arbor/lif_cell_group.cpp @@ -24,8 +24,8 @@ lif_cell_group::lif_cell_group(const std::vector& gids, // tell our caller about this cell's connections cg_sources.add_cell(); cg_targets.add_cell(); - cg_sources.add_label(internal_hash(cell.source), {0, 1}); - cg_targets.add_label(internal_hash(cell.target), {0, 1}); + cg_sources.add_label(hash_value(cell.source), {0, 1}); + cg_targets.add_label(hash_value(cell.target), {0, 1}); // insert probes where needed auto probes = rec.get_probes(gid); for (const auto& probe: probes) { diff --git a/arbor/spike_source_cell_group.cpp b/arbor/spike_source_cell_group.cpp index 41bc9705c0..651980aa44 100644 --- a/arbor/spike_source_cell_group.cpp +++ b/arbor/spike_source_cell_group.cpp @@ -1,5 +1,3 @@ -#include - #include #include #include @@ -33,7 +31,7 @@ spike_source_cell_group::spike_source_cell_group( try { auto cell = util::any_cast(rec.get_cell_description(gid)); time_sequences_.emplace_back(cell.seqs); - cg_sources.add_label(internal_hash(cell.source), {0, 1}); + cg_sources.add_label(hash_value(cell.source), {0, 1}); } catch (std::bad_any_cast& e) { throw bad_cell_description(cell_kind::spike_source, gid); diff --git a/test/unit/test_cable_cell.cpp b/test/unit/test_cable_cell.cpp index 2757de30df..8b0a4d460d 100644 --- a/test/unit/test_cable_cell.cpp +++ b/test/unit/test_cable_cell.cpp @@ -1,7 +1,8 @@ #include + #include "../common_cells.hpp" -#include +#include #include #include @@ -46,20 +47,20 @@ TEST(cable_cell, lid_ranges) { const auto& src_ranges = cell.detector_ranges(); const auto& tgt_ranges = cell.synapse_ranges(); - EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t0"))); - EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t1"))); - EXPECT_EQ(1u, src_ranges.count(internal_hash("s0"))); - EXPECT_EQ(1u, tgt_ranges.count(internal_hash("t2"))); - EXPECT_EQ(1u, src_ranges.count(internal_hash("s1"))); - EXPECT_EQ(2u, tgt_ranges.count(internal_hash("t3"))); + EXPECT_EQ(1u, tgt_ranges.count(hash_value("t0"))); + EXPECT_EQ(1u, tgt_ranges.count(hash_value("t1"))); + EXPECT_EQ(1u, src_ranges.count(hash_value("s0"))); + EXPECT_EQ(1u, tgt_ranges.count(hash_value("t2"))); + EXPECT_EQ(1u, src_ranges.count(hash_value("s1"))); + EXPECT_EQ(2u, tgt_ranges.count(hash_value("t3"))); - auto r1 = tgt_ranges.equal_range(internal_hash("t0")).first->second; - auto r2 = tgt_ranges.equal_range(internal_hash("t1")).first->second; - auto r3 = src_ranges.equal_range(internal_hash("s0")).first->second; - auto r4 = tgt_ranges.equal_range(internal_hash("t2")).first->second; - auto r5 = src_ranges.equal_range(internal_hash("s1")).first->second; + auto r1 = tgt_ranges.equal_range(hash_value("t0")).first->second; + auto r2 = tgt_ranges.equal_range(hash_value("t1")).first->second; + auto r3 = src_ranges.equal_range(hash_value("s0")).first->second; + auto r4 = tgt_ranges.equal_range(hash_value("t2")).first->second; + auto r5 = src_ranges.equal_range(hash_value("s1")).first->second; - auto r6_range = tgt_ranges.equal_range(internal_hash("t3")); + auto r6_range = tgt_ranges.equal_range(hash_value("t3")); auto r6_0 = r6_range.first; auto r6_1 = std::next(r6_range.first); if (r6_0->second.begin != 4u) { diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index f06c22c4af..28cc19562e 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -25,11 +24,8 @@ #include "backends/multicore/fvm.hpp" #include "fvm_lowered_cell.hpp" #include "fvm_lowered_cell_impl.hpp" -#include "util/meta.hpp" -#include "util/maputil.hpp" #include "util/rangeutil.hpp" #include "util/span.hpp" -#include "util/transform.hpp" #include "common.hpp" #include "mech_private_field_access.hpp" @@ -994,10 +990,10 @@ TEST(fvm_lowered, label_data) { auto clg = cell_labels_and_gids(fvm_info.target_data, gids); std::vector expected_sizes = {2, 0, 0, 2, 0, 0, 2, 0, 0, 2}; std::vector> expected_labeled_ranges = { - {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, - {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, - {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}}, - {internal_hash("1_synapse"), {4, 5}}, {internal_hash("4_synapses"), {0, 4}} + {hash_value("1_synapse"), {4, 5}}, {hash_value("4_synapses"), {0, 4}}, + {hash_value("1_synapse"), {4, 5}}, {hash_value("4_synapses"), {0, 4}}, + {hash_value("1_synapse"), {4, 5}}, {hash_value("4_synapses"), {0, 4}}, + {hash_value("1_synapse"), {4, 5}}, {hash_value("4_synapses"), {0, 4}} }; std::vector> actual_labeled_ranges; @@ -1033,16 +1029,16 @@ TEST(fvm_lowered, label_data) { auto clg = cell_labels_and_gids(fvm_info.source_data, gids); std::vector expected_sizes = {1, 2, 2, 1, 2, 2, 1, 2, 2, 1}; std::vector> expected_labeled_ranges = { - {internal_hash("1_detector"), {0, 1}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("1_detector"), {0, 1}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("1_detector"), {0, 1}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("2_detectors"), {3, 5}}, {internal_hash("3_detectors"), {0, 3}}, - {internal_hash("1_detector"), {0, 1}} + {hash_value("1_detector"), {0, 1}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("1_detector"), {0, 1}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("1_detector"), {0, 1}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("2_detectors"), {3, 5}}, {hash_value("3_detectors"), {0, 3}}, + {hash_value("1_detector"), {0, 1}} }; std::vector> actual_labeled_ranges; @@ -1077,12 +1073,12 @@ TEST(fvm_lowered, label_data) { auto clg = cell_labels_and_gids(fvm_info.gap_junction_data, gids); std::vector expected_sizes = {0, 2, 2, 0, 2, 2, 0, 2, 2, 0}; std::vector> expected_labeled_ranges = { - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, - {internal_hash("1_gap_junction"), {2, 3}}, {internal_hash("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, + {hash_value("1_gap_junction"), {2, 3}}, {hash_value("2_gap_junctions"), {0, 2}}, }; EXPECT_EQ(clg.gids, gids); diff --git a/test/unit/test_hash.cpp b/test/unit/test_hash.cpp index 444830cc42..bcab44a6d4 100644 --- a/test/unit/test_hash.cpp +++ b/test/unit/test_hash.cpp @@ -6,8 +6,8 @@ TEST(hash, string_eq) { ASSERT_EQ(arb::hash_value("foobar"), arb::hash_value(std::string{"foobar"})); - ASSERT_EQ(arb::hash_value("foobar"), arb::internal_hash("foobar")); - ASSERT_NE(arb::hash_value("foobar"), arb::internal_hash("barfoo")); + ASSERT_EQ(arb::hash_value("foobar"), arb::hash_value("foobar")); + ASSERT_NE(arb::hash_value("foobar"), arb::hash_value("barfoo")); } TEST(hash, doesnt_compile) { diff --git a/test/unit/test_label_resolution.cpp b/test/unit/test_label_resolution.cpp index d4f5ade50c..32c262ee3d 100644 --- a/test/unit/test_label_resolution.cpp +++ b/test/unit/test_label_resolution.cpp @@ -12,7 +12,7 @@ std::vector make_labels(const std::vector& ls) { std::vector res; std::transform(ls.begin(), ls.end(), std::back_inserter(res), - internal_hash); + hash_value); return res; } @@ -23,7 +23,7 @@ TEST(test_cell_label_range, build) { // Test add_cell and add_label auto b0 = cell_label_range(); - EXPECT_THROW(b0.add_label(internal_hash("l0"), {0u, 1u}), arb::arbor_internal_error); + EXPECT_THROW(b0.add_label(hash_value("l0"), {0u, 1u}), arb::arbor_internal_error); EXPECT_TRUE(b0.sizes.empty()); EXPECT_TRUE(b0.labels.empty()); EXPECT_TRUE(b0.ranges.empty()); @@ -40,16 +40,16 @@ TEST(test_cell_label_range, build) { auto b2 = cell_label_range(); b2.add_cell(); - b2.add_label(internal_hash("l0"), {0u, 1u}); - b2.add_label(internal_hash("l0"), {3u, 13u}); - b2.add_label(internal_hash("l1"), {0u, 5u}); + b2.add_label(hash_value("l0"), {0u, 1u}); + b2.add_label(hash_value("l0"), {3u, 13u}); + b2.add_label(hash_value("l1"), {0u, 5u}); b2.add_cell(); b2.add_cell(); - b2.add_label(internal_hash("l2"), {6u, 8u}); - b2.add_label(internal_hash("l3"), {1u, 0u}); - b2.add_label(internal_hash("l4"), {7u, 2u}); - b2.add_label(internal_hash("l4"), {7u, 2u}); - b2.add_label(internal_hash("l2"), {7u, 2u}); + b2.add_label(hash_value("l2"), {6u, 8u}); + b2.add_label(hash_value("l3"), {1u, 0u}); + b2.add_label(hash_value("l4"), {7u, 2u}); + b2.add_label(hash_value("l4"), {7u, 2u}); + b2.add_label(hash_value("l2"), {7u, 2u}); EXPECT_EQ((ivec{3u, 0u, 5u}), b2.sizes); EXPECT_EQ(make_labels(svec{"l0", "l0", "l1", "l2", "l3", "l4", "l4", "l2"}), b2.labels); EXPECT_EQ((lvec{{0u, 1u}, {3u, 13u}, {0u, 5u}, {6u, 8u}, {1u, 0u}, {7u, 2u}, {7u, 2u}, {7u, 2u}}), b2.ranges); @@ -57,8 +57,8 @@ TEST(test_cell_label_range, build) { auto b3 = cell_label_range(); b3.add_cell(); - b3.add_label(internal_hash("r0"), {0u, 9u}); - b3.add_label(internal_hash("r1"), {10u, 10u}); + b3.add_label(hash_value("r0"), {0u, 9u}); + b3.add_label(hash_value("r1"), {10u, 10u}); b3.add_cell(); EXPECT_EQ((ivec{2u, 0u}), b3.sizes); EXPECT_EQ(make_labels From 1b3bbb839268a5be6faae2ffcc81efbe8b7b1c0a Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 20 Dec 2023 12:57:53 +0100 Subject: [PATCH 67/84] use cbprng random generator definition --- arbor/include/arbor/network.hpp | 33 +++++++++++------- arbor/network.cpp | 8 ++--- arbor/network_impl.cpp | 35 +++++++++++-------- test/unit/test_network.cpp | 62 +++++++++++++++------------------ 4 files changed, 73 insertions(+), 65 deletions(-) diff --git a/arbor/include/arbor/network.hpp b/arbor/include/arbor/network.hpp index d5717d4d3c..262f318bdf 100644 --- a/arbor/include/arbor/network.hpp +++ b/arbor/include/arbor/network.hpp @@ -25,16 +25,16 @@ namespace arb { using network_hash_type = std::uint64_t; struct ARB_SYMBOL_VISIBLE network_site_info { - // network_site_info(cell_gid_type gid, - // cell_kind kind, - // hash_type label, - // mlocation location, - // mpoint global_location): - // gid(gid), - // kind(kind), - // label(label), - // location(location), - // global_location(global_location) {} + network_site_info(cell_gid_type gid, + cell_kind kind, + hash_type label, + mlocation location, + mpoint global_location): + gid(gid), + kind(kind), + label(label), + location(location), + global_location(global_location) {} cell_gid_type gid; cell_kind kind; @@ -53,13 +53,22 @@ struct ARB_SYMBOL_VISIBLE network_connection_info { network_site_info source, target; double weight, delay; + network_connection_info(network_site_info source, + network_site_info target, + double weight, + double delay): + source(source), + target(target), + weight(weight), + delay(delay) {} + ARB_ARBOR_API friend std::ostream& operator<<(std::ostream& os, const network_connection_info& s); }; ARB_DEFINE_LEXICOGRAPHIC_ORDERING(network_connection_info, - (a.source, a.target), - (b.source, b.target)) + (a.source, a.target, a.weight, a.delay), + (b.source, b.target, b.weight, b.delay)) struct network_selection_impl; diff --git a/arbor/network.cpp b/arbor/network.cpp index 42e18ebe11..ba3ece380c 100644 --- a/arbor/network.cpp +++ b/arbor/network.cpp @@ -15,6 +15,7 @@ #include #include +#include "backends/rand_impl.hpp" #include "network_impl.hpp" namespace arb { @@ -38,12 +39,11 @@ std::uint64_t location_hash(const mlocation& loc) { double uniform_rand(std::array seed, const network_site_info& source, const network_site_info& target) { - using rand_type = r123::Threefry4x64; - const rand_type::ctr_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; + const cbprng::array_type seed_input = {{seed[0], seed[1], seed[2], seed[3]}}; - const rand_type::key_type key = { + const cbprng::array_type key = { {source.gid, location_hash(source.location), target.gid, location_hash(target.location)}}; - rand_type gen; + cbprng::generator gen; return r123::u01(gen(seed_input, key)[0]); } diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index 82ded2f274..b19c940dbc 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -58,6 +58,15 @@ void push_back(const domain_decomposition& dom_dec, dom_dec.index_on_domain(target.info.gid)}); } +void push_back(const domain_decomposition& dom_dec, + std::vector& vec, + const network_site_info_extended& source, + const network_site_info_extended& target, + double weight, + double delay) { + vec.emplace_back(source.info, target.info, weight, delay); +} + template std::vector generate_network_connections(const recipe& rec, const context& ctx, @@ -109,9 +118,8 @@ std::vector generate_network_connections(const recipe& rec, throw bad_cell_description(kind, gid); } - auto lid_to_label = - [](const std::unordered_multimap& map, - cell_lid_type lid) -> hash_type { + auto lid_to_label = [](const std::unordered_multimap& map, + cell_lid_type lid) -> hash_type { for (const auto& [label, range]: map) { if (lid >= range.begin && lid < range.end) return label; } @@ -203,16 +211,15 @@ std::vector generate_network_connections(const recipe& rec, } } - auto src_sites = std::move(src_site_batches.back()); src_site_batches.pop_back(); for (const auto& batch: src_site_batches) - src_sites.insert(src_sites.end(), batch.begin(), batch.end()); + src_sites.insert(src_sites.end(), batch.begin(), batch.end()); auto tgt_sites = std::move(tgt_site_batches.back()); tgt_site_batches.pop_back(); for (const auto& batch: tgt_site_batches) - tgt_sites.insert(tgt_sites.end(), batch.begin(), batch.end()); + tgt_sites.insert(tgt_sites.end(), batch.begin(), batch.end()); // create octree const std::size_t max_depth = selection.max_distance().has_value() ? 10 : 1; @@ -283,22 +290,20 @@ std::vector generate_connections(const recipe& rec, ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { - // auto connections = generate_network_connections(rec, ctx, dom_dec); + auto connections = generate_network_connections(rec, ctx, dom_dec); - // // generated connections may have different order each time due to multi-threading. - // // Sort before returning to user for reproducibility. - // std::sort(connections.begin(), connections.end()); + // generated connections may have different order each time due to multi-threading. + // Sort before returning to user for reproducibility. + std::sort(connections.begin(), connections.end()); - // return connections; - return {}; + return connections; } ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { auto ctx = arb::make_context(); - // auto decomp = arb::partition_load_balance(rec, ctx); + auto decomp = arb::partition_load_balance(rec, ctx); - // return generate_network_connections(rec, ctx, decomp); - return {}; + return generate_network_connections(rec, ctx, decomp); } } // namespace arb diff --git a/test/unit/test_network.cpp b/test/unit/test_network.cpp index 760cef7c7b..5005d76775 100644 --- a/test/unit/test_network.cpp +++ b/test/unit/test_network.cpp @@ -434,7 +434,7 @@ TEST(network_selection, random_reproducibility) { {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; - std::vector ref = {1, 1, 0, 1, 1, 0, 0, 0, 0}; + std::vector ref = {0, 1, 1, 0, 1, 1, 1, 1, 1}; std::size_t i = 0; for (const auto& source: sites) { @@ -558,17 +558,15 @@ TEST(network_value, uniform_distribution_reproducibility) { {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; - std::vector ref = { - 1.08007184307616289, - 0.688511962867972116, - -2.83551807417554347, - 0.688511962867972116, - 0.824599122495063064, - 1.4676501652366376, - -2.83551807417554347, - 1.4676501652366376, - -4.89687864740961487, - }; + std::vector ref = {0.152358748168055058, + -4.499410763769494004, + 2.208818591778559437, + -4.615620548394118394, + -2.883165846887783879, + -1.227842167463327083, + -3.938243119645829182, + -0.032436439374857962, + -3.392091783670958982}; std::size_t i = 0; for (const auto& source: sites) { @@ -613,17 +611,15 @@ TEST(network_value, normal_distribution_reproducibility) { {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; - std::vector ref = { - 9.27330832850693909, - 6.29969914563416733, - 1.81597827782531063, - 6.29969914563416733, - 8.12362497769330183, - 1.52496785710691851, - 1.81597827782531063, - 1.52496785710691851, - 1.49089022270221472, - }; + std::vector ref = {1.719220750899862038, + 3.792930460082558852, + 2.040797389626836544, + 4.690543724504090406, + 6.048018986729678304, + 3.468450499834405676, + 2.641602074572110492, + 4.045110924716160739, + 4.619102745858998382}; std::size_t i = 0; for (const auto& source: sites) { @@ -677,17 +673,15 @@ TEST(network_value, truncated_normal_distribution_reproducibility) { {0, cell_kind::cable, hash_value("b"), {0, 0.1}, {-1.0, 0.5, 0.7}}, {1, cell_kind::benchmark, hash_value("c"), {0, 0.0}, {20.5, -59.5, 5.0}}, }; - std::vector ref = { - 2.81708378066100629, - 4.82619033891918026, - 7.82585873628304096, - 4.82619033891918026, - 3.95914976610015401, - 5.74869285185564216, - 7.82585873628304096, - 5.74869285185564216, - 5.45028211635819293, - }; + std::vector ref = {6.933077952929343368, + 3.822103684855993055, + 3.081517892090295696, + 3.238387276739735476, + 3.739312586647523418, + 8.589787762424691664, + 7.554985027779592244, + 2.924644471896214348, + 3.085597042676768265}; std::size_t i = 0; for (const auto& source: sites) { From a96cbf1cf1d4de7030e1453e8c772eb7762e7b57 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 22 Dec 2023 14:17:49 +0100 Subject: [PATCH 68/84] refactor --- arbor/include/arbor/network_generation.hpp | 3 +++ arbor/network_impl.cpp | 9 ++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/arbor/include/arbor/network_generation.hpp b/arbor/include/arbor/network_generation.hpp index aa8e9e25ea..262cc426b6 100644 --- a/arbor/include/arbor/network_generation.hpp +++ b/arbor/include/arbor/network_generation.hpp @@ -8,10 +8,13 @@ namespace arb { // Generate and return list of connections from the network description of the recipe. // Does not include connections from the "connections_on" recipe function. +// Only returns connections with local cell targets as described in the domain decomposition. ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec); +// Generate and return list of ALL connections from the network description of the recipe. +// Does not include connections from the "connections_on" recipe function. ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec); } // namespace arb diff --git a/arbor/network_impl.cpp b/arbor/network_impl.cpp index b19c940dbc..11d138f089 100644 --- a/arbor/network_impl.cpp +++ b/arbor/network_impl.cpp @@ -68,7 +68,7 @@ void push_back(const domain_decomposition& dom_dec, } template -std::vector generate_network_connections(const recipe& rec, +std::vector generate_network_connections_impl(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { const auto description_opt = rec.network_description(); @@ -284,13 +284,13 @@ std::vector generate_network_connections(const recipe& rec, std::vector generate_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { - return generate_network_connections(rec, ctx, dom_dec); + return generate_network_connections_impl(rec, ctx, dom_dec); } ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec, const context& ctx, const domain_decomposition& dom_dec) { - auto connections = generate_network_connections(rec, ctx, dom_dec); + auto connections = generate_network_connections_impl(rec, ctx, dom_dec); // generated connections may have different order each time due to multi-threading. // Sort before returning to user for reproducibility. @@ -302,8 +302,7 @@ ARB_ARBOR_API std::vector generate_network_connections( ARB_ARBOR_API std::vector generate_network_connections(const recipe& rec) { auto ctx = arb::make_context(); auto decomp = arb::partition_load_balance(rec, ctx); - - return generate_network_connections(rec, ctx, decomp); + return generate_network_connections(rec, ctx, decomp); } } // namespace arb From dec80fa212e92b90b9c9f76159e3628928297991 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 22 Dec 2023 14:28:22 +0100 Subject: [PATCH 69/84] fix example --- arbor/simulation.cpp | 3 ++- example/network_description/network_description.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index accace2377..6cfd7cbdc0 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -2,13 +2,14 @@ #include #include -#include #include #include #include +#include #include #include #include +#include #include #include "epoch.hpp" diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index d2344b3292..2830eff60f 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -150,7 +150,7 @@ class ring_recipe: public arb::recipe { std::vector get_probes(cell_gid_type gid) const override { // Measure membrane voltage at end of soma. arb::mlocation loc{0, 0.0}; - return {arb::cable_probe_membrane_voltage{loc}}; + return {{arb::cable_probe_membrane_voltage{loc}, "Um"}}; } std::any get_global_properties(arb::cell_kind) const override { return gprop_; } @@ -207,7 +207,7 @@ int main(int argc, char** argv) { // Set up the probe that will measure voltage in the cell. // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0) - auto probeset_id = cell_member_type{0, 0}; + auto probeset_id = arb::cell_address_type{0, "Um"}; // The schedule for sampling is 10 samples every 1 ms. auto sched = arb::regular_schedule(1); // This is where the voltage samples will be stored as (time, value) pairs From 0e756151a5c68fc1d5c21a0a5f892b24801c6062 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 23 Jan 2024 11:14:04 +0100 Subject: [PATCH 70/84] fix python example --- python/example/network_description.py | 149 ++++++++++++-------------- scripts/run_cpp_examples.sh | 2 +- 2 files changed, 67 insertions(+), 84 deletions(-) diff --git a/python/example/network_description.py b/python/example/network_description.py index 1b39c7aac4..a0cf01a863 100755 --- a/python/example/network_description.py +++ b/python/example/network_description.py @@ -1,53 +1,48 @@ #!/usr/bin/env python3 # This script is included in documentation. Adapt line numbers if touched. -import arbor +import arbor as A +from arbor import units as U import pandas # You may have to pip install these import seaborn # You may have to pip install these from math import sqrt -import math - -# Construct a cell with the following morphology. -# The soma (at the root of the tree) is marked 's', and -# the end of each branch i is marked 'bi'. -# -# b1 -# / -# s----b0 -# \ -# b2 def make_cable_cell(gid): # (1) Build a segment tree - tree = arbor.segment_tree() - + # The dendrite (dend) attaches to the soma and has two simple segments + # attached. + # + # left + # / + # soma - dend + # \ + # right + tree = A.segment_tree() + root = A.mnpos # Soma (tag=1) with radius 6 μm, modelled as cylinder of length 2*radius - s = tree.append( - arbor.mnpos, arbor.mpoint(-12, 0, 0, 6), arbor.mpoint(0, 0, 0, 6), tag=1 - ) - - # (b0) Single dendrite (tag=3) of length 50 μm and radius 2 μm attached to soma. - b0 = tree.append(s, arbor.mpoint(0, 0, 0, 2), arbor.mpoint(0, 0, 50, 2), tag=3) - + soma = tree.append(root, (-12, 0, 0, 6), (0, 0, 0, 6), tag=1) + # Single dendrite (tag=3) of length 50 μm and radius 2 μm attached to soma. + dend = tree.append(soma, (0, 0, 0, 2), (50, 0, 0, 2), tag=3) # Attach two dendrites (tag=3) of length 50 μm to the end of the first dendrite. - # (b1) Radius tapers from 2 to 0.5 μm over the length of the dendrite. - tree.append( - b0, - arbor.mpoint(0, 0, 50, 2), - arbor.mpoint(0, 50 / sqrt(2), 50 + 50 / sqrt(2), 0.5), + # Radius tapers from 2 to 0.5 μm over the length of the dendrite. + l = 50 / sqrt(2) + _ = tree.append( + dend, + (50, 0, 0, 2), + (50 + l, l, 0, 0.5), tag=3, ) - # (b2) Constant radius of 1 μm over the length of the dendrite. - tree.append( - b0, - arbor.mpoint(0, 0, 50, 1), - arbor.mpoint(0, -50 / sqrt(2), 50 + 50 / sqrt(2), 1), + # Constant radius of 1 μm over the length of the dendrite. + _ = tree.append( + dend, + (50, 0, 0, 1), + (50 + l, -l, 0, 1), tag=3, ) # Associate labels to tags - labels = arbor.label_dict( + labels = A.label_dict( { "soma": "(tag 1)", "dend": "(tag 3)", @@ -60,29 +55,27 @@ def make_cable_cell(gid): # (3) Create a decor and a cable_cell decor = ( - arbor.decor() + A.decor() # Put hh dynamics on soma, and passive properties on the dendrites. - .paint('"soma"', arbor.density("hh")).paint('"dend"', arbor.density("pas")) + .paint('"soma"', A.density("hh")).paint('"dend"', A.density("pas")) # (4) Attach a single synapse. - .place('"synapse_site"', arbor.synapse("expsyn"), "syn") + .place('"synapse_site"', A.synapse("expsyn"), "syn") # Attach a detector with threshold of -10 mV. - .place('"root"', arbor.threshold_detector(-10), "detector") + .place('"root"', A.threshold_detector(-10 * U.mV), "detector") ) - return arbor.cable_cell(tree, decor, labels) + return A.cable_cell(tree, decor, labels) # (5) Create a recipe that generates a network of connected cells. -class random_ring_recipe(arbor.recipe): +class random_ring_recipe(A.recipe): def __init__(self, ncells): - # The base C++ class constructor must be called first, to ensure that - # all memory in the C++ class is initialized correctly. - arbor.recipe.__init__(self) + # Base class constructor must be called first for proper initialization. + A.recipe.__init__(self) self.ncells = ncells - self.props = arbor.neuron_cable_properties() + self.props = A.neuron_cable_properties() - # (6) The num_cells method that returns the total number of cells in the model - # must be implemented. + # (6) Returns the total number of cells in the model; must be implemented. def num_cells(self): return self.ncells @@ -90,19 +83,11 @@ def num_cells(self): def cell_description(self, gid): return make_cable_cell(gid) - # The kind method returns the type of cell with gid. - # Note: this must agree with the type returned by cell_description. - def cell_kind(self, gid): - return arbor.cell_kind.cable - - def cell_isometry(self, gid): - # place cells with equal distance on a circle - radius = 500.0 # μm - angle = 2.0 * math.pi * gid / self.ncells - return arbor.isometry.translate( - radius * math.cos(angle), radius * math.sin(angle), 0 - ) + # Return the type of cell; must be implemented and match cell_description. + def cell_kind(self, _): + return A.cell_kind.cable + # (8) Descripe network def network_description(self): seed = 42 @@ -134,66 +119,64 @@ def network_description(self): # fixed delay d = "(scalar 5.0)" # ms delay - return arbor.network_description(s, w, d, {}) + return A.network_description(s, w, d, {}) # (9) Attach a generator to the first cell in the ring. def event_generators(self, gid): if gid == 0: - sched = arbor.explicit_schedule([1]) # one event at 1 ms + sched = A.explicit_schedule([1 * U.ms]) # one event at 1 ms weight = 0.1 # 0.1 μS on expsyn - return [arbor.event_generator("syn", weight, sched)] + return [A.event_generator("syn", weight, sched)] return [] # (10) Place a probe at the root of each cell. def probes(self, gid): - return [arbor.cable_probe_membrane_voltage('"root"')] + return [A.cable_probe_membrane_voltage('"root"', "Um")] - def global_properties(self, kind): + def global_properties(self, _): return self.props # (11) Instantiate recipe -ncells = 20 +ncells = 4 recipe = random_ring_recipe(ncells) -sim = arbor.simulation(recipe) +# (12) Create a simulation using the default settings: +# - Use all threads available +# - Use round-robin distribution of cells across groups with one cell per group +# - Use GPU if present +# - No MPI +# Other constructors of simulation can be used to change all of these. +sim = A.simulation(recipe) # (13) Set spike generators to record -sim.record(arbor.spike_recording.all) +sim.record(A.spike_recording.all) # (14) Attach a sampler to the voltage probe on cell 0. Sample rate of 10 sample every ms. -handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] - -# (15) Inspect generated connections -connections = arbor.generate_network_connections(recipe) +handles = [ + sim.sample((gid, "Um"), A.regular_schedule(0.1 * U.ms)) for gid in range(ncells) +] -print("connections:") -for c in connections: - print( - f'({c.source.gid}, "{c.source.label}") -> ({c.target.gid}, "{c.target.label}")' - ) - -# (16) Run simulation for 100 ms -sim.run(100) +# (15) Run simulation for 100 ms +sim.run(100 * U.ms) print("Simulation finished") -# (17) Print spike times +# (16) Print spike times print("spikes:") for sp in sim.spikes(): print(" ", sp) -# (18) Plot the recorded voltages over time. +# (17) Plot the recorded voltages over time. print("Plotting results ...") -df_list = [] +dfs = [] for gid in range(ncells): samples, meta = sim.samples(handles[gid])[0] - df_list.append( + dfs.append( pandas.DataFrame( {"t/ms": samples[:, 0], "U/mV": samples[:, 1], "Cell": f"cell {gid}"} ) ) - -df = pandas.concat(df_list, ignore_index=True) +df = pandas.concat(dfs, ignore_index=True) seaborn.relplot( data=df, kind="line", x="t/ms", y="U/mV", hue="Cell", errorbar=None -).savefig("network_ring_result.svg") +).savefig("network_description_result.svg") diff --git a/scripts/run_cpp_examples.sh b/scripts/run_cpp_examples.sh index 96d498c2b0..b916c23594 100755 --- a/scripts/run_cpp_examples.sh +++ b/scripts/run_cpp_examples.sh @@ -59,7 +59,7 @@ expected_outputs=( "" "" "" - 37 + 46 "" ) From 20902f16781abd55916296165c28e12924adf92b Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 5 Apr 2024 17:46:14 +0200 Subject: [PATCH 71/84] change back pybind11 submodule --- ext/pybind11 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/pybind11 b/ext/pybind11 index 80dc998efc..8a099e44b3 160000 --- a/ext/pybind11 +++ b/ext/pybind11 @@ -1 +1 @@ -Subproject commit 80dc998efced8ceb2be59756668a7e90e8bef917 +Subproject commit 8a099e44b3d5f85b20f05828d919d2332a8de841 From 1e07cddd62665a0ad8db26b08493baa03ab17cce Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Fri, 5 Apr 2024 18:17:26 +0200 Subject: [PATCH 72/84] add noexcept to spatial_tree --- arbor/util/spatial_tree.hpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 10f3879194..e3bba73dfe 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -89,20 +90,28 @@ class spatial_tree { spatial_tree(const spatial_tree &) = default; - spatial_tree(spatial_tree &&t) { *this = std::move(t); } + spatial_tree(spatial_tree &&t) noexcept(std::is_nothrow_move_assignable_v) { + *this = std::move(t); + } spatial_tree &operator=(const spatial_tree &) = default; - spatial_tree &operator=(spatial_tree &&t) { - data_ = std::move(t.data_); + spatial_tree &operator=(spatial_tree &&t) noexcept( + noexcept(std::swap(data_, t.data_)) && + std::is_nothrow_default_constructible_v && + std::is_nothrow_move_assignable_v) { + + std::swap(data_, t.data_); size_ = t.size_; - min_ = t.min_; - max_ = t.max_; + min_ = std::move(t.min_); + max_ = std::move(t.max_); + location_ = t.location_; t.data_ = leaf_data(); t.size_ = 0; t.min_ = point_type(); t.max_ = point_type(); + t.location_ = nullptr; return *this; } @@ -167,9 +176,9 @@ class spatial_tree { } - inline std::size_t size() const { return size_; } + inline std::size_t size() const noexcept { return size_; } - inline bool empty() const { return !size_; } + inline bool empty() const noexcept { return !size_; } private: std::size_t size_; From fc444f689c5c0d0b245ae86ecf94f391aca99d08 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 8 Apr 2024 17:24:57 +0200 Subject: [PATCH 73/84] improved error message --- arborio/networkio.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arborio/networkio.cpp b/arborio/networkio.cpp index 72ca1f172f..194ff41c9e 100644 --- a/arborio/networkio.cpp +++ b/arborio/networkio.cpp @@ -272,11 +272,13 @@ std::string eval_description(const char* name, const std::vector& args if (t == typeid(double)) return "real"; if (t == typeid(arb::region)) return "region"; if (t == typeid(arb::locset)) return "locset"; + if (t == typeid(arb::network_selection)) return "network_selection"; + if (t == typeid(arb::network_value)) return "network_value"; return "unknown"; }; const auto nargs = args.size(); - std::string msg = concat("'", name, "' with ", nargs, "argument", nargs != 1u ? "s:" : ":"); + std::string msg = concat("'", name, "' with ", nargs, " argument", nargs != 1u ? "s:" : ":"); if (nargs) { msg += " ("; bool first = true; From 654d93594f85c73d875e3de1c52dbc808ec93fb4 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 8 Apr 2024 17:53:08 +0200 Subject: [PATCH 74/84] fix doc for random selection --- doc/concepts/interconnectivity.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index ee78cc3b39..cb48327ba0 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -195,11 +195,11 @@ Network Selection Expressions A chain of connections between cells in reverse of the given order of the gid-range, such that entry "i+1" is the source and entry "i" the target. -.. label:: (random p:real) +.. label:: (random seed:integer p:real) A random selection of connections, where each connection is selected with the given probability. -.. label:: (random p:network-value) +.. label:: (random seed:integer p:network-value) A random selection of connections, where each connection is selected with the given probability expression. From 55eff364a2aeca8b6a52f55a1e7546c7c37bf0cb Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 8 Apr 2024 19:55:42 +0200 Subject: [PATCH 75/84] replace pybind arg_v --- python/network.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/network.cpp b/python/network.cpp index 0deb551ab2..d1bf3858b5 100644 --- a/python/network.cpp +++ b/python/network.cpp @@ -154,8 +154,8 @@ void register_network(py::module& m) { return generate_network_connections(rec_shim, ctx->context, decomp.value()); }, "recipe"_a, - pybind11::arg_v("context", pybind11::none(), "Execution context"), - pybind11::arg_v("decomp", pybind11::none(), "Domain decomposition"), + "context"_a = pybind11::none(), + "decomp"_a = pybind11::none(), "Generate network connections from the network description in the recipe. Will only " "generate connections with local gids in the domain composition as target."); } From 43de6a1597aa166adf9f50ef1af1e690ffbd8422 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 8 Apr 2024 20:00:14 +0200 Subject: [PATCH 76/84] doc --- doc/concepts/interconnectivity.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/concepts/interconnectivity.rst b/doc/concepts/interconnectivity.rst index cb48327ba0..bdb93f3a29 100644 --- a/doc/concepts/interconnectivity.rst +++ b/doc/concepts/interconnectivity.rst @@ -23,7 +23,7 @@ The recipe callbacks are interrogated during simulation creation. High Level Network Description ------------------------------ -As an alternative to providing a list of connections for each cell in the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or target label, cell indices and also distance between source and target. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. +As an additional option to providing a list of connections for each cell in the :ref:`recipe `, arbor supports high-level description of a cell network. It is based around a ``network_selection`` type, that represents a selection from the set of all possible connections between cells. A selection can be created based on different criteria, such as source or target label, cell indices and also distance between source and target. Selections can then be combined with other selections through set algebra like expressions. For distance calculations, the location of each connection point on the cell is resolved through the morphology combined with a cell isometry, which describes translation and rotation of the cell. Each connection also requires a weight and delay value. For this purpose, a ``network_value`` type is available, that allows to mathematically describe the value calculation using common math functions, as well random distributions. The following example shows the relevant recipe functions, where cells are connected into a ring with additional random connections between them: From 1596219f4958800abb96b9132b1e972f9a1dffaa Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 9 Apr 2024 10:52:34 +0200 Subject: [PATCH 77/84] distributed for each doc --- arbor/communication/distributed_for_each.hpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp index 55401772c4..7024a59f7b 100644 --- a/arbor/communication/distributed_for_each.hpp +++ b/arbor/communication/distributed_for_each.hpp @@ -41,13 +41,17 @@ void for_each_in_tuple_pair(FUNC&& func, std::tuple& t1, std::tuple::value_type*>&...) -> void - * Given 'n' distributed ranks, the function will be called 'n' times with data from each rank. - * There is no guaranteed order. + * Collective operation, calling func on args supplied by each rank exactly once. The order of calls + * is unspecified. Requires + * + * - Item = util::range::value_type to be identical across all ranks + * - Item is trivially_copyable + * - Alignment of Item must not exceed std::max_align_t + * - func to be a callable type with signature + * void func(util::range...) + * - func must not modify contents of range + * - All ranks in distributed must call this collectively. */ template void distributed_for_each(FUNC&& func, From d178b07fcf88fe5dcb000bcc6ac7a0331c5d39f1 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 9 Apr 2024 13:33:19 +0200 Subject: [PATCH 78/84] move call out of inner loop --- arbor/communication/communicator.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index 4753f5eac7..e1d3de887c 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -143,6 +143,7 @@ void communicator::update_connections(const recipe& rec, auto target_resolver = resolver(&target_resolution_map); for (const auto index: util::make_span(num_local_cells_)) { const auto tgt_gid = gids[index]; + const auto iod = dom_dec.index_on_domain(tgt_gid); auto source_resolver = resolver(&source_resolution_map); for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) { const auto& conn = gid_connections[cidx]; @@ -152,11 +153,7 @@ void communicator::update_connections(const recipe& rec, auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); auto offset = offsets[*src_domain]++; ++src_domain; - connections[offset] = {{src_gid, src_lid}, - tgt_lid, - conn.weight, - conn.delay, - dom_dec.index_on_domain(tgt_gid)}; + connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, iod}; } for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) { const auto& conn = gid_ext_connections[cidx]; @@ -164,8 +161,7 @@ void communicator::update_connections(const recipe& rec, auto src_gid = conn.source.rid; if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid); auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); - ext_connections[ext] = { - src, tgt_lid, conn.weight, conn.delay, dom_dec.index_on_domain(tgt_gid)}; + ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, iod}; ++ext; } } From 23e9659215c0ec09705e12b70b9a047f5ee34040 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 9 Apr 2024 17:21:58 +0200 Subject: [PATCH 79/84] fix warning --- arbor/communication/distributed_for_each.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbor/communication/distributed_for_each.hpp b/arbor/communication/distributed_for_each.hpp index 7024a59f7b..2a1399639b 100644 --- a/arbor/communication/distributed_for_each.hpp +++ b/arbor/communication/distributed_for_each.hpp @@ -110,7 +110,7 @@ void distributed_for_each(FUNC&& func, arg_tuple, ranges); - for (std::size_t step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); } + for (int step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); } return; } From 953b1cebcd086f98a075eb70363fd2fc823e1123 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 9 Apr 2024 17:22:58 +0200 Subject: [PATCH 80/84] noexcept fix attempt --- arbor/util/spatial_tree.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index e3bba73dfe..5bd31d4a56 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -97,7 +97,7 @@ class spatial_tree { spatial_tree &operator=(const spatial_tree &) = default; spatial_tree &operator=(spatial_tree &&t) noexcept( - noexcept(std::swap(data_, t.data_)) && + noexcept(std::swap(this->data_, t.data_)) && std::is_nothrow_default_constructible_v && std::is_nothrow_move_assignable_v) { From 6ad6c308129b092ec04376c3686bd93257f75a05 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Tue, 9 Apr 2024 17:31:07 +0200 Subject: [PATCH 81/84] fix dummy context test --- test/unit/test_domain_decomposition.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index 915a346ae9..ac336616e6 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -163,6 +163,15 @@ struct dummy_context { cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { throw unimplemented{__FUNCTION__}; } cell_labels_and_gids gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const { throw unimplemented{__FUNCTION__}; } template std::vector gather(T value, int) const { throw unimplemented{__FUNCTION__}; } + distributed_request send_recv_nonblocking(std::size_t dest_count, + void* dest_data, + int dest, + std::size_t source_count, + const void* source_data, + int source, + int tag) const { + throw unimplemented{__FUNCTION__}; + } int id() const { return id_; } int size() const { return size_; } From 340da63196fbb8b81d38c19ad2bba69c1912aa2d Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 10 Apr 2024 08:52:42 +0200 Subject: [PATCH 82/84] revert std::swap usage --- arbor/util/spatial_tree.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 5bd31d4a56..5f84b7a91a 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -7,7 +7,6 @@ #include #include -#include #include #include #include @@ -97,14 +96,15 @@ class spatial_tree { spatial_tree &operator=(const spatial_tree &) = default; spatial_tree &operator=(spatial_tree &&t) noexcept( - noexcept(std::swap(this->data_, t.data_)) && std::is_nothrow_default_constructible_v && - std::is_nothrow_move_assignable_v) { + std::is_nothrow_move_assignable_v && + std::is_nothrow_default_constructible_v && + std::is_nothrow_move_assignable_vdata_)>) { - std::swap(data_, t.data_); + data_ = std::move(t.data_); size_ = t.size_; - min_ = std::move(t.min_); - max_ = std::move(t.max_); + min_ = t.min_; + max_ = t.max_; location_ = t.location_; t.data_ = leaf_data(); From 534c559c76d34157860c35552dfc840ecd8037e5 Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Wed, 10 Apr 2024 08:53:55 +0200 Subject: [PATCH 83/84] use lowest() --- arbor/util/spatial_tree.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbor/util/spatial_tree.hpp b/arbor/util/spatial_tree.hpp index 5f84b7a91a..f572781f87 100644 --- a/arbor/util/spatial_tree.hpp +++ b/arbor/util/spatial_tree.hpp @@ -43,7 +43,7 @@ class spatial_tree { if (leaf_d.empty()) return; min_.fill(std::numeric_limits::max()); - max_.fill(-std::numeric_limits::max()); + max_.fill(std::numeric_limits::lowest()); for (const auto &d: leaf_d) { const auto p = location(d); From 31fbbe41d0801ade184b8e8e14224d8e3d23cf0b Mon Sep 17 00:00:00 2001 From: Simon Frasch Date: Mon, 13 May 2024 15:23:39 +0200 Subject: [PATCH 84/84] add documentation for generate_network_connections function --- doc/cpp/interconnectivity.rst | 12 ++++++++++++ doc/python/interconnectivity.rst | 8 ++++++++ example/network_description/network_description.cpp | 4 ++-- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/cpp/interconnectivity.rst b/doc/cpp/interconnectivity.rst index d6cff1821f..bfd9557921 100644 --- a/doc/cpp/interconnectivity.rst +++ b/doc/cpp/interconnectivity.rst @@ -360,3 +360,15 @@ Interconnectivity .. cpp:member:: network_label_dict dict Label dictionary for named selecations and values. + + +.. function:: generate_network_connections(recipe, context, decomp) + + Generate network connections from the network description in the recipe. Only generates connections + with local gids in the domain composition as target. Does not include connections from + the "connections_on" recipe function. + +.. function:: generate_network_connections(recipe) + + Generate network connections from the network description in the recipe. Returns all generated connections on every process. + Does not include connections from the "connections_on" recipe function. diff --git a/doc/python/interconnectivity.rst b/doc/python/interconnectivity.rst index 9b19c52487..28edc9c114 100644 --- a/doc/python/interconnectivity.rst +++ b/doc/python/interconnectivity.rst @@ -171,3 +171,11 @@ Interconnectivity .. attribute:: dict Dictionary for named selecations and values. + + +.. function:: generate_network_connections(recipe, context = None, decomp = None) + + Generate network connections from the network description in the recipe. A distributed context and + domain decomposition can optionally be provided. Only generates connections with local gids in the + domain composition as target. Will return all connections on every process, if no context and domain + decomposition are provided. Does not include connections from the "connections_on" recipe function. diff --git a/example/network_description/network_description.cpp b/example/network_description/network_description.cpp index cb6d5a94b9..8f80814be7 100644 --- a/example/network_description/network_description.cpp +++ b/example/network_description/network_description.cpp @@ -96,9 +96,9 @@ class ring_recipe: public arb::recipe { std::optional network_description() const override { // create a chain - auto ring = arb::network_selection::chain(arb::gid_range(0, num_cells_)); + auto chain = arb::network_selection::chain(arb::gid_range(0, num_cells_)); // connect front and back of chain to form ring - ring = arb::join(ring, + auto ring = arb::join(chain, arb::intersect(arb::network_selection::source_cell({num_cells_ - 1}), arb::network_selection::target_cell({0})));