diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 90cc86c4d9..8f3be40596 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -564,6 +564,8 @@ def build( self.filter_precision, ) self.negative_mask = -(2 << 32) * (1.0 - self.nmask) + # hard coding the magnitude of attention weight shift + self.smth_attn_w_shift = 20.0 # only used when tensorboard was set as true tf.summary.histogram("descrpt", self.descrpt) tf.summary.histogram("rij", self.rij) @@ -599,7 +601,9 @@ def build( ) self.recovered_r = ( tf.reshape( - tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]), + tf.slice( + tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1] + ), [-1, natoms[0], self.sel_all_a[0]], ) * self.std_looked_up @@ -865,10 +869,26 @@ def _scaled_dot_attn( save_weights=True, ): attn = tf.matmul(Q / temperature, K, transpose_b=True) - attn *= self.nmask - attn += self.negative_mask + if self.smooth: + # (nb x nloc) x nsel + nsel = self.sel_all_a[0] + attn = (attn + self.smth_attn_w_shift) * tf.reshape( + self.recovered_switch, [-1, 1, nsel] + ) * tf.reshape( + self.recovered_switch, [-1, nsel, 1] + ) - self.smth_attn_w_shift + else: + attn *= self.nmask + attn += self.negative_mask attn = tf.nn.softmax(attn, axis=-1) - attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1]) + if self.smooth: + attn = ( + attn + * tf.reshape(self.recovered_switch, [-1, 1, nsel]) + * tf.reshape(self.recovered_switch, [-1, nsel, 1]) + ) + else: + attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1]) if save_weights: self.attn_weight[layer] = attn[0] # atom 0 if dotr: diff --git a/deepmd/op/_tabulate_grad.py b/deepmd/op/_tabulate_grad.py index 9076ee3213..8ad8908d7e 100644 --- a/deepmd/op/_tabulate_grad.py +++ b/deepmd/op/_tabulate_grad.py @@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy): op.outputs[0], is_sorted=op.get_attr("is_sorted"), ) - return [None, None, dy_dx, dy_df, None] + return [None, None, dy_dx, dy_df, dy_dtwo] @ops.RegisterGradient("TabulateFusionSeAttenGrad") @@ -68,6 +68,7 @@ def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo): op.inputs[4], dy, dy_, + dy_dtwo, op.inputs[6], is_sorted=op.get_attr("is_sorted"), ) diff --git a/source/lib/include/tabulate.h b/source/lib/include/tabulate.h index 93992cea5b..47c3062449 100644 --- a/source/lib/include/tabulate.h +++ b/source/lib/include/tabulate.h @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -38,6 +39,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, + const FPTYPE* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, @@ -125,6 +127,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -145,6 +148,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, + const FPTYPE* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index 09d02bdf2c..a22742ae19 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -253,6 +253,7 @@ template __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* em_x, const FPTYPE* em, @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; + FPTYPE oldres = res; FPTYPE t; if (enable_se_atten) { t = two_embed[block_idx * nnei * last_layer_size + @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( xx) * xx) * (enable_se_atten ? res * t + res : res); + if (enable_se_atten) { + // from ii to ii + (nnei - breakpoint) + for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) { + dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size + + jj] = oldres * res; + } + } } GpuSyncThreads(); for (int kk = 0; kk < MTILE; kk++) { @@ -357,6 +366,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, + const FPTYPE* dz_dy_dtwo, const FPTYPE lower, const FPTYPE upper, const FPTYPE max, @@ -404,9 +414,15 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * xx) * xx; + FPTYPE two_grad = 0.; if (enable_se_atten) { FPTYPE t = two_embed[block_idx * nnei * last_layer_size + ii * last_layer_size + thread_idx]; + // dz_dy_dtwo * res * em + // res above should be used instead of res + res * t below + two_grad = dz_dy_dtwo[block_idx * nnei * last_layer_size + + ii * last_layer_size + thread_idx] * + res; res += res * t; res_grad += res_grad * t; } @@ -434,8 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( for (int kk = 0; kk < MTILE; kk++) { int em_index = block_idx * nnei * MTILE + ii * MTILE + kk; iteratorC[kk * last_layer_size + thread_idx] += - (nnei - breakpoint) * - (em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res); + (nnei - breakpoint) * (em[em_index] * (res_grad * dz_xx + two_grad) + + dz_dy_dem[em_index] * res); } mark_table_idx = table_idx; if (unloop) { @@ -764,6 +780,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out, template void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -784,9 +801,9 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x, tabulate_fusion_se_a_grad_fifth_order_polynomial <<>>( - dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0], - table_info[1], table_info[2], table_info[3], table_info[4], nnei, - last_layer_size, is_sorted); + dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy, + table_info[0], table_info[1], table_info[2], table_info[3], + table_info[4], nnei, last_layer_size, is_sorted); DPErrcheck(gpuGetLastError()); DPErrcheck(gpuDeviceSynchronize()); } @@ -800,6 +817,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, + const FPTYPE* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, @@ -812,7 +830,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy, DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size)); tabulate_fusion_se_a_grad_grad_fifth_order_polynomial <<>>( - dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, + dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, dz_dy_dtwo, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei, last_layer_size, is_sorted); DPErrcheck(gpuGetLastError()); @@ -990,6 +1008,7 @@ template void tabulate_fusion_se_a_gpu(double* out, const bool is_sorted); template void tabulate_fusion_se_a_grad_gpu(float* dy_dem_x, float* dy_dem, + float* dy_dtwo, const float* table, const float* table_info, const float* em_x, @@ -1002,6 +1021,7 @@ template void tabulate_fusion_se_a_grad_gpu(float* dy_dem_x, const bool is_sorted); template void tabulate_fusion_se_a_grad_gpu(double* dy_dem_x, double* dy_dem, + double* dy_dtwo, const double* table, const double* table_info, const double* em_x, @@ -1021,6 +1041,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu( const float* two_embed, const float* dz_dy_dem_x, const float* dz_dy_dem, + const float* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, @@ -1034,6 +1055,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu( const double* two_embed, const double* dz_dy_dem_x, const double* dz_dy_dem, + const double* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 1f49cf0daa..3e2a1bec62 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, template void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, FPTYPE* dy_dem, + FPTYPE* dy_dtwo, const FPTYPE* table, const FPTYPE* table_info, const FPTYPE* em_x, @@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, bool enable_se_atten = two_embed != nullptr; memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei); memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4); + if (enable_se_atten) { + memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size); + } FPTYPE const lower = table_info[0]; FPTYPE const upper = table_info[1]; FPTYPE const _max = table_info[2]; @@ -212,6 +216,7 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; FPTYPE g = (a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx); + FPTYPE resold = res; if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + jj * last_layer_size + kk]; @@ -219,18 +224,30 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, g += t * g; } + FPTYPE dotllrr = dot(ll, rr); if (unloop) { - grad += g * dot(ll, rr) * (nnei - jj); + grad += g * dotllrr * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj); dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj); + if (enable_se_atten) { + // fill from jj to nnei + for (int jj2 = jj; jj2 < nnei; jj2++) { + dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size + + kk] += resold * dotllrr; + } + } } else { - grad += g * dot(ll, rr); + grad += g * dotllrr; dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0]; dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1]; dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2]; dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3]; + if (enable_se_atten) { + dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] += + resold * dotllrr; + } } } dy_dem_x[ii * nnei + jj] = grad; @@ -250,6 +267,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, const FPTYPE* two_embed, const FPTYPE* dz_dy_dem_x, const FPTYPE* dz_dy_dem, + const FPTYPE* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, @@ -300,9 +318,15 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, ((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * xx) * xx; + FPTYPE two_grad = 0.; if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + jj * last_layer_size + kk]; + // dz_dy_dtwo * var * ll + // var above should be used instead of var + var * t below + two_grad = dz_dy_dtwo[ii * nnei * last_layer_size + + jj * last_layer_size + kk] * + var; var += var * t; var_grad += var_grad * t; } @@ -329,22 +353,26 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, */ if (unloop) { dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += - (nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]); + (nnei - jj) * + (var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0]); dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += - (nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]); + (nnei - jj) * + (var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1]); dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += - (nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]); + (nnei - jj) * + (var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2]); dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += - (nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]); + (nnei - jj) * + (var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3]); } else { dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += - var * hh[0] + dz_xx * var_grad * ll[0]; + var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0]; dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += - var * hh[1] + dz_xx * var_grad * ll[1]; + var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1]; dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += - var * hh[2] + dz_xx * var_grad * ll[2]; + var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2]; dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += - var * hh[3] + dz_xx * var_grad * ll[3]; + var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3]; } } if (unloop) { @@ -660,6 +688,7 @@ template void deepmd::tabulate_fusion_se_a_cpu( template void deepmd::tabulate_fusion_se_a_grad_cpu( float* dy_dem_x, float* dy_dem, + float* dy_dtwo, const float* table, const float* table_info, const float* em_x, @@ -673,6 +702,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu( template void deepmd::tabulate_fusion_se_a_grad_cpu( double* dy_dem_x, double* dy_dem, + double* dy_dtwo, const double* table, const double* table_info, const double* em_x, @@ -692,6 +722,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu( const float* two_embed, const float* dz_dy_dem_x, const float* dz_dy_dem, + const float* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, @@ -705,6 +736,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu( const double* two_embed, const double* dz_dy_dem_x, const double* dz_dy_dem, + const double* dz_dy_dtwo, const int nloc, const int nnei, const int last_layer_size, diff --git a/source/lib/tests/test_tabulate_se_a.cc b/source/lib/tests/test_tabulate_se_a.cc index fc0fd04980..ce2defb22c 100644 --- a/source/lib/tests/test_tabulate_se_a.cc +++ b/source/lib/tests/test_tabulate_se_a.cc @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { std::vector dy_dem_x(em_x.size()); std::vector dy_dem(em.size()); std::vector dy(nloc * nnei * last_layer_size, 1.0); + std::vector dy_dtwo(nloc * nnei * last_layer_size); deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], nullptr, - &dy[0], nloc, nnei, last_layer_size); + &dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], + &em[0], nullptr, &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size()); @@ -741,8 +742,8 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { } deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], &table[0], &info[0], &em_x[0], &em[0], - &two_embed[0], &dy[0], nloc, nnei, last_layer_size); + &dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], + &em[0], &two_embed[0], &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); EXPECT_EQ(dy_dem_x.size(), expected_dy_dem_x.size()); @@ -802,9 +803,11 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { std::vector dy_dem_x(em_x.size(), 0.0); std::vector dy_dem(em.size(), 0.0); std::vector dy(nloc * nnei * last_layer_size, 1.0); + std::vector dy_dtwo(nloc * nnei * last_layer_size, 0.0); double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL, - *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL; + *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, + *dy_dtwo_dev = nullptr; deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); deepmd::malloc_device_memory_sync(table_dev, table); @@ -812,8 +815,8 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { deepmd::malloc_device_memory_sync(em_dev, em); deepmd::malloc_device_memory_sync(dy_dev, dy); deepmd::tabulate_fusion_se_a_grad_gpu( - dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, nullptr, - dy_dev, nloc, nnei, last_layer_size); + dy_dem_x_dev, dy_dem_dev, dy_dtwo_dev, table_dev, &info[0], em_x_dev, + em_dev, nullptr, dy_dev, nloc, nnei, last_layer_size); deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); @@ -832,9 +835,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { deepmd::malloc_device_memory_sync(two_embed_dev, two_embed); deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(dy_dtwo_dev, dy_dtwo); deepmd::tabulate_fusion_se_a_grad_gpu( - dy_dem_x_dev, dy_dem_dev, table_dev, &info[0], em_x_dev, em_dev, - two_embed_dev, dy_dev, nloc, nnei, last_layer_size); + dy_dem_x_dev, dy_dem_dev, dy_dtwo_dev, table_dev, &info[0], em_x_dev, + em_dev, two_embed_dev, dy_dev, nloc, nnei, last_layer_size); deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); for (int jj = 0; jj < dy_dem_x.size(); ++jj) { diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 488a99bd7d..6a70f60a96 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -100,6 +100,7 @@ REGISTER_OP("TabulateFusionSeAttenGradGrad") .Input("two_embed: T") .Input("dz_dy_dem_x: T") .Input("dz_dy_dem: T") + .Input("dz_dy_dtwo: T") .Input("descriptor: T") .Output("dz_dy: T") .Attr("is_sorted: bool = true"); @@ -261,6 +262,7 @@ class TabulateFusionSeAGradOp : public OpKernel { // flat the tensors FPTYPE* dy_dem_x = dy_dem_x_tensor->flat().data(); FPTYPE* dy_dem = dy_dem_tensor->flat().data(); + FPTYPE* dy_dtwo = nullptr; const FPTYPE* descriptor = descriptor_tensor.flat().data(); const FPTYPE* table = table_tensor.flat().data(); @@ -275,14 +277,14 @@ class TabulateFusionSeAGradOp : public OpKernel { if (device == "GPU") { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size); + deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, dy_dtwo, table, + table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } else if (device == "CPU") { - deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size); + deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, dy_dtwo, table, + table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size); } } @@ -328,6 +330,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel { const FPTYPE* two_embed = nullptr; const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat().data(); const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat().data(); + const FPTYPE* dz_dy_dtwo = nullptr; const int nloc = em_tensor.shape().dim_size(0); const int nnei = em_tensor.shape().dim_size(1); const int last_layer_size = descriptor_tensor.shape().dim_size(2); @@ -336,7 +339,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM deepmd::tabulate_fusion_se_a_grad_grad_gpu( dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, - nloc, nnei, last_layer_size, is_sorted); + dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM OP_REQUIRES(context, (last_layer_size <= 1024), errors::InvalidArgument( @@ -345,7 +348,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel { } else if (device == "CPU") { deepmd::tabulate_fusion_se_a_grad_grad_cpu( dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, - nloc, nnei, last_layer_size, is_sorted); + dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted); } } @@ -468,6 +471,7 @@ class TabulateFusionSeAttenGradOp : public OpKernel { // flat the tensors FPTYPE* dy_dem_x = dy_dem_x_tensor->flat().data(); FPTYPE* dy_dem = dy_dem_tensor->flat().data(); + FPTYPE* dy_dtwo = dy_dtwo_tensor->flat().data(); const FPTYPE* descriptor = descriptor_tensor.flat().data(); const FPTYPE* table = table_tensor.flat().data(); @@ -482,14 +486,14 @@ class TabulateFusionSeAttenGradOp : public OpKernel { if (device == "GPU") { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - deepmd::tabulate_fusion_se_a_grad_gpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size, is_sorted); + deepmd::tabulate_fusion_se_a_grad_gpu( + dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size, is_sorted); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } else if (device == "CPU") { - deepmd::tabulate_fusion_se_a_grad_cpu(dy_dem_x, dy_dem, table, table_info, - em_x, em, two_embed, dy, nloc, nnei, - last_layer_size, is_sorted); + deepmd::tabulate_fusion_se_a_grad_cpu( + dy_dem_x, dy_dem, dy_dtwo, table, table_info, em_x, em, two_embed, dy, + nloc, nnei, last_layer_size, is_sorted); } } @@ -520,6 +524,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel { const Tensor& two_embed_tensor = context->input(context_input_index++); const Tensor& dz_dy_dem_x_tensor = context->input(context_input_index++); const Tensor& dz_dy_dem_tensor = context->input(context_input_index++); + const Tensor& dz_dy_dtwo_tensor = context->input(context_input_index++); const Tensor& descriptor_tensor = context->input(context_input_index++); // set size of the sample OP_REQUIRES(context, (dz_dy_dem_x_tensor.shape().dims() == 2), @@ -542,6 +547,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel { const FPTYPE* two_embed = two_embed_tensor.flat().data(); const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat().data(); const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat().data(); + const FPTYPE* dz_dy_dtwo = dz_dy_dtwo_tensor.flat().data(); const int nloc = em_tensor.shape().dim_size(0); const int nnei = em_tensor.shape().dim_size(1); const int last_layer_size = descriptor_tensor.shape().dim_size(2); @@ -550,7 +556,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM deepmd::tabulate_fusion_se_a_grad_grad_gpu( dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, - nloc, nnei, last_layer_size, is_sorted); + dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM OP_REQUIRES(context, (last_layer_size <= 1024), errors::InvalidArgument( @@ -559,7 +565,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel { } else if (device == "CPU") { deepmd::tabulate_fusion_se_a_grad_grad_cpu( dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, - nloc, nnei, last_layer_size, is_sorted); + dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted); } } diff --git a/source/tests/common.py b/source/tests/common.py index f8ed23df03..9af324896f 100644 --- a/source/tests/common.py +++ b/source/tests/common.py @@ -530,6 +530,85 @@ def strerch_box(old_coord, old_box, new_box): return ncoord.reshape(old_coord.shape) +def finite_difference_fv(sess, energy, feed_dict, t_coord, t_box, delta=1e-6): + """For energy models, compute f, v by finite difference.""" + base_dict = feed_dict.copy() + coord0 = base_dict.pop(t_coord) + box0 = base_dict.pop(t_box) + fdf = -finite_difference( + lambda coord: sess.run( + energy, feed_dict={**base_dict, t_coord: coord, t_box: box0} + ).reshape(-1), + coord0, + delta=delta, + ).reshape(-1) + fdv = -( + finite_difference( + lambda box: sess.run( + energy, + feed_dict={ + **base_dict, + t_coord: strerch_box(coord0, box0, box), + t_box: box, + }, + ).reshape(-1), + box0, + delta=delta, + ) + .reshape([-1, 3, 3]) + .transpose(0, 2, 1) + @ box0.reshape(3, 3) + ).reshape(-1) + return fdf, fdv + + +def check_continuity(f, cc, rcut, delta): + """coord[0:2] to [[0, 0, 0], [rcut+-.5*delta, 0, 0]].""" + cc = cc.reshape([-1, 3]) + cc0 = np.copy(cc) + cc1 = np.copy(cc) + cc0[:2, :] = np.array( + [ + 0.0, + 0.0, + 0.0, + rcut - 0.5 * delta, + 0.0, + 0.0, + ] + ).reshape([-1, 3]) + cc1[:2, :] = np.array( + [ + 0.0, + 0.0, + 0.0, + rcut + 0.5 * delta, + 0.0, + 0.0, + ] + ).reshape([-1, 3]) + return f(cc0.reshape(-1)), f(cc1.reshape(-1)) + + +def check_smooth_efv(sess, energy, force, virial, feed_dict, t_coord, rcut, delta=1e-5): + """Check the smoothness of e, f and v + the returned values are de, df, dv + de[0] are supposed to be closed to de[1] + df[0] are supposed to be closed to df[1] + dv[0] are supposed to be closed to dv[1]. + """ + base_dict = feed_dict.copy() + coord0 = base_dict.pop(t_coord) + [fe, ff, fv] = [ + lambda coord: sess.run(ii, feed_dict={**base_dict, t_coord: coord}).reshape(-1) + for ii in [energy, force, virial] + ] + [de, df, dv] = [ + check_continuity(ii, coord0, rcut, delta=delta) for ii in [fe, ff, fv] + ] + return de, df, dv + + def run_dp(cmd: str) -> int: """Run DP directly from the entry point instead of the subprocess. diff --git a/source/tests/test_model_se_atten.py b/source/tests/test_model_se_atten.py index 445959ceb2..ad037e2931 100644 --- a/source/tests/test_model_se_atten.py +++ b/source/tests/test_model_se_atten.py @@ -5,6 +5,8 @@ import numpy as np from common import ( DataSystem, + check_smooth_efv, + finite_difference_fv, gen_data, j_loader, ) @@ -726,3 +728,147 @@ def test_stripped_type_embedding_exclude_types(self): np.testing.assert_almost_equal(des[:, 0:2], 0.0, 10) with self.assertRaises(AssertionError): np.testing.assert_almost_equal(des[:, 2:6], 0.0, 10) + + def test_smoothness_of_stripped_type_embedding_smooth_model(self): + """test: auto-diff, continuity of e,f,v.""" + jfile = "water_se_atten.json" + jdata = j_loader(jfile) + + systems = j_must_have(jdata, "systems") + set_pfx = j_must_have(jdata, "set_prefix") + batch_size = j_must_have(jdata, "batch_size") + test_size = j_must_have(jdata, "numb_test") + batch_size = 1 + test_size = 1 + stop_batch = j_must_have(jdata, "stop_batch") + rcut = j_must_have(jdata["model"]["descriptor"], "rcut") + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) + + test_data = data.get_test() + numb_test = 1 + + jdata["model"]["descriptor"].pop("type", None) + jdata["model"]["descriptor"]["ntypes"] = 2 + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["smooth_type_embdding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 1 + descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) + jdata["model"]["fitting_net"]["descrpt"] = descrpt + fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True) + typeebd_param = jdata["model"]["type_embedding"] + typeebd = TypeEmbedNet( + neuron=typeebd_param["neuron"], + activation_function=None, + resnet_dt=typeebd_param["resnet_dt"], + seed=typeebd_param["seed"], + uniform_seed=True, + padding=True, + ) + model = EnerModel(descrpt, fitting, typeebd) + + # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']]) + input_data = { + "coord": [test_data["coord"]], + "box": [test_data["box"]], + "type": [test_data["type"]], + "natoms_vec": [test_data["natoms_vec"]], + "default_mesh": [test_data["default_mesh"]], + } + model._compute_input_stat(input_data) + model.descrpt.bias_atom_e = data.compute_energy_shift() + + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") + t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") + t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") + t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") + t_atom_ener = tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" + ) + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + is_training = tf.placeholder(tf.bool) + t_fparam = None + inputs_dict = {} + + model_pred = model.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + inputs_dict, + suffix=self.filename + + "-" + + inspect.stack()[0][3] + + "test_model_se_atten_model_compressible", + reuse=False, + ) + energy = model_pred["energy"] + force = model_pred["force"] + virial = model_pred["virial"] + atom_ener = model_pred["atom_ener"] + + feed_dict_test = { + t_prop_c: test_data["prop_c"], + t_energy: test_data["energy"][:numb_test], + t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), + t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), + t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), + t_box: test_data["box"][:numb_test, :], + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), + t_natoms: test_data["natoms_vec"], + t_mesh: test_data["default_mesh"], + is_training: False, + } + sess = self.cached_session().__enter__() + sess.run(tf.global_variables_initializer()) + [pe, pf, pv] = sess.run([energy, force, virial], feed_dict=feed_dict_test) + pf, pv = pf.reshape(-1), pv.reshape(-1) + + eps = 1e-4 + delta = 1e-5 + fdf, fdv = finite_difference_fv( + sess, energy, feed_dict_test, t_coord, t_box, delta=eps + ) + np.testing.assert_allclose(pf, fdf, delta) + np.testing.assert_allclose(pv, fdv, delta) + + tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7] + for eps in tested_eps: + deltae = eps + deltad = eps + de, df, dv = check_smooth_efv( + sess, + energy, + force, + virial, + feed_dict_test, + t_coord, + jdata["model"]["descriptor"]["rcut"], + delta=eps, + ) + np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae) + np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad) + np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad) + + for eps in tested_eps: + deltae = 5.0 * eps + deltad = 5.0 * eps + de, df, dv = check_smooth_efv( + sess, + energy, + force, + virial, + feed_dict_test, + t_coord, + jdata["model"]["descriptor"]["rcut_smth"], + delta=eps, + ) + np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae) + np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad) + np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)