Skip to content

Commit

Permalink
[TensorV2] Add dot product support
Browse files Browse the repository at this point in the history
This PR adds the ability to compute the dot product between two tensors.

**Changes proposed in this PR:**
- Added a new method dot() to the TensorV2 class that computes the dot product between two tensors.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Feb 23, 2024
1 parent dc8d02d commit ec43650
Show file tree
Hide file tree
Showing 8 changed files with 466 additions and 0 deletions.
60 changes: 60 additions & 0 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,66 @@ TensorV2 &FloatTensor::erf(TensorV2 &output) const {
return output;
}

TensorV2 &FloatTensor::dot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, float beta) const {
// Comment out with intension to support the calculation wrt. batch and height
// direction. It supposes to have this->dim as [ BxCxH,W ] and input.dim is
// [BxCxH,W] as well if (input.dim.rank() > 2) {
// throw exception::not_supported("Error: support only for rank of dot "
// "matrix <= 2");
// }

// Comment out with intension to support the calculation wrt. batch and height
// direction of this tensor. It is OK as long as input is 2D
if (trans && dim.rank() > 2) {
ml_logw("Warning: support only for rank of dot matrix <= 2 with trans");
}
unsigned int first_three_flat, last_axis, input_first_three_flat,
input_last_axis, M, N, K, lda, ldb, ldc;

calculateFlattenDot(input, output, trans, trans_in, first_three_flat,
last_axis, input_first_three_flat, input_last_axis, M, N,
K, lda, ldb, ldc);

const float *data = (float *)getData();
const float *mdata = input.getData<float>();
float *rdata = output.getData<float>();
const float alpha = 1.0f;
enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans;

/// shortcut handling in case of vector
/// for vector, (1 * K) == (K * 1) in current memory layout...
/// and plaese note that N, K, M is a fixed place holder after considering
/// transpose.
/// For example, there is no case like (1 * K) X (1 * K) while
/// (1 * K) X (1 * M) can be a case
/// case1: (1 * K) X (K * 1)
if (M == 1 && N == 1) {
*rdata = sdot(K, data, 1, mdata, 1) + beta * (*rdata);
}
/// case2: (M * K) X (K * 1)
else if (N == 1) {
sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda,
mdata, 1, beta, rdata, 1);
}
/// case3: (1 * K) X (K * N) = 1 * N = R
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
/// Effectively a translation of sgemv
else if (M == 1) {
transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha,
mdata, ldb, data, 1, beta, rdata, 1);
}
/// case others: use gemm
else {
sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb,
beta, rdata, ldc);
}

return output;
}

void FloatTensor::print(std::ostream &out) const {
printInstance(out, this);
const float *data = (float *)getData();
Expand Down
7 changes: 7 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ class FloatTensor : public TensorBase {
*/
TensorV2 &erf(TensorV2 &output) const override;

/**
* @copydoc TensorV2::dot(TensorV2 const &input, TensorV2 &output, bool
* trans, bool trans_in, float beta)
*/
TensorV2 &dot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, float beta) const override;

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
Expand Down
60 changes: 60 additions & 0 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,66 @@ TensorV2 &HalfTensor::erf(TensorV2 &output) const {
return output;
}

TensorV2 &HalfTensor::dot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, float beta) const {
// Comment out with intension to support the calculation wrt. batch and height
// direction. It supposes to have this->dim as [ BxCxH,W ] and input.dim is
// [BxCxH,W] as well if (input.dim.rank() > 2) {
// throw exception::not_supported("Error: support only for rank of dot "
// "matrix <= 2");
// }

// Comment out with intension to support the calculation wrt. batch and height
// direction of this tensor. It is OK as long as input is 2D
if (trans && dim.rank() > 2) {
ml_logw("Warning: support only for rank of dot matrix <= 2 with trans");
}
unsigned int first_three_flat, last_axis, input_first_three_flat,
input_last_axis, M, N, K, lda, ldb, ldc;

calculateFlattenDot(input, output, trans, trans_in, first_three_flat,
last_axis, input_first_three_flat, input_last_axis, M, N,
K, lda, ldb, ldc);

const _FP16 *data = (_FP16 *)getData();
const _FP16 *mdata = input.getData<_FP16>();
_FP16 *rdata = output.getData<_FP16>();
const float alpha = 1.0f;
enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans;

/// shortcut handling in case of vector
/// for vector, (1 * K) == (K * 1) in current memory layout...
/// and plaese note that N, K, M is a fixed place holder after considering
/// transpose.
/// For example, there is no case like (1 * K) X (1 * K) while
/// (1 * K) X (1 * M) can be a case
/// case1: (1 * K) X (K * 1)
if (M == 1 && N == 1) {
*rdata = sdot(K, data, 1, mdata, 1) + static_cast<_FP16>(beta) * (*rdata);
}
/// case2: (M * K) X (K * 1)
else if (N == 1) {
sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda,
mdata, 1, beta, rdata, 1);
}
/// case3: (1 * K) X (K * N) = 1 * N = R
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
/// Effectively a translation of sgemv
else if (M == 1) {
transB = transB == CblasTrans ? CblasNoTrans : CblasTrans;
sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha,
mdata, ldb, data, 1, beta, rdata, 1);
}
/// case others: use sgemm
else {
sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb,
beta, rdata, ldc);
}

return output;
}

void HalfTensor::print(std::ostream &out) const {
printInstance(out, this);
const _FP16 *data = (_FP16 *)getData();
Expand Down
7 changes: 7 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ class HalfTensor : public TensorBase {
*/
TensorV2 &erf(TensorV2 &output) const override;

/**
* @copydoc TensorV2::dot(TensorV2 const &input, TensorV2 &output, bool
* trans, bool trans_in, float beta)
*/
TensorV2 &dot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, float beta) const override;

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
Expand Down
88 changes: 88 additions & 0 deletions nntrainer/tensor/tensor_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,92 @@ TensorBase::computeBroadcastInfo(const TensorV2 &m) const {
return e;
}

void TensorBase::calculateFlattenDot(
TensorV2 const &input, TensorV2 &output, bool trans, bool trans_in,
unsigned int &first_three_flat, unsigned int &last_axis,
unsigned int &input_first_three_flat, unsigned int &input_last_axis,
unsigned int &M, unsigned int &N, unsigned int &K, unsigned int &lda,
unsigned int &ldb, unsigned int &ldc) const {

if (trans && dim.rank() > 2) {
ml_logw("Warning: support only for rank of dot matrix <= 2 with trans");
}

if (getFormat() == Tformat::NHWC) {
first_three_flat = batch() * height() * width();
last_axis = channel();
input_first_three_flat = input.batch() * input.height() * input.width();
input_last_axis = input.channel();
} else {
first_three_flat = batch() * channel() * height();
last_axis = width();
input_first_three_flat = input.batch() * input.channel() * input.height();
input_last_axis = input.width();
}

if (!trans && !trans_in) {
if (last_axis != input_first_three_flat)
throw std::runtime_error(
"Error: incompatible dimensions for dot product");
K = input_first_three_flat; /** == last_axis */
N = input_last_axis;
M = first_three_flat;
if (getFormat() == Tformat::NHWC) {
CREATE_V2_IF_EMPTY_DIMS(output, batch(), N, height(), width(),
getTensorType()); // NHWC Result Tensor
} else {
CREATE_V2_IF_EMPTY_DIMS(output, batch(), channel(), height(), N,
getTensorType());
}

// We are not set zero the output because of performance reason.
// However, output is not initialized properly. There might include
// garbage like nan. When we have to use this value as in C = alpha*A*B +
// beta*C, then have to check garbage data of C is not effect or not.

} else if (!trans && trans_in) {
if (last_axis != input_last_axis)
throw std::runtime_error(
"Error: incompatible dimensions for dot product");
K = input_last_axis; /** == last_axis */
N = input_first_three_flat;
M = first_three_flat;
if (getFormat() == Tformat::NHWC) {
CREATE_V2_IF_EMPTY_DIMS(output, batch(), N, height(), width(),
getTensorType());
} else {
CREATE_V2_IF_EMPTY_DIMS(output, batch(), channel(), height(), N,
getTensorType());
}
} else if (trans && !trans_in) {
if (first_three_flat != input_first_three_flat)
throw std::runtime_error(
"Error: incompatible dimensions for dot product");
K = input_first_three_flat; /** == first_three_flat */
N = input_last_axis;
M = last_axis;
if (getFormat() == Tformat::NHWC) {
CREATE_V2_IF_EMPTY_DIMS(output, 1, N, M, 1, getTensorType());
} else {
CREATE_V2_IF_EMPTY_DIMS(output, 1, 1, M, N, getTensorType());
}
} else {
if (first_three_flat != input_last_axis)
throw std::runtime_error(
"Error: incompatible dimensions for dot product");
K = input_last_axis; /** == first_three_flat */
N = input_first_three_flat;
M = last_axis;
if (getFormat() == Tformat::NHWC) {
CREATE_V2_IF_EMPTY_DIMS(output, 1, N, M, 1, getTensorType());
} else {
CREATE_V2_IF_EMPTY_DIMS(output, 1, 1, M, N, getTensorType());
}
}

lda = last_axis;
ldb = input_last_axis;
ldc = (getFormat() == Tformat::NHWC) ? output.channel() : output.width();
}

} // namespace nntrainer
42 changes: 42 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ class TensorBase {
*/
virtual TensorV2 &erf(TensorV2 &output) const = 0;

/**
* @brief Dot Product of Tensor ( equal MxM )
* @details This applies dot of the last dimension of this and
* second-last dimension of passed tensor m.
* @param[in] input Tensor
* @param[in] output output Tensor
* @param[in] trans Transpose
* @param[in] trans_in Transpose input
* @param[in] beta beta
* @retval Calculated Tensor
*/
virtual TensorV2 &dot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, float beta) const = 0;

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down Expand Up @@ -498,6 +512,34 @@ class TensorBase {
* @return BroadcastInfo Loopinfo needed to run external loop
*/
BroadcastInfoV2 computeBroadcastInfo(const TensorV2 &m) const;

/**
* @brief Calcuates variables needed to perform tensor flatten dot product
*
* @param[in] input Tensor
* @param[in] output output Tensor
* @param[in] trans Transpose
* @param[in] trans_in Transpose input
* @param[out] first_three_flat flattened the fist 3 axis
* @param[out] last_axis last axis
* @param[out] input_first_three_flat input's flattened the fist 3 axis
* @param[out] input_last_axis input's last axis
* @param[out] M number of op(this)'s and output's row
* @param[out] N number of op(inputs)'s and output's columns
* @param[out] K number of op(this)'s column and op(input)'s row
* @param[out] lda leading dimension of this
* @param[out] ldb leading dimension of input
* @param[out] ldc leading dimension of output
*
* @note op(X) is one of X or X**T
*/
void calculateFlattenDot(TensorV2 const &input, TensorV2 &output, bool trans,
bool trans_in, unsigned int &first_three_flat,
unsigned int &last_axis,
unsigned int &input_first_three_flat,
unsigned int &input_last_axis, unsigned int &M,
unsigned int &N, unsigned int &K, unsigned int &lda,
unsigned int &ldb, unsigned int &ldc) const;
};

/**
Expand Down
Loading

0 comments on commit ec43650

Please sign in to comment.