diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 9091e0b1e2..ed3668598d 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -52,7 +52,7 @@ InitLayerContext::InitLayerContext( prefix(prefix_), tensor_type(tensor_type_), loss_scale(loss_scale_), - mode(mode_){ + mode(mode_) { NNTR_THROW_IF(!validate(), std::invalid_argument) << "Invalid init context name: " << name << " num inputs: " << getNumInputs(); @@ -292,7 +292,6 @@ const Tensor RunLayerContext::getIncomingDerivative(unsigned int idx) const { return getOutputGrad(idx); } - /** * @brief Get the Input tensor object * diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 4a2838d05e..f77a012b49 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -455,8 +455,8 @@ std::vector Manager::requestWeights( * reduce the memory. */ bool is_wgrad = true; - if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) - is_wgrad = false; + // if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) + // is_wgrad = false; grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 827ba7e979..cfab9bb488 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3319,13 +3319,17 @@ void Tensor::setValue(float val) { void Tensor::setZero() { if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) { if (contiguous) - sscal(size(), 0, getData(), 1); + // sscal(size(), 0, getData(), 1); + /// @note we cannot use sscal, when we set zero. if the data is inf or + /// NaN, then the inf or NaN still remain. + memset(getData<_FP16>(), 0, sizeof(float) * size()); else apply_i([](float val) -> float { return 0; }); } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 if (contiguous) - sscal(size(), 0, getData<_FP16>(), 1); + // sscal(size(), 0, getData<_FP16>(), 1); + memset(getData<_FP16>(), 0, sizeof(_FP16) * size()); else apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; }); #else