diff --git a/sionna/rt/solver_base.py b/sionna/rt/solver_base.py index eb0f7790..9e99830c 100644 --- a/sionna/rt/solver_base.py +++ b/sionna/rt/solver_base.py @@ -816,12 +816,22 @@ def _swap_edges(self, edges): # 1. norm(p1) >= norm(p0) # 2. azimuth(p1) >= azimuth(p0) # 3. elevation (p1) >= elevation(p0) - needs_swap_1 = r0 > r1 - not_disc_1 = tf.experimental.numpy.isclose(r0, r1) - needs_swap_2 = tf.logical_and(not_disc_1, phi0 > phi1) - not_disc_2 = tf.experimental.numpy.isclose(phi0, phi1) - not_disc_12 = tf.logical_and(not_disc_1, not_disc_2) - needs_swap_3 = tf.logical_and(not_disc_12, theta0 > theta1) + + # More details of the algorithm: + # needs_swap 1: !r_equal and r0 > r1 + # needs_swap 2: r_equal and !phi_equal and phi0 > phi1 + # needs_swap 3: r_equal and phi_equal and theta0 > theta1 + # Note: case when all three coordinates are equal is not considered + + r_equal = tf.experimental.numpy.isclose(r0, r1) + phi_equal = tf.experimental.numpy.isclose(phi0, phi1) + case_2 = tf.logical_and(r_equal, tf.logical_not(phi_equal)) + case_3 = tf.logical_and(r_equal, phi_equal) + + needs_swap_1 = tf.logical_and(tf.logical_not(r_equal), r0 > r1) + needs_swap_2 = tf.logical_and(case_2, phi0 > phi1) + needs_swap_3 = tf.logical_and(case_3, theta0 > theta1) + needs_swap = tf.reduce_any(tf.stack([needs_swap_1, needs_swap_2, needs_swap_3], axis=1),