Skip to content

Commit

Permalink
Array2d / Sparse2d interop
Browse files Browse the repository at this point in the history
  • Loading branch information
andro2157 authored and PhilipDeegan committed Jul 6, 2019
1 parent dcfe2d9 commit c933478
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions lib/include/tick/array/array2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class SArray2d;
template <typename T, typename MAJ = RowMajor>
class SparseArray2d;

template <typename T, typename MAJ>
class SSparseArray2d;

/*! \class Array2d
* \brief Template class for basic non sparse 2d arrays of type `T`.
*
Expand Down Expand Up @@ -51,6 +54,81 @@ class Array2d : public BaseArray2d<T, MAJ> {
using AbstractArray1d2d<T, MAJ>::is_sparse;
using AbstractArray1d2d<T, MAJ>::init_to_zero;

// implement this properly
explicit Array2d(std::vector<std::vector<T>> data) {
allocVector2D_Data(data);
}

Array2d<T>& operator=(std::vector<std::vector<T>> data) {
allocVector2D_Data(data);
return *this;
}

void allocVector2D_Data(std::vector<std::vector<T>> data) {
if (data.size() == 0) {
TICK_ERROR("data empty");
return;
}

if (is_data_allocation_owned)
TICK_PYTHON_FREE(_data);

_n_cols = data[0].size();
_n_rows = data.size();
_size = _n_cols * _n_rows;

is_data_allocation_owned = true;
TICK_PYTHON_MALLOC(_data, T, _size);

ulong index = 0;
for (std::vector<T> vec : data) {
if (vec.size() != _n_cols)
TICK_ERROR("non consistent column length");
memcpy(_data + index, vec.data(), _n_cols * sizeof(T));
index+=_n_cols * sizeof(T) / sizeof(void*);
}
}

std::shared_ptr<SSparseArray2d<T, MAJ>> toSSparseArray2dPtr() {
std::vector<T> data;
std::vector<uint> row_idx(_n_rows + 1);
std::vector<uint> col;

ulong nnz = 0;

row_idx[0] = 0;
for (uint r = 0; r < _n_rows; r++) {
int nnz_row = 0;
for (uint c = 0; c < _n_cols; c++) {
if (operator()(r, c) != (T)0) {
T val = operator()(r, c);
nnz++;
nnz_row++;
data.push_back(val);
col.push_back(c);
}
}
row_idx[r + 1] = row_idx[r] + nnz_row;
}

uint* row_ptr = new uint[row_idx.size()];
uint* col_ptr = new uint[col.size()];
T* data_ptr = new T[data.size()];

memcpy(row_ptr, row_idx.data(), row_idx.size() * sizeof(uint));
memcpy(col_ptr, col.data(), col.size() * sizeof(uint));
memcpy(data_ptr, data.data(), data.size() * sizeof(T));

std::shared_ptr<SSparseArray2d<T, MAJ>> arrayptr =
SSparseArray2d<T, MAJ>::new_ptr(0, 0, 0);

arrayptr->set_data_indices_rowindices(data_ptr, col_ptr, row_ptr, _n_rows, _n_cols);
return arrayptr;
}


//

//! @brief Constructor for an empty array.
Array2d() : BaseArray2d<T, MAJ>(true) {}

Expand Down

0 comments on commit c933478

Please sign in to comment.