Skip to content

Commit

Permalink
Adding parallelized tracking from AT (#42)
Browse files Browse the repository at this point in the history
Used only for beam transmission
  • Loading branch information
lmalina authored Sep 26, 2023
1 parent 1c7e002 commit a6c8d3e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pySC/core/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pySC.core.simulated_commissioning import SimulatedCommissioning
from pySC.core.constants import TRACK_ORB, TRACK_PORB, TRACK_TBT
from pySC.utils.sc_tools import SCrandnc
from pySC.utils.at_wrapper import atgetfieldvalues, atpass, findorbit6, findspos
from pySC.utils.at_wrapper import atgetfieldvalues, atpass, findorbit6, findspos, patpass
import warnings
from pySC.utils import logging_tools

Expand Down Expand Up @@ -130,7 +130,7 @@ def beam_transmission(SC: SimulatedCommissioning, nParticles: int = None, nTurns
if nTurns is None:
nTurns = SC.INJ.nTurns
LOGGER.debug(f'Calculating maximum beam transmission for {nParticles} particles and {nTurns} turns: ')
T = atpass(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]), keep_lattice=False)
T = patpass(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]), keep_lattice=False)
fraction_survived = np.mean(~np.isnan(T[0, :, :, :]), axis=(0, 1))
max_turns = np.sum(fraction_survived > 1 - SC.INJ.beamLostAt)
if plot:
Expand Down
5 changes: 5 additions & 0 deletions pySC/utils/at_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def atpass(ring: Lattice, init_pos: ndarray, nturns: int, refpts: ndarray, keep_
keep_lattice=keep_lattice)


def patpass(ring: Lattice, init_pos: ndarray, nturns: int, refpts: ndarray, keep_lattice: bool = False):
return at.patpass(lattice=ring.copy(), r_in=init_pos.copy(), nturns=nturns, refpts=refpts,
keep_lattice=keep_lattice)


def atgetfieldvalues(ring: Lattice, refpts: ndarray, attrname: str, index: int = None):
return at.get_value_refpts(ring, refpts, attrname, index)

Expand Down
13 changes: 12 additions & 1 deletion tests/test_at_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.testing import assert_equal
import at
from at import Lattice
from pySC.utils.at_wrapper import findspos, atgetfieldvalues, atpass, findorbit6, findorbit4
from pySC.utils.at_wrapper import findspos, atgetfieldvalues, atpass, patpass, findorbit6, findorbit4



Expand Down Expand Up @@ -37,6 +37,17 @@ def test_atpass(at_lattice):
assert_equal(indices, np.arange(11, 450, 22, dtype=int))


def test_patpass(at_lattice):
lattice_copy = copy.deepcopy(at_lattice)
indices = np.arange(11, 450, 22, dtype=int)
initial_pos = np.random.randn(6)
copy_initial_pos = copy.deepcopy(initial_pos)
tracking = patpass(at_lattice, initial_pos, 3, indices,)
assert tracking.shape == (6, 1, 20, 3)
assert_equal(initial_pos, copy_initial_pos)
assert at_lattice.__repr__() == lattice_copy.__repr__()
assert_equal(indices, np.arange(11, 450, 22, dtype=int))

def test_findorbit6(at_lattice):
lattice_copy = copy.deepcopy(at_lattice)
indices = np.arange(11, 450, 22, dtype=int)
Expand Down

0 comments on commit a6c8d3e

Please sign in to comment.