Skip to content

Commit

Permalink
update and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
loiccoyle committed Jun 26, 2024
1 parent e1ef48d commit 15d8863
Show file tree
Hide file tree
Showing 21 changed files with 1,271 additions and 1,135 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,27 @@ name: tests

on:
push:
branches: [ master ]
branches: [main]
pull_request:
branches: [ master ]
branches: [main]

jobs:
ci:
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
poetry-version: [1.1.2]
python-version: [3.9, 3.12]
poetry-version: [1.8.3]
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Run image
uses: abatilo/actions-poetry@v2.0.0
uses: abatilo/actions-poetry@v2
with:
poetry-version: ${{ matrix.poetry-version }}

Expand Down
2,074 changes: 1,107 additions & 967 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions pyaccelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Python package to build simple toy accelerator.
"""
"""Python package to build simple toy accelerator."""

import logging

from . import elements
Expand All @@ -11,7 +11,6 @@
TargetTwiss,
TargetTwissSolution,
)
from .elements import *
from .lattice import Lattice

__all__ = [
Expand Down
49 changes: 21 additions & 28 deletions pyaccelerator/beam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Accelerator Beam"""

from typing import Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -92,7 +93,7 @@ def gamma_relativistic(self):

@property
def beta_relativistic(self):
return np.sqrt(1 - 1 / self.gamma_relativistic ** 2)
return np.sqrt(1 - 1 / self.gamma_relativistic**2)

@property
def geo_emittance_h(self):
Expand All @@ -105,20 +106,20 @@ def geo_emittance_v(self):
@property
def p(self):
# in MeV/c
return np.sqrt(self.energy ** 2 + 2 * self.energy * self.mass)
return np.sqrt(self.energy**2 + 2 * self.energy * self.mass)

@property
def sigma_p(self):
absolute_sigma_e = self.sigma_energy * self.energy
# in MeV/c
return np.sqrt(absolute_sigma_e ** 2 + 2 * absolute_sigma_e * self.mass)
return np.sqrt(absolute_sigma_e**2 + 2 * absolute_sigma_e * self.mass)

def ellipse(
self,
twiss_h: Sequence[float],
twiss_v: Optional[Sequence[float]] = None,
twiss_h: Union[Sequence[float], np.ndarray],
twiss_v: Optional[Union[Sequence[float], np.ndarray]] = None,
closure_tol: float = 1e-10,
n_angles: int = 1e3,
n_angles: int = 1000,
) -> PhasespaceDistribution:
"""Compute the beam's phase space ellipse given the twiss parameters.
Expand All @@ -137,17 +138,13 @@ def ellipse(
Position, angle phase and dp/p space coordrinates of the ellipse.
Note, dp/p will be set to 0.
"""
twiss_h = to_twiss(twiss_h)
if twiss_v is None:
# if no vertical twiss provided use the same as the horizontal
twiss_v = twiss_h
else:
twiss_v = to_twiss(twiss_v)
beta_h, alpha_h, _ = twiss_h.T[0] # pylint: disable=unsubscriptable-object
beta_v, alpha_v, _ = twiss_v.T[0] # pylint: disable=unsubscriptable-object
twiss_h_ = to_twiss(twiss_h)
twiss_v_ = twiss_h_ if twiss_v is None else to_twiss(twiss_v)
beta_h, alpha_h, _ = twiss_h_.T[0]
beta_v, alpha_v, _ = twiss_v_.T[0]
# check the twiss parameters
for twiss in (twiss_h, twiss_v):
closure = compute_twiss_clojure(twiss)
for twiss in (twiss_h_, twiss_v_):
closure = compute_twiss_clojure(twiss) # type: ignore
if not -closure_tol <= closure - 1 <= closure_tol:
raise ValueError(
f"Closure condition not met for {twiss}: beta * gamma - alpha**2 = {closure} != 1"
Expand All @@ -169,8 +166,8 @@ def ellipse(

def match(
self,
twiss_h: Sequence[float],
twiss_v: Optional[Sequence[float]] = None,
twiss_h: Union[Sequence[float], np.ndarray],
twiss_v: Optional[Union[Sequence[float], np.ndarray]] = None,
closure_tol: float = 1e-10,
) -> PhasespaceDistribution:
"""Generate a matched beam phase space distribution to the provided
Expand All @@ -188,17 +185,13 @@ def match(
Returns:
Position, angle and dp/p phase space coordinates.
"""
twiss_h = to_twiss(twiss_h)
if twiss_v is None:
# if no vertical twiss provided use the same as the horizontal
twiss_v = twiss_h
else:
twiss_v = to_twiss(twiss_v)
beta_h, alpha_h, _ = twiss_h.T[0] # pylint: disable=unsubscriptable-object
beta_v, alpha_v, _ = twiss_v.T[0] # pylint: disable=unsubscriptable-object
twiss_h_ = to_twiss(twiss_h)
twiss_v_ = twiss_h_ if twiss_v is None else to_twiss(twiss_v)
beta_h, alpha_h, _ = twiss_h_.T[0]
beta_v, alpha_v, _ = twiss_v_.T[0]
# check the twiss parameters
for twiss in (twiss_h, twiss_v):
closure = compute_twiss_clojure(twiss)
for twiss in (twiss_h_, twiss_v_):
closure = compute_twiss_clojure(twiss) # type: ignore
if not -closure_tol <= closure - 1 <= closure_tol:
raise ValueError(
f"Closure condition not met: beta * gamma - alpha**2 = {closure} != 1"
Expand Down
26 changes: 13 additions & 13 deletions pyaccelerator/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union

import numpy as np
from scipy.optimize import minimize
from scipy.optimize import OptimizeResult, minimize

from .elements.base import BaseElement
from .utils import PLANE_INDICES

if TYPE_CHECKING: # pragma: no cover
from scipy.optimize import OptimizeResult

if TYPE_CHECKING:
from .lattice import Lattice


class BaseTarget:
"""Base target."""

@abstractmethod
def loss(self, lattice: "Lattice"):
def loss(self, lattice: "Lattice") -> float:
"""Compute the loss for this target."""


Expand Down Expand Up @@ -83,8 +81,8 @@ def loss(self, lattice: "Lattice") -> float:
transported_rows = [
i for i, value in enumerate(self.value) if value is not None
]
result = transported[transported_rows, transported_columns]
return abs(result - self.value[transported_rows])
result: float = transported[transported_rows, transported_columns] # type: ignore
return abs(result - self.value[transported_rows]) # type: ignore

def __repr__(self) -> str:
args = ["element", "value", "initial"]
Expand Down Expand Up @@ -136,8 +134,8 @@ def loss(self, lattice: "Lattice") -> float:
transported_rows = [
i for i, value in enumerate(self.value) if value is not None
]
result = transported[transported_rows, transported_columns]
return abs(result - self.value[transported_rows])
result: float = transported[transported_rows, transported_columns] # type: ignore
return abs(result - self.value[transported_rows]) # type: ignore

def __repr__(self) -> str:
args = ["element", "value", "plane"]
Expand Down Expand Up @@ -185,13 +183,13 @@ def loss(self, lattice: "Lattice") -> float:
except ValueError:
return np.inf

transported_columns = -1 # Use Twiss values at end of lattice
transported_columns = -1 # Use Twiss values at end of lattice
transported_rows = [
i for i, value in enumerate(self.value) if value is not None
]

result = transported[transported_rows, transported_columns]
return abs(result - self.value[transported_rows])
result: float = transported[transported_rows, transported_columns] # type: ignore
return abs(result - self.value[transported_rows]) # type: ignore

def __repr__(self) -> str:
args = ["value", "plane"]
Expand Down Expand Up @@ -440,7 +438,9 @@ def match_function(new_settings):
if res.fun > 1e-1:
# as this is a minimzation algorithm, it can find a minimum
# but the matching could still be off.
self._logger.warning("Loss is high:%f, double check the matching.", res.fun)
self._logger.warning(
"Loss is high:%f, double check the matching.", res.fun
)
return lattice, res

def _set_parameters(self, new_settings: Sequence[float], lattice: "Lattice"):
Expand Down
1 change: 1 addition & 0 deletions pyaccelerator/elements/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Accelerator elements."""

from .custom import CustomThin
from .dipole import Dipole, DipoleThin
from .drift import Drift
Expand Down
29 changes: 21 additions & 8 deletions pyaccelerator/elements/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from abc import abstractmethod
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional

import numpy as np
from matplotlib.patches import Patch

from ..transfer_matrix import TransferMatrix

if TYPE_CHECKING:
from ..lattice import Lattice

class BaseElement:
"""Base class of a lattice element.

Args:
*instance_args: Arguments required to make the instance of this
class's subclasses.
"""
class BaseElement:
name: str

def __init__(self, *instance_args):
"""Base class of a lattice element.
Args:
*instance_args: Arguments required to make the instance of this
class's subclasses.
"""
# args of the subclass instance.
self._instance_args = instance_args

Expand All @@ -37,7 +41,7 @@ def _get_length(self) -> float: # pragma: no cover
pass

@abstractmethod
def _get_patch(self, s: float) -> Patch:
def _get_patch(self, s: float) -> Optional[Patch]:
"""Generate a ``matplotlib.patches.Patch`` object to represent the
element when plotting the lattice.
Expand All @@ -48,6 +52,15 @@ def _get_patch(self, s: float) -> Patch:
``matplotlib.patches.Patch`` which represents the element.
"""

@abstractmethod
def slice(self, n_slices: int) -> "Lattice":
"""Slice the element into many smaller elements.
Args:
n_slices: Number of elements to slice the element into.
Returns:
A list of sliced elements.
"""

def _transport(self, phase_coords: np.ndarray) -> np.ndarray:
return (self._get_transfer_matrix() @ phase_coords) + self._non_linear_term(
phase_coords
Expand Down
8 changes: 4 additions & 4 deletions pyaccelerator/elements/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import count
from typing import Optional, Tuple, Union
from typing import Optional

import numpy as np
from matplotlib import patches
Expand All @@ -8,7 +8,6 @@


class CustomThin(BaseElement):

_instance_count = count(0)

def __init__(
Expand Down Expand Up @@ -44,7 +43,7 @@ def _get_length(self) -> float:
def _get_transfer_matrix(self) -> np.ndarray:
return self.transfer_matrix

def _get_patch(self, s: float) -> Union[None, patches.Patch]:
def _get_patch(self, s: float) -> patches.Patch:
label = self.name
colour = "black"

Expand All @@ -59,6 +58,7 @@ def _get_patch(self, s: float) -> Union[None, patches.Patch]:

@staticmethod
def _dxztheta_ds(
theta: float, d_s: float # pylint: disable=unused-argument
theta: float,
d_s: float, # pylint: disable=unused-argument
) -> np.ndarray:
return np.array([0, 0, 0])
12 changes: 5 additions & 7 deletions pyaccelerator/elements/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,18 @@ def _get_transfer_matrix(self) -> np.ndarray:
out[4, 4] = 1
return out

def slice(self, n_dipoles: int) -> Lattice:
def slice(self, n_slices: int) -> Lattice:
"""Slice the element into a many smaller elements.
Args:
n_dipoles: Number of :py:class:`Dipole` elements.
n_slices: Number of :py:class:`Dipole` elements.
Returns:
:py:class:`~accelerator.lattice.Lattice` of sliced :py:class:`Dipole` elements.
"""
out = [
Dipole(self.rho, self.theta / n_dipoles, name=f"{self.name}_slice_{i}")
for i in range(n_dipoles)
Dipole(self.rho, self.theta / n_slices, name=f"{self.name}_slice_{i}")
for i in range(n_slices)
]
return Lattice(out)

Expand Down Expand Up @@ -136,7 +136,5 @@ def _get_patch(self, s: float) -> patches.Patch:
facecolor="lightcoral",
)

def _dxztheta_ds(
self, theta: float, d_s: float # pylint: disable=unused-argument
) -> np.ndarray:
def _dxztheta_ds(self, theta: float, d_s: float) -> np.ndarray:
return np.array([0, 0, self.theta])
8 changes: 4 additions & 4 deletions pyaccelerator/elements/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ def _get_transfer_matrix(self) -> np.ndarray:
out[2, 3] = self.length
return out

def slice(self, n_drifts: int) -> Lattice:
def slice(self, n_slices: int) -> Lattice:
"""Slice the element into a many smaller elements.
Args:
n_drifts: Number of :py:class:`Drift` elements.
n_slices: Number of :py:class:`Drift` elements.
Returns:
:py:class:`~accelerator.lattice.Lattice` of sliced :py:class:`Drift`
elements.
"""
out = [
Drift(self.length / n_drifts, name=f"{self.name}_slice_{i}")
for i in range(n_drifts)
Drift(self.length / n_slices, name=f"{self.name}_slice_{i}")
for i in range(n_slices)
]
return Lattice(out)

Expand Down
Loading

0 comments on commit 15d8863

Please sign in to comment.