From ebbdbd9f9960d32534e8be2e82e6dd86631d4ed1 Mon Sep 17 00:00:00 2001 From: fandaoyi Date: Thu, 7 Dec 2023 11:05:56 +0800 Subject: [PATCH] fix test err on 2.1 --- .../autogen_diopi_wrapper/diopi_functions.yaml | 8 ++++---- .../individual_scripts/test_op_benchmark.py | 5 +++-- dipu/tests/run_nv_tests.sh | 17 ++++++++++++----- .../CustomFallbackFunctionsForAmpGradScaler.cpp | 7 ++++--- .../csrc_dipu/runtime/distributed/c10dOps.cpp | 3 ++- .../csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp | 2 +- 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 8ae9f31ca..9b5258cc2 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -643,18 +643,18 @@ ::diopiSize_t diopi_size = toDiopiSize(dim); interface: diopiMean(ctx, out, self_dtype_diopi, diopi_size); -- schema: "std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor" +- schema: "std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor" custom_code_at_the_beginning: | std::vector output_shape = infer_reduce_op_shape(self.sizes(), dim.value_or(std::vector()), keepdim); auto out = at::empty(output_shape, self.options()); - bool unbiased = correction.value_or(1) == 1; + bool unbiased = correction.value_or(1).toLong() == 1; ::diopiSize_t diopi_size = toDiopiSize(dim); interface: diopiStd(ctx, out, self, diopi_size, unbiased); -- schema: "std.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)" +- schema: "std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)" custom_code_at_the_beginning: | ::diopiSize_t diopi_size = toDiopiSize(dim); - bool unbiased = correction.value_or(1) == 1; + bool unbiased = correction.value_or(1).toLong() == 1; interface: diopiStd(ctx, out, self, diopi_size, unbiased); - schema: "linear_backward(Tensor input, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)" diff --git a/dipu/tests/python/individual_scripts/test_op_benchmark.py b/dipu/tests/python/individual_scripts/test_op_benchmark.py index b34f409b2..f1aaa5dba 100644 --- a/dipu/tests/python/individual_scripts/test_op_benchmark.py +++ b/dipu/tests/python/individual_scripts/test_op_benchmark.py @@ -58,8 +58,9 @@ def batched_dot_bmm(a, b): # description is the column label = "Batched dot" sub_label = f"[{b}, {n}]" - x = torch.ones((b, n)) - for num_threads in [1, 4, 16, 32]: + x = torch.ones((b, n)).cuda() + # cuda tensor, not so many dispatch threads in actual case. 16, 32]: + for num_threads in [1, 4]: results.append( benchmark.Timer( stmt="batched_dot_mul_sum(x, x)", diff --git a/dipu/tests/run_nv_tests.sh b/dipu/tests/run_nv_tests.sh index c1c15c224..2f08e6a54 100644 --- a/dipu/tests/run_nv_tests.sh +++ b/dipu/tests/run_nv_tests.sh @@ -7,8 +7,15 @@ function run_dipu_tests { unset DIPU_DUMP_OP_ARGS export PYTHONPATH=${DIPU_ROOT}/../:${PYTHONPATH} ${CDIR}/python/run_tests.sh + echo "fill_.Scalar" >> .dipu_force_fallback_op_list.config + run_test "${PYTORCH_DIR}/test/test_tensor_creation_ops.py" "$@" -v -f TestTensorCreationDIPU # --locals -f + echo "" > .dipu_force_fallback_op_list.config + # run_test "${PYTORCH_DIR}/test/test_reductions.py" "$@" -v -f TestReductionsDIPU + run_test "${PYTORCH_TEST_DIR}/test/nn/test_convolution.py" -v TestConvolutionNNDeviceTypeDIPU # run_test "${PYTORCH_TEST_DIR}/test/test_linalg.py" "$@" -v TestLinalgDIPU + + # mock cuda cause test number err, temporary ignore # run_test "${PYTORCH_TEST_DIR}/test/test_testing.py" "$@" -v TestTestParametrizationDeviceTypeDIPU TestTestingDIPU run_test "${PYTORCH_TEST_DIR}/test/test_type_hints.py" "$@" -v run_test "${PYTORCH_TEST_DIR}/test/test_type_info.py" "$@" -v @@ -17,14 +24,14 @@ function run_dipu_tests { # run_test "${PYTORCH_TEST_DIR}/test/test_binary_ufuncs.py" "$@" -v TestBinaryUfuncsDIPU # run_test "${PYTORCH_TEST_DIR}/test/test_torch.py" "$@" -v TestTorchDeviceTypeDIPU #--subprocess #run_test "${PYTORCH_TEST_DIR}/test/test_indexing.py" "$@" -v TestIndexingDIPU - #run_test "${PYTORCH_TEST_DIR}/test/test_indexing.py" "$@" -v NumpyTestsDIPU - # run_test "${PYTORCH_TEST_DIR}/test/test_view_ops.py" "$@" -v TestViewOpsDIPU + run_test "${PYTORCH_TEST_DIR}/test/test_indexing.py" "$@" -v NumpyTestsDIPU + run_test "${PYTORCH_TEST_DIR}/test/test_view_ops.py" "$@" -v TestViewOpsDIPU # run_test "${PYTORCH_TEST_DIR}/test/test_type_promotion.py" "$@" -v TestTypePromotionDIPU # run_test "${PYTORCH_TEST_DIR}/test/test_nn.py" "$@" -v TestNN - # run_test "${PYTORCH_TEST_DIR}/test/test_ops_fwd_gradients.py" "$@" -v TestFwdGradientsDIPU - # run_test "${PYTORCH_TEST_DIR}/test/test_ops_gradients.py" "$@" -v TestBwdGradientsDIPU + run_test "${PYTORCH_TEST_DIR}/test/test_ops_fwd_gradients.py" "$@" -v TestFwdGradientsDIPU + run_test "${PYTORCH_TEST_DIR}/test/test_ops_gradients.py" "$@" -v TestBwdGradientsDIPU # run_test "${PYTORCH_TEST_DIR}/test/test_ops.py" "$@" -v - # run_test "${PYTORCH_TEST_DIR}/test/test_shape_ops.py" "$@" -v TestShapeOpsDIPU + run_test "${PYTORCH_TEST_DIR}/test/test_shape_ops.py" "$@" -v TestShapeOpsDIPU } if [ "$LOGFILE" != "" ]; then diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp index 2514e1e16..9c94370e6 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/CustomFallbackFunctionsForAmpGradScaler.cpp @@ -87,16 +87,17 @@ at::Tensor& custom_fallback_dipu__amp_update_scale_(at::Tensor& current_scale, "found_inf must be a float tensor."); if (static_cast(found_inf.item())) { current_scale *= backoff_factor; - growth_tracker[0] = 0; + growth_tracker.fill_(c10::Scalar(0)); } else { // Entering this branch means we just carried out a successful step, // so growth_tracker is incremented before comparing to growth_interval. auto successful = growth_tracker.item() + 1; if (successful == growth_interval) { current_scale *= growth_factor; - growth_tracker[0] = 0; + growth_tracker.fill_(c10::Scalar(0)); } else { - growth_tracker[0] = successful; + //growth_tracker in torch 2.1 is a scalar tensor. in 2.0 is a size=1 tensor. + growth_tracker.fill_(c10::Scalar(successful)); } } return current_scale; diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/c10dOps.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/c10dOps.cpp index bdaf28fe7..b380438e7 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/c10dOps.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/c10dOps.cpp @@ -62,7 +62,8 @@ std::tuple, c10::intrusive_ptr> broadcast_dipu_( std::tuple, c10::intrusive_ptr> allreduce_dipu_( at::TensorList tensors, const c10::intrusive_ptr& process_group, - const c10::intrusive_ptr& reduce_op, int64_t timeout) { + const c10::intrusive_ptr& reduce_op, + const c10::optional& sparse_indices, int64_t timeout) { auto tensor_vec = tensors.vec(); auto work = process_group->getBackend(dipu::DIPU_DEVICE_TYPE) diff --git a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp index bd65992ff..31ca91c0d 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp @@ -5,7 +5,7 @@ namespace dipu { -static const size_t states_size = 200 * sizeof(4120); +static const size_t states_size = 0; // 200 * sizeof(4120); static const size_t seed_size = sizeof(uint64_t); static const size_t offset_size = sizeof(int64_t); static const size_t total_size = states_size + seed_size + offset_size;