"""Accelerator lattice"""
import json
import logging
import os
import re
from typing import TYPE_CHECKING, List, Sequence, Tuple, Type, Union
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import root
from .constraints import Constraints
from .harmonic_analysis import HarmonicAnalysis
from .transfer_matrix import TransferMatrix
from .utils import (
PLANE_INDICES,
PLANE_SLICES,
TransportedPhasespace,
TransportedTwiss,
compute_one_turn,
compute_twiss_solution,
to_twiss,
)
if TYPE_CHECKING: # pragma: no cover
from .elements.base import BaseElement
[docs]class Lattice(list):
"""A lattice of accelerator elements.
Looks like a list, smells like a list and tastes like a list.
Is in fact an accelerator lattice.
Examples:
Create a simple lattice.
>>> Lattice([Drift(1), QuadrupoleThin(0.8)])
Lattice([Drift(l=1, name="drift_0"), QuadrupoleThin(f=0.8, name="quadrupole_thin_0")])
"""
@classmethod
[docs] def load(cls, path: os.PathLike) -> "Lattice":
"""Load a lattice from a file.
Args:
path: File path.
Returns:
Loaded :py:class:`Lattice` instance.
Examples:
Save and load a lattice:
>>> lat = Lattice([Drift(1)])
>>> lat.save("drift.json")
>>> lat_loaded = Lattice.load("drift.json")
"""
# non top level import to avoid circular imports
from .elements.utils import deserialize
with open(path, "r") as fp:
serialized = json.load(fp)
return cls([deserialize(element) for element in serialized])
def __init__(self, *args):
super().__init__(*args)
self._m = None
self.plot = Plotter(self)
self.constraints = Constraints(self)
self._log = logging.getLogger(__name__)
@property
[docs] def m(self):
if self._m is None:
self._m = TransferMatrix(compute_one_turn([element.m for element in self]))
return self._m
def _clear_cache(self):
self._m = None
[docs] def closed_orbit(self, dp: float, **solver_kwargs) -> TransportedPhasespace:
"""Compute the closed orbit for a given dp/p.
Args:
dp: dp/p for which to compute the closed orbit.
**solver_kwargs: passed to `scipy.root`.
Returns:
Closed orbit solution transported through the lattice.
"""
return self.transport(self.closed_orbit_solution(dp, **solver_kwargs))
[docs] def closed_orbit_solution(self, dp: float, **solver_kwargs) -> np.ndarray:
"""Compute the closed orbit solution for a given dp/p.
Args:
dp: dp/p for which to compute the closed orbit.
**solver_kwargs: passed to `scipy.root`.
Returns:
Closed orbit solution.
"""
def try_solve(x_x_prime_y_y_prime):
init = np.zeros(5)
init[:4] = x_x_prime_y_y_prime
init[4] = dp
_, *transported = self.transport(init)
out = [point[-1] for point in transported]
return (init - out)[:4]
opt_res = root(try_solve, [0, 0, 0, 0], **solver_kwargs)
self._log.info("Closed orbit optimization:\n %s", opt_res)
solution = np.zeros(5)
solution[4] = dp
if opt_res.success:
solution[:4] = opt_res.x
else:
raise ValueError("Failed to compute closed orbit solution.")
return solution
[docs] def dispersion(self, **solver_kwargs) -> TransportedPhasespace:
"""Compute the dispersion, i.e. the closed orbit for a particle with dp/p = 1.
Args:
**solver_kwargs: passed to `scipy.root`.
Return:
Dispersion solution transported through the lattice.
"""
dp = 1e-3
out = self.closed_orbit(dp=dp, **solver_kwargs)
x = out.x / dp
y = out.y / dp
x_prime = out.x_prime / dp
y_prime = out.y_prime / dp
return TransportedPhasespace(out.s, x, x_prime, y, y_prime, out.dp)
[docs] def dispersion_solution(self, **solver_kwargs):
"""Compute the dispersion solution.
Args:
**solver_kwargs: passed to `scipy.root`.
Returns:
Dispersion solution.
"""
dp = 1e-3
out = self.closed_orbit_solution(dp=dp, **solver_kwargs)
out /= dp
return out
[docs] def twiss(self, plane="h") -> TransportedTwiss:
"""Compute the twiss parameters through the lattice for a given plane.
Args:
plane: plane of interest, either "h" or "v".
Returns:
Twiss parameters through the lattice.
"""
plane = plane.lower()
return self.transport_twiss(self.twiss_solution(plane=plane), plane=plane)
[docs] def twiss_solution(self, plane: str = "h") -> np.ndarray:
"""Compute the twiss periodic solution.
Args:
plane: plane of interest, either "h" or "v".
Returns:
Twiss periodic solution.
"""
plane = plane.lower()
return compute_twiss_solution(self.m[PLANE_SLICES[plane], PLANE_SLICES[plane]])
[docs] def tune(
self, plane: str = "h", n_turns: int = 1024, dp: float = 0, tol=1e-4
) -> float:
"""Compute the fractional part of the tune.
Note: the whole tune value would be Q = n + q or Q = n + (1 - q) with q
the fractional part of the tune returned by this method and n an integer.
Args:
plane: plane of interest, either "h" or "v".
n_turns: number of turns for which to track the particle, higher
values lead to more precise values at the expense of computation
time.
dp: dp/p value of the tracked particle.
tol: numerical tolerance for DC component.
Returns:
The fractional part of the tune.
"""
init = np.zeros(5)
init[PLANE_INDICES[plane]] = [1e-6, 0]
init[4] = dp
out_turns = [init]
# track for n_turns
for _ in range(n_turns - 1):
_, *transported = self.transport(out_turns[-1])
out_turns.append([point[-1] for point in transported])
out_turns = np.array(out_turns)
# get the frequency with the highest amplitude
position = out_turns[:, PLANE_INDICES[plane][0]]
angle = out_turns[:, PLANE_INDICES[plane][1]]
beta, alpha, _ = self.twiss_solution(plane=plane)
sqrt_beta = np.sqrt(beta)
norm_eta = position / sqrt_beta
norm_eta_prime = position * alpha / sqrt_beta + sqrt_beta * angle
complex_signal = norm_eta - 1j * norm_eta_prime
tune, _ = HarmonicAnalysis(complex_signal).laskar_method(2)
self._log.info("Harmonics: %s", tune)
if abs(tune[0]) < tol:
self._log.info("Dropped DC component.")
# if there is a DC component to the signal then the tune will be the
# second harmonic
tune = tune[1]
else:
# if not then the tune will be the first harmonic
tune = tune[0]
return tune
[docs] def chromaticity(self, plane: str = "h", delta_dp=1e-3, **kwargs) -> float:
"""Compute the chromaticity. Tracks 2 particles with different dp/p and
computes the chromaticity from the tune change.
Args:
plane: plane of interest, either "h" of "v".
delta_dp: dp/p difference between the 2 particles.
**kwargs: passed to the compute tune method.
Returns:
Chromaticity value.
"""
tune_0 = self.tune(plane=plane, dp=0, **kwargs)
tune_1 = self.tune(plane=plane, dp=delta_dp, **kwargs)
return (tune_1 - tune_0) / delta_dp
[docs] def slice(self, element_type: Type["BaseElement"], n_element: int) -> "Lattice":
"""Slice the `element_type` elements of the lattice into `n_element`.
Args:
element_type: Element class to slice.
n_element: Slice `element_type` into `n_element` smaller elements.
Returns:
Sliced :py:class:`Lattice`.
Examples:
Slice the :py:class:`~accelerator.elements.drift.Drift` elements
into 2:
>>> lat = Lattice([Drift(1), QuadrupoleThin(0.8)])
>>> lat.slice(Drift, 2)
Lattice([Drift(l=0.5, name="drift_0_slice_0"),
Drift(l=0.5, name="drift_0_slice_1"),
Quadrupole(f=0.8, name="quadrupole_thin_0")])
"""
new_lattice = []
for element in self:
if isinstance(element, element_type) and element.length > 0:
new_lattice.extend(element.slice(n_element))
else:
new_lattice.append(element)
return Lattice(new_lattice)
[docs] def transport(
self,
initial: Sequence[Union[float, np.ndarray]],
) -> TransportedPhasespace:
"""Transport phase space coordinates or twiss parameters along the lattice.
Args:
initial: phase space coords to transport through the
lattice.
Returns:
Transported phase space coords through the lattice.
Examples:
Transport phase space coords through a
:py:class:`~accelerator.elements.drift.Drift`:
>>> lat = Lattice([Drift(1)])
>>> lat.transport(phasespace=[1, 1, 0, 0, 0])
TransportedPhasespace(s=array([0, 1], x=array([1., 2.]), x_prime=array([1., 1.]), y=array([0, 0]), y_prime=array([0, 0]), dp=array([0., 0.]))
Transport a distribution of phase space coordinates through the
lattice:
>>> beam = Beam()
>>> lat = Lattice([Drift(1)])
>>> transported = lat.transport(beam.match([1, 0, 1]))
>>> plt.plot(tranported.s, transported.x)
...
Transport a phase space ellipse's coordinates through the lattice:
>>> beam = Beam()
>>> lat = Lattice([Drift(1)])
>>> transported = lat.transport(beam.ellipse([1, 0, 1]))
>>> plt.plot(transported.x, transported.x_prime)
...
"""
if not isinstance(initial, np.ndarray):
initial = np.array(initial)
out = [initial]
s_coords = [0]
for i, element in enumerate(self):
post_element = element._transport(out[i])
out.append(post_element)
s_coords.append(s_coords[i] + element.length)
x_coords, x_prime_coords, y_coords, y_prime_coords, dp_coords = zip(*out)
x_coords = np.vstack(x_coords).squeeze().T
x_prime_coords = np.vstack(x_prime_coords).squeeze().T
y_coords = np.vstack(y_coords).squeeze().T
y_prime_coords = np.vstack(y_prime_coords).squeeze().T
dp_coords = np.vstack(dp_coords).squeeze().T
return TransportedPhasespace(
np.array(s_coords),
x_coords,
x_prime_coords,
y_coords,
y_prime_coords,
dp_coords,
)
[docs] def transport_twiss(
self,
twiss: Sequence[float],
plane: str = "h",
) -> TransportedTwiss:
"""Transport the given twiss parameters along the lattice.
Args:
twiss: list of twiss parameters, beta[m], alpha[rad], and
gamma[m^-1], one twiss parameter can be None.
plane: plane of interest, either "h" or "v".
Returns:
Named tuple containing the twiss parameters along the lattice the
coordinates along the ring.
"""
twiss = to_twiss(twiss)
out = [twiss]
s_coords = [0]
transfer_ms = [element.m.twiss(plane=plane) for element in self]
for i, m in enumerate(transfer_ms):
out.append(m @ out[i])
s_coords.append(s_coords[i] + self[i].length)
out = np.hstack(out)
return TransportedTwiss(np.array(s_coords), *out)
[docs] def search(self, pattern: str, *args, **kwargs) -> List[int]:
"""Search the lattice for elements with `name` matching the pattern.
Args:
pattern: RegEx pattern.
*args: Passed to ``re.search``.
**kwargs: Passed to ``re.search``.
Raises:
ValueError: If not elements match the provided pattern.
Return:
List of indexes in the lattice where the element's name matches the pattern.
"""
pattern = re.compile(pattern)
out = [
i
for i, element in enumerate(self)
if re.search(pattern, element.name, *args, **kwargs)
]
if not out:
raise ValueError(f"'{pattern}' does not match with any elements in {self}")
return out
# Very ugly way of clearing cached one turn matrices on in place
# modification of the sequence.
[docs] def append(self, *args, **kwargs):
self._clear_cache()
return super().append(*args, **kwargs)
[docs] def clear(self, *args, **kwargs):
self._clear_cache()
return super().clear(*args, **kwargs)
[docs] def extend(self, *args, **kwargs):
self._clear_cache()
return super().extend(*args, **kwargs)
[docs] def insert(self, *args, **kwargs):
self._clear_cache()
return super().insert(*args, **kwargs)
[docs] def pop(self, *args, **kwargs):
self._clear_cache()
return super().pop(*args, **kwargs)
[docs] def remove(self, *args, **kwargs):
self._clear_cache()
return super().remove(*args, **kwargs)
[docs] def reverse(self, *args, **kwargs):
self._clear_cache()
return super().reverse(*args, **kwargs)
# Disable sorting
# TODO: is there a way to remove the method altogether?
[docs] def sort(self, *args, **kwargs):
"""DISABLED."""
def __add__(self, other):
return Lattice(list.__add__(self, other))
def __mul__(self, other):
return Lattice(list.__mul__(self, other))
def __getitem__(self, item):
result = list.__getitem__(self, item)
try:
return Lattice(result)
except TypeError:
return result
[docs] def save(self, path: os.PathLike):
"""Save a lattice to file.
Args:
path: File path.
Examples:
Save a lattice:
>>> lat = Lattice([Drift(1)])
>>> lat.save('drift.json')
"""
serializable = [element._serialize() for element in self]
with open(path, "w") as fp:
json.dump(serializable, fp, indent=4)
[docs] def copy(self, deep=True) -> "Lattice":
"""Create a copy of the lattice.
Args:
deep: If True create copies of the elements themselves.
Returns:
A copy of the lattice.
"""
if deep:
return Lattice([element.copy() for element in self])
return Lattice(self)
def __repr__(self):
return f"Lattice({super().__repr__()})"
[docs]class Plotter:
"""Lattice plotter.
Args:
Lattice: :py:class:`Lattice` instance.
Examples:
Plot a lattice:
>>> lat = Lattice([QuadrupoleThin(-0.6), Drift(1), QuadrupoleThin(0.6)])
>>> lat.plot.layout() # or lat.plot("layout")
...
Plot the top down view of the lattice:
>>> lat = Lattice([Drift(1), Dipole(1, np.pi/2)])
>>> lat.plot.top_down() # or lat.plot("top_down")
...
"""
def __init__(self, lattice: Lattice):
self._lattice = lattice
[docs] def top_down(
self,
n_s_per_element: int = int(1e3),
) -> Tuple[plt.Figure, plt.Axes]:
"""Plot the s coordinate in the horizontal plane of the lattice.
Args:
n_s_per_element: Number of steps along the s coordinate for each
element in the lattice.
Returns:
Plotted ``plt.Figure`` and ``plt.Axes``.
"""
xztheta = [np.array([0, 0, np.pi / 2])]
s_start = 0
for element in self._lattice:
if element.length == 0:
# thin elements don't waste time on slicing them and running
# this many times
xztheta.append(xztheta[-1] + element._dxztheta_ds(xztheta[-1][2], 0))
else:
d_s = element.length / n_s_per_element
for _ in range(n_s_per_element):
xztheta.append(
xztheta[-1] + element._dxztheta_ds(xztheta[-1][2], d_s)
)
s_start += element.length
xztheta = np.vstack(xztheta)
fig, ax = plt.subplots(1, 1)
ax.plot(xztheta[:, 0], xztheta[:, 1], label="s")
# forcefully adding margins, this might cause issues
if xztheta[:, 0].max() - xztheta[:, 0].min() < 0.1:
ax.set_xlim((-1, 1))
if xztheta[:, 1].max() - xztheta[:, 1].min() < 0.1:
ax.set_ylim((-1, 1))
ax.set_aspect("equal")
ax.margins(0.05)
ax.set_xlabel("x [m]")
ax.set_ylabel("z [m]")
ax.legend()
return fig, ax
[docs] def layout(self) -> Tuple[plt.Figure, plt.Axes]:
"""Plot the lattice.
Returns:
Plotted ``plt.Figure`` and ``plt.Axes``.
"""
fig, ax = plt.subplots(1, 1)
s_coord = 0
for element in self._lattice:
patch = element._get_patch(s_coord)
s_coord += element.length
# skip elements which don't have a defined patch
if patch is None:
continue
ax.add_patch(patch)
ax.hlines(0, 0, s_coord, color="tab:gray", ls="dashed")
ax.axes.yaxis.set_visible(False)
ax.margins(0.05)
ax.set_xlabel("s [m]")
# remove duplicates from the legend
handles, labels = ax.get_legend_handles_labels()
unique_indexes = sorted([labels.index(label) for label in set(labels)])
new_handles = [handles[i] for i in unique_indexes]
new_labels = [labels[i] for i in unique_indexes]
ax.legend(
handles=new_handles,
labels=new_labels,
bbox_to_anchor=(1.05, 1),
loc="upper left",
)
return fig, ax
def __call__(self, *args, plot_type="layout", **kwargs):
return getattr(self, plot_type)(*args, **kwargs)
def __repr__(self):
return f"Plotter({repr(self._lattice)})"