Skip to content

Commit

Permalink
Add min_bler and callback to sim_ber (#257)
Browse files Browse the repository at this point in the history
* Add min_ber and callback to sim_ber

* Add documentation to previous commit for Add min_ber and callback to sim_ber

Signed-off-by: Neal Becker <[email protected]>

* adds support for early stoping from callback

---------

Signed-off-by: Neal Becker <[email protected]>
Co-authored-by: SebastianCa <[email protected]>
  • Loading branch information
nbecker and SebastianCa authored Nov 13, 2023
1 parent f5ea373 commit 105a6fe
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions sionna/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def sim_ber(mc_fun,
graph_mode=None,
verbose=True,
forward_keyboard_interrupt=True,
dtype=tf.complex64):
dtype=tf.complex64,
callback=None):
"""Simulates until target number of errors is reached and returns BER/BLER.
The simulation continues with the next SNR point if either
Expand All @@ -420,7 +421,7 @@ def sim_ber(mc_fun,
Input
-----
mc_fun:
mc_fun: callable
Callable that yields the transmitted bits `b` and the
receiver's estimate `b_hat` for a given ``batch_size`` and
``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as
Expand Down Expand Up @@ -470,6 +471,22 @@ def sim_ber(mc_fun,
dtype: tf.complex64
Datatype of the model / function to be used (``mc_fun``).
callback: callable
Defaults to `None`. If specified, ``callback``
will be called after each Monte-Carlo step. Can be used for
logging or advanced early stopping.
Input signature of ``callback`` must match `callback(mc_iter,
ebno_dbs, bit_errors, block_errors, nb_bits, nb_blocks)` where
``mc_iter`` denotes the number of processed batches for the current
SNR, ``ebno_dbs`` is the current SNR point, ``bit_errors`` the number
of bit errors, ``block_errors`` the number of block errors, ``nb_bits``
the number of simulated bits, ``nb_blocks`` the number of simulated
blocks. If ``callable`` returns `sim_ber.CALLBACK_NEXT_SNR`, early
stopping is detected and the simulation will continue with the next SNR
point. If ``callable`` returns `sim_ber.CALLBACK_STOP`, the simulation
is stopped immediately. For `sim_ber.CALLBACK_CONTINUE` continues with
the simulation.
Output
------
(ber, bler) :
Expand Down Expand Up @@ -567,7 +584,8 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):
"reached max iter ", # status=1; spacing for impr. layout
"no errors - early stop", # status=2
"reached target bit errors", # status=3
"reached target block errors"] # status=4
"reached target block errors", # status=4
"callback triggered stopping"] # status=5

# check inputs for consistency
assert isinstance(early_stop, bool), "early_stop must be bool."
Expand Down Expand Up @@ -658,6 +676,18 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):
nb_blocks = tf.tensor_scatter_nd_add( nb_blocks, [[i]],
tf.cast([block_n], tf.int64))

cb_state = sim_ber.CALLBACK_CONTINUE
if callback is not None:
cb_state = callback (ii, ebno_dbs[i], bit_errors[i],
block_errors[i], nb_bits[i],
nb_blocks[i])
if cb_state in (sim_ber.CALLBACK_STOP,
sim_ber.CALLBACK_NEXT_SNR):
# stop runtime timer
runtime[i] = time.perf_counter() - runtime[i]
status[i] = 5 # change internal status for summary
break # stop for this SNR point have been simulated

# print progress summary
if verbose:
# print summary header during first iteration
Expand Down Expand Up @@ -714,6 +744,14 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):
print("\nSimulation stopped as no error occurred " \
f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
break
# allow callback to end the entire simulation
if cb_state is sim_ber.CALLBACK_STOP:
# stop runtime timer
status[i] = 5 # change internal status for summary
if verbose:
print("\nSimulation stopped by callback funtion " \
f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
break

# Stop if KeyboardInterrupt is detected and set remaining SNR points to -1
except KeyboardInterrupt as e:
Expand Down Expand Up @@ -745,6 +783,9 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):

return ber, bler

sim_ber.CALLBACK_CONTINUE = None
sim_ber.CALLBACK_STOP = 2
sim_ber.CALLBACK_NEXT_SNR = 1

def complex_normal(shape, var=1.0, dtype=tf.complex64):
r"""Generates a tensor of complex normal random variables.
Expand Down

0 comments on commit 105a6fe

Please sign in to comment.