Skip to content

Commit

Permalink
implement dy_dtwo
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 9, 2023
1 parent 39bc5c7 commit 63a4aee
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 24 deletions.
2 changes: 2 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand Down Expand Up @@ -125,6 +126,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand Down
21 changes: 18 additions & 3 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ template <typename FPTYPE, int MTILE, int KTILE>
__global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* em_x,
const FPTYPE* em,
Expand Down Expand Up @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
(var[1] +
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;
FPTYPE oldres = res;
FPTYPE t;
if (enable_se_atten) {
t = two_embed[block_idx * nnei * last_layer_size +
Expand All @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
xx) *
xx) *
(enable_se_atten ? res * t + res : res);
if (enable_se_atten) {
// from ii to ii + (nnei - breakpoint)
for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) {
dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size +
jj] = oldres * res;
}
}
}
GpuSyncThreads();
for (int kk = 0; kk < MTILE; kk++) {
Expand Down Expand Up @@ -764,6 +773,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -781,12 +791,15 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
DPErrcheck(gpuDeviceSynchronize());
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei));
DPErrcheck(gpuMemset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4));
if (two_embed != nullptr) {
DPErrcheck(gpuMemset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei));
}

tabulate_fusion_se_a_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, KK * WARP_SIZE, sizeof(FPTYPE) * MM * last_layer_size>>>(
dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0],
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
last_layer_size, is_sorted);
dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
Expand Down Expand Up @@ -990,6 +1003,7 @@ template void tabulate_fusion_se_a_gpu<double>(double* out,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -1002,6 +1016,7 @@ template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<double>(double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand Down
23 changes: 21 additions & 2 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
bool enable_se_atten = two_embed != nullptr;
memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei);
memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4);
if (enable_se_atten) {
memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size);
}
FPTYPE const lower = table_info[0];
FPTYPE const upper = table_info[1];
FPTYPE const _max = table_info[2];
Expand Down Expand Up @@ -212,25 +216,38 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
FPTYPE g =
(a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx);
FPTYPE resold = res;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
res = res * t + res;
g += t * g;
}

FPTYPE dotllrr = dot(ll, rr);
if (unloop) {
grad += g * dot(ll, rr) * (nnei - jj);
grad += g * dotllrr * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj);
if (enable_se_atten) {
// fill from jj to nnei
for (int jj2 = jj; jj2 < nnei; jj2++) {
dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size +
kk] += res * dotllrr;
}
}
} else {
grad += g * dot(ll, rr);
grad += g * dotllrr;
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0];
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1];
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2];
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3];
if (enable_se_atten) {
dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] +=
resold * dotllrr;
}
}
}
dy_dem_x[ii * nnei + jj] = grad;
Expand Down Expand Up @@ -660,6 +677,7 @@ template void deepmd::tabulate_fusion_se_a_cpu<double>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -673,6 +691,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<double>(
double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand Down
15 changes: 8 additions & 7 deletions source/lib/tests/test_tabulate_se_a.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
std::vector<double> dy_dem_x(em_x.size());
std::vector<double> dy_dem(em.size());
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);
FPTYPE *dy_dtwo = nullptr;
deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], nullptr,
&dy[0], nloc, nnei, last_layer_size);
&dy_dem_x[0], &dy_dem[0], dy_dtwo, table[0], &info[0], &em_x[0], &em[0],
nullptr, &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size());
Expand All @@ -741,7 +742,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
}

deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0],
&dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0],
&two_embed[0], &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
Expand Down Expand Up @@ -804,16 +805,16 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);

double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL,
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL;
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo = nullptr;
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::malloc_device_memory_sync(table_dev, table);
deepmd::malloc_device_memory_sync(em_x_dev, em_x);
deepmd::malloc_device_memory_sync(em_dev, em);
deepmd::malloc_device_memory_sync(dy_dev, dy);
deepmd::tabulate_fusion_se_a_grad_gpu<double>(
dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, nullptr,
dy_dev, nloc, nnei, last_layer_size);
dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev,
nullptr, dy_dev, nloc, nnei, last_layer_size);
deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x);
deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);

Expand All @@ -833,7 +834,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::tabulate_fusion_se_a_grad_gpu<double>(
dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev,
dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev,
two_embed_dev, dy_dev, nloc, nnei, last_layer_size);
deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x);
deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);
Expand Down
26 changes: 14 additions & 12 deletions source/op/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class TabulateFusionSeAGradOp : public OpKernel {
// flat the tensors
FPTYPE* dy_dem_x = dy_dem_x_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dem = dy_dem_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dtwo = nullptr;

const FPTYPE* descriptor = descriptor_tensor.flat<FPTYPE>().data();
const FPTYPE* table = table_tensor.flat<FPTYPE>().data();
Expand All @@ -275,14 +276,14 @@ class TabulateFusionSeAGradOp : public OpKernel {

if (device == "GPU") {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size);
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, dy_dtwo, table,
table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size);
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, dy_dtwo, table,
table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size);
}
}

Expand Down Expand Up @@ -468,6 +469,7 @@ class TabulateFusionSeAttenGradOp : public OpKernel {
// flat the tensors
FPTYPE* dy_dem_x = dy_dem_x_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dem = dy_dem_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dtwo = dy_dtwo_tensor->flat<FPTYPE>().data();

const FPTYPE* descriptor = descriptor_tensor.flat<FPTYPE>().data();
const FPTYPE* table = table_tensor.flat<FPTYPE>().data();
Expand All @@ -482,14 +484,14 @@ class TabulateFusionSeAttenGradOp : public OpKernel {

if (device == "GPU") {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size, is_sorted);
deepmd::tabulate_fusion_se_a_grad_gpu(
dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size, is_sorted);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size, is_sorted);
deepmd::tabulate_fusion_se_a_grad_cpu(
dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size, is_sorted);
}
}

Expand Down

0 comments on commit 63a4aee

Please sign in to comment.