From 63a4aee915b259124e55489e0253476a4eb22d2d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 9 Oct 2023 17:55:07 -0400 Subject: [PATCH] implement dy_dtwo Signed-off-by: Jinzhe Zeng --- source/lib/include/tabulate.h | 2 ++ source/lib/src/gpu/tabulate.cu | 21 ++++++++++++++++++--- source/lib/src/tabulate.cc | 23 +++++++++++++++++++++-- source/lib/tests/test_tabulate_se_a.cc | 15 ++++++++------- source/op/tabulate_multi_device.cc | 26 ++++++++++++++------------ 5 files changed, 63 insertions(+), 24 deletions(-) diff --git a/source/lib/include/tabulate.h b/source/lib/include/tabulate.h index 93992cea5b..52bc06e2d4 100644 --- a/source/lib/include/tabulate.h +++ b/source/lib/include/tabulate.h @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -125,6 +126,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index 09d02bdf2c..323f89a26b 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -253,6 +253,7 @@ template __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* em_x, const FPTYPE* em, @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; + FPTYPE oldres = res; FPTYPE t; if (enable_se_atten) { t = two_embed[block_idx * nnei * last_layer_size + @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( xx) * xx) * (enable_se_atten ? res * t + res : res); + if (enable_se_atten) { + // from ii to ii + (nnei - breakpoint) + for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) { + dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size + + jj] = oldres * res; + } + } } GpuSyncThreads(); for (int kk = 0; kk < MTILE; kk++) { @@ -764,6 +773,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -781,12 +791,15 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, DPErrcheck(gpuDeviceSynchronize()); DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei)); DPErrcheck(gpuMemset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4)); + if (two_embed != nullptr) { + DPErrcheck(gpuMemset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei)); + } tabulate_fusion_se_a_grad_fifth_order_polynomial <<>>( - dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0], - table_info[1], table_info[2], table_info[3], table_info[4], nnei, - last_layer_size, is_sorted); + dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy, + table_info[0], table_info[1], table_info[2], table_info[3], + table_info[4], nnei, last_layer_size, is_sorted); DPErrcheck(gpuGetLastError()); DPErrcheck(gpuDeviceSynchronize()); } @@ -990,6 +1003,7 @@ template void tabulate_fusion_se_a_gpu(double* out, const bool is_sorted); template void tabulate_fusion_se_a_grad_gpu(float* dy_dem_x, float* dy_dem, + float* dy_dtwo, const float* table, const float* table_info, const float* em_x, @@ -1002,6 +1016,7 @@ template void tabulate_fusion_se_a_grad_gpu(float* dy_dem_x, const bool is_sorted); template void tabulate_fusion_se_a_grad_gpu(double* dy_dem_x, double* dy_dem, + double* dy_dtwo, const double* table, const double* table_info, const double* em_x, diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 1f49cf0daa..6c2068b262 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, template void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, bool enable_se_atten = two_embed != nullptr; memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei); memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4); + if (enable_se_atten) { + memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size); + } FPTYPE const lower = table_info[0]; FPTYPE const upper = table_info[1]; FPTYPE const _max = table_info[2]; @@ -212,6 +216,7 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; FPTYPE g = (a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx); + FPTYPE resold = res; if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + jj * last_layer_size + kk]; @@ -219,18 +224,30 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, g += t * g; } + FPTYPE dotllrr = dot(ll, rr); if (unloop) { - grad += g * dot(ll, rr) * (nnei - jj); + grad += g * dotllrr * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj); + if (enable_se_atten) { + // fill from jj to nnei + for (int jj2 = jj; jj2 < nnei; jj2++) { + dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size + + kk] += res * dotllrr; + } + } } else { - grad += g * dot(ll, rr); + grad += g * dotllrr; dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0]; dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1]; dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2]; dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3]; + if (enable_se_atten) { + dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] += + resold * dotllrr; + } } } dy_dem_x[ii * nnei + jj] = grad; @@ -660,6 +677,7 @@ template void deepmd::tabulate_fusion_se_a_cpu( template void deepmd::tabulate_fusion_se_a_grad_cpu( float* dy_dem_x, float* dy_dem, + float* dy_dtwo, const float* table, const float* table_info, const float* em_x, @@ -673,6 +691,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu( template void deepmd::tabulate_fusion_se_a_grad_cpu( double* dy_dem_x, double* dy_dem, + double* dy_dtwo, const double* table, const double* table_info, const double* em_x, diff --git a/source/lib/tests/test_tabulate_se_a.cc b/source/lib/tests/test_tabulate_se_a.cc index fc0fd04980..b272740584 100644 --- a/source/lib/tests/test_tabulate_se_a.cc +++ b/source/lib/tests/test_tabulate_se_a.cc @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { std::vector dy_dem_x(em_x.size()); std::vector dy_dem(em.size()); std::vector dy(nloc * nnei * last_layer_size, 1.0); + FPTYPE *dy_dtwo = nullptr; deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], nullptr, - &dy[0], nloc, nnei, last_layer_size); + &dy_dem_x[0], &dy_dem[0], dy_dtwo, table[0], &info[0], &em_x[0], &em[0], + nullptr, &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size()); @@ -741,7 +742,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { } deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], + &dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0], &two_embed[0], &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); @@ -804,7 +805,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { std::vector dy(nloc * nnei * last_layer_size, 1.0); double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL, - *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL; + *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo = nullptr; deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); deepmd::malloc_device_memory_sync(table_dev, table); @@ -812,8 +813,8 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { deepmd::malloc_device_memory_sync(em_dev, em); deepmd::malloc_device_memory_sync(dy_dev, dy); deepmd::tabulate_fusion_se_a_grad_gpu( - dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, nullptr, - dy_dev, nloc, nnei, last_layer_size); + dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev, + nullptr, dy_dev, nloc, nnei, last_layer_size); deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); @@ -833,7 +834,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); deepmd::tabulate_fusion_se_a_grad_gpu( - dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, + dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev, two_embed_dev, dy_dev, nloc, nnei, last_layer_size); deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 488a99bd7d..ceac954d0d 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -261,6 +261,7 @@ class TabulateFusionSeAGradOp : public OpKernel { // flat the tensors FPTYPE* dy_dem_x = dy_dem_x_tensor->flat().data(); FPTYPE* dy_dem = dy_dem_tensor->flat().data(); + FPTYPE* dy_dtwo = nullptr; const FPTYPE* descriptor = descriptor_tensor.flat().data(); const FPTYPE* table = table_tensor.flat().data(); @@ -275,14 +276,14 @@ class TabulateFusionSeAGradOp : public OpKernel { if (device == "GPU") { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size); + deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, dy_dtwo, table, + table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } else if (device == "CPU") { - deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size); + deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, dy_dtwo, table, + table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size); } } @@ -468,6 +469,7 @@ class TabulateFusionSeAttenGradOp : public OpKernel { // flat the tensors FPTYPE* dy_dem_x = dy_dem_x_tensor->flat().data(); FPTYPE* dy_dem = dy_dem_tensor->flat().data(); + FPTYPE* dy_dtwo = dy_dtwo_tensor->flat().data(); const FPTYPE* descriptor = descriptor_tensor.flat().data(); const FPTYPE* table = table_tensor.flat().data(); @@ -482,14 +484,14 @@ class TabulateFusionSeAttenGradOp : public OpKernel { if (device == "GPU") { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size, is_sorted); + deepmd::tabulate_fusion_se_a_grad_gpu( + dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size, is_sorted); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } else if (device == "CPU") { - deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size, is_sorted); + deepmd::tabulate_fusion_se_a_grad_cpu( + dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size, is_sorted); } }