Source code for pyaccelerator.constraints

from abc import abstractmethod
from collections.abc import Iterable
from logging import getLogger
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union

import numpy as np
from scipy.optimize import minimize

from .elements.base import BaseElement
from .utils import PLANE_INDICES

if TYPE_CHECKING:  # pragma: no cover
    from scipy.optimize import OptimizeResult

    from .lattice import Lattice


[docs]class BaseTarget: """Base target.""" @abstractmethod
[docs] def loss(self, lattice: "Lattice"):
"""Compute the loss for this target."""
[docs]class TargetPhasespace(BaseTarget): """Target phase space coordinates. Args: element: Element name pattern or element instance at which the `value` should be achieved. value: Target phase space coordinates at the given `element`. x (optional): Target x coordinate. x_prime (optional): Target x_prime coordinate. y (optional): Target y coordinate. y_prime (optional): Target y_prime coordinate. dp (optional): Target dp coordinate (probably breaks everything). initial (optional): Initial phase space coordinates with which to start the transport. If None will use the close orbit solution. """ def __init__( self, element: Union[str, "BaseElement"], x: Optional[float] = None, x_prime: Optional[float] = None, y: Optional[float] = None, y_prime: Optional[float] = None, dp: Optional[float] = None, initial: Optional[Sequence[float]] = None, ): value = [x, x_prime, y, y_prime, dp] if all([v is None for v in value]): raise ValueError("All phase space coords are None.") if initial is not None: initial = tuple(initial) if isinstance(element, BaseElement): element = element.name self.element = element self.value = np.array(value) self.initial = initial # as of yet dp doesn't change use it to set the closed orbit if dp is None: dp = 0 self._dp = dp def _transport(self, lattice: "Lattice"): if self.initial is None: init = lattice.closed_orbit_solution(self._dp) else: init = self.initial _, *tranported = lattice.transport(init) return np.vstack(tranported)
[docs] def loss(self, lattice: "Lattice") -> float: try: transported = self._transport(lattice) except ValueError: return np.inf transported_columns = [i + 1 for i in lattice.search(self.element)] transported_rows = [ i for i, value in enumerate(self.value) if value is not None ] result = transported[transported_rows, transported_columns] return abs(result - self.value[transported_rows])
def __repr__(self) -> str: args = ["element", "value", "initial"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"TargetPhasespace({arg_string})"
[docs]class TargetTwiss(BaseTarget): """Target twiss parameters. Args: element: Element name pattern or element instance at which the target `beta`, `alpha`, `gamma` should be achieved. beta (optional): Target beta value at the location of the `element`. alpha (optional): Target alpha value at the location of the `element`. gamma (optional): Target gamma value at the location of the `element`. plane: Plane of interest, either "h" or "v". """ def __init__( self, element: Union[str, "BaseElement"], beta: Optional[float] = None, alpha: Optional[float] = None, gamma: Optional[float] = None, plane: str = "h", ): plane = plane.lower() value = [beta, alpha, gamma] if all([v is None for v in value]): raise ValueError("All twiss parameters are None.") if isinstance(element, BaseElement): element = element.name self.element = element self.value = np.array(value) self.plane = plane def _transport(self, lattice: "Lattice"): _, *twiss = lattice.twiss(plane=self.plane) return np.vstack(twiss)
[docs] def loss(self, lattice: "Lattice") -> float: try: transported = self._transport(lattice) except ValueError: return np.inf transported_columns = [i + 1 for i in lattice.search(self.element)] transported_rows = [ i for i, value in enumerate(self.value) if value is not None ] result = transported[transported_rows, transported_columns] return abs(result - self.value[transported_rows])
def __repr__(self) -> str: args = ["element", "value", "plane"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"TargetTwiss({arg_string})"
[docs]class TargetTwissSolution(BaseTarget): """Target periodic twiss solution, twiss parameters at the beginning and end of the lattice. Useful when a lattice needs some coaxing into having a periodic twiss solution i.e. when `lattice.twiss` fails to find a solution. Note: Only one of the twiss arguments can be omitted. Args: beta (optional): Target beta value at beginning and end of lattice. alpha (optional): Target alpha value at beginning and end of lattice. gamma (optional): Target gamma value at beginning and end of lattice. plane: Plane of interest, either "h" or "v". """ def __init__( self, beta: Optional[float] = None, alpha: Optional[float] = None, gamma: Optional[float] = None, plane: str = "h", ): plane = plane.lower() value = [beta, alpha, gamma] if all([v is None for v in value]): raise ValueError("All twiss parameters are None.") self.value = np.array(value) self.plane = plane def _transport(self, lattice: "Lattice"): _, *twiss = lattice.transport_twiss(self.value, plane=self.plane) return np.vstack(twiss)
[docs] def loss(self, lattice: "Lattice") -> float: try: transported = self._transport(lattice) except ValueError: return np.inf transported_columns = -1 # Use Twiss values at end of lattice transported_rows = [ i for i, value in enumerate(self.value) if value is not None ] result = transported[transported_rows, transported_columns] return abs(result - self.value[transported_rows])
def __repr__(self) -> str: args = ["value", "plane"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"TargetTwissSolution({arg_string})"
[docs]class TargetDispersion(BaseTarget): """Target dispersion function. Args: element: Element name pattern or element instance at which the `value` should be achieved. value: Target value of dispersion function at the given `element`. plane: Plane of interest, either "h" or "v". """ def __init__( self, element: Union[str, "BaseElement"], value: float, plane: str = "h", **solver_kwargs, ): plane = plane.lower() if isinstance(element, BaseElement): element = element.name self.element = element self.value = value self.plane = plane self.solver_kwargs = solver_kwargs def _transport(self, lattice: "Lattice"): _, *transported = lattice.dispersion(**self.solver_kwargs) return transported[PLANE_INDICES[self.plane][0]]
[docs] def loss(self, lattice: "Lattice") -> float: try: transported = self._transport(lattice) except ValueError: return np.inf transported_columns = [i + 1 for i in lattice.search(self.element)] result = transported[transported_columns] return abs(result - self.value)
def __repr__(self) -> str: args = ["element", "value", "plane", "solver_kwargs"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"TargetDispersion({arg_string})"
[docs]class TargetGlobal(BaseTarget): """Target global lattice attribute. Args: method: lattice method to match. value: Target value. **method_kwargs: additional method kwargs. Examples: Create a constraint on the horizontal tune value: >>> TargetGlobal("tune", 0.23, plane='h', n_turns=512) TargetGlobal(method='tune', value=0.23, method_kwargs={'plane': 'h', 'n_turns': 512}) """ def __init__(self, method: str, value: float, **method_kwargs): self.method = method self.value = value self.method_kwargs = method_kwargs
[docs] def loss(self, lattice: "Lattice"): out = getattr(lattice, self.method)(**self.method_kwargs) return abs(out - self.value)
def __repr__(self) -> str: args = ["method", "value", "method_kwargs"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"TargetGlobal({arg_string})"
[docs]class FreeParameter: """Constraint free parameter. Args: element: Element name pattern or element instance for which the provided `attribute` will be considered a free parameter. attribute: attribute of `element`. """ def __init__(self, element: Union[str, "BaseElement"], attribute: str): if isinstance(element, BaseElement): element = element.name self.element = element self.attribute = attribute def __repr__(self) -> str: args = ["element", "attribute"] arg_string = ", ".join([arg + "=" + repr(getattr(self, arg)) for arg in args]) return f"FreeParameter({arg_string})"
[docs]class Constraints: """Match a lattice to constraints. Args: lattice: :py:class:`~accelerator.lattice.Lattice` instance on which to match. Examples: Compute :py:class:`~accelerator.elements.drift.Drift` length to reach a x coord of 10 meters: >>> lat = Lattice([Drift(1)]) >>> lat.constraints.add_free_parameter(element="drift", attribute="l") >>> target = TargetPhasespace("drift", x=10, initial=[0, 1, 0, 0, 0]) >>> lat.constraints.add_target(target) >>> matched_lat, _ = lat.constraints.match() >>> matched_lat Lattice([Drift(length=10, name='drift_0')]) Compute :py:class:`~accelerator.elements.drift.Drift` length to reach a x coord of 5 meters after the first Drift: >>> lat = Lattice([Drift(1), Drift(1)]) >>> lat.constraints.add_free_parameter("drift_0", "l") >>> target = TargetPhasespace("drift_0", x=5, initial=[0, 1, 0, 0, 0]) >>> lat.constraints.add_target(target) >>> matched_lat, _ = lat.constraints.match() >>> matched_lat Lattice([Drift(length=5, name='drift_0'), Drift(length=1, name='drift_1')]) Compute :py:class:`~accelerator.elements.drift.Drift` length to reach a x coord of 5 meters after the second Drift with equal lengths of both Drifts: >>> lat = Lattice([Drift(1), Drift(1)]) >>> lat.constraints.add_free_parameter("drift", "l") >>> target = TargetPhasespace("drift_1", x=5, initial=[0, 1, 0, 0, 0]) >>> lat.constraints.add_target(target) >>> matched_lat, _ = lat.constraints.match() >>> matched_lat Lattice([Drift(length=2.5, name='drift_0'), Drift(length=2.5, name='drift_1')]) Compute the :py:class:`~accelerator.elements.quadrupole.Quadrupole` strengths of a FODO cell to achieve a minimum beta of 0.5 meters: >>> lat = Lattice([QuadrupoleThin(1.6, name='quad_f'), Drift(1), QuadrupoleThin(-0.8, name='quad_d'), ... Drift(1), QuadrupoleThin(1.6, name='quad_f')]) >>> lat.constraints.add_free_parameter("quad_f", "f") >>> lat.constraints.add_free_parameter("quad_d", "f") >>> target = TargetTwiss("quad_d", beta=0.5, plane="h") >>> lat.constraints.add_target(target) >>> matched_lat, _ = lat.constraints.match() >>> matched_lat Lattice([QuadrupoleThin(f=1.319, name='quad_f'), Drift(length=1, name='drift_0'), QuadrupoleThin(f=-0.918, name='quad_d'), Drift(1, name='drift_1'), QuadrupoleThin(f=1.319, name='quad_f')]) """ def __init__(self, lattice: "Lattice"): self._lattice = lattice self.targets = [] self.free_parameters = [] self._logger = getLogger(__name__)
[docs] def add_target(self, target: BaseTarget): """Add a constraint target. Args: target: An instance of either ``TargetPhasespace``, ``TargetDispersion``, ``TargetTwiss`` or ``TargetGlobal``. """ self.targets.append(target)
[docs] def add_free_parameter(self, element: str, attribute: str): """Add a free parameter. Args: element: Element name pattern or element instance for which the provided `attribute` will be considered a free parameter. attribute: attribute of `element`. Examples: Setting a :py:class:`~accelerator.element.drift.Drift`'s length as a free parameters: >>> drift = Drift(1) >>> lat = Lattice([drift]) >>> lat Lattice([Drift(l=1, name='drift_0')]) >>> lat.constraints.add_free_parameter("drift_0", "l") ... # or lat.constraints.add_free_parameter(drift, "l") """ self.free_parameters.append(FreeParameter(element, attribute))
[docs] def clear(self): """Clear the targets and free parameters.""" self.targets.clear() self.free_parameters.clear()
[docs] def match(self, *args, **kwargs) -> Tuple["Lattice", "OptimizeResult"]: """Match lattice properties to constraints using ``scipy.optimize.minimize``. Args: *args: Passed to ``scipy.optimize.minimize``. **kwargs: Passed to ``scipy.optimize.minimize``. Raises: ValueError: If no targets or free parameters specified. Returns: New matched :py:class:`~accelerator.lattce.Lattice` instance and ``scipy.optimize.OptmizeResult``. """ if "method" not in kwargs.keys() and "constraints" not in kwargs.keys(): kwargs["method"] = "Nelder-Mead" if self.targets == []: raise ValueError("No targets specified.") if self.free_parameters == []: raise ValueError("No free parameters specified.") lattice = self._lattice.copy() root_start = self._get_initial(lattice) def match_function(new_settings): self._set_parameters(new_settings, lattice) out = [] for target in self.targets: loss = target.loss(lattice) if not isinstance(loss, Iterable): loss = [loss] out.extend(loss) return np.linalg.norm(out, 2) # a root finder would be better suited but scipy doesn't have one that # works with arbitrary number of inputs & outputs, that I could find. res = minimize(match_function, *args, x0=root_start, **kwargs) if res.success: # sometimes the last iteration is not the minimum, set the real # solution self._set_parameters(res.x, lattice) if res.fun > 1e-1: # as this is a minimzation algorithm, it can find a minimum # but the matching could still be off. self._logger.warning("Loss is high:%f, double check the matching.", res.fun) return lattice, res
def _set_parameters(self, new_settings: Sequence[float], lattice: "Lattice"): """Set the new lattice settings.""" for param, value in zip(self.free_parameters, new_settings): for i in lattice.search(param.element): setattr(lattice[i], param.attribute, value) # TODO: decide if we keep the caching of the one turn matrices? lattice._clear_cache() def _get_initial(self, lattice: "Lattice") -> Sequence[float]: """Get the starting point for the minimization algorithm.""" out = [] for param in self.free_parameters: out.append( # this mean might cause issues, maybe switch to the median ? np.mean( [ getattr(lattice[i], param.attribute) for i in lattice.search(param.element) ] ) ) return out def __repr__(self): return f"Free Parameters: {repr(self.free_parameters)}\nTargets: {repr(self.targets)}"