From 9473049721adb9854afd59035661568321ead587 Mon Sep 17 00:00:00 2001 From: Hiyu Shintani Date: Thu, 20 Jun 2019 16:01:36 +0200 Subject: [PATCH] Array2d / Sparse2d interop --- lib/include/tick/array/array2d.h | 78 ++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/lib/include/tick/array/array2d.h b/lib/include/tick/array/array2d.h index 641ab80fd..e5abb5bfa 100644 --- a/lib/include/tick/array/array2d.h +++ b/lib/include/tick/array/array2d.h @@ -13,6 +13,9 @@ class SArray2d; template class SparseArray2d; +template +class SSparseArray2d; + /*! \class Array2d * \brief Template class for basic non sparse 2d arrays of type `T`. * @@ -51,6 +54,81 @@ class Array2d : public BaseArray2d { using AbstractArray1d2d::is_sparse; using AbstractArray1d2d::init_to_zero; + // implement this properly + explicit Array2d(std::vector> data) { + allocVector2D_Data(data); + } + + Array2d& operator=(std::vector> data) { + allocVector2D_Data(data); + return *this; + } + + void allocVector2D_Data(std::vector> data) { + if (data.size() == 0) { + TICK_ERROR("data empty"); + return; + } + + if (is_data_allocation_owned) + TICK_PYTHON_FREE(_data); + + _n_cols = data[0].size(); + _n_rows = data.size(); + _size = _n_cols * _n_rows; + + is_data_allocation_owned = true; + TICK_PYTHON_MALLOC(_data, T, _size); + + ulong index = 0; + for (std::vector vec : data) { + if (vec.size() != _n_cols) + TICK_ERROR("non consistent column length"); + memcpy(_data + index, vec.data(), _n_cols * sizeof(T)); + index+=_n_cols * sizeof(T) / sizeof(void*); + } + } + + std::shared_ptr> toSSparseArray2dPtr() { + std::vector data; + std::vector row_idx(_n_rows + 1); + std::vector col; + + ulong nnz = 0; + + row_idx[0] = 0; + for (uint r = 0; r < _n_rows; r++) { + int nnz_row = 0; + for (uint c = 0; c < _n_cols; c++) { + if (operator()(r, c) != (T)0) { + T val = operator()(r, c); + nnz++; + nnz_row++; + data.push_back(val); + col.push_back(c); + } + } + row_idx[r + 1] = row_idx[r] + nnz_row; + } + + uint* row_ptr = new uint[row_idx.size()]; + uint* col_ptr = new uint[col.size()]; + T* data_ptr = new T[data.size()]; + + memcpy(row_ptr, row_idx.data(), row_idx.size() * sizeof(uint)); + memcpy(col_ptr, col.data(), col.size() * sizeof(uint)); + memcpy(data_ptr, data.data(), data.size() * sizeof(T)); + + std::shared_ptr> arrayptr = + SSparseArray2d::new_ptr(0, 0, 0); + + arrayptr->set_data_indices_rowindices(data_ptr, col_ptr, row_ptr, _n_rows, _n_cols); + return arrayptr; + } + + + // + //! @brief Constructor for an empty array. Array2d() : BaseArray2d(true) {}