Skip to content

Commit

Permalink
Merge pull request #420 from mtao/mtao/default_attribute_initialization
Browse files Browse the repository at this point in the history
Default attribute initialization
  • Loading branch information
mtao authored Oct 5, 2023
2 parents b93f12a + c4c8422 commit 3d77912
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 23 deletions.
10 changes: 5 additions & 5 deletions src/wmtk/Mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ void Mesh::serialize(MeshWriter& writer)

template <typename T>
MeshAttributeHandle<T>
Mesh::register_attribute(const std::string& name, PrimitiveType ptype, long size, bool replace)
Mesh::register_attribute(const std::string& name, PrimitiveType ptype, long size, bool replace, T default_value)
{
return m_attribute_manager.register_attribute<T>(name, ptype, size, replace);
return m_attribute_manager.register_attribute<T>(name, ptype, size, replace, default_value);
}

std::vector<long> Mesh::request_simplex_indices(PrimitiveType type, long count)
Expand Down Expand Up @@ -250,11 +250,11 @@ attribute::AttributeScopeHandle Mesh::create_scope()
}

template MeshAttributeHandle<char>
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool);
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool, char);
template MeshAttributeHandle<long>
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool);
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool, long);
template MeshAttributeHandle<double>
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool);
Mesh::register_attribute(const std::string&, PrimitiveType, long, bool, double);

Tuple Mesh::switch_tuples(
const Tuple& tuple,
Expand Down
4 changes: 3 additions & 1 deletion src/wmtk/Mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class Mesh : public std::enable_shared_from_this<Mesh>
const std::string& name,
PrimitiveType type,
long size,
bool replace = false);
bool replace = false,
T default_value = T(0)
);

template <typename T>
MeshAttributeHandle<T> get_attribute_handle(
Expand Down
14 changes: 5 additions & 9 deletions src/wmtk/attribute/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@ void Attribute<T>::serialize(const std::string& name, const int dim, MeshWriter&
writer.write(name, dim, dimension(), m_data);
}


template <typename T>
Attribute<T>::Attribute(long dimension)
Attribute<T>::Attribute(long dimension, T default_value, long size)
: m_scope_stacks(new PerThreadAttributeScopeStacks<T>())
, m_dimension(dimension)
, m_default_value(default_value)
{
assert(m_dimension > 0);
}

template <typename T>
Attribute<T>::Attribute(long dimension, long size)
: Attribute(dimension)
{
if (size > 0) {
m_data = std::vector<T>(size * dimension, T(0));
m_data = std::vector<T>(size * dimension, m_default_value);
}
}

Expand Down Expand Up @@ -66,7 +62,7 @@ template <typename T>
void Attribute<T>::reserve(const long size)
{
if (size > (m_data.size() / m_dimension)) {
m_data.resize(m_dimension * size, T(0));
m_data.resize(m_dimension * size, m_default_value);
}
}
template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions src/wmtk/attribute/Attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class Attribute
void serialize(const std::string& name, const int dim, MeshWriter& writer) const;

// if size < 0 then the internal data is not initialized
Attribute(long dimension, long size);
Attribute(long dimension);
Attribute(long dimension, T default_value = T(0), long size = 0);

Attribute(const Attribute& o);
Attribute(Attribute&& o);
Expand Down Expand Up @@ -67,6 +66,7 @@ class Attribute
std::vector<T> m_data;
std::unique_ptr<PerThreadAttributeScopeStacks<T>> m_scope_stacks;
long m_dimension = -1;
T m_default_value = T(0);
};
} // namespace attribute
} // namespace wmtk
10 changes: 7 additions & 3 deletions src/wmtk/attribute/AttributeManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ struct AttributeManager
const std::string& name,
PrimitiveType type,
long size,
bool replace = false);
bool replace = false,
T default_value = T(0)
);
template <typename T>
MeshAttributes<T>& get(PrimitiveType ptype);

Expand Down Expand Up @@ -121,14 +123,16 @@ MeshAttributeHandle<T> AttributeManager::register_attribute(
const std::string& name,
PrimitiveType ptype,
long size,
bool replace)
bool replace,
T default_value
)
{
// return MeshAttributeHandle<T>{
// .m_base_handle = get_mesh_attributes<T>(ptype).register_attribute(name, size),
// .m_primitive_type = ptype};

MeshAttributeHandle<T> r;
r.m_base_handle = get<T>(ptype).register_attribute(name, size, replace),
r.m_base_handle = get<T>(ptype).register_attribute(name, size, replace, default_value),
r.m_primitive_type = ptype;
return r;
}
Expand Down
4 changes: 2 additions & 2 deletions src/wmtk/attribute/MeshAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void MeshAttributes<T>::clear_current_scope()
}
template <typename T>
AttributeHandle
MeshAttributes<T>::register_attribute(const std::string& name, long dimension, bool replace)
MeshAttributes<T>::register_attribute(const std::string& name, long dimension, bool replace, T default_value)
{
assert(replace || m_handles.find(name) == m_handles.end());

Expand All @@ -66,7 +66,7 @@ MeshAttributes<T>::register_attribute(const std::string& name, long dimension, b
handle.index = it->second.index;
} else {
handle.index = m_attributes.size();
m_attributes.emplace_back(dimension, reserved_size());
m_attributes.emplace_back(dimension, default_value, reserved_size());
}
m_handles[name] = handle;

Expand Down
2 changes: 1 addition & 1 deletion src/wmtk/attribute/MeshAttributes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MeshAttributes
void serialize(const int dim, MeshWriter& writer) const;

[[nodiscard]] AttributeHandle
register_attribute(const std::string& name, long dimension, bool replace = false);
register_attribute(const std::string& name, long dimension, bool replace = false, T default_value = T(0));

long reserved_size() const;
void reserve(const long size);
Expand Down
17 changes: 17 additions & 0 deletions tests/test_accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,26 @@ TEST_CASE("test_accessor_basic")
auto long_handle = m.register_attribute<long>("long", wmtk::PrimitiveType::Vertex, 1);
auto double_handle = m.register_attribute<double>("double", wmtk::PrimitiveType::Vertex, 3);

auto char_def1_handle = m.register_attribute<char>("char1", wmtk::PrimitiveType::Vertex, 1, false, 1);
auto long_def1_handle = m.register_attribute<long>("long1", wmtk::PrimitiveType::Vertex, 1, false, 1);
auto double_def1_handle = m.register_attribute<double>("double1", wmtk::PrimitiveType::Vertex, 3, false, 1);

REQUIRE(m.get_attribute_dimension(char_handle) == 1);
REQUIRE(m.get_attribute_dimension(long_handle) == 1);
REQUIRE(m.get_attribute_dimension(double_handle) == 3);

REQUIRE(m.get_attribute_dimension(char_def1_handle) == 1);
REQUIRE(m.get_attribute_dimension(long_def1_handle) == 1);
REQUIRE(m.get_attribute_dimension(double_def1_handle) == 3);

auto char_acc = m.create_accessor(char_handle);
auto long_acc = m.create_accessor(long_handle);
auto double_acc = m.create_accessor(double_handle);

auto char_def1_acc = m.create_accessor(char_def1_handle);
auto long_def1_acc = m.create_accessor(long_def1_handle);
auto double_def1_acc = m.create_accessor(double_def1_handle);

auto char_bacc = m.create_base_accessor(char_handle);
auto long_bacc = m.create_base_accessor(long_handle);
auto double_bacc = m.create_base_accessor(double_handle);
Expand All @@ -87,6 +99,11 @@ TEST_CASE("test_accessor_basic")
CHECK(char_acc.const_scalar_attribute(tup) == 0);
CHECK(long_acc.const_scalar_attribute(tup) == 0);
CHECK((double_acc.const_vector_attribute(tup).array() == 0).all());

// checking that default initialization of 1 worked
CHECK(char_def1_acc.const_scalar_attribute(tup) == 1);
CHECK(long_def1_acc.const_scalar_attribute(tup) == 1);
CHECK((double_def1_acc.const_vector_attribute(tup).array() == 1).all());
}

// use global set to force all values
Expand Down

0 comments on commit 3d77912

Please sign in to comment.