Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make the se attn v2 descriptor energy conservative. #2905

Merged
merged 15 commits into from
Oct 11, 2023
28 changes: 24 additions & 4 deletions deepmd/descriptor/se_atten.py
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@
self.filter_precision,
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
# hard coding the magnitude of attention weight shift
self.smth_attn_w_shift = 20.0
# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
tf.summary.histogram("rij", self.rij)
Expand Down Expand Up @@ -599,7 +601,9 @@
)
self.recovered_r = (
tf.reshape(
tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]),
tf.slice(
tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]
),
[-1, natoms[0], self.sel_all_a[0]],
)
* self.std_looked_up
Expand Down Expand Up @@ -865,10 +869,26 @@
save_weights=True,
):
attn = tf.matmul(Q / temperature, K, transpose_b=True)
attn *= self.nmask
attn += self.negative_mask
if self.smooth:

Check warning on line 872 in deepmd/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_atten.py#L872

Added line #L872 was not covered by tests
# (nb x nloc) x nsel
nsel = self.sel_all_a[0]
attn = (attn + self.smth_attn_w_shift) * tf.reshape(

Check warning on line 875 in deepmd/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_atten.py#L874-L875

Added lines #L874 - L875 were not covered by tests
self.recovered_switch, [-1, 1, nsel]
) * tf.reshape(
self.recovered_switch, [-1, nsel, 1]
) - self.smth_attn_w_shift
else:
attn *= self.nmask
attn += self.negative_mask

Check warning on line 882 in deepmd/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_atten.py#L881-L882

Added lines #L881 - L882 were not covered by tests
attn = tf.nn.softmax(attn, axis=-1)
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if self.smooth:
attn = (

Check warning on line 885 in deepmd/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_atten.py#L884-L885

Added lines #L884 - L885 were not covered by tests
attn
* tf.reshape(self.recovered_switch, [-1, 1, nsel])
* tf.reshape(self.recovered_switch, [-1, nsel, 1])
)
else:
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])

Check warning on line 891 in deepmd/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_atten.py#L891

Added line #L891 was not covered by tests
if save_weights:
self.attn_weight[layer] = attn[0] # atom 0
if dotr:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/op/_tabulate_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy):
op.outputs[0],
is_sorted=op.get_attr("is_sorted"),
)
return [None, None, dy_dx, dy_df, None]
return [None, None, dy_dx, dy_df, dy_dtwo]


@ops.RegisterGradient("TabulateFusionSeAttenGrad")
Expand Down
2 changes: 2 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand Down Expand Up @@ -125,6 +126,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand Down
21 changes: 18 additions & 3 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ template <typename FPTYPE, int MTILE, int KTILE>
__global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* em_x,
const FPTYPE* em,
Expand Down Expand Up @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
(var[1] +
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;
FPTYPE oldres = res;
FPTYPE t;
if (enable_se_atten) {
t = two_embed[block_idx * nnei * last_layer_size +
Expand All @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
xx) *
xx) *
(enable_se_atten ? res * t + res : res);
if (enable_se_atten) {
// from ii to ii + (nnei - breakpoint)
for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) {
dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size +
jj] = oldres * res;
}
}
}
GpuSyncThreads();
for (int kk = 0; kk < MTILE; kk++) {
Expand Down Expand Up @@ -764,6 +773,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -781,12 +791,15 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
DPErrcheck(gpuDeviceSynchronize());
DPErrcheck(gpuMemset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei));
DPErrcheck(gpuMemset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4));
if (two_embed != nullptr) {
DPErrcheck(gpuMemset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei));
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
}

tabulate_fusion_se_a_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, KK * WARP_SIZE, sizeof(FPTYPE) * MM * last_layer_size>>>(
dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0],
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
last_layer_size, is_sorted);
dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
Expand Down Expand Up @@ -990,6 +1003,7 @@ template void tabulate_fusion_se_a_gpu<double>(double* out,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -1002,6 +1016,7 @@ template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<double>(double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand Down
23 changes: 21 additions & 2 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
bool enable_se_atten = two_embed != nullptr;
memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei);
memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4);
if (enable_se_atten) {
memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size);
}
FPTYPE const lower = table_info[0];
FPTYPE const upper = table_info[1];
FPTYPE const _max = table_info[2];
Expand Down Expand Up @@ -212,25 +216,38 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
FPTYPE g =
(a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx);
FPTYPE resold = res;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
res = res * t + res;
g += t * g;
}

FPTYPE dotllrr = dot(ll, rr);
if (unloop) {
grad += g * dot(ll, rr) * (nnei - jj);
grad += g * dotllrr * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj);
if (enable_se_atten) {
// fill from jj to nnei
for (int jj2 = jj; jj2 < nnei; jj2++) {
dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size +
kk] += resold * dotllrr;
}
}
} else {
grad += g * dot(ll, rr);
grad += g * dotllrr;
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0];
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1];
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2];
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3];
if (enable_se_atten) {
dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] +=
resold * dotllrr;
}
}
}
dy_dem_x[ii * nnei + jj] = grad;
Expand Down Expand Up @@ -660,6 +677,7 @@ template void deepmd::tabulate_fusion_se_a_cpu<double>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -673,6 +691,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<double>(
double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand Down
22 changes: 13 additions & 9 deletions source/lib/tests/test_tabulate_se_a.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
std::vector<double> dy_dem_x(em_x.size());
std::vector<double> dy_dem(em.size());
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);
std::vector<double> dy_dtwo(nloc * nnei * last_layer_size);
deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], nullptr,
&dy[0], nloc, nnei, last_layer_size);
&dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0],
&em[0], nullptr, &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size());
Expand All @@ -741,8 +742,8 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
}

deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0],
&two_embed[0], &dy[0], nloc, nnei, last_layer_size);
&dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0],
&em[0], &two_embed[0], &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size());
Expand Down Expand Up @@ -802,18 +803,20 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
std::vector<double> dy_dem_x(em_x.size(), 0.0);
std::vector<double> dy_dem(em.size(), 0.0);
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);
std::vector<double> dy_dtwo(nloc * nnei * last_layer_size, 0.0);

double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL,
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL;
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL,
*dy_dtwo_dev = nullptr;
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::malloc_device_memory_sync(table_dev, table);
deepmd::malloc_device_memory_sync(em_x_dev, em_x);
deepmd::malloc_device_memory_sync(em_dev, em);
deepmd::malloc_device_memory_sync(dy_dev, dy);
deepmd::tabulate_fusion_se_a_grad_gpu<double>(
dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, nullptr,
dy_dev, nloc, nnei, last_layer_size);
dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev,
njzjz marked this conversation as resolved.
Show resolved Hide resolved
nullptr, dy_dev, nloc, nnei, last_layer_size);
deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x);
deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);

Expand All @@ -832,9 +835,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
deepmd::malloc_device_memory_sync(two_embed_dev, two_embed);
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::malloc_device_memory_sync(dy_dtwo_dev, dy_dtwo);
deepmd::tabulate_fusion_se_a_grad_gpu<double>(
dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev,
two_embed_dev, dy_dev, nloc, nnei, last_layer_size);
dy_dem_x_dev, dy_dem_dev, dy_dtwo_dev, table_dev, &info[0], em_x_dev,
em_dev, two_embed_dev, dy_dev, nloc, nnei, last_layer_size);
deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x);
deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);
for (int jj = 0; jj < dy_dem_x.size(); ++jj) {
Expand Down
26 changes: 14 additions & 12 deletions source/op/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class TabulateFusionSeAGradOp : public OpKernel {
// flat the tensors
FPTYPE* dy_dem_x = dy_dem_x_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dem = dy_dem_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dtwo = nullptr;

const FPTYPE* descriptor = descriptor_tensor.flat<FPTYPE>().data();
const FPTYPE* table = table_tensor.flat<FPTYPE>().data();
Expand All @@ -275,14 +276,14 @@ class TabulateFusionSeAGradOp : public OpKernel {

if (device == "GPU") {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size);
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, dy_dtwo, table,
table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size);
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, dy_dtwo, table,
table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size);
}
}

Expand Down Expand Up @@ -468,6 +469,7 @@ class TabulateFusionSeAttenGradOp : public OpKernel {
// flat the tensors
FPTYPE* dy_dem_x = dy_dem_x_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dem = dy_dem_tensor->flat<FPTYPE>().data();
FPTYPE* dy_dtwo = dy_dtwo_tensor->flat<FPTYPE>().data();

const FPTYPE* descriptor = descriptor_tensor.flat<FPTYPE>().data();
const FPTYPE* table = table_tensor.flat<FPTYPE>().data();
Expand All @@ -482,14 +484,14 @@ class TabulateFusionSeAttenGradOp : public OpKernel {

if (device == "GPU") {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size, is_sorted);
deepmd::tabulate_fusion_se_a_grad_gpu(
dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size, is_sorted);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info,
em_x, em, two_embed, dy, nloc, nnei,
last_layer_size, is_sorted);
deepmd::tabulate_fusion_se_a_grad_cpu(
dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy,
nloc, nnei, last_layer_size, is_sorted);
}
}

Expand Down
Loading
Loading