diff --git a/pySC/core/beam.py b/pySC/core/beam.py index bc57b19..b13712e 100644 --- a/pySC/core/beam.py +++ b/pySC/core/beam.py @@ -13,7 +13,7 @@ from pySC.core.simulated_commissioning import SimulatedCommissioning from pySC.core.constants import TRACK_ORB, TRACK_PORB, TRACK_TBT from pySC.utils.sc_tools import SCrandnc -from pySC.utils.at_wrapper import atgetfieldvalues, atpass, findorbit6, findspos +from pySC.utils.at_wrapper import atgetfieldvalues, atpass, findorbit6, findspos, patpass import warnings from pySC.utils import logging_tools @@ -130,7 +130,7 @@ def beam_transmission(SC: SimulatedCommissioning, nParticles: int = None, nTurns if nTurns is None: nTurns = SC.INJ.nTurns LOGGER.debug(f'Calculating maximum beam transmission for {nParticles} particles and {nTurns} turns: ') - T = atpass(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]), keep_lattice=False) + T = patpass(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]), keep_lattice=False) fraction_survived = np.mean(~np.isnan(T[0, :, :, :]), axis=(0, 1)) max_turns = np.sum(fraction_survived > 1 - SC.INJ.beamLostAt) if plot: diff --git a/pySC/utils/at_wrapper.py b/pySC/utils/at_wrapper.py index 46f3360..430ae18 100644 --- a/pySC/utils/at_wrapper.py +++ b/pySC/utils/at_wrapper.py @@ -26,6 +26,11 @@ def atpass(ring: Lattice, init_pos: ndarray, nturns: int, refpts: ndarray, keep_ keep_lattice=keep_lattice) +def patpass(ring: Lattice, init_pos: ndarray, nturns: int, refpts: ndarray, keep_lattice: bool = False): + return at.patpass(lattice=ring.copy(), r_in=init_pos.copy(), nturns=nturns, refpts=refpts, + keep_lattice=keep_lattice) + + def atgetfieldvalues(ring: Lattice, refpts: ndarray, attrname: str, index: int = None): return at.get_value_refpts(ring, refpts, attrname, index) diff --git a/tests/test_at_wrapper.py b/tests/test_at_wrapper.py index 209c0e2..0ee172c 100644 --- a/tests/test_at_wrapper.py +++ b/tests/test_at_wrapper.py @@ -5,7 +5,7 @@ from numpy.testing import assert_equal import at from at import Lattice -from pySC.utils.at_wrapper import findspos, atgetfieldvalues, atpass, findorbit6, findorbit4 +from pySC.utils.at_wrapper import findspos, atgetfieldvalues, atpass, patpass, findorbit6, findorbit4 @@ -37,6 +37,17 @@ def test_atpass(at_lattice): assert_equal(indices, np.arange(11, 450, 22, dtype=int)) +def test_patpass(at_lattice): + lattice_copy = copy.deepcopy(at_lattice) + indices = np.arange(11, 450, 22, dtype=int) + initial_pos = np.random.randn(6) + copy_initial_pos = copy.deepcopy(initial_pos) + tracking = patpass(at_lattice, initial_pos, 3, indices,) + assert tracking.shape == (6, 1, 20, 3) + assert_equal(initial_pos, copy_initial_pos) + assert at_lattice.__repr__() == lattice_copy.__repr__() + assert_equal(indices, np.arange(11, 450, 22, dtype=int)) + def test_findorbit6(at_lattice): lattice_copy = copy.deepcopy(at_lattice) indices = np.arange(11, 450, 22, dtype=int)